diff --git a/.beads/.gitignore b/.beads/.gitignore new file mode 100644 index 0000000..d27a1db --- /dev/null +++ b/.beads/.gitignore @@ -0,0 +1,44 @@ +# SQLite databases +*.db +*.db?* +*.db-journal +*.db-wal +*.db-shm + +# Daemon runtime files +daemon.lock +daemon.log +daemon.pid +bd.sock +sync-state.json +last-touched + +# Local version tracking (prevents upgrade notification spam after git ops) +.local_version + +# Legacy database files +db.sqlite +bd.db + +# Worktree redirect file (contains relative path to main repo's .beads/) +# Must not be committed as paths would be wrong in other clones +redirect + +# Merge artifacts (temporary files from 3-way merge) +beads.base.jsonl +beads.base.meta.json +beads.left.jsonl +beads.left.meta.json +beads.right.jsonl +beads.right.meta.json + +# Sync state (local-only, per-machine) +# These files are machine-specific and should not be shared across clones +.sync.lock +sync_base.jsonl + +# NOTE: Do NOT add negation patterns (e.g., !issues.jsonl) here. +# They would override fork protection in .git/info/exclude, allowing +# contributors to accidentally commit upstream issue databases. +# The JSONL files (issues.jsonl, interactions.jsonl) and config files +# are tracked by git by default since no pattern above ignores them. diff --git a/.beads/README.md b/.beads/README.md new file mode 100644 index 0000000..50f281f --- /dev/null +++ b/.beads/README.md @@ -0,0 +1,81 @@ +# Beads - AI-Native Issue Tracking + +Welcome to Beads! This repository uses **Beads** for issue tracking - a modern, AI-native tool designed to live directly in your codebase alongside your code. + +## What is Beads? + +Beads is issue tracking that lives in your repo, making it perfect for AI coding agents and developers who want their issues close to their code. No web UI required - everything works through the CLI and integrates seamlessly with git. + +**Learn more:** [github.com/steveyegge/beads](https://github.com/steveyegge/beads) + +## Quick Start + +### Essential Commands + +```bash +# Create new issues +bd create "Add user authentication" + +# View all issues +bd list + +# View issue details +bd show + +# Update issue status +bd update --status in_progress +bd update --status done + +# Sync with git remote +bd sync +``` + +### Working with Issues + +Issues in Beads are: +- **Git-native**: Stored in `.beads/issues.jsonl` and synced like code +- **AI-friendly**: CLI-first design works perfectly with AI coding agents +- **Branch-aware**: Issues can follow your branch workflow +- **Always in sync**: Auto-syncs with your commits + +## Why Beads? + +✨ **AI-Native Design** +- Built specifically for AI-assisted development workflows +- CLI-first interface works seamlessly with AI coding agents +- No context switching to web UIs + +🚀 **Developer Focused** +- Issues live in your repo, right next to your code +- Works offline, syncs when you push +- Fast, lightweight, and stays out of your way + +🔧 **Git Integration** +- Automatic sync with git commits +- Branch-aware issue tracking +- Intelligent JSONL merge resolution + +## Get Started with Beads + +Try Beads in your own projects: + +```bash +# Install Beads +curl -sSL https://raw.githubusercontent.com/steveyegge/beads/main/scripts/install.sh | bash + +# Initialize in your repo +bd init + +# Create your first issue +bd create "Try out Beads" +``` + +## Learn More + +- **Documentation**: [github.com/steveyegge/beads/docs](https://github.com/steveyegge/beads/tree/main/docs) +- **Quick Start Guide**: Run `bd quickstart` +- **Examples**: [github.com/steveyegge/beads/examples](https://github.com/steveyegge/beads/tree/main/examples) + +--- + +*Beads: Issue tracking that moves at the speed of thought* ⚡ diff --git a/.beads/config.yaml b/.beads/config.yaml new file mode 100644 index 0000000..1de3590 --- /dev/null +++ b/.beads/config.yaml @@ -0,0 +1,62 @@ +# Beads Configuration File +# This file configures default behavior for all bd commands in this repository +# All settings can also be set via environment variables (BD_* prefix) +# or overridden with command-line flags + +# Issue prefix for this repository (used by bd init) +# If not set, bd init will auto-detect from directory name +# Example: issue-prefix: "myproject" creates issues like "myproject-1", "myproject-2", etc. +# issue-prefix: "" + +# Use no-db mode: load from JSONL, no SQLite, write back after each command +# When true, bd will use .beads/issues.jsonl as the source of truth +# instead of SQLite database +# no-db: false + +# Disable daemon for RPC communication (forces direct database access) +# no-daemon: false + +# Disable auto-flush of database to JSONL after mutations +# no-auto-flush: false + +# Disable auto-import from JSONL when it's newer than database +# no-auto-import: false + +# Enable JSON output by default +# json: false + +# Default actor for audit trails (overridden by BD_ACTOR or --actor) +# actor: "" + +# Path to database (overridden by BEADS_DB or --db) +# db: "" + +# Auto-start daemon if not running (can also use BEADS_AUTO_START_DAEMON) +# auto-start-daemon: true + +# Debounce interval for auto-flush (can also use BEADS_FLUSH_DEBOUNCE) +# flush-debounce: "5s" + +# Git branch for beads commits (bd sync will commit to this branch) +# IMPORTANT: Set this for team projects so all clones use the same sync branch. +# This setting persists across clones (unlike database config which is gitignored). +# Can also use BEADS_SYNC_BRANCH env var for local override. +# If not set, bd sync will require you to run 'bd config set sync.branch '. +sync-branch: "beads-sync" + +# Multi-repo configuration (experimental - bd-307) +# Allows hydrating from multiple repositories and routing writes to the correct JSONL +# repos: +# primary: "." # Primary repo (where this database lives) +# additional: # Additional repos to hydrate from (read-only) +# - ~/beads-planning # Personal planning repo +# - ~/work-planning # Work planning repo + +# Integration settings (access with 'bd config get/set') +# These are stored in the database, not in this file: +# - jira.url +# - jira.project +# - linear.url +# - linear.api-key +# - github.org +# - github.repo \ No newline at end of file diff --git a/.beads/interactions.jsonl b/.beads/interactions.jsonl new file mode 100644 index 0000000..e69de29 diff --git a/.beads/issues.jsonl b/.beads/issues.jsonl new file mode 100644 index 0000000..c4f066f --- /dev/null +++ b/.beads/issues.jsonl @@ -0,0 +1,3 @@ +{"id":"nfp-g0j","title":"Refactor HeadBounds reduction helpers","description":"Deduplicate chunked reduction task helpers in Nfp/Sound/Induction/HeadBounds.lean to reduce proof churn while preserving behavior.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T16:21:11.583135+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:22:20.621327+01:00","closed_at":"2026-01-12T16:22:20.621327+01:00","close_reason":"Completed"} +{"id":"nfp-g78","title":"Refactor CoreSound lemmas for stability","description":"Refactor Nfp/Sound/Induction/CoreSound.lean to reduce proof churn and improve maintainability; keep proofs explicit and stable.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T16:12:25.347409+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:16:04.322186+01:00","closed_at":"2026-01-12T16:16:04.322186+01:00","close_reason":"Completed","labels":["in_progress"]} +{"id":"nfp-snh","title":"Refactor MatrixNorm interval bounds proofs","description":"Refactor Nfp/Sound/Bounds/MatrixNorm/Interval.lean to reduce proof churn and simplify interval bound lemmas.","status":"closed","priority":2,"issue_type":"task","owner":"robin.gieseke@me.com","created_at":"2026-01-12T15:54:50.752468+01:00","created_by":"TheDarkchip","updated_at":"2026-01-12T16:01:12.070535+01:00","closed_at":"2026-01-12T16:01:12.070535+01:00","close_reason":"Closed","comments":[{"id":1,"issue_id":"nfp-snh","author":"TheDarkchip","text":"Refactored foldl_pair helper for pair foldl projections in MatrixNorm interval bounds; adjusted proofs to use Prod.fst/Prod.snd. Fixed CoreSound cdot lint warnings encountered during build.","created_at":"2026-01-12T15:01:08Z"}]} diff --git a/.beads/metadata.json b/.beads/metadata.json new file mode 100644 index 0000000..c787975 --- /dev/null +++ b/.beads/metadata.json @@ -0,0 +1,4 @@ +{ + "database": "beads.db", + "jsonl_export": "issues.jsonl" +} \ No newline at end of file diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..807d598 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,3 @@ + +# Use bd merge for beads JSONL files +.beads/issues.jsonl merge=beads diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 364b4e3..06d8d82 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -33,33 +33,3 @@ jobs: echo "Quiet build failed; retrying verbose..." lake build nfp -v --no-ansi --wfail } - - - name: SOUND cache regression test (tiny fixture) - run: | - set -euo pipefail - lake exe nfp sound_cache_check --scalePow10 9 --maxTokens 0 tests/fixtures/tiny_sound_model.nfpt - - - name: Python setup - uses: actions/setup-python@v5 - with: - python-version: "3.11" - - - name: Cache HF model files - uses: actions/cache@v4 - with: - path: | - ~/.cache/huggingface - ~/.cache/torch - key: ${{ runner.os }}-hf-tiny-gpt2-v1 - - - name: Install PyTorch + Transformers (CPU) - run: | - set -euo pipefail - python -m pip install --upgrade pip - python -m pip install --index-url https://download.pytorch.org/whl/cpu torch - python -m pip install transformers numpy - - - name: Sanity check (Lean vs PyTorch forward step) - run: | - set -euo pipefail - python scripts/ci_sanity_forward_step.py --model sshleifer/tiny-gpt2 --seqLen 8 --layer 0 --pos 0 --take 16 diff --git a/.gitignore b/.gitignore index ed0dfed..8d7a9be 100644 --- a/.gitignore +++ b/.gitignore @@ -1,7 +1,4 @@ /.lake -models -*.nfpt -!tests/fixtures/*.nfpt .DS_Store logs/ sound_cache/ @@ -9,5 +6,4 @@ lean-reference-manual/ inductionhead_papers/ nfp.zip reports/ -tests/fixtures/tiny_sound_binary.nfpt nfp_agent_book/ diff --git a/AGENTS.md b/AGENTS.md index 1f5d40b..7945cf0 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -14,28 +14,23 @@ but keep the core invariants and the “no fake proofs” ethos. --- +**Use 'bd' for task tracking** + ## 0. Quick Start (What to run) -### Build (warnings are errors) -- `lake build -q --wfail` +### Build +- `lake build --wfail` ### Build the CLI -- `lake build nfp -q --wfail` +- `lake build nfp --wfail` ### Run the CLI (preferred integration path) One of these typically works (depending on your Lake setup): - `lake exe nfp --help` -- `./.lake/build/bin/nfp --help` - -If you add or change CLI behavior, validate at least: -- `nfp --help` -- `nfp analyze --help` -- `nfp induction --help` -- `nfp --version` (if supported) -Before you finish any change: -- `lake build -q --wfail` -- `lake build nfp -q --wfail` +### Search tips +Note: `models/` is gitignored, so `rg` will skip it unless you pass `--no-ignore` +or `-uuu` (or equivalent) when searching. --- @@ -43,24 +38,17 @@ Before you finish any change: ### 1.1 No fake proofs - **Forbidden:** `sorry` -- **Forbidden:** introducing new nontrivial axioms beyond what mathlib already uses. -- If you can’t prove a lemma as stated: - - reconsider the statement (missing assumptions? wrong generality?), - - introduce helper lemmas, - - or refactor the structure so the proof becomes natural. - - Do **not** “paper over” gaps. - -> Lean 4.26+ exploration tools (`finish?`, `try?`, `grind => finish?`, etc.) may *suggest* scripts that -> contain `sorry` (useful for debugging). Treat those suggestions as **scratch** only. -> **No `sorry` may reach the branch.** ### 1.2 Linting stays on - **Never** disable linters globally or locally. - **Forbidden:** any `set_option linter.* false` (including e.g. `linter.unnecessarySimpa`). - Fix the code/proofs instead. +- If linters warn about line length or file length, prefer principled refactors + (split modules, extract helpers) and keep docstrings with their code; avoid + squashing whitespace or formatting. ### 1.3 Clean build -- `lake build -q --wfail` must succeed. +- `lake build --wfail` must succeed. - Any warning is treated as an error: resolve it, do not ignore it. ### 1.4 Core invariants must remain true @@ -72,14 +60,15 @@ The library’s claims rest on these being preserved (preferably with explicit l - Finiteness assumptions (`[Fintype _]`) are used intentionally and consistently ### 1.5 Trusted Code Verification (Total Soundness) -**All code** in trusted namespaces (e.g., `Nfp.Sound.*`) must be **verified**. +**All code** in trusted namespaces (see §6) must be **verified**. - **Requirement:** Every pure definition in the trusted scope must be characterized by a theorem or return a proof-carrying structure. - *Example (Bad):* `def addOne (x : Nat) := x + 1` (Unverified logic) - *Example (Good):* `def addOne (x : Nat) : { y // y > x } := ⟨x + 1, Nat.lt_succ_self _⟩` - *Example (Good):* `def addOne ...` followed immediately by `theorem addOne_gt_input ...` - **Scope:** This applies to **everything**: parsers, converters, arithmetic helpers, and bound computations. -- **IO Exception:** Low-level IO primitives (reading bytes/files) cannot be "proven" correct but must be kept **logic-free**. +- **IO Exception:** Low-level IO primitives (reading bytes/files) cannot be "proven" + correct but must be kept **logic-free**. - IO code should only read data and pass it to verified Pure code. - No mathematical transformations or complex branching allowed in IO functions. @@ -91,30 +80,30 @@ The library’s claims rest on these being preserved (preferably with explicit l - Prefer `NNReal` for masses/capacities/probabilities. - Prefer finite types (`[Fintype ι]`) where possible. -### 2.2 Keep proofs readable and local -- Prefer: `simp`, `rw`, `linarith`/`nlinarith` when appropriate, small `calc` blocks, - and restrained `aesop` usage backed by named helper lemmas. -- Avoid huge opaque “mega proofs”. If a proof is long, factor it. - -> Lean 4.26+ note for agents: -> - Use stronger automation (`simp?`, `finish?`, `grind`, `try?`) primarily as **proof discovery** tools. -> - The final committed proof should be **explicit, minimal, and stable** (often: a small lemma + `simp [..]` / `rw [..]`). - -### 2.3 Don’t duplicate mathlib +### 2.2 Don’t duplicate mathlib - Search for existing lemmas before inventing new ones. - If you introduce a lemma that feels “standard”, consider whether mathlib already has it (or whether it belongs in a more general file in this repo). -### 2.4 Verify, Don't Trust -- Distinguish between **witness generation** (untrusted, can use heuristics) and **verification** (trusted, must contain proofs). -- The trusted kernel should only check that a candidate witness is valid; it should not be responsible for finding it if the search is complex. +### 2.3 Verify, Don't Trust +- Distinguish between **witness generation** (untrusted, can use heuristics) and + **verification** (trusted, must contain proofs). +- The trusted kernel should only check that a candidate witness is valid; it + should not be responsible for finding it if the search is complex. + +### 2.4 Prefer principled redesigns +When forced to choose between: +- “slightly breaking but conceptually clean redesign” +- vs “preserve an awkward design forever” + +prefer the **clean redesign**, but do it consciously and document the rationale. --- ## 3. Workflow Expectations (How to make changes) ### 3.1 Before coding -- Identify the right module (see §5 Module Map). +- Identify the right module. - Skim the top docstring / main definitions in that module. - Look for existing lemmas and naming patterns to match. @@ -128,10 +117,7 @@ The library’s claims rest on these being preserved (preferably with explicit l - small lemmas, smaller proof terms, fewer global simp rules. ### 3.3 After coding -- Ensure `lake build -q --wfail` passes. -- Ensure no `sorry`. -- Ensure no linter toggles were introduced. -- If you changed module responsibilities/structure, update §5 in the same commit. +- Ensure the Definition of Done checklist is satisfied. --- @@ -139,7 +125,7 @@ The library’s claims rest on these being preserved (preferably with explicit l ### 4.1 Naming and organization - Prefer consistent, descriptive names: - - `_lemma`, `_iff`, `_eq`, `_mono`, `_nonneg`, `_sum_one`, etc. + - `_iff`, `_eq`, `_mono`, `_nonneg`, `_sum_one`, `_def`, etc. - Keep namespaces coherent: - attach lemmas to the structure/namespace they conceptually belong to. @@ -149,199 +135,51 @@ The library’s claims rest on these being preserved (preferably with explicit l - non-explosive, - and broadly safe. - Prefer `simp [foo]` over global simp-set growth. -- Prefer `simp?` **only to discover** what `simp [..]` should be. -### 4.3 Tactic usage -- `aesop` is allowed, but: - - avoid relying on “magic” if it makes failures hard to debug, - - extract key steps into named lemmas so proofs stay stable. +### 4.3 Proof automation discipline (Aesop-aware) -### 4.4 Refactors are allowed—but must be principled -- You may do nontrivial refactors to improve conceptual cleanliness. -- If you rename/reshape core APIs: - - update all call sites, - - leave a brief comment (or commit message rationale), - - keep the module map (§5) accurate. - -### 4.5 Lean 4.26+ proof exploration toolkit (for LLM agents) -These tools can dramatically reduce “stuck time” for lemma discovery. Use them like a **search assistant**. -They are *not* a substitute for readable proofs. - -**Allowed for exploration (scratch / development):** -- `simp?` (optionally with suggestions, if available) -- `finish?` -- `grind` / `grind?`, and `grind => finish?` -- `try?` (as a hint generator) - -**Rules for using exploration tools:** -1. **Never commit generated `sorry`.** If an exploration tactic suggests a script with `sorry`, treat it as debugging output and delete it. -2. **Never commit giant opaque scripts.** If a generated script is long: - - identify the key lemmas it used, - - create named helper lemmas, - - replace the script with a small proof built from those lemmas. -3. **Minimize lemma sets.** - - If `simp?` / `finish?` / `grind` suggests many lemmas, shrink to the smallest stable subset. -4. **Prefer stable shapes:** - - a short `calc` block, - - or a couple of `simp [..]` / `rw [..]` steps, - - plus one helper lemma if necessary. -5. **Keep it local.** Prefer adding lemmas to the local simp set (`simp [foo, bar]`) over tagging globally `[simp]`. - -**Agent “proof playbook” (recommended loop):** -- Step A: Try the obvious: `simp`, `simp [defs]`, `rw [defs]`, `linarith`, `nlinarith`, `ring`, `field_simp` (as appropriate). -- Step B: If stuck, run `simp?` to discover missing rewrite/simp lemmas. -- Step C: If still stuck, use `finish?` or `grind => finish?` to learn the *shape* of the proof and which lemmas matter. -- Step D: Replace the discovered script with: - - a helper lemma (named + documented) capturing the crucial step, - - and a short final proof using `simp`/`rw`/`calc`. -- Step E: Re-run `lake build -q --wfail`. +- Use `aesop?` (and `simp?`) to *discover* a proof or a minimal set of steps, then + prefer to: + - write a small explicit proof (`simp only [...]`, `constructor`, `cases`, `refine`, `exact`, etc.), or + - keep Aesop but constrain it with a local ruleset and/or targeted rules. ---- +- Avoid unconditional `by aesop` in core/trusted library code unless: + - the goal is genuinely routine, + - it stays fast and stable under small refactors, and + - it does not rely on a large implicit rule universe. + - Note: mathlib itself uses `by aesop` widely; we are stricter in trusted/core code here. -## Lean 4 performance & scalability (use when justified) +- Prefer local rules over global rules: + - If a lemma is meant to be reused by Aesop, tag it deliberately (e.g. `@[aesop safe]`) + and explain why it is safe. + - Avoid tagging “utility” lemmas as Aesop rules unless you want them participating + in search broadly. -Default: write the simplest correct thing first. Use the levers below only when there is a clear payoff -(hot path, large workload, or expensive work that’s often unused). Add a short comment explaining the trigger. +- If Aesop succeeds but produces a long/fragile search: + - extract a helper lemma that expresses the key reasoning step, + - prove that lemma explicitly (or with tightly scoped automation), + - then let Aesop use the helper. -### Parallelism: `Task` (opt-in, deterministic-by-construction) -Use `Task` when work is independent and CPU-heavy (e.g., per-candidate / per-layer computations). -- Prefer *pure* tasks: `Task.spawn (fun () => ...)` and later `Task.get`. - Tasks cache their result; subsequent `get`s do not recompute. (Tasks are like “opportunistic thunks”.) -- Use `IO.asTask` only when you truly need effects; remember a task is spawned each time the returned `IO` action is executed. -- Keep results deterministic: never depend on completion order; aggregate by stable keys. -- Keep granularity coarse enough to amortize scheduling overhead. -- Cancellation: pure tasks stop when dropped; `IO.asTask` tasks must check for cancellation (`IO.checkCanceled`), and can be canceled via `IO.cancel`. -- If benchmarking, note that the runtime task thread pool size is controlled by `LEAN_NUM_THREADS` (or defaults to logical CPU count). +- Keep automation predictable: + - prefer `simp only [...]` and small custom simp sets locally, + - avoid growing the global simp set and global Aesop rule set without strong reason. -### Laziness: `Thunk` / delayed computations (opt-in, for expensive-but-often-unused work) -Use `Thunk` to defer work that is expensive and frequently unused (debug traces, optional certificates, rare branches). -- Prefer `Thunk` over “manual caching”: the runtime forces at most once and caches the value. -- Force explicitly at the boundary (`Thunk.get`), not “deep inside” unrelated logic. -- If a thunk is forced from multiple threads, other threads will wait while one thread computes it—avoid forcing in places where blocking could deadlock. -### Compile-time / elaboration performance nudge -When proofs or declarations get large, prefer factoring them into smaller independent theorems/lemmas when it improves clarity. -Lean can elaborate theorem bodies in parallel, so smaller independent units can help the compiler do more work concurrently. - -### Transparency / unfolding control (use sparingly) -Unfolding choices affect performance of simplification and typeclass search. -- The simplifier unfolds *reducible* definitions by default; semireducible/irreducible require explicit rewrite rules or different settings. -- `opaque` definitions are not δ-reduced in the kernel; use them to prevent expensive kernel reduction when unfolding is not needed for reasoning. -- Avoid cargo-culting reducibility attributes: use `local`/`scoped` when possible, and leave a short comment about why. - -Note: Recent Lean versions changed the story around well-founded recursion transparency; don’t rely on old recipes like making well-founded recursion “reducible” via attributes. - ---- - -## 5. Module Map (Where Things Live) - -This is a *map*, not a prison. You may reshuffle if a better design emerges, -but you **must** update this list in the same commit. - -### 5.1 Core probability + mixing -- `Prob.lean` - - Probability vectors (`ProbVec`) on finite types; normalization + basic lemmas. -- `Mixer.lean` - - Row-stochastic operators (“mixers”), composition/pushforward/support tools. -- `Influence.lean` - - Influence specifications/families, capacities, scaling, and conversion into mixers. -- `Reroute/Partition.lean` - - Finite partitions + reroute planning structures. -- `Reroute/Heat.lean` - - Weighted reroute plans and induced “heat” distributions. -- `PCC.lean` - - Tracer/contribution utilities; discrete AUC / interval machinery. -- `Uniqueness.lean` - - `LocalSystem` for finite DAG mixing systems; uniqueness theorem(s) for tracers. -- `MixerLocalSystem.lean` - - Bridges mixers-on-DAGs to `LocalSystem` (interpreters using a topo order). -- `Appendix.lean` - - Supplemental lemmas and wrappers that don’t belong elsewhere. - -### 5.2 Interpretability / NN-oriented layers (mathematical, mostly proofs) -- `Layers.lean` - - Neural-network layer operations modeled as mixers; attribution/ablation/reachability laws. -- `Attribution.lean` - - Interpretability axioms and bridges from tracer-based notions. -- `Induction.lean` - - True induction head definitions and certification theorems (pattern + faithfulness + functional effect). -- `SignedMixer.lean` - - Signed/real-weight generalization (negative weights, affine maps, etc.). -- `Linearization.lean` - - Jacobian-based linearizations, decomposition results, deep composition/error theorems. -- `Abstraction.lean` - - Causal-consistency / intervention correspondence between “real” networks and abstract DAG views. - -### 5.3 Executable analysis & CLI surface -- `Discovery.lean` - - Executable discovery + bound computations and verification pipeline. - - May be performance-sensitive; keep proofs minimal and move them to proof modules when possible. -- `Sound/Decimal.lean` - - Exact parsing of decimal/scientific numerals into `Rat` for sound mode. -- `Sound/Activation.lean` - - Activation-derivative metadata + header parsing helpers for SOUND mode. -- `Sound/ModelHeader.lean` - - Pure header parsing helpers for SOUND metadata (e.g., LayerNorm epsilon). -- `Sound/BinaryPure.lean` - - Pure binary parsing/decoding helpers (IO-free, used by untrusted IO wrappers). -- `Sound/TextPure.lean` - - Pure text parsing helpers for model-weight norms used in sound verification. -- `Sound/CachePure.lean` - - Pure cache parsing/encoding helpers used by untrusted IO wrappers. -- `Untrusted/SoundBinary.lean` - - IO wrappers for the SOUND binary path (untrusted). -- `Untrusted/SoundCacheIO.lean` - - IO wrappers for the SOUND fixed-point cache (untrusted). -- `Sound/Bounds.lean` - - Umbrella import for SOUND bound utilities (exact `Rat` arithmetic, no Float). -- `Sound/Bounds/Basic.lean` - - Basic `Rat` helpers used across bounds modules. -- `Sound/Bounds/MatrixNorm.lean` - - Rat matrices, row-sum norms, and multiplicativity bounds. -- `Sound/Bounds/Gelu.lean` - - GeLU derivative envelopes (global). -- `Sound/Bounds/Exp.lean` - - exp lower bounds (scaled Taylor + squaring). -- `Sound/Bounds/Softmax.lean` - - Softmax Jacobian bounds and margin-derived weight helpers. -- `Sound/Bounds/Attention.lean` - - Attention pattern-term coefficient helpers. -- `Sound/Bounds/LayerNorm.lean` - - LayerNorm operator bounds (global/local). -- `Sound/Bounds/Portfolio.lean` - - Placeholder for portfolio combinators. -- `Sound/Bounds/Effort.lean` - - Placeholder for effort-tier records. -- `Sound/Affine.lean` - - Affine-form scaffolding for future local soundness improvements. -- `Sound/HeadCert.lean` - - Sound per-head contribution certificate structures. -- `Sound/Bridge.lean` - - Lemmas connecting `Rat`-level bounds to `SignedMixer` operator-norm bounds. -- `Sound/Cert.lean` - - Certificate/report structures and pretty-printing for SOUND-mode output. -- `Sound/IO.lean` - - Trusted IO wrappers: read inputs, call untrusted computation, and verify certificates. -- `Untrusted/SoundCompute.lean` - - IO-heavy witness generation for sound certificates (untrusted; verified by `Sound/IO`). -- `Sound/Demo.lean` - - Tiny end-to-end lemma demo bridging to `Linearization.operatorNormBound`. -- `Verification.lean` - - Executable **causal verification** via head ablation + runtime axiom checks (competence, control independence, energy matching). -- `IO.lean` - - Parsing/loading/tokenization/report formatting glue. - - **IO-only principle:** no heavy proofs; keep it as a bridge to filesystem/CLI. -- `Main.lean` - - CLI entrypoint and subcommand wiring. Keep it thin: - - argument parsing + calling into `Nfp.IO` / `Discovery` / `Nfp.Sound.*` reporting helpers, - - minimal logic, minimal proof content. -- `Nfp.lean` - - Top-level reexports and an axioms check (`#print axioms` / trust dashboard). - -If you introduce a new conceptual layer: -- either extend the closest existing file, -- or add a new module with a clear name + top docstring, -- and update this map in the same commit. +### 4.4 Refactors are allowed—but must be principled +- You may do nontrivial refactors to improve conceptual cleanliness. +- If you rename/reshape core APIs: + - update all call sites, + - leave a brief comment (or commit message rationale). + +### 4.5 Mathlib Module Structure (Local Baseline) +Based on `.lake/packages/mathlib` in this workspace: +- `Mathlib/` files start with `module`; `MathlibTest/` and `Archive/` files may not. +- Most `Mathlib/` files use `public import` for reexports, but plain `import` appears in core files too; meta helpers often use `public meta import` based on public API/compile-time reachability. +- Many `Mathlib/` files open with `@[expose] public section`; `public meta section` is common in meta/Util, while `public section` and `public noncomputable section` are rarer. +- Use `@[expose] public section` when broad unfolding is expected; use plain `public section` or `private` when you want bodies hidden. +- Prefer `*_def` lemmas for targeted definitional rewriting even when definitions are exposed. +- `import all` is rare, requires `allowImportAll`, and is used mainly for tests/private access; keep it local and documented. +- Keep aggregator modules (e.g., `Nfp.Core`, `Nfp.Sound`) as thin `public import` reexports. --- @@ -351,25 +189,48 @@ This repo treats “axioms creep” as a serious regression. - Do not add axioms. - Keep an eye on classical assumptions; they may be unavoidable, but should be explicit. -- Use `Nfp.lean` as the “trust dashboard” for `#print axioms` / dependency visibility. +- Trusted namespaces are `Nfp.Sound.*`, `Nfp.IO.Pure.*`, and `Nfp.IO.NfptPure`. + If another module is intended to be trusted, say so explicitly in its docstring + and treat it as in-scope here. +- Use `TheoremAxioms.lean` / `lake build theorem-axioms --wfail` as the trust dashboard for + `#print axioms` / dependency visibility. --- ## 7. Definition of Done (Checklist) -- [ ] `lake build -q --wfail` succeeds. +- [ ] `lake build --wfail` succeeds. - [ ] No `sorry`. - [ ] No new axioms were introduced. -- [ ] **Total Soundness:** Every pure definition in the trusted section is verified/proven. +- [ ] **Total Soundness:** Every pure definition in trusted namespaces is verified/proven. - [ ] No linters were disabled (`set_option linter.* false` is absent). - [ ] New nontrivial definitions/theorems have short, accurate docstrings. -- [ ] Core invariants (nonnegativity, normalization, finiteness, acyclicity) are preserved and, where possible, explicitly proved. -- [ ] §5 Module Map is accurate (updated in the same commit if needed). -- [ ] If CLI behavior changed: `lake build nfp -q --wfail` succeeds and basic `nfp ... --help` works. -- [ ] If you used Lean 4.26+ exploration tools, the final committed proof is short, explicit, and stable (no giant generated scripts). - -When forced to choose between: -- “slightly breaking but conceptually clean redesign” -- vs “preserve an awkward design forever” - -prefer the **clean redesign**, but do it consciously and document the rationale. +- [ ] Core invariants (nonnegativity, normalization, finiteness, acyclicity) are + preserved and, where possible, explicitly proved. +- [ ] If CLI behavior changed: `lake build nfp --wfail` succeeds and basic `nfp ... --help` works. + +## Landing the Plane (Session Completion) + +**When ending a work session**, you MUST complete ALL steps below. Work is NOT complete until `git push` succeeds. + +**MANDATORY WORKFLOW:** + +1. **File issues for remaining work** - Create issues for anything that needs follow-up +2. **Run quality gates** (if code changed) - Tests, linters, builds +3. **Update issue status** - Close finished work, update in-progress items +4. **PUSH TO REMOTE** - This is MANDATORY: + ```bash + git pull --rebase + bd sync + git push + git status # MUST show "up to date with origin" + ``` +5. **Clean up** - Clear stashes, prune remote branches +6. **Verify** - All changes committed AND pushed +7. **Hand off** - Provide context for next session + +**CRITICAL RULES:** +- Work is NOT complete until `git push` succeeds +- NEVER stop before pushing - that leaves work stranded locally +- NEVER say "ready to push when you are" - YOU must push +- If push fails, resolve and retry until it succeeds diff --git a/CHANGELOG_NOTES.md b/CHANGELOG_NOTES.md deleted file mode 100644 index eeb5be8..0000000 --- a/CHANGELOG_NOTES.md +++ /dev/null @@ -1,46 +0,0 @@ -# CHANGELOG / NOTES - -## 2025-12-24 -- Added attention pattern-term coefficients using max `W_Q/W_K` row-sum norms and a conservative - LayerNorm output magnitude bound; updated layer cert formulas and reports accordingly. -- Added `modelDim`/`headDim` metadata to sound certificates and threaded through the checker. - -## 2025-12-23 -- Added margin-derived softmax max-probability and Jacobian bounds for best-match pattern certificates. -- Added effort-indexed exp lower bounds (scaled Taylor + squaring) and wired them into best-match softmax bounds. -- Extended best-match head pattern certs with a recorded softmax Jacobian upper bound and wired untrusted computation to populate it. -- Noted that the exp lower-bound correctness is not yet formalized in Lean. -- Layer-level sound certificates now use a portfolio softmax Jacobian bound field, with margin-based - tightening available when margins are supplied (defaults remain worst-case today). -- Added `nfp certify --softmaxMargin/--softmaxExpEffort` flags and report fields to pass margin - evidence into layer-level softmax portfolio bounds. - -## 2025-12-22 -- Optimized induction head discovery by caching per-head induction scores and per-layer input norms, eliminating redundant pattern scans and repeated Frobenius norm computations. -- Tightened induction error bounds by using data-dependent V norms (Frobenius/op) in pattern-term calculations. -- Tightened per-head weight operator-norm bounds using Brauer/moment Gram candidates, improving pattern-term bounds. - -## 2025-12-21 -- Updated sound certificate algebra to include the attn*mlp cross term and surfaced it in output. -- Added CLAIMS.md and clarified soundness limitations and reproducibility documentation. -- Added operator-norm bound lemmas for SignedMixers, including a residual composition bound that takes external operator-norm bounds. -- Added a helper lemma to extract the `C` identity from `LayerAmplificationCert.Valid`. -- Added a bridge lemma to bound residual composition from component bounds plus the cast `C` identity. -- Added cast and `Valid`-based bridge lemmas to move from certificate validity to the residual bound. -- Added a `DeepLinearization` lemma that turns per-component operator-norm bounds into a layer residual bound. -- Added a certificate-to-Jacobian bridge lemma tying layer certificates to residual bounds (under component-bound assumptions). -- Added attention/MLP component bound lemmas that connect certificate identities to operator-norm bounds. -- Added an assumption-based bridge theorem combining component bounds into a full layer residual bound. -- Added a `LayerComponentNormAssumptions` structure to package the remaining component-norm obligations. -- Added an operator-norm bound lemma for diagonal mixers based on uniform entry bounds. -- Added an operator-norm bound lemma for `A ∘ diag(d) ∘ B` from component bounds. -- Replaced the MLP component assumption with a factored-Jacobian assumption and derived the MLP bound from it. -- Added `MLPFactorization` and `mlpFactors` to `DeepLinearization`, with MLP Jacobians derived from the factorization data. -- Added `mlpWinBound`/`mlpWoutBound` fields to the sound certificate and wired the bridge to use them for MLP coefficient bounds. -- Updated README tiny demo instructions to use the binary fixture directly and document the optional scripted path. -- Added a helper to patch missing `gelu_kind` in legacy `.nfpt` headers and wired it into the GPT-2 sound demo script. -- Made demo scripts fall back to `python3` when `python` is unavailable. -- Documented the GPT-2 header patch behavior in README. -- Updated the GPT-2 induction demo to patch legacy headers and use the current Python executable when generating rigorous models. -- Documented GPT-2 induction scan runtime notes in README. -- Added `--fast`, `--jobs`, and `--nfp-bin` options to the induction scan script for faster runs. diff --git a/CLAIMS.md b/CLAIMS.md index bef758a..cbffcfb 100644 --- a/CLAIMS.md +++ b/CLAIMS.md @@ -1,12 +1,36 @@ # CLAIMS This file lists what is formally proven in Lean, what is soundly checked by the trusted checker, -what is heuristic, and what is not yet proven. - -| Claim | Status | Where | -| --- | --- | --- | -| Definitions of mixers/signed mixers and linearizations (ReLU, GeLU, LayerNorm, softmax) with basic lemmas (composition, diagonality, etc.) | Proven in Lean | `Nfp/SignedMixer.lean`, `Nfp/Linearization.lean` | -| The sound certificate checker validates internal arithmetic consistency for layer bounds and the total amplification factor | Soundly checked (Lean) | `Nfp/Sound/Cert.lean` | -| Sound bound formulas use exact `Rat` arithmetic (LayerNorm/softmax/GeLU envelopes); witness values are produced in untrusted code and then checked | Soundly checked formulas; untrusted witnesses | `Nfp/Sound/Bounds.lean`, `Nfp/Untrusted/SoundCompute.lean` | -| Heuristic discovery and ranking of induction-style candidates | Heuristic | `Nfp/Discovery.lean`, CLI `induction` | -| End-to-end statement that certificate validity implies `||layerJacobian - I|| <= C` for Lean-defined Jacobians | Not yet proven | See `SOUNDNESS_LIMITATIONS.md` | +what is untrusted/heuristic, and what is not yet proven in the tabula rasa rewrite. + +## Proven in Lean + +- Circuit core definitions and semantics (typed circuits, evaluation, interfaces). +- Softmax-margin certificate soundness: `checkSoftmaxMarginCert` implies + `SoftmaxMarginBoundsOn`. +- Value-range certificate soundness: `checkValueRangeCert` implies `ValueRangeBounds`. +- Induction-head certificate soundness: `checkInductionHeadCert` implies + `InductionHeadCertBounds`. +- Logit-diff lower bound lemmas: `logitDiffLowerBound_le`, `logitDiffLowerBoundAt_le`, and + `logitDiffLowerBoundWeightedAt_le`. +- The head logit-diff equals the direction dot product of the head output + (`headLogitDiff_eq_direction_dot_headOutput`). +- Row-stochastic attention/one-hot bounds for induction heads and related interval lemmas. + +## Soundly checked by the trusted CLI + +- `nfp induction certify` verifies explicit induction-head certificates from a single cert + file, optionally enforcing minimum `active`, `margin`, `eps`, and logit-diff gates, and + optionally checking `prev`/`active` against a supplied token list. + +## Untrusted / heuristic + +- Python helpers that generate explicit induction-head certificates from GPT-2 weights, + including optional direction search: `scripts/build_gpt2_induction_cert.py`. +- Any choice of prompts, directions, or candidate heads used by certificate generators. + +## Not yet proven + +- A verified extraction pipeline from model weights to explicit certificates. +- Model-level claims about logits or Jacobians derived from certificates. +- A full bridge from explicit head certificates to complete model semantics. diff --git a/Main.lean b/Main.lean index 8806542..e5d1eaa 100644 --- a/Main.lean +++ b/Main.lean @@ -1,1760 +1,13 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Cli -import Nfp.IO -import Nfp.Linearization -import Nfp.Untrusted.SoundCacheIO -import Nfp.Verification -import Nfp.Sound.IO -import Std.Time.Format +module -/-! -# NFP CLI: Circuit Verification Command-Line Tool +import Nfp.Cli -This is the main entry point for the NFP circuit verification tool. +public section -## Usage +/-- CLI entry point. -/ +def main (args : List String) : IO UInt32 := + Nfp.main args -Build the executable: -```bash -lake build nfp -``` - -List the available subcommands: -```bash -lake exe nfp --help -``` - -Example invocations (see README for full flag descriptions): -```bash -# Analyze a model and write a report to a file -lake exe nfp analyze model.nfpt --threshold 0.1 --output report.txt - -# Search for induction heads with diagnostics enabled -lake exe nfp induction model.nfpt --diagnostics --diagTop 5 --adaptive - -# Generate a sound-mode certificate report -lake exe nfp certify model.nfpt - -# Local (input-dependent) sound-mode certificate report -# (NFP_BINARY_V1 always embeds inputs; legacy text requires an EMBEDDINGS section.) -lake exe nfp certify model.nfpt --delta 1/100 - -# Sound per-head contribution bounds (global or local with --delta) -lake exe nfp head_bounds model.nfpt --delta 1/100 - -# Sound attention pattern bounds for a single head (binary only) -lake exe nfp head_pattern model.nfpt --layer 0 --head 0 --delta 1/100 --offset -1 - -# Sound induction head certificate (binary only) -lake exe nfp induction_cert model.nfpt --layer1 0 --head1 0 --layer2 1 --head2 0 \ - --coord 0 --delta 1/100 --offset1 -1 --offset2 -1 --target 42 --negative 17 - -# Instantiate RoPE bounds for a specific shape -lake exe nfp rope --seqLen 4 --pairs 8 - -# Show version -lake exe nfp --version -``` --/ - -open Cli Nfp - -private def fmtFloat (x : Float) : String := - toString x - -/-! ## Stdout logging -/ - -private structure StdoutLogCtx where - handle : IO.FS.Handle - pathRef : IO.Ref System.FilePath - pendingRef : IO.Ref Bool - timestamp : String - -initialize stdoutLogCtxRef : IO.Ref (Option StdoutLogCtx) ← IO.mkRef none - -private def sanitizeFileComponent (s : String) : String := - String.ofList <| - s.toList.map fun c => - if c.isAlphanum || c = '_' || c = '-' || c = '.' then c else '_' - -private def timestampNowForLog : IO String := do - let dt ← Std.Time.ZonedDateTime.now - let dateStr := s!"{dt.toPlainDate}" - let timeRaw := s!"{dt.toPlainTime}" - let timeNoFrac := (timeRaw.splitOn ".").getD 0 timeRaw - let timeStr := timeNoFrac.replace ":" "-" - return s!"{dateStr}-{timeStr}" - -private def openPendingStdoutLog : IO StdoutLogCtx := do - let logsDir : System.FilePath := "logs" - IO.FS.createDirAll logsDir - let ts ← timestampNowForLog - let path : System.FilePath := logsDir / s!"{ts}_pending.log" - let h ← IO.FS.Handle.mk path .write - let pathRef ← IO.mkRef path - let pendingRef ← IO.mkRef true - return { handle := h, pathRef := pathRef, pendingRef := pendingRef, timestamp := ts } - -private def mkTeeStream (out log : IO.FS.Stream) : IO.FS.Stream := - { flush := do out.flush; log.flush - read := fun n => out.read n - write := fun b => do out.write b; log.write b - getLine := out.getLine - putStr := fun s => do out.putStr s; log.putStr s - isTty := out.isTty } - -private def setStdoutLogName (name : String) : IO Unit := do - let some ctx ← stdoutLogCtxRef.get | return () - let pending ← ctx.pendingRef.get - if !pending then - return () - let oldPath ← ctx.pathRef.get - let logsDir : System.FilePath := "logs" - let safeName := sanitizeFileComponent name - let newPath : System.FilePath := logsDir / s!"{ctx.timestamp}_{safeName}.log" - try - IO.FS.rename oldPath newPath - ctx.pathRef.set newPath - ctx.pendingRef.set false - catch - | _ => - -- If rename fails, keep the pending filename but continue. - pure () - -private def setStdoutLogNameFromModelPath (modelPath : String) : IO Unit := do - let p : System.FilePath := modelPath - let stem := p.fileStem.getD (p.fileName.getD "model") - setStdoutLogName stem - -/-- Write a report to stdout or to a file if an output path is provided. -/ -private def writeReport (outputPath? : Option System.FilePath) (report : String) : IO Unit := do - match outputPath? with - | some path => - IO.FS.writeFile path report - IO.println s!"Report written to {path}" - | none => - IO.println report - -/-- Check whether a `.nfpt` file contains an `EMBEDDINGS` section before the first `LAYER`. - -For `NFP_BINARY_V1`, embeddings are always present, so this returns true. This is used to decide -whether `nfp certify --delta ...` can default to using the model file as its own input source -(so users don't have to pass `--input model.nfpt`). -/ -private def hasEmbeddingsBeforeLayers (path : System.FilePath) : IO Bool := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let line ← h.getLine - if line.isEmpty then - return false - let s := line.trim - if s = "NFP_BINARY_V1" then - return true - -- Header: read until blank line (text format only). - let mut seenHeader : Bool := false - if s.startsWith "NFP_TEXT" then - seenHeader := true - while true do - let line ← h.getLine - if line.isEmpty then - return false - let s := line.trim - if !seenHeader then - if s.startsWith "NFP_TEXT" then - seenHeader := true - continue - if s.isEmpty then - break - -- After the header, `EMBEDDINGS` (if present) must appear before any layer payload. - while true do - let line ← h.getLine - if line.isEmpty then - return false - let s := line.trim - if s = "EMBEDDINGS" then - return true - if s.startsWith "LAYER" then - return false - return false - -private def printHeadDiagnostics (label : String) (data : PrecomputedHeadData) : IO Unit := do - let attnFrob : Float := Float.sqrt data.attentionFrobeniusNormSq - let patternRecon : Float := - (data.softmaxJacobianOpEst / data.scaleFactor) * - data.inputNorm * data.queryKeyAlignSchurNorm * data.valueOutputProjSchurNorm - let valueRecon : Float := attnFrob * data.valueOutputProjNorm - let epsRecon : Float := - if valueRecon < 1e-10 then Float.inf else patternRecon / valueRecon - let patternCached := data.patternTermBoundCached - let valueCached := data.valueTermNormCached - let epsCached := data.faithfulnessRatioCached - IO.println s!" {label} L{data.layerIdx}H{data.headIdx}:" - IO.println s!" softmaxOpBound = {fmtFloat data.softmaxJacobianOpEst}" - IO.println s!" softmaxParts = rowMaxP={fmtFloat data.softmaxRowMaxP}" - IO.println s!" = rowTrace={fmtFloat data.softmaxRowTraceBound}" - IO.println s!" = rowMoment={fmtFloat data.softmaxRowMomentBound}" - IO.println s!" = rowGersh={fmtFloat data.softmaxRowGershBound}" - IO.println s!" = rowUsed={fmtFloat data.softmaxRowBoundUsed}" - IO.println s!" = fallbackRows={data.softmaxRowsFallback}" - IO.println s!" scaleFactor = {fmtFloat data.scaleFactor}" - IO.println s!" inputNorm = {fmtFloat data.inputNorm}" - IO.println s!" inputOpBound = {fmtFloat data.inputOpBound}" - IO.println s!" qkOpBoundUsed = {fmtFloat data.queryKeyAlignSchurNorm}" - IO.println s!" qkActFrob(c) = {fmtFloat data.qkActFrobBound}" - IO.println s!" kqActFrob(c) = {fmtFloat data.kqActFrobBound}" - IO.println s!" qkActOp(c) = {fmtFloat data.qkActOpBound}" - IO.println s!" kqActOp(c) = {fmtFloat data.kqActOpBound}" - let qkActOpUbStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.qkActOpUb - | none => "n/a" - let kqActOpUbStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.kqActOpUb - | none => "n/a" - IO.println s!" qkActOpUbUsed = {qkActOpUbStr}" - IO.println s!" kqActOpUbUsed = {kqActOpUbStr}" - IO.println s!" qkActOpSource = {data.qkActOpBoundSource}" - IO.println s!" kqActOpSource = {data.kqActOpBoundSource}" - let vOpUbStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.vOpUb - | none => "n/a" - let vOpUbWOStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.vOpUbWO - | none => "n/a" - IO.println s!" vOpUbUsed = {vOpUbStr}" - IO.println s!" vOpUbWOUsed = {vOpUbWOStr}" - IO.println s!" qOpBoundAct = {fmtFloat data.qOpBoundAct}" - IO.println s!" kOpBoundAct = {fmtFloat data.kOpBoundAct}" - IO.println s!" vOpBoundAct = {fmtFloat data.vOpBoundAct}" - IO.println s!" qkFrob = {fmtFloat data.queryKeyAlignNorm}" - IO.println s!" wqOpGram = {fmtFloat data.wqOpGram}" - IO.println s!" wkOpGram = {fmtFloat data.wkOpGram}" - IO.println s!" qkFactorGram = {fmtFloat data.qkFactorGram}" - let qkDenseSchurStr : String := - match data.diag? with - | some diag => fmtFloat diag.qkDenseSchur.get - | none => "n/a" - IO.println s!" qkCandidates = denseSchur={qkDenseSchurStr}" - IO.println s!" = denseFrob={fmtFloat data.qkDenseFrob}" - IO.println s!" = denseGram={fmtFloat data.qkDenseGram}" - IO.println s!" = denseBrauer={fmtFloat data.qkDenseBrauer}" - let qkBrauerOk : String := - if data.qkDenseBrauer ≤ data.qkDenseGram then "true" else "false" - IO.println s!" = denseBrauer≤denseGram={qkBrauerOk}" - IO.println s!" = factorSchur={fmtFloat data.qkFactorSchur}" - IO.println s!" = factorFrob={fmtFloat data.qkFactorFrob}" - IO.println s!" = factorGram={fmtFloat data.qkFactorGram}" - IO.println s!" voOpBoundUsed = {fmtFloat data.valueOutputProjSchurNorm}" - IO.println s!" voFrob = {fmtFloat data.valueOutputProjNorm}" - IO.println s!" wvOpGram = {fmtFloat data.wvOpGram}" - IO.println s!" woOpGram = {fmtFloat data.woOpGram}" - IO.println s!" voFactorGram = {fmtFloat data.voFactorGram}" - let voDenseSchurStr : String := - match data.diag? with - | some diag => fmtFloat diag.voDenseSchur.get - | none => "n/a" - IO.println s!" voCandidates = denseSchur={voDenseSchurStr}" - IO.println s!" = denseFrob={fmtFloat data.voDenseFrob}" - IO.println s!" = denseGram={fmtFloat data.voDenseGram}" - IO.println s!" = denseBrauer={fmtFloat data.voDenseBrauer}" - let voBrauerOk : String := - if data.voDenseBrauer ≤ data.voDenseGram then "true" else "false" - IO.println s!" = denseBrauer≤denseGram={voBrauerOk}" - IO.println s!" = factorSchur={fmtFloat data.voFactorSchur}" - IO.println s!" = factorFrob={fmtFloat data.voFactorFrob}" - IO.println s!" = factorGram={fmtFloat data.voFactorGram}" - IO.println s!" attnFrob = {fmtFloat attnFrob}" - IO.println s!" patternCached = {fmtFloat patternCached}" - IO.println s!" valueCached = {fmtFloat valueCached}" - IO.println s!" εCached = {fmtFloat epsCached}" - let candFrobStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.candFrob - | none => "n/a" - let candOpStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.candOp - | none => "n/a" - let candOpWOStr : String := - match data.patternBoundParts? with - | some parts => fmtFloat parts.candOpWO - | none => "n/a" - IO.println s!" candFrob = {candFrobStr}" - IO.println s!" candOp = {candOpStr}" - IO.println s!" candOpWO = {candOpWOStr}" - IO.println s!" patternRecon = {fmtFloat patternRecon}" - IO.println s!" valueRecon = {fmtFloat valueRecon}" - IO.println s!" εRecon = {fmtFloat epsRecon}" - let qkOk : String := - if data.queryKeyAlignSchurNorm ≤ data.queryKeyAlignNorm then "true" else "false" - let voOk : String := - if data.valueOutputProjSchurNorm ≤ data.valueOutputProjNorm then "true" else "false" - IO.println s!" checks = qkUsed≤qkFrob={qkOk}, voUsed≤voFrob={voOk}" - IO.println s!" reconDiff = Δpattern={fmtFloat (patternRecon - patternCached)}" - IO.println s!" Δvalue={fmtFloat (valueRecon - valueCached)}" - IO.println s!" Δε={fmtFloat (epsRecon - epsCached)}" - -/-! ## Analyze command helpers -/ - -private structure AnalyzeArgs where - modelPath : System.FilePath - modelPathStr : String - threshold : Float - outputPath? : Option System.FilePath - verify : Bool - verbose : Bool - -private def parseAnalyzeArgs (p : Parsed) : IO (Option AnalyzeArgs) := do - let modelPathStr := p.positionalArg! "model" |>.as! String - let thresholdStr := p.flag? "threshold" |>.map (·.as! String) |>.getD "0.1" - let some threshold := Nfp.parseFloat thresholdStr - | do - IO.eprintln s!"Error: Invalid threshold value '{thresholdStr}'" - return none - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let verify := p.hasFlag "verify" - let verbose := p.hasFlag "verbose" - return some { - modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - threshold := threshold - outputPath? := outputPath? - verify := verify - verbose := verbose - } - -private def runAnalyzeWithArgs (args : AnalyzeArgs) : IO UInt32 := do - setStdoutLogNameFromModelPath args.modelPathStr - if args.verbose then - IO.println s!"Loading model from {args.modelPathStr}..." - IO.println s!"Threshold: {args.threshold}" - if args.verify then - IO.println "Mode: Verification (with empirical validation)" - else - IO.println "Mode: Analysis (static bounds only)" - let loadResult ← loadModel args.modelPath - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => - let model := model0.trimTrailingZeroEmbeddings - if args.verbose && model.seqLen ≠ model0.seqLen then - IO.println s!" Trimmed trailing zero embeddings: seqLen {model0.seqLen} -> {model.seqLen}" - if args.verbose then - IO.println s!"✓ Model loaded successfully" - IO.println s!" Layers: {model.numLayers}" - IO.println s!" Sequence Length: {model.seqLen}" - let vocabSize := - match model.unembedding with - | some u => u.numCols - | none => 0 - IO.println s!" Embedding Vocabulary: {vocabSize}" - IO.println s!" Model Dimension: {model.modelDim}" - IO.println "" - IO.println "Running analysis..." - let report ← if args.verify then - analyzeAndVerify model args.modelPathStr args.threshold none - else - analyzeModel model args.modelPathStr args.threshold - writeReport args.outputPath? (toString report) - return 0 - -/-- Run the analyze command - perform circuit analysis. -/ -def runAnalyze (p : Parsed) : IO UInt32 := do - let some args ← parseAnalyzeArgs p - | return 1 - runAnalyzeWithArgs args - -/-! ## Induction command helpers -/ - -private structure InductionArgs where - modelPath : System.FilePath - modelPathStr : String - correctOpt : Option Nat - incorrectOpt : Option Nat - minEffect : Float - verify : Bool - verbose : Bool - diagnostics : Bool - adaptive : Bool - targetSlack : Float - maxUpgrades : Nat - minRelImprove : Float - krylovSteps : Nat - adaptiveScope : Nfp.AdaptiveScope - adaptiveScopeStr : String - diagTop : Nat - -private def parseInductionArgs (p : Parsed) : IO (Option InductionArgs) := do - let modelPathStr := p.positionalArg! "model" |>.as! String - let correctOpt := p.flag? "correct" |>.map (·.as! Nat) - let incorrectOpt := p.flag? "incorrect" |>.map (·.as! Nat) - let thresholdStr := p.flag? "threshold" |>.map (·.as! String) |>.getD "0.0" - let verify := p.hasFlag "verify" - let verbose := p.hasFlag "verbose" - let diagnostics := p.hasFlag "diagnostics" - let adaptive := p.hasFlag "adaptive" - let targetSlackStr := p.flag? "targetSlack" |>.map (·.as! String) |>.getD "8.0" - let maxUpgrades := p.flag? "maxUpgrades" |>.map (·.as! Nat) |>.getD 120 - let minRelImproveStr := p.flag? "minRelImprove" |>.map (·.as! String) |>.getD "0.01" - let krylovSteps := p.flag? "krylovSteps" |>.map (·.as! Nat) |>.getD 2 - let adaptiveScopeStr := p.flag? "adaptiveScope" |>.map (·.as! String) |>.getD "layernorm" - let diagTop := p.flag? "diagTop" |>.map (·.as! Nat) |>.getD 5 - let some minEffect := Nfp.parseFloat thresholdStr - | do - IO.eprintln s!"Error: Invalid threshold value '{thresholdStr}'" - return none - let (targetSlack, minRelImprove, adaptiveScope) ← - if adaptive then - let some targetSlack := Nfp.parseFloat targetSlackStr - | do - IO.eprintln s!"Error: Invalid --targetSlack '{targetSlackStr}'" - return none - let some minRelImprove := Nfp.parseFloat minRelImproveStr - | do - IO.eprintln s!"Error: Invalid --minRelImprove '{minRelImproveStr}'" - return none - let adaptiveScope? : Option Nfp.AdaptiveScope := - match adaptiveScopeStr.trim.toLower with - | "layernorm" => some .layernorm - | "all" => some .all - | _ => none - let some adaptiveScope := adaptiveScope? - | do - IO.eprintln <| - s!"Error: Invalid --adaptiveScope '{adaptiveScopeStr}' " ++ - "(expected layernorm|all)" - return none - pure (targetSlack, minRelImprove, adaptiveScope) - else - pure (8.0, 0.01, Nfp.AdaptiveScope.layernorm) - return some { - modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - correctOpt := correctOpt - incorrectOpt := incorrectOpt - minEffect := minEffect - verify := verify - verbose := verbose - diagnostics := diagnostics - adaptive := adaptive - targetSlack := targetSlack - maxUpgrades := maxUpgrades - minRelImprove := minRelImprove - krylovSteps := krylovSteps - adaptiveScope := adaptiveScope - adaptiveScopeStr := adaptiveScopeStr - diagTop := diagTop - } - -private def deriveInductionTarget - (model : ConcreteModel) - (W_U : ConcreteMatrix) - (correctOpt incorrectOpt : Option Nat) : Option TargetDirection := - match correctOpt, incorrectOpt with - | some correct, some incorrect => - some (TargetDirection.fromLogitDiff W_U correct incorrect) - | some _, none | none, some _ => - none - | none, none => - match model.inputTokens with - | some _ => TargetDirection.fromInductionHistory model - | none => some (TargetDirection.fromLogitDiff W_U 0 1) - -private def printInductionSearchIntro (minEffect : Float) : IO Unit := do - IO.println s!"Searching for heads with Effect > {minEffect}... (heuristic)" - IO.println "Ranking: highest mechScore (= kComp·indScore·prevTok) first" - IO.println " Tie-break: Effect, kComp, δ, then lowest Error" - IO.println " circuitScore = Effect · mechScore" - IO.println " Effect = δ / (‖ln₁(X₂)‖_F · ‖u‖₂)" - IO.println " kComp_raw = ‖W_QK² · W_OV¹‖_F / (‖W_QK²‖_F · ‖W_OV¹‖_F)" - IO.println " kComp = kComp_raw - 1/√modelDim" - -private def printInductionCandidate (verbose : Bool) (h : HeuristicInductionHead) : IO Unit := do - let c := h.candidate - let mechScore := c.kComp * c.inductionScore * c.prevTokenStrength - let circuitScore := h.effect * mechScore - if verbose then - IO.println <| - s!"L{c.layer1Idx}H{c.head1Idx} -> L{c.layer2Idx}H{c.head2Idx} | " ++ - s!"Mech: {mechScore} | Circuit: {circuitScore} | " ++ - s!"Effect: {h.effect} | kComp: {c.kComp} | " ++ - s!"indScore: {c.inductionScore} | prevTok: {c.prevTokenStrength} | " ++ - s!"δ: {h.delta} | " ++ - s!"Error: {c.combinedError} | " ++ - s!"‖X₂‖_F: {h.layer2InputNorm} | " ++ - s!"‖ln₁(X₂)‖_F: {h.layer2Ln1InputNorm} " ++ - s!"(ε₁={c.patternBound1}, ε₂={c.patternBound2})" - else - IO.println <| - s!"L{c.layer1Idx}H{c.head1Idx} -> L{c.layer2Idx}H{c.head2Idx} | " ++ - s!"Mech: {mechScore} | Effect: {h.effect} | " ++ - s!"kComp: {c.kComp} | " ++ - s!"indScore: {c.inductionScore} | prevTok: {c.prevTokenStrength} | " ++ - s!"Error: {c.combinedError} | " ++ - s!"‖X₂‖_F: {h.layer2InputNorm}" - -private def printInductionCandidates (heads : Array HeuristicInductionHead) (verbose : Bool) : - IO (Array HeuristicInductionHead) := do - let top := heads.take 50 - IO.println s!"Top Induction Head Pairs by mechScore (top {top.size} of {heads.size})" - for h in top do - printInductionCandidate verbose h - return top - -private def buildAdaptiveScheduler - (cache : PrecomputedCache) - (args : InductionArgs) : Option AdaptiveSchedulerResult := - if args.adaptive && (args.verbose || args.diagnostics) then - let cfg : Nfp.AdaptiveSchedulerConfig := - { targetSlack := args.targetSlack - maxUpgrades := args.maxUpgrades - minRelImprove := args.minRelImprove - krylovSteps := args.krylovSteps - scope := args.adaptiveScope - debugMonotone := args.diagnostics } - some (Nfp.runAdaptiveScheduler cache cfg none) - else - none - -private def printAdaptiveSchedulerSteps - (sched : AdaptiveSchedulerResult) - (args : InductionArgs) : IO Unit := do - IO.println "" - IO.println "ADAPTIVE SCHEDULER" - IO.println <| - s!" targetSlack={fmtFloat args.targetSlack} maxUpgrades={args.maxUpgrades} " ++ - s!"minRelImprove={fmtFloat args.minRelImprove} krylovSteps={args.krylovSteps} " ++ - s!"scope={args.adaptiveScopeStr}" - for s in sched.steps do - IO.println <| - match s.kind with - | .ubTier => - s!" it={s.iter} L{s.layerIdx}: tier {s.tierFrom}->{s.tierTo} " ++ - s!"ub {fmtFloat s.ubBefore}->{fmtFloat s.ubAfter} " ++ - s!"lb≈{fmtFloat s.lb} (k={s.kTo}) " ++ - s!"slack {fmtFloat s.slackBefore}->{fmtFloat s.slackAfter}" - | .lbSteps => - s!" it={s.iter} L{s.layerIdx}: lb-steps {s.kFrom}->{s.kTo} " ++ - s!"ub {fmtFloat s.ubBefore} " ++ - s!"lb≈{fmtFloat s.lb} slack {fmtFloat s.slackBefore}->{fmtFloat s.slackAfter}" - -private def printInductionDiagnostics - (cache : PrecomputedCache) - (top : Array HeuristicInductionHead) - (args : InductionArgs) - (sched? : Option AdaptiveSchedulerResult) : IO (Option UInt32) := do - IO.println "" - let diagN := min args.diagTop top.size - if args.adaptive then - let some sched := sched? - | do - IO.eprintln "Error: internal scheduler state missing." - return some 1 - IO.println "" - IO.println "LAYER NORM DIAGNOSTICS (LOWER vs UPPER)" - IO.println " (lb is rigorous lower bound via Krylov steps; ub is rigorous upper bound)" - for l in [:sched.ub.size] do - let ub := sched.ub[l]! - let lb := sched.lb[l]! - let ratio : Float := if lb > 1e-12 then ub / lb else Float.inf - let tier := sched.tier[l]! - let k := sched.lbK.getD l 0 - IO.println <| - s!" L{l}: lb≈{fmtFloat lb} ub={fmtFloat ub} " ++ - s!"ub/lb={fmtFloat ratio} tier={tier} k={k}" - let x := cache.forwardResult.getLayerInput l - let y := cache.forwardResult.getPostAttnResidual l - let ln1p := cache.model.ln1Params l - let ln2p := cache.model.ln2Params l - let ln1 := ConcreteMatrix.layerNormRowwiseOpDiag x ln1p.gamma - let ln2 := ConcreteMatrix.layerNormRowwiseOpDiag y ln2p.gamma - IO.println <| - s!" ln1: op≈{fmtFloat (ln1.gammaMaxAbs * ln1.maxInvStd)} " ++ - s!"(γmax≈{fmtFloat ln1.gammaMaxAbs}, invStdMax≈{fmtFloat ln1.maxInvStd} " ++ - s!"@r={ln1.maxInvStdRow}, varMin≈{fmtFloat ln1.minVar} @r={ln1.minVarRow})" - IO.println <| - s!" ln2: op≈{fmtFloat (ln2.gammaMaxAbs * ln2.maxInvStd)} " ++ - s!"(γmax≈{fmtFloat ln2.gammaMaxAbs}, invStdMax≈{fmtFloat ln2.maxInvStd} " ++ - s!"@r={ln2.maxInvStdRow}, varMin≈{fmtFloat ln2.minVar} @r={ln2.minVarRow})" - if hm : l < cache.model.mlps.size then - let mlp := cache.model.mlps[l]'hm - let y := cache.forwardResult.getPostAttnResidual l - let ln2Bound := cache.model.ln2OpBound l y - let (attnPart, maxHeadIdx, maxHeadContrib, maxHeadValue, maxHeadPattern) : - (Float × Nat × Float × Float × Float) := Id.run do - let layerData := cache.headData.getD l #[] - let mut a : Float := 0.0 - let mut bestIdx : Nat := 0 - let mut best : Float := 0.0 - let mut bestValue : Float := 0.0 - let mut bestPattern : Float := 0.0 - let mut idx : Nat := 0 - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := d.ln1OpBound * (attnOpUb * d.valueOutputProjSchurNorm) - let inputs : Nfp.PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := d.inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := d.ln1OpBound * (Nfp.computePatternTermBound inputs) - let contrib := valueTermUb + patternTermUb - a := a + contrib - if contrib > best then - bestIdx := idx - best := contrib - bestValue := valueTermUb - bestPattern := patternTermUb - idx := idx + 1 - (a, bestIdx, best, bestValue, bestPattern) - let mlpTotal : Float := max 0.0 (ub - attnPart) - let mlpOnly : Float := - if 1.0 + attnPart > 1e-12 then - mlpTotal / (1.0 + attnPart) - else - mlpTotal - let cross : Float := max 0.0 (mlpTotal - mlpOnly) - IO.println <| - s!" contrib: attn≈{fmtFloat attnPart} mlpOnly≈{fmtFloat mlpOnly} " ++ - s!"cross≈{fmtFloat cross} mlpTotal≈{fmtFloat mlpTotal} " ++ - s!"(maxHead=H{maxHeadIdx}≈{fmtFloat maxHeadContrib}, " ++ - s!"value≈{fmtFloat maxHeadValue}, pattern≈{fmtFloat maxHeadPattern})" - let mlpJacLegacy : Float := - let denom := ln2Bound * (1.0 + attnPart) - if denom > 1e-12 then - mlpTotal / denom - else - Float.nan - let geluDeriv := cache.forwardResult.getMlpGeluDeriv l - let diag := computeMLPOpAbsSchurDiag mlp geluDeriv - let chosen : Float := min diag.absSchur mlpJacLegacy - IO.println <| - s!" mlpDiag(L{l}): absSchur={fmtFloat diag.absSchur} " ++ - s!"legacy≈{fmtFloat mlpJacLegacy} chosen≈{fmtFloat chosen} " ++ - s!"dMax≈{fmtFloat diag.dMax}" - else - IO.println "LAYER NORM DIAGNOSTICS (PI vs rigorous)" - IO.println " (PI is diagnostics-only; rigorous is used for bounds)" - for l in [:cache.layerNormBounds.size] do - let ub := cache.layerNormBounds[l]! - let pi := estimateAttentionLayerNormHeuristicPI cache.model cache.forwardResult l true - let ratio : Float := if pi > 1e-12 then ub / pi else Float.inf - IO.println s!" L{l}: PI≈{fmtFloat pi} ub={fmtFloat ub} ub/PI={fmtFloat ratio}" - -- Rectangular Gram diagnostics for MLP weights (layer 0 only). - if h0 : 0 < cache.model.mlps.size then - let mlp0 : ConcreteMLPLayer := cache.model.mlps[0]'h0 - let dIn := mlp0.W_in.opNormUpperBoundRectGramDiag - let dOut := mlp0.W_out.opNormUpperBoundRectGramDiag - let chosenMsg (d : _) : String := - if d.usedGram then "chosen=signedGram" - else if d.usedAbsGram then "chosen=absGram" - else "chosen=cheap" - let signedGramMsg (d : _) : String := - if !d.signedGramEnabled then "signedGram=disabled" - else if d.gramDim > d.maxGramDimCap then "signedGram=skipped(maxGramDim cap)" - else if d.skippedGram then "signedGram=skipped(cost guard)" - else if d.computedGram then "signedGram=computed" - else "signedGram=not-attempted" - let absGramMsg (d : _) : String := - if !d.computedAbsGram then "absGram=disabled" - else if d.usedAbsGram then "absGram=chosen" - else "absGram=computed" - IO.println "" - IO.println "RECT-GRAM DIAGNOSTICS (MLP layer 0 weights)" - IO.println <| - s!" W_in: usedGram={dIn.usedGram} usedAbsGram={dIn.usedAbsGram} " ++ - s!"computedGram={dIn.computedGram} computedAbsGram={dIn.computedAbsGram} " ++ - s!"skippedGram={dIn.skippedGram}" - IO.println <| - s!" gramDim={dIn.gramDim} maxGramDimCap={dIn.maxGramDimCap} " ++ - s!"signedGramEnabled={dIn.signedGramEnabled}" - IO.println <| - s!" gramCost={dIn.gramCost} gramCostLimit={dIn.gramCostLimit} " ++ - s!"{chosenMsg dIn} {signedGramMsg dIn} {absGramMsg dIn}" - IO.println <| - s!" frob={fmtFloat dIn.frobBound} oneInf={fmtFloat dIn.oneInfBound} " ++ - s!"opBound={fmtFloat dIn.opBound}" - IO.println <| - s!" λ_abs_gersh={fmtFloat dIn.lambdaAbsGersh} " ++ - s!"λ_abs_brauer={fmtFloat dIn.lambdaAbsBrauer}" - IO.println s!" λ_gersh={fmtFloat dIn.lambdaGersh}" - IO.println s!" λ_brauer={fmtFloat dIn.lambdaBrauer}" - IO.println s!" λ_moment={fmtFloat dIn.lambdaMoment}" - IO.println s!" λ_used={fmtFloat dIn.lambdaUsed}" - IO.println <| - s!" W_out: usedGram={dOut.usedGram} usedAbsGram={dOut.usedAbsGram} " ++ - s!"computedGram={dOut.computedGram} computedAbsGram={dOut.computedAbsGram} " ++ - s!"skippedGram={dOut.skippedGram}" - IO.println <| - s!" gramDim={dOut.gramDim} maxGramDimCap={dOut.maxGramDimCap} " ++ - s!"signedGramEnabled={dOut.signedGramEnabled}" - IO.println <| - s!" gramCost={dOut.gramCost} gramCostLimit={dOut.gramCostLimit} " ++ - s!"{chosenMsg dOut} {signedGramMsg dOut} {absGramMsg dOut}" - IO.println <| - s!" frob={fmtFloat dOut.frobBound} oneInf={fmtFloat dOut.oneInfBound} " ++ - s!"opBound={fmtFloat dOut.opBound}" - IO.println <| - s!" λ_abs_gersh={fmtFloat dOut.lambdaAbsGersh} " ++ - s!"λ_abs_brauer={fmtFloat dOut.lambdaAbsBrauer}" - IO.println s!" λ_gersh={fmtFloat dOut.lambdaGersh}" - IO.println s!" λ_brauer={fmtFloat dOut.lambdaBrauer}" - IO.println s!" λ_moment={fmtFloat dOut.lambdaMoment}" - IO.println s!" λ_used={fmtFloat dOut.lambdaUsed}" - IO.println "" - IO.println s!"DIAGNOSTICS (ε decomposition) for top-{diagN} candidates" - for h in (top.take diagN) do - let c := h.candidate - IO.println s!"Candidate L{c.layer1Idx}H{c.head1Idx} -> L{c.layer2Idx}H{c.head2Idx}" - match cache.getHeadData c.layer1Idx c.head1Idx, - cache.getHeadData c.layer2Idx c.head2Idx with - | some d1, some d2 => - printHeadDiagnostics "Head1" d1 - printHeadDiagnostics "Head2" d2 - let ε1 := d1.faithfulnessRatioCached - let ε2 := d2.faithfulnessRatioCached - let combinedRecon := ε1 + ε2 + ε1 * ε2 - IO.println " Combined:" - IO.println s!" ε1 = {fmtFloat ε1} ε2 = {fmtFloat ε2}" - IO.println s!" combinedError = (1+ε1)(1+ε2)-1 = {fmtFloat combinedRecon}" - IO.println s!" combinedErrorCached = {fmtFloat c.combinedError}" - IO.println s!" reconDiff = {fmtFloat (combinedRecon - c.combinedError)}" - | _, _ => - IO.println " (diagnostics unavailable: missing cached head data)" - return none - -private def runInductionVerification - (model : ConcreteModel) - (heads : Array HeuristicInductionHead) - (correctOpt : Option Nat) : IO (Option UInt32) := do - IO.println "" - IO.println "Causal Verification (Head Ablation)" - IO.println "Metric: Δ = logit(target) - logit(top non-target) at last position" - IO.println "Impact = Δ_base - Δ_ablated" - IO.println "" - let targetToken? : Option Nat := - match correctOpt with - | some t => some t - | none => inductionTargetTokenFromHistory model - let some targetToken := targetToken? - | do - IO.eprintln "Error: Cannot infer induction target token (need TOKENS or --correct)." - return some 2 - match VerificationContext.build model targetToken {} with - | .error msg => - IO.eprintln s!"Error: {msg}" - return some 2 - | .ok ctx => - let verifyTop := heads.take 10 - IO.println s!"Top-{verifyTop.size} ablation checks (ranked by mechScore):" - IO.println s!" target={ctx.targetToken} | neg={ctx.negativeToken}" - IO.println s!" Δ_base={ctx.baseDelta}" - IO.println "" - IO.println "Rank | Candidate | Base Δ | Ablated Δ | Impact (Logits) | RelScore | \ - Control Impact | Axioms Verified?" - IO.println "-----|-----------|--------|----------|----------------|----------|\ - --------------|----------------" - let fmtOpt (x : Option Float) : String := - match x with - | some v => toString v - | none => "undef" - let mut rank : Nat := 1 - for h in verifyTop do - let c := h.candidate - let candHeads : Array HeadRef := #[ - { layerIdx := c.layer1Idx, headIdx := c.head1Idx }, - { layerIdx := c.layer2Idx, headIdx := c.head2Idx } - ] - let row := verifyCircuit ctx candHeads - let axiomsStr := - if row.axioms.verified then - "yes" - else - let reasons := String.intercalate "; " row.axioms.failures.toList - if reasons.isEmpty then "no" else s!"no ({reasons})" - IO.println s!"{rank} | {row.candidateLabel} | {row.baseDelta} | \ - {fmtOpt row.ablatedDelta} | {fmtOpt row.impact} | {fmtOpt row.relScore} | \ - {fmtOpt row.controlImpact} | {axiomsStr}" - rank := rank + 1 - return none - -/-- Run the induction command - discover induction heads ranked by effectiveness. -/ -def runInduction (p : Parsed) : IO UInt32 := do - let some args ← parseInductionArgs p - | return 1 - setStdoutLogNameFromModelPath args.modelPathStr - IO.println "Loading model..." - let loadResult ← loadModel args.modelPath - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => - let model := model0.trimTrailingZeroEmbeddings - if args.verbose && model.seqLen ≠ model0.seqLen then - IO.println s!" Trimmed trailing zero embeddings: seqLen {model0.seqLen} -> {model.seqLen}" - match model.unembedding with - | none => - IO.eprintln "Error: Model is missing unembedding matrix (needed for logit directions)." - return 1 - | some W_U => - let target? := deriveInductionTarget model W_U args.correctOpt args.incorrectOpt - let some target := target? - | do - if args.correctOpt.isSome ∨ args.incorrectOpt.isSome then - IO.eprintln "Error: Use both --correct and --incorrect (or neither to auto-detect)." - return 1 - else - IO.eprintln "No valid induction target could be derived from TOKENS \ - (no prior repetition of last token)." - IO.eprintln "Hint: pass --correct/--incorrect to override, or export a prompt \ - where the last token repeats." - return 2 - if args.correctOpt.isNone ∧ args.incorrectOpt.isNone ∧ model.inputTokens.isNone then - IO.println "Warning: No TOKENS section found; using default target logit_diff(0-1)." - IO.println "Hint: export with TOKENS or pass --correct/--incorrect." - IO.println s!"Target: {target.description}" - printInductionSearchIntro args.minEffect - let buildLayerNormBounds := args.diagnostics && (!args.adaptive) - let (heads, cache) := - findHeuristicInductionHeadsWithCache model target args.minEffect - (minInductionScore := 0.01) - (buildLayerNormBounds := buildLayerNormBounds) - (storeDiagnostics := args.diagnostics) - let top ← printInductionCandidates heads args.verbose - let sched? := buildAdaptiveScheduler cache args - if args.adaptive && args.verbose then - match sched? with - | some sched => printAdaptiveSchedulerSteps sched args - | none => pure () - if args.diagnostics then - let err? ← printInductionDiagnostics cache top args sched? - if let some code := err? then - return code - if args.verify then - let err? ← runInductionVerification model heads args.correctOpt - if let some code := err? then - return code - return 0 - -/-! ## SOUND command helpers -/ - -private structure CertifyArgs where - modelPath : System.FilePath - modelPathStr : String - inputPath? : Option System.FilePath - soundnessBits : Nat - partitionDepth : Nat - deltaFlag? : Option String - deltaStr : String - softmaxMarginStr : String - softmaxExpEffort : Nat - outputPath? : Option System.FilePath - -private def parseCertifyArgs (p : Parsed) : CertifyArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let inputPath? := p.flag? "input" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let partitionDepth := p.flag? "partitionDepth" |>.map (·.as! Nat) |>.getD 0 - let deltaFlag? := p.flag? "delta" |>.map (·.as! String) - let deltaStr := deltaFlag?.getD "0" - let softmaxMarginStr := p.flag? "softmaxMargin" |>.map (·.as! String) |>.getD "0" - let softmaxExpEffort := - p.flag? "softmaxExpEffort" |>.map (·.as! Nat) |>.getD Nfp.Sound.defaultSoftmaxExpEffort - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - { modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - inputPath? := inputPath? - soundnessBits := soundnessBits - partitionDepth := partitionDepth - deltaFlag? := deltaFlag? - deltaStr := deltaStr - softmaxMarginStr := softmaxMarginStr - softmaxExpEffort := softmaxExpEffort - outputPath? := outputPath? } - -private def runCertifyAction (args : CertifyArgs) : ExceptT String IO Nfp.Sound.ModelCert := do - let delta ← - match Nfp.Sound.parseRat args.deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{args.deltaStr}': {e}" - let softmaxMarginLowerBound ← - match Nfp.Sound.parseRat args.softmaxMarginStr with - | .ok r => pure r - | .error e => throw s!"invalid --softmaxMargin '{args.softmaxMarginStr}': {e}" - /- If `--input` is omitted but `--delta` is provided, try to use `modelPath` as the input file - (for `.nfpt` exports that embed `EMBEDDINGS` in the same file). This keeps `nfp certify` - ergonomic without changing the default behavior when `--delta` is absent. -/ - let inputPath? : Option System.FilePath ← - match args.inputPath? with - | some path => pure (some path) - | none => - match args.deltaFlag? with - | none => pure none - | some _ => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers args.modelPath - if hasEmbeddings then - pure (some args.modelPath) - else - throw <| - "local certification requested via --delta, but the model file has no \ -EMBEDDINGS section before the first LAYER (legacy text format). Pass --input \ -containing EMBEDDINGS or omit --delta for global certification." - let cert ← ExceptT.mk <| - Nfp.Sound.certifyModelFile args.modelPath args.soundnessBits - (inputPath? := inputPath?) (inputDelta := delta) (partitionDepth := args.partitionDepth) - (softmaxMarginLowerBound := softmaxMarginLowerBound) - (softmaxExpEffort := args.softmaxExpEffort) - return cert - -private structure HeadBoundsArgs where - modelPath : System.FilePath - inputPath? : Option System.FilePath - deltaFlag? : Option String - deltaStr : String - soundnessBits : Nat - scalePow10 : Nat - outputPath? : Option System.FilePath - -private def parseHeadBoundsArgs (p : Parsed) : HeadBoundsArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let inputPath? := p.flag? "input" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let deltaFlag? := p.flag? "delta" |>.map (·.as! String) - let deltaStr := deltaFlag?.getD "0" - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - { modelPath := ⟨modelPathStr⟩ - inputPath? := inputPath? - deltaFlag? := deltaFlag? - deltaStr := deltaStr - soundnessBits := soundnessBits - scalePow10 := scalePow10 - outputPath? := outputPath? } - -private def formatHeadBoundsLocal - (heads : Array Nfp.Sound.HeadLocalContributionCert) - (delta : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath) - (modelPath : System.FilePath) : String := - let header := - "SOUND per-head bounds (local): " ++ - s!"delta={delta}, soundnessBits={soundnessBits}, input={inputPath?.getD modelPath}\n" - let body := - heads.foldl (fun acc h => - acc ++ - s!"Layer {h.layerIdx} Head {h.headIdx}: " ++ - s!"ln1MaxAbsGamma={h.ln1MaxAbsGamma}, " ++ - s!"ln1VarLB={h.ln1VarianceLowerBound}, " ++ - s!"ln1Bound={h.ln1Bound}, " ++ - s!"wqOp={h.wqOpBound}, wkOp={h.wkOpBound}, " ++ - s!"qk={h.qkFactorBound}, " ++ - s!"wvOp={h.wvOpBound}, woOp={h.woOpBound}, " ++ - s!"attn={h.attnJacBound}\n") "" - header ++ body - -private def formatHeadBoundsGlobal - (heads : Array Nfp.Sound.HeadContributionCert) - (scalePow10 : Nat) : String := - let header := - s!"SOUND per-head bounds (weight-only): scalePow10={scalePow10}\n" - let body := - heads.foldl (fun acc h => - acc ++ - s!"Layer {h.layerIdx} Head {h.headIdx}: " ++ - s!"wqOp={h.wqOpBound}, wkOp={h.wkOpBound}, " ++ - s!"wvOp={h.wvOpBound}, woOp={h.woOpBound}, " ++ - s!"qk={h.qkFactorBound}, vo={h.voFactorBound}\n") "" - header ++ body - -private def runHeadBoundsAction (args : HeadBoundsArgs) : ExceptT String IO String := do - let delta ← - match Nfp.Sound.parseRat args.deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{args.deltaStr}': {e}" - let useLocal := (args.inputPath?.isSome || args.deltaFlag?.isSome) - if useLocal then - let inputPath? : Option System.FilePath ← - match args.inputPath? with - | some path => pure (some path) - | none => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers args.modelPath - if hasEmbeddings then - pure (some args.modelPath) - else - throw <| - "local head bounds requested via --delta, but the model file has no \ -EMBEDDINGS section before the first LAYER (legacy text format). Pass --input \ -containing EMBEDDINGS or omit --delta for global head bounds." - let heads ← ExceptT.mk <| - Nfp.Sound.certifyHeadBoundsLocal args.modelPath - (inputPath? := inputPath?) (inputDelta := delta) (soundnessBits := args.soundnessBits) - return formatHeadBoundsLocal heads delta args.soundnessBits inputPath? args.modelPath - else - let heads ← ExceptT.mk <| - Nfp.Sound.certifyHeadBounds args.modelPath (scalePow10 := args.scalePow10) - return formatHeadBoundsGlobal heads args.scalePow10 - -private structure HeadPatternArgs where - modelPath : System.FilePath - layerIdx : Nat - headIdx : Nat - offset : Int - soundnessBits : Nat - softmaxExpEffort : Nat - tightPatternLayers : Nat - tightPattern : Bool - perRowPatternLayers : Nat - bestMatch : Bool - sweep : Bool - queryPos? : Option Nat - inputPath? : Option System.FilePath - deltaStr : String - maxSeqLen : Nat - outputPath? : Option System.FilePath - -private def parseHeadPatternArgs (p : Parsed) : HeadPatternArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let layerIdx := p.flag? "layer" |>.map (·.as! Nat) |>.getD 0 - let headIdx := p.flag? "head" |>.map (·.as! Nat) |>.getD 0 - let offset := p.flag? "offset" |>.map (·.as! Int) |>.getD (-1) - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let softmaxExpEffort := - p.flag? "softmaxExpEffort" |>.map (·.as! Nat) - |>.getD Nfp.Sound.defaultSoftmaxExpEffort - let tightPatternLayers? := p.flag? "tightPatternLayers" |>.map (·.as! Nat) - let tightPatternLayers := tightPatternLayers?.getD 1 - let tightPattern := p.hasFlag "tightPattern" || tightPatternLayers?.isSome - let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 - let bestMatch := p.hasFlag "bestMatch" - let sweep := p.hasFlag "sweep" - let queryPos? := p.flag? "queryPos" |>.map (·.as! Nat) - let inputPath? := p.flag? "input" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - let deltaStr := p.flag? "delta" |>.map (·.as! String) |>.getD "0" - let maxSeqLen := p.flag? "maxSeqLen" |>.map (·.as! Nat) |>.getD 256 - let outputPath? := p.flag? "output" |>.map (·.as! String) |>.map (fun s => ⟨s⟩) - { modelPath := ⟨modelPathStr⟩ - layerIdx := layerIdx - headIdx := headIdx - offset := offset - soundnessBits := soundnessBits - softmaxExpEffort := softmaxExpEffort - tightPatternLayers := tightPatternLayers - tightPattern := tightPattern - perRowPatternLayers := perRowPatternLayers - bestMatch := bestMatch - sweep := sweep - queryPos? := queryPos? - inputPath? := inputPath? - deltaStr := deltaStr - maxSeqLen := maxSeqLen - outputPath? := outputPath? } - -private def formatHeadPatternBestMatchSweep - (layerIdx headIdx : Nat) - (offset : Int) - (certs : Array Nfp.Sound.HeadBestMatchPatternCert) : String := - let header := - "SOUND head pattern sweep (best-match): " ++ - s!"layer={layerIdx}, head={headIdx}, offset={offset}\n" - let body := - certs.foldl (fun acc cert => - acc ++ - s!"queryPos={cert.queryPos} targetTok={cert.targetToken} " ++ - s!"marginLB={cert.marginLowerBound} " ++ - s!"weightLB={cert.bestMatchWeightLowerBound}\n") "" - header ++ body - -private def formatHeadPatternBestMatch - (cert : Nfp.Sound.HeadBestMatchPatternCert) : String := - "SOUND head pattern (best-match): " ++ - s!"layer={cert.layerIdx}, head={cert.headIdx}, " ++ - s!"offset={cert.targetOffset}, queryPos={cert.queryPos}\n" ++ - s!"seqLen={cert.seqLen}, targetTok={cert.targetToken}, " ++ - s!"bestMatchLogitLB={cert.bestMatchLogitLowerBound}, " ++ - s!"bestNonmatchLogitUB={cert.bestNonmatchLogitUpperBound}\n" ++ - s!"marginLB={cert.marginLowerBound}, " ++ - s!"bestMatchWeightLB={cert.bestMatchWeightLowerBound}, " ++ - s!"softmaxExpEffort={cert.softmaxExpEffort}\n" - -private def formatHeadPatternLocal - (cert : Nfp.Sound.HeadPatternCert) : String := - "SOUND head pattern (local): " ++ - s!"layer={cert.layerIdx}, head={cert.headIdx}, offset={cert.targetOffset}\n" ++ - s!"seqLen={cert.seqLen}, " ++ - s!"targetCountLB={cert.targetCountLowerBound}, " ++ - s!"targetLogitLB={cert.targetLogitLowerBound}, " ++ - s!"otherLogitUB={cert.otherLogitUpperBound}\n" ++ - s!"marginLB={cert.marginLowerBound}, " ++ - s!"targetWeightLB={cert.targetWeightLowerBound}, " ++ - s!"softmaxExpEffort={cert.softmaxExpEffort}\n" - -private def runHeadPatternAction (args : HeadPatternArgs) : ExceptT String IO String := do - let delta ← - match Nfp.Sound.parseRat args.deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{args.deltaStr}': {e}" - let inputPath? : Option System.FilePath ← - match args.inputPath? with - | some path => pure (some path) - | none => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers args.modelPath - if hasEmbeddings then - pure (some args.modelPath) - else - throw <| - "head pattern bounds require EMBEDDINGS; pass --input for legacy text models." - if args.bestMatch then - if args.sweep then - let certs ← ExceptT.mk <| - Nfp.Sound.certifyHeadPatternBestMatchLocalSweep args.modelPath args.layerIdx args.headIdx - (inputPath? := inputPath?) (inputDelta := delta) - (soundnessBits := args.soundnessBits) (targetOffset := args.offset) - (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) - (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (softmaxExpEffort := args.softmaxExpEffort) - return formatHeadPatternBestMatchSweep args.layerIdx args.headIdx args.offset certs - else - let cert ← ExceptT.mk <| - Nfp.Sound.certifyHeadPatternBestMatchLocal args.modelPath args.layerIdx args.headIdx - (queryPos? := args.queryPos?) (inputPath? := inputPath?) - (inputDelta := delta) (soundnessBits := args.soundnessBits) - (targetOffset := args.offset) (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (softmaxExpEffort := args.softmaxExpEffort) - return formatHeadPatternBestMatch cert - else - if args.sweep then - throw "use --sweep with --bestMatch" - let cert ← ExceptT.mk <| - Nfp.Sound.certifyHeadPatternLocal args.modelPath args.layerIdx args.headIdx - (inputPath? := inputPath?) (inputDelta := delta) - (soundnessBits := args.soundnessBits) (targetOffset := args.offset) - (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (softmaxExpEffort := args.softmaxExpEffort) - return formatHeadPatternLocal cert - -/-- Run the certify command - compute conservative, exact bounds in sound mode. -/ -def runCertify (p : Parsed) : IO UInt32 := do - let args := parseCertifyArgs p - setStdoutLogNameFromModelPath args.modelPathStr - match ← (runCertifyAction args).run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok cert => - writeReport args.outputPath? (toString cert) - return 0 - -/-- Run the head-bounds command - compute per-head contribution bounds in sound mode. -/ -def runHeadBounds (p : Parsed) : IO UInt32 := do - let args := parseHeadBoundsArgs p - match ← (runHeadBoundsAction args).run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok s => - writeReport args.outputPath? s - return 0 - -/-- Run the head-pattern command - compute per-head attention pattern bounds in sound mode. -/ -def runHeadPattern (p : Parsed) : IO UInt32 := do - let args := parseHeadPatternArgs p - match ← (runHeadPatternAction args).run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok s => - writeReport args.outputPath? s - return 0 - -/-- Parsed arguments for `induction-cert`, with resolved input path and delta. -/ -private structure InductionCertArgs where - modelPath : System.FilePath - layer1 : Nat - head1 : Nat - layer2 : Nat - head2 : Nat - coord : Nat - offset1 : Int - offset2 : Int - targetToken? : Option Nat - negativeToken? : Option Nat - soundnessBits : Nat - softmaxExpEffort : Nat - tightPatternLayers : Nat - tightPattern : Bool - perRowPatternLayers : Nat - bestMatch : Bool - queryPos? : Option Nat - inputPath? : Option System.FilePath - delta : Rat - maxSeqLen : Nat - outputPath? : Option System.FilePath - -/-- Parse and validate `induction-cert` arguments. -/ -private def parseInductionCertArgs (p : Parsed) : ExceptT String IO InductionCertArgs := do - let modelPathStr := p.positionalArg! "model" |>.as! String - let modelPath : System.FilePath := ⟨modelPathStr⟩ - let layer1 := p.flag? "layer1" |>.map (·.as! Nat) |>.getD 0 - let head1 := p.flag? "head1" |>.map (·.as! Nat) |>.getD 0 - let layer2 := p.flag? "layer2" |>.map (·.as! Nat) |>.getD 1 - let head2 := p.flag? "head2" |>.map (·.as! Nat) |>.getD 0 - let coord := p.flag? "coord" |>.map (·.as! Nat) |>.getD 0 - let offset1 := p.flag? "offset1" |>.map (·.as! Int) |>.getD (-1) - let offset2 := p.flag? "offset2" |>.map (·.as! Int) |>.getD (-1) - let targetToken := p.flag? "target" |>.map (·.as! Nat) - let negativeToken := p.flag? "negative" |>.map (·.as! Nat) - let soundnessBits := p.flag? "soundnessBits" |>.map (·.as! Nat) |>.getD 20 - let softmaxExpEffort := - p.flag? "softmaxExpEffort" |>.map (·.as! Nat) - |>.getD Nfp.Sound.defaultSoftmaxExpEffort - let tightPatternLayers? := p.flag? "tightPatternLayers" |>.map (·.as! Nat) - let tightPatternLayers := tightPatternLayers?.getD 1 - let tightPattern := p.hasFlag "tightPattern" || tightPatternLayers?.isSome - let perRowPatternLayers := p.flag? "perRowPatternLayers" |>.map (·.as! Nat) |>.getD 0 - let bestMatch := p.hasFlag "bestMatch" - let queryPos := p.flag? "queryPos" |>.map (·.as! Nat) - let inputPath := p.flag? "input" |>.map (·.as! String) - let deltaStr := p.flag? "delta" |>.map (·.as! String) |>.getD "0" - let maxSeqLen := p.flag? "maxSeqLen" |>.map (·.as! Nat) |>.getD 256 - let outputPath := p.flag? "output" |>.map (·.as! String) - let delta ← - match Nfp.Sound.parseRat deltaStr with - | .ok r => pure r - | .error e => throw s!"invalid --delta '{deltaStr}': {e}" - let inputPath? : Option System.FilePath ← - match inputPath with - | some s => pure (some ⟨s⟩) - | none => - let hasEmbeddings ← - hasEmbeddingsBeforeLayers modelPath - if hasEmbeddings then - pure (some modelPath) - else - throw <| - "induction cert requires EMBEDDINGS; pass --input for legacy text models." - let outputPath? : Option System.FilePath := - outputPath.map (fun s => ⟨s⟩) - return { - modelPath := modelPath - layer1 := layer1 - head1 := head1 - layer2 := layer2 - head2 := head2 - coord := coord - offset1 := offset1 - offset2 := offset2 - targetToken? := targetToken - negativeToken? := negativeToken - soundnessBits := soundnessBits - softmaxExpEffort := softmaxExpEffort - tightPatternLayers := tightPatternLayers - tightPattern := tightPattern - perRowPatternLayers := perRowPatternLayers - bestMatch := bestMatch - queryPos? := queryPos - inputPath? := inputPath? - delta := delta - maxSeqLen := maxSeqLen - outputPath? := outputPath? - } - -/-- Format the optional logit-diff line for local induction certs. -/ -private def formatInductionLogitLine - (logit? : Option Nfp.Sound.HeadLogitDiffLowerBoundCert) : String := - match logit? with - | none => "" - | some logit => - s!"logitDiffLB={logit.logitDiffLowerBound} " ++ - s!"targetTok={logit.targetToken} negTok={logit.negativeToken}\n" ++ - s!"logitMatchLB={logit.matchLogitLowerBound} " ++ - s!"logitNonmatchLB={logit.nonmatchLogitLowerBound}\n" - -/-- Format the optional logit-diff line for best-match induction certs. -/ -private def formatInductionLogitLinePos - (logit? : Option Nfp.Sound.HeadLogitDiffLowerBoundPosCert) : String := - match logit? with - | none => "" - | some logit => - s!"logitDiffLB={logit.logitDiffLowerBound} " ++ - s!"targetTok={logit.targetToken} negTok={logit.negativeToken}\n" ++ - s!"logitMatchLB={logit.matchLogitLowerBound} " ++ - s!"logitNonmatchLB={logit.nonmatchLogitLowerBound}\n" - -/-- Render a best-match induction certificate report. -/ -private def formatInductionBestMatch - (cert : Nfp.Sound.InductionHeadBestMatchSoundCert) : String := - let p1 := cert.layer1Pattern - let p2 := cert.layer2Pattern - let v := cert.layer2Value - let logitLine := formatInductionLogitLinePos cert.layer2Logit? - "SOUND induction cert (best-match):\n" ++ - s!"queryPos={p2.queryPos}\n" ++ - s!"layer1=L{p1.layerIdx} H{p1.headIdx} offset={p1.targetOffset} " ++ - s!"targetTok={p1.targetToken} " ++ - s!"marginLB={p1.marginLowerBound} " ++ - s!"weightLB={p1.bestMatchWeightLowerBound} " ++ - s!"softmaxExpEffort={p1.softmaxExpEffort}\n" ++ - s!"layer2=L{p2.layerIdx} H{p2.headIdx} offset={p2.targetOffset} " ++ - s!"targetTok={p2.targetToken} " ++ - s!"marginLB={p2.marginLowerBound} " ++ - s!"weightLB={p2.bestMatchWeightLowerBound} " ++ - s!"softmaxExpEffort={p2.softmaxExpEffort}\n" ++ - s!"coord={v.coord} matchCoordLB={v.matchCoordLowerBound} " ++ - s!"nonmatchCoordLB={v.nonmatchCoordLowerBound}\n" ++ - s!"deltaLB={cert.deltaLowerBound}\n" ++ - logitLine - -/-- Render a local induction certificate report. -/ -private def formatInductionLocal - (cert : Nfp.Sound.InductionHeadSoundCert) : String := - let p1 := cert.layer1Pattern - let p2 := cert.layer2Pattern - let v := cert.layer2Value - let logitLine := formatInductionLogitLine cert.layer2Logit? - "SOUND induction cert:\n" ++ - s!"layer1=L{p1.layerIdx} H{p1.headIdx} offset={p1.targetOffset} " ++ - s!"marginLB={p1.marginLowerBound} weightLB={p1.targetWeightLowerBound} " ++ - s!"softmaxExpEffort={p1.softmaxExpEffort}\n" ++ - s!"layer2=L{p2.layerIdx} H{p2.headIdx} offset={p2.targetOffset} " ++ - s!"marginLB={p2.marginLowerBound} weightLB={p2.targetWeightLowerBound} " ++ - s!"softmaxExpEffort={p2.softmaxExpEffort}\n" ++ - s!"coord={v.coord} matchCountLB={p2.targetCountLowerBound} " ++ - s!"matchCoordLB={v.matchCoordLowerBound} " ++ - s!"nonmatchCoordLB={v.nonmatchCoordLowerBound}\n" ++ - s!"deltaLB={cert.deltaLowerBound}\n" ++ - logitLine - -/-- Run the induction-cert action and return the report string. -/ -private def runInductionCertAction (args : InductionCertArgs) : ExceptT String IO String := do - if args.bestMatch then - let cert ← ExceptT.mk <| - Nfp.Sound.certifyInductionSoundBestMatch args.modelPath - args.layer1 args.head1 args.layer2 args.head2 args.coord - (queryPos? := args.queryPos?) (inputPath? := args.inputPath?) - (inputDelta := args.delta) (soundnessBits := args.soundnessBits) - (offset1 := args.offset1) (offset2 := args.offset2) (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) - (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) - (softmaxExpEffort := args.softmaxExpEffort) - return formatInductionBestMatch cert - else - let cert ← ExceptT.mk <| - Nfp.Sound.certifyInductionSound args.modelPath - args.layer1 args.head1 args.layer2 args.head2 args.coord - (inputPath? := args.inputPath?) (inputDelta := args.delta) - (soundnessBits := args.soundnessBits) - (offset1 := args.offset1) (offset2 := args.offset2) (maxSeqLen := args.maxSeqLen) - (tightPattern := args.tightPattern) (tightPatternLayers := args.tightPatternLayers) - (perRowPatternLayers := args.perRowPatternLayers) - (targetToken? := args.targetToken?) (negativeToken? := args.negativeToken?) - (softmaxExpEffort := args.softmaxExpEffort) - return formatInductionLocal cert - -/-- Run the induction-cert command - compute a sound induction head certificate. -/ -def runInductionCert (p : Parsed) : IO UInt32 := do - let action : ExceptT String IO (String × Option System.FilePath) := do - let args ← parseInductionCertArgs p - let report ← runInductionCertAction args - return (report, args.outputPath?) - match ← action.run with - | .error msg => - IO.eprintln s!"Error: {msg}" - return 1 - | .ok (s, outputPath?) => - match outputPath? with - | some path => - IO.FS.writeFile path s - IO.println s!"Report written to {path}" - | none => - IO.println s - return 0 - -/-! ## Sound cache check helpers -/ - -private structure SoundCacheCheckArgs where - modelPath : System.FilePath - scalePow10 : Nat - maxTokens : Nat - -private def parseSoundCacheCheckArgs (p : Parsed) : SoundCacheCheckArgs := - let modelPath := p.positionalArg! "model" |>.as! String - let scalePow10 := p.flag? "scalePow10" |>.map (·.as! Nat) |>.getD 9 - let maxTokens := p.flag? "maxTokens" |>.map (·.as! Nat) |>.getD 0 - { modelPath := ⟨modelPath⟩, scalePow10 := scalePow10, maxTokens := maxTokens } - -private def runSoundCacheCheckWithArgs (args : SoundCacheCheckArgs) : IO UInt32 := do - let modelHash ← Nfp.Untrusted.SoundCacheIO.fnv1a64File args.modelPath - let cacheFp := Nfp.Sound.SoundCache.cachePath args.modelPath modelHash args.scalePow10 - match (← Nfp.Untrusted.SoundCacheIO.buildCacheFile args.modelPath cacheFp args.scalePow10) with - | .error e => - IO.eprintln s!"Error: cache build failed: {e}" - return 1 - | .ok _ => - let ch ← IO.FS.Handle.mk cacheFp IO.FS.Mode.read - let hdr ← Nfp.Untrusted.SoundCacheIO.readHeader ch - if hdr.modelHash ≠ modelHash then - IO.eprintln "Error: cache hash mismatch" - return 1 - match (← Nfp.Untrusted.SoundCacheIO.checkCacheFileSize cacheFp hdr) with - | .error e => - IO.eprintln s!"Error: {e}" - return 1 - | .ok _ => - match - (← Nfp.Untrusted.SoundCacheIO.checkTextTokenEnvelope - args.modelPath args.scalePow10 args.maxTokens) with - | .error e => - IO.eprintln s!"Error: {e}" - return 1 - | .ok _ => - IO.println <| - "OK: sound cache validated " ++ - s!"(scalePow10={args.scalePow10}, maxTokens={args.maxTokens})" - return 0 - -/-- Regression test for SOUND fixed-point cache soundness and consistency. - -This is intended for CI and small fixtures. It: -- builds a `.nfpc` cache (if needed), -- checks cache size matches the expected tensor stream length, -- checks the `±1`-ulp rounding envelope on up to `maxTokens` numeric tokens in the text file. --/ -def runSoundCacheCheck (p : Parsed) : IO UInt32 := do - let args := parseSoundCacheCheckArgs p - runSoundCacheCheckWithArgs args - -/-- Run the rope command - print a proof-backed RoPE operator norm certificate. -/ -def runRoPE (p : Parsed) : IO UInt32 := do - let seqLen := p.flag? "seqLen" |>.map (·.as! Nat) |>.getD 4 - let pairs := p.flag? "pairs" |>.map (·.as! Nat) |>.getD 8 - match seqLen, pairs with - | Nat.succ n, Nat.succ m => - -- Instantiate the theorem at concrete sizes to ensure the report is proof-backed. - let θ : Fin (Nat.succ n) → Fin (Nat.succ m) → ℝ := fun _ _ => 0 - have _ : - Nfp.operatorNormBound - (n := Fin (Nat.succ n)) (d := Nfp.RoPEDim (Fin (Nat.succ m))) - (Nfp.ropeJacobian (pos := Fin (Nat.succ n)) (pair := Fin (Nat.succ m)) θ) - ≤ (2 : ℝ) := by - simpa using - (Nfp.rope_operatorNormBound_le_two - (pos := Fin (Nat.succ n)) (pair := Fin (Nat.succ m)) θ) - IO.println "RoPE certificate (static):" - IO.println s!" seqLen={seqLen}, pairs={pairs}, dim={2 * pairs}" - IO.println " Bound: operatorNormBound(ropeJacobian θ) ≤ 2" - IO.println " Meaning: max row-sum of absolute weights (ℓ1 induced for row-vectors)." - IO.println " Proof: Nfp.rope_operatorNormBound_le_two (uses |sin|,|cos| ≤ 1 from mathlib)." - return 0 - | _, _ => - IO.eprintln "Error: --seqLen and --pairs must be positive." - return 1 - -/-! ## Dump command helpers -/ - -private structure DumpArgs where - modelPath : System.FilePath - modelPathStr : String - layer : Nat - pos : Nat - take : Nat - kind : String - -private def parseDumpArgs (p : Parsed) : DumpArgs := - let modelPathStr := p.positionalArg! "model" |>.as! String - let layer := p.flag? "layer" |>.map (·.as! Nat) |>.getD 0 - let pos := p.flag? "pos" |>.map (·.as! Nat) |>.getD 0 - let take := p.flag? "take" |>.map (·.as! Nat) |>.getD 16 - let kind := p.flag? "kind" |>.map (·.as! String) |>.getD "afterLayer" - { modelPath := ⟨modelPathStr⟩ - modelPathStr := modelPathStr - layer := layer - pos := pos - take := take - kind := kind } - -private def selectDumpMatrix - (kind : String) (layer : Nat) (model : ConcreteModel) (fwd : ForwardPassResult) : - ConcreteMatrix := - match kind with - | "embeddings" => model.inputEmbeddings - | "layerInput" => fwd.getLayerInput layer - | "postAttn" => fwd.getPostAttnResidual layer - | "afterLayer" => fwd.getLayerInput (layer + 1) - | _ => fwd.getLayerInput (layer + 1) - -private def collectDumpRow (X : ConcreteMatrix) (pos n : Nat) : - (Array Float × Float × Float) := Id.run do - let mut xs : Array Float := Array.mkEmpty n - let mut sum : Float := 0.0 - let mut sumSq : Float := 0.0 - for j in [:n] do - let v := X.get pos j - xs := xs.push v - sum := sum + v - sumSq := sumSq + v * v - return (xs, sum, sumSq) - -private def runDumpWithArgs (args : DumpArgs) : IO UInt32 := do - setStdoutLogNameFromModelPath args.modelPathStr - let loadResult ← loadModel args.modelPath - match loadResult with - | .error msg => - IO.eprintln s!"Error loading model: {msg}" - return 1 - | .ok model0 => - let model := model0.trimTrailingZeroEmbeddings - let fwd := model.runForward true - let X := selectDumpMatrix args.kind args.layer model fwd - if X.numRows = 0 || X.numCols = 0 then - IO.eprintln s!"Error: empty matrix for kind={args.kind}" - return 1 - if args.pos ≥ X.numRows then - IO.eprintln s!"Error: pos={args.pos} out of bounds (rows={X.numRows})" - return 1 - let n := min args.take X.numCols - let (xs, sum, sumSq) := collectDumpRow X args.pos n - IO.println <| - s!"DUMP kind={args.kind} layer={args.layer} pos={args.pos} take={n} " ++ - s!"rows={X.numRows} cols={X.numCols}" - IO.println s!"sum={sum} sumSq={sumSq}" - IO.println (String.intercalate " " (xs.toList.map (fun x => s!"{x}"))) - return 0 - -/-- Dump a small slice of a forward pass for cross-implementation sanity checks. -/ -def runDump (p : Parsed) : IO UInt32 := do - let args := parseDumpArgs p - runDumpWithArgs args - -/-- The analyze subcommand. -/ -def analyzeCmd : Cmd := `[Cli| - analyze VIA runAnalyze; - "Analyze a neural network model for circuit discovery and verification." - FLAGS: - t, threshold : String; "Error threshold for verification (default: 0.1)" - o, output : String; "Write report to file instead of stdout" - verify; "Run empirical verification (requires input in model)" - v, verbose; "Enable verbose output" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The induction subcommand. -/ -def inductionCmd : Cmd := `[Cli| - induction VIA runInduction; - "Discover induction head pairs ranked by mechScore (kComp·indScore·prevTok)." - FLAGS: - c, correct : Nat; "Correct token ID (manual override; requires --incorrect)" - i, incorrect : Nat; "Incorrect token ID (manual override; requires --correct)" - t, threshold : String; "Minimum normalized Effect threshold (default: 0.0)" - verify; "Run causal verification via head ablation on the top-10 candidates" - v, verbose; "Enable verbose output" - d, diagnostics; "Print diagnostic breakdown of ε bounds (pattern/value decomposition)" - diagTop : Nat; "How many top candidates get diagnostics (default: 5)" - adaptive; "Enable adaptive bound scheduler (rigorous; deterministic)" - targetSlack : String; "Stop when ub/lb ≤ targetSlack (default: 8.0)" - maxUpgrades : Nat; "Maximum adaptive upgrades (default: 120)" - minRelImprove : String; "Stop upgrading a layer if improvement < this fraction (default: 0.01)" - krylovSteps : Nat; "Krylov steps for LOWER bounds only (default: 2)" - adaptiveScope : String; "Adaptive scope: layernorm | all (default: layernorm)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The certify subcommand. -/ -def certifyCmd : Cmd := `[Cli| - certify VIA runCertify; - "SOUND mode: compute conservative bounds using exact Rat arithmetic (no Float trust). \ -LayerNorm epsilon and GeLU kind are read from the model header." - FLAGS: - input : String; "Optional input .nfpt file for local certification (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local certification (default: 0; \ -if --input is omitted, uses EMBEDDINGS in the model file when present)" - softmaxMargin : String; "Lower bound on softmax logit margin (default: 0)" - softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - partitionDepth : Nat; "Partition depth for input splitting (default: 0; >0 scaffold only)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The head-bounds subcommand. -/ -def headBoundsCmd : Cmd := `[Cli| - head_bounds VIA runHeadBounds; - "SOUND mode: compute per-head contribution bounds. \ -LayerNorm epsilon is read from the model header." - FLAGS: - input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local bounds (default: 0; if --input is omitted, \ -uses EMBEDDINGS in the model file when present)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - scalePow10 : Nat; "Fixed-point scale exponent p in S=10^p for global bounds (default: 9)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The head-pattern subcommand. -/ -def headPatternCmd : Cmd := `[Cli| - head_pattern VIA runHeadPattern; - "SOUND mode: compute per-head attention pattern bounds (binary only). \ -LayerNorm epsilon is read from the model header." - FLAGS: - layer : Nat; "Layer index (default: 0)" - head : Nat; "Head index (default: 0)" - offset : Int; "Token-match offset (default: -1 for previous token, 0 for self)" - tightPattern; "Use tighter (slower) pattern bounds near the target layer" - tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" - perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" - bestMatch; "Use best-match (single-query) pattern bounds" - sweep; "Sweep best-match bounds across all valid query positions" - queryPos : Nat; "Query position for best-match bounds (default: last position)" - input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local bounds (default: 0)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" - maxSeqLen : Nat; "Maximum sequence length to analyze (default: 256)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The induction-cert subcommand. -/ -def inductionCertCmd : Cmd := `[Cli| - induction_cert VIA runInductionCert; - "SOUND mode: compute a minimal induction head certificate (binary only). \ -LayerNorm epsilon is read from the model header." - FLAGS: - layer1 : Nat; "Layer index for the previous-token head (default: 0)" - head1 : Nat; "Head index for the previous-token head (default: 0)" - layer2 : Nat; "Layer index for the token-match head (default: 1)" - head2 : Nat; "Head index for the token-match head (default: 0)" - coord : Nat; "Output coordinate for the value bound (default: 0)" - offset1 : Int; "Token-match offset for layer1 (default: -1)" - offset2 : Int; "Token-match offset for layer2 (default: -1)" - target : Nat; "Target token ID for logit-diff bound (optional; requires --negative)" - negative : Nat; "Negative token ID for logit-diff bound (optional; requires --target)" - tightPattern; "Use tighter (slower) pattern bounds near the target layer" - tightPatternLayers : Nat; "Number of layers using tight pattern bounds (default: 1)" - perRowPatternLayers : Nat; "Number of layers using per-row MLP propagation (default: 0)" - bestMatch; "Use best-match (single-query) pattern bounds" - queryPos : Nat; "Query position for best-match bounds (default: last position)" - input : String; "Optional input .nfpt file for local bounds (must contain EMBEDDINGS \ -for legacy text)" - delta : String; "Input ℓ∞ radius δ for local bounds (default: 0)" - soundnessBits : Nat; "Dyadic sqrt precision bits for LayerNorm bounds (default: 20)" - softmaxExpEffort : Nat; "Effort level for margin-based exp lower bounds (default: 1)" - maxSeqLen : Nat; "Maximum sequence length to analyze (default: 256)" - o, output : String; "Write report to file instead of stdout" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The sound-cache-check subcommand (CI regression test). -/ -def soundCacheCheckCmd : Cmd := `[Cli| - sound_cache_check VIA runSoundCacheCheck; - "Check SOUND fixed-point cache soundness (CI / small fixtures)." - FLAGS: - scalePow10 : Nat; "Fixed-point scale exponent p in S=10^p (default: 9)" - maxTokens : Nat; "Check at most this many numeric tokens (0=all; default: 0)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The rope subcommand. -/ -def ropeCmd : Cmd := `[Cli| - rope VIA runRoPE; - "Static certificate for RoPE (rotary position embedding) linearization bounds." - FLAGS: - seqLen : Nat; "Sequence length (>0) used for instantiation (default: 4)" - pairs : Nat; "Number of RoPE pairs (>0); dimension is 2*pairs (default: 8)" -] - -/-- The dump subcommand. -/ -def dumpCmd : Cmd := `[Cli| - dump VIA runDump; - "Dump a small forward-pass slice (for PyTorch sanity checking)." - FLAGS: - layer : Nat; "Layer index (default: 0)" - pos : Nat; "Token position / row index (default: 0)" - take : Nat; "How many columns to dump from the start (default: 16)" - kind : String; "embeddings | layerInput | postAttn | afterLayer (default: afterLayer)" - ARGS: - model : String; "Path to the model weights file (.nfpt)" -] - -/-- The main CLI command. -/ -def nfpCmd : Cmd := `[Cli| - nfp NOOP; - "NFP: Neural Formal Pathways verification toolkit" - SUBCOMMANDS: - analyzeCmd; - inductionCmd; - certifyCmd; - headBoundsCmd; - headPatternCmd; - inductionCertCmd; - soundCacheCheckCmd; - ropeCmd; - dumpCmd -] - -/-- Main entry point. -/ -def main (args : List String) : IO UInt32 := do - let ctx ← openPendingStdoutLog - stdoutLogCtxRef.set (some ctx) - let out ← IO.getStdout - let log := IO.FS.Stream.ofHandle ctx.handle - let tee := mkTeeStream out log - IO.withStdout tee <| do - try - if args.contains "--version" then - setStdoutLogName "version" - IO.println "nfp version 0.1.0" - return (0 : UInt32) - nfpCmd.validate args - finally - let pending ← ctx.pendingRef.get - if pending then - setStdoutLogName "no_model" - stdoutLogCtxRef.set none +end diff --git a/Nfp.lean b/Nfp.lean index 1967251..2dae38d 100644 --- a/Nfp.lean +++ b/Nfp.lean @@ -1,287 +1,15 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Prob -import Nfp.PCC -import Nfp.Uniqueness -import Nfp.Appendix -import Nfp.Mixer -import Nfp.Reroute.Heat -import Nfp.Reroute.Partition -import Nfp.Influence -import Nfp.MixerLocalSystem -import Nfp.Attribution -import Nfp.Layers -import Nfp.SignedMixer -import Nfp.Linearization -import Nfp.Abstraction -import Nfp.Induction -import Nfp.Discovery -import Nfp.Verification -import Nfp.IO -import Nfp.Sound.IO -import Nfp.Sound.Bridge -import Nfp.Sound.Demo +module + +public import Nfp.Core +public import Nfp.Prob +public import Nfp.Mixer +public import Nfp.System +public import Nfp.Circuit +public import Nfp.Model +public import Nfp.Sound /-! -Axioms used by key theorems/definitions -These `#print axioms` lines help ensure we only depend on a small set of axioms -(ideally a subset of: `propext`, `Classical.choice`, `Quot.sound`). +Top-level reexports and trust dashboard for the NFP rewrite. -/ - --- From Nfp.Prob -#print axioms Nfp.ProbVec.sum_mass -#print axioms Nfp.ProbVec.pure -#print axioms Nfp.ProbVec.mix - --- From Nfp.PCC -#print axioms Nfp.tracerOfContrib -#print axioms Nfp.tracerOfContrib_mass -#print axioms Nfp.sum_monotone_chain -#print axioms Nfp.monotone_removed_mass --- From Nfp.Uniqueness -#print axioms Nfp.LocalSystem.tracer_unique --- From Nfp.Appendix (PCC-focused lemmas) -#print axioms Nfp.drop_eq_mass -#print axioms Nfp.pcc_upper_bound_for_any_mask -#print axioms Nfp.normalized_sum_monotone --- Residuals -#print axioms Nfp.lambdaEC -#print axioms Nfp.lambdaEC_sum_one -#print axioms Nfp.residual_lambda_from_norm --- Greedy optimality -#print axioms Nfp.greedy_topk_optimal --- Budgeted PCC sup -#print axioms Nfp.normMass -#print axioms Nfp.normMass_union -#print axioms Nfp.normMass_partitionUnion -#print axioms Nfp.normMass_partition_eq_one -#print axioms Nfp.ReroutePlan.masks -#print axioms Nfp.ReroutePlan.normMass_sum_one -#print axioms Nfp.unionParts -#print axioms Nfp.disjoint_unionParts -#print axioms Nfp.feasible -#print axioms Nfp.pccArg -#print axioms Nfp.pccMax -#print axioms Nfp.pccMax_le_tau -#print axioms Nfp.pccMax_monotone -#print axioms Nfp.pccMax_dominates -#print axioms Nfp.greedy_topmass_optimal - --- PCC(t) alias and properties -#print axioms Nfp.PCC_monotone -#print axioms Nfp.PCC_upper_bounds_masks - --- Normalization utilities and uniqueness -#print axioms Nfp.normalizeOn_outside -#print axioms Nfp.normalizeOn_inside -#print axioms Nfp.normalizeOn_sum_one -#print axioms Nfp.proportional_row_unique --- Residual characterization (Appendix A.3) -#print axioms Nfp.lambdaEC_scale_invariant_global - --- Appendix A.1 consolidated wrappers and packaged theorem -#print axioms Nfp.A1_residual_unique -#print axioms Nfp.A1_normalize_unique -#print axioms Nfp.A1_global_tracer_unique -#print axioms Nfp.A1 - --- Forward mixers (Appendix A.1) -#print axioms Nfp.Mixer.push -#print axioms Nfp.Mixer.row -#print axioms Nfp.Mixer.push_pure -#print axioms Nfp.Mixer.push_mix -#print axioms Nfp.Mixer.comp - --- Appendix A.2 wrappers and packaged theorem -#print axioms Nfp.A2_graph_faithful_comp -#print axioms Nfp.A2_residual_energy_consistent -#print axioms Nfp.A2_residual_unique -#print axioms Nfp.A2_normalize_row_sum_one -#print axioms Nfp.A2 - --- From Nfp.Mixer (support restriction) -#print axioms Nfp.Mixer.supported_comp -#print axioms Nfp.Mixer.supp_push_subset_image - --- From Nfp.Influence -#print axioms Nfp.Mixer.ofInfluenceSpec -#print axioms Nfp.Mixer.ofInfluenceSpec_supported - --- From Nfp.MixerLocalSystem -#print axioms Nfp.LocalSystem.ofMixerIdx -#print axioms Nfp.LocalSystem.ofMixer - --- From Nfp.Reroute.Partition -#print axioms Nfp.ReroutePlan.increments -#print axioms Nfp.ReroutePlan.increments_pairwise -#print axioms Nfp.sum_unionParts_eq - --- From Nfp.Reroute.Heat -#print axioms Nfp.WeightedReroutePlan.rerouteHeat -#print axioms Nfp.WeightedReroutePlan.heatRaw_sum_increment -#print axioms Nfp.WeightedReroutePlan.rerouteHeat_sum_increment - --- From Nfp.Attribution (attribution axioms for NN interpretation) -#print axioms Nfp.Attribution.Complete -#print axioms Nfp.Attribution.CompleteScalar -#print axioms Nfp.Attribution.Sensitive -#print axioms Nfp.Attribution.Dummy -#print axioms Nfp.Attribution.Symmetric -#print axioms Nfp.attributionOfProbVec -#print axioms Nfp.tracer_attribution_complete - --- From Nfp.Layers (NN layer mixers) -#print axioms Nfp.Mixer.identity -#print axioms Nfp.Mixer.identity_comp -#print axioms Nfp.Mixer.comp_identity -#print axioms Nfp.Mixer.comp_assoc -#print axioms Nfp.Mixer.attention -#print axioms Nfp.Mixer.selfAttention -#print axioms Nfp.Mixer.residual -#print axioms Nfp.Mixer.residual_one -#print axioms Nfp.Mixer.residual_zero -#print axioms Nfp.Mixer.comp3 -#print axioms Nfp.Mixer.comp3_path_decomposition -#print axioms Nfp.Mixer.transformerBlock -#print axioms Nfp.attentionFlow -#print axioms Nfp.attentionFlow_singleton -#print axioms Nfp.attentionFlow_nil -#print axioms Nfp.effectiveAttention_normalized -#print axioms Nfp.pathContrib_sum -#print axioms Nfp.Mixer.push_preserves_total_mass -#print axioms Nfp.Mixer.push_comp -#print axioms Nfp.transformer_attribution_unique - --- Residual stream analysis -#print axioms Nfp.Mixer.residual_decomposition -#print axioms Nfp.Mixer.residual_skip_dominance -#print axioms Nfp.Mixer.residual_off_diag_bound -#print axioms Nfp.Mixer.residual_off_diag_scaling - --- Attention concentration -#print axioms Nfp.Mixer.maxWeight -#print axioms Nfp.Mixer.weight_le_maxWeight -#print axioms Nfp.Mixer.push_concentration_bound - --- Ablation analysis -#print axioms Nfp.Mixer.maskFn -#print axioms Nfp.Mixer.maskFn_blocked -#print axioms Nfp.Mixer.maskFn_unblocked -#print axioms Nfp.blockedContribution -#print axioms Nfp.unblockedContribution -#print axioms Nfp.Mixer.ablation_decomposition - --- Composition depth and reachability -#print axioms Nfp.Mixer.reachable -#print axioms Nfp.Mixer.reach_comp -#print axioms Nfp.Mixer.path_contrib_le_comp -#print axioms Nfp.Mixer.comp_reachable_of_path - --- Information bounds -#print axioms Nfp.Mixer.supportSize -#print axioms Nfp.Mixer.exists_nonzero -#print axioms Nfp.Mixer.supportSize_pos - --- Gradient correspondence (chain rule) -#print axioms Nfp.Mixer.chain_rule_analog -#print axioms Nfp.Mixer.chain_rule_three - --- Multi-head attention -#print axioms Nfp.Mixer.multiHead -#print axioms Nfp.Mixer.multiHead_head_contrib_bound -#print axioms Nfp.Mixer.multiHead_zero_weight -#print axioms Nfp.Mixer.multiHead_single_head -#print axioms Nfp.Mixer.multiHead_convex - --- Causal (autoregressive) attention -#print axioms Nfp.isCausal -#print axioms Nfp.causal_reachable_dir -#print axioms Nfp.causal_comp -#print axioms Nfp.causal_future_attention_zero -#print axioms Nfp.causal_past_attention_one -#print axioms Nfp.causal_first_token_self - --- Attention head analysis -#print axioms Nfp.Mixer.attentionConcentration -#print axioms Nfp.Mixer.attentionConcentration_upper_bound -#print axioms Nfp.Mixer.isSparseAt -#print axioms Nfp.Mixer.isDiffuseAt -#print axioms Nfp.Mixer.attentionConcentration_one_hot - --- Residual dominance analysis -#print axioms Nfp.residual_diagonal_lower -#print axioms Nfp.residual_offdiag_scale -#print axioms Nfp.residual_offdiag_sum_bound - --- Deep composition analysis -#print axioms Nfp.comp_weight_le_one -#print axioms Nfp.comp_term_bound - --- Cross-attention for encoder-decoder models -#print axioms Nfp.Mixer.crossAttention -#print axioms Nfp.Mixer.crossAttention_normalized - --- Layer-wise attribution analysis -#print axioms Nfp.layerWiseAttribution -#print axioms Nfp.layerWiseAttribution_nil -#print axioms Nfp.layerWiseAttribution_singleton -#print axioms Nfp.layerWiseAttribution_le_one -#print axioms Nfp.layerWiseAttribution_sum_one - --- From Nfp.SignedMixer (generalized signed weight mixers) -#print axioms Nfp.SignedMixer -#print axioms Nfp.SignedMixer.comp -#print axioms Nfp.SignedMixer.comp_assoc -#print axioms Nfp.SignedMixer.identity -#print axioms Nfp.SignedMixer.positivePart -#print axioms Nfp.SignedMixer.negativePart -#print axioms Nfp.SignedMixer.decompose -#print axioms Nfp.SignedMixer.rowSum -#print axioms Nfp.SignedMixer.IsRowStochastic -#print axioms Nfp.SignedMixer.toMixer -#print axioms Nfp.SignedMixer.ofMixer -#print axioms Nfp.SignedMixer.ofMixer_isRowStochastic -#print axioms Nfp.SignedMixer.influence -#print axioms Nfp.SignedMixer.totalInfluenceFrom -#print axioms Nfp.SignedMixer.totalInfluenceOn -#print axioms Nfp.SignedMixer.apply -#print axioms Nfp.SignedMixer.apply_comp -#print axioms Nfp.SignedMixer.jacobianEntry_eq_weight - --- AffineMixer (linear + bias) -#print axioms Nfp.AffineMixer -#print axioms Nfp.AffineMixer.linear -#print axioms Nfp.AffineMixer.bias -#print axioms Nfp.AffineMixer.apply -#print axioms Nfp.AffineMixer.comp - --- From Nfp.Linearization (linearizing non-linear ops) -#print axioms Nfp.Linearization -#print axioms Nfp.Linearization.comp -#print axioms Nfp.relu -#print axioms Nfp.reluGrad -#print axioms Nfp.reluLinearization -#print axioms Nfp.reluLinearization_diagonal -#print axioms Nfp.reluLinearization_diag_binary -#print axioms Nfp.gelu -#print axioms Nfp.geluGrad -#print axioms Nfp.geluLinearization -#print axioms Nfp.layerNorm -#print axioms Nfp.layerNormJacobian -#print axioms Nfp.layerNorm_translation_invariant -#print axioms Nfp.softmax -#print axioms Nfp.softmaxJacobian -#print axioms Nfp.softmax_nonneg -#print axioms Nfp.softmax_sum_one -#print axioms Nfp.softmaxJacobian_diag_pos -#print axioms Nfp.softmaxJacobian_off_diag_neg -#print axioms Nfp.softmax_translation_invariant -#print axioms Nfp.ropeJacobian -#print axioms Nfp.rope -#print axioms Nfp.ropeJacobian_cross_pos -#print axioms Nfp.rope_operatorNormBound_le_two -#print axioms Nfp.gradientTimesInput -#print axioms Nfp.gradientTimesInput_complete -#print axioms Nfp.composed_attribution -#print axioms Nfp.integratedGradientsLinear -#print axioms Nfp.integratedGradients_linear_complete diff --git a/Nfp/Abstraction.lean b/Nfp/Abstraction.lean deleted file mode 100644 index 72fc5f3..0000000 --- a/Nfp/Abstraction.lean +++ /dev/null @@ -1,432 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Set.Basic -import Mathlib.Logic.Equiv.Defs -import Mathlib.Data.Fin.Basic -import Nfp.Linearization -import Nfp.Uniqueness - -/-! -# Causal Consistency of Circuit Abstractions - -This module bridges the gap between **real neural network computations** -(`DeepLinearization`, containing weights and Jacobians) and **abstract causal graphs** -(`LocalSystem`, containing topology and mixing coefficients). - -## Main Results - -1. **Projection**: `DeepLinearization.toLocalSystem` extracts a causal DAG from a - network's `DeepValueTerm` (the "Attention Rollout" component). - -2. **Interventions**: `SignedMixer.ablate` and `DeepLinearization.ablate` formalize - node removal / path zeroing interventions. - -3. **Causal Consistency Theorem**: If the `DeepPatternTerm` (linearization error) - is bounded by ε, then the real network's output under ablation matches the - `LocalSystem`'s prediction within O(ε). - -## Significance - -This transforms the library from a descriptive tool into a **verification engine**. -Practitioners can input real model weights and receive a mathematical certificate -that a discovered "induction head" or "circuit" is a **genuine mechanism** and not -an interpretability illusion. - -The key insight is: -- `LocalSystem` computes via: T(i) = Σ_{u ∈ Pa(i)} c(i,u) · T(u) -- `DeepValueTerm` approximates the Jacobian via attention flow -- When `DeepPatternTerm` is small, interventions on the abstract graph - accurately predict interventions on the real network. --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Signed Mixer Ablation -/ - -section SignedMixerAblation - -variable {S T : Type*} [Fintype S] [Fintype T] - -/-- Ablate (zero out) specific source positions in a SignedMixer. - -This models the intervention "what if we remove position i from contributing?" -Used for causal intervention analysis. -/ -noncomputable def SignedMixer.ablate (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] : SignedMixer S T where - w := fun i j => if blocked i then 0 else M.w i j - -@[simp] lemma SignedMixer.ablate_blocked (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] {i : S} (hi : blocked i) (j : T) : - (M.ablate blocked).w i j = 0 := by - simp [SignedMixer.ablate, hi] - -@[simp] lemma SignedMixer.ablate_unblocked (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] {i : S} (hi : ¬blocked i) (j : T) : - (M.ablate blocked).w i j = M.w i j := by - simp [SignedMixer.ablate, hi] - -/-- The effect of ablation on application to a vector. -/ -theorem SignedMixer.apply_ablate (M : SignedMixer S T) (blocked : Set S) - [DecidablePred blocked] (v : S → ℝ) (j : T) : - (M.ablate blocked).apply v j = ∑ i : S, if blocked i then 0 else v i * M.w i j := by - simp only [SignedMixer.apply_def, SignedMixer.ablate] - apply Finset.sum_congr rfl - intro i _ - split_ifs <;> ring - -/-- Ablation decomposes application into blocked and unblocked contributions. -/ -theorem SignedMixer.apply_ablate_decomposition (M : SignedMixer S T) - (blocked : Set S) [DecidablePred blocked] (v : S → ℝ) (j : T) : - M.apply v j = (M.ablate blocked).apply v j + - ∑ i : S, if blocked i then v i * M.w i j else 0 := by - simp only [SignedMixer.apply_def, SignedMixer.ablate] - rw [← Finset.sum_add_distrib] - apply Finset.sum_congr rfl - intro i _ - split_ifs <;> ring - -end SignedMixerAblation - -/-! ## Deep Linearization Ablation -/ - -section DeepLinearizationAblation - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- Ablate the `DeepValueTerm` by zeroing out contributions from blocked positions. -/ -noncomputable def DeepLinearization.ablateValueTerm - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - SignedMixer (n × d) (n × d) := - (DeepValueTerm D).ablate blocked - -/-- Ablate the full `composedJacobian` by zeroing out contributions from blocked positions. -/ -noncomputable def DeepLinearization.ablateJacobian - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - SignedMixer (n × d) (n × d) := - D.composedJacobian.ablate blocked - -/-- The difference between ablating the full Jacobian vs the value term. -/ -noncomputable def DeepLinearization.ablationError - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - SignedMixer (n × d) (n × d) := - D.ablateJacobian blocked - D.ablateValueTerm blocked - -/-- Ablation error equals ablated pattern term. -/ -theorem DeepLinearization.ablationError_eq_ablatedPatternTerm - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] : - D.ablationError blocked = (DeepPatternTerm D).ablate blocked := by - ext i j - simp only [ablationError, ablateJacobian, ablateValueTerm, SignedMixer.sub_w, - SignedMixer.ablate, DeepPatternTerm] - split_ifs with h - · simp - · rfl - -end DeepLinearizationAblation - -/-! ## Projection to LocalSystem -/ - -section Projection - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- Extract an NNReal coefficient from the absolute value of a real weight. - -This is used when projecting a `SignedMixer` (which has real weights) to a -`LocalSystem` (which uses NNReal coefficients for mixing). The absolute value -captures the "influence magnitude" regardless of sign. -/ -noncomputable def absToNNReal (x : ℝ) : NNReal := ⟨|x|, abs_nonneg x⟩ - -/-- A position-level signed mixer extracted from collapsing the dimension axes. -/ -noncomputable def positionMixer (M : SignedMixer (n × d) (n × d)) : - SignedMixer n n where - w := fun i j => ∑ di : d, ∑ dj : d, M.w (i, di) (j, dj) - -/-- The magnitude of position-to-position flow (for LocalSystem coefficients). -/ -noncomputable def positionFlowMagnitude (M : SignedMixer (n × d) (n × d)) - (i j : n) : NNReal := - absToNNReal ((positionMixer M).w i j) - -/-- Extract a `LocalSystem` from a `DeepLinearization` using its `DeepValueTerm`. - -The resulting graph has: -- Nodes: positions (type `n`) numbered by `e : n ≃ Fin k` -- Parents: positions with nonzero attention flow -- Coefficients: magnitude of position-to-position value term flow - -This represents the "attention rollout" approximation as a causal DAG. -/ -noncomputable def DeepLinearization.toLocalSystem - (D : DeepLinearization (n := n) (d := d)) - {k : ℕ} (e : n ≃ Fin k) - (acyclic : ∀ i j : n, (positionMixer (DeepValueTerm D)).w i j ≠ 0 → e i < e j) : - LocalSystem k := by - classical - let posOf : Fin k → n := e.symm - exact { - Pa := fun idx => - Finset.univ.filter fun u : Fin k => - (positionMixer (DeepValueTerm D)).w (posOf u) (posOf idx) ≠ 0 - c := fun idx u => - positionFlowMagnitude (DeepValueTerm D) (posOf u) (posOf idx) - topo := by - intro idx u hu - have hmem := Finset.mem_filter.mp hu - have hweight : (positionMixer (DeepValueTerm D)).w (posOf u) (posOf idx) ≠ 0 := hmem.2 - have htopo : e (posOf u) < e (posOf idx) := acyclic _ _ hweight - simpa [posOf] using htopo - } - -/-- The parent set of position `i` in the extracted LocalSystem. -/ -theorem DeepLinearization.toLocalSystem_Pa - (D : DeepLinearization (n := n) (d := d)) - {k : ℕ} (e : n ≃ Fin k) - (acyclic : ∀ i j : n, (positionMixer (DeepValueTerm D)).w i j ≠ 0 → e i < e j) - (idx : Fin k) : - (D.toLocalSystem e acyclic).Pa idx = - Finset.univ.filter fun u : Fin k => - (positionMixer (DeepValueTerm D)).w (e.symm u) (e.symm idx) ≠ 0 := rfl - -/-- The coefficient for parent `u` of position `idx` in the extracted LocalSystem. -/ -theorem DeepLinearization.toLocalSystem_c - (D : DeepLinearization (n := n) (d := d)) - {k : ℕ} (e : n ≃ Fin k) - (acyclic : ∀ i j : n, (positionMixer (DeepValueTerm D)).w i j ≠ 0 → e i < e j) - (idx u : Fin k) : - (D.toLocalSystem e acyclic).c idx u = - positionFlowMagnitude (DeepValueTerm D) (e.symm u) (e.symm idx) := rfl - -end Projection - -/-! ## Causal Consistency -/ - -section CausalConsistency - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- The error norm when comparing ablated computations. -/ -noncomputable def ablationDiscrepancy - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) (j : n × d) : ℝ := - |(D.ablateJacobian blocked).apply v j - (D.ablateValueTerm blocked).apply v j| - -/-- **Causal Consistency Bound**: The ablation discrepancy is bounded by the -pattern term's influence on the input. - -If position `i` is blocked, the discrepancy at output `j` is bounded by: - |ablated_real - ablated_abstract| ≤ Σ_{i ∉ blocked} |v_i| · |PatternTerm_{i,j}| - -This shows that when the pattern term is small, interventions on the abstract -`LocalSystem` accurately predict interventions on the real network. -/ -theorem causal_consistency_bound - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) (j : n × d) : - ablationDiscrepancy D blocked v j ≤ - ∑ i : n × d, if blocked i then 0 else |v i| * |(DeepPatternTerm D).w i j| := by - simp only [ablationDiscrepancy] - -- The key insight: ablation error = ablated pattern term - have h := D.ablationError_eq_ablatedPatternTerm blocked - -- The difference in applications - calc |(D.ablateJacobian blocked).apply v j - (D.ablateValueTerm blocked).apply v j| - = |(D.ablationError blocked).apply v j| := by - congr 1 - simp only [DeepLinearization.ablationError, SignedMixer.apply_def, SignedMixer.sub_w] - rw [← Finset.sum_sub_distrib] - apply Finset.sum_congr rfl - intro i _ - ring - _ = |((DeepPatternTerm D).ablate blocked).apply v j| := by - rw [h] - _ = |∑ i : n × d, if blocked i then 0 else v i * (DeepPatternTerm D).w i j| := by - simp only [SignedMixer.apply_ablate] - _ ≤ ∑ i : n × d, |if blocked i then 0 else v i * (DeepPatternTerm D).w i j| := - abs_sum_le_sum_abs _ _ - _ = ∑ i : n × d, if blocked i then 0 else |v i| * |(DeepPatternTerm D).w i j| := by - apply Finset.sum_congr rfl - intro i _ - by_cases hb : blocked i - · simp [hb] - · simp [hb, abs_mul] - -/-- **Simplified Frobenius-style bound**: - -The total squared ablation discrepancy is bounded by the product of -input energy and pattern term Frobenius norm squared. - -This is a cleaner statement that captures the key O(ε) relationship. -/ -theorem causal_consistency_frobenius_simple - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - -- We use a direct bound via the pointwise bounds - calc ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 - ≤ ∑ j : n × d, (∑ i : n × d, |v i| * |(DeepPatternTerm D).w i j|) ^ 2 := by - apply Finset.sum_le_sum - intro j _ - apply sq_le_sq' - · have hnn := Finset.sum_nonneg - (fun i (_ : i ∈ Finset.univ) => mul_nonneg (abs_nonneg (v i)) - (abs_nonneg ((DeepPatternTerm D).w i j))) - calc -(∑ i : n × d, |v i| * |(DeepPatternTerm D).w i j|) - ≤ 0 := neg_nonpos.mpr hnn - _ ≤ ablationDiscrepancy D blocked v j := abs_nonneg _ - · have hbound := causal_consistency_bound D blocked v j - refine le_trans hbound ?_ - apply Finset.sum_le_sum - intro i _ - by_cases hb : blocked i - · have hnonneg : - 0 ≤ |v i| * |(DeepPatternTerm D).w i j| := - mul_nonneg (abs_nonneg _) (abs_nonneg _) - simpa [hb] using hnonneg - · simp [hb] - _ ≤ ∑ j : n × d, (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - apply Finset.sum_le_sum - intro j _ - -- Cauchy-Schwarz: (Σ ab)² ≤ (Σ a²)(Σ b²) - have cs : (∑ i : n × d, |v i| * |(DeepPatternTerm D).w i j|) ^ 2 ≤ - (∑ i : n × d, (|v i|) ^ 2) * (∑ i : n × d, (|(DeepPatternTerm D).w i j|) ^ 2) := - Finset.sum_mul_sq_le_sq_mul_sq Finset.univ (fun i => |v i|) - (fun i => |(DeepPatternTerm D).w i j|) - simp only [sq_abs] at cs - exact cs - _ = (∑ i : n × d, (v i) ^ 2) * (∑ j : n × d, ∑ i : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - rw [Finset.mul_sum] - _ = (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2) := by - congr 1 - rw [Finset.sum_comm] - -/-- **Main Causal Consistency Theorem**: - -If a network's `DeepPatternTerm` has Frobenius norm bounded by ε, then -interventions (ablations) on the extracted `LocalSystem` predict the -real network's behavior within O(ε) error. - -Specifically: Σⱼ (discrepancy_j)² ≤ ε² · Σᵢ (vᵢ)² - -This is the key result that turns the library into a verification engine: -- Small pattern term → attention rollout is faithful -- Faithful rollout → discovered circuits are genuine mechanisms -- Genuine mechanisms → interventions have predictable effects -/ -theorem causal_consistency_frobenius - (D : DeepLinearization (n := n) (d := d)) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) - (ε : ℝ) (hε : frobeniusNorm (DeepPatternTerm D) ≤ ε) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - ε ^ 2 * (∑ i : n × d, (v i) ^ 2) := by - have h1 := causal_consistency_frobenius_simple D blocked v - have h2 : ∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2 ≤ ε ^ 2 := by - calc ∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2 - = (frobeniusNorm (DeepPatternTerm D)) ^ 2 := by - simp only [frobeniusNorm] - rw [Real.sq_sqrt] - exact Finset.sum_nonneg (fun i _ => Finset.sum_nonneg (fun j _ => sq_nonneg _)) - _ ≤ ε ^ 2 := by - apply sq_le_sq' - · calc -ε ≤ 0 := by - by_contra hne - push_neg at hne - have : frobeniusNorm (DeepPatternTerm D) ≤ ε := hε - have hpos : 0 ≤ frobeniusNorm (DeepPatternTerm D) := by - simp only [frobeniusNorm] - exact Real.sqrt_nonneg _ - linarith - _ ≤ frobeniusNorm (DeepPatternTerm D) := by - simp only [frobeniusNorm] - exact Real.sqrt_nonneg _ - · exact hε - calc ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 - ≤ (∑ i : n × d, (v i) ^ 2) * (∑ i : n × d, ∑ j : n × d, ((DeepPatternTerm D).w i j) ^ 2) := h1 - _ ≤ (∑ i : n × d, (v i) ^ 2) * ε ^ 2 := by - apply mul_le_mul_of_nonneg_left h2 - exact Finset.sum_nonneg (fun i _ => sq_nonneg _) - _ = ε ^ 2 * (∑ i : n × d, (v i) ^ 2) := by ring - -end CausalConsistency - -/-! ## Mechanism Certification -/ - -section MechanismCertification - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- A circuit is **causally certified** if its pattern term error is below threshold. - -This means interventions on the abstract `LocalSystem` derived from the circuit -will accurately predict the real network's behavior. -/ -def isCausallyCertified (D : DeepLinearization (n := n) (d := d)) (threshold : ℝ) : Prop := - frobeniusNorm (DeepPatternTerm D) ≤ threshold - -/-- A certified circuit's ablations are faithful within the error bound. -/ -theorem certified_ablation_faithful - (D : DeepLinearization (n := n) (d := d)) - (threshold : ℝ) - (hcert : isCausallyCertified D threshold) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - threshold ^ 2 * (∑ i : n × d, (v i) ^ 2) := - causal_consistency_frobenius D blocked v threshold hcert - -/-- A mechanism discovered via interpretability is **verified** if: -1. The extracted `LocalSystem` has the expected structure (e.g., induction head pattern) -2. The circuit is causally certified (small pattern term) - -When both hold, the mechanism is a genuine causal explanation of the network's behavior, -not an interpretability illusion. -/ -structure VerifiedMechanism (D : DeepLinearization (n := n) (d := d)) where - /-- The certification threshold -/ - threshold : ℝ - /-- The threshold is positive -/ - threshold_pos : 0 < threshold - /-- The circuit meets the certification bound -/ - certified : isCausallyCertified D threshold - /-- Description of the discovered mechanism (for documentation) -/ - description : String - -/-- Any verified mechanism satisfies causal consistency. -/ -theorem VerifiedMechanism.causal_consistency - {D : DeepLinearization (n := n) (d := d)} - (M : VerifiedMechanism D) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 ≤ - M.threshold ^ 2 * (∑ i : n × d, (v i) ^ 2) := - certified_ablation_faithful D M.threshold M.certified blocked v - -/-- The RMS discrepancy is bounded by threshold times RMS input. -/ -theorem VerifiedMechanism.rms_bound - {D : DeepLinearization (n := n) (d := d)} - (M : VerifiedMechanism D) - (blocked : Set (n × d)) [DecidablePred blocked] - (v : (n × d) → ℝ) : - Real.sqrt (∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2) ≤ - M.threshold * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - have h := M.causal_consistency blocked v - have hpos : 0 ≤ ∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2 := - Finset.sum_nonneg (fun j _ => sq_nonneg _) - calc Real.sqrt (∑ j : n × d, (ablationDiscrepancy D blocked v j) ^ 2) - ≤ Real.sqrt (M.threshold ^ 2 * (∑ i : n × d, (v i) ^ 2)) := Real.sqrt_le_sqrt h - _ = |M.threshold| * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - rw [Real.sqrt_mul (sq_nonneg _), Real.sqrt_sq_eq_abs] - _ = M.threshold * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - rw [abs_of_pos M.threshold_pos] - -end MechanismCertification - -end Nfp diff --git a/Nfp/Appendix.lean b/Nfp/Appendix.lean deleted file mode 100644 index ddfac2d..0000000 --- a/Nfp/Appendix.lean +++ /dev/null @@ -1,956 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Indicator -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Order.MinMax -import Mathlib.Data.Finset.Lattice.Basic -import Nfp.Prob -import Nfp.PCC -import Nfp.Uniqueness -import Nfp.Mixer -import Nfp.Influence -import Nfp.Reroute.Partition -import Nfp.Reroute.Heat - -namespace Nfp - -open scoped BigOperators -open Finset - -namespace List - -lemma take_succ_drop_append {α : Type*} : - ∀ {xs : List α} {k : ℕ}, - xs.take (k + 1) = xs.take k ++ (xs.drop k).take 1 - | [], 0 => by simp - | [], k+1 => by simp - | _ :: _, 0 => by simp - | x :: xs, k + 1 => by - have ih := - take_succ_drop_append (xs:=xs) (k:=k) - simp [ih] - -end List - -/-! -# Appendix A mapping (Lean formalization) - -This file collects the statements used to mirror the Appendix A results: - -- Appendix A.4 (PCC curve): `normMass`, `feasible`, `pccArg`, `pccMax`, - `PCC`, monotonicity (`PCC_monotone`), and the dominance/upper-bound lemmas - (`pccMax_dominates`, `PCC_upper_bounds_masks`), plus the existence/greedy view - `greedy_topmass_optimal`. -- Appendix A.3 (Residual rule): `lambdaEC`, `lambdaEC_sum_one`, - `residual_lambda_from_norm`, and the global characterization - `lambdaEC_scale_invariant_global`. -- Appendix A.2 (Normalization on a subset): `normalizeOn` and its properties - (`normalizeOn_outside`, `normalizeOn_inside`, `normalizeOn_sum_one`), and the - uniqueness theorem `proportional_row_unique`. - -All proofs avoid `sorry`. Some results use classical reasoning (`classical`). --/ - -section PCC - -variable {S : Type*} [Fintype S] - -variable (m : Nfp.Contrib S) - -lemma drop_eq_mass (A : Finset S) : - (A.sum m) / (∑ i, m i) = (A.sum (fun i => m i / (∑ j, m j))) := by - classical - -- Division distributes over finite sums in `NNReal`. - simp [Finset.sum_div] - -/-- Appendix A.4 (upper bound): any mask with tracer budget ≤ τ removes at most τ -of normalized logit. -/ -lemma pcc_upper_bound_for_any_mask (A : Finset S) (τ : NNReal) - (hA : (A.sum (fun i => m i / (∑ j, m j))) ≤ τ) : - (A.sum m) / (∑ j, m j) ≤ τ := by - simpa [drop_eq_mass (m:=m) (A:=A)] using hA - -/-- Appendix A.4 (monotonicity helper): normalized removed mass is monotone along -any nested mask chain. -/ -lemma normalized_sum_monotone (A : ℕ → Finset S) - (hchain : ∀ k, A k ⊆ A (k + 1)) : - Monotone (fun k => (A k).sum (fun i => m i / (∑ j, m j))) := by - classical - -- Apply monotonicity of raw sums to the pointwise scaled weights. - have := Nfp.sum_monotone_chain (A:=A) (w:=fun i => m i / (∑ j, m j)) hchain - simpa using this - -/-- Appendix A.4: normalized mass of a mask `A`. -/ -noncomputable def normMass (m : S → NNReal) (A : Finset S) : NNReal := - (A.sum m) / (∑ j, m j) - -@[simp] lemma normMass_empty (m : S → NNReal) : normMass m (∅ : Finset S) = 0 := by - simp [normMass] - -lemma normMass_union [DecidableEq S] (m : S → NNReal) (A B : Finset S) - (hdisj : Disjoint A B) : - normMass m (A ∪ B) = normMass m A + normMass m B := by - classical - have hsum := Finset.sum_union hdisj (f := m) - simp [normMass, hsum, add_div] - -lemma normMass_partitionUnion [DecidableEq S] (m : S → NNReal) - (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) : - normMass m (unionParts parts) = - parts.foldr (fun A acc => normMass m A + acc) 0 := by - classical - induction parts with - | nil => - simp [unionParts] - | cons A parts ih => - rcases List.pairwise_cons.mp hpair with ⟨hA, htail⟩ - have hdisj : Disjoint A (unionParts parts) := - disjoint_unionParts A parts (by - intro B hB - exact hA _ (by simpa using hB)) - have hAdd := - normMass_union (m:=m) A (unionParts parts) hdisj - simp [unionParts, hAdd, ih htail, List.foldr] - -lemma normMass_partition_eq_one [DecidableEq S] (m : S → NNReal) - (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - (hcover : unionParts parts = (Finset.univ : Finset S)) - (htot : (∑ j, m j) ≠ 0) : - parts.foldr (fun A acc => normMass m A + acc) 0 = 1 := by - classical - have hsum : - parts.foldr (fun A acc => normMass m A + acc) 0 = normMass m (unionParts parts) := by - simpa using (normMass_partitionUnion (m:=m) parts hpair).symm - simpa [hcover, normMass, htot] using hsum - -namespace ReroutePlan - -section - -variable {S : Type*} [Fintype S] [DecidableEq S] - -lemma normMass_sum_one (m : S → NNReal) (P : ReroutePlan (S := S)) - (htot : (∑ j, m j) ≠ 0) : - P.masks.foldr (fun A acc => normMass m A + acc) 0 = 1 := by - simpa [ReroutePlan.masks] using - normMass_partition_eq_one (S := S) m P.masks P.pairwise_disjoint P.covers_univ htot - -end - -end ReroutePlan - -lemma normMass_rerouteHeat_increment {S : Type*} [Fintype S] [DecidableEq S] - (P : WeightedReroutePlan (S := S)) {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - normMass (fun i => (P.rerouteHeat).mass i) A = w / P.weightsSum := by - classical - have hsum := WeightedReroutePlan.rerouteHeat_sum_increment (P:=P) hmem - have hdenom : (∑ i, (P.rerouteHeat).mass i) = 1 := ProbVec.sum_mass _ - have hsum' : - (∑ i ∈ A, P.heatRaw i / P.weightsSum) = w / P.weightsSum := by - simpa [WeightedReroutePlan.rerouteHeat_mass] using hsum - have hdenom' : - (∑ i, P.heatRaw i / P.weightsSum) = 1 := by - simpa [WeightedReroutePlan.rerouteHeat_mass] using hdenom - simp [normMass, hsum', hdenom'] - - -/-- Appendix A.4: feasible masks under budget `τ` (by normalized mass). -/ -noncomputable def feasible (m : S → NNReal) (τ : NNReal) : Finset (Finset S) := - (Finset.univ : Finset S).powerset.filter (fun A => normMass m A ≤ τ) - -lemma feasible_nonempty (m : S → NNReal) (τ : NNReal) : - (feasible (S:=S) m τ).Nonempty := by - classical - refine ⟨∅, ?_⟩ - simp [feasible, normMass] - --- Appendix A.4: argmax mask under budget τ w.r.t. normalized mass. -noncomputable def pccArg (m : S → NNReal) (τ : NNReal) : Finset S := by - classical - exact Classical.choose - (Finset.exists_max_image (feasible (S:=S) m τ) (fun A => normMass m A) - (feasible_nonempty (S:=S) m τ)) - -private lemma pccArg_mem (m : S → NNReal) (τ : NNReal) : - pccArg (S:=S) m τ ∈ feasible (S:=S) m τ := by - classical - have h := Classical.choose_spec - (Finset.exists_max_image (feasible (S:=S) m τ) (fun A => normMass m A) - (feasible_nonempty (S:=S) m τ)) - exact h.left - -private lemma pccArg_is_max (m : S → NNReal) (τ : NNReal) : - ∀ B ∈ feasible (S:=S) m τ, normMass m B ≤ normMass m (pccArg (S:=S) m τ) := by - classical - have h := Classical.choose_spec - (Finset.exists_max_image (feasible (S:=S) m τ) (fun A => normMass m A) - (feasible_nonempty (S:=S) m τ)) - exact h.right - -/-- Appendix A.4: maximal normalized drop (the PCC value at budget `τ`). -/ -noncomputable def pccMax (m : S → NNReal) (τ : NNReal) : NNReal := - normMass m (pccArg (S:=S) m τ) - -lemma pccMax_le_tau (m : S → NNReal) (τ : NNReal) : - pccMax (S:=S) m τ ≤ τ := by - classical - have hmem := pccArg_mem (S:=S) m τ - simpa [pccMax] using (Finset.mem_filter.mp hmem).2 - -lemma pccMax_monotone (m : S → NNReal) : - Monotone (fun τ : NNReal => pccMax (S:=S) m τ) := by - classical - intro τ₁ τ₂ hle - -- Any τ₁-feasible set is τ₂-feasible when τ₁ ≤ τ₂. - have hsubset : feasible (S:=S) m τ₁ ⊆ feasible (S:=S) m τ₂ := by - intro A hA - rcases Finset.mem_filter.mp hA with ⟨hpow, hbudget⟩ - have hbudget' : normMass m A ≤ τ₂ := hbudget.trans hle - aesop (add simp [feasible]) - -- The τ₂-argmax is ≥ any τ₁-feasible value, in particular the τ₁-argmax. - have hAτ1 := pccArg_mem (S:=S) m τ₁ - have hAτ1_in : pccArg (S:=S) m τ₁ ∈ feasible (S:=S) m τ₂ := hsubset hAτ1 - have hmaxτ2 := pccArg_is_max (S:=S) m τ₂ (pccArg (S:=S) m τ₁) hAτ1_in - simpa [pccMax] using hmaxτ2 - - -/-- Appendix A.4 (dominance/upper bound): for any mask `B` that removes at most -`τ` of tracer mass (normalized), its normalized drop is upper-bounded by `pccMax m τ`. -Equivalently, `pccMax` is the supremum over all feasible masks at budget `τ`. -/ -lemma pccMax_dominates (m : S → NNReal) (τ : NNReal) (B : Finset S) - (hB : normMass m B ≤ τ) : - normMass m B ≤ pccMax (S:=S) m τ := by - classical - -- Feasibility of B at budget τ - have hB_mem : B ∈ feasible (S:=S) m τ := by - have hpow : B ∈ (Finset.univ : Finset S).powerset := by - simp - exact (Finset.mem_filter.mpr ⟨hpow, hB⟩) - -- Maximality of the arg at τ - have hmax := pccArg_is_max (S:=S) m τ B hB_mem - simpa [pccMax] using hmax - -/-- Appendix A.4 (PCC curve): normalized logit drop achievable at tracer-mass -budget `t`. This matches the budgeted maximum `pccMax` over feasible masks. -/ -noncomputable abbrev PCC (m : S → NNReal) (t : NNReal) : NNReal := pccMax (S:=S) m t - -@[simp] lemma PCC_def (m : S → NNReal) (t : NNReal) : PCC (S:=S) m t = pccMax (S:=S) m t := rfl - -/-- Appendix A.4 (monotonicity): the PCC curve is monotone in the budget `t`. -/ -lemma PCC_monotone (m : S → NNReal) : Monotone (fun t : NNReal => PCC (S:=S) m t) := by - simpa [PCC_def] using pccMax_monotone (S:=S) m - -/-- Appendix A.4 (upper bound, mask phrasing): for any mask `B` that -preserves at least `1 - t` tracer mass (i.e., removes ≤ `t`), the normalized -logit drop of `B` is at most `PCC m t`. -/ -lemma PCC_upper_bounds_masks (m : S → NNReal) (t : NNReal) (B : Finset S) - (hBudget : normMass m B ≤ t) : - normMass m B ≤ PCC (S:=S) m t := by - simpa [PCC_def] using pccMax_dominates (S:=S) m t B hBudget - -lemma rerouteHeat_pcc_drop {S : Type*} [Fintype S] [DecidableEq S] - (P : WeightedReroutePlan (S := S)) - {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - PCC (S:=S) (fun i => (P.rerouteHeat).mass i) (w / P.weightsSum) = w / P.weightsSum := by - classical - set m := fun i => (P.rerouteHeat).mass i - have hnorm' : normMass m A = w / P.weightsSum := by - simpa [m] using normMass_rerouteHeat_increment (P:=P) hmem - have hfeas : normMass m A ≤ w / P.weightsSum := by - simp [hnorm'] - have hdom := pccMax_dominates (S:=S) m (w / P.weightsSum) A hfeas - have hupper : PCC (S:=S) m (w / P.weightsSum) ≤ w / P.weightsSum := by - simpa [PCC_def] using pccMax_le_tau (S:=S) m (w / P.weightsSum) - have hdom' : w / P.weightsSum ≤ pccMax m (w / P.weightsSum) := by - simpa [hnorm'] using hdom - have hlower : w / P.weightsSum ≤ PCC (S:=S) m (w / P.weightsSum) := by - simpa [PCC_def] using hdom' - exact le_antisymm hupper hlower - -/-- Convenience corollary: the `k`-th reroute increment (mask/weight pair) -from a weighted plan achieves equality with the PCC budget that matches its -normalized weight. This version avoids spelunking through `zip` membership -and is tailored for the exported segment lists used in the Python notebooks. -/ -lemma rerouteHeat_pcc_drop_index {S : Type*} [Fintype S] [DecidableEq S] - (P : WeightedReroutePlan (S := S)) {k : Nat} - (hk : k < P.plan.increments.length) : - PCC (S:=S) (fun i => (P.rerouteHeat).mass i) - (P.weights.get ⟨k, by simpa [P.length_eq_increments] using hk⟩ / P.weightsSum) - = P.weights.get ⟨k, by simpa [P.length_eq_increments] using hk⟩ / P.weightsSum := by - classical - have hlen : P.plan.increments.length = P.weights.length := (P.length_eq_increments).symm - have hw : k < P.weights.length := hlen ▸ hk - have hmem : - (P.plan.increments.get ⟨k, hk⟩, - P.weights.get ⟨k, hw⟩) ∈ P.plan.increments.zip P.weights := by - simpa [hw, hlen] using - List.get_mem_zip (xs:=P.plan.increments) (ys:=P.weights) - (k:=k) hk hw - have h := rerouteHeat_pcc_drop (P:=P) (hmem := hmem) - exact h - -namespace WeightedReroutePlan - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Stage-4 prefix objective: restrict the AUC evaluation to the first `k` -intervals induced by `weights`. -/ -noncomputable def rerouteHeatObjectivePrefix - (P : WeightedReroutePlan (S := S)) (k : ℕ) : NNReal := - PCC.evalFromWeights - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - P.weightsSum (P.weights.take k) - -/-- Objective used in Stage 4: the PCC AUC achieved by the reroute plan’s -`rerouteHeat` distribution, evaluated via the generic weights-to-interval map. -/ -noncomputable def rerouteHeatObjective (P : WeightedReroutePlan (S := S)) : NNReal := - PCC.evalFromWeights - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - P.weightsSum P.weights - -lemma rerouteHeatObjectivePrefix_as_auc (P : WeightedReroutePlan (S := S)) (k : ℕ) : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals).take k) - = P.rerouteHeatObjectivePrefix k := by - classical - simpa [rerouteHeatObjectivePrefix, aucIntervals] using - PCC.AUC_intervalsFromWeights_take - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - P.weightsSum P.weights k - -lemma rerouteHeatObjective_as_auc (P : WeightedReroutePlan (S := S)) : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals) - = P.rerouteHeatObjective := by - simp [rerouteHeatObjective, auc_eval] - -lemma rerouteHeatObjectivePrefix_le (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjectivePrefix k ≤ P.rerouteHeatObjective := by - classical - have hmono := - PCC.AUC_monotone_append - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) (P.aucIntervals.drop k) - have hsplit := List.take_append_drop k P.aucIntervals - have hle : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) - ≤ - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals) := by - simpa [hsplit] using hmono - have hprefix := (rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k)).symm - have htotal := (rerouteHeatObjective_as_auc (P:=P)).symm - simpa [hprefix, htotal] using hle - -lemma rerouteHeatObjective_split (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjective = - P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.drop k) := by - classical - have h := - PCC.AUC_append - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) (P.aucIntervals.drop k) - have hsplit := List.take_append_drop k P.aucIntervals - have hsum : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals) - = - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.drop k) := by - simpa [hsplit] using h - have hprefix := (rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k)).symm - have htotal := (rerouteHeatObjective_as_auc (P:=P)).symm - simpa [hprefix, htotal] using hsum - -/-- Prefix addition lemma: the `(k+1)`-st prefix equals the `k`-prefix plus the -contribution coming from the `(k+1)`-st interval (if present). -/ -lemma rerouteHeatObjectivePrefix_succ (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjectivePrefix (k + 1) = - P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := by - classical - have hsum : - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take (k + 1)) - = - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := by - simpa [List.take_succ_drop_append (xs:=P.aucIntervals) (k:=k)] using - PCC.AUC_append - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - (P.aucIntervals.take k) - ((P.aucIntervals.drop k).take 1) - have hprefix := rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k) - have hprefix_succ := rerouteHeatObjectivePrefix_as_auc (P:=P) (k:=k+1) - have hsum' := hsum - rw [hprefix_succ] at hsum' - have hsum'' := hsum' - rw [hprefix] at hsum'' - exact hsum'' - -lemma rerouteHeatObjectivePrefix_succ_le (P : WeightedReroutePlan (S := S)) (k : ℕ) : - P.rerouteHeatObjectivePrefix k ≤ P.rerouteHeatObjectivePrefix (k + 1) := by - classical - have h := rerouteHeatObjectivePrefix_succ (P:=P) (k:=k) - have hnonneg : - 0 ≤ - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := - PCC.AUC_nonneg - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) - have hle : - P.rerouteHeatObjectivePrefix k ≤ - P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := - le_add_of_nonneg_right hnonneg - calc - P.rerouteHeatObjectivePrefix k - ≤ P.rerouteHeatObjectivePrefix k + - PCC.AUC - (f:=fun t => PCC (S:=S) (fun i => (P.rerouteHeat).mass i) t) - ((P.aucIntervals.drop k).take 1) := hle - _ = P.rerouteHeatObjectivePrefix (k + 1) := by - simp [h] - -lemma rerouteHeatObjectivePrefix_mono (P : WeightedReroutePlan (S := S)) - {k₁ k₂ : ℕ} (hle : k₁ ≤ k₂) : - P.rerouteHeatObjectivePrefix k₁ ≤ P.rerouteHeatObjectivePrefix k₂ := by - classical - have hmono : - ∀ d k, P.rerouteHeatObjectivePrefix k ≤ P.rerouteHeatObjectivePrefix (k + d) := by - intro d - induction d with - | zero => - intro k - simp - | succ d ih => - intro k - have hsucc := - rerouteHeatObjectivePrefix_succ_le - (P:=P) (k:=k + d) - have := ih k - have := this.trans hsucc - simpa [Nat.add_comm, Nat.add_left_comm, Nat.add_assoc] using this - obtain ⟨d, rfl⟩ := Nat.exists_eq_add_of_le hle - simpa using hmono d k₁ - -/-- Feasibility predicate for Stage 4: each incremental mask carries the exact -normalized mass dictated by its weight ratio. -/ -def rerouteFeasible (m : S → NNReal) (P : WeightedReroutePlan (S := S)) : Prop := - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ P.plan.increments.zip P.weights → - normMass m A = w / P.weightsSum - -lemma rerouteHeat_feasible (P : WeightedReroutePlan (S := S)) : - rerouteFeasible (S:=S) (fun i => (P.rerouteHeat).mass i) P := by - intro A w hmem - simpa using normMass_rerouteHeat_increment (P:=P) (hmem:=hmem) - -/-- Stage-4 goal statement (stub): a “delta-weighted” optimal plan maximizes -`rerouteHeatObjective` among feasible plans. Proof to be supplied once the -majorization lemmas are in place. -/ -def optimal_delta_weighting_statement (m : S → NNReal) : Prop := - ∃ P : WeightedReroutePlan (S := S), - rerouteFeasible (S:=S) m P ∧ - ∀ Q : WeightedReroutePlan (S := S), - rerouteFeasible (S:=S) m Q → - Q.rerouteHeatObjective ≤ P.rerouteHeatObjective - -end WeightedReroutePlan - - -/-- Appendix A.4 (greedy optimality for mass budgets): there exists a -subset `A` (a feasible mask under budget `τ`) that maximizes normalized -removed mass among all feasible masks. Concretely we can take `A = pccArg m τ`. -This ties the constructive “pick a best feasible mask” view to our -`pccArg/pccMax` definitions; ties are handled by the argmax choice. -/ -lemma greedy_topmass_optimal (m : S → NNReal) (τ : NNReal) : - ∃ A : Finset S, - A ∈ feasible (S:=S) m τ ∧ - (∀ B : Finset S, B ∈ feasible (S:=S) m τ → - normMass m B ≤ normMass m A) := by - classical - refine ⟨pccArg (S:=S) m τ, ?_, ?_⟩ - · exact pccArg_mem (S:=S) m τ - · intro B hB - exact pccArg_is_max (S:=S) m τ B hB - - -/-- Greedy top-k optimality (helper for A.4): among all subsets with cardinality ≤ k, -there exists -an optimal set A maximizing `A.sum m` such that no outside element has larger weight -than an inside element (swap-optimality). This captures the compact majorization-style -“top-k is optimal” property without constructing an explicit sort. -/ -lemma greedy_topk_optimal (m : S → NNReal) (k : ℕ) : - ∃ A : Finset S, - A.card ≤ k ∧ - (∀ B : Finset S, B.card ≤ k → B.sum m ≤ A.sum m) ∧ - (∀ {i j}, i ∈ A → j ∉ A → m i ≥ m j) := by - classical - -- Candidates: all subsets of `univ` with card ≤ k. - let C : Finset (Finset S) := (Finset.univ : Finset S).powerset.filter (fun A => A.card ≤ k) - have hC_nonempty : C.Nonempty := by - refine ⟨∅, ?_⟩ - simp [C] - -- Pick A ∈ C maximizing the sum. - obtain ⟨A, hA_mem, hAmax⟩ := Finset.exists_max_image C (fun A : Finset S => A.sum m) hC_nonempty - -- Basic facts about A - have hA_card_le : A.card ≤ k := by - rcases Finset.mem_filter.mp hA_mem with ⟨hA_ps, hk⟩ - exact hk - -- Show optimality among all B with card ≤ k. - have hA_opt : ∀ B : Finset S, B.card ≤ k → B.sum m ≤ A.sum m := by - intro B hBk - have hB_mem : B ∈ C := by - -- Any B ⊆ univ is in the powerset, and `B.card ≤ k` ensures it passes the filter. - simp [C, hBk] - exact hAmax B hB_mem - -- Swap optimality: no beneficial swap exists. - have hswap : ∀ {i j}, i ∈ A → j ∉ A → m i ≥ m j := by - intro i j hiA hjA - by_contra hij - have hij' : m i < m j := lt_of_not_ge hij - -- Construct the swapped set B = insert j (A.erase i) - let B := insert j (A.erase i) - have hj_not : j ∉ A.erase i := by - simp [Finset.mem_erase, hjA] - have hcard_eq : B.card = A.card := by - have hIns : insert i (A.erase i) = A := by simpa using Finset.insert_erase hiA - have hi_not : i ∉ A.erase i := by simp - have hA_card' : (A.erase i).card + 1 = A.card := by - -- From card(insert i (erase i A)) = card(erase i A) + 1 - calc - (A.erase i).card + 1 = (insert i (A.erase i)).card := by - simp [Finset.card_insert_of_notMem hi_not] - _ = A.card := by simp [hIns] - have hB_card : B.card = (A.erase i).card + 1 := by - simp [B, Finset.card_insert_of_notMem hj_not] - simpa [hB_card] using hA_card' - have hB_card_le : B.card ≤ k := by simpa [hcard_eq] using hA_card_le - -- Compare sums via erase/insert decomposition - have hsumA : A.sum m = m i + (A.erase i).sum m := by - have hi_not : i ∉ A.erase i := by simp - have hIns : insert i (A.erase i) = A := by simpa using Finset.insert_erase hiA - have sumIns := Finset.sum_insert (s:=A.erase i) (a:=i) (f:=m) hi_not - simpa [hIns, add_comm] using sumIns - have hsumB : B.sum m = m j + (A.erase i).sum m := by - have sumInsB := Finset.sum_insert (s:=A.erase i) (a:=j) (f:=m) hj_not - simpa [B, add_comm] using sumInsB - have : A.sum m < B.sum m := by - have h := add_lt_add_right hij' ((A.erase i).sum m) - simpa [hsumA, hsumB, add_comm, add_left_comm, add_assoc] using h - have hB_le := hA_opt B hB_card_le - exact (not_le_of_gt this) hB_le - exact ⟨A, hA_card_le, hA_opt, hswap⟩ - -end PCC - -/-! Residual additions: energy-consistent coefficients -/ - -section Residual - -/-- Energy-consistent residual coefficient (Appendix A.3). -/ -noncomputable def lambdaEC (x a : NNReal) : NNReal := x / (x + a) - -lemma lambdaEC_sum_one (x a : NNReal) (hx : x + a ≠ 0) : - lambdaEC x a + a / (x + a) = 1 := by - unfold lambdaEC - have : x / (x + a) + a / (x + a) = (x + a) / (x + a) := by - simp [add_div] - simp [this, div_self hx] - -/-- If a residual coefficient `λ` satisfies `λ + a/(x+a) = 1` (with `x+a ≠ 0`), -then necessarily `λ = x/(x+a)` in `NNReal`. -/ -lemma residual_lambda_from_norm - (lcoeff : NNReal → NNReal → NNReal) - (x a : NNReal) (hx : x + a ≠ 0) - (hnorm : lcoeff x a + a / (x + a) = 1) : - lcoeff x a = x / (x + a) := by - -- Work in ℝ via coercions to use field operations. - have hR : (lcoeff x a : ℝ) + (a : ℝ) / (x + a) = 1 := by - simpa using congrArg (fun t : NNReal => (t : ℝ)) hnorm - have hxneR : (x + a : ℝ) ≠ 0 := by exact_mod_cast hx - have hval : (lcoeff x a : ℝ) = x / (x + a) := by - -- From hR: l + a/(x+a) = 1 ⇒ l = 1 - a/(x+a) = x/(x+a) - have : (lcoeff x a : ℝ) = 1 - (a : ℝ) / (x + a) := by - exact (eq_sub_iff_add_eq).mpr hR - have hunit : (1 : ℝ) = (x + a) / (x + a) := by simp [div_self hxneR] - have hdiff : (x + a : ℝ) / (x + a) - (a : ℝ) / (x + a) = ((x + a : ℝ) - a) / (x + a) := by - simpa [sub_eq_add_neg] using (sub_div ((x + a : ℝ)) (a : ℝ) (x + a)).symm - have : (lcoeff x a : ℝ) = ((x + a : ℝ) - a) / (x + a) := by simp [this, hunit, hdiff] - simpa [sub_eq_add_neg, add_comm] - -- Conclude equality in NNReal by extensionality on coercions to ℝ. - apply Subtype.ext - simpa using hval - -/-- Global characterization: any residual rule which is pointwise row-normalized -(`lcoeff x a + a/(x+a) = 1` whenever `x+a ≠ 0`) must be `x/(x+a)` (Appendix A.3). -The Appendix mentions scale-invariance for motivation, but the normalization -equation alone determines the coefficients, so we keep the statement minimal. -/ -lemma lambdaEC_scale_invariant_global - (lcoeff : NNReal → NNReal → NNReal) - (hnorm : ∀ (x a : NNReal), x + a ≠ 0 → lcoeff x a + a / (x + a) = 1) : - ∀ (x a : NNReal), x + a ≠ 0 → lcoeff x a = x / (x + a) := by - intro x a hx - exact residual_lambda_from_norm lcoeff x a hx (hnorm x a hx) - -end Residual - -/-! Normalization uniqueness for local mixers -/ - -section Normalize - -variable {S : Type*} [DecidableEq S] - -/-- Appendix A.2: normalize nonnegative weights on a finite subset `A` to a probability row. -/ -noncomputable def normalizeOn (A : Finset S) (w : S → NNReal) : S → NNReal := - fun i => if i ∈ A then w i / (A.sum w) else 0 - -/-- Appendix A.2: outside the support `A`, `normalizeOn A w` is zero. -/ -@[simp] lemma normalizeOn_outside (A : Finset S) (w : S → NNReal) {i : S} (hi : i ∉ A) : - normalizeOn (S:=S) A w i = 0 := by - classical - simp [normalizeOn, hi] - -/-- Appendix A.2: inside the support `A`, `normalizeOn A w` is proportional to `w`. -/ -@[simp] lemma normalizeOn_inside (A : Finset S) (w : S → NNReal) {i : S} (hi : i ∈ A) : - normalizeOn (S:=S) A w i = w i / (A.sum w) := by - classical - simp [normalizeOn, hi] - -/-- Appendix A.2: `normalizeOn A w` is a probability row (sums to 1). -/ -lemma normalizeOn_sum_one [Fintype S] (A : Finset S) (w : S → NNReal) (h : A.sum w ≠ 0) : - (∑ i, normalizeOn (S:=S) A w i) = 1 := by - classical - -- Sum over `univ` of an indicator-style function equals the sum over `A`. - have hsumA := - (Finset.sum_indicator_subset (s:=A) (t:=(Finset.univ : Finset S)) - (f:=fun i => w i / (A.sum w)) (Finset.subset_univ A)) - -- simplify the indicator - have hsumA' : (∑ i, normalizeOn (S:=S) A w i) = A.sum (fun i => w i / (A.sum w)) := by - convert hsumA using 1 - simp [normalizeOn, Set.indicator] - have hsumA'' : A.sum (fun i => w i / (A.sum w)) = (A.sum w) / (A.sum w) := by - simp [Finset.sum_div] - have hsumA1 : A.sum (fun i => w i / (A.sum w)) = 1 := by - simpa [div_self h] using hsumA'' - simpa [hsumA'] using hsumA1 - --- (To be extended) A local uniqueness lemma will connect proportional rows supported on `A` --- and the `normalizeOn` construction via the budget/sum constraint. - -/-- Appendix A.2 (uniqueness): if `p` is supported on `A`, -and proportional to `w` on `A` with total mass 1, then `p` is exactly `normalizeOn A w`. -We also derive the proportionality constant as `1 / A.sum w` using the mass constraint. -/ -lemma proportional_row_unique [Fintype S] - (A : Finset S) (w p : S → NNReal) (k : NNReal) - (hAw : A.sum w ≠ 0) - (hout : ∀ i ∉ A, p i = 0) - (hin : ∀ i ∈ A, p i = k * w i) - (hsum : (∑ i, p i) = 1) : - p = normalizeOn (S:=S) A w := by - classical - -- Represent `p` as an indicator-style function using `hin`/`hout`. - have hrepr : (fun i => if i ∈ A then k * w i else 0) = p := by - funext i - by_cases hi : i ∈ A - · simp [hi, hin i hi] - · simp [hi, hout i hi] - -- The total mass constraint identifies `k`. - -- Sum of indicator over `univ` reduces to a sum over `A`. - have hsumA0 := - (Finset.sum_indicator_subset (s:=A) (t:=(Finset.univ : Finset S)) - (f:=fun i => k * w i) (Finset.subset_univ A)) - have hsumA0' : - (∑ i, ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) - = A.sum (fun i => k * w i) := by - simpa using hsumA0 - -- Identify the indicator function with the `if-then-else` representation and then with `p`. - have hind : (fun i => ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) - = (fun i => if i ∈ A then k * w i else 0) := by - funext i; simp [Set.indicator] - have hfun : (fun i => ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) = p := by - simpa [hrepr] using hind - have hsumA : (∑ i, p i) = A.sum (fun i => k * w i) := by - have : (∑ i, p i) = (∑ i, ((↑A : Set S).indicator (fun i => k * w i) i : NNReal)) := by - simp [hfun] - exact this.trans hsumA0' - have hsumA1 : A.sum (fun i => k * w i) = 1 := by simpa [hsumA] using hsum - -- Move to ℝ to use field-style algebra and cancel the nonzero sum. - have hsumA1R : (A.sum (fun i => ((k : ℝ) * (w i : ℝ)))) = 1 := by - simpa using congrArg (fun t : NNReal => (t : ℝ)) hsumA1 - have hR : (k : ℝ) * (((A.sum w : NNReal) : ℝ)) = 1 := by - -- factor out the constant from the sum and identify the coerced sum - simpa [Finset.mul_sum] using hsumA1R - have hAwR : (((A.sum w : NNReal) : ℝ)) ≠ 0 := by exact_mod_cast hAw - have k_eq : (k : ℝ) = 1 / (((A.sum w : NNReal) : ℝ)) := by - -- from k * S = 1 - exact (eq_div_iff_mul_eq hAwR).2 (by simpa [mul_comm] using hR) - -- Conclude pointwise equality with normalizeOn. - apply funext; intro i; by_cases hiA : i ∈ A - · -- inside A: p i = k * w i = (1 / sum) * w i = w i / sum - have : p i = k * w i := hin i hiA - have : (p i : ℝ) = (k : ℝ) * (w i : ℝ) := by - simpa using congrArg (fun t : NNReal => (t : ℝ)) this - have : (p i : ℝ) = (1 / (((A.sum w : NNReal) : ℝ))) * (w i : ℝ) := by - simpa [k_eq, mul_comm, mul_left_comm, mul_assoc] - using this - -- convert back to NNReal - apply Subtype.ext - -- use commutativity to write as division - have : (p i : ℝ) = (w i : ℝ) / (((A.sum w : NNReal) : ℝ)) := by - simpa [div_eq_mul_inv, mul_comm] using this - -- now identify with normalizeOn - have hnorm : (normalizeOn (S:=S) A w i : ℝ) = (w i : ℝ) / (((A.sum w : NNReal) : ℝ)) := by - -- simplify normalizeOn since i ∈ A - have := normalizeOn_inside (S:=S) A w (i:=i) hiA - -- coerce and rewrite - simp [this] - simpa [hnorm] - · -- outside A: both are 0 - have : p i = 0 := hout i hiA - simp [normalizeOn_outside (S:=S) A w hiA, this] - -end Normalize - -end Nfp - - -/-! -## Consolidated Theorem A.1 (wrapper statements) - -We collect three core ingredients used in Appendix A.1’s uniqueness story as -named wrappers. These are restatements (aliases) of lemmas proved above or in -`Nfp.Uniqueness` so downstream developments can refer to a single place. - -- Residual coefficient uniqueness from row-normalization. -- Normalization uniqueness for proportional rows supported on `A`. -- Global uniqueness of tracer families for a linear local system on a finite DAG. - -These wrappers introduce no new proof obligations; they simply expose the -relevant facts under Appendix A.1’s umbrella. --/ - -namespace Nfp - -/-! Residual coefficient uniqueness (Appendix A.3 inside A.1) -/ - -theorem A1_residual_unique - (lcoeff : NNReal → NNReal → NNReal) - (x a : NNReal) (hx : x + a ≠ 0) - (hnorm : lcoeff x a + a / (x + a) = 1) : - lcoeff x a = x / (x + a) := - residual_lambda_from_norm lcoeff x a hx hnorm - -/-! Normalization uniqueness on a support (Appendix A.2 inside A.1) -/ - -theorem A1_normalize_unique - {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w p : S → NNReal) (k : NNReal) - (hAw : A.sum w ≠ 0) - (hout : ∀ i ∉ A, p i = 0) - (hin : ∀ i ∈ A, p i = k * w i) - (hsum : (∑ i, p i) = 1) : - p = normalizeOn (S:=S) A w := - proportional_row_unique (S:=S) A w p k hAw hout hin hsum - -/-! Global linear uniqueness on a DAG (Appendix A.1 via `LocalSystem`) -/ - -theorem A1_global_tracer_unique - {S : Type*} {n : ℕ} - (L : LocalSystem n) - {T T' : LocalSystem.TracerFamily (S := S) n} - (hT : LocalSystem.Satisfies (S := S) L T) - (hT' : LocalSystem.Satisfies (S := S) L T') : - T = T' := - LocalSystem.tracer_unique (S:=S) L hT hT' - -/-! -## Packaged Appendix A.1 theorem - -We bundle the three core uniqueness components of Appendix A.1 into a single -exported statement `A1` returning a triple of universally quantified facts: - -- residual coefficient uniqueness from row-normalization (A.3 used in A.1), -- normalization uniqueness on a given support (A.2 used in A.1), -- global tracer uniqueness for a linear local system on a finite DAG (A.1). --/ - -/-- Appendix A.1 (packaged): a conjunction of the three core uniqueness results -used in the Appendix narrative. Each component is universally quantified over -its parameters and follows directly from the dedicated wrapper theorems above. -/ -theorem A1 : - (∀ (lcoeff : NNReal → NNReal → NNReal) (x a : NNReal), - x + a ≠ 0 → lcoeff x a + a / (x + a) = 1 → - lcoeff x a = x / (x + a)) ∧ - (∀ {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w p : S → NNReal) (k : NNReal), - A.sum w ≠ 0 → - (∀ i ∉ A, p i = 0) → - (∀ i ∈ A, p i = k * w i) → - (∑ i, p i) = 1 → - p = normalizeOn (S:=S) A w) ∧ - (∀ {S : Type*} {n : ℕ} - (L : LocalSystem n) - {T T' : LocalSystem.TracerFamily (S := S) n}, - LocalSystem.Satisfies (S := S) L T → - LocalSystem.Satisfies (S := S) L T' → - T = T') := by - refine ⟨?residual, ?normalize, ?global⟩ - · intro lcoeff x a hx hnorm; exact A1_residual_unique lcoeff x a hx hnorm - · intro S _ _ A w p k hAw hout hin hsum; - exact A1_normalize_unique (S:=S) A w p k hAw hout hin hsum - · intro S n L T T' hT hT'; exact A1_global_tracer_unique (S:=S) L hT hT' - -end Nfp - -/-! -## Appendix A.2 (axioms-as-theorems wrappers) - -We expose the A.2 "axioms" as formal properties derived from our basic -constructions. Rather than assuming them as axioms, we state them as theorems -that hold for our generic primitives: - -- Graph faithfulness: composition preserves support restrictions. -- Residual additivity: the energy-consistent coefficients sum to 1 and are - uniquely determined by row-normalization. -- Row-normalization: `normalizeOn` produces rows that sum to 1 over a finite - support. - -These are lightweight restatements of previously proven lemmas so the Appendix -can reference stable names. --/ - -namespace Nfp - -/-! Graph faithfulness under composition (Axiom 1) -/ - -/- -Appendix A.2 (Graph-faithfulness under composition). -If mixers `M : S → T` and `N : T → U` are supported by relations `R` and `Q` -respectively, then their composition is supported by the composed relation. -This is a wrapper around `Mixer.supported_comp` so Appendix A can refer to a -stable name. See docs/APPENDIX.md §A.2. --/ -theorem A2_graph_faithful_comp - {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - (M : Mixer S T) (N : Mixer T U) - (R : S → T → Prop) (Q : T → U → Prop) - (hM : Mixer.supported (S := S) (T := T) M R) - (hN : Mixer.supported (S := T) (T := U) N Q) : - Mixer.supported (S := S) (T := U) (M.comp N) (Mixer.compSupport R Q) := - Mixer.supported_comp (S := S) (T := T) (U := U) (M := M) (N := N) hM hN - -/-! Residual additivity and energy-consistency (Axiom 2) -/ - -/- -Appendix A.2 (Residual additivity / energy consistency). -For residual coefficient `λ := lambdaEC x a = x/(x+a)`, we have -`λ + a/(x+a) = 1`. See docs/APPENDIX.md §A.2. --/ -theorem A2_residual_energy_consistent (x a : NNReal) (hx : x + a ≠ 0) : - lambdaEC x a + a / (x + a) = 1 := - lambdaEC_sum_one x a hx - -/- -Appendix A.2 (Residual uniqueness from row-normalization). -Any residual rule that satisfies the row-normalization equation is uniquely -determined as `x/(x+a)`. Wrapper of `residual_lambda_from_norm`. -See docs/APPENDIX.md §A.2. --/ -theorem A2_residual_unique - (lcoeff : NNReal → NNReal → NNReal) - (x a : NNReal) (hx : x + a ≠ 0) - (hnorm : lcoeff x a + a / (x + a) = 1) : - lcoeff x a = x / (x + a) := - residual_lambda_from_norm lcoeff x a hx hnorm - -/-! Row-normalization on a finite support (part of Axioms 1/3/6) -/ - -/- -Appendix A.2 (Row-normalization on a finite support). -The `normalizeOn` construction sums to 1 provided the total mass on `A` is -nonzero. Wrapper of `normalizeOn_sum_one`. See docs/APPENDIX.md §A.2. --/ -theorem A2_normalize_row_sum_one {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w : S → NNReal) (h : A.sum w ≠ 0) : - (∑ i, normalizeOn (S := S) A w i) = 1 := - normalizeOn_sum_one (S := S) A w h - -/-! -### Packaged Appendix A.2 statement - -We bundle three generic A.2 facets into a single theorem `A2`: - -1. Residual uniqueness from row-normalization at a residual node. -2. Row-normalization of proportional weights on a finite support via `normalizeOn`. -3. Preservation of support/faithfulness under composition of mixers. - -This mirrors the intention that the "axioms" are consequences of our formal -definitions rather than assumptions. --/ - -/- -Appendix A.2 (packaged statement). -Conjunction of: residual uniqueness, row-normalization via `normalizeOn`, and -graph-faithfulness of mixer composition. See docs/APPENDIX.md §A.2. --/ -theorem A2 : - (∀ (lcoeff : NNReal → NNReal → NNReal) (x a : NNReal), - x + a ≠ 0 → lcoeff x a + a / (x + a) = 1 → - lcoeff x a = x / (x + a)) ∧ - (∀ {S : Type*} [Fintype S] [DecidableEq S] - (A : Finset S) (w : S → NNReal), - A.sum w ≠ 0 → (∑ i, normalizeOn (S := S) A w i) = 1) ∧ - (∀ {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - (M : Mixer S T) (N : Mixer T U) - (R : S → T → Prop) (Q : T → U → Prop), - Mixer.supported (S := S) (T := T) M R → - Mixer.supported (S := T) (T := U) N Q → - Mixer.supported (S := S) (T := U) (M.comp N) (Mixer.compSupport R Q)) := by - refine ⟨?resid, ?row, ?faith⟩ - · intro lcoeff x a hx hnorm; exact A2_residual_unique lcoeff x a hx hnorm - · intro S _ _ A w h; exact A2_normalize_row_sum_one (S := S) A w h - · intro S T U _ _ _ M N R Q hM hN; - exact A2_graph_faithful_comp (M:=M) (N:=N) R Q hM hN - -/-- Appendix A.2 compatibility: an `InfluenceSpec` with nonzero rows yields a -graph-faithful mixer via `ofInfluenceSpec`, ready to feed into `A2` results. -/ -lemma A2_for_ofInfluenceSpec {Site : Type*} [Fintype Site] [DecidableEq Site] - (I : InfluenceSpec Site) (hZ : ∀ s, InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) : - Mixer.supported (S := Site) (T := Site) (Mixer.ofInfluenceSpec (Site := Site) I) I.adj := - Mixer.ofInfluenceSpec_supported (Site := Site) (I := I) hZ - -end Nfp diff --git a/Nfp/Attribution.lean b/Nfp/Attribution.lean deleted file mode 100644 index 5bd5fe3..0000000 --- a/Nfp/Attribution.lean +++ /dev/null @@ -1,236 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Nfp.Prob -import Nfp.Mixer - -/-! -# Attribution Axioms for Neural Network Interpretation - -This module formalizes key properties (axioms) that attribution methods for -neural networks should satisfy. These axioms are used to characterize and -compare different interpretation methods such as Integrated Gradients, LRP, -Shapley values, and path-based attribution. - -## Main definitions - -* `Attribution` – an attribution assigns a contribution score to each input feature -* `Completeness` – contributions sum to the output difference from baseline -* `Sensitivity` – if an input affects output, it receives nonzero attribution -* `Implementation Invariance` – attributions depend only on input-output behavior -* `Linearity` – attribution is linear in the function being explained - -## Key theorems - -* `completeness_of_conservation` – conservation mixers induce complete attributions -* `tracer_attribution_complete` – tracer-based attributions satisfy completeness - -## References - -* Sundararajan, Taly, Yan: "Axiomatic Attribution for Deep Networks" (ICML 2017) -* Ancona et al.: "Towards better understanding of gradient-based attribution" (2018) -* Lundstrom et al.: "A Rigorous Study of Integrated Gradients" (2022) --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Attribution structure -/ - -section Attribution - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- An attribution method assigns a contribution score to each input feature -for a given output. Scores are signed reals to allow negative contributions. -/ -structure Attribution (Input Output : Type*) [Fintype Input] [Fintype Output] where - /-- The contribution of input `i` to output `o`. -/ - contrib : Input → Output → ℝ - -namespace Attribution - -variable (A : Attribution Input Output) - -/-- Total contribution to a specific output. -/ -noncomputable def totalContrib (o : Output) : ℝ := - ∑ i, A.contrib i o - -/-- An attribution is nonnegative if all contributions are nonnegative. -/ -def Nonneg : Prop := - ∀ i o, 0 ≤ A.contrib i o - -end Attribution - -end Attribution - -/-! ## Completeness axiom -/ - -section Completeness - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- The completeness axiom: contributions sum to the difference between the -output at input `x` and the output at baseline `x₀`. This is the core axiom -from Integrated Gradients and Shapley value attribution. - -For neural networks: `f(x) - f(x₀) = ∑ᵢ attribution(i)` -/ -def Attribution.Complete - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) - (o : Output) : Prop := - A.totalContrib o = f x o - f x₀ o - -/-- Completeness for a single-output function (common case). -/ -def Attribution.CompleteScalar - (A : Attribution Input Unit) - (f : (Input → ℝ) → ℝ) - (x x₀ : Input → ℝ) : Prop := - A.totalContrib () = f x - f x₀ - -end Completeness - -/-! ## Sensitivity axiom -/ - -section Sensitivity - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- An input feature `i` influences output `o` for function `f` at point `x` -relative to baseline `x₀` if changing just that feature changes the output. -/ -def Influences - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) - (i : Input) - (o : Output) : Prop := - ∃ (x' : Input → ℝ), - (∀ j, j ≠ i → x' j = x₀ j) ∧ - x' i = x i ∧ - f x' o ≠ f x₀ o - -/-- The sensitivity axiom: if a feature influences the output, it receives -nonzero attribution. -/ -def Attribution.Sensitive - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) : Prop := - ∀ i o, Influences f x x₀ i o → A.contrib i o ≠ 0 - -/-- Dummy axiom: features that don't influence output get zero attribution. -/ -def Attribution.Dummy - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) - (x x₀ : Input → ℝ) : Prop := - ∀ i o, ¬ Influences f x x₀ i o → A.contrib i o = 0 - -end Sensitivity - -/-! ## Conservation and tracer-based attribution -/ - -section Conservation - -variable {S : Type*} [Fintype S] - -/-- A mixer `M` is mass-conserving if the total outgoing mass equals -the total incoming mass. For row-stochastic mixers, this is automatic. -/ -lemma Mixer.conserves_mass (M : Mixer S S) (p : ProbVec S) : - (∑ i, (M.push p).mass i) = ∑ i, p.mass i := by - simp only [ProbVec.sum_mass] - -/-- Attribution derived from a probability distribution (tracer mass). -/ -noncomputable def attributionOfProbVec [Fintype S] - (p : ProbVec S) : Attribution S Unit where - contrib := fun i _ => (p.mass i : ℝ) - -/-- Tracer-based attributions automatically satisfy completeness with respect -to total mass (which is 1). -/ -theorem tracer_attribution_complete (p : ProbVec S) : - (attributionOfProbVec p).totalContrib () = 1 := by - simp only [Attribution.totalContrib, attributionOfProbVec] - have h := p.norm_one - simp only [← NNReal.coe_sum] at h ⊢ - exact_mod_cast h - -end Conservation - -/-! ## Linearity axiom -/ - -section Linearity - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] - -/-- A method that produces attributions for any function. -/ -def AttributionMethod (Input Output : Type*) [Fintype Input] [Fintype Output] := - ((Input → ℝ) → Output → ℝ) → (Input → ℝ) → (Input → ℝ) → Attribution Input Output - -/-- The linearity axiom: attribution of a sum of functions equals the sum -of attributions. -/ -def AttributionMethod.Linear (method : AttributionMethod Input Output) : Prop := - ∀ (f g : (Input → ℝ) → Output → ℝ) (x x₀ : Input → ℝ) (i : Input) (o : Output), - (method (fun inp => fun o => f inp o + g inp o) x x₀).contrib i o = - (method f x x₀).contrib i o + (method g x x₀).contrib i o - -/-- Scale invariance: scaling the function scales the attribution. -/ -def AttributionMethod.ScaleInvariant (method : AttributionMethod Input Output) : Prop := - ∀ (f : (Input → ℝ) → Output → ℝ) (c : ℝ) (x x₀ : Input → ℝ) (i : Input) (o : Output), - (method (fun inp => fun o => c * f inp o) x x₀).contrib i o = - c * (method f x x₀).contrib i o - -end Linearity - -/-! ## Symmetry and efficiency -/ - -section Symmetry - -variable {Input Output : Type*} [Fintype Input] [Fintype Output] [DecidableEq Input] - -/-- Two inputs are symmetric for a function if swapping them preserves the output. -/ -def SymmetricInputs - (f : (Input → ℝ) → Output → ℝ) - (i j : Input) : Prop := - ∀ (x : Input → ℝ) (o : Output), - f x o = f (Function.swap x i j) o - where - Function.swap (x : Input → ℝ) (i j : Input) : Input → ℝ := - fun k => if k = i then x j else if k = j then x i else x k - -/-- The symmetry axiom: symmetric inputs receive equal attribution. -/ -def Attribution.Symmetric - (A : Attribution Input Output) - (f : (Input → ℝ) → Output → ℝ) : Prop := - ∀ i j o, SymmetricInputs f i j → A.contrib i o = A.contrib j o - -end Symmetry - -/-! ## Path-based attribution -/ - -section PathAttribution - -variable {Input : Type*} - -/-- A path from baseline `x₀` to input `x` parameterized by `t ∈ [0,1]`. -/ -def Path (Input : Type*) := (t : NNReal) → Input → ℝ - -/-- A straight-line (linear interpolation) path between two points. -/ -noncomputable def straightPath (x x₀ : Input → ℝ) : Path Input := - fun t i => (1 - (t : ℝ)) * x₀ i + (t : ℝ) * x i - -/-- A valid path starts at baseline and ends at input. -/ -structure Path.Valid (γ : Path Input) (x x₀ : Input → ℝ) : Prop where - at_zero : ∀ i, γ 0 i = x₀ i - at_one : ∀ i, γ 1 i = x i - -lemma straightPath_valid (x x₀ : Input → ℝ) : (straightPath x x₀).Valid x x₀ := by - constructor - · intro i; simp [straightPath] - · intro i; simp [straightPath] - -end PathAttribution - -end Nfp diff --git a/Nfp/Bounds.lean b/Nfp/Bounds.lean new file mode 100644 index 0000000..eeb5226 --- /dev/null +++ b/Nfp/Bounds.lean @@ -0,0 +1,15 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Bounds.Attention +public import Nfp.Bounds.Cache +public import Nfp.Bounds.Gelu +public import Nfp.Bounds.LayerNorm +public import Nfp.Bounds.LayerNorm.InvStd +public import Nfp.Bounds.Mlp +public import Nfp.Bounds.UnnormRat + +/-! +Aggregator for untrusted interval bounds. +-/ diff --git a/Nfp/Bounds/Attention.lean b/Nfp/Bounds/Attention.lean new file mode 100644 index 0000000..94aa802 --- /dev/null +++ b/Nfp/Bounds/Attention.lean @@ -0,0 +1,403 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Field +public import Mathlib.Algebra.BigOperators.Ring.Finset +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Data.Real.Basic +public import Nfp.Circuit.Layers.Softmax +public import Nfp.Core.Basic +public import Nfp.Model.Gpt2 +public import Nfp.Bounds.Cache +public import Nfp.Bounds.LayerNorm +public import Nfp.Bounds.Mlp + +/-! +Interval bounds for multi-head attention and transformer layers. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +open scoped BigOperators + +/-- Real-valued attention output for a query token and model coordinate. -/ +noncomputable def attentionOutputReal {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) + (q : Fin seq) (i : Fin dModel) : Real := + let lnOut : Fin seq → Fin dModel → Real := fun k j => + layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j + let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => + dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d + let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => + Circuit.softmax (scores h q) k + let headOutput : Fin numHeads → Fin seq → Fin dHead → Real := fun h q d => + dotProduct (headWeights h q) (fun k => headValue h k d) + let headProj : Fin numHeads → Fin seq → Fin dModel → Real := fun h q j => + dotProduct (fun d => ((heads h).wo j d : Real)) (fun d => headOutput h q d) + (∑ h, headProj h q i) + (attnBias i : Real) + +/-- Unfolding lemma for `attentionOutputReal`. -/ +private theorem attentionOutputReal_def {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (x : Fin seq → Fin dModel → Real) + (q : Fin seq) (i : Fin dModel) : + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i = + let lnOut : Fin seq → Fin dModel → Real := fun k j => + layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j + let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => + dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d + let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => + Circuit.softmax (scores h q) k + let headOutput : Fin numHeads → Fin seq → Fin dHead → Real := fun h q d => + dotProduct (headWeights h q) (fun k => headValue h k d) + let headProj : Fin numHeads → Fin seq → Fin dModel → Real := fun h q j => + dotProduct (fun d => ((heads h).wo j d : Real)) (fun d => headOutput h q d) + (∑ h, headProj h q i) + (attnBias i : Real) := rfl + +/-- Interval bounds for multi-head attention outputs from interval inputs. -/ +def attentionOutputBounds {dModel dHead numHeads : Nat} + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := + let absBound := intervalAbsBound lo hi + let ln := layerNormAbsBounds eps ln1Gamma ln1Beta absBound + let lnLo := ln.1 + let lnHi := ln.2 + let vLo : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let vHi : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let headLo : Fin numHeads → Fin dModel → Rat := fun h i => + dotIntervalLower (fun d => (heads h).wo i d) (vLo h) (vHi h) + let headHi : Fin numHeads → Fin dModel → Rat := fun h i => + dotIntervalUpper (fun d => (heads h).wo i d) (vLo h) (vHi h) + let sumLo : Fin dModel → Rat := fun i => ∑ h, headLo h i + let sumHi : Fin dModel → Rat := fun i => ∑ h, headHi h i + (fun i => sumLo i + attnBias i, fun i => sumHi i + attnBias i) + +private theorem sum_weighted_const {seq : Nat} (w : Fin seq → Real) (c : Real) + (hsum : ∑ k, w k = 1) : + ∑ k, w k * c = c := by + calc + ∑ k, w k * c = (∑ k, w k) * c := by + simpa using + (Finset.sum_mul (s := (Finset.univ : Finset (Fin seq))) (f := w) (a := c)).symm + _ = c := by simp [hsum] + +/-- Weighted dot-products preserve interval bounds. -/ +theorem dotProduct_bounds_of_weights {seq : Nat} {lo hi : Real} + {vals w : Fin seq → Real} + (hlo : ∀ k, lo ≤ vals k) (hhi : ∀ k, vals k ≤ hi) + (hnonneg : ∀ k, 0 ≤ w k) (hsum : ∑ k, w k = 1) : + lo ≤ dotProduct w vals ∧ dotProduct w vals ≤ hi := by + have hsum_lo : ∑ k, w k * lo ≤ ∑ k, w k * vals k := by + refine Finset.sum_le_sum ?_ + intro k _ + exact mul_le_mul_of_nonneg_left (hlo k) (hnonneg k) + have hsum_lo' : ∑ k, w k * lo = lo := sum_weighted_const w lo hsum + have hlow : lo ≤ dotProduct w vals := by + have hsum_le : lo ≤ ∑ k, w k * vals k := by + simpa [hsum_lo'] using hsum_lo + simpa [dotProduct] using hsum_le + have hsum_hi : ∑ k, w k * vals k ≤ ∑ k, w k * hi := by + refine Finset.sum_le_sum ?_ + intro k _ + exact mul_le_mul_of_nonneg_left (hhi k) (hnonneg k) + have hsum_hi' : ∑ k, w k * hi = hi := sum_weighted_const w hi hsum + have hhigh : dotProduct w vals ≤ hi := by + have hsum_le : ∑ k, w k * vals k ≤ hi := by + simpa [hsum_hi'] using hsum_hi + simpa [dotProduct] using hsum_le + exact ⟨hlow, hhigh⟩ + +/-- `attentionOutputBounds` soundness for real attention outputs. -/ +theorem attentionOutputBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi + ∀ q i, + (bounds.1 i : Real) ≤ + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ∧ + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ≤ + (bounds.2 i : Real) := by + classical + intro bounds q i + let absBound := intervalAbsBound lo hi + let lnBounds := layerNormAbsBounds eps ln1Gamma ln1Beta absBound + let lnLo := lnBounds.1 + let lnHi := lnBounds.2 + let lnOut : Fin seq → Fin dModel → Real := fun k j => + layerNormRealOfReal eps ln1Gamma ln1Beta (x k) j + let vLo : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalLower (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let vHi : Fin numHeads → Fin dHead → Rat := fun h d => + dotIntervalUpper (fun j => (heads h).wv j d) lnLo lnHi + (heads h).bv d + let headLo : Fin numHeads → Fin dModel → Rat := fun h j => + dotIntervalLower (fun d => (heads h).wo j d) (vLo h) (vHi h) + let headHi : Fin numHeads → Fin dModel → Rat := fun h j => + dotIntervalUpper (fun d => (heads h).wo j d) (vLo h) (vHi h) + let sumLo : Fin dModel → Rat := fun j => ∑ h, headLo h j + let sumHi : Fin dModel → Rat := fun j => ∑ h, headHi h j + let headValue : Fin numHeads → Fin seq → Fin dHead → Real := fun h k d => + dotProduct (fun j => ((heads h).wv j d : Real)) (lnOut k) + (heads h).bv d + let softmaxWeights : Fin numHeads → Circuit.SoftmaxWeights seq := fun h => + Circuit.softmaxWeights (scores h) + let headWeights : Fin numHeads → Fin seq → Fin seq → Real := fun h q k => + Circuit.softmax (scores h q) k + let headOutput : Fin numHeads → Fin seq → Fin dHead → Real := fun h q d => + dotProduct (headWeights h q) (fun k => headValue h k d) + let headProj : Fin numHeads → Fin seq → Fin dModel → Real := fun h q j => + dotProduct (fun d => ((heads h).wo j d : Real)) (fun d => headOutput h q d) + have habs : ∀ q i, |x q i| ≤ (absBound : Real) := by + intro q i + have hbound : + |x q i| ≤ max |(lo i : Real)| |(hi i : Real)| := + abs_le_max_abs_abs_of_interval_real (hlo q i) (hhi q i) + have hsup : max |lo i| |hi i| ≤ intervalAbsBound lo hi := + max_abs_le_intervalAbsBound lo hi i + have hsup_real : + max |(lo i : Real)| |(hi i : Real)| ≤ (absBound : Real) := by + have hsup' : ratToReal (max |lo i| |hi i|) ≤ ratToReal absBound := + ratToReal_le_of_le hsup + simpa [ratToReal_abs, ratToReal_max, ratToReal_def] using hsup' + exact le_trans hbound hsup_real + have hln_bounds : ∀ q i, (lnLo i : Real) ≤ lnOut q i ∧ lnOut q i ≤ (lnHi i : Real) := by + intro q i + have hln := layerNormAbsBounds_spec_real eps ln1Gamma ln1Beta absBound (x q) hne heps hsqrt + (fun j => habs q j) + simpa [lnBounds, lnLo, lnHi, lnOut] using hln i + have hval_bounds : + ∀ h k d, + (vLo h d : Real) ≤ headValue h k d ∧ + headValue h k d ≤ (vHi h d : Real) := by + intro h k d + have hln := hln_bounds k + have hlo' : ∀ j, (lnLo j : Real) ≤ lnOut k j := fun j => (hln j).1 + have hhi' : ∀ j, lnOut k j ≤ (lnHi j : Real) := fun j => (hln j).2 + have hlow' := + dotIntervalLower_le_dotProduct_real_add (v := fun j => (heads h).wv j d) + (lo := lnLo) (hi := lnHi) (x := lnOut k) (b := ((heads h).bv d : Real)) hlo' hhi' + have hhigh' := + dotProduct_le_dotIntervalUpper_real_add (v := fun j => (heads h).wv j d) + (lo := lnLo) (hi := lnHi) (x := lnOut k) (b := ((heads h).bv d : Real)) hlo' hhi' + constructor + · simpa [headValue, vLo] using hlow' + · simpa [headValue, vHi] using hhigh' + have hhead_output_bounds : + ∀ h q d, + (vLo h d : Real) ≤ headOutput h q d ∧ + headOutput h q d ≤ (vHi h d : Real) := by + intro h q d + have hvals := hval_bounds h + have hlo' : ∀ k, + (vLo h d : Real) ≤ headValue h k d := fun k => (hvals k d).1 + have hhi' : ∀ k, + headValue h k d ≤ (vHi h d : Real) := fun k => (hvals k d).2 + have hnonneg : ∀ k, 0 ≤ headWeights h q k := by + intro k + simpa [headWeights, softmaxWeights, Circuit.softmaxWeights_weights] using + (softmaxWeights h).nonneg q k + have hsum : ∑ k, headWeights h q k = 1 := by + simpa [headWeights, softmaxWeights, Circuit.softmaxWeights_weights] using + (softmaxWeights h).sum_one q + have h := dotProduct_bounds_of_weights (lo := (vLo h d : Real)) (hi := (vHi h d : Real)) + (vals := fun k => headValue h k d) (w := headWeights h q) + hlo' hhi' hnonneg hsum + simpa [headOutput] using h + have hproj_bounds : + ∀ h q i, + (headLo h i : Real) ≤ headProj h q i ∧ headProj h q i ≤ (headHi h i : Real) := by + intro h q i + have hout := hhead_output_bounds h q + have hlo' : ∀ d, + (vLo h d : Real) ≤ headOutput h q d := fun d => (hout d).1 + have hhi' : ∀ d, + headOutput h q d ≤ (vHi h d : Real) := fun d => (hout d).2 + have hlow := + dotIntervalLower_le_dotProduct_real (v := fun d => (heads h).wo i d) + (lo := vLo h) (hi := vHi h) + (x := fun d => headOutput h q d) hlo' hhi' + have hhigh := + dotProduct_le_dotIntervalUpper_real (v := fun d => (heads h).wo i d) + (lo := vLo h) (hi := vHi h) + (x := fun d => headOutput h q d) hlo' hhi' + constructor + · simpa [headProj, headLo] using hlow + · simpa [headProj, headHi] using hhigh + have hsum_bounds : + (sumLo i : Real) ≤ ∑ h, headProj h q i ∧ + ∑ h, headProj h q i ≤ (sumHi i : Real) := by + have hlow : (sumLo i : Real) ≤ ∑ h, headProj h q i := by + have hsum := + Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) + (fun h _ => (hproj_bounds h q i).1) + simpa [sumLo, Linear.ratToReal_sum_univ] using hsum + have hhigh : ∑ h, headProj h q i ≤ (sumHi i : Real) := by + have hsum := + Finset.sum_le_sum (s := (Finset.univ : Finset (Fin numHeads))) + (fun h _ => (hproj_bounds h q i).2) + simpa [sumHi, Linear.ratToReal_sum_univ] using hsum + exact ⟨hlow, hhigh⟩ + have hlow : + (sumLo i : Real) + (attnBias i : Real) ≤ + (∑ h, headProj h q i) + (attnBias i : Real) := by + simpa [add_comm] using + add_le_add_left hsum_bounds.1 (attnBias i : Real) + have hhigh : + (∑ h, headProj h q i) + (attnBias i : Real) ≤ + (sumHi i : Real) + (attnBias i : Real) := by + simpa [add_comm] using + add_le_add_left hsum_bounds.2 (attnBias i : Real) + have hreal : + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i = + (∑ h, headProj h q i) + (attnBias i : Real) := by + simp [attentionOutputReal, lnOut, headValue, headWeights, headOutput, headProj] + have hlo : + (bounds.1 i : Real) ≤ + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i := by + simpa [bounds, attentionOutputBounds, absBound, lnBounds, lnLo, lnHi, vLo, vHi, headLo, headHi, + sumLo, sumHi, hreal] using hlow + have hhi : + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ≤ + (bounds.2 i : Real) := by + simpa [bounds, attentionOutputBounds, absBound, lnBounds, lnLo, lnHi, vLo, vHi, headLo, headHi, + sumLo, sumHi, hreal] using hhigh + exact And.intro hlo hhi + +/-- Interval bounds for the attention residual path. -/ +def attentionResidualBounds {dModel dHead numHeads : Nat} + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := + let attn := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi + (fun i => lo i + attn.1 i, fun i => hi i + attn.2 i) + +/-- `attentionResidualBounds` soundness for attention residual outputs. -/ +theorem attentionResidualBounds_spec {seq dModel dHead numHeads : Nat} [NeZero seq] + (eps : Rat) (ln1Gamma ln1Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias lo hi + ∀ q i, + (bounds.1 i : Real) ≤ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ∧ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i ≤ + (bounds.2 i : Real) := by + classical + intro bounds q i + let attn := attentionOutputBounds eps ln1Gamma ln1Beta heads attnBias lo hi + have hattn := + attentionOutputBounds_spec eps ln1Gamma ln1Beta heads attnBias scores lo hi x + hne heps hsqrt hlo hhi q i + have hlow := add_le_add (hlo q i) hattn.1 + have hhigh := add_le_add (hhi q i) hattn.2 + constructor + · simpa [bounds, attentionResidualBounds, attn] using hlow + · simpa [bounds, attentionResidualBounds, attn] using hhigh + +/-- Interval bounds for a full transformer layer (attention + MLP). -/ +def transformerLayerBounds {dModel dHead numHeads hidden : Nat} + (eps : Rat) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) + (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : + (Fin dModel → Rat) × (Fin dModel → Rat) := + let loCached := cacheBound lo + let hiCached := cacheBound hi + let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias loCached hiCached + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let out := layerNormAbsMlpResidualBounds eps ln2Gamma ln2Beta mlpWIn mlpBIn mlpWOut mlpBOut + attnLo attnHi + let outLo := cacheBound out.1 + let outHi := cacheBound out.2 + (outLo, outHi) + +/-- `transformerLayerBounds` soundness for full transformer-layer outputs. -/ +theorem transformerLayerBounds_spec {seq dModel dHead numHeads hidden : Nat} [NeZero seq] + (eps : Rat) + (ln1Gamma ln1Beta ln2Gamma ln2Beta : Fin dModel → Rat) + (heads : Fin numHeads → Model.Gpt2HeadWeights dModel dHead) + (attnBias : Fin dModel → Rat) + (mlpWIn : Fin dModel → Fin hidden → Rat) (mlpBIn : Fin hidden → Rat) + (mlpWOut : Fin hidden → Fin dModel → Rat) (mlpBOut : Fin dModel → Rat) + (scores : Fin numHeads → Fin seq → Fin seq → Real) + (lo hi : Fin dModel → Rat) (x : Fin seq → Fin dModel → Real) + (hne : dModel ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ q i, (lo i : Real) ≤ x q i) (hhi : ∀ q i, x q i ≤ (hi i : Real)) : + let bounds := transformerLayerBounds eps ln1Gamma ln1Beta ln2Gamma ln2Beta heads attnBias + mlpWIn mlpBIn mlpWOut mlpBOut lo hi + ∀ q i, + (bounds.1 i : Real) ≤ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i + + mlpReal mlpWIn mlpBIn mlpWOut mlpBOut + (layerNormRealOfReal eps ln2Gamma ln2Beta + (fun j => + x q j + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q j)) i ∧ + x q i + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q i + + mlpReal mlpWIn mlpBIn mlpWOut mlpBOut + (layerNormRealOfReal eps ln2Gamma ln2Beta + (fun j => + x q j + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q j)) i ≤ + (bounds.2 i : Real) := by + classical + intro bounds q i + let loCached := cacheBound lo + let hiCached := cacheBound hi + have hloCached : ∀ q i, (loCached i : Real) ≤ x q i := by + intro q i + simpa [loCached, cacheBound_apply] using hlo q i + have hhiCached : ∀ q i, x q i ≤ (hiCached i : Real) := by + intro q i + simpa [hiCached, cacheBound_apply] using hhi q i + let attn := attentionResidualBounds eps ln1Gamma ln1Beta heads attnBias loCached hiCached + have hattn := attentionResidualBounds_spec eps ln1Gamma ln1Beta heads attnBias scores + loCached hiCached x hne heps hsqrt hloCached hhiCached q + let attnLo := cacheBound attn.1 + let attnHi := cacheBound attn.2 + let y := fun j => x q j + attentionOutputReal eps ln1Gamma ln1Beta heads attnBias scores x q j + have hattnLo : ∀ j, (attnLo j : Real) ≤ y j := by + intro j + simpa [attnLo, cacheBound_apply, y] using (hattn j).1 + have hattnHi : ∀ j, y j ≤ (attnHi j : Real) := by + intro j + simpa [attnHi, cacheBound_apply, y] using (hattn j).2 + have hmlp := layerNormAbsMlpResidualBounds_spec eps ln2Gamma ln2Beta mlpWIn mlpBIn mlpWOut + mlpBOut attnLo attnHi y hne heps hsqrt hattnLo hattnHi + have hmlp_i := hmlp i + simpa [bounds, transformerLayerBounds, attn, loCached, hiCached, attnLo, attnHi, y, + cacheBound_apply] using hmlp_i + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/Cache.lean b/Nfp/Bounds/Cache.lean new file mode 100644 index 0000000..a4f8101 --- /dev/null +++ b/Nfp/Bounds/Cache.lean @@ -0,0 +1,222 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Core.Basic + +/-! +Caching helpers for interval bounds. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +/-- Cache a bound function in an array-backed lookup to avoid repeated evaluation. -/ +def cacheBound {n : Nat} (f : Fin n → Rat) : Fin n → Rat := + let data : Thunk (Array Rat) := Thunk.mk (fun _ => Array.ofFn f) + fun i => (Thunk.get data)[i.1]'(by + simp [Thunk.get, data, i.isLt]) + +/-- `cacheBound` preserves pointwise values. -/ +theorem cacheBound_apply {n : Nat} (f : Fin n → Rat) (i : Fin n) : + cacheBound f i = f i := by + simp [cacheBound, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function using per-index thunks for lazy evaluation. -/ +def cacheBoundThunk {n : Nat} (f : Fin n → Rat) : Fin n → Rat := + let data : Array (Thunk Rat) := Array.ofFn (fun i => Thunk.mk (fun _ => f i)) + fun i => + let t := data[i.1]'(by + simp [data, i.isLt]) + Thunk.get t + +/-- `cacheBoundThunk` preserves pointwise values. -/ +theorem cacheBoundThunk_apply {n : Nat} (f : Fin n → Rat) (i : Fin n) : + cacheBoundThunk f i = f i := by + simp [cacheBoundThunk, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function using tasks for parallel evaluation. -/ +def cacheBoundTask {n : Nat} (f : Fin n → Rat) : Fin n → Rat := + let tasks : Array (Task Rat) := + Array.ofFn (fun i : Fin n => + Task.spawn (fun _ => f i)) + fun i => + let t := tasks[i.1]'(by + simp [tasks, i.isLt]) + t.get + +/-- `cacheBoundTask` preserves pointwise values. -/ +theorem cacheBoundTask_apply {n : Nat} (f : Fin n → Rat) (i : Fin n) : + cacheBoundTask f i = f i := by + classical + simp [cacheBoundTask, Task.spawn, Array.getElem_ofFn] + +/-- Cache a bound function on two indices. -/ +def cacheBound2 {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fin n → Rat := + let data : Thunk (Array (Thunk (Array Rat))) := Thunk.mk (fun _ => + Array.ofFn (fun q => Thunk.mk (fun _ => Array.ofFn (f q)))) + fun q i => + let rows := Thunk.get data + let rowThunk := rows[q.1]'(by + simp [rows, Thunk.get, data, q.isLt]) + let row := Thunk.get rowThunk + row[i.1]'(by + simp [row, rowThunk, rows, Thunk.get, data, i.isLt]) + +/-- `cacheBound2` preserves pointwise values. -/ +theorem cacheBound2_apply {m n : Nat} (f : Fin m → Fin n → Rat) (q : Fin m) (i : Fin n) : + cacheBound2 f q i = f q i := by + simp [cacheBound2, Thunk.get, Array.getElem_ofFn] + +/-- Cache a bound function on two indices using row tasks for parallel evaluation. -/ +def cacheBound2Task {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fin n → Rat := + let rowTasks : Array (Task { row : Array Rat // row.size = n }) := + Array.ofFn (fun q : Fin m => + Task.spawn (fun _ => ⟨Array.ofFn (f q), by simp⟩)) + fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get + row.1[i.1]'(by + simp [row.2]) + +/-- `cacheBound2Task` preserves pointwise values. -/ +theorem cacheBound2Task_apply {m n : Nat} (f : Fin m → Fin n → Rat) (q : Fin m) (i : Fin n) : + cacheBound2Task f q i = f q i := by + classical + simp [cacheBound2Task, Task.spawn, Array.getElem_ofFn] + +/-- Cache a bound function on two indices using per-element tasks for parallel evaluation. -/ +def cacheBound2TaskElem {m n : Nat} (f : Fin m → Fin n → Rat) : Fin m → Fin n → Rat := + let rowTasks : Array (Array (Task Rat)) := + Array.ofFn (fun q : Fin m => + Array.ofFn (fun i : Fin n => + Task.spawn (fun _ => f q i))) + fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])) + let t := row[i.1]'(by + simp [row, rowTasks, i.isLt]) + t.get + +/-- `cacheBound2TaskElem` preserves pointwise values. -/ +theorem cacheBound2TaskElem_apply {m n : Nat} (f : Fin m → Fin n → Rat) (q : Fin m) (i : Fin n) : + cacheBound2TaskElem f q i = f q i := by + classical + simp [cacheBound2TaskElem, Task.spawn, Array.getElem_ofFn] + +/-- Cache a pair of bound functions on two indices. -/ +def cacheBoundPair2 {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) : + (Fin m → Fin n → Rat) × (Fin m → Fin n → Rat) := + let data : Thunk (Array (Array Rat × Array Rat)) := Thunk.mk (fun _ => + Array.ofFn (fun q => + let row := f q + (Array.ofFn row.1, Array.ofFn row.2))) + let lo : Fin m → Fin n → Rat := fun q i => + let rows := Thunk.get data + let row := rows[q.1]'(by + simp [rows, Thunk.get, data, q.isLt]) + let loRow := row.1 + loRow[i.1]'(by + simp [loRow, row, rows, Thunk.get, data, i.isLt]) + let hi : Fin m → Fin n → Rat := fun q i => + let rows := Thunk.get data + let row := rows[q.1]'(by + simp [rows, Thunk.get, data, q.isLt]) + let hiRow := row.2 + hiRow[i.1]'(by + simp [hiRow, row, rows, Thunk.get, data, i.isLt]) + (lo, hi) + +/-- `cacheBoundPair2` preserves pointwise values (first component). -/ +theorem cacheBoundPair2_apply_left {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2 f).1 q i = (f q).1 i := by + simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] + +/-- `cacheBoundPair2` preserves pointwise values (second component). -/ +theorem cacheBoundPair2_apply_right {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2 f).2 q i = (f q).2 i := by + simp [cacheBoundPair2, Thunk.get, Array.getElem_ofFn] + +/-- Cache a pair of bound functions on two indices using row tasks. -/ +def cacheBoundPair2Task {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) : + (Fin m → Fin n → Rat) × (Fin m → Fin n → Rat) := + let rowTasks : + Array (Task ({ rowLo : Array Rat // rowLo.size = n } × + { rowHi : Array Rat // rowHi.size = n })) := + Array.ofFn (fun q : Fin m => + Task.spawn (fun _ => + let row := f q + (⟨Array.ofFn row.1, by simp⟩, ⟨Array.ofFn row.2, by simp⟩))) + let lo : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get + row.1.1[i.1]'(by + simp [row.1.2, i.isLt]) + let hi : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])).get + row.2.1[i.1]'(by + simp [row.2.2, i.isLt]) + (lo, hi) + +/-- `cacheBoundPair2Task` preserves pointwise values (first component). -/ +theorem cacheBoundPair2Task_apply_left {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2Task f).1 q i = (f q).1 i := by + classical + simp [cacheBoundPair2Task, Task.spawn, Array.getElem_ofFn] + +/-- `cacheBoundPair2Task` preserves pointwise values (second component). -/ +theorem cacheBoundPair2Task_apply_right {m n : Nat} + (f : Fin m → (Fin n → Rat) × (Fin n → Rat)) (q : Fin m) (i : Fin n) : + (cacheBoundPair2Task f).2 q i = (f q).2 i := by + classical + simp [cacheBoundPair2Task, Task.spawn, Array.getElem_ofFn] + +/-- Cache a pair of bound functions on two indices using per-element tasks. -/ +def cacheBoundPair2TaskElem {m n : Nat} (f : Fin m → Fin n → Rat × Rat) : + (Fin m → Fin n → Rat) × (Fin m → Fin n → Rat) := + let rowTasks : Array (Array (Task (Rat × Rat))) := + Array.ofFn (fun q : Fin m => + Array.ofFn (fun i : Fin n => + Task.spawn (fun _ => f q i))) + let lo : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])) + let t := row[i.1]'(by + simp [row, rowTasks, i.isLt]) + (t.get).1 + let hi : Fin m → Fin n → Rat := fun q i => + let row := (rowTasks[q.1]'(by + simp [rowTasks, q.isLt])) + let t := row[i.1]'(by + simp [row, rowTasks, i.isLt]) + (t.get).2 + (lo, hi) + +/-- `cacheBoundPair2TaskElem` preserves pointwise values (first component). -/ +theorem cacheBoundPair2TaskElem_apply_left {m n : Nat} + (f : Fin m → Fin n → Rat × Rat) (q : Fin m) (i : Fin n) : + (cacheBoundPair2TaskElem f).1 q i = (f q i).1 := by + classical + simp [cacheBoundPair2TaskElem, Task.spawn, Array.getElem_ofFn] + +/-- `cacheBoundPair2TaskElem` preserves pointwise values (second component). -/ +theorem cacheBoundPair2TaskElem_apply_right {m n : Nat} + (f : Fin m → Fin n → Rat × Rat) (q : Fin m) (i : Fin n) : + (cacheBoundPair2TaskElem f).2 q i = (f q i).2 := by + classical + simp [cacheBoundPair2TaskElem, Task.spawn, Array.getElem_ofFn] + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/Gelu.lean b/Nfp/Bounds/Gelu.lean new file mode 100644 index 0000000..77edfce --- /dev/null +++ b/Nfp/Bounds/Gelu.lean @@ -0,0 +1,165 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.Order.Ring.Abs +public import Mathlib.Analysis.Complex.Trigonometric +public import Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic +public import Nfp.Core.Basic + +/-! +Tanh-based GELU bounds for GPT-2 style MLPs. +These bounds are used to propagate interval constraints through nonlinear gates. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +/-- Tanh-based GELU activation used by GPT-2 (approximate form). -/ +noncomputable def geluTanh (x : Real) : Real := + let k : Real := Real.sqrt (2 / Real.pi) + let c : Real := (44715 : Real) / 1000000 + x * ((1 + Real.tanh (k * (x + c * x ^ 3))) / 2) + +/-- The hyperbolic tangent is bounded in absolute value by `1`. -/ +theorem abs_tanh_le_one (x : Real) : |Real.tanh x| ≤ 1 := by + have hpos_exp : 0 < Real.exp x := Real.exp_pos x + have hpos_exp_neg : 0 < Real.exp (-x) := Real.exp_pos (-x) + have hsum_pos : 0 < Real.exp x + Real.exp (-x) := + add_pos hpos_exp hpos_exp_neg + have hsum_nonneg : 0 ≤ Real.exp x + Real.exp (-x) := le_of_lt hsum_pos + have habs : |Real.exp x - Real.exp (-x)| ≤ Real.exp x + Real.exp (-x) := by + have h := abs_add_le (Real.exp x) (-Real.exp (-x)) + simpa [sub_eq_add_neg, abs_neg, abs_of_nonneg (le_of_lt hpos_exp), + abs_of_nonneg (le_of_lt hpos_exp_neg)] using h + calc + |Real.tanh x| = + |(Real.exp x - Real.exp (-x)) / (Real.exp x + Real.exp (-x))| := by + simp [Real.tanh_eq] + _ = |Real.exp x - Real.exp (-x)| / (Real.exp x + Real.exp (-x)) := by + simp [abs_div, abs_of_nonneg hsum_nonneg] + _ ≤ (Real.exp x + Real.exp (-x)) / (Real.exp x + Real.exp (-x)) := by + exact div_le_div_of_nonneg_right habs hsum_nonneg + _ = 1 := by + have hne : Real.exp x + Real.exp (-x) ≠ 0 := ne_of_gt hsum_pos + simp [hne] + +/-- The tanh coefficient in `geluTanh` lies in `[0, 1]`. -/ +theorem geluTanh_coeff_bounds (x : Real) : + 0 ≤ + (1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2 ∧ + (1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2 ≤ 1 := by + have habs := + abs_tanh_le_one + (Real.sqrt (2 / Real.pi) * (x + (44715 : Real) / 1000000 * x ^ 3)) + have hbounds := abs_le.mp habs + constructor <;> nlinarith + +/-- `geluTanh` outputs stay between `min x 0` and `max x 0`. -/ +theorem geluTanh_bounds (x : Real) : + min x 0 ≤ geluTanh x ∧ geluTanh x ≤ max x 0 := by + by_cases hx : 0 ≤ x + · have hcoeff := geluTanh_coeff_bounds x + have hnonneg : + 0 ≤ x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) := by + exact mul_nonneg hx hcoeff.1 + have hle : + x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) ≤ x := by + have h := mul_le_mul_of_nonneg_left hcoeff.2 hx + simpa [mul_one] using h + simpa [geluTanh, min_eq_right hx, max_eq_left hx] using + And.intro hnonneg hle + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have hcoeff := geluTanh_coeff_bounds x + have hle0 : + x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) ≤ 0 := by + exact mul_nonpos_of_nonpos_of_nonneg hx' hcoeff.1 + have hxle : + x ≤ + x * + ((1 + + Real.tanh + (Real.sqrt (2 / Real.pi) * + (x + (44715 : Real) / 1000000 * x ^ 3))) / + 2) := by + have h := mul_le_mul_of_nonpos_left hcoeff.2 hx' + simpa [mul_one] using h + simpa [geluTanh, min_eq_left hx', max_eq_right hx'] using And.intro hxle hle0 + +/-- Interval bounds for GELU given input bounds. -/ +def geluInterval (lo hi : Rat) : Rat × Rat := + (if lo ≤ 0 then lo else 0, if 0 ≤ hi then hi else 0) + +/-- `geluInterval` soundly bounds `geluTanh` on a real interval. -/ +theorem geluInterval_bounds {lo hi : Rat} {x : Real} + (hlo : (lo : Real) ≤ x) (hhi : x ≤ (hi : Real)) : + (geluInterval lo hi).1 ≤ (geluTanh x : Real) ∧ + (geluTanh x : Real) ≤ (geluInterval lo hi).2 := by + have hgelu := geluTanh_bounds x + have hupper : geluTanh x ≤ (geluInterval lo hi).2 := by + have hmax : geluTanh x ≤ max (hi : Real) 0 := by + have hmax' : max x 0 ≤ max (hi : Real) 0 := max_le_max hhi le_rfl + exact le_trans hgelu.2 hmax' + by_cases hhi0 : 0 ≤ hi + · have hhi0r : 0 ≤ (hi : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hhi0 + simpa [geluInterval, hhi0, max_eq_left hhi0r] using hmax + · have hhi0r : (hi : Real) ≤ 0 := by + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := hi)).2 (le_of_not_ge hhi0) + have hx0 : x ≤ 0 := le_trans hhi hhi0r + have hmax' : max x 0 = 0 := max_eq_right hx0 + have hhi'' : geluTanh x ≤ (0 : Real) := by + simpa [hmax'] using hgelu.2 + simpa [geluInterval, hhi0, ratToReal_zero] using hhi'' + by_cases hlo0 : lo ≤ 0 + · have hlo0r : (lo : Real) ≤ 0 := by + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := lo)).2 hlo0 + have hmin : min (lo : Real) 0 ≤ min x 0 := min_le_min hlo le_rfl + have hlo' : (lo : Real) ≤ geluTanh x := by + have hmin' : (lo : Real) ≤ min x 0 := by + simpa [min_eq_left hlo0r] using hmin + exact le_trans hmin' hgelu.1 + constructor + · simpa [geluInterval, hlo0] using hlo' + · exact hupper + · have hlo0r : 0 ≤ (lo : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg (le_of_not_ge hlo0) + have hx0 : 0 ≤ x := le_trans hlo0r hlo + have hmin' : min x 0 = 0 := min_eq_right hx0 + have hlo' : (0 : Real) ≤ geluTanh x := by + simpa [hmin'] using hgelu.1 + constructor + · simpa [geluInterval, hlo0, ratToReal_zero] using hlo' + · exact hupper + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/LayerNorm.lean b/Nfp/Bounds/LayerNorm.lean new file mode 100644 index 0000000..28d0317 --- /dev/null +++ b/Nfp/Bounds/LayerNorm.lean @@ -0,0 +1,12 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Bounds.LayerNorm.Basic +public import Nfp.Bounds.LayerNorm.InvStd +public import Nfp.Bounds.LayerNorm.MeanVariance +public import Nfp.Bounds.LayerNorm.SqrtBounds + +/-! +LayerNorm bounds and supporting lemmas. +-/ diff --git a/Nfp/Bounds/LayerNorm/Basic.lean b/Nfp/Bounds/LayerNorm/Basic.lean new file mode 100644 index 0000000..24f1a98 --- /dev/null +++ b/Nfp/Bounds/LayerNorm/Basic.lean @@ -0,0 +1,1028 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Algebra.Order.Field.Basic +public import Mathlib.Algebra.Order.Ring.Basic +public import Mathlib.Data.Real.Sqrt +public import Mathlib.Data.Rat.BigOperators +public import Mathlib.Data.Rat.Cast.Order +public import Nfp.Core.Basic +public import Nfp.Bounds.LayerNorm.MeanVariance +public import Nfp.Bounds.LayerNorm.SqrtBounds +public import Nfp.Linear.FinFold + +/-! +LayerNorm interval bounds for rational inputs. + +This module computes rational interval bounds for LayerNorm outputs and proves +those bounds sound for real-valued LayerNorm semantics. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +open scoped BigOperators + +/-- Bounds for multiplying a scalar by a bounded value. -/ +def scaleInterval (x lo hi : Rat) : Rat × Rat := + if 0 ≤ x then + (x * lo, x * hi) + else + (x * hi, x * lo) + +/-- `scaleInterval` bounds a product. -/ +theorem scaleInterval_bounds {x lo hi y : Rat} + (hlo : lo ≤ y) (hhi : y ≤ hi) : + let bounds := scaleInterval x lo hi + bounds.1 ≤ x * y ∧ x * y ≤ bounds.2 := by + by_cases hx : 0 ≤ x + · have hbounds : x * lo ≤ x * y ∧ x * y ≤ x * hi := by + exact ⟨mul_le_mul_of_nonneg_left hlo hx, mul_le_mul_of_nonneg_left hhi hx⟩ + simpa [scaleInterval, hx] using hbounds + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have hbounds : x * hi ≤ x * y ∧ x * y ≤ x * lo := by + exact ⟨mul_le_mul_of_nonpos_left hhi hx', mul_le_mul_of_nonpos_left hlo hx'⟩ + simpa [scaleInterval, hx] using hbounds + +/-- `scaleInterval` bounds interpreted in the reals. -/ +theorem scaleInterval_bounds_real {x lo hi : Rat} {y : Real} + (hlo : (lo : Real) ≤ y) (hhi : y ≤ (hi : Real)) : + let bounds := scaleInterval x lo hi + (bounds.1 : Real) ≤ (x : Real) * y ∧ (x : Real) * y ≤ (bounds.2 : Real) := by + by_cases hx : 0 ≤ x + · have hx' : 0 ≤ (x : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hx + have hbounds : (x : Real) * (lo : Real) ≤ (x : Real) * y ∧ + (x : Real) * y ≤ (x : Real) * (hi : Real) := by + exact ⟨mul_le_mul_of_nonneg_left hlo hx', mul_le_mul_of_nonneg_left hhi hx'⟩ + simpa [scaleInterval, hx] using hbounds + · have hx' : x ≤ 0 := le_of_lt (lt_of_not_ge hx) + have hx'' : (x : Real) ≤ 0 := by + simpa [ratToReal_def] using (ratToReal_nonpos_iff (x := x)).2 hx' + have hbounds : (x : Real) * (hi : Real) ≤ (x : Real) * y ∧ + (x : Real) * y ≤ (x : Real) * (lo : Real) := by + exact ⟨mul_le_mul_of_nonpos_left hhi hx'', mul_le_mul_of_nonpos_left hlo hx''⟩ + simpa [scaleInterval, hx] using hbounds + +/-- Real-valued LayerNorm output for a vector. -/ +noncomputable def layerNormReal {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : Fin n → Real := + if n = 0 then + fun _ => 0 + else + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) + +/-- Unfolding lemma for `layerNormReal`. -/ +theorem layerNormReal_def {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + layerNormReal eps gamma beta x = + if n = 0 then + fun _ => 0 + else + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * ((x i : Real) - μ) * invStd + (beta i : Real) := by + simp [layerNormReal] + +/-- Real-valued LayerNorm output for a real vector. -/ +noncomputable def layerNormRealOfReal {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Real) : Fin n → Real := + if n = 0 then + fun _ => 0 + else + let μ : Real := meanReal x + let varEps : Real := varianceReal x + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + fun i => (gamma i : Real) * (x i - μ) * invStd + (beta i : Real) + +/-- Interval bounds for LayerNorm outputs. -/ +def layerNormBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μ : Rat := mean x + let centered : Fin n → Rat := fun i => x i - μ + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) + let sqrtUpperBound : Rat := sqrtUpper varEps + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let coeff : Fin n → Rat := fun i => gamma i * centered i + let lo : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdLower + else + beta i + coeff i * invStdUpper + let hi : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdUpper + else + beta i + coeff i * invStdLower + (lo, hi) + +/-- Interval bounds for LayerNorm outputs with a custom sqrt scale. -/ +def layerNormBoundsWithScale {n : Nat} (scale : Nat) + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μ : Rat := mean x + let centered : Fin n → Rat := fun i => x i - μ + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := + max (sqrtLowerWithScale scale eps) (sqrtLowerWithScale scale varEps) + let sqrtUpperBound : Rat := sqrtUpperWithScale scale varEps + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let coeff : Fin n → Rat := fun i => gamma i * centered i + let lo : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdLower + else + beta i + coeff i * invStdUpper + let hi : Fin n → Rat := fun i => + if 0 ≤ coeff i then + beta i + coeff i * invStdUpper + else + beta i + coeff i * invStdLower + (lo, hi) + +/-- `layerNormBounds` soundness for real LayerNorm outputs. -/ +theorem layerNormBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := layerNormBounds eps gamma beta x + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let μRat : Rat := mean x + let varRat : Rat := variance x + let varEpsRat : Rat := varRat + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsRat) + let sqrtUpperBound : Rat := sqrtUpper varEpsRat + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let centered : Rat := x i - μRat + let coeff : Rat := gamma i * centered + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hmu : (μRat : Real) = μ := by + simp [μRat, μ, mean_def, hne, ratRoundDown_def] + have hvar : (varRat : Real) = (varianceRat x : Real) := by + simp [varRat, variance_def, hne, ratRoundDown_def] + have hvarEps : (varEpsRat : Real) = varEps := by + simp [varEpsRat, varEps, hvar] + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_real : 0 ≤ ratToReal (varianceRat x) := by + simpa [ratToReal_def] using hvar_nonneg + have hvar_nonneg_rat : 0 ≤ varianceRat x := by + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg_real + have hvarRat_nonneg : 0 ≤ varRat := by + have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat + simpa [varRat, variance_def x hne] using h + have hvarEps_nonneg : 0 ≤ varEpsRat := by + exact add_nonneg hvarRat_nonneg (le_of_lt heps) + have hsqrt_lower : + (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) + have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by + have hsqrt_var' : + (sqrtLower varEpsRat : Real) ≤ Real.sqrt (varEpsRat : Real) := by + have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg + simpa using h + have hle : (varEpsRat : Real) ≤ varEps := by + simp [hvarEps] + exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) + have hmax : + max (sqrtLower eps : Real) (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := + (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ + simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_upper : + Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by + have h := real_sqrt_le_sqrtUpper (q := varEpsRat) hvarEps_nonneg + simpa [sqrtUpperBound, hvarEps] using h + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + have hpos : 0 < sqrtLower eps := hsqrt + have hpos' : 0 < max (sqrtLower eps) (sqrtLower varEpsRat) := + lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound] using hpos' + have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat + have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by + simpa [sqrtUpperBound] using sqrtUpper_pos varEpsRat + have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos + have hinv_lower_real : + (sqrtUpperBound : Real)⁻¹ ≤ invStd := by + have hle := inv_anti₀ hsqrt_pos hsqrt_upper + simpa [invStd] using hle + have hinv_upper_real : + invStd ≤ (sqrtLowerBound : Real)⁻¹ := by + have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using hle + have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat + have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hinv_lower : (invStdLower : Real) ≤ invStd := by + simpa [invStdLower, ratDivDown_def, hupper_ne, one_div] using hinv_lower_real + have hinv_upper : invStd ≤ (invStdUpper : Real) := by + simpa [invStdUpper, ratDivUp_def, hlower_ne, one_div] using hinv_upper_real + have hlayer : + layerNormReal eps gamma beta x i = + (beta i : Real) + (coeff : Real) * invStd := by + simp [layerNormReal, hne, coeff, centered, μ, hmu, invStd, varEps, add_comm, mul_assoc] + by_cases hcoeff : 0 ≤ coeff + · have hcoeff_real : 0 ≤ (coeff : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hcoeff + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonneg_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) := by + have hmul := mul_le_mul_of_nonneg_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + · have hcoeff_lt : coeff < 0 := lt_of_not_ge hcoeff + have hcoeff_real : (coeff : Real) ≤ 0 := by + exact_mod_cast (le_of_lt hcoeff_lt) + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonpos_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdLower : Real) := by + have hmul := mul_le_mul_of_nonpos_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBounds, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + +/-- `layerNormBoundsWithScale` soundness for real LayerNorm outputs. -/ +theorem layerNormBoundsWithScale_spec {n : Nat} {scale : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLowerWithScale scale eps) + (hscale : 0 < scale) : + let bounds := layerNormBoundsWithScale scale eps gamma beta x + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let μRat : Rat := mean x + let varRat : Rat := variance x + let varEpsRat : Rat := varRat + eps + let sqrtLowerBound : Rat := + max (sqrtLowerWithScale scale eps) (sqrtLowerWithScale scale varEpsRat) + let sqrtUpperBound : Rat := sqrtUpperWithScale scale varEpsRat + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let centered : Rat := x i - μRat + let coeff : Rat := gamma i * centered + let μ : Real := meanRat x + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hmu : (μRat : Real) = μ := by + simp [μRat, μ, mean_def, hne, ratRoundDown_def] + have hvar : (varRat : Real) = (varianceRat x : Real) := by + simp [varRat, variance_def, hne, ratRoundDown_def] + have hvarEps : (varEpsRat : Real) = varEps := by + simp [varEpsRat, varEps, hvar] + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_real : 0 ≤ ratToReal (varianceRat x) := by + simpa [ratToReal_def] using hvar_nonneg + have hvar_nonneg_rat : 0 ≤ varianceRat x := by + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg_real + have hvarRat_nonneg : 0 ≤ varRat := by + have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat + simpa [varRat, variance_def x hne] using h + have hvarEps_nonneg : 0 ≤ varEpsRat := by + exact add_nonneg hvarRat_nonneg (le_of_lt heps) + have hsqrt_lower : + (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : (sqrtLowerWithScale scale eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : + (sqrtLowerWithScale scale eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := + sqrtLowerWithScale_le_real_sqrt (q := eps) (scale := scale) + (by exact le_of_lt heps) hscale + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) + have hsqrt_var : + (sqrtLowerWithScale scale varEpsRat : Real) ≤ Real.sqrt varEps := by + have hsqrt_var' : + (sqrtLowerWithScale scale varEpsRat : Real) ≤ + Real.sqrt (varEpsRat : Real) := by + have h := + sqrtLowerWithScale_le_real_sqrt (q := varEpsRat) (scale := scale) + hvarEps_nonneg hscale + simpa using h + have hle : (varEpsRat : Real) ≤ varEps := by + simp [hvarEps] + exact le_trans hsqrt_var' (Real.sqrt_le_sqrt hle) + have hmax : + max (sqrtLowerWithScale scale eps : Real) + (sqrtLowerWithScale scale varEpsRat : Real) ≤ Real.sqrt varEps := + (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ + simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_upper : + Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by + have h := + real_sqrt_le_sqrtUpperWithScale (q := varEpsRat) (scale := scale) + hvarEps_nonneg hscale + simpa [sqrtUpperBound, hvarEps] using h + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + have hpos : 0 < sqrtLowerWithScale scale eps := hsqrt + have hpos' : 0 < max (sqrtLowerWithScale scale eps) (sqrtLowerWithScale scale varEpsRat) := + lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound] using hpos' + have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat + have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by + have hpos_real : 0 < (sqrtUpperWithScale scale varEpsRat : Real) := by + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + exact lt_of_lt_of_le (Real.sqrt_pos.2 hvarEps_pos) hsqrt_upper + exact (Rat.cast_pos (K := Real) (q := sqrtUpperWithScale scale varEpsRat)).1 hpos_real + have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos + have hinv_lower_real : + (sqrtUpperBound : Real)⁻¹ ≤ invStd := by + have hle := inv_anti₀ hsqrt_pos hsqrt_upper + simpa [invStd] using hle + have hinv_upper_real : + invStd ≤ (sqrtLowerBound : Real)⁻¹ := by + have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using hle + have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat + have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hinv_lower : (invStdLower : Real) ≤ invStd := by + simpa [invStdLower, ratDivDown_def, hupper_ne, one_div] using hinv_lower_real + have hinv_upper : invStd ≤ (invStdUpper : Real) := by + simpa [invStdUpper, ratDivUp_def, hlower_ne, one_div] using hinv_upper_real + have hlayer : + layerNormReal eps gamma beta x i = + (beta i : Real) + (coeff : Real) * invStd := by + simp [layerNormReal, hne, coeff, centered, μ, hmu, invStd, varEps, add_comm, mul_assoc] + by_cases hcoeff : 0 ≤ coeff + · have hcoeff_real : 0 ≤ (coeff : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hcoeff + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdLower : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonneg_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) := by + have hmul := mul_le_mul_of_nonneg_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + · have hcoeff_lt : coeff < 0 := lt_of_not_ge hcoeff + have hcoeff_real : (coeff : Real) ≤ 0 := by + exact_mod_cast (le_of_lt hcoeff_lt) + have hlow_raw : + (beta i : Real) + (coeff : Real) * (invStdUpper : Real) ≤ + (beta i : Real) + (coeff : Real) * invStd := by + have hmul := mul_le_mul_of_nonpos_left hinv_upper hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hhigh_raw : + (beta i : Real) + (coeff : Real) * invStd ≤ + (beta i : Real) + (coeff : Real) * (invStdLower : Real) := by + have hmul := mul_le_mul_of_nonpos_left hinv_lower hcoeff_real + simpa only [add_comm] using add_le_add_left hmul (beta i : Real) + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hlow_raw + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormBoundsWithScale, hne, μRat, centered, varRat, varEpsRat, + sqrtLowerBound, sqrtUpperBound, invStdLower, invStdUpper, coeff, hcoeff, hlayer] + using hhigh_raw + exact And.intro hlo hhi + +/-! +Local bounds for monotone multiplication in real-valued bounds. +-/ + +/-- Lower sqrt bound against the variance-plus-eps term. -/ +theorem sqrtLower_le_real_sqrt_varEps {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + intro varEps + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + +/-- Inverse-std upper bound from the lower sqrt bound. -/ +theorem invStd_le_invStdBound {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + invStd ≤ (invStdBound : Real) := by + intro varEps invStd invStdBound + have hsqrt_lower : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + simpa [varEps] using + (sqrtLower_le_real_sqrt_varEps (eps := eps) (x := x) hne heps) + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := by + simpa [ratToReal_def] using ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + exact le_trans hinv_sqrt hinv_bound + +/-- Inverse-std is nonnegative. -/ +theorem invStd_nonneg {n : Nat} (eps : Rat) (x : Fin n → Rat) : + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let invStd : Real := (Real.sqrt varEps)⁻¹ + 0 ≤ invStd := by + intro varEps invStd + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + +/-- Interval bounds for LayerNorm outputs from per-coordinate intervals. -/ +def layerNormIntervalBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + if n = 0 then + (fun _ => 0, fun _ => 0) + else + let μLo := mean lo + let μHi := meanUpper hi + let centeredBound : Fin n → Rat := fun i => + max |lo i - μHi| |hi i - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound i * invStdBound + (fun i => beta i - radius i, fun i => beta i + radius i) + +/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs. -/ +theorem layerNormIntervalBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ i, lo i ≤ x i) (hhi : ∀ i, x i ≤ hi i) : + let bounds := layerNormIntervalBounds eps gamma beta lo hi + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let μLo : Rat := mean lo + let μHi : Rat := meanUpper hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μ : Real := meanRat x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by + have h0 : 0 ≤ centeredBound i := by + dsimp [centeredBound] + exact le_trans (abs_nonneg _) (le_max_left _ _) + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg h0 + have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound i : Real) := by + have hmean_lo_real : (μLo : Real) ≤ μ := by + have hmean_rat : (meanRat lo : Real) ≤ (meanRat x : Real) := + meanRat_le_meanRat_real lo x hne hlo + have hdown : (μLo : Real) ≤ (meanRat lo : Real) := by + simpa [μLo, mean_def lo hne] using ratRoundDown_le_real (meanRat lo) + exact le_trans hdown hmean_rat + have hmean_hi_real : μ ≤ (μHi : Real) := by + have hmean_rat : (meanRat x : Real) ≤ (meanRat hi : Real) := + meanRat_le_meanRat_real x hi hne hhi + have hup : (meanRat hi : Real) ≤ (μHi : Real) := by + simpa [μHi, meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) + exact le_trans hmean_rat hup + have hlo' : (lo i : Real) - (μHi : Real) ≤ (x i : Real) - μ := by + have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by + exact sub_le_sub_left hmean_hi_real (lo i : Real) + have h2 : (lo i : Real) - μ ≤ (x i : Real) - μ := by + exact sub_le_sub_right + (by + simpa [ratToReal_def] using ratToReal_le_of_le (hlo i)) + μ + exact le_trans h1 h2 + have hhi' : (x i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + have h1 : (x i : Real) - μ ≤ (hi i : Real) - μ := by + exact sub_le_sub_right + (by + simpa [ratToReal_def] using ratToReal_le_of_le (hhi i)) + μ + have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + exact sub_le_sub_left hmean_lo_real (hi i : Real) + exact le_trans h1 h2 + have hbound := abs_le_max_of_bounds hlo' hhi' + simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, + ratToReal_max] using hbound + have hinv : invStd ≤ (invStdBound : Real) := by + simpa [varEps, invStd, invStdBound] using + (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) + have hinv_nonneg : 0 ≤ invStd := by + simp [varEps, invStd] + have hmul1 : |(x i : Real) - μ| * invStd ≤ + (centeredBound i : Real) * (invStdBound : Real) := by + have hleft : + |(x i : Real) - μ| * invStd ≤ (centeredBound i : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hlow + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hhigh + exact And.intro hlo hhi + +/-- Interval bounds for LayerNorm outputs from an absolute input bound. -/ +def layerNormAbsBounds {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) : + (Fin n → Rat) × (Fin n → Rat) := + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let radius : Fin n → Rat := fun i => |gamma i| * centeredBound * invStdBound + (fun i => beta i - radius i, fun i => beta i + radius i) + +/-- Bound a centered value by double the absolute bound. -/ +private theorem abs_sub_le_double_bound {a b bound : Real} + (ha : |a| ≤ bound) (hb : |b| ≤ bound) : + |a - b| ≤ bound + bound := by + have h1 : |a - b| ≤ |a| + |b| := by + simpa [sub_eq_add_neg, abs_neg] using abs_add_le a (-b) + exact le_trans h1 (add_le_add ha hb) + +/-- `layerNormAbsBounds` soundness for real LayerNorm outputs under absolute input bounds. -/ +theorem layerNormAbsBounds_spec {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (habs : ∀ i, |x i| ≤ absBound) : + let bounds := layerNormAbsBounds eps gamma beta absBound + ∀ i, + (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i ∧ + layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_abs_real : |(meanRat x : Real)| ≤ (absBound : Real) := by + have h := + meanReal_abs_le_bound (x := fun j => (x j : Real)) (bound := absBound) hne + (by + intro j + simpa [ratToReal_def] using ratToReal_abs_le_of_le (habs j)) + simpa [meanReal_eq_meanRat] using h + have hbound_nonneg : 0 ≤ absBound := by + have hposn : 0 < n := Nat.pos_of_ne_zero hne + let i0 : Fin n := ⟨0, hposn⟩ + have h0 : 0 ≤ |x i0| := abs_nonneg _ + exact le_trans h0 (habs i0) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := (varianceRat x : Real) + (eps : Real) + let μ : Real := meanRat x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_abs : |(x i : Real) - μ| ≤ (centeredBound : Real) := by + have hx : |(x i : Real)| ≤ (absBound : Real) := by + simpa [ratToReal_def] using ratToReal_abs_le_of_le (habs i) + have hmu : |μ| ≤ (absBound : Real) := by + simpa [μ] using hmean_abs_real + have h12 : |(x i : Real) - μ| ≤ (absBound : Real) + (absBound : Real) := + abs_sub_le_double_bound hx hmu + simpa [centeredBound, two_mul] using h12 + have hbound_nonneg_real : 0 ≤ (absBound : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hbound_nonneg + have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by + have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real + simpa [centeredBound, two_mul] using hsum + have hinv : invStd ≤ (invStdBound : Real) := by + simpa [varEps, invStd, invStdBound] using + (invStd_le_invStdBound (eps := eps) (x := x) hne heps hsqrt) + have hinv_nonneg : 0 ≤ invStd := by + simp [varEps, invStd] + have hmul1 : |(x i : Real) - μ| * invStd ≤ + (centeredBound : Real) * (invStdBound : Real) := by + have hleft : + |(x i : Real) - μ| * invStd ≤ (centeredBound : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |(x i : Real) - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|(x i : Real) - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * ((x i : Real) - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |(x i : Real) - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormReal eps gamma beta x i := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow + have hhi : layerNormReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh + exact And.intro hlo hhi + +/-- `layerNormAbsBounds` soundness for real LayerNorm outputs on real inputs. -/ +theorem layerNormAbsBounds_spec_real {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (absBound : Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (habs : ∀ i, |x i| ≤ (absBound : Real)) : + let bounds := layerNormAbsBounds eps gamma beta absBound + ∀ i, + (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ + layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_abs : |meanReal x| ≤ (absBound : Real) := + meanReal_abs_le_bound x absBound hne habs + have hbound_nonneg_real : 0 ≤ (absBound : Real) := by + have hposn : 0 < n := Nat.pos_of_ne_zero hne + let i0 : Fin n := ⟨0, hposn⟩ + have h0 : 0 ≤ |x i0| := abs_nonneg _ + exact le_trans h0 (habs i0) + let centeredBound : Rat := 2 * absBound + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := varianceReal x + (eps : Real) + let μ : Real := meanReal x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_abs : |x i - μ| ≤ (centeredBound : Real) := by + have hx : |x i| ≤ (absBound : Real) := habs i + have hmu : |μ| ≤ (absBound : Real) := by + simpa using hmean_abs + have h12 : |x i - μ| ≤ (absBound : Real) + (absBound : Real) := + abs_sub_le_double_bound hx hmu + simpa [centeredBound, two_mul] using h12 + have hcentered_nonneg : 0 ≤ (centeredBound : Real) := by + have hsum := add_nonneg hbound_nonneg_real hbound_nonneg_real + simpa [centeredBound, two_mul] using hsum + have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by + exact le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := by + simpa [ratToReal_def] using ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |x i - μ| * invStd ≤ + (centeredBound : Real) * (invStdBound : Real) := by + have hleft : + |x i - μ| * invStd ≤ (centeredBound : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound : Real) * invStd ≤ (centeredBound : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * (x i - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hlow + have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormAbsBounds, radius, centeredBound, invStdBound, hreal] using hhigh + exact And.intro hlo hhi + +/-- `layerNormIntervalBounds` soundness for real LayerNorm outputs on real inputs. -/ +theorem layerNormIntervalBounds_spec_real {n : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) (lo hi : Fin n → Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : + let bounds := layerNormIntervalBounds eps gamma beta lo hi + ∀ i, + (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i ∧ + layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + have hmean_lo : (mean lo : Real) ≤ meanReal x := by + have h := + meanReal_le_meanReal (x := fun j => (lo j : Real)) (y := x) hne + (fun j => hlo j) + have hrat : (meanRat lo : Real) ≤ meanReal x := by + simpa [meanReal_eq_meanRat] using h + have hdown : (mean lo : Real) ≤ (meanRat lo : Real) := by + simpa [mean_def lo hne] using ratRoundDown_le_real (meanRat lo) + exact le_trans hdown hrat + have hmean_hi : meanReal x ≤ (meanUpper hi : Real) := by + have h := + meanReal_le_meanReal (x := x) (y := fun j => (hi j : Real)) hne + (fun j => hhi j) + have hrat : meanReal x ≤ (meanRat hi : Real) := by + simpa [meanReal_eq_meanRat] using h + have hup : (meanRat hi : Real) ≤ (meanUpper hi : Real) := by + simpa [meanUpper_def hi hne] using real_le_ratRoundUp (meanRat hi) + exact le_trans hrat hup + let μLo : Rat := mean lo + let μHi : Rat := meanUpper hi + let centeredBound : Fin n → Rat := fun j => max |lo j - μHi| |hi j - μLo| + let invStdBound : Rat := ratDivUp 1 (sqrtLower eps) + let varEps : Real := varianceReal x + (eps : Real) + let μ : Real := meanReal x + let invStd : Real := (Real.sqrt varEps)⁻¹ + have hcentered_nonneg : 0 ≤ (centeredBound i : Real) := by + have h0 : 0 ≤ centeredBound i := by + dsimp [centeredBound] + exact le_trans (abs_nonneg _) (le_max_left _ _) + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg h0 + have hcentered_abs : |x i - μ| ≤ (centeredBound i : Real) := by + have hmean_lo_real : (μLo : Real) ≤ μ := by + simpa [μLo, μ] using hmean_lo + have hmean_hi_real : μ ≤ (μHi : Real) := by + simpa [μHi, μ] using hmean_hi + have hlo' : (lo i : Real) - (μHi : Real) ≤ x i - μ := by + have h1 : (lo i : Real) - (μHi : Real) ≤ (lo i : Real) - μ := by + exact sub_le_sub_left hmean_hi_real (lo i : Real) + have h2 : (lo i : Real) - μ ≤ x i - μ := by + exact sub_le_sub_right (hlo i) μ + exact le_trans h1 h2 + have hhi' : x i - μ ≤ (hi i : Real) - (μLo : Real) := by + have h1 : x i - μ ≤ (hi i : Real) - μ := by + exact sub_le_sub_right (hhi i) μ + have h2 : (hi i : Real) - μ ≤ (hi i : Real) - (μLo : Real) := by + exact sub_le_sub_left hmean_lo_real (hi i : Real) + exact le_trans h1 h2 + have hbound := abs_le_max_of_bounds hlo' hhi' + simpa [centeredBound, μLo, μHi, ratToReal_abs, ratToReal_sub, + ratToReal_max] using hbound + have hvar_nonneg : 0 ≤ varianceReal x := varianceReal_nonneg x hne + have hsqrt_lower : + (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : + (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ varianceReal x + (eps : Real) := by + exact le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + have hsqrt_eps' : Real.sqrt (eps : Real) ≤ Real.sqrt varEps := by + exact Real.sqrt_le_sqrt hle + exact le_trans hsqrt_eps hsqrt_eps' + have hsqrt_lower_pos : 0 < (sqrtLower eps : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLower eps)).2 hsqrt + have hinv_sqrt : invStd ≤ (sqrtLower eps : Real)⁻¹ := by + have h := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd] using h + have hinv_bound : (sqrtLower eps : Real)⁻¹ ≤ (invStdBound : Real) := by + have hy : sqrtLower eps ≠ 0 := ne_of_gt hsqrt + have hdiv := by + simpa [ratToReal_def] using ratDivUp_ge_real (x := 1) (y := sqrtLower eps) hy + simpa [invStdBound, one_div] using hdiv + have hinv : invStd ≤ (invStdBound : Real) := by + exact le_trans hinv_sqrt hinv_bound + have hinv_nonneg : 0 ≤ invStd := by + have hsqrt_nonneg : 0 ≤ Real.sqrt varEps := by + exact Real.sqrt_nonneg _ + exact inv_nonneg.2 hsqrt_nonneg + have hmul1 : |x i - μ| * invStd ≤ + (centeredBound i : Real) * (invStdBound : Real) := by + have hleft : |x i - μ| * invStd ≤ (centeredBound i : Real) * invStd := by + exact mul_le_mul_of_nonneg_right hcentered_abs hinv_nonneg + have hright : + (centeredBound i : Real) * invStd ≤ (centeredBound i : Real) * (invStdBound : Real) := by + exact mul_le_mul_of_nonneg_left hinv hcentered_nonneg + exact le_trans hleft hright + have hmul2 : |(gamma i : Real)| * |x i - μ| * invStd ≤ + |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have hgamma_nonneg : 0 ≤ |(gamma i : Real)| := abs_nonneg _ + have hmul2' : |(gamma i : Real)| * (|x i - μ| * invStd) ≤ + |(gamma i : Real)| * ((centeredBound i : Real) * (invStdBound : Real)) := by + exact mul_le_mul_of_nonneg_left hmul1 hgamma_nonneg + simpa only [mul_assoc] using hmul2' + let t : Real := (gamma i : Real) * (x i - μ) * invStd + have ht_abs : + |t| ≤ |(gamma i : Real)| * (centeredBound i : Real) * (invStdBound : Real) := by + have ht : |t| = |(gamma i : Real)| * |x i - μ| * invStd := by + have hinv_abs : |invStd| = invStd := abs_of_nonneg hinv_nonneg + simp [t, abs_mul, hinv_abs, mul_assoc] + simpa [ht] using hmul2 + let radius : Fin n → Rat := fun j => |gamma j| * centeredBound j * invStdBound + have ht_abs' : |t| ≤ (radius i : Real) := by + simpa [radius, centeredBound, invStdBound] using ht_abs + have hbounds : -(radius i : Real) ≤ t ∧ t ≤ (radius i : Real) := by + exact abs_le.mp ht_abs' + have hlow : + (beta i : Real) - (radius i : Real) ≤ t + (beta i : Real) := by + have h : (beta i : Real) + -(radius i : Real) ≤ (beta i : Real) + t := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.1 (beta i : Real) + simpa only [sub_eq_add_neg, add_comm, add_left_comm, add_assoc] using h + have hhigh : + t + (beta i : Real) ≤ (beta i : Real) + (radius i : Real) := by + have h : (beta i : Real) + t ≤ (beta i : Real) + (radius i : Real) := by + simpa only [add_comm, add_left_comm, add_assoc] using + add_le_add_left hbounds.2 (beta i : Real) + simpa only [add_comm, add_left_comm, add_assoc] using h + have hreal : + layerNormRealOfReal eps gamma beta x i = t + (beta i : Real) := by + simp [layerNormRealOfReal, hne, μ, invStd, varEps, t, add_comm] + have hlo : (bounds.1 i : Real) ≤ layerNormRealOfReal eps gamma beta x i := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hlow + have hhi : layerNormRealOfReal eps gamma beta x i ≤ (bounds.2 i : Real) := by + simpa [bounds, layerNormIntervalBounds, hne, radius, centeredBound, invStdBound, μLo, μHi, + hreal] using hhigh + exact And.intro hlo hhi + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/LayerNorm/InvStd.lean b/Nfp/Bounds/LayerNorm/InvStd.lean new file mode 100644 index 0000000..c4dadec --- /dev/null +++ b/Nfp/Bounds/LayerNorm/InvStd.lean @@ -0,0 +1,138 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Bounds.LayerNorm.MeanVariance +public import Nfp.Bounds.LayerNorm.SqrtBounds + +/-! +Inverse-standard-deviation bounds for LayerNorm. + +This module isolates invStd bounds and their soundness proof to keep +`LayerNorm/Basic.lean` below the style linter's file-length limit. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +/-- Bounds for the LayerNorm inverse standard deviation term. -/ +def invStdBounds {n : Nat} (eps : Rat) (x : Fin n → Rat) : Rat × Rat := + if n = 0 then + (0, 0) + else + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) + let sqrtUpperBound : Rat := sqrtUpper varEps + (ratDivDown 1 sqrtUpperBound, ratDivUp 1 sqrtLowerBound) + +/-- Unfolding lemma for `invStdBounds`. -/ +theorem invStdBounds_def {n : Nat} (eps : Rat) (x : Fin n → Rat) : + invStdBounds eps x = + if n = 0 then + (0, 0) + else + let var : Rat := variance x + let varEps : Rat := var + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEps) + let sqrtUpperBound : Rat := sqrtUpper varEps + (ratDivDown 1 sqrtUpperBound, ratDivUp 1 sqrtLowerBound) := by + simp [invStdBounds] + +/-- `invStdBounds` soundness for real inverse-std terms. -/ +theorem invStdBounds_spec {n : Nat} (eps : Rat) (x : Fin n → Rat) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := invStdBounds eps x + let invStd : Real := (Real.sqrt ((varianceRat x : Real) + (eps : Real)))⁻¹ + (bounds.1 : Real) ≤ invStd ∧ invStd ≤ (bounds.2 : Real) := by + classical + intro bounds invStd + let varRat : Rat := variance x + let varEpsRat : Rat := varRat + eps + let sqrtLowerBound : Rat := max (sqrtLower eps) (sqrtLower varEpsRat) + let sqrtUpperBound : Rat := sqrtUpper varEpsRat + let invStdLower : Rat := ratDivDown 1 sqrtUpperBound + let invStdUpper : Rat := ratDivUp 1 sqrtLowerBound + let varEps : Real := (varianceRat x : Real) + (eps : Real) + have hvar : (varRat : Real) = (varianceRat x : Real) := by + simp [varRat, variance_def, hne, ratRoundDown_def] + have hvarEps : (varEpsRat : Real) = varEps := by + simp [varEpsRat, varEps, hvar] + have hvar_nonneg : 0 ≤ (varianceRat x : Real) := varianceRat_nonneg_real x hne + have hvar_nonneg_rat : 0 ≤ varianceRat x := by + have hvar_nonneg_real : 0 ≤ ratToReal (varianceRat x) := by + simpa [ratToReal_def] using hvar_nonneg + exact (ratToReal_nonneg_iff (x := varianceRat x)).1 hvar_nonneg_real + have hvarRat_nonneg : 0 ≤ varRat := by + have h := ratRoundDown_nonneg (q := varianceRat x) hvar_nonneg_rat + simpa [varRat, variance_def x hne] using h + have hvarEps_nonneg : 0 ≤ varEpsRat := by + exact add_nonneg hvarRat_nonneg (le_of_lt heps) + have hsqrt_var : (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := by + have h := sqrtLower_le_real_sqrt (q := varEpsRat) hvarEps_nonneg + simpa [hvarEps] using h + have hsqrt_lower : + (sqrtLowerBound : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps : (sqrtLower eps : Real) ≤ Real.sqrt varEps := by + have hsqrt_eps' : (sqrtLower eps : Real) ≤ Real.sqrt (eps : Real) := by + have h := sqrtLower_le_real_sqrt (q := eps) (by exact le_of_lt heps) + simpa using h + have hle : (eps : Real) ≤ varEps := by + have hle' : (eps : Real) ≤ (varianceRat x : Real) + (eps : Real) := + le_add_of_nonneg_left hvar_nonneg + simpa [varEps] using hle' + exact le_trans hsqrt_eps' (Real.sqrt_le_sqrt hle) + have hmax : + max (sqrtLower eps : Real) (sqrtLower varEpsRat : Real) ≤ Real.sqrt varEps := + (max_le_iff).2 ⟨hsqrt_eps, hsqrt_var⟩ + simpa [sqrtLowerBound, ratToReal_max] using hmax + have hsqrt_upper : + Real.sqrt varEps ≤ (sqrtUpperBound : Real) := by + have h := real_sqrt_le_sqrtUpper (q := varEpsRat) hvarEps_nonneg + simpa [sqrtUpperBound, hvarEps] using h + have hsqrt_lower_pos_rat : 0 < sqrtLowerBound := by + have hpos : 0 < sqrtLower eps := hsqrt + have hpos' : 0 < max (sqrtLower eps) (sqrtLower varEpsRat) := + lt_of_lt_of_le hpos (le_max_left _ _) + simpa [sqrtLowerBound] using hpos' + have hsqrt_lower_pos : 0 < (sqrtLowerBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtLowerBound)).2 hsqrt_lower_pos_rat + have hsqrt_upper_pos_rat : 0 < sqrtUpperBound := by + simpa [sqrtUpperBound] using sqrtUpper_pos varEpsRat + have hsqrt_upper_pos : 0 < (sqrtUpperBound : Real) := by + exact (Rat.cast_pos (K := Real) (q := sqrtUpperBound)).2 hsqrt_upper_pos_rat + have hvarEps_pos : 0 < varEps := by + have heps_real : 0 < (eps : Real) := by + exact_mod_cast heps + have hpos := add_pos_of_nonneg_of_pos hvar_nonneg heps_real + simpa [varEps] using hpos + have hsqrt_pos : 0 < Real.sqrt varEps := Real.sqrt_pos.2 hvarEps_pos + have hinv_lower_real : + (sqrtUpperBound : Real)⁻¹ ≤ invStd := by + have hle := inv_anti₀ hsqrt_pos hsqrt_upper + simpa [invStd, varEps] using hle + have hinv_upper_real : + invStd ≤ (sqrtLowerBound : Real)⁻¹ := by + have hle := inv_anti₀ hsqrt_lower_pos hsqrt_lower + simpa [invStd, varEps] using hle + have hupper_ne : sqrtUpperBound ≠ 0 := ne_of_gt hsqrt_upper_pos_rat + have hlower_ne : sqrtLowerBound ≠ 0 := ne_of_gt hsqrt_lower_pos_rat + have hinv_lower : (invStdLower : Real) ≤ invStd := by + simpa [invStdLower, ratDivDown_def, hupper_ne, one_div] using hinv_lower_real + have hinv_upper : invStd ≤ (invStdUpper : Real) := by + simpa [invStdUpper, ratDivUp_def, hlower_ne, one_div] using hinv_upper_real + have hbounds : bounds = (invStdLower, invStdUpper) := by + simp [bounds, invStdBounds, hne, varRat, varEpsRat, sqrtLowerBound, sqrtUpperBound, + invStdLower, invStdUpper] + constructor + · simpa [hbounds] using hinv_lower + · simpa [hbounds] using hinv_upper + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/LayerNorm/MeanVariance.lean b/Nfp/Bounds/LayerNorm/MeanVariance.lean new file mode 100644 index 0000000..ef3bf3f --- /dev/null +++ b/Nfp/Bounds/LayerNorm/MeanVariance.lean @@ -0,0 +1,245 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Algebra.Order.Field.Basic +public import Mathlib.Algebra.Order.Ring.Basic +public import Mathlib.Data.Rat.BigOperators +public import Mathlib.Data.Rat.Cast.Order +public import Nfp.Core.Basic + +/-! +Mean/variance helpers for LayerNorm bounds. + +This module isolates the rational and real mean/variance definitions and their +basic lemmas to keep `LayerNorm` bounds modular. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +open scoped BigOperators + +/-- Sum as a rational, used for exact mean/variance computations. -/ +def sumRat {n : Nat} (x : Fin n → Rat) : Rat := + ∑ i, (x i : Rat) + +/-- Exact mean as a rational (defaults to `0` when `n = 0`). -/ +def meanRat {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + (sumRat x) / n + +/-- Mean rounded down (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def mean {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + ratRoundDown (meanRat x) + +/-- Mean rounded up (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def meanUpper {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + ratRoundUp (meanRat x) + +/-- Unfold `mean` when `n ≠ 0`. -/ +theorem mean_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + mean x = ratRoundDown (meanRat x) := by + simp [mean, h] + +/-- Unfold `meanUpper` when `n ≠ 0`. -/ +theorem meanUpper_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + meanUpper x = ratRoundUp (meanRat x) := by + simp [meanUpper, h] + +/-- Exact variance as a rational (defaults to `0` when `n = 0`). -/ +def varianceRat {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + let μ := meanRat x + (∑ i, ((x i : Rat) - μ) ^ 2) / n + +/-- Variance rounded down (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def variance {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + ratRoundDown (varianceRat x) + +/-- Variance rounded up (identity in exact-rational mode; defaults to `0` when `n = 0`). -/ +def varianceUpper {n : Nat} (x : Fin n → Rat) : Rat := + if n = 0 then + 0 + else + ratRoundUp (varianceRat x) + +/-- Unfold `variance` when `n ≠ 0`. -/ +theorem variance_def {n : Nat} (x : Fin n → Rat) (h : n ≠ 0) : + variance x = ratRoundDown (varianceRat x) := by + simp [variance, h] + +/-! Interval helpers. -/ + +/-- Absolute value bound from endpoint bounds. -/ +theorem abs_le_max_of_bounds {α : Type _} [Ring α] [LinearOrder α] [IsOrderedRing α] + {a b z : α} + (hlo : a ≤ z) (hhi : z ≤ b) : + |z| ≤ max |a| |b| := by + have hleft : -max |a| |b| ≤ z := by + have hneg : -max |a| |b| ≤ a := by + have hneg' : -max |a| |b| ≤ -|a| := by + exact neg_le_neg (le_max_left _ _) + have hneg'' : -|a| ≤ a := by + have h : -a ≤ |a| := neg_le_abs a + simpa using (neg_le_neg h) + exact le_trans hneg' hneg'' + exact le_trans hneg hlo + have hright : z ≤ max |a| |b| := by + have hb : b ≤ |b| := by + exact le_abs_self b + have hb' : b ≤ max |a| |b| := by + exact le_trans hb (le_max_right _ _) + exact le_trans hhi hb' + exact (abs_le.mpr ⟨hleft, hright⟩) + +/-! Real-valued mean and variance. -/ + +/-- Mean of a real vector (defaults to `0` when `n = 0`). -/ +noncomputable def meanReal {n : Nat} (x : Fin n → Real) : Real := + if n = 0 then + 0 + else + (∑ i, x i) / n + +/-- Unfold `meanReal` when `n ≠ 0`. -/ +theorem meanReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + meanReal x = (∑ i, x i) / n := by + simp [meanReal, h] + +/-- `sumRat` agrees with the real sum after casting. -/ +theorem sumRat_cast {n : Nat} (x : Fin n → Rat) : + (sumRat x : Real) = ∑ i, (x i : Real) := by + classical + simp [sumRat, Rat.cast_sum] + +/-- `meanReal` agrees with `meanRat` when `n ≠ 0`. -/ +theorem meanReal_eq_meanRat_of_ne {n : Nat} (x : Fin n → Rat) (hne : n ≠ 0) : + meanReal (fun i => (x i : Real)) = (meanRat x : Real) := by + classical + simp [meanReal, meanRat, sumRat_cast, hne] + +/-- `meanReal` agrees with `mean` after casting. -/ +theorem meanReal_eq_meanRat {n : Nat} (x : Fin n → Rat) : + meanReal (fun i => (x i : Real)) = (meanRat x : Real) := by + by_cases h : n = 0 + · simp [meanReal, meanRat, h] + · simpa [h] using meanReal_eq_meanRat_of_ne (x := x) h + +/-- Mean is monotone under pointwise order (real inputs). -/ +theorem meanReal_le_meanReal {n : Nat} (x y : Fin n → Real) (hne : n ≠ 0) + (hxy : ∀ i, x i ≤ y i) : meanReal x ≤ meanReal y := by + classical + have hsum : (∑ i, x i) ≤ ∑ i, y i := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hxy i + have hden : 0 ≤ (n : Real) := by + simp + have hdiv : (∑ i, x i) / n ≤ (∑ i, y i) / n := + div_le_div_of_nonneg_right hsum hden + simpa [meanReal, hne] using hdiv + +/-- Mean monotonicity for rational inputs, interpreted in reals. -/ +theorem meanRat_le_meanRat_real {n : Nat} (x y : Fin n → Rat) (hne : n ≠ 0) + (hxy : ∀ i, x i ≤ y i) : + (meanRat x : Real) ≤ (meanRat y : Real) := by + have hreal : + meanReal (fun i => (x i : Real)) ≤ meanReal (fun i => (y i : Real)) := by + refine meanReal_le_meanReal (x := fun i => (x i : Real)) (y := fun i => (y i : Real)) hne ?_ + intro i + simpa [ratToReal_def] using ratToReal_le_of_le (hxy i) + simpa [meanReal_eq_meanRat] using hreal + +/-- Variance of a real vector (defaults to `0` when `n = 0`). -/ +noncomputable def varianceReal {n : Nat} (x : Fin n → Real) : Real := + if n = 0 then + 0 + else + let μ := meanReal x + (∑ i, (x i - μ) ^ 2) / n + +/-- Unfold `varianceReal` when `n ≠ 0`. -/ +theorem varianceReal_def {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + varianceReal x = + let μ := meanReal x + (∑ i, (x i - μ) ^ 2) / n := by + simp [varianceReal, h] + +/-- Variance is nonnegative when `n ≠ 0`. -/ +theorem varianceReal_nonneg {n : Nat} (x : Fin n → Real) (h : n ≠ 0) : + 0 ≤ varianceReal x := by + classical + have hsum : 0 ≤ ∑ i, (x i - meanReal x) ^ 2 := by + refine Finset.sum_nonneg ?_ + intro i _ + exact sq_nonneg (x i - meanReal x) + have hden : 0 ≤ (n : Real) := by + simp + have hdiv : 0 ≤ (∑ i, (x i - meanReal x) ^ 2) / n := + div_nonneg hsum hden + simpa [varianceReal_def x h] using hdiv + +theorem varianceReal_eq_varianceRat {n : Nat} (x : Fin n → Rat) : + varianceReal (fun i => (x i : Real)) = (varianceRat x : Real) := by + by_cases h : n = 0 + · simp [varianceReal, varianceRat, h] + · classical + simp [varianceReal, varianceRat, h, meanReal_eq_meanRat_of_ne (x := x) h, Rat.cast_sum] + +/-- Variance is nonnegative when `n ≠ 0`, interpreted in reals. -/ +theorem varianceRat_nonneg_real {n : Nat} (x : Fin n → Rat) (hne : n ≠ 0) : + 0 ≤ (varianceRat x : Real) := by + have hreal := varianceReal_nonneg (x := fun i => (x i : Real)) hne + simpa [varianceReal_eq_varianceRat] using hreal + +/-- Absolute mean bound from per-coordinate bounds (real inputs). -/ +theorem meanReal_abs_le_bound {n : Nat} (x : Fin n → Real) (bound : Rat) + (hne : n ≠ 0) (hbound : ∀ i, |x i| ≤ (bound : Real)) : + |meanReal x| ≤ (bound : Real) := by + classical + have hsum_abs : |∑ i : Fin n, x i| ≤ ∑ i : Fin n, |x i| := by + simpa using + (Finset.abs_sum_le_sum_abs + (f := fun i : Fin n => x i) + (s := (Finset.univ : Finset (Fin n)))) + have hsum_bound : ∑ i : Fin n, |x i| ≤ ∑ i : Fin n, (bound : Real) := by + refine Finset.sum_le_sum ?_ + intro i _ + exact hbound i + have hsum_le : |∑ i : Fin n, x i| ≤ (bound : Real) * (n : Real) := by + have hsum := le_trans hsum_abs hsum_bound + simpa [Finset.sum_const, Finset.card_univ, mul_comm] using hsum + have hpos : 0 < (n : Real) := by + exact (Nat.cast_pos (α := Real)).2 (Nat.pos_of_ne_zero hne) + have hdiv : |∑ i : Fin n, x i| / (n : Real) ≤ (bound : Real) := by + exact (div_le_iff₀ hpos).2 hsum_le + have habs_mean : + |(∑ i : Fin n, x i) / (n : Real)| ≤ (bound : Real) := by + simpa [abs_div, abs_of_nonneg (le_of_lt hpos)] using hdiv + simpa [meanReal_def x hne] using habs_mean + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/LayerNorm/SqrtBounds.lean b/Nfp/Bounds/LayerNorm/SqrtBounds.lean new file mode 100644 index 0000000..c1f0828 --- /dev/null +++ b/Nfp/Bounds/LayerNorm/SqrtBounds.lean @@ -0,0 +1,829 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.Order.Field.Basic +public import Mathlib.Algebra.Order.Ring.Basic +public import Mathlib.Data.Nat.Sqrt +public import Mathlib.Data.Real.Sqrt +public import Mathlib.Data.Rat.Cast.Order +public import Nfp.Core.Basic + +/-! +Square-root bounds for LayerNorm intervals. + +This module isolates the rational sqrt lower/upper bounds and their basic +nonnegativity/positivity lemmas so the main LayerNorm bounds stay focused. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +/-! Square-root bounds. -/ + +lemma rat_nat_cast_nonneg (n : Nat) : (0 : Rat) ≤ (n : Rat) := by + simp + +lemma rat_nat_cast_pos {n : Nat} (h : 0 < n) : (0 : Rat) < (n : Rat) := by + exact (Nat.cast_pos (α := Rat)).2 h + +/-- `ratRoundDown` preserves nonnegativity for nonnegative divisions. -/ +theorem ratRoundDown_nonneg_div {a b : Rat} (ha : 0 ≤ a) (hb : 0 ≤ b) : + 0 ≤ ratRoundDown (a / b) := by + exact ratRoundDown_nonneg (q := a / b) (by exact div_nonneg ha hb) + +/-- `ratRoundUp` preserves nonnegativity for nonnegative divisions. -/ +theorem ratRoundUp_nonneg_div {a b : Rat} (ha : 0 ≤ a) (hb : 0 ≤ b) : + 0 ≤ ratRoundUp (a / b) := by + exact ratRoundUp_nonneg (q := a / b) (by exact div_nonneg ha hb) + +/-- `ratRoundUp` preserves positivity for positive divisions. -/ +theorem ratRoundUp_pos_div {a b : Rat} (ha : 0 < a) (hb : 0 < b) : + 0 < ratRoundUp (a / b) := by + exact ratRoundUp_pos (q := a / b) (by exact div_pos ha hb) + +/-- Base rational lower bound for a square root. -/ +def sqrtLowerBase (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt num + let b := Nat.sqrt den + ratRoundDown ((a : Rat) / (b + 1)) + +/-- Base rational upper bound for a square root. -/ +def sqrtUpperBase (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt num + let b := Nat.sqrt den + ratRoundUp ((a + 1 : Rat) / b) + +/-- Alternate rational lower bound for a square root. -/ +def sqrtLowerAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den) + ratRoundDown ((a : Rat) / den) + +/-- Alternate rational upper bound for a square root. -/ +def sqrtUpperAlt (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den) + ratRoundUp ((a + 1 : Rat) / den) + +/-- Extra precision scale for `sqrtLowerScaled`. -/ +def sqrtLowerScale : Nat := 1048576 + +/-- Unfolding lemma for `sqrtLowerScale`. -/ +theorem sqrtLowerScale_def : sqrtLowerScale = 1048576 := by + rfl + +/-- Scaled rational lower bound for a square root (extra precision). -/ +def sqrtLowerScaled (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let scale := sqrtLowerScale + let a := Nat.sqrt (num * den * scale * scale) + ratRoundDown ((a : Rat) / (den * scale)) + +/-- Scaled rational upper bound for a square root (extra precision). -/ +def sqrtUpperScaled (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let scale := sqrtLowerScale + let a := Nat.sqrt (num * den * scale * scale) + ratRoundUp ((a + 1 : Rat) / (den * scale)) + +/-- Scaled rational lower bound for a square root with a custom scale. -/ +def sqrtLowerScaledWith (scale : Nat) (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den * scale * scale) + ratRoundDown ((a : Rat) / (den * scale)) + +/-- Scaled rational upper bound for a square root with a custom scale. -/ +def sqrtUpperScaledWith (scale : Nat) (q : Rat) : Rat := + let num := q.num.natAbs + let den := q.den + let a := Nat.sqrt (num * den * scale * scale) + ratRoundUp ((a + 1 : Rat) / (den * scale)) + +/-- Rational lower bound for a square root (tighter of three bounds). -/ +def sqrtLower (q : Rat) : Rat := + max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) + +/-- Rational upper bound for a square root (tighter of three bounds). -/ +def sqrtUpper (q : Rat) : Rat := + min (min (sqrtUpperBase q) (sqrtUpperAlt q)) (sqrtUpperScaled q) + +/-- Rational lower bound for a square root with a custom scale. -/ +def sqrtLowerWithScale (scale : Nat) (q : Rat) : Rat := + max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaledWith scale q) + +/-- Rational upper bound for a square root with a custom scale. -/ +def sqrtUpperWithScale (scale : Nat) (q : Rat) : Rat := + min (min (sqrtUpperBase q) (sqrtUpperAlt q)) (sqrtUpperScaledWith scale q) + +/-- `sqrtLowerBase` is nonnegative. -/ +theorem sqrtLowerBase_nonneg (q : Rat) : 0 ≤ sqrtLowerBase q := by + classical + unfold sqrtLowerBase + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) := + rat_nat_cast_nonneg (Nat.sqrt q.num.natAbs) + have hden : 0 ≤ (Nat.sqrt q.den : Rat) + 1 := by + simpa using rat_nat_cast_nonneg (Nat.sqrt q.den + 1) + exact ratRoundDown_nonneg_div hnum hden + +/-- `sqrtUpperBase` is nonnegative. -/ +theorem sqrtUpperBase_nonneg (q : Rat) : 0 ≤ sqrtUpperBase q := by + classical + unfold sqrtUpperBase + have hnum : 0 ≤ (Nat.sqrt q.num.natAbs : Rat) + 1 := by + simpa using rat_nat_cast_nonneg (Nat.sqrt q.num.natAbs + 1) + have hden : 0 ≤ (Nat.sqrt q.den : Rat) := + rat_nat_cast_nonneg (Nat.sqrt q.den) + exact ratRoundUp_nonneg_div hnum hden + +/-- `sqrtUpperBase` is always positive. -/ +theorem sqrtUpperBase_pos (q : Rat) : 0 < sqrtUpperBase q := by + classical + unfold sqrtUpperBase + have hnum_pos : (0 : Rat) < (Nat.sqrt q.num.natAbs : Rat) + 1 := by + simpa using rat_nat_cast_pos (Nat.succ_pos (Nat.sqrt q.num.natAbs)) + have hden_pos : (0 : Rat) < (Nat.sqrt q.den : Rat) := by + have hden : 0 < q.den := q.den_pos + exact rat_nat_cast_pos (Nat.sqrt_pos.2 hden) + exact ratRoundUp_pos_div hnum_pos hden_pos + +/-- `sqrtLowerAlt` is nonnegative. -/ +theorem sqrtLowerAlt_nonneg (q : Rat) : 0 ≤ sqrtLowerAlt q := by + classical + unfold sqrtLowerAlt + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) := + rat_nat_cast_nonneg (Nat.sqrt (q.num.natAbs * q.den)) + have hden : 0 ≤ (q.den : Rat) := + rat_nat_cast_nonneg q.den + exact ratRoundDown_nonneg_div hnum hden + +/-- `sqrtUpperAlt` is nonnegative. -/ +theorem sqrtUpperAlt_nonneg (q : Rat) : 0 ≤ sqrtUpperAlt q := by + classical + unfold sqrtUpperAlt + have hnum : 0 ≤ (Nat.sqrt (q.num.natAbs * q.den) : Rat) + 1 := by + simpa using rat_nat_cast_nonneg (Nat.sqrt (q.num.natAbs * q.den) + 1) + have hden : 0 ≤ (q.den : Rat) := + rat_nat_cast_nonneg q.den + exact ratRoundUp_nonneg_div hnum hden + +/-- `sqrtUpperAlt` is always positive. -/ +theorem sqrtUpperAlt_pos (q : Rat) : 0 < sqrtUpperAlt q := by + classical + unfold sqrtUpperAlt + have hnum_pos : + (0 : Rat) < (Nat.sqrt (q.num.natAbs * q.den) : Rat) + 1 := by + simpa using rat_nat_cast_pos (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den))) + have hden_pos : (0 : Rat) < (q.den : Rat) := + rat_nat_cast_pos q.den_pos + exact ratRoundUp_pos_div hnum_pos hden_pos + +/-- `sqrtUpperScaled` is nonnegative. -/ +theorem sqrtUpperScaled_nonneg (q : Rat) : 0 ≤ sqrtUpperScaled q := by + classical + unfold sqrtUpperScaled + have hnum : + 0 ≤ (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) : Rat) + 1 := by + simpa using rat_nat_cast_nonneg + (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) + 1) + have hden : 0 ≤ (q.den : Rat) * (sqrtLowerScale : Rat) := by + simpa [Nat.cast_mul] using rat_nat_cast_nonneg (q.den * sqrtLowerScale) + exact ratRoundUp_nonneg_div hnum hden + +/-- `sqrtUpperScaled` is always positive. -/ +theorem sqrtUpperScaled_pos (q : Rat) : 0 < sqrtUpperScaled q := by + classical + unfold sqrtUpperScaled + have hnum_pos : + (0 : Rat) < + (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale) : Rat) + 1 := by + simpa using rat_nat_cast_pos + (Nat.succ_pos (Nat.sqrt (q.num.natAbs * q.den * sqrtLowerScale * sqrtLowerScale))) + have hden_pos : (0 : Rat) < (q.den : Rat) * (sqrtLowerScale : Rat) := by + have hden : 0 < q.den := q.den_pos + have hscale : 0 < sqrtLowerScale := by + simp [sqrtLowerScale] + simpa [Nat.cast_mul] using rat_nat_cast_pos (Nat.mul_pos hden hscale) + exact ratRoundUp_pos_div hnum_pos hden_pos + +/-! Combined bounds. -/ + +/-- `sqrtLower` is nonnegative. -/ +theorem sqrtLower_nonneg (q : Rat) : 0 ≤ sqrtLower q := by + have hbase : 0 ≤ sqrtLowerBase q := sqrtLowerBase_nonneg q + have hmax : 0 ≤ max (sqrtLowerBase q) (sqrtLowerAlt q) := + le_trans hbase (le_max_left _ _) + exact le_trans hmax (le_max_left _ _) + +/-- `sqrtUpper` is nonnegative. -/ +theorem sqrtUpper_nonneg (q : Rat) : 0 ≤ sqrtUpper q := by + have hbase : 0 ≤ sqrtUpperBase q := sqrtUpperBase_nonneg q + have halt : 0 ≤ sqrtUpperAlt q := sqrtUpperAlt_nonneg q + have hscaled : 0 ≤ sqrtUpperScaled q := sqrtUpperScaled_nonneg q + have hmin1 : 0 ≤ min (sqrtUpperBase q) (sqrtUpperAlt q) := by + exact le_min hbase halt + exact le_min hmin1 hscaled + +/-- `sqrtUpper` is always positive. -/ +theorem sqrtUpper_pos (q : Rat) : 0 < sqrtUpper q := by + have hbase : 0 < sqrtUpperBase q := sqrtUpperBase_pos q + have halt : 0 < sqrtUpperAlt q := sqrtUpperAlt_pos q + have hscaled : 0 < sqrtUpperScaled q := sqrtUpperScaled_pos q + have hmin1 : 0 < min (sqrtUpperBase q) (sqrtUpperAlt q) := by + exact lt_min hbase halt + exact lt_min hmin1 hscaled + +/-! Real-valued bounds. -/ + +/-- Cast a nonnegative rational as `num.natAbs / den`. -/ +theorem rat_cast_eq_num_den {q : Rat} (hq : 0 ≤ q) : + (q : Real) = (q.num.natAbs : Real) / q.den := by + have hnum_nonneg : 0 ≤ q.num := (Rat.num_nonneg (q := q)).2 hq + have hnum_eq : (q.num.natAbs : Int) = q.num := by + exact (Int.natAbs_of_nonneg hnum_nonneg) + have hnum_cast : (q.num : Real) = (q.num.natAbs : Real) := by + exact (congrArg (fun z : Int => (z : Real)) hnum_eq).symm + have hq_rat : (q : Real) = (q.num : Real) / q.den := by + simp [Rat.cast_def] + calc + (q : Real) = (q.num : Real) / q.den := hq_rat + _ = (q.num.natAbs : Real) / q.den := by + rw [hnum_cast] + +/-- Cast a nonnegative rational as `num.natAbs * den / den^2`. -/ +theorem rat_cast_eq_num_den_mul {q : Rat} (hq : 0 ≤ q) : + (q : Real) = (q.num.natAbs : Real) * q.den / (q.den : Real) ^ 2 := by + have hq_cast : (q : Real) = (q.num.natAbs : Real) / q.den := + rat_cast_eq_num_den (q := q) hq + have hden_ne : (q.den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_eq : + (q.num.natAbs : Real) / q.den = + (q.num.natAbs : Real) * q.den / (q.den : Real) ^ 2 := by + field_simp [hden_ne] + exact hq_cast.trans hq_eq + +/-- Cast a nonnegative rational as `num.natAbs * den * scale^2 / (den * scale)^2`. -/ +theorem rat_cast_eq_num_den_scale {q : Rat} (hq : 0 ≤ q) {scale : Nat} (hscale : 0 < scale) : + (q : Real) = + (q.num.natAbs : Real) * q.den * (scale : Real) * (scale : Real) / + ((q.den : Real) * (scale : Real)) ^ 2 := by + have hq_cast : (q : Real) = (q.num.natAbs : Real) / q.den := + rat_cast_eq_num_den (q := q) hq + have hden_pos : 0 < (q.den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale + have hden_scale_ne : ((q.den : Real) * (scale : Real)) ≠ 0 := by + exact ne_of_gt (mul_pos hden_pos hscale_pos) + have hq_eq : + (q.num.natAbs : Real) / q.den = + (q.num.natAbs : Real) * q.den * (scale : Real) * (scale : Real) / + ((q.den : Real) * (scale : Real)) ^ 2 := by + field_simp [hden_scale_ne] + exact hq_cast.trans hq_eq + +/-- Square-root lower bound in reals. -/ +theorem sqrtLowerBase_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerBase q : Real) ≤ Real.sqrt (q : Real) := by + classical + -- Set up numerator/denominator witnesses. + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt num + set b : Nat := Nat.sqrt den + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hbpos : 0 < (b + 1 : Real) := by + exact_mod_cast (Nat.succ_pos b) + have hnum_le : (a ^ 2 : Real) ≤ num := by + exact_mod_cast (Nat.sqrt_le' num) + have hden_le : (den : Real) ≤ (b + 1) ^ 2 := by + exact_mod_cast (le_of_lt (Nat.lt_succ_sqrt' den)) + have hmul : (a ^ 2 : Real) * den ≤ (num : Real) * (b + 1) ^ 2 := by + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + have hnum_nonneg : 0 ≤ (num : Real) := by exact_mod_cast (Nat.zero_le num) + exact mul_le_mul hnum_le hden_le hden_nonneg hnum_nonneg + have hbpos2 : 0 < (b + 1 : Real) ^ 2 := by + simpa [pow_two] using (mul_pos hbpos hbpos) + have hdiv : (a ^ 2 : Real) / (b + 1) ^ 2 ≤ (num : Real) / den := by + exact (div_le_div_iff₀ hbpos2 hden_pos).2 hmul + have hpow : ((a : Real) / (b + 1 : Real)) ^ 2 = (a ^ 2 : Real) / (b + 1) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hq_cast : (q : Real) = (num : Real) / den := by + simpa [num, den] using (rat_cast_eq_num_den (q := q) hq) + have hsq : ((a : Real) / (b + 1 : Real)) ^ 2 ≤ (q : Real) := by + simpa [hpow, hq_cast, den, num] using hdiv + have hnonneg : 0 ≤ (a : Real) / (b + 1 : Real) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (b + 1 : Real) := by exact_mod_cast (Nat.zero_le (b + 1)) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq + have hle : (a : Real) / (b + 1 : Real) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerBase q : Real) ≤ (a : Real) / (b + 1 : Real) := by + have hdown' : + (ratRoundDown ((a : Rat) / (b + 1)) : Real) ≤ + (a : Real) / (b + 1 : Real) := by + simpa using ratRoundDown_le_real ((a : Rat) / (b + 1)) + simpa [sqrtLowerBase, num, den, a, b] using hdown' + exact le_trans hdown hle + +/-- Square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperBase {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperBase q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt num + set b : Nat := Nat.sqrt den + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hbpos : 0 < (b : Real) := by + have hb : 0 < b := by + have hden : 0 < den := q.den_pos + exact (Nat.sqrt_pos).2 hden + exact_mod_cast hb + have hnum_lt : (num : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' num) + have hden_le : (b ^ 2 : Real) ≤ den := by + exact_mod_cast (Nat.sqrt_le' den) + have hmul : (num : Real) * (b ^ 2) ≤ (a + 1) ^ 2 * den := by + have hb2_nonneg : 0 ≤ (b ^ 2 : Real) := by + exact sq_nonneg (b : Real) + have hsq_nonneg : 0 ≤ (a + 1 : Real) ^ 2 := by + exact sq_nonneg (a + 1 : Real) + exact mul_le_mul (le_of_lt hnum_lt) hden_le hb2_nonneg hsq_nonneg + have hbpos2 : 0 < (b : Real) ^ 2 := by + simpa [pow_two] using (mul_pos hbpos hbpos) + have hdiv : (num : Real) / den ≤ (a + 1) ^ 2 / (b : Real) ^ 2 := by + exact (div_le_div_iff₀ hden_pos hbpos2).2 hmul + have hpow : ((a + 1 : Real) / (b : Real)) ^ 2 = (a + 1) ^ 2 / (b : Real) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hq_cast : (q : Real) = (num : Real) / den := by + simpa [num, den] using (rat_cast_eq_num_den (q := q) hq) + have hsq : (q : Real) ≤ ((a + 1 : Real) / (b : Real)) ^ 2 := by + simpa [hpow, hq_cast, den, num] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / (b : Real)) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (b : Real) := by exact_mod_cast (Nat.zero_le b) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (b : Real) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / (b : Real) ≤ (sqrtUpperBase q : Real) := by + have hup' : + (a + 1 : Real) / (b : Real) ≤ + (ratRoundUp ((a + 1 : Rat) / b) : Real) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / b) + simpa [sqrtUpperBase, num, den, a, b] using hup' + exact le_trans hle hup + +/- + Local automation for monotone scaling steps in real-valued bounds. +-/ +/-- Alternate square-root lower bound in reals. -/ +theorem sqrtLowerAlt_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerAlt q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hnumden_le : (a ^ 2 : Real) ≤ (num * den : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den)) + have hmul : (a ^ 2 : Real) ≤ (num : Real) * den := by + simpa [num, den, Nat.cast_mul] using hnumden_le + have hden_pos2 : 0 < (den : Real) ^ 2 := by + exact pow_pos hden_pos 2 + have hdiv : + (a ^ 2 : Real) / (den : Real) ^ 2 ≤ (num : Real) * den / (den : Real) ^ 2 := by + have hmul' : + (a ^ 2 : Real) * (den : Real) ^ 2 ≤ (num : Real) * den * (den : Real) ^ 2 := by + have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by + exact sq_nonneg (den : Real) + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' + have hden_ne : (den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by + simpa [num, den] using (rat_cast_eq_num_den_mul (q := q) hq) + have hsq : ((a : Real) / (den : Real)) ^ 2 ≤ (q : Real) := by + simpa [hq_cast, pow_two, div_mul_div_comm] using hdiv + have hnonneg : 0 ≤ (a : Real) / (den : Real) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq + have hle : (a : Real) / (den : Real) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerAlt q : Real) ≤ (a : Real) / (den : Real) := by + have hdown' : + (ratRoundDown ((a : Rat) / den) : Real) ≤ + (a : Real) / (den : Real) := by + simpa using ratRoundDown_le_real ((a : Rat) / den) + simpa [sqrtLowerAlt, num, den, a] using hdown' + exact le_trans hdown hle + +/-- Scaled square-root lower bound in reals. -/ +theorem sqrtLowerScaled_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLowerScaled q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set scale : Nat := sqrtLowerScale + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos_nat : 0 < scale := by + simp [scale, sqrtLowerScale] + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale_pos_nat + have hnumden_le : (a ^ 2 : Real) ≤ (num * den * scale * scale : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den * scale * scale)) + have hmul : + (a ^ 2 : Real) ≤ (num : Real) * den * (scale : Real) * (scale : Real) := by + simpa [num, den, scale, Nat.cast_mul, mul_assoc, mul_left_comm, mul_comm] using hnumden_le + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := + mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hmul' : + (a ^ 2 : Real) * ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) * + ((den : Real) * (scale : Real)) ^ 2 := by + have hnonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hnonneg + have hdiv : + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast : + (q : Real) = + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den, scale] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale_pos_nat) + have hpow : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 ≤ (q : Real) := by + calc + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 + = (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := hpow + _ ≤ ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := hdiv + _ = (q : Real) := by simp [hq_cast] + have hnonneg : 0 ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq + have hle : + (a : Real) / ((den : Real) * (scale : Real)) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerScaled q : Real) ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hdown' : + (ratRoundDown ((a : Rat) / (den * scale)) : Real) ≤ + (a : Real) / ((den : Real) * (scale : Real)) := by + simpa using ratRoundDown_le_real ((a : Rat) / (den * scale)) + simpa [sqrtLowerScaled, num, den, scale, a] using hdown' + exact le_trans hdown hle + +/-- Scaled square-root lower bound in reals with a custom scale. -/ +theorem sqrtLowerScaledWith_le_real_sqrt {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + (sqrtLowerScaledWith scale q : Real) ≤ Real.sqrt (q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale + have hnumden_le : (a ^ 2 : Real) ≤ (num * den * scale * scale : Nat) := by + exact_mod_cast (Nat.sqrt_le' (num * den * scale * scale)) + have hmul : + (a ^ 2 : Real) ≤ (num : Real) * den * (scale : Real) * (scale : Real) := by + simpa [num, den, Nat.cast_mul, mul_assoc, mul_left_comm, mul_comm] using hnumden_le + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := + mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hmul' : + (a ^ 2 : Real) * ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) * + ((den : Real) * (scale : Real)) ^ 2 := by + have hnonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hnonneg + have hdiv : + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 ≤ + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast : + (q : Real) = + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale) + have hpow : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 ≤ (q : Real) := by + calc + ((a : Real) / ((den : Real) * (scale : Real))) ^ 2 + = (a ^ 2 : Real) / ((den : Real) * (scale : Real)) ^ 2 := hpow + _ ≤ ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := hdiv + _ = (q : Real) := by simp [hq_cast] + have hnonneg : 0 ≤ (a : Real) / ((den : Real) * (scale : Real)) := by + have hnum_nonneg : 0 ≤ (a : Real) := by exact_mod_cast (Nat.zero_le a) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hq_nonneg : 0 ≤ (q : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hq + have hle : + (a : Real) / ((den : Real) * (scale : Real)) ≤ Real.sqrt (q : Real) := + (Real.le_sqrt hnonneg hq_nonneg).2 hsq + have hdown : + (sqrtLowerScaledWith scale q : Real) ≤ + (a : Real) / ((den : Real) * (scale : Real)) := by + have hdown' : + (ratRoundDown ((a : Rat) / (den * scale)) : Real) ≤ + (a : Real) / ((den : Real) * (scale : Real)) := by + simpa using ratRoundDown_le_real ((a : Rat) / (den * scale)) + simpa [sqrtLowerScaledWith, num, den, a] using hdown' + exact le_trans hdown hle + +/-- Alternate square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperAlt {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperAlt q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hnumden_lt : (num * den : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den)) + have hmul : (num : Real) * den ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hden_pos2 : 0 < (den : Real) ^ 2 := by + exact pow_pos hden_pos 2 + have hdiv : + (num : Real) * den / (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by + have hmul' : + (num : Real) * den * (den : Real) ^ 2 ≤ (a + 1 : Real) ^ 2 * (den : Real) ^ 2 := by + have hden_sq_nonneg : 0 ≤ (den : Real) ^ 2 := by + exact sq_nonneg (den : Real) + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + exact (div_le_div_iff₀ hden_pos2 hden_pos2).2 hmul' + have hden_ne : (den : Real) ≠ 0 := by + exact_mod_cast q.den_pos.ne' + have hq_cast : (q : Real) = (num : Real) * den / (den : Real) ^ 2 := by + simpa [num, den] using (rat_cast_eq_num_den_mul (q := q) hq) + have hpow : + ((a + 1 : Real) / (den : Real)) ^ 2 = + (a + 1 : Real) ^ 2 / (den : Real) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : (q : Real) ≤ ((a + 1 : Real) / (den : Real)) ^ 2 := by + simpa [hq_cast, hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / (den : Real)) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) := by exact_mod_cast (Nat.zero_le den) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : Real.sqrt (q : Real) ≤ (a + 1 : Real) / (den : Real) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / (den : Real) ≤ (sqrtUpperAlt q : Real) := by + have hup' : + (a + 1 : Real) / (den : Real) ≤ + (ratRoundUp ((a + 1 : Rat) / den) : Real) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / den) + simpa [sqrtUpperAlt, num, den, a] using hup' + exact le_trans hle hup + +/-- Scaled square-root upper bound in reals. -/ +theorem real_sqrt_le_sqrtUpperScaled {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpperScaled q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set scale : Nat := sqrtLowerScale + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos_nat : 0 < scale := by + simp [scale, sqrtLowerScale] + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale_pos_nat + have hnumden_lt : (num * den * scale * scale : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den * scale * scale)) + have hmul : + (num : Real) * den * (scale : Real) * (scale : Real) ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := by + exact mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hdiv : + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + have hmul' : + (num : Real) * den * (scale : Real) * (scale : Real) * + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 * ((den : Real) * (scale : Real)) ^ 2 := by + have hden_sq_nonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul hmul (le_refl _) hden_sq_nonneg (sq_nonneg _) + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast' : + (q : Real) = + ((num : Real) * den * (scale : Real) * (scale : Real)) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den, scale] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale_pos_nat) + have hpow : + ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : + (q : Real) ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 := by + simpa [hq_cast', hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by + exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : + Real.sqrt (q : Real) ≤ (a + 1 : Real) / ((den : Real) * (scale : Real)) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ (sqrtUpperScaled q : Real) := by + have hup' : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ + (ratRoundUp ((a + 1 : Rat) / (den * scale)) : Real) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / (den * scale)) + simpa [sqrtUpperScaled, num, den, scale, a] using hup' + exact le_trans hle hup + +/-- Scaled square-root upper bound in reals with a custom scale. -/ +theorem real_sqrt_le_sqrtUpperScaledWith {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + Real.sqrt (q : Real) ≤ (sqrtUpperScaledWith scale q : Real) := by + classical + set num : Nat := q.num.natAbs + set den : Nat := q.den + set a : Nat := Nat.sqrt (num * den * scale * scale) + have hden_pos : 0 < (den : Real) := by + exact_mod_cast q.den_pos + have hscale_pos : 0 < (scale : Real) := by + exact_mod_cast hscale + have hnumden_lt : (num * den * scale * scale : Real) < (a + 1) ^ 2 := by + exact_mod_cast (Nat.lt_succ_sqrt' (num * den * scale * scale)) + have hmul : + (num : Real) * den * (scale : Real) * (scale : Real) ≤ (a + 1 : Real) ^ 2 := by + exact le_of_lt hnumden_lt + have hdenScale_pos : 0 < (den : Real) * (scale : Real) := by + exact mul_pos hden_pos hscale_pos + have hdenScale_pos2 : 0 < ((den : Real) * (scale : Real)) ^ 2 := by + exact pow_pos hdenScale_pos 2 + have hmul' : + (num : Real) * den * (scale : Real) * (scale : Real) * + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 * ((den : Real) * (scale : Real)) ^ 2 := by + have hden_sq_nonneg : 0 ≤ ((den : Real) * (scale : Real)) ^ 2 := by + exact sq_nonneg _ + exact mul_le_mul_of_nonneg_right hmul hden_sq_nonneg + have hdiv : + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 ≤ + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + exact (div_le_div_iff₀ hdenScale_pos2 hdenScale_pos2).2 hmul' + have hq_cast : + (q : Real) = + (num : Real) * den * (scale : Real) * (scale : Real) / + ((den : Real) * (scale : Real)) ^ 2 := by + simpa [num, den] using + (rat_cast_eq_num_den_scale (q := q) hq (scale := scale) hscale) + have hpow : + ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 = + (a + 1 : Real) ^ 2 / ((den : Real) * (scale : Real)) ^ 2 := by + simp [pow_two, div_mul_div_comm] + have hsq : (q : Real) ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) ^ 2 := by + simpa [hq_cast, hpow] using hdiv + have hnonneg : 0 ≤ ((a + 1 : Real) / ((den : Real) * (scale : Real))) := by + have hnum_nonneg : 0 ≤ (a + 1 : Real) := by exact_mod_cast (Nat.zero_le (a + 1)) + have hden_nonneg : 0 ≤ (den : Real) * (scale : Real) := by + exact mul_nonneg (le_of_lt hden_pos) (le_of_lt hscale_pos) + exact div_nonneg hnum_nonneg hden_nonneg + have hle : + Real.sqrt (q : Real) ≤ (a + 1 : Real) / ((den : Real) * (scale : Real)) := + (Real.sqrt_le_iff).2 ⟨hnonneg, hsq⟩ + have hup : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ + (sqrtUpperScaledWith scale q : Real) := by + have hup' : + (a + 1 : Real) / ((den : Real) * (scale : Real)) ≤ + (ratRoundUp ((a + 1 : Rat) / (den * scale)) : Real) := by + simpa using real_le_ratRoundUp ((a + 1 : Rat) / (den * scale)) + simpa [sqrtUpperScaledWith, num, den, a] using hup' + exact le_trans hle hup + +theorem sqrtLowerWithScale_le_real_sqrt {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + (sqrtLowerWithScale scale q : Real) ≤ Real.sqrt (q : Real) := by + have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq + have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq + have hscaled := sqrtLowerScaledWith_le_real_sqrt (q := q) (scale := scale) hq hscale + have hmax1 : + (max (sqrtLowerBase q) (sqrtLowerAlt q) : Real) ≤ Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hbase, halt⟩ + have hmax2 : + (max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaledWith scale q) : Real) ≤ + Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hmax1, hscaled⟩ + simpa [sqrtLowerWithScale] using hmax2 + +theorem real_sqrt_le_sqrtUpperWithScale {q : Rat} {scale : Nat} + (hq : 0 ≤ q) (hscale : 0 < scale) : + Real.sqrt (q : Real) ≤ (sqrtUpperWithScale scale q : Real) := by + have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq + have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq + have hscaled := real_sqrt_le_sqrtUpperScaledWith (q := q) (scale := scale) hq hscale + have hmin1 : + Real.sqrt (q : Real) ≤ min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real) := by + exact (le_min_iff).2 ⟨hbase, halt⟩ + have hmin2 : + Real.sqrt (q : Real) ≤ + min (min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real)) + (sqrtUpperScaledWith scale q : Real) := by + exact (le_min_iff).2 ⟨hmin1, hscaled⟩ + simpa [sqrtUpperWithScale, ratToReal_min] using hmin2 +/-- Square-root lower bound in reals (tighter of three bounds). -/ +theorem sqrtLower_le_real_sqrt {q : Rat} (hq : 0 ≤ q) : + (sqrtLower q : Real) ≤ Real.sqrt (q : Real) := by + have hbase := sqrtLowerBase_le_real_sqrt (q := q) hq + have halt := sqrtLowerAlt_le_real_sqrt (q := q) hq + have hscaled := sqrtLowerScaled_le_real_sqrt (q := q) hq + have hmax1 : + (max (sqrtLowerBase q) (sqrtLowerAlt q) : Real) ≤ Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hbase, halt⟩ + have hmax2 : + (max (max (sqrtLowerBase q) (sqrtLowerAlt q)) (sqrtLowerScaled q) : Real) ≤ + Real.sqrt (q : Real) := by + simpa [ratToReal_max] using (max_le_iff).2 ⟨hmax1, hscaled⟩ + simpa [sqrtLower] using hmax2 + +/-- Square-root upper bound in reals (tighter of three bounds). -/ +theorem real_sqrt_le_sqrtUpper {q : Rat} (hq : 0 ≤ q) : + Real.sqrt (q : Real) ≤ (sqrtUpper q : Real) := by + have hbase := real_sqrt_le_sqrtUpperBase (q := q) hq + have halt := real_sqrt_le_sqrtUpperAlt (q := q) hq + have hscaled := real_sqrt_le_sqrtUpperScaled (q := q) hq + have hmin1 : + Real.sqrt (q : Real) ≤ min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real) := by + exact (le_min_iff).2 ⟨hbase, halt⟩ + have hmin2 : + Real.sqrt (q : Real) ≤ + min (min (sqrtUpperBase q : Real) (sqrtUpperAlt q : Real)) (sqrtUpperScaled q : Real) := by + exact (le_min_iff).2 ⟨hmin1, hscaled⟩ + simpa [sqrtUpper, ratToReal_min] using hmin2 + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/Mlp.lean b/Nfp/Bounds/Mlp.lean new file mode 100644 index 0000000..1a443fe --- /dev/null +++ b/Nfp/Bounds/Mlp.lean @@ -0,0 +1,295 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Bounds.Gelu +public import Nfp.Bounds.LayerNorm + +/-! +Interval bounds for GPT-2 MLP blocks (linear + GELU + linear). +-/ + +public section + +namespace Nfp + + +namespace Bounds + +open scoped BigOperators + +/-- Real-valued MLP with tanh-based GELU activations. -/ +noncomputable def mlpReal {dModel hidden : Nat} + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (x : Fin dModel → Real) : Fin dModel → Real := + fun i => + let hidden : Fin hidden → Real := fun h => + geluTanh (dotProduct (fun j => (wIn j h : Real)) x + (bIn h : Real)) + dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) + +/-- Interval bounds for a tanh-GELU MLP given input intervals. -/ +def mlpBounds {dModel hidden : Nat} + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) : (Fin dModel → Rat) × (Fin dModel → Rat) := + let preLo : Fin hidden → Rat := fun h => + dotIntervalLower (fun j => wIn j h) lo hi + bIn h + let preHi : Fin hidden → Rat := fun h => + dotIntervalUpper (fun j => wIn j h) lo hi + bIn h + let geluBounds : Fin hidden → Rat × Rat := fun h => geluInterval (preLo h) (preHi h) + let geluLo : Fin hidden → Rat := fun h => (geluBounds h).1 + let geluHi : Fin hidden → Rat := fun h => (geluBounds h).2 + let outLo : Fin dModel → Rat := fun i => + dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i + let outHi : Fin dModel → Rat := fun i => + dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i + (outLo, outHi) + +/-- `mlpBounds` soundness for real MLP outputs. -/ +theorem mlpBounds_spec {dModel hidden : Nat} + (wIn : Fin dModel → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin dModel → Rat) (bOut : Fin dModel → Rat) + (lo hi : Fin dModel → Rat) (x : Fin dModel → Real) + (hlo : ∀ j, (lo j : Real) ≤ x j) (hhi : ∀ j, x j ≤ (hi j : Real)) : + let bounds := mlpBounds wIn bIn wOut bOut lo hi + ∀ i, (bounds.1 i : Real) ≤ mlpReal wIn bIn wOut bOut x i ∧ + mlpReal wIn bIn wOut bOut x i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let preLo : Fin hidden → Rat := fun h => + dotIntervalLower (fun j => wIn j h) lo hi + bIn h + let preHi : Fin hidden → Rat := fun h => + dotIntervalUpper (fun j => wIn j h) lo hi + bIn h + let pre : Fin hidden → Real := fun h => + dotProduct (fun j => (wIn j h : Real)) x + (bIn h : Real) + have hpre_lower : ∀ h, (preLo h : Real) ≤ pre h := by + intro h + have hlow := + dotIntervalLower_le_dotProduct_real_add (v := fun j => wIn j h) + (lo := lo) (hi := hi) (x := x) (b := (bIn h : Real)) hlo hhi + simpa [pre, preLo] using hlow + have hpre_upper : ∀ h, pre h ≤ (preHi h : Real) := by + intro h + have hhigh := + dotProduct_le_dotIntervalUpper_real_add (v := fun j => wIn j h) + (lo := lo) (hi := hi) (x := x) (b := (bIn h : Real)) hlo hhi + simpa [pre, preHi] using hhigh + let geluBounds : Fin hidden → Rat × Rat := fun h => geluInterval (preLo h) (preHi h) + let geluLo : Fin hidden → Rat := fun h => (geluBounds h).1 + let geluHi : Fin hidden → Rat := fun h => (geluBounds h).2 + let hidden : Fin hidden → Real := fun h => geluTanh (pre h) + have hgelu : ∀ h, (geluLo h : Real) ≤ hidden h ∧ hidden h ≤ (geluHi h : Real) := by + intro h + have hbounds := geluInterval_bounds (lo := preLo h) (hi := preHi h) + (hpre_lower h) (hpre_upper h) + simpa [geluLo, geluHi, geluBounds, hidden] using hbounds + let outLo : Fin dModel → Rat := fun i => + dotIntervalLower (fun h => wOut h i) geluLo geluHi + bOut i + let outHi : Fin dModel → Rat := fun i => + dotIntervalUpper (fun h => wOut h i) geluLo geluHi + bOut i + have hout_lower : + (outLo i : Real) ≤ dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) := by + have hgelu_lo : ∀ h, (geluLo h : Real) ≤ hidden h := fun h => (hgelu h).1 + have hgelu_hi : ∀ h, hidden h ≤ (geluHi h : Real) := fun h => (hgelu h).2 + have hlow := + dotIntervalLower_le_dotProduct_real_add (v := fun h => wOut h i) + (lo := geluLo) (hi := geluHi) (x := hidden) (b := (bOut i : Real)) + hgelu_lo hgelu_hi + simpa [outLo] using hlow + have hout_upper : + dotProduct (fun h => (wOut h i : Real)) hidden + (bOut i : Real) ≤ (outHi i : Real) := by + have hgelu_lo : ∀ h, (geluLo h : Real) ≤ hidden h := fun h => (hgelu h).1 + have hgelu_hi : ∀ h, hidden h ≤ (geluHi h : Real) := fun h => (hgelu h).2 + have hhigh := + dotProduct_le_dotIntervalUpper_real_add (v := fun h => wOut h i) + (lo := geluLo) (hi := geluHi) (x := hidden) (b := (bOut i : Real)) + hgelu_lo hgelu_hi + simpa [outHi] using hhigh + have hlo' : (outLo i : Real) ≤ mlpReal wIn bIn wOut bOut x i := by + simpa [mlpReal, hidden, pre] using hout_lower + have hhi' : mlpReal wIn bIn wOut bOut x i ≤ (outHi i : Real) := by + simpa [mlpReal, hidden, pre] using hout_upper + simpa [bounds, mlpBounds, preLo, preHi, geluLo, geluHi, outLo, outHi] using + And.intro hlo' hhi' + +/-- Interval bounds for a LayerNorm + MLP sublayer from exact inputs. -/ +def layerNormMlpBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let ln := layerNormBounds eps gamma beta x + mlpBounds wIn bIn wOut bOut ln.1 ln.2 + +/-- `layerNormMlpBounds` soundness for real LayerNorm + MLP outputs. -/ +theorem layerNormMlpBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x + ∀ i, (bounds.1 i : Real) ≤ + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ∧ + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let ln := layerNormBounds eps gamma beta x + have hln := layerNormBounds_spec eps gamma beta x hne heps hsqrt + have hlo : ∀ j, (ln.1 j : Real) ≤ layerNormReal eps gamma beta x j := fun j => (hln j).1 + have hhi : ∀ j, layerNormReal eps gamma beta x j ≤ (ln.2 j : Real) := fun j => (hln j).2 + have hmlp := mlpBounds_spec wIn bIn wOut bOut ln.1 ln.2 + (layerNormReal eps gamma beta x) hlo hhi + simpa [bounds, layerNormMlpBounds, ln] using hmlp i + +/-- Interval bounds for LayerNorm + MLP sublayer from interval inputs. -/ +def layerNormAbsMlpBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let absBound := intervalAbsBound lo hi + let ln := layerNormAbsBounds eps gamma beta absBound + mlpBounds wIn bIn wOut bOut ln.1 ln.2 + +/-- `layerNormAbsMlpBounds` soundness for real LayerNorm + MLP outputs. -/ +theorem layerNormAbsMlpBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : + let bounds := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi + ∀ i, (bounds.1 i : Real) ≤ + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ∧ + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ≤ (bounds.2 i : Real) := by + classical + intro bounds i + let absBound := intervalAbsBound lo hi + let ln := layerNormAbsBounds eps gamma beta absBound + have habs : ∀ j, |x j| ≤ (absBound : Real) := by + intro j + have hbound : + |x j| ≤ max |(lo j : Real)| |(hi j : Real)| := + abs_le_max_abs_abs_of_interval_real (hlo j) (hhi j) + have hsup : max |lo j| |hi j| ≤ intervalAbsBound lo hi := + max_abs_le_intervalAbsBound lo hi j + have hsup_real : + max |(lo j : Real)| |(hi j : Real)| ≤ (absBound : Real) := by + have hsup' : + ratToReal (max |lo j| |hi j|) ≤ ratToReal absBound := + ratToReal_le_of_le hsup + simpa [ratToReal_abs, ratToReal_max, ratToReal_def] using hsup' + exact le_trans hbound hsup_real + have hln := + layerNormAbsBounds_spec_real eps gamma beta absBound x hne heps hsqrt habs + have hlo_ln : ∀ j, (ln.1 j : Real) ≤ layerNormRealOfReal eps gamma beta x j := fun j => (hln j).1 + have hhi_ln : ∀ j, layerNormRealOfReal eps gamma beta x j ≤ (ln.2 j : Real) := fun j => (hln j).2 + have hmlp := mlpBounds_spec wIn bIn wOut bOut ln.1 ln.2 + (layerNormRealOfReal eps gamma beta x) hlo_ln hhi_ln + simpa [bounds, layerNormAbsMlpBounds, absBound, ln] using hmlp i + +/-- Add residual inputs to interval bounds. -/ +def residualAddBounds {n : Nat} (x : Fin n → Rat) (lo hi : Fin n → Rat) : + (Fin n → Rat) × (Fin n → Rat) := + (fun i => x i + lo i, fun i => x i + hi i) + +/-- `residualAddBounds` soundness for residual addition. -/ +theorem residualAddBounds_spec {n : Nat} (x : Fin n → Rat) + (lo hi : Fin n → Rat) (y : Fin n → Real) + (hlo : ∀ i, (lo i : Real) ≤ y i) (hhi : ∀ i, y i ≤ (hi i : Real)) : + let bounds := residualAddBounds x lo hi + ∀ i, (bounds.1 i : Real) ≤ (x i : Real) + y i ∧ + (x i : Real) + y i ≤ (bounds.2 i : Real) := by + intro bounds i + have hlow : (x i : Real) + (lo i : Real) ≤ (x i : Real) + y i := by + simpa [add_comm] using + add_le_add_left (hlo i) (x i : Real) + have hhigh : (x i : Real) + y i ≤ (x i : Real) + (hi i : Real) := by + simpa [add_comm] using + add_le_add_left (hhi i) (x i : Real) + constructor + · simpa [bounds, residualAddBounds] using hlow + · simpa [bounds, residualAddBounds] using hhigh + +/-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add). -/ +def layerNormMlpResidualBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let mlp := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x + residualAddBounds x mlp.1 mlp.2 + +/-- `layerNormMlpResidualBounds` soundness for the MLP residual path. -/ +theorem layerNormMlpResidualBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (x : Fin n → Rat) (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) : + let bounds := layerNormMlpResidualBounds eps gamma beta wIn bIn wOut bOut x + ∀ i, + (bounds.1 i : Real) ≤ + (x i : Real) + + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ∧ + (x i : Real) + + mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x) i ≤ + (bounds.2 i : Real) := by + classical + intro bounds i + let mlp := layerNormMlpBounds eps gamma beta wIn bIn wOut bOut x + have hmlp := layerNormMlpBounds_spec eps gamma beta wIn bIn wOut bOut x hne heps hsqrt + have hres := residualAddBounds_spec x mlp.1 mlp.2 + (mlpReal wIn bIn wOut bOut (layerNormReal eps gamma beta x)) + (fun j => (hmlp j).1) (fun j => (hmlp j).2) + simpa [bounds, layerNormMlpResidualBounds, mlp] using hres i + +/-- Interval bounds for a full MLP residual path (LayerNorm + MLP + residual add) from intervals. -/ +def layerNormAbsMlpResidualBounds {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) : (Fin n → Rat) × (Fin n → Rat) := + let mlp := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi + (fun i => lo i + mlp.1 i, fun i => hi i + mlp.2 i) + +/-- `layerNormAbsMlpResidualBounds` soundness for the MLP residual path. -/ +theorem layerNormAbsMlpResidualBounds_spec {n hidden : Nat} + (eps : Rat) (gamma beta : Fin n → Rat) + (wIn : Fin n → Fin hidden → Rat) (bIn : Fin hidden → Rat) + (wOut : Fin hidden → Fin n → Rat) (bOut : Fin n → Rat) + (lo hi : Fin n → Rat) (x : Fin n → Real) + (hne : n ≠ 0) (heps : 0 < eps) (hsqrt : 0 < sqrtLower eps) + (hlo : ∀ i, (lo i : Real) ≤ x i) (hhi : ∀ i, x i ≤ (hi i : Real)) : + let bounds := layerNormAbsMlpResidualBounds eps gamma beta wIn bIn wOut bOut lo hi + ∀ i, + (bounds.1 i : Real) ≤ + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ∧ + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ≤ + (bounds.2 i : Real) := by + classical + intro bounds i + let mlp := layerNormAbsMlpBounds eps gamma beta wIn bIn wOut bOut lo hi + have hmlp := + layerNormAbsMlpBounds_spec eps gamma beta wIn bIn wOut bOut lo hi x hne heps hsqrt hlo hhi + have hlo' := (hmlp i).1 + have hhi' := (hmlp i).2 + have hlow : + (lo i : Real) + (mlp.1 i : Real) ≤ + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i := by + exact add_le_add (hlo i) hlo' + have hhigh : + x i + mlpReal wIn bIn wOut bOut (layerNormRealOfReal eps gamma beta x) i ≤ + (hi i : Real) + (mlp.2 i : Real) := by + exact add_le_add (hhi i) hhi' + constructor + · simpa [bounds, layerNormAbsMlpResidualBounds, mlp] using hlow + · simpa [bounds, layerNormAbsMlpResidualBounds, mlp] using hhigh + +end Bounds + + +end Nfp diff --git a/Nfp/Bounds/UnnormRat.lean b/Nfp/Bounds/UnnormRat.lean new file mode 100644 index 0000000..1c3b5f6 --- /dev/null +++ b/Nfp/Bounds/UnnormRat.lean @@ -0,0 +1,62 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Core.Basic +public import Nfp.Linear.FinFold + +/-! +Unnormalized rational arithmetic. + +Rat values already avoid gcd normalization, so this module provides a +lightweight alias and helper API used by older code paths. +-/ + +public section + +namespace Nfp + + +namespace Bounds + +/-- Unnormalized rational value (alias). -/ +abbrev UnnormRat := Rat + +/-- Interpret an unnormalized rational as a rational. -/ +def UnnormRat.toRat (q : UnnormRat) : Rat := + q + +/-- Embed a rational as an unnormalized rational. -/ +def UnnormRat.ofRat (q : Rat) : UnnormRat := + q + +/-- Unnormalized zero. -/ +def UnnormRat.zero : UnnormRat := 0 + +/-- Unnormalized addition. -/ +def UnnormRat.add (a b : UnnormRat) : UnnormRat := + a + b + +/-- Unnormalized multiplication. -/ +def UnnormRat.mul (a b : UnnormRat) : UnnormRat := + a * b + +/-- `toRat` respects multiplication. -/ +theorem UnnormRat.toRat_mul_ofRat (a b : Rat) : + UnnormRat.toRat (UnnormRat.mul (UnnormRat.ofRat a) (UnnormRat.ofRat b)) = a * b := by + rfl + +/-- Tail-recursive sum of unnormalized rationals. -/ +def UnnormRat.sumFin (n : Nat) (f : Fin n → UnnormRat) : UnnormRat := + Linear.sumFin n f + +/-- `toRat` commutes with `sumFin`. -/ +theorem UnnormRat.toRat_sumFin (n : Nat) (f : Fin n → UnnormRat) : + UnnormRat.toRat (UnnormRat.sumFin n f) = + Linear.sumFin n (fun i => UnnormRat.toRat (f i)) := by + rfl + +end Bounds + + +end Nfp diff --git a/Nfp/Circuit.lean b/Nfp/Circuit.lean new file mode 100644 index 0000000..ae32604 --- /dev/null +++ b/Nfp/Circuit.lean @@ -0,0 +1,19 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Basic +public import Nfp.Circuit.Combinators +public import Nfp.Circuit.Interface +public import Nfp.Circuit.Semantics +public import Nfp.Circuit.WellFormed +public import Nfp.Circuit.Cert +public import Nfp.Circuit.Typed +public import Nfp.Circuit.Compose +public import Nfp.Circuit.Gates +public import Nfp.Circuit.Tensor +public import Nfp.Circuit.Layers + +/-! +Circuit definitions, semantics, and equivalence checking. +-/ diff --git a/Nfp/Circuit/Basic.lean b/Nfp/Circuit/Basic.lean new file mode 100644 index 0000000..c796bb3 --- /dev/null +++ b/Nfp/Circuit/Basic.lean @@ -0,0 +1,45 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.System.Dag + +/-! +Circuit foundations: a DAG with designated inputs/outputs and gate semantics. +-/ + +public section + +namespace Nfp + +universe u v + +/-- A finite circuit on a DAG with designated inputs/outputs and per-node gate semantics. -/ +structure Circuit (ι : Type u) [Fintype ι] (α : Type v) where + /-- The underlying DAG that orders dependencies. -/ + dag : Dag ι + /-- Input nodes read from the external assignment. -/ + inputs : Finset ι + /-- Output nodes observed after evaluation. -/ + outputs : Finset ι + /-- Gate semantics at each node, given values of its parents. -/ + gate : ∀ i, (∀ j, dag.rel j i → α) → α + +namespace Circuit + +variable {ι : Type u} [Fintype ι] {α : Type v} + +/-- External input assignment on the circuit's input nodes. -/ +abbrev InputAssignment (C : Circuit ι α) : Type (max u v) := + { i // i ∈ C.inputs } → α + +/-- Reinterpret input assignments along an equality of input sets. -/ +def InputAssignment.cast {C₁ C₂ : Circuit ι α} (h : C₁.inputs = C₂.inputs) : + InputAssignment C₁ → InputAssignment C₂ := by + intro input i + refine input ⟨i.1, ?_⟩ + simp [h] + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert.lean b/Nfp/Circuit/Cert.lean new file mode 100644 index 0000000..e078289 --- /dev/null +++ b/Nfp/Circuit/Cert.lean @@ -0,0 +1,13 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Cert.InductionHead +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.Circuit.Cert.ValueRange + +/-! +Certificate definitions and checkers for circuits. +-/ diff --git a/Nfp/Circuit/Cert/Basic.lean b/Nfp/Circuit/Cert/Basic.lean new file mode 100644 index 0000000..7636610 --- /dev/null +++ b/Nfp/Circuit/Cert/Basic.lean @@ -0,0 +1,219 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Fold +public import Mathlib.Data.Finset.Insert +public import Mathlib.Data.Fintype.Pi +public import Nfp.Circuit.Interface +public import Nfp.Circuit.Semantics + +/-! +Circuit equivalence and a finite checker. +-/ + +public section + +namespace Nfp + +universe u v u' u_in u_out + +namespace Circuit + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {α : Type v} + +/-- Circuits share the same input/output interface. -/ +def SameInterface (C₁ C₂ : Circuit ι α) : Prop := + C₁.inputs = C₂.inputs ∧ C₁.outputs = C₂.outputs + +/-- `SameInterface` is decidable. -/ +private instance (C₁ C₂ : Circuit ι α) : Decidable (SameInterface C₁ C₂) := by + dsimp [SameInterface] + infer_instance + +/-- Circuits agree on outputs for all input assignments on a fixed interface. -/ +def EquivOn (C₁ C₂ : Circuit ι α) (h : SameInterface C₁ C₂) : Prop := + ∀ input : C₁.InputAssignment, ∀ i ∈ C₁.outputs, + evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i + +/-- Circuits are equivalent if they share an interface and agree on all inputs. -/ +def Equiv (C₁ C₂ : Circuit ι α) : Prop := + ∃ h : SameInterface C₁ C₂, EquivOn C₁ C₂ h + +section Interface + +variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] +variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] +variable {ι_in : Type u_in} {ι_out : Type u_out} + +/-- Circuits agree on outputs for all typed inputs on a shared interface. -/ +def EquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) : Prop := + ∀ input : ι_in → α, ∀ o : ι_out, I₁.eval input o = I₂.eval input o + +end Interface + +section + +local instance : Std.Commutative (α := Bool) (· && ·) := ⟨Bool.and_comm⟩ +local instance : Std.Associative (α := Bool) (· && ·) := ⟨Bool.and_assoc⟩ + +/-- Boolean `all` over a finset (tail-recursive fold over the multiset). -/ +def finsetAll {β : Type v} (s : Finset β) (p : β → Bool) : Bool := by + classical + let f : Bool → β → Bool := fun acc a => acc && p a + have hf : RightCommutative f := by + refine ⟨?_⟩ + intro b a c + calc + f (f b a) c = ((b && p a) && p c) := rfl + _ = (b && (p a && p c)) := by simp [Bool.and_assoc] + _ = (b && (p c && p a)) := by simp [Bool.and_comm] + _ = ((b && p c) && p a) := by simp [Bool.and_assoc] + _ = f (f b c) a := rfl + let _ : RightCommutative f := hf + exact Multiset.foldl (f := f) (b := true) s.1 + +theorem finsetAll_eq_true_iff {β : Type v} {s : Finset β} {p : β → Bool} : + finsetAll s p = true ↔ ∀ a ∈ s, p a = true := by + classical + let f : Bool → β → Bool := fun acc a => acc && p a + have hf : RightCommutative f := by + refine ⟨?_⟩ + intro b a c + calc + f (f b a) c = ((b && p a) && p c) := rfl + _ = (b && (p a && p c)) := by simp [Bool.and_assoc] + _ = (b && (p c && p a)) := by simp [Bool.and_comm] + _ = ((b && p c) && p a) := by simp [Bool.and_assoc] + _ = f (f b c) a := rfl + let _ : RightCommutative f := hf + have hfoldl : + ∀ (s : Multiset β) (acc : Bool), + Multiset.foldl (f := f) (b := acc) s = true ↔ + acc = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + intro s acc + revert acc + refine Multiset.induction_on s ?h0 ?hcons + · intro acc + simp [Multiset.foldl_zero] + · intro a s ih acc + have ih_acc : + Multiset.foldl (f := f) (b := acc && p a) s = true ↔ + (acc && p a) = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + simpa using (ih (acc := acc && p a)) + have ih_pa : + Multiset.foldl (f := f) (b := p a) s = true ↔ + p a = true ∧ Multiset.foldl (f := f) (b := true) s = true := by + simpa using (ih (acc := p a)) + have hgoal : + Multiset.foldl (f := f) (b := acc && p a) s = true ↔ + acc = true ∧ Multiset.foldl (f := f) (b := p a) s = true := by + constructor + · intro h + have haccpa := ih_acc.mp h + have haccpa' : acc = true ∧ p a = true := by + simpa [Bool.and_eq_true] using haccpa.1 + have hacc : acc = true := haccpa'.1 + have hpa : p a = true := haccpa'.2 + have hfold : Multiset.foldl (f := f) (b := p a) s = true := + ih_pa.mpr ⟨hpa, haccpa.2⟩ + exact ⟨hacc, hfold⟩ + · intro h + rcases h with ⟨hacc, hfold⟩ + have hpa := ih_pa.mp hfold + have haccpa : (acc && p a) = true := by + simpa [Bool.and_eq_true] using And.intro hacc hpa.1 + exact ih_acc.mpr ⟨haccpa, hpa.2⟩ + simpa [Multiset.foldl_cons, f] using hgoal + induction s using Finset.induction_on with + | empty => + simp [finsetAll, Multiset.foldl_zero] + | @insert a s ha ih => + have hfold : + finsetAll (insert a s) p = true ↔ + p a = true ∧ finsetAll s p = true := by + have hval : (insert a s).1 = a ::ₘ s.1 := by + simpa using (Finset.insert_val_of_notMem (a := a) (s := s) ha) + calc + finsetAll (insert a s) p = true ↔ + Multiset.foldl (f := f) (b := true) (insert a s).1 = true := by + simp [finsetAll, f] + _ ↔ Multiset.foldl (f := f) (b := true) (a ::ₘ s.1) = true := by + simp [hval] + _ ↔ Multiset.foldl (f := f) (b := f true a) s.1 = true := by + simp [Multiset.foldl_cons] + _ ↔ Multiset.foldl (f := f) (b := p a) s.1 = true := by + simp [f] + _ ↔ p a = true ∧ Multiset.foldl (f := f) (b := true) s.1 = true := by + simpa using (hfoldl (s := s.1) (acc := p a)) + _ ↔ p a = true ∧ finsetAll s p = true := by + simp [finsetAll, f] + have hfold' : + finsetAll (insert a s) p = true ↔ p a = true ∧ finsetAll s p = true := hfold + simpa [Finset.forall_mem_insert, ih] using hfold' + +/-- Boolean check for interface equality. -/ +def sameInterface (C₁ C₂ : Circuit ι α) : Bool := + decide (C₁.inputs = C₂.inputs) && decide (C₁.outputs = C₂.outputs) + +theorem sameInterface_eq_true_iff (C₁ C₂ : Circuit ι α) : + sameInterface C₁ C₂ = true ↔ SameInterface C₁ C₂ := by + simp [sameInterface, SameInterface, Bool.and_eq_true] + +/-- Decide equivalence by enumerating all input assignments on a finite value type. -/ +def checkEquiv (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : Bool := + if h : SameInterface C₁ C₂ then + finsetAll (Finset.univ : Finset C₁.InputAssignment) (fun input => + finsetAll C₁.outputs (fun i => + decide (evalInput C₁ input i = evalInput C₂ (InputAssignment.cast h.1 input) i))) + else + false + +/-- `checkEquiv` is sound and complete for `Equiv`. -/ +theorem checkEquiv_eq_true_iff (C₁ C₂ : Circuit ι α) [Fintype α] [DecidableEq α] : + checkEquiv C₁ C₂ = true ↔ Equiv C₁ C₂ := by + classical + by_cases h : SameInterface C₁ C₂ + · have hcheck : checkEquiv C₁ C₂ = true ↔ EquivOn C₁ C₂ h := by + simp [checkEquiv, h, EquivOn, finsetAll_eq_true_iff] + constructor + · intro hc + exact ⟨h, hcheck.mp hc⟩ + · intro hEquiv + rcases hEquiv with ⟨h', hEq⟩ + have hh : h' = h := Subsingleton.elim _ _ + exact hcheck.mpr (by simpa [hh] using hEq) + · simp [checkEquiv, h, Equiv] + +end + +section InterfaceCheck + +variable {ι₁ : Type u} [Fintype ι₁] [DecidableEq ι₁] +variable {ι₂ : Type u'} [Fintype ι₂] [DecidableEq ι₂] +variable {ι_in : Type u_in} [Fintype ι_in] [DecidableEq ι_in] +variable {ι_out : Type u_out} [Fintype ι_out] + +/-- Decide interface-based equivalence by enumerating typed inputs. -/ +def checkEquivOnInterface (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) + [Fintype α] [DecidableEq α] : Bool := + finsetAll (Finset.univ : Finset (ι_in → α)) (fun input => + finsetAll (Finset.univ : Finset ι_out) (fun o => + decide (I₁.eval input o = I₂.eval input o))) + +/-- `checkEquivOnInterface` is sound and complete for `EquivOnInterface`. -/ +theorem checkEquivOnInterface_eq_true_iff (C₁ : Circuit ι₁ α) (C₂ : Circuit ι₂ α) + (I₁ : Interface C₁ ι_in ι_out) (I₂ : Interface C₂ ι_in ι_out) + [Fintype α] [DecidableEq α] : + checkEquivOnInterface C₁ C₂ I₁ I₂ = true ↔ EquivOnInterface C₁ C₂ I₁ I₂ := by + classical + simp [checkEquivOnInterface, EquivOnInterface, finsetAll_eq_true_iff] + +end InterfaceCheck + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert/InductionHead.lean b/Nfp/Circuit/Cert/InductionHead.lean new file mode 100644 index 0000000..301b983 --- /dev/null +++ b/Nfp/Circuit/Cert/InductionHead.lean @@ -0,0 +1,310 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.Circuit.Layers.Induction + +/-! +Induction-head certificates with explicit scores, weights, and value bounds. +-/ + +public section + +namespace Nfp + +namespace Circuit + +open scoped BigOperators + +variable {seq : Nat} + +/-- Certificate payload for value-interval bounds (Rat-valued). -/ +structure ValueIntervalCert (seq : Nat) where + /-- Lower bound for values. -/ + lo : Rat + /-- Upper bound for values. -/ + hi : Rat + /-- Lower bounds on per-key values. -/ + valsLo : Fin seq → Rat + /-- Upper bounds on per-key values. -/ + valsHi : Fin seq → Rat + /-- Exact per-key values. -/ + vals : Fin seq → Rat + /-- Optional logit-diff direction metadata (ignored by the checker). -/ + direction : Option DirectionSpec + +/-- Internal consistency predicate for value-interval certificates. -/ +structure ValueIntervalCertBounds {seq : Nat} (c : ValueIntervalCert seq) : Prop where + /-- Interval endpoints are ordered. -/ + lo_le_hi : c.lo ≤ c.hi + /-- `lo` is below every lower bound. -/ + lo_le_valsLo : ∀ k, c.lo ≤ c.valsLo k + /-- Lower bounds are below the values. -/ + valsLo_le_vals : ∀ k, c.valsLo k ≤ c.vals k + /-- Values are below the upper bounds. -/ + vals_le_valsHi : ∀ k, c.vals k ≤ c.valsHi k + /-- Upper bounds are below `hi`. -/ + valsHi_le_hi : ∀ k, c.valsHi k ≤ c.hi + +/-- Boolean checker for value-interval certificates. -/ +def checkValueIntervalCert [NeZero seq] (c : ValueIntervalCert seq) : Bool := + decide (c.lo ≤ c.hi) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.valsLo k) && + decide (c.valsLo k ≤ c.vals k) && + decide (c.vals k ≤ c.valsHi k) && + decide (c.valsHi k ≤ c.hi)) + +/-- `checkValueIntervalCert` is sound for `ValueIntervalCertBounds`. -/ +theorem checkValueIntervalCert_sound [NeZero seq] (c : ValueIntervalCert seq) : + checkValueIntervalCert c = true → ValueIntervalCertBounds c := by + classical + intro hcheck + have hcheck' : + decide (c.lo ≤ c.hi) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.valsLo k) && + decide (c.valsLo k ≤ c.vals k) && + decide (c.vals k ≤ c.valsHi k) && + decide (c.valsHi k ≤ c.hi)) = true := by + simpa [checkValueIntervalCert, Bool.and_eq_true] using hcheck + rcases hcheck' with ⟨hlohi, hall⟩ + have hlohi' : c.lo ≤ c.hi := by + simpa [decide_eq_true_iff] using hlohi + have hall' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hall + have hbounds : + ∀ k, c.lo ≤ c.valsLo k ∧ c.valsLo k ≤ c.vals k ∧ + c.vals k ≤ c.valsHi k ∧ c.valsHi k ≤ c.hi := by + intro k + have hk := hall' k (by simp) + have hk' : + decide (c.lo ≤ c.valsLo k) = true ∧ + decide (c.valsLo k ≤ c.vals k) = true ∧ + decide (c.vals k ≤ c.valsHi k) = true ∧ + decide (c.valsHi k ≤ c.hi) = true := by + simpa [Bool.and_eq_true, and_assoc] using hk + rcases hk' with ⟨hlo, hloVals, hvalsHi, hhi⟩ + refine ⟨?_, ?_, ?_, ?_⟩ + · simpa [decide_eq_true_iff] using hlo + · simpa [decide_eq_true_iff] using hloVals + · simpa [decide_eq_true_iff] using hvalsHi + · simpa [decide_eq_true_iff] using hhi + refine + { lo_le_hi := hlohi' + lo_le_valsLo := fun k => (hbounds k).1 + valsLo_le_vals := fun k => (hbounds k).2.1 + vals_le_valsHi := fun k => (hbounds k).2.2.1 + valsHi_le_hi := fun k => (hbounds k).2.2.2 } + +/-- Certificate payload for induction-head bounds (Rat-valued). -/ +structure InductionHeadCert (seq : Nat) where + /-- Weight tolerance. -/ + eps : Rat + /-- Per-query weight tolerance. -/ + epsAt : Fin seq → Rat + /-- Per-key weight bounds derived from score gaps. -/ + weightBoundAt : Fin seq → Fin seq → Rat + /-- Score margin used to justify weight bounds. -/ + margin : Rat + /-- Active queries for which bounds are checked. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Score matrix entries. -/ + scores : Fin seq → Fin seq → Rat + /-- Attention weight entries. -/ + weights : Fin seq → Fin seq → Rat + /-- Value-interval certificate for direction values. -/ + values : ValueIntervalCert seq + +/-- View an induction certificate as a softmax-margin certificate. -/ +def InductionHeadCert.softmaxMargin (c : InductionHeadCert seq) : SoftmaxMarginCert seq := + { eps := c.eps + margin := c.margin + active := c.active + prev := c.prev + scores := c.scores + weights := c.weights } + +private def weightsOkAt [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (epsAt : Fin seq → Rat) (q : Fin seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ weights q k) && + (if k = prev q then + true + else + decide (weights q k ≤ epsAt q))) + +private def checkOneHotAt [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (epsAt : Fin seq → Rat) (q : Fin seq) : Bool := + weightsOkAt prev weights epsAt q && + decide (1 ≤ weights q (prev q) + epsAt q) && + decide ((∑ k, weights q k) = 1) + +private def checkWeightBoundsAt [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) (q : Fin seq) : + Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = prev q then + true + else + decide (weights q k ≤ weightBoundAt q k)) + +/-- `checkOneHotAt` yields per-query approximate one-hot bounds. -/ +private theorem checkOneHotAt_sound [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (epsAt : Fin seq → Rat) (q : Fin seq) : + checkOneHotAt prev weights epsAt q = true → + Layers.OneHotApproxBoundsOnActive (Val := Rat) (epsAt q : Rat) + (fun q' => q' = q) prev weights := by + classical + intro hOneHot + have hOneHot' : + weightsOkAt prev weights epsAt q = true ∧ + decide (1 ≤ weights q (prev q) + epsAt q) = true ∧ + decide ((∑ k, weights q k) = 1) = true := by + simpa [checkOneHotAt, Bool.and_eq_true, and_assoc] using hOneHot + rcases hOneHot' with ⟨hweights, hprev, hsum⟩ + have hweights' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + cases hq' + have hk := hweights' k (by simp) + have hk' : + decide (0 ≤ weights q k) = true ∧ + (if k = prev q then + true + else + decide (weights q k ≤ epsAt q)) = true := by + simpa [Bool.and_eq_true] using hk + simpa [decide_eq_true_iff] using hk'.1 + · intro q' hq' + cases hq' + simpa [decide_eq_true_iff] using hsum + · intro q' hq' + cases hq' + simpa [decide_eq_true_iff] using hprev + · intro q' hq' k hk + cases hq' + have hk' := hweights' k (by simp) + have hk'' : + decide (0 ≤ weights q k) = true ∧ + (if k = prev q then + true + else + decide (weights q k ≤ epsAt q)) = true := by + simpa [Bool.and_eq_true] using hk' + have hother : + decide (weights q k ≤ epsAt q) = true := by + simpa [hk] using hk''.2 + simpa [decide_eq_true_iff] using hother + +/-- `checkWeightBoundsAt` yields per-key upper bounds on non-`prev` weights. -/ +private theorem checkWeightBoundsAt_sound [NeZero seq] (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Rat) (weightBoundAt : Fin seq → Fin seq → Rat) (q : Fin seq) : + checkWeightBoundsAt prev weights weightBoundAt q = true → + ∀ k, k ≠ prev q → weights q k ≤ weightBoundAt q k := by + classical + intro hweights k hk + have hweights' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights + have hk' := hweights' k (by simp) + have hk'' : decide (weights q k ≤ weightBoundAt q k) = true := by + simpa [hk] using hk' + simpa [decide_eq_true_iff] using hk'' + +/-- Boolean checker for induction-head certificates. -/ +def checkInductionHeadCert [NeZero seq] (c : InductionHeadCert seq) : Bool := + checkSoftmaxMarginCert c.softmaxMargin && + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q ∈ c.active then + checkOneHotAt c.prev c.weights c.epsAt q && + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q + else + true) && + checkValueIntervalCert c.values + +/-- Soundness predicate for induction-head certificates. -/ +structure InductionHeadCertBounds [NeZero seq] (c : InductionHeadCert seq) : Prop where + /-- Softmax-margin bounds on active queries. -/ + softmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) + c.prev c.scores c.weights + /-- Per-query one-hot bounds for the weights. -/ + oneHot_bounds_at : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Rat) (c.epsAt q : Rat) + (fun q' => q' = q) c.prev c.weights + /-- Per-key weight bounds for non-`prev` keys. -/ + weight_bounds_at : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + c.weights q k ≤ c.weightBoundAt q k + /-- Value-interval bounds are internally consistent. -/ + value_bounds : ValueIntervalCertBounds c.values + +/-- `checkInductionHeadCert` is sound for `InductionHeadCertBounds`. -/ +theorem checkInductionHeadCert_sound [NeZero seq] (c : InductionHeadCert seq) : + checkInductionHeadCert c = true → InductionHeadCertBounds c := by + classical + intro hcheck + have hsplit : + checkSoftmaxMarginCert c.softmaxMargin = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q ∈ c.active then + checkOneHotAt c.prev c.weights c.epsAt q && + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q + else + true) = true ∧ + checkValueIntervalCert c.values = true := by + simpa [checkInductionHeadCert, Bool.and_eq_true, and_assoc] using hcheck + rcases hsplit with ⟨hsoftmax, hactive, hvalues⟩ + have hsoftmax' : + Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) + c.prev c.scores c.weights := + checkSoftmaxMarginCert_sound c.softmaxMargin hsoftmax + have hactive' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hactive + have honehot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Rat) (c.epsAt q : Rat) + (fun q' => q' = q) c.prev c.weights := by + intro q hq + have hq' := hactive' q (by simp) + have hq'' : + checkOneHotAt c.prev c.weights c.epsAt q = true ∧ + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q = true := by + simpa [hq, Bool.and_eq_true] using hq' + have hOneHot : checkOneHotAt c.prev c.weights c.epsAt q = true := hq''.1 + exact checkOneHotAt_sound c.prev c.weights c.epsAt q hOneHot + have hweightBounds : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + c.weights q k ≤ c.weightBoundAt q k := by + intro q hq k hk + have hq' := hactive' q (by simp) + have hq'' : + checkOneHotAt c.prev c.weights c.epsAt q = true ∧ + checkWeightBoundsAt c.prev c.weights c.weightBoundAt q = true := by + simpa [hq, Bool.and_eq_true] using hq' + have hweights : checkWeightBoundsAt c.prev c.weights c.weightBoundAt q = true := hq''.2 + exact checkWeightBoundsAt_sound c.prev c.weights c.weightBoundAt q hweights k hk + have hvals : ValueIntervalCertBounds c.values := + checkValueIntervalCert_sound c.values hvalues + exact + { softmax_bounds := hsoftmax' + oneHot_bounds_at := honehot + weight_bounds_at := hweightBounds + value_bounds := hvals } + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert/LogitDiff.lean b/Nfp/Circuit/Cert/LogitDiff.lean new file mode 100644 index 0000000..8eecac8 --- /dev/null +++ b/Nfp/Circuit/Cert/LogitDiff.lean @@ -0,0 +1,208 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Core.Basic +public import Mathlib.Data.Finset.Lattice.Fold +public import Nfp.Circuit.Layers.Induction + +/-! +Lower bounds for logit-diff contributions from induction-style heads. +-/ + +public section + +namespace Nfp + +namespace Circuit + +variable {seq : Nat} + +/-- Compute a lower bound on the logit-diff contribution over active queries. -/ +def logitDiffLowerBound (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (eps lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let gap := eps * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap + exact some (active.inf' h f) + else + exact none + +/-- Compute a lower bound on the logit-diff contribution with per-query eps. -/ +def logitDiffLowerBoundAt (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap q + exact some (active.inf' h f) + else + exact none + +/-- Compute a lower bound on the logit-diff contribution using per-query eps and the global + lower value bound. -/ +def logitDiffLowerBoundAtLo (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo : Rat) (valsLo : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let f : Fin seq → Rat := fun q => + valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) + exact some (active.inf' h f) + else + exact none + +/-- Compute a lower bound on the logit-diff contribution using per-query eps and per-query + lower bounds for other values. -/ +def logitDiffLowerBoundAtLoAt (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (loAt : Fin seq → Rat) (valsLo : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let f : Fin seq → Rat := fun q => + let delta := valsLo (prev q) - loAt q + valsLo (prev q) - epsAt q * max (0 : Rat) delta + exact some (active.inf' h f) + else + exact none + +/-- Compute a lower bound on the logit-diff contribution using per-key weight bounds and + per-key value lower bounds. -/ +def logitDiffLowerBoundWeightedAt (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (weightBoundAt : Fin seq → Fin seq → Rat) + (valsLo : Fin seq → Rat) : Option Rat := by + classical + if h : active.Nonempty then + let gap : Fin seq → Rat := fun q => + (Finset.univ : Finset (Fin seq)).sum (fun k => + let diff := valsLo (prev q) - valsLo k + weightBoundAt q k * max (0 : Rat) diff) + let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q + exact some (active.inf' h f) + else + exact none + +/-- Unfolding lemma for `logitDiffLowerBoundWeightedAt`. -/ +theorem logitDiffLowerBoundWeightedAt_def (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (weightBoundAt : Fin seq → Fin seq → Rat) + (valsLo : Fin seq → Rat) : + logitDiffLowerBoundWeightedAt active prev weightBoundAt valsLo = + by + classical + if h : active.Nonempty then + let gap : Fin seq → Rat := fun q => + (Finset.univ : Finset (Fin seq)).sum (fun k => + let diff := valsLo (prev q) - valsLo k + weightBoundAt q k * max (0 : Rat) diff) + let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q + exact some (active.inf' h f) + else + exact none := by + classical + rfl + +/-- The computed lower bound is below every active `prev` value minus the tolerance gap. -/ +theorem logitDiffLowerBound_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (eps lo hi : Rat) (vals : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBound active prev eps lo hi vals = some lb → + lb ≤ vals (prev q) - eps * (hi - lo) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let gap := eps * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap + have hbound' : active.inf' hnonempty f = lb := by + simpa [logitDiffLowerBound, hnonempty, f, gap] using hbound + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, gap, hbound'] using hmin + +/-- The per-query lower bound is below every active `prev` value minus the local gap. -/ +theorem logitDiffLowerBoundAt_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo hi : Rat) (vals : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundAt active prev epsAt lo hi vals = some lb → + lb ≤ vals (prev q) - epsAt q * (hi - lo) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let gap : Fin seq → Rat := fun q => epsAt q * (hi - lo) + let f : Fin seq → Rat := fun q => vals (prev q) - gap q + have hbound' : active.inf' hnonempty f = lb := by + simpa [logitDiffLowerBoundAt, hnonempty, f, gap] using hbound + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, gap, hbound'] using hmin + +/-- The per-query lower bound is below every active `prev` value minus the `lo`-gap. -/ +theorem logitDiffLowerBoundAtLo_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (lo : Rat) (valsLo : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundAtLo active prev epsAt lo valsLo = some lb → + lb ≤ valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let f : Fin seq → Rat := fun q => + valsLo (prev q) - epsAt q * (valsLo (prev q) - lo) + have hbound' : active.inf' hnonempty f = lb := by + simpa [logitDiffLowerBoundAtLo, hnonempty, f] using hbound + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, hbound'] using hmin + +/-- The per-query lower bound is below every active `prev` value minus the local `loAt` gap. -/ +theorem logitDiffLowerBoundAtLoAt_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (epsAt : Fin seq → Rat) (loAt : Fin seq → Rat) (valsLo : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundAtLoAt active prev epsAt loAt valsLo = some lb → + lb ≤ valsLo (prev q) - epsAt q * max (0 : Rat) (valsLo (prev q) - loAt q) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let f : Fin seq → Rat := fun q => + let delta := valsLo (prev q) - loAt q + valsLo (prev q) - epsAt q * max (0 : Rat) delta + have hbound' : active.inf' hnonempty f = lb := by + simpa [logitDiffLowerBoundAtLoAt, hnonempty, f] using hbound + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, hbound'] using hmin + +/-- The weighted lower bound is below every active `prev` value minus the weighted gap. -/ +theorem logitDiffLowerBoundWeightedAt_le (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (weightBoundAt : Fin seq → Fin seq → Rat) + (valsLo : Fin seq → Rat) + (q : Fin seq) (hq : q ∈ active) : + ∀ lb, logitDiffLowerBoundWeightedAt active prev weightBoundAt valsLo = some lb → + lb ≤ + valsLo (prev q) - + (Finset.univ : Finset (Fin seq)).sum (fun k => + weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) := by + classical + intro lb hbound + have hnonempty : active.Nonempty := ⟨q, hq⟩ + let gap : Fin seq → Rat := fun q => + (Finset.univ : Finset (Fin seq)).sum (fun k => + weightBoundAt q k * max (0 : Rat) (valsLo (prev q) - valsLo k)) + let f : Fin seq → Rat := fun q => valsLo (prev q) - gap q + have hbound' : active.inf' hnonempty f = lb := by + simpa [logitDiffLowerBoundWeightedAt, hnonempty, f, gap] using hbound + have hmin : active.inf' hnonempty f ≤ f q := + Finset.inf'_le (s := active) (f := f) hq + simpa [f, gap, hbound'] using hmin + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert/SoftmaxMargin.lean b/Nfp/Circuit/Cert/SoftmaxMargin.lean new file mode 100644 index 0000000..63cd189 --- /dev/null +++ b/Nfp/Circuit/Cert/SoftmaxMargin.lean @@ -0,0 +1,165 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Layers.Induction + +/-! +Softmax-margin certificates for approximate one-hot attention weights. +-/ + +public section + +namespace Nfp + +namespace Circuit + +open scoped BigOperators + +variable {seq : Nat} + +/-- Certificate payload for softmax-margin bounds (Rat-valued). -/ +structure SoftmaxMarginCert (seq : Nat) where + /-- Weight tolerance. -/ + eps : Rat + /-- Score margin used to justify weight bounds. -/ + margin : Rat + /-- Active queries for which bounds are checked. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Score matrix entries. -/ + scores : Fin seq → Fin seq → Rat + /-- Attention weight entries. -/ + weights : Fin seq → Fin seq → Rat + +/-- Boolean checker for softmax-margin certificates. -/ +def checkSoftmaxMarginCert [NeZero seq] (c : SoftmaxMarginCert seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q ∈ c.active then + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) + else + true) + +/-- `checkSoftmaxMarginCert` is sound for `SoftmaxMarginBoundsOn`. -/ +theorem checkSoftmaxMarginCert_sound [NeZero seq] (c : SoftmaxMarginCert seq) : + checkSoftmaxMarginCert c = true → + Layers.SoftmaxMarginBoundsOn (Val := Rat) c.eps c.margin (fun q => q ∈ c.active) + c.prev c.scores c.weights := by + classical + intro hcheck + let weightsOk (q : Fin seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (0 ≤ c.weights q k) && + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps))) + let scoresOk (q : Fin seq) : Bool := + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + if k = c.prev q then + true + else + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q))) + have hqall : + ∀ q ∈ (Finset.univ : Finset (Fin seq)), + (if q ∈ c.active then + weightsOk q && + scoresOk q && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) + else + true) = true := by + have hcheck'' : + finsetAll (Finset.univ : Finset (Fin seq)) (fun q => + if q ∈ c.active then + weightsOk q && + scoresOk q && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) + else + true) = true := by + simpa [checkSoftmaxMarginCert, weightsOk, scoresOk] using hcheck + exact (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hcheck'' + have hqchecks {q} (hq : q ∈ c.active) : + weightsOk q = true ∧ + scoresOk q = true ∧ + decide (1 ≤ c.weights q (c.prev q) + c.eps) = true ∧ + decide ((∑ k, c.weights q k) = 1) = true := by + have hqall' := hqall q (by simp) + have hqall'' : + weightsOk q && + scoresOk q && + decide (1 ≤ c.weights q (c.prev q) + c.eps) && + decide ((∑ k, c.weights q k) = 1) = true := by + simpa [hq] using hqall' + simpa [Bool.and_eq_true, and_assoc] using hqall'' + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + rcases hqchecks hq with ⟨_, hscore, _, _⟩ + have hscoreall := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hscore + have hscorek := hscoreall k (by simp) + have hscorek' : + decide (c.scores q k + c.margin ≤ c.scores q (c.prev q)) = true := by + simpa [hk] using hscorek + simpa [decide_eq_true_iff] using hscorek' + · intro q hq k + rcases hqchecks hq with ⟨hweights, _, _, _⟩ + have hweightsall := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights + have hweightsk := hweightsall k (by simp) + have hweightsk' : + decide (0 ≤ c.weights q k) = true ∧ + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps)) = true := by + simpa [Bool.and_eq_true] using hweightsk + simpa [decide_eq_true_iff] using hweightsk'.1 + · intro q hq + rcases hqchecks hq with ⟨_, _, _, hsum⟩ + simpa [decide_eq_true_iff] using hsum + · intro q hq + rcases hqchecks hq with ⟨_, _, hprev, _⟩ + simpa [decide_eq_true_iff] using hprev + · intro q hq k hk + rcases hqchecks hq with ⟨hweights, _, _, _⟩ + have hweightsall := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hweights + have hweightsk := hweightsall k (by simp) + have hweightsk' : + decide (0 ≤ c.weights q k) = true ∧ + (if k = c.prev q then + true + else + decide (c.weights q k ≤ c.eps)) = true := by + simpa [Bool.and_eq_true] using hweightsk + have hother : + decide (c.weights q k ≤ c.eps) = true := by + simpa [hk] using hweightsk'.2 + simpa [decide_eq_true_iff] using hother + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Cert/ValueRange.lean b/Nfp/Circuit/Cert/ValueRange.lean new file mode 100644 index 0000000..56be21d --- /dev/null +++ b/Nfp/Circuit/Cert/ValueRange.lean @@ -0,0 +1,75 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.Basic +public import Nfp.Circuit.Layers.Induction + +/-! +Value-range certificates for attention value vectors. +-/ + +public section + +namespace Nfp + +namespace Circuit + +open scoped BigOperators + +variable {seq : Nat} + +/-- Metadata describing a logit-diff direction (target minus negative token). -/ +structure DirectionSpec where + /-- Target token id for the logit-diff direction. -/ + target : Nat + /-- Negative token id for the logit-diff direction. -/ + negative : Nat + +/-- Certificate payload for value-range bounds (Rat-valued). -/ +structure ValueRangeCert (seq : Nat) where + /-- Lower bound for values. -/ + lo : Rat + /-- Upper bound for values. -/ + hi : Rat + /-- Value entries. -/ + vals : Fin seq → Rat + /-- Optional logit-diff direction metadata (ignored by the checker). -/ + direction : Option DirectionSpec + +/-- Boolean checker for value-range certificates. -/ +def checkValueRangeCert [NeZero seq] (c : ValueRangeCert seq) : Bool := + decide (c.lo ≤ c.hi) && + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.vals k) && decide (c.vals k ≤ c.hi)) + +/-- `checkValueRangeCert` is sound for `ValueRangeBounds`. -/ +theorem checkValueRangeCert_sound [NeZero seq] (c : ValueRangeCert seq) : + checkValueRangeCert c = true → + Layers.ValueRangeBounds (Val := Rat) c.lo c.hi c.vals := by + classical + intro hcheck + have hcheck' : + decide (c.lo ≤ c.hi) = true ∧ + finsetAll (Finset.univ : Finset (Fin seq)) (fun k => + decide (c.lo ≤ c.vals k) && decide (c.vals k ≤ c.hi)) = true := by + simpa [checkValueRangeCert, Bool.and_eq_true] using hcheck + rcases hcheck' with ⟨hlohi, hall⟩ + have hlohi' : c.lo ≤ c.hi := by + simpa [decide_eq_true_iff] using hlohi + have hall' := + (finsetAll_eq_true_iff (s := (Finset.univ : Finset (Fin seq)))).1 hall + have hbounds : ∀ k, c.lo ≤ c.vals k ∧ c.vals k ≤ c.hi := by + intro k + have hk := hall' k (by simp) + simpa [Bool.and_eq_true, decide_eq_true_iff] using hk + exact + { lo_le_hi := hlohi' + lo_le := fun k => (hbounds k).1 + le_hi := fun k => (hbounds k).2 } + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Combinators.lean b/Nfp/Circuit/Combinators.lean new file mode 100644 index 0000000..aaada71 --- /dev/null +++ b/Nfp/Circuit/Combinators.lean @@ -0,0 +1,63 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Image +public import Mathlib.Logic.Equiv.Basic +public import Nfp.Circuit.Interface + +/-! +Circuit combinators such as relabeling. +-/ + +public section + +namespace Nfp + +universe u v u' u_in u_out + +namespace Circuit + +variable {Node : Type u} [Fintype Node] +variable {Node' : Type u'} [Fintype Node'] +variable {Val : Type v} + +/-- Relabel the nodes of a circuit along an equivalence. -/ +def relabel (C : Circuit Node Val) (e : _root_.Equiv Node Node') : Circuit Node' Val := by + refine + { dag := C.dag.relabel e + inputs := C.inputs.map e.toEmbedding + outputs := C.outputs.map e.toEmbedding + gate := ?_ } + intro i rec + refine C.gate (e.symm i) ?_ + intro j h + refine rec (e j) ?_ + simpa [Dag.relabel_rel_iff] using h + +namespace Interface + +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Node' : Type u'} [Fintype Node'] [DecidableEq Node'] +variable {Val : Type v} +variable {Input : Type u_in} {Output : Type u_out} +variable {C : Circuit Node Val} + +/-- Relabel a circuit interface along an equivalence of nodes. -/ +def relabel (I : Interface C Input Output) (e : _root_.Equiv Node Node') : + Interface (C.relabel e) Input Output := by + refine { inputs := ?_, outputs := ?_ } + · refine I.inputs.trans ?_ + refine (e.subtypeEquiv ?_) + intro a + simp [Circuit.relabel] + · refine I.outputs.trans ?_ + refine (e.subtypeEquiv ?_) + intro a + simp [Circuit.relabel] + +end Interface + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Compose.lean b/Nfp/Circuit/Compose.lean new file mode 100644 index 0000000..ad692c0 --- /dev/null +++ b/Nfp/Circuit/Compose.lean @@ -0,0 +1,383 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Disjoint +public import Mathlib.Data.Fintype.Sum +public import Mathlib.Data.Sum.Order +public import Mathlib.Logic.Embedding.Basic +public import Nfp.Circuit.Typed + +/-! +Combinators for composing typed circuits (sequential and residual wiring). +-/ + +public section + +namespace Nfp + +universe u v u' u_in u_mid u_out + +namespace Circuit + +open Function + +section SumEquiv + +variable {Left : Type u} {Right : Type u'} + +/-- Embed a finset subtype into the left injection of a sum. -/ +def inlSubtypeEquiv (s : Finset Left) : + { i // i ∈ s } ≃ { i // i ∈ s.map (Embedding.inl : Left ↪ Left ⊕ Right) } := + { toFun := fun i => + ⟨Sum.inl i.1, by + refine Finset.mem_map.2 ?_ + exact ⟨i.1, i.2, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl a, ha⟩ => + let h' : a ∈ s := by + rcases (Finset.mem_map.1 ha) with ⟨a', ha', h⟩ + cases h + exact ha' + ⟨a, h'⟩ + | ⟨Sum.inr b, hb⟩ => + False.elim <| by + rcases (Finset.mem_map.1 hb) with ⟨a, _ha, h⟩ + cases h + left_inv := by + intro i + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl a => + rfl + | inr b => + have : False := by + rcases (Finset.mem_map.1 hs) with ⟨a, _ha, h⟩ + cases h + cases this } + +/-- Embed a finset subtype into the right injection of a sum. -/ +def inrSubtypeEquiv (s : Finset Right) : + { i // i ∈ s } ≃ { i // i ∈ s.map (Embedding.inr : Right ↪ Left ⊕ Right) } := + { toFun := fun i => + ⟨Sum.inr i.1, by + refine Finset.mem_map.2 ?_ + exact ⟨i.1, i.2, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr a, ha⟩ => + let h' : a ∈ s := by + rcases (Finset.mem_map.1 ha) with ⟨a', ha', h⟩ + cases h + exact ha' + ⟨a, h'⟩ + | ⟨Sum.inl b, hb⟩ => + False.elim <| by + rcases (Finset.mem_map.1 hb) with ⟨a, _ha, h⟩ + cases h + left_inv := by + intro i + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inr a => + rfl + | inl b => + have : False := by + rcases (Finset.mem_map.1 hs) with ⟨a, _ha, h⟩ + cases h + cases this } + +end SumEquiv + +section Seq + +variable {Node₁ : Type u} [Fintype Node₁] +variable {Node₂ : Type u'} [Fintype Node₂] +variable {Val : Type v} +variable {Input : Type u_in} {Mid : Type u_mid} {Output : Type u_out} +variable {C1 : Circuit Node₁ Val} {C2 : Circuit Node₂ Val} +variable {I1 : Interface C1 Input Mid} {I2 : Interface C2 Mid Output} + +/-- Bridge edges from the outputs of `C1` to the inputs of `C2`. -/ +def seqBridge (j : Node₁) (i : Node₂) : Prop := + ∃ h : i ∈ C2.inputs, + j = (I1.outputs (I2.inputs.symm ⟨i, h⟩)).1 + +/-- Fixing an input witness reduces `seqBridge` to an equality. -/ +theorem seqBridge_iff_eq {j : Node₁} {i : Node₂} (hmem : i ∈ C2.inputs) : + seqBridge (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) j i ↔ + j = (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 := by + constructor + · rintro ⟨h, hEq⟩ + have hSubtype : + (⟨i, h⟩ : { i // i ∈ C2.inputs }) = ⟨i, hmem⟩ := by + ext + rfl + simpa [hSubtype] using hEq + · intro hEq + exact ⟨hmem, hEq⟩ + +/-- Adjacency for sequentially composed circuits. -/ +def seqAdj : Node₁ ⊕ Node₂ → Node₁ ⊕ Node₂ → Prop + | Sum.inl j, Sum.inl i => C1.dag.rel j i + | Sum.inl j, Sum.inr i => + seqBridge (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) j i + | Sum.inr j, Sum.inr i => C2.dag.rel j i + | _, _ => False + +variable [DecidableEq Node₁] [DecidableEq Node₂] + +/-- DAG for sequentially composed circuits. -/ +def seqDag : Dag (Node₁ ⊕ Node₂) := + { graph := { Adj := seqAdj (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) } + decAdj := by + intro j i + cases j with + | inl j => + cases i with + | inl i => + exact (inferInstance : Decidable (C1.dag.rel j i)) + | inr i => + by_cases hmem : i ∈ C2.inputs + · by_cases hEq : + j = (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 + · exact isTrue ⟨hmem, hEq⟩ + · exact isFalse (by + intro h + have hEq' : + j = (I1.outputs (I2.inputs.symm ⟨i, hmem⟩)).1 := + (seqBridge_iff_eq (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) + (j := j) (i := i) hmem).1 h + exact hEq hEq') + · exact isFalse (by + intro h + exact hmem h.1) + | inr j => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + exact (inferInstance : Decidable (C2.dag.rel j i)) + wf := by + have hsub : + Subrelation + (seqAdj (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2)) + (Sum.Lex C1.dag.rel C2.dag.rel) := by + intro j i h + cases j with + | inl j => + cases i with + | inl i => + exact Sum.Lex.inl h + | inr i => + exact Sum.Lex.sep _ _ + | inr j => + cases i with + | inl _ => + exact False.elim h + | inr i => + exact Sum.Lex.inr h + have hwf : WellFounded (Sum.Lex C1.dag.rel C2.dag.rel) := + Sum.lex_wf C1.dag.wf C2.dag.wf + exact Subrelation.wf hsub hwf } + +/-- Sequential composition of circuits at the node level. -/ +def seqCircuit : Circuit (Node₁ ⊕ Node₂) Val := + { dag := seqDag (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2) + inputs := C1.inputs.map (Embedding.inl : Node₁ ↪ Node₁ ⊕ Node₂) + outputs := C2.outputs.map (Embedding.inr : Node₂ ↪ Node₁ ⊕ Node₂) + gate := by + intro i rec + cases i with + | inl i => + exact C1.gate i (fun j h => + rec (Sum.inl j) (by simpa using h)) + | inr i => + by_cases hinput : i ∈ C2.inputs + · let mid : Mid := I2.inputs.symm ⟨i, hinput⟩ + let out : Node₁ := (I1.outputs mid).1 + exact rec (Sum.inl out) (by + refine ⟨hinput, rfl⟩) + · exact C2.gate i (fun j h => + rec (Sum.inr j) (by simpa using h)) } + +/-- Interface for sequentially composed circuits. -/ +def seqInterface : + Interface + (seqCircuit (C1 := C1) (C2 := C2) (I1 := I1) (I2 := I2)) Input Output := + { inputs := + I1.inputs.trans (inlSubtypeEquiv (s := C1.inputs)) + outputs := + I2.outputs.trans (inrSubtypeEquiv (s := C2.outputs)) } + +end Seq + +section Residual + +variable {Node : Type u} +variable {Input : Type u_in} [Fintype Input] + +/-- Output nodes for residual wiring. -/ +def residualOutputs : Finset (Node ⊕ Input) := + (Finset.univ : Finset Input).map (Embedding.inr : Input ↪ Node ⊕ Input) + +/-- Output equivalence for residual wiring. -/ +def residualOutputEquiv : + Input ≃ { i // i ∈ residualOutputs (Node := Node) (Input := Input) } := + { toFun := fun o => + ⟨Sum.inr o, by + refine Finset.mem_map.2 ?_ + exact ⟨o, by simp, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr o, _⟩ => o + | ⟨Sum.inl _, h⟩ => + False.elim <| by + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + left_inv := by + intro o + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inr o => + rfl + | inl s => + have : False := by + rcases (Finset.mem_map.1 hs) with ⟨o, _ho, ho⟩ + cases ho + cases this } + +variable [Fintype Node] +variable {Val : Type v} [Add Val] +variable {C : Circuit Node Val} +variable {I : Interface C Input Input} + +/-- Adjacency for residual wiring on a typed circuit. -/ +def residualAdj : Node ⊕ Input → Node ⊕ Input → Prop + | Sum.inl j, Sum.inl i => C.dag.rel j i + | Sum.inl j, Sum.inr o => j = (I.inputs o).1 ∨ j = (I.outputs o).1 + | _, _ => False + +variable [DecidableEq Node] + +/-- DAG for residual wiring on a typed circuit. -/ +def residualDag : Dag (Node ⊕ Input) := + { graph := { Adj := residualAdj (C := C) (I := I) } + decAdj := by + intro j i + cases j with + | inl j => + cases i with + | inl i => + exact (inferInstance : Decidable (C.dag.rel j i)) + | inr o => + exact (inferInstance : + Decidable (j = (I.inputs o).1 ∨ j = (I.outputs o).1)) + | inr _ => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr _ => + exact isFalse (by intro h; cases h) + wf := by + have hsub : + Subrelation (residualAdj (C := C) (I := I)) + (Sum.Lex C.dag.rel (fun _ _ : Input => False)) := by + intro j i h + cases j with + | inl j => + cases i with + | inl i => + exact Sum.Lex.inl h + | inr _ => + exact Sum.Lex.sep _ _ + | inr _ => + cases i with + | inl _ => + exact False.elim h + | inr _ => + exact False.elim h + have hfalse : WellFounded (fun _ _ : Input => False) := by + refine ⟨?_⟩ + intro a + refine Acc.intro a ?_ + intro b h + cases h + have hwf : WellFounded (Sum.Lex C.dag.rel (fun _ _ : Input => False)) := + Sum.lex_wf C.dag.wf hfalse + exact Subrelation.wf hsub hwf } + +/-- Circuit that adds a residual connection to a typed circuit. -/ +def residualCircuit : Circuit (Node ⊕ Input) Val := + { dag := residualDag (C := C) (I := I) + inputs := C.inputs.map (Embedding.inl : Node ↪ Node ⊕ Input) + outputs := residualOutputs (Node := Node) (Input := Input) + gate := by + intro i rec + cases i with + | inl i => + exact C.gate i (fun j h => + rec (Sum.inl j) (by simpa [residualAdj] using h)) + | inr o => + let inNode : Node := (I.inputs o).1 + let outNode : Node := (I.outputs o).1 + let inVal := rec (Sum.inl inNode) (by + change inNode = (I.inputs o).1 ∨ inNode = (I.outputs o).1 + exact Or.inl rfl) + let outVal := rec (Sum.inl outNode) (by + change outNode = (I.inputs o).1 ∨ outNode = (I.outputs o).1 + exact Or.inr rfl) + exact inVal + outVal } + +/-- Interface for residual wiring. -/ +def residualInterface : + Interface (residualCircuit (C := C) (I := I)) Input Input := + { inputs := I.inputs.trans (inlSubtypeEquiv (s := C.inputs)) + outputs := residualOutputEquiv (Node := Node) (Input := Input) } + +end Residual + +namespace TypedCircuit + +variable {Node₁ : Type u} [Fintype Node₁] [DecidableEq Node₁] +variable {Node₂ : Type u'} [Fintype Node₂] [DecidableEq Node₂] +variable {Val : Type v} +variable {Input : Type u_in} {Mid : Type u_mid} {Output : Type u_out} + +/-- Sequential composition of typed circuits. -/ +def seq (T₁ : TypedCircuit Node₁ Val Input Mid) + (T₂ : TypedCircuit Node₂ Val Mid Output) : + TypedCircuit (Node₁ ⊕ Node₂) Val Input Output := + { circuit := seqCircuit (C1 := T₁.circuit) (C2 := T₂.circuit) + (I1 := T₁.interface) (I2 := T₂.interface) + interface := seqInterface (C1 := T₁.circuit) (C2 := T₂.circuit) + (I1 := T₁.interface) (I2 := T₂.interface) } + +variable [Add Val] +variable [Fintype Input] +variable [DecidableEq Input] + +/-- Add a residual connection to a typed circuit. -/ +def residual (T : TypedCircuit Node₁ Val Input Input) : + TypedCircuit (Node₁ ⊕ Input) Val Input Input := + { circuit := residualCircuit (C := T.circuit) (I := T.interface) + interface := residualInterface (C := T.circuit) (I := T.interface) } + +end TypedCircuit + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Gates.lean b/Nfp/Circuit/Gates.lean new file mode 100644 index 0000000..06b6d54 --- /dev/null +++ b/Nfp/Circuit/Gates.lean @@ -0,0 +1,10 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Gates.Basic +public import Nfp.Circuit.Gates.Linear + +/-! +Gate combinators for circuit semantics. +-/ diff --git a/Nfp/Circuit/Gates/Basic.lean b/Nfp/Circuit/Gates/Basic.lean new file mode 100644 index 0000000..31f60c8 --- /dev/null +++ b/Nfp/Circuit/Gates/Basic.lean @@ -0,0 +1,44 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.Ring.Basic +public import Mathlib.Data.Finset.Attach +public import Mathlib.Data.Fintype.BigOperators + +/-! +Basic gate combinators for aggregating parent values. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Gates + +universe u v + +variable {Node : Type u} {Val : Type v} + +/-- Sum of parent values. -/ +def sumParents (parents : Finset Node) (rec : ∀ j, j ∈ parents → Val) + [AddCommMonoid Val] : Val := + parents.attach.sum fun j => rec j.1 j.2 + +/-- Weighted sum of parent values using weights `w`. -/ +def weightedSumParents (parents : Finset Node) (w : Node → Val) + (rec : ∀ j, j ∈ parents → Val) [Semiring Val] : Val := + parents.attach.sum fun j => w j.1 * rec j.1 j.2 + +/-- Affine combination of parent values with weights `w` and bias `b`. -/ +def affineParents (parents : Finset Node) (w : Node → Val) (b : Val) + (rec : ∀ j, j ∈ parents → Val) [Semiring Val] : Val := + weightedSumParents parents w rec + b + +end Gates + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Gates/Linear.lean b/Nfp/Circuit/Gates/Linear.lean new file mode 100644 index 0000000..6502b9b --- /dev/null +++ b/Nfp/Circuit/Gates/Linear.lean @@ -0,0 +1,37 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Matrix.Mul + +/-! +Linear and affine gate combinators built from `Matrix.mulVec`. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Gates + +universe u v + +variable {Row : Type u} {Col : Type u} {Val : Type v} + +/-- Linear map on vectors defined by a matrix. -/ +def linear [Fintype Row] [Fintype Col] [NonUnitalNonAssocSemiring Val] + (W : Matrix Row Col Val) (x : Col → Val) : Row → Val := + Matrix.mulVec W x + +/-- Affine map on vectors defined by a matrix and bias. -/ +def affine [Fintype Row] [Fintype Col] [NonUnitalNonAssocSemiring Val] + (W : Matrix Row Col Val) (b : Row → Val) (x : Col → Val) : Row → Val := + linear W x + b + +end Gates + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Interface.lean b/Nfp/Circuit/Interface.lean new file mode 100644 index 0000000..37ddf57 --- /dev/null +++ b/Nfp/Circuit/Interface.lean @@ -0,0 +1,62 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Semantics + +/-! +Typed input/output interfaces for circuits. +-/ + +public section + +namespace Nfp + +universe u v u_in u_out + +namespace Circuit + +variable {ι : Type u} [Fintype ι] +variable {α : Type v} + +/-- A typed input/output interface for a circuit. -/ +structure Interface (C : Circuit ι α) (ι_in : Type u_in) (ι_out : Type u_out) where + /-- Input labels correspond exactly to the circuit's input nodes. -/ + inputs : ι_in ≃ { i // i ∈ C.inputs } + /-- Output labels correspond exactly to the circuit's output nodes. -/ + outputs : ι_out ≃ { i // i ∈ C.outputs } + +namespace Interface + +variable {C : Circuit ι α} {ι_in : Type u_in} {ι_out : Type u_out} + +/-- Convert a typed input assignment into an input-node assignment. -/ +def toInputAssignment (I : Interface C ι_in ι_out) (input : ι_in → α) : + C.InputAssignment := + fun i => input (I.inputs.symm i) + +/-- Definitional characterization of `Interface.toInputAssignment`. -/ +theorem toInputAssignment_def (I : Interface C ι_in ι_out) (input : ι_in → α) : + I.toInputAssignment input = fun i => input (I.inputs.symm i) := by + rfl + +section Eval + +variable [DecidableEq ι] + +/-- Evaluate a circuit on a typed interface. -/ +def eval (I : Interface C ι_in ι_out) (input : ι_in → α) : ι_out → α := + fun o => evalInput C (I.toInputAssignment input) (I.outputs o).1 + +/-- Definitional characterization of `Interface.eval`. -/ +theorem eval_def (I : Interface C ι_in ι_out) (input : ι_in → α) (o : ι_out) : + I.eval input o = evalInput C (I.toInputAssignment input) (I.outputs o).1 := by + rfl + +end Eval + +end Interface + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers.lean b/Nfp/Circuit/Layers.lean new file mode 100644 index 0000000..d0e0111 --- /dev/null +++ b/Nfp/Circuit/Layers.lean @@ -0,0 +1,16 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Layers.Linear +public import Nfp.Circuit.Layers.Tensor +public import Nfp.Circuit.Layers.Reshape +public import Nfp.Circuit.Layers.Heads +public import Nfp.Circuit.Layers.Attention +public import Nfp.Circuit.Layers.Softmax +public import Nfp.Circuit.Layers.Induction +public import Nfp.Circuit.Layers.TransformerBlock + +/-! +Circuit layer combinators. +-/ diff --git a/Nfp/Circuit/Layers/Attention.lean b/Nfp/Circuit/Layers/Attention.lean new file mode 100644 index 0000000..6005a62 --- /dev/null +++ b/Nfp/Circuit/Layers/Attention.lean @@ -0,0 +1,883 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Image +public import Mathlib.Data.Matrix.Mul +public import Mathlib.Logic.Embedding.Basic +public import Nfp.Circuit.Layers.Heads +public import Nfp.Circuit.Layers.Tensor + +/-! +QKV and output projection wiring for attention layers, plus attention score/mixing core. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +open Function + +universe v + +variable {Batch : Type} [Fintype Batch] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Hidden dimension type for a `heads × dim` factorization. -/ +abbrev Hidden (heads dim : Nat) : Type := Fin (heads * dim) + +/-- Node type for Q/K/V and output projection layers. -/ +abbrev QkvNode (Batch : Type) (heads dim : Nat) : Type := + BatchedLinearNode Batch (Hidden heads dim) (Hidden heads dim) + +section Projections + +variable [DecidableEq Batch] + +/-- Q projection with head-split output. -/ +def qProj (heads dim : Nat) (Wq : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Hidden heads dim) + (Batch × Fin heads × Fin dim) := + splitHeadsOutput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wq) + +/-- K projection with head-split output. -/ +def kProj (heads dim : Nat) (Wk : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Hidden heads dim) + (Batch × Fin heads × Fin dim) := + splitHeadsOutput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wk) + +/-- V projection with head-split output. -/ +def vProj (heads dim : Nat) (Wv : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Hidden heads dim) + (Batch × Fin heads × Fin dim) := + splitHeadsOutput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wv) + +/-- Output projection with head-merged input. -/ +def outProj (heads dim : Nat) (Wo : Matrix (Hidden heads dim) (Hidden heads dim) Val) : + TypedCircuit (QkvNode Batch heads dim) Val (Batch × Fin heads × Fin dim) + (Batch × Hidden heads dim) := + splitHeadsInput (Batch := Batch) heads dim + (batchedLinearTyped (Batch := Batch) + (Row := Hidden heads dim) (Col := Hidden heads dim) Wo) + +end Projections + +/- +Attention score/mixing core. +-/ + +variable {seq heads dim : Nat} + +/-- Index for per-token head-split vectors. -/ +abbrev QkvIndex (Batch : Type) (seq heads dim : Nat) : Type := + Batch × Fin seq × Fin heads × Fin dim + +/-- Index for attention score/weight entries. -/ +abbrev ScoreIndex (Batch : Type) (seq heads : Nat) : Type := + Batch × Fin heads × Fin seq × Fin seq + +/-- Index for attention weight entries (same shape as scores). -/ +abbrev WeightIndex (Batch : Type) (seq heads : Nat) : Type := + ScoreIndex Batch seq heads + +/-- Input-node labels for attention core circuits. -/ +abbrev AttentionInputNode (Batch : Type) (seq heads dim : Nat) : Type := + Sum (QkvIndex Batch seq heads dim) + (Sum (QkvIndex Batch seq heads dim) (QkvIndex Batch seq heads dim)) + +/-- Typed input labels for attention core circuits (Q/K/V). -/ +abbrev AttentionInput (Batch : Type) (seq heads dim : Nat) : Type := + AttentionInputNode Batch seq heads dim + +/-- Typed output labels for attention core circuits. -/ +abbrev AttentionOutput (Batch : Type) (seq heads dim : Nat) : Type := + QkvIndex Batch seq heads dim + +/-- Node type for the attention score/mixing core. -/ +abbrev AttentionNode (Batch : Type) (seq heads dim : Nat) : Type := + Sum (AttentionInputNode Batch seq heads dim) + (Sum (ScoreIndex Batch seq heads) (Sum (WeightIndex Batch seq heads) + (AttentionOutput Batch seq heads dim))) + +section Decidable + +variable [DecidableEq Batch] + +/-- Decidable equality for attention-core input nodes. -/ +instance instDecidableEqAttentionInputNode : + DecidableEq (AttentionInputNode Batch seq heads dim) := by + intro x y + cases x with + | inl qx => + cases y with + | inl qy => + simpa using (inferInstance : Decidable (qx = qy)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr x => + cases x with + | inl kx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl ky => + simpa using (inferInstance : Decidable (kx = ky)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr vx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr vy => + simpa using (inferInstance : Decidable (vx = vy)) + +/-- Decidable equality for attention-core nodes. -/ +instance instDecidableEqAttentionNode : + DecidableEq (AttentionNode Batch seq heads dim) := by + intro x y + cases x with + | inl x => + cases y with + | inl y => + simpa using (inferInstance : Decidable (x = y)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr x => + cases x with + | inl sx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl sy => + simpa using (inferInstance : Decidable (sx = sy)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr x => + cases x with + | inl wx => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl wy => + simpa using (inferInstance : Decidable (wx = wy)) + | inr _ => + exact isFalse (by intro h; cases h) + | inr ox => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr y => + cases y with + | inl _ => + exact isFalse (by intro h; cases h) + | inr oy => + simpa using (inferInstance : Decidable (ox = oy)) + +end Decidable + +/-- Inject a Q input node into the attention core node type. -/ +abbrev attnQ (q : QkvIndex Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inl (Sum.inl q) + +/-- Inject a K input node into the attention core node type. -/ +abbrev attnK (k : QkvIndex Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inl (Sum.inr (Sum.inl k)) + +/-- Inject a V input node into the attention core node type. -/ +abbrev attnV (v : QkvIndex Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inl (Sum.inr (Sum.inr v)) + +/-- Inject a score node into the attention core node type. -/ +abbrev attnScore (s : ScoreIndex Batch seq heads) : AttentionNode Batch seq heads dim := + Sum.inr (Sum.inl s) + +/-- Inject a weight node into the attention core node type. -/ +abbrev attnWeight (w : WeightIndex Batch seq heads) : AttentionNode Batch seq heads dim := + Sum.inr (Sum.inr (Sum.inl w)) + +/-- Inject an output node into the attention core node type. -/ +abbrev attnOut (o : AttentionOutput Batch seq heads dim) : AttentionNode Batch seq heads dim := + Sum.inr (Sum.inr (Sum.inr o)) + +/-- Rank function used to orient attention-core edges from inputs to outputs. -/ +def attentionRank : AttentionNode Batch seq heads dim → Nat + | Sum.inl _ => 0 + | Sum.inr (Sum.inl _) => 1 + | Sum.inr (Sum.inr (Sum.inl _)) => 2 + | Sum.inr (Sum.inr (Sum.inr _)) => 3 + +/-- Adjacency relation for attention-core wiring. -/ +def attentionAdj : AttentionNode Batch seq heads dim → AttentionNode Batch seq heads dim → Prop + | Sum.inl (Sum.inl (b, q, h, _)), + Sum.inr (Sum.inl (b', h', q', _)) => + b = b' ∧ h = h' ∧ q = q' + | Sum.inl (Sum.inr (Sum.inl (b, k, h, _))), + Sum.inr (Sum.inl (b', h', _, k')) => + b = b' ∧ h = h' ∧ k = k' + | Sum.inr (Sum.inl (b, h, q, _)), + Sum.inr (Sum.inr (Sum.inl (b', h', q', _))) => + b = b' ∧ h = h' ∧ q = q' + | Sum.inr (Sum.inr (Sum.inl (b, h, q, _))), + Sum.inr (Sum.inr (Sum.inr (b', q', h', _))) => + b = b' ∧ h = h' ∧ q = q' + | Sum.inl (Sum.inr (Sum.inr (b, _, h, d))), + Sum.inr (Sum.inr (Sum.inr (b', _, h', d'))) => + b = b' ∧ h = h' ∧ d = d' + | _, _ => False + +section Dag + +variable [DecidableEq Batch] + +/-- DAG for the attention score/mixing core. -/ +def attentionDag : Dag (AttentionNode Batch seq heads dim) := + { graph := { Adj := attentionAdj (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } + decAdj := by + intro j i + cases j with + | inl j => + cases j with + | inl q => + rcases q with ⟨b, q, h, _d⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl s => + rcases s with ⟨b', h', q', _k⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ q = q')) + | inr _ => + exact isFalse (by intro h; cases h) + | inr j => + cases j with + | inl k => + rcases k with ⟨b, k, h, _d⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl s => + rcases s with ⟨b', h', _q, k'⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ k = k')) + | inr _ => + exact isFalse (by intro h; cases h) + | inr v => + rcases v with ⟨b, _k, h, d⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr o => + rcases o with ⟨b', _q, h', d'⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ d = d')) + | inr j => + cases j with + | inl s => + rcases s with ⟨b, h, q, _k⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl w => + rcases w with ⟨b', h', q', _k'⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ q = q')) + | inr _ => + exact isFalse (by intro h; cases h) + | inr j => + cases j with + | inl w => + rcases w with ⟨b, h, q, _k⟩ + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr i => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr o => + rcases o with ⟨b', q', h', _d⟩ + exact (inferInstance : Decidable (b = b' ∧ h = h' ∧ q = q')) + | inr _ => + exact isFalse (by intro h; cases h) + wf := by + have hsub : + Subrelation (attentionAdj (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)) + (fun j i => + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) j < + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) i) := by + intro j i h + cases j with + | inl j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + simp [attentionRank] + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + cases h + | inr j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + simp [attentionRank] + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + cases h + | inr _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + simp [attentionRank] + | inr j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + simp [attentionRank] + | inr _ => + cases h + | inr j => + cases j with + | inl _ => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr i => + cases i with + | inl _ => + cases h + | inr _ => + simp [attentionRank] + | inr _ => + cases h + have hwf : WellFounded (fun j i => + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) j < + attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) i) := by + simpa using (InvImage.wf + (f := attentionRank (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)) + (h := Nat.lt_wfRel.wf)) + exact Subrelation.wf hsub hwf } + +/-- Q nodes feed score nodes for matching batch/head/query. -/ +theorem attentionDag_rel_q_score (b : Batch) (q : Fin seq) (h : Fin heads) (d : Fin dim) + (k : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnQ (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- K nodes feed score nodes for matching batch/head/key. -/ +theorem attentionDag_rel_k_score (b : Batch) (k : Fin seq) (h : Fin heads) (d : Fin dim) + (q : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnK (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) + (attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- Score nodes feed weight nodes for matching batch/head/query. -/ +theorem attentionDag_rel_score_weight (b : Batch) (h : Fin heads) (q : Fin seq) (k k' : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k')) + (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- Weight nodes feed output nodes for matching batch/head/query. -/ +theorem attentionDag_rel_weight_out (b : Batch) (h : Fin heads) (q : Fin seq) + (k : Fin seq) (d : Fin dim) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +/-- V nodes feed output nodes for matching batch/head/dimension. -/ +theorem attentionDag_rel_v_out (b : Batch) (k : Fin seq) (h : Fin heads) (d : Fin dim) + (q : Fin seq) : + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by + simp [Dag.rel, attentionDag, attentionAdj] + +end Dag + +section Inputs + +/-- Input nodes for the attention core. -/ +def attentionInputs : Finset (AttentionNode Batch seq heads dim) := + (Finset.univ : Finset (AttentionInput Batch seq heads dim)).map Embedding.inl + +/-- Definitional characterization of `attentionInputs`. -/ +theorem attentionInputs_def : + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + (Finset.univ : Finset (AttentionInput Batch seq heads dim)).map Embedding.inl := by + rfl + +open scoped Classical in +/-- Membership in attention inputs corresponds to being a left injection. -/ +theorem mem_attentionInputs_iff {s : AttentionNode Batch seq heads dim} : + s ∈ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) ↔ + ∃ a, s = Sum.inl a := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨a, _ha, hsa⟩ + exact ⟨a, hsa.symm⟩ + · rintro ⟨a, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨a, by simp, rfl⟩ + +open scoped Classical in +/-- Right injections are not attention input nodes. -/ +theorem not_mem_attentionInputs_inr (s : Sum (ScoreIndex Batch seq heads) + (Sum (WeightIndex Batch seq heads) (AttentionOutput Batch seq heads dim))) : + Sum.inr s ∉ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + intro h + rcases (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).1 h + with ⟨a, ha⟩ + cases ha + +open scoped Classical in +/-- Input labels correspond to input nodes in the attention core. -/ +def attentionInputEquiv : + AttentionInput Batch seq heads dim ≃ + { i // i ∈ attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := + { toFun := fun a => + ⟨Sum.inl a, + (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨a, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl a, _⟩ => a + | ⟨Sum.inr s, h⟩ => + False.elim + (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + left_inv := by + intro a + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl a => rfl + | inr s => + cases (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) } + +/-- Definitional characterization of `attentionInputEquiv`. -/ +theorem attentionInputEquiv_def : + attentionInputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + { toFun := fun a => + ⟨Sum.inl a, + (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨a, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl a, _⟩ => a + | ⟨Sum.inr s, h⟩ => + False.elim + (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + left_inv := by + intro a + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl a => rfl + | inr s => + cases (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) } := by + rfl + +end Inputs + +section Outputs + +/-- Output nodes for the attention core. -/ +def attentionOutputs : Finset (AttentionNode Batch seq heads dim) := + (Finset.univ : Finset (AttentionOutput Batch seq heads dim)).map + { toFun := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inj' := by + intro a b h + cases h + rfl } + +/-- Definitional characterization of `attentionOutputs`. -/ +theorem attentionOutputs_def : + attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + (Finset.univ : Finset (AttentionOutput Batch seq heads dim)).map + { toFun := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inj' := by + intro a b h + cases h + rfl } := by + rfl + +open scoped Classical in +/-- Membership in attention outputs corresponds to being an output injection. -/ +theorem mem_attentionOutputs_iff {s : AttentionNode Batch seq heads dim} : + s ∈ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) ↔ + ∃ o, s = attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) o := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨o, _ho, hso⟩ + exact ⟨o, hso.symm⟩ + · rintro ⟨o, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨o, by simp, rfl⟩ + +open scoped Classical in +/-- Left injections are not attention output nodes. -/ +theorem not_mem_attentionOutputs_inl (s : AttentionInputNode Batch seq heads dim) : + Sum.inl s ∉ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + intro h + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + +open scoped Classical in +/-- Score nodes are not attention output nodes. -/ +theorem not_mem_attentionOutputs_score (s : ScoreIndex Batch seq heads) : + Sum.inr (Sum.inl s) ∉ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) := by + intro h + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + +open scoped Classical in +/-- Weight nodes are not attention output nodes. -/ +theorem not_mem_attentionOutputs_weight (w : WeightIndex Batch seq heads) : + Sum.inr (Sum.inr (Sum.inl w)) ∉ + attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + intro h + rcases (Finset.mem_map.1 h) with ⟨o, _ho, ho⟩ + cases ho + +open scoped Classical in +/-- Output labels correspond to output nodes in the attention core. -/ +def attentionOutputEquiv : + AttentionOutput Batch seq heads dim ≃ + { i // i ∈ attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } := + { toFun := fun o => + ⟨attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) o, + (mem_attentionOutputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨o, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr (Sum.inr (Sum.inr o)), _⟩ => o + | ⟨Sum.inl s, h⟩ => + False.elim + (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inl s), h⟩ => + False.elim + (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inr (Sum.inl w)), h⟩ => + False.elim + (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) w h) + left_inv := by + intro o + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl s => + cases (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) + | inr s => + cases s with + | inl s => + cases (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) s hs) + | inr s => + cases s with + | inl w => + cases (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) w hs) + | inr _ => + rfl } + +/-- Definitional characterization of `attentionOutputEquiv`. -/ +theorem attentionOutputEquiv_def : + attentionOutputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) = + { toFun := fun o => + ⟨attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) o, + (mem_attentionOutputs_iff (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).2 + ⟨o, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr (Sum.inr (Sum.inr o)), _⟩ => o + | ⟨Sum.inl s, h⟩ => + False.elim + (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inl s), h⟩ => + False.elim + (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s h) + | ⟨Sum.inr (Sum.inr (Sum.inl w)), h⟩ => + False.elim + (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) w h) + left_inv := by + intro o + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl s => + cases (not_mem_attentionOutputs_inl (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) s hs) + | inr s => + cases s with + | inl s => + cases (not_mem_attentionOutputs_score (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) s hs) + | inr s => + cases s with + | inl w => + cases (not_mem_attentionOutputs_weight (Batch := Batch) (seq := seq) + (heads := heads) (dim := dim) w hs) + | inr _ => + rfl } := by + rfl + +end Outputs + +section Circuits + +variable [DecidableEq Batch] + +/-- Gate semantics for attention score/mixing circuits. -/ +def attentionGate (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + ∀ i, + (∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j i → + Val) → + Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr i => + cases i with + | inl s => + rcases s with ⟨b, h, q, k⟩ + let qNode : Fin dim → AttentionNode Batch seq heads dim := fun d => + attnQ (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, q, h, d) + let kNode : Fin dim → AttentionNode Batch seq heads dim := fun d => + attnK (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, k, h, d) + let qVals : Fin dim → Val := fun d => + rec (qNode d) + (attentionDag_rel_q_score (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b q h d k) + let kVals : Fin dim → Val := fun d => + rec (kNode d) + (attentionDag_rel_k_score (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b k h d q) + exact scale * dotProduct qVals kVals + | inr i => + cases i with + | inl w => + rcases w with ⟨b, h, q, k⟩ + let scoreNode : Fin seq → AttentionNode Batch seq heads dim := fun k' => + attnScore (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, h, q, k') + let scores : Fin seq → Val := fun k' => + rec (scoreNode k') + (attentionDag_rel_score_weight (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) b h q k k') + exact softmax scores k + | inr o => + rcases o with ⟨b, q, h, d⟩ + let weightNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, h, q, k) + let valueNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, k, h, d) + let weights : Fin seq → Val := fun k => + rec (weightNode k) + (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) b h q k d) + let vals : Fin seq → Val := fun k => + rec (valueNode k) + (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim) b k h d q) + exact dotProduct weights vals + +/-- Definitional characterization of `attentionGate` on output nodes. -/ +theorem attentionGate_out_def (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) + (b : Batch) (q : Fin seq) (h : Fin heads) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) rec = + let weightNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k) + let valueNode : Fin seq → AttentionNode Batch seq heads dim := fun k => + attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d) + let weights : Fin seq → Val := fun k => + rec (weightNode k) + (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q k d) + let vals : Fin seq → Val := fun k => + rec (valueNode k) + (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b k h d q) + dotProduct weights vals := by + simp [attentionGate] + +/-- Circuit for attention score/mixing. -/ +def attentionCircuit (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + Circuit (AttentionNode Batch seq heads dim) Val := + { dag := attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inputs := attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + outputs := attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + gate := + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax } + +/-- Definitional characterization of `attentionCircuit`. -/ +theorem attentionCircuit_def (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax = + { dag := attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + inputs := attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + outputs := attentionOutputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + gate := attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale + softmax } := by + rfl + +/-- Typed interface for attention score/mixing circuits. -/ +def attentionInterface (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + Interface + (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax) + (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := + { inputs := attentionInputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + outputs := attentionOutputEquiv (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) } + +end Circuits + +section Typed + +variable [DecidableEq Batch] + +/-- Typed attention score/mixing circuit. -/ +def attentionTyped (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + TypedCircuit (AttentionNode Batch seq heads dim) Val + (AttentionInput Batch seq heads dim) (AttentionOutput Batch seq heads dim) := + { circuit := attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax + interface := attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax } + +/-- Definitional characterization of `attentionTyped`. -/ +theorem attentionTyped_def (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) : + attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax = + { circuit := attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax + interface := attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax } := by + rfl + +end Typed + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Heads.lean b/Nfp/Circuit/Layers/Heads.lean new file mode 100644 index 0000000..2cebb38 --- /dev/null +++ b/Nfp/Circuit/Layers/Heads.lean @@ -0,0 +1,87 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Logic.Equiv.Fin.Basic +public import Mathlib.Logic.Equiv.Prod +public import Nfp.Circuit.Layers.Reshape + +/-! +Head split/merge combinators for transformer-style shapes. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v u_in u_out + +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Val : Type v} +variable {Batch : Type u} + +/-- Split a hidden index into `(heads, headDim)` using `Fin` product equivalence. -/ +def headSplitEquiv (heads dim : Nat) : + Fin (heads * dim) ≃ Fin heads × Fin dim := + (finProdFinEquiv (m := heads) (n := dim)).symm + +/-- Merge `(heads, headDim)` back into a hidden index. -/ +def headMergeEquiv (heads dim : Nat) : + Fin heads × Fin dim ≃ Fin (heads * dim) := + finProdFinEquiv (m := heads) (n := dim) + +/-- Split the hidden dimension inside a batched index. -/ +def batchHeadSplitEquiv (Batch : Type u) (heads dim : Nat) : + Batch × Fin (heads * dim) ≃ Batch × Fin heads × Fin dim := + _root_.Equiv.prodCongr (_root_.Equiv.refl Batch) (headSplitEquiv heads dim) + +/-- Merge the head and head-dimension inside a batched index. -/ +def batchHeadMergeEquiv (Batch : Type u) (heads dim : Nat) : + Batch × Fin heads × Fin dim ≃ Batch × Fin (heads * dim) := + (batchHeadSplitEquiv Batch heads dim).symm + +/-- Split heads on the input labels of a typed circuit. -/ +def splitHeadsInput {Output : Type u_out} (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin (heads * dim)) Output) : + TypedCircuit Node Val (Batch × Fin heads × Fin dim) Output := + mapInterface T (batchHeadSplitEquiv Batch heads dim) (_root_.Equiv.refl Output) + +/-- Split heads on the output labels of a typed circuit. -/ +def splitHeadsOutput {Input : Type u_in} (heads dim : Nat) + (T : TypedCircuit Node Val Input (Batch × Fin (heads * dim))) : + TypedCircuit Node Val Input (Batch × Fin heads × Fin dim) := + mapInterface T (_root_.Equiv.refl Input) (batchHeadSplitEquiv Batch heads dim) + +/-- Split heads on both input and output labels. -/ +def splitHeads (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin (heads * dim)) (Batch × Fin (heads * dim))) : + TypedCircuit Node Val (Batch × Fin heads × Fin dim) (Batch × Fin heads × Fin dim) := + mapInterface T (batchHeadSplitEquiv Batch heads dim) (batchHeadSplitEquiv Batch heads dim) + +/-- Merge heads on the input labels of a typed circuit. -/ +def mergeHeadsInput {Output : Type u_out} (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin heads × Fin dim) Output) : + TypedCircuit Node Val (Batch × Fin (heads * dim)) Output := + mapInterface T (batchHeadMergeEquiv Batch heads dim) (_root_.Equiv.refl Output) + +/-- Merge heads on the output labels of a typed circuit. -/ +def mergeHeadsOutput {Input : Type u_in} (heads dim : Nat) + (T : TypedCircuit Node Val Input (Batch × Fin heads × Fin dim)) : + TypedCircuit Node Val Input (Batch × Fin (heads * dim)) := + mapInterface T (_root_.Equiv.refl Input) (batchHeadMergeEquiv Batch heads dim) + +/-- Merge heads on both input and output labels. -/ +def mergeHeads (heads dim : Nat) + (T : TypedCircuit Node Val (Batch × Fin heads × Fin dim) (Batch × Fin heads × Fin dim)) : + TypedCircuit Node Val (Batch × Fin (heads * dim)) (Batch × Fin (heads * dim)) := + mapInterface T (batchHeadMergeEquiv Batch heads dim) (batchHeadMergeEquiv Batch heads dim) + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Induction.lean b/Nfp/Circuit/Layers/Induction.lean new file mode 100644 index 0000000..a2c49ea --- /dev/null +++ b/Nfp/Circuit/Layers/Induction.lean @@ -0,0 +1,10 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Layers.Induction.Basic +public import Nfp.Circuit.Layers.Induction.Circuit + +/-! +Induction-head layer wiring and helper lemmas. +-/ diff --git a/Nfp/Circuit/Layers/Induction/Basic.lean b/Nfp/Circuit/Layers/Induction/Basic.lean new file mode 100644 index 0000000..db2d94f --- /dev/null +++ b/Nfp/Circuit/Layers/Induction/Basic.lean @@ -0,0 +1,998 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.BigOperators.Ring.Finset +public import Mathlib.Algebra.Order.Monoid.Unbundled.Basic +public import Mathlib.Algebra.Order.Ring.Defs +import all Nfp.Circuit.Layers.Attention +public import Nfp.Circuit.Layers.Attention + +/-! +Induction-head specifications for attention cores. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe v + +open scoped BigOperators + +section Weights + +variable {Val : Type v} [NonAssocSemiring Val] +variable {seq : Nat} + +/-- Induction weights are one-hot at `prev` for each query position. -/ +def InductionWeights (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop := + ∀ q, weights q = Pi.single (prev q) 1 + +/-- A one-hot weight vector selects the corresponding value in a dot product. -/ +theorem dotProduct_eq_of_oneHot (k : Fin seq) (vals : Fin seq → Val) : + dotProduct (Pi.single k 1) vals = vals k := by + simp + +/-- Induction weights select the `prev` value in each dot product. -/ +theorem dotProduct_eq_prev (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) (vals : Fin seq → Fin seq → Val) + (hweights : InductionWeights (Val := Val) prev weights) (q : Fin seq) : + dotProduct (weights q) (vals q) = vals q (prev q) := by + have hq : weights q = Pi.single (prev q) 1 := hweights q + simp [hq] + +end Weights + +section Spec + +variable {Val : Type v} +variable {n : Nat} + +/-- Induction-head spec: for non-initial queries (1-based indices ≥ 2), + outputs copy `prev` values. -/ +def InductionSpec (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, q ≠ 0 → out q = vals (prev q) + +/-- Unfolding lemma for `InductionSpec`. -/ +theorem InductionSpec_def (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : + InductionSpec (n := n) prev out vals = ∀ q, q ≠ 0 → out q = vals (prev q) := by + rfl + +/-- Concrete `prev` map on `Fin (n + 1)` (with `0 ↦ 0`). -/ +def prevIndex : Fin (Nat.succ n) → Fin (Nat.succ n) + | ⟨0, _⟩ => 0 + | ⟨Nat.succ k, hk⟩ => + ⟨k, Nat.lt_trans (Nat.lt_of_succ_lt_succ hk) (Nat.lt_succ_self n)⟩ + +/-- Previous-token head spec: copies the immediately preceding token. -/ +def PrevTokenSpec (out vals : Fin (Nat.succ n) → Val) : Prop := + InductionSpec (n := n) (prevIndex (n := n)) out vals + +/-- Unfolding lemma for `PrevTokenSpec`. -/ +theorem PrevTokenSpec_def (out vals : Fin (Nat.succ n) → Val) : + PrevTokenSpec (n := n) out vals = + InductionSpec (n := n) (prevIndex (n := n)) out vals := by + rfl + +end Spec + +section ApproxSpec + +variable {Val : Type v} [AddCommMonoid Val] [PartialOrder Val] +variable {n : Nat} + +/-- Approximate induction-head spec: outputs are within `ε` of `prev` values. -/ +def InductionSpecApprox (ε : Val) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, q ≠ 0 → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε + +/-- Approximate induction-head spec restricted to active queries. -/ +def InductionSpecApproxOn (ε : Val) (active : Fin (Nat.succ n) → Prop) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : Prop := + ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε + +/-- Definitional characterization of `InductionSpecApproxOn`. -/ +theorem InductionSpecApproxOn_def (ε : Val) (active : Fin (Nat.succ n) → Prop) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) : + InductionSpecApproxOn (Val := Val) (n := n) ε active prev out vals = + ∀ q, active q → out q ≤ vals (prev q) + ε ∧ vals (prev q) ≤ out q + ε := by + rfl + +variable [IsOrderedAddMonoid Val] + +/-- Exact induction spec implies the approximate spec for any nonnegative tolerance. -/ +theorem inductionSpecApprox_of_spec (ε : Val) (hε : 0 ≤ ε) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (out vals : Fin (Nat.succ n) → Val) + (h : InductionSpec prev out vals) : + InductionSpecApprox (Val := Val) (n := n) ε prev out vals := by + intro q hq + have hq' : out q = vals (prev q) := h q hq + constructor <;> + simpa [hq'] using + (le_add_of_nonneg_right hε : + vals (prev q) ≤ vals (prev q) + ε) + +end ApproxSpec + +section ValueRange + +variable {Val : Type v} [PartialOrder Val] +variable {seq : Nat} + +/-- Value-range bounds for a vector of attention values. -/ +structure ValueRangeBounds (lo hi : Val) (vals : Fin seq → Val) : Prop where + /-- Lower and upper bounds are ordered. -/ + lo_le_hi : lo ≤ hi + /-- All values are at least `lo`. -/ + lo_le : ∀ k, lo ≤ vals k + /-- All values are at most `hi`. -/ + le_hi : ∀ k, vals k ≤ hi + +end ValueRange + +section Bounds + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Numeric bounds certifying one-hot weights on non-initial queries + (1-based indices ≥ 2). -/ +structure OneHotBoundsOn (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on non-initial queries (1-based indices ≥ 2). -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on non-initial queries (1-based indices ≥ 2). -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- Non-prev weights are nonpositive on non-initial queries (1-based indices ≥ 2). -/ + other_le_zero : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ 0 + +/-- Certified bounds imply one-hot weights on non-initial queries + (1-based indices ≥ 2). -/ +theorem oneHot_of_boundsOn (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) [DecidableEq (Fin seq)] + (h : OneHotBoundsOn prev weights) : + ∀ q, q ≠ 0 → weights q = Pi.single (prev q) 1 := by + intro q hq + funext k + by_cases hk : k = prev q + · subst hk + have hzero : + (∑ k ∈ (Finset.univ.erase (prev q)), weights q k) = 0 := by + refine Finset.sum_eq_zero ?_ + intro k hk' + have hkne : k ≠ prev q := (Finset.mem_erase.1 hk').1 + have hle : weights q k ≤ 0 := h.other_le_zero q hq k hkne + have hge : 0 ≤ weights q k := h.nonneg q hq k + exact le_antisymm hle hge + have hsum : + weights q (prev q) + + ∑ k ∈ (Finset.univ.erase (prev q)), weights q k = + ∑ k, weights q k := by + simpa using + (Finset.add_sum_erase + (s := (Finset.univ : Finset (Fin seq))) + (f := weights q) (a := prev q) (by simp)) + have hprev : weights q (prev q) = 1 := by + have hsum' : + weights q (prev q) + 0 = 1 := by + simpa [hzero, h.sum_one q hq] using hsum + simpa using hsum' + simp [Pi.single, hprev] + · have hle : weights q k ≤ 0 := h.other_le_zero q hq k hk + have hge : 0 ≤ weights q k := h.nonneg q hq k + have hzero : weights q k = 0 := le_antisymm hle hge + simp [Pi.single, hk, hzero] + +end Bounds + +section ApproxBounds + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Approximate one-hot bounds for attention weights on non-initial queries + (1-based indices ≥ 2). -/ +structure OneHotApproxBoundsOn (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on non-initial queries (1-based indices ≥ 2). -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on non-initial queries (1-based indices ≥ 2). -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on non-initial queries + (1-based indices ≥ 2). -/ + prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on non-initial queries (1-based indices ≥ 2). -/ + other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Approximate one-hot bounds for attention weights on active queries. -/ +structure OneHotApproxBoundsOnActive (ε : Val) (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop where + /-- All weights are nonnegative on active queries. -/ + nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on active queries. -/ + sum_one : ∀ q, active q → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on active queries. -/ + prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on active queries. -/ + other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Lift global approximate bounds to an active-set version. -/ +theorem oneHotApproxBoundsOnActive_of_on (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) + (h : OneHotApproxBoundsOn (Val := Val) ε prev weights) : + OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := by + refine { nonneg := ?_, sum_one := ?_, prev_large := ?_, other_le := ?_ } + · intro q hq k + exact h.nonneg q hq k + · intro q hq + exact h.sum_one q hq + · intro q hq + exact h.prev_large q hq + · intro q hq k hk + exact h.other_le q hq k hk + +/-- Approximate induction weights: prev weight near one, others at most `ε`. -/ +def InductionWeightsApprox (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) : Prop := + ∀ q, q ≠ 0 → + 1 ≤ weights q (prev q) + ε ∧ + ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Approximate bounds imply approximate induction weights. -/ +theorem inductionWeightsApprox_of_boundsOn (ε : Val) (prev : Fin seq → Fin seq) + (weights : Fin seq → Fin seq → Val) + (h : OneHotApproxBoundsOn ε prev weights) : + InductionWeightsApprox (Val := Val) ε prev weights := by + intro q hq + exact ⟨h.prev_large q hq, h.other_le q hq⟩ + +end ApproxBounds + +section ApproxOutput + +variable {Val : Type v} [Ring Val] [LinearOrder Val] [IsOrderedRing Val] +variable {n : Nat} + +local instance : NeZero (Nat.succ n) := ⟨by simp⟩ + +/-- Approximate one-hot weights plus bounded values yield an approximate induction spec + on active queries. -/ +theorem inductionSpecApproxOn_of_oneHotApprox_valueRange + (ε lo hi : Val) (active : Fin (Nat.succ n) → Prop) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) + (vals : Fin (Nat.succ n) → Val) + (hweights : OneHotApproxBoundsOnActive (Val := Val) ε active prev weights) + (hvals : ValueRangeBounds (Val := Val) lo hi vals) : + InductionSpecApproxOn (Val := Val) (n := n) (ε * (hi - lo)) active prev + (fun q => dotProduct (weights q) vals) vals := by + classical + intro q hq + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (prev q) + have hsum_decomp : + weights q (prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (prev q) + ∑ k ∈ others, weights q k = 1 := by + simpa [hweights.sum_one q hq] using hsum_decomp + have hsum_others_le : (∑ k ∈ others, weights q k) ≤ ε := by + have hprev : 1 ≤ weights q (prev q) + ε := hweights.prev_large q hq + have hprev' : + weights q (prev q) + ∑ k ∈ others, weights q k ≤ weights q (prev q) + ε := by + simpa [hsum] using hprev + exact (add_le_add_iff_left (weights q (prev q))).1 hprev' + have hsum_others_nonneg : 0 ≤ ∑ k ∈ others, weights q k := by + refine Finset.sum_nonneg ?_ + intro k hk + exact hweights.nonneg q hq k + have hvals_hi : ∀ k, vals k ≤ hi := hvals.le_hi + have hvals_lo : ∀ k, lo ≤ vals k := hvals.lo_le + have hdiff_nonneg : 0 ≤ hi - lo := sub_nonneg.mpr hvals.lo_le_hi + have hsum_vals_le : + (∑ k ∈ others, weights q k * vals k) ≤ (∑ k ∈ others, weights q k) * hi := by + have hle : ∀ k ∈ others, weights q k * vals k ≤ weights q k * hi := by + intro k hk + have hval : vals k ≤ hi := hvals_hi k + have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k + exact mul_le_mul_of_nonneg_left hval hnonneg + calc + ∑ k ∈ others, weights q k * vals k + ≤ ∑ k ∈ others, weights q k * hi := Finset.sum_le_sum hle + _ = (∑ k ∈ others, weights q k) * hi := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := hi)).symm + have hsum_vals_ge : + (∑ k ∈ others, weights q k) * lo ≤ (∑ k ∈ others, weights q k * vals k) := by + have hle : ∀ k ∈ others, weights q k * lo ≤ weights q k * vals k := by + intro k hk + have hval : lo ≤ vals k := hvals_lo k + have hnonneg : 0 ≤ weights q k := hweights.nonneg q hq k + exact mul_le_mul_of_nonneg_left hval hnonneg + calc + (∑ k ∈ others, weights q k) * lo + = ∑ k ∈ others, weights q k * lo := by + exact + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := lo)) + _ ≤ ∑ k ∈ others, weights q k * vals k := Finset.sum_le_sum hle + have hsum_prod : + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hsum_val_prev : + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) = + vals (prev q) := by + calc + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) = + (weights q (prev q) + ∑ k ∈ others, weights q k) * vals (prev q) := by + simpa using + (add_mul (weights q (prev q)) (∑ k ∈ others, weights q k) (vals (prev q))).symm + _ = 1 * vals (prev q) := by + simp [hsum] + _ = vals (prev q) := by simp + have hsplit : + (∑ k ∈ others, weights q k) * hi = + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + (∑ k ∈ others, weights q k) * hi = + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * hi - + (∑ k ∈ others, weights q k) * lo := by + exact + (add_sub_cancel_left + ((∑ k ∈ others, weights q k) * lo) ((∑ k ∈ others, weights q k) * hi)).symm + _ = (∑ k ∈ others, weights q k) * lo + + ((∑ k ∈ others, weights q k) * hi - + (∑ k ∈ others, weights q k) * lo) := by + simp [sub_eq_add_neg, add_assoc] + _ = (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [mul_sub] + have hsum_prev_le : + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo ≤ + vals (prev q) := by + have hmul : (∑ k ∈ others, weights q k) * lo ≤ + (∑ k ∈ others, weights q k) * vals (prev q) := + mul_le_mul_of_nonneg_left (hvals_lo (prev q)) hsum_others_nonneg + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo + ≤ weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) := by + have h := + add_le_add_left hmul (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ = vals (prev q) := hsum_val_prev + have hupper_mid : + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi ≤ + vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi = + weights q (prev q) * vals (prev q) + + ((∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo)) := by + simp [hsplit] + _ = weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [add_assoc] + _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := by + have h := + add_le_add_right hsum_prev_le + ((∑ k ∈ others, weights q k) * (hi - lo)) + simpa [add_comm, add_left_comm, add_assoc] using h + have hupper : + dotProduct (weights q) vals ≤ vals (prev q) + ε * (hi - lo) := by + have hmul : + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + calc + dotProduct (weights q) vals = + weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := hout_eq + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have h := + add_le_add_left hsum_vals_le (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ ≤ vals (prev q) + (∑ k ∈ others, weights q k) * (hi - lo) := hupper_mid + _ ≤ vals (prev q) + ε * (hi - lo) := by + have h := + add_le_add_left hmul (vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + have hprev_le : + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have hmul : (∑ k ∈ others, weights q k) * vals (prev q) ≤ + (∑ k ∈ others, weights q k) * hi := + mul_le_mul_of_nonneg_left (hvals_hi (prev q)) hsum_others_nonneg + have hmul' : + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := by + have h := + add_le_add_left hmul (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + calc + vals (prev q) = + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * vals (prev q) := by + simpa using hsum_val_prev.symm + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hmul' + have hprev_le' : + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + calc + vals (prev q) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * hi := hprev_le + _ = + weights q (prev q) * vals (prev q) + + (∑ k ∈ others, weights q k) * lo + + (∑ k ∈ others, weights q k) * (hi - lo) := by + simp [hsplit, add_assoc] + have hsub : + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) ≤ + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := by + exact (sub_le_iff_le_add).2 hprev_le' + have hlowershift : + vals (prev q) - ε * (hi - lo) ≤ + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := by + have hmul : + (∑ k ∈ others, weights q k) * (hi - lo) ≤ ε * (hi - lo) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hdiff_nonneg + exact sub_le_sub_left hmul (vals (prev q)) + have hlow : + vals (prev q) - ε * (hi - lo) ≤ dotProduct (weights q) vals := by + calc + vals (prev q) - ε * (hi - lo) ≤ + vals (prev q) - (∑ k ∈ others, weights q k) * (hi - lo) := hlowershift + _ ≤ weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo := hsub + _ ≤ dotProduct (weights q) vals := by + calc + weights q (prev q) * vals (prev q) + (∑ k ∈ others, weights q k) * lo + ≤ weights q (prev q) * vals (prev q) + ∑ k ∈ others, weights q k * vals k := by + have h := + add_le_add_left hsum_vals_ge (weights q (prev q) * vals (prev q)) + simpa [add_comm, add_left_comm, add_assoc] using h + _ = dotProduct (weights q) vals := by + simp [hout_eq] + have hlower : + vals (prev q) ≤ dotProduct (weights q) vals + ε * (hi - lo) := by + exact (sub_le_iff_le_add).1 hlow + exact ⟨hupper, hlower⟩ + +/-- Approximate one-hot weights plus bounded values yield an approximate induction spec. -/ +theorem inductionSpecApprox_of_oneHotApprox_valueRange + (ε lo hi : Val) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Val) + (vals : Fin (Nat.succ n) → Val) + (hweights : OneHotApproxBoundsOn (Val := Val) ε prev weights) + (hvals : ValueRangeBounds (Val := Val) lo hi vals) : + InductionSpecApprox (Val := Val) (n := n) (ε * (hi - lo)) prev + (fun q => dotProduct (weights q) vals) vals := by + have hweights' : + OneHotApproxBoundsOnActive (Val := Val) ε (fun q => q ≠ 0) prev weights := + oneHotApproxBoundsOnActive_of_on (Val := Val) (seq := Nat.succ n) + (ε := ε) (prev := prev) (weights := weights) hweights + exact + inductionSpecApproxOn_of_oneHotApprox_valueRange + (Val := Val) + (n := n) + (ε := ε) + (lo := lo) + (hi := hi) + (active := fun q => q ≠ 0) + (prev := prev) + (weights := weights) + (vals := vals) + (hweights := hweights') + (hvals := hvals) + +end ApproxOutput + +section SoftmaxMargin + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} [NeZero seq] + +/-- Softmax margin certificates for approximate one-hot weights. -/ +structure SoftmaxMarginBounds (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) : Prop where + /-- Score gap between `prev` and other keys on non-initial queries + (1-based indices ≥ 2). -/ + score_margin : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) + /-- All weights are nonnegative on non-initial queries (1-based indices ≥ 2). -/ + nonneg : ∀ q, q ≠ 0 → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on non-initial queries (1-based indices ≥ 2). -/ + sum_one : ∀ q, q ≠ 0 → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on non-initial queries + (1-based indices ≥ 2). -/ + prev_large : ∀ q, q ≠ 0 → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on non-initial queries (1-based indices ≥ 2). -/ + other_le : ∀ q, q ≠ 0 → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Softmax margin certificates for approximate one-hot weights on active queries. -/ +structure SoftmaxMarginBoundsOn (ε margin : Val) (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) : Prop where + /-- Score gap between `prev` and other keys on active queries. -/ + score_margin : ∀ q, active q → ∀ k, k ≠ prev q → scores q k + margin ≤ scores q (prev q) + /-- All weights are nonnegative on active queries. -/ + nonneg : ∀ q, active q → ∀ k, 0 ≤ weights q k + /-- Weights sum to one on active queries. -/ + sum_one : ∀ q, active q → (∑ k, weights q k) = 1 + /-- The `prev` weight is within `ε` of one on active queries. -/ + prev_large : ∀ q, active q → 1 ≤ weights q (prev q) + ε + /-- Non-prev weights are at most `ε` on active queries. -/ + other_le : ∀ q, active q → ∀ k, k ≠ prev q → weights q k ≤ ε + +/-- Lift global softmax-margin bounds to an active-set version. -/ +theorem softmaxMarginBoundsOn_of_on (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + SoftmaxMarginBoundsOn (Val := Val) ε margin (fun q => q ≠ 0) prev scores weights := by + refine + { score_margin := ?_ + nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q hq k hk + exact h.score_margin q hq k hk + · intro q hq k + exact h.nonneg q hq k + · intro q hq + exact h.sum_one q hq + · intro q hq + exact h.prev_large q hq + · intro q hq k hk + exact h.other_le q hq k hk + +/-- Margin certificates yield approximate one-hot bounds for the weights. -/ +theorem oneHotApproxBounds_of_softmaxMargin (ε margin : Val) (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + OneHotApproxBoundsOn (Val := Val) ε prev weights := by + exact + { nonneg := h.nonneg + sum_one := h.sum_one + prev_large := h.prev_large + other_le := h.other_le } + +/-- Margin certificates imply approximate induction-weight bounds. -/ +theorem inductionWeightsApprox_of_softmaxMargin (ε margin : Val) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBounds (Val := Val) ε margin prev scores weights) : + InductionWeightsApprox (Val := Val) ε prev weights := by + exact inductionWeightsApprox_of_boundsOn + (Val := Val) + (seq := seq) + (ε := ε) + (prev := prev) + (weights := weights) + (h := oneHotApproxBounds_of_softmaxMargin + (Val := Val) + (seq := seq) + (ε := ε) + (margin := margin) + (prev := prev) + (scores := scores) + (weights := weights) + h) + +end SoftmaxMargin + +section SoftmaxMarginActive + +variable {Val : Type v} [Semiring Val] [PartialOrder Val] +variable {seq : Nat} + +/-- Margin certificates yield approximate one-hot bounds on active queries. -/ +theorem oneHotApproxBoundsOnActive_of_softmaxMargin (ε margin : Val) + (active : Fin seq → Prop) + (prev : Fin seq → Fin seq) + (scores weights : Fin seq → Fin seq → Val) + (h : SoftmaxMarginBoundsOn (Val := Val) ε margin active prev scores weights) : + OneHotApproxBoundsOnActive (Val := Val) ε active prev weights := by + exact + { nonneg := h.nonneg + sum_one := h.sum_one + prev_large := h.prev_large + other_le := h.other_le } + +end SoftmaxMarginActive + +section Attention + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {seq heads dim : Nat} +variable {Val : Type v} [NonAssocSemiring Val] + +/-- Typed V-input label for attention cores. -/ +abbrev attnInputV (v : QkvIndex Batch seq heads dim) : + AttentionInput Batch seq heads dim := + Sum.inr (Sum.inr v) + +/-- Weight function feeding an attention output node. -/ +def attentionOutWeights (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + Fin seq → Val := + fun k => + rec (attnWeight (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, h, q, k)) + (attentionDag_rel_weight_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q k d) + +/-- Value function feeding an attention output node. -/ +def attentionOutValues (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) : + Fin seq → Val := + fun k => + rec (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, k, h, d)) + (attentionDag_rel_v_out (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b k h d q) + +/-- One-hot attention weights force the output to copy the selected value. -/ +theorem attentionGate_out_eq_of_oneHot (scale : Val) + (softmax : (Fin seq → Val) → Fin seq → Val) (prev : Fin seq → Fin seq) + (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (rec : + ∀ j, + (attentionDag (Batch := Batch) (seq := seq) (heads := heads) (dim := dim)).rel j + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) → + Val) + (hweights : + attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec = + Pi.single (prev q) 1) : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) rec = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec (prev q) := by + simp only [attentionGate_out_def] + change + dotProduct + (attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) + (attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec (prev q) + rw [hweights] + exact dotProduct_eq_of_oneHot (Val := Val) (seq := seq) (k := prev q) + (vals := attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d rec) + +section Typed + +variable (scale : Val) (softmax : (Fin seq → Val) → Fin seq → Val) + +/-- Attention output equals the selected V input when weights are one-hot. -/ +theorem attentionTyped_eval_out_eq_of_oneHot (prev : Fin seq → Fin seq) + (input : AttentionInput Batch seq heads dim → Val) + (b : Batch) (h : Fin heads) (q : Fin seq) (d : Fin dim) + (hweights : + attentionOutWeights (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax) + ((attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval + input (b, q, h, d) = + input + (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := by + let C := + attentionCircuit (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + let I := + attentionInterface (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + let inputAssign := I.toInputAssignment input + have hnot : + attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d) ∉ + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + simpa using + (not_mem_attentionInputs_inr (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (s := Sum.inr (Sum.inr (b, q, h, d)))) + have hgate : + Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (fun j _ => Circuit.evalInput C inputAssign j) := by + exact Circuit.evalInput_eq_gate (C := C) (input := inputAssign) + (i := attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + hnot + have hcopy : + Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := by + have hgate' : + attentionGate (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) + (fun j _ => Circuit.evalInput C inputAssign j) = + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := + attentionGate_out_eq_of_oneHot (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (scale := scale) (softmax := softmax) (prev := prev) (b := b) (h := h) (q := q) (d := d) + (rec := fun j _ => Circuit.evalInput C inputAssign j) hweights + exact hgate.trans hgate' + have hmem : + attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d) ∈ + attentionInputs (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) := by + refine (mem_attentionInputs_iff (Batch := Batch) (seq := seq) (heads := heads) + (dim := dim)).2 ?_ + exact ⟨attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d), rfl⟩ + have hinput : + Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, prev q, h, d)) = + input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := by + have h := + Circuit.evalInput_eq_input (C := C) (input := inputAssign) + (i := attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) hmem + simpa [inputAssign, I, attentionInterface, attentionInputEquiv_def, + Interface.toInputAssignment_def, attnInputV] using h + have hvals : + attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) = + Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := rfl + calc + (attentionTyped (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) scale softmax).eval + input (b, q, h, d) = + Circuit.evalInput C inputAssign (I.outputs (b, q, h, d)).1 := by + simp [TypedCircuit.eval_def, Interface.eval_def, attentionTyped_def, C, I, inputAssign] + _ = Circuit.evalInput C inputAssign + (attnOut (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) (b, q, h, d)) := by + rfl + _ = attentionOutValues (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + b h q d (fun j _ => Circuit.evalInput C inputAssign j) (prev q) := hcopy + _ = Circuit.evalInput C inputAssign + (attnV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := hvals + _ = input (attnInputV (Batch := Batch) (seq := seq) (heads := heads) (dim := dim) + (b, prev q, h, d)) := hinput + +end Typed + +end Attention + +section InductionSpecTyped + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {heads dim n : Nat} +variable {Val : Type v} [NonAssocSemiring Val] + +variable (scale : Val) +variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) + +/-- One-hot weights on non-initial queries (1-based indices ≥ 2) imply the induction spec + for typed evaluation. -/ +theorem attentionTyped_eval_inductionSpec_of_oneHot + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + InductionSpec (n := n) prev + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + intro q hq + have hweights_q := hweights q hq + exact attentionTyped_eval_out_eq_of_oneHot + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (scale := scale) + (softmax := softmax) + (prev := prev) + (input := input) + (b := b) + (h := h) + (q := q) + (d := d) + hweights_q + +/-- Induction spec for `prevIndex` under one-hot weight hypotheses. -/ +theorem attentionTyped_eval_inductionSpec_prevIndex + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prevIndex (n := n) q) 1) : + InductionSpec (n := n) (prevIndex (n := n)) + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + exact attentionTyped_eval_inductionSpec_of_oneHot + (Batch := Batch) + (heads := heads) + (dim := dim) + (n := n) + (scale := scale) + (softmax := softmax) + (prev := prevIndex (n := n)) + (input := input) + (b := b) + (h := h) + (d := d) + hweights + +end InductionSpecTyped + +section InductionSpecApproxTyped + +variable {Batch : Type} [Fintype Batch] [DecidableEq Batch] +variable {heads dim n : Nat} +variable {Val : Type v} [NonAssocSemiring Val] [PartialOrder Val] [IsOrderedAddMonoid Val] + +variable (scale : Val) +variable (softmax : (Fin (Nat.succ n) → Val) → Fin (Nat.succ n) → Val) + +/-- One-hot weights imply the approximate induction spec for any nonnegative tolerance. -/ +theorem attentionTyped_eval_inductionSpecApprox_of_oneHot (ε : Val) (hε : 0 ≤ ε) + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (input : AttentionInput Batch (Nat.succ n) heads dim → Val) + (b : Batch) (h : Fin heads) (d : Fin dim) + (hweights : + ∀ q, q ≠ 0 → + attentionOutWeights + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + b h q d + (fun j _ => + Circuit.evalInput + (attentionCircuit + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax) + ((attentionInterface + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).toInputAssignment input) j) = + Pi.single (prev q) 1) : + InductionSpecApprox (Val := Val) (n := n) ε prev + (fun q => + (attentionTyped + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + scale softmax).eval input (b, q, h, d)) + (fun k => + input (attnInputV + (Batch := Batch) + (seq := Nat.succ n) + (heads := heads) + (dim := dim) + (b, k, h, d))) := by + apply inductionSpecApprox_of_spec (Val := Val) (n := n) (ε := ε) hε + exact attentionTyped_eval_inductionSpec_of_oneHot + (Batch := Batch) + (heads := heads) + (dim := dim) + (n := n) + (scale := scale) + (softmax := softmax) + (prev := prev) + (input := input) + (b := b) + (h := h) + (d := d) + hweights + +end InductionSpecApproxTyped + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Induction/Circuit.lean b/Nfp/Circuit/Layers/Induction/Circuit.lean new file mode 100644 index 0000000..30bcce2 --- /dev/null +++ b/Nfp/Circuit/Layers/Induction/Circuit.lean @@ -0,0 +1,80 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Layers.Induction.Basic + +/-! +Circuit-level induction specs: previous-token head feeding an induction head. + +These are definitional wrappers that name the canonical two-head induction +mechanism used in the literature: a previous-token head writes the prior token +representation into the residual stream, and the induction head then copies the +appropriate continuation. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe v + +section Specs + +variable {Val : Type v} +variable {n : Nat} + +/-- +Two-head induction circuit spec. + +The previous-token head copies `vals (q-1)` into `prevOut`, and the induction +head copies from `prevOut (prev q)` into `indOut`. +-/ +def InductionCircuitSpec + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (prevOut indOut vals : Fin (Nat.succ n) → Val) : Prop := + PrevTokenSpec (n := n) prevOut vals ∧ + InductionSpec (n := n) prev indOut prevOut + +/-- Unfolding lemma for `InductionCircuitSpec`. -/ +theorem InductionCircuitSpec_def + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (prevOut indOut vals : Fin (Nat.succ n) → Val) : + InductionCircuitSpec (n := n) prev prevOut indOut vals = + (PrevTokenSpec (n := n) prevOut vals ∧ + InductionSpec (n := n) prev indOut prevOut) := by + rfl + +/-- +Circuit composition: the induction head copies the predecessor of the predecessor. + +This is the direct consequence of a previous-token head feeding the induction head. +-/ +theorem InductionCircuitSpec_compose + (prev : Fin (Nat.succ n) → Fin (Nat.succ n)) + (prevOut indOut vals : Fin (Nat.succ n) → Val) + (h : InductionCircuitSpec (n := n) prev prevOut indOut vals) : + ∀ q, q ≠ 0 → prev q ≠ 0 → + indOut q = vals (prevIndex (n := n) (prev q)) := by + intro q hq hprev + rcases h with ⟨hprevTok, hind⟩ + have hprevTok' : + ∀ q, q ≠ 0 → prevOut q = vals (prevIndex (n := n) q) := by + simpa [PrevTokenSpec_def, InductionSpec_def] using hprevTok + have hind' : ∀ q, q ≠ 0 → indOut q = prevOut (prev q) := by + simpa [InductionSpec_def] using hind + calc + indOut q = prevOut (prev q) := hind' q hq + _ = vals (prevIndex (n := n) (prev q)) := hprevTok' (prev q) hprev + +end Specs + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Linear.lean b/Nfp/Circuit/Layers/Linear.lean new file mode 100644 index 0000000..1f08adb --- /dev/null +++ b/Nfp/Circuit/Layers/Linear.lean @@ -0,0 +1,248 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Image +public import Mathlib.Logic.Embedding.Basic +public import Nfp.Circuit.Basic +public import Nfp.Circuit.Gates.Linear +public import Nfp.Circuit.Typed + +/-! +Linear and affine layer circuits. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +open Function + +universe u v + +variable {Row Col : Type u} + +/-- Node type for a linear/affine layer from `Col` inputs to `Row` outputs. -/ +abbrev LinearNode (Row Col : Type u) : Type u := Sum Col Row + +/-- Rank function used to orient layer edges from inputs to outputs. -/ +def linearRank : LinearNode Row Col → Nat + | Sum.inl _ => 0 + | Sum.inr _ => 1 + +/-- Definitional characterization of `linearRank`. -/ +theorem linearRank_def (x : LinearNode Row Col) : + linearRank (Row := Row) (Col := Col) x = + match x with + | Sum.inl _ => 0 + | Sum.inr _ => 1 := by + cases x <;> rfl + +section Dag + +variable [Fintype Row] [Fintype Col] + +/-- DAG for a single linear/affine layer. -/ +def linearDag : Dag (LinearNode Row Col) := + { graph := { Adj := fun j i => linearRank (Row := Row) (Col := Col) j < + linearRank (Row := Row) (Col := Col) i } + decAdj := by + intro j i + infer_instance + wf := by + simpa using (InvImage.wf (f := linearRank (Row := Row) (Col := Col)) + (h := Nat.lt_wfRel.wf)) } + +/-- Every input node has an edge to every output node. -/ +theorem linearDag_rel_inl_inr (c : Col) (r : Row) : + (linearDag (Row := Row) (Col := Col)).rel (Sum.inl c) (Sum.inr r) := by + dsimp [linearDag, linearRank] + exact Nat.zero_lt_one + +end Dag + +section Inputs + +variable [Fintype Col] + +/-- Input nodes for a linear/affine layer circuit. -/ +def linearInputs : Finset (LinearNode Row Col) := + (Finset.univ : Finset Col).map Embedding.inl + +/-- Membership in the input nodes corresponds to being a left injection. -/ +theorem mem_linearInputs_iff {s : LinearNode Row Col} : + s ∈ linearInputs (Row := Row) (Col := Col) ↔ ∃ c, s = Sum.inl c := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨c, _hc, hcs⟩ + exact ⟨c, hcs.symm⟩ + · rintro ⟨c, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨c, by simp, rfl⟩ + +/-- Right injections are not input nodes. -/ +theorem not_mem_linearInputs_inr (r : Row) : + Sum.inr r ∉ linearInputs (Row := Row) (Col := Col) := by + intro h + rcases (mem_linearInputs_iff (Row := Row) (Col := Col)).1 h with ⟨c, hcs⟩ + cases hcs + +/-- Input labels correspond to input nodes in a linear/affine layer. -/ +def linearInputEquiv : Col ≃ { i // i ∈ linearInputs (Row := Row) (Col := Col) } := + { toFun := fun c => + ⟨Sum.inl c, (mem_linearInputs_iff (Row := Row) (Col := Col)).2 ⟨c, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inl c, _⟩ => c + | ⟨Sum.inr r, h⟩ => False.elim + (not_mem_linearInputs_inr (Row := Row) (Col := Col) r h) + left_inv := by + intro c + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inl c => rfl + | inr r => + cases (not_mem_linearInputs_inr (Row := Row) (Col := Col) r hs) } + +end Inputs + +section Outputs + +variable [Fintype Row] + +/-- Output nodes for a linear/affine layer circuit. -/ +def linearOutputs : Finset (LinearNode Row Col) := + (Finset.univ : Finset Row).map Embedding.inr + +/-- Membership in the output nodes corresponds to being a right injection. -/ +theorem mem_linearOutputs_iff {s : LinearNode Row Col} : + s ∈ linearOutputs (Row := Row) (Col := Col) ↔ ∃ r, s = Sum.inr r := by + constructor + · intro hs + rcases (Finset.mem_map.1 hs) with ⟨r, _hr, hrs⟩ + exact ⟨r, hrs.symm⟩ + · rintro ⟨r, rfl⟩ + refine Finset.mem_map.2 ?_ + exact ⟨r, by simp, rfl⟩ + +/-- Left injections are not output nodes. -/ +theorem not_mem_linearOutputs_inl (c : Col) : + Sum.inl c ∉ linearOutputs (Row := Row) (Col := Col) := by + intro h + rcases (mem_linearOutputs_iff (Row := Row) (Col := Col)).1 h with ⟨r, hrs⟩ + cases hrs + +/-- Output labels correspond to output nodes in a linear/affine layer. -/ +def linearOutputEquiv : Row ≃ { i // i ∈ linearOutputs (Row := Row) (Col := Col) } := + { toFun := fun r => + ⟨Sum.inr r, (mem_linearOutputs_iff (Row := Row) (Col := Col)).2 ⟨r, rfl⟩⟩ + invFun := fun i => + match i with + | ⟨Sum.inr r, _⟩ => r + | ⟨Sum.inl c, h⟩ => False.elim + (not_mem_linearOutputs_inl (Row := Row) (Col := Col) c h) + left_inv := by + intro r + rfl + right_inv := by + intro i + cases i with + | mk s hs => + cases s with + | inr r => rfl + | inl c => + cases (not_mem_linearOutputs_inl (Row := Row) (Col := Col) c hs) } + +end Outputs + +section Circuits + +variable [Fintype Row] [Fintype Col] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Gate semantics for a linear layer circuit. -/ +def linearGate (W : Matrix Row Col Val) : + ∀ i, (∀ j, (linearDag (Row := Row) (Col := Col)).rel j i → Val) → Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr r => + let x : Col → Val := fun c => + rec (Sum.inl c) (linearDag_rel_inl_inr (Row := Row) (Col := Col) c r) + exact Gates.linear W x r + +/-- Gate semantics for an affine layer circuit. -/ +def affineGate (W : Matrix Row Col Val) (b : Row → Val) : + ∀ i, (∀ j, (linearDag (Row := Row) (Col := Col)).rel j i → Val) → Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr r => + let x : Col → Val := fun c => + rec (Sum.inl c) (linearDag_rel_inl_inr (Row := Row) (Col := Col) c r) + exact Gates.affine W b x r + +/-- Circuit for a linear layer. -/ +def linearCircuit (W : Matrix Row Col Val) : Circuit (LinearNode Row Col) Val := + { dag := linearDag (Row := Row) (Col := Col) + inputs := linearInputs (Row := Row) (Col := Col) + outputs := linearOutputs (Row := Row) (Col := Col) + gate := linearGate (Row := Row) (Col := Col) W } + +/-- Circuit for an affine layer. -/ +def affineCircuit (W : Matrix Row Col Val) (b : Row → Val) : + Circuit (LinearNode Row Col) Val := + { dag := linearDag (Row := Row) (Col := Col) + inputs := linearInputs (Row := Row) (Col := Col) + outputs := linearOutputs (Row := Row) (Col := Col) + gate := affineGate (Row := Row) (Col := Col) W b } + +/-- Typed interface for a linear layer circuit. -/ +def linearInterface (W : Matrix Row Col Val) : + Interface (linearCircuit (Row := Row) (Col := Col) W) Col Row := + { inputs := linearInputEquiv (Row := Row) (Col := Col) + outputs := linearOutputEquiv (Row := Row) (Col := Col) } + +/-- Typed interface for an affine layer circuit. -/ +def affineInterface (W : Matrix Row Col Val) (b : Row → Val) : + Interface (affineCircuit (Row := Row) (Col := Col) W b) Col Row := + { inputs := linearInputEquiv (Row := Row) (Col := Col) + outputs := linearOutputEquiv (Row := Row) (Col := Col) } + +end Circuits + +section Typed + +variable [Fintype Row] [Fintype Col] +variable [DecidableEq Row] [DecidableEq Col] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Typed linear layer circuit. -/ +def linearTyped (W : Matrix Row Col Val) : + TypedCircuit (LinearNode Row Col) Val Col Row := + { circuit := linearCircuit (Row := Row) (Col := Col) W + interface := linearInterface (Row := Row) (Col := Col) W } + +/-- Typed affine layer circuit. -/ +def affineTyped (W : Matrix Row Col Val) (b : Row → Val) : + TypedCircuit (LinearNode Row Col) Val Col Row := + { circuit := affineCircuit (Row := Row) (Col := Col) W b + interface := affineInterface (Row := Row) (Col := Col) W b } + +end Typed + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Reshape.lean b/Nfp/Circuit/Layers/Reshape.lean new file mode 100644 index 0000000..59ad3fa --- /dev/null +++ b/Nfp/Circuit/Layers/Reshape.lean @@ -0,0 +1,61 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Logic.Equiv.Prod +public import Nfp.Circuit.Typed + +/-! +Reshape combinators for product-typed circuit interfaces. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v u_in u_out u_in' u_out' + +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Val : Type v} +variable {α β γ : Type u_in} +variable {δ ε ζ : Type u_out} +variable {Input : Type u_in} {Output : Type u_out} +variable {Input' : Type u_in'} {Output' : Type u_out'} + +/-- Reassociate the input/output product structure of a typed circuit. -/ +def reassoc3 + (T : TypedCircuit Node Val ((α × β) × γ) ((δ × ε) × ζ)) : + TypedCircuit Node Val (α × β × γ) (δ × ε × ζ) := + { circuit := T.circuit + interface := + { inputs := (_root_.Equiv.prodAssoc α β γ).symm.trans T.interface.inputs + outputs := (_root_.Equiv.prodAssoc δ ε ζ).symm.trans T.interface.outputs } } + +/-- Swap the two factors of the input/output product structure. -/ +def swap12 + (T : TypedCircuit Node Val (α × β) (δ × ε)) : + TypedCircuit Node Val (β × α) (ε × δ) := + { circuit := T.circuit + interface := + { inputs := (_root_.Equiv.prodComm α β).symm.trans T.interface.inputs + outputs := (_root_.Equiv.prodComm δ ε).symm.trans T.interface.outputs } } + +/-- Apply equivalences to the input/output labels of a typed circuit. -/ +def mapInterface + (T : TypedCircuit Node Val Input Output) + (eIn : _root_.Equiv Input Input') (eOut : _root_.Equiv Output Output') : + TypedCircuit Node Val Input' Output' := + { circuit := T.circuit + interface := + { inputs := eIn.symm.trans T.interface.inputs + outputs := eOut.symm.trans T.interface.outputs } } + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Softmax.lean b/Nfp/Circuit/Layers/Softmax.lean new file mode 100644 index 0000000..a117099 --- /dev/null +++ b/Nfp/Circuit/Layers/Softmax.lean @@ -0,0 +1,156 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Field +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Analysis.Complex.Exponential +public import Mathlib.Data.Finset.Card + +/-! +Real-valued softmax utilities and margin-based bounds. + +These lemmas provide the analytical bridge from score gaps to softmax weight +upper bounds. +-/ + +public section + +namespace Nfp + +namespace Circuit + +open scoped BigOperators +noncomputable section + +variable {seq : Nat} + +/-- Real softmax over a finite score vector. -/ +def softmax (scores : Fin seq → Real) (k : Fin seq) : Real := + Real.exp (scores k) / ∑ j, Real.exp (scores j) + +private lemma softmax_denom_pos [NeZero seq] (scores : Fin seq → Real) : + 0 < ∑ j, Real.exp (scores j) := by + classical + have hnonempty : (Finset.univ : Finset (Fin seq)).Nonempty := by + refine ⟨⟨0, ?_⟩, by simp⟩ + exact Nat.pos_of_ne_zero (NeZero.ne seq) + exact Finset.sum_pos (fun _ _ => Real.exp_pos _) hnonempty + +lemma softmax_nonneg [NeZero seq] (scores : Fin seq → Real) (k : Fin seq) : + 0 ≤ softmax scores k := by + have hdenom : 0 < ∑ j, Real.exp (scores j) := softmax_denom_pos scores + exact (div_nonneg (Real.exp_pos _).le (le_of_lt hdenom)) + +lemma softmax_sum_one [NeZero seq] (scores : Fin seq → Real) : + (∑ k, softmax scores k) = 1 := by + classical + have hdenom : (∑ j, Real.exp (scores j)) ≠ 0 := + ne_of_gt (softmax_denom_pos scores) + have hsum : + (∑ k, Real.exp (scores k) / ∑ j, Real.exp (scores j)) = + (∑ k, Real.exp (scores k)) / ∑ j, Real.exp (scores j) := by + simpa using + (Finset.sum_div (Finset.univ) (fun k => Real.exp (scores k)) + (∑ j, Real.exp (scores j))).symm + calc + ∑ k, softmax scores k + = ∑ k, Real.exp (scores k) / ∑ j, Real.exp (scores j) := by + simp [softmax] + _ = (∑ k, Real.exp (scores k)) / ∑ j, Real.exp (scores j) := hsum + _ = 1 := by + simp [hdenom] + +/-- Real-valued row-stochastic weights with explicit nonnegativity and row-sum proofs. + Kept separate from `ProbVec` because softmax outputs `Real` rather than `NNReal`. -/ +structure SoftmaxWeights (seq : Nat) [NeZero seq] where + /-- Weight assigned to each query/key pair. -/ + weights : Fin seq → Fin seq → Real + /-- All weights are nonnegative. -/ + nonneg : ∀ q k, 0 ≤ weights q k + /-- Each row sums to one. -/ + sum_one : ∀ q, (∑ k, weights q k) = 1 + +/-- Package softmax weights with row-stochastic proofs. -/ +def softmaxWeights [NeZero seq] (scores : Fin seq → Fin seq → Real) : + SoftmaxWeights seq := + { weights := fun q k => softmax (scores q) k + nonneg := by + intro q k + simpa using softmax_nonneg (scores := scores q) k + sum_one := by + intro q + simpa using softmax_sum_one (scores := scores q) } + +/-- Definitional unfolding for `softmaxWeights.weights`. -/ +theorem softmaxWeights_weights [NeZero seq] (scores : Fin seq → Fin seq → Real) : + (softmaxWeights scores).weights = fun q k => softmax (scores q) k := by + rfl + +lemma softmax_le_one [NeZero seq] (scores : Fin seq → Real) (k : Fin seq) : + softmax scores k ≤ 1 := by + classical + have hdenom_pos : 0 < ∑ j, Real.exp (scores j) := softmax_denom_pos scores + have hnum_le : Real.exp (scores k) ≤ ∑ j, Real.exp (scores j) := by + have hnonneg : ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ Real.exp (scores j) := + fun _ _ => (Real.exp_pos _).le + simpa using (Finset.single_le_sum hnonneg (by simp)) + have hdiv := (div_le_one hdenom_pos).2 hnum_le + simpa [softmax] using hdiv + +lemma exp_neg_le_inv_one_add {m : Real} (hm : 0 ≤ m) : + Real.exp (-m) ≤ 1 / (1 + m) := by + have hpos : 0 < 1 + m := add_pos_of_pos_of_nonneg zero_lt_one hm + have hle : 1 + m ≤ Real.exp m := by + simpa [add_comm] using (Real.add_one_le_exp m) + have hdiv : 1 / Real.exp m ≤ 1 / (1 + m) := + one_div_le_one_div_of_le hpos hle + simpa [Real.exp_neg] using hdiv + +lemma softmax_other_le_exp_neg [NeZero seq] (scores : Fin seq → Real) + {prev k : Fin seq} {m : Real} (hmargin : scores k + m ≤ scores prev) : + softmax scores k ≤ Real.exp (-m) := by + classical + let denom : Real := ∑ j, Real.exp (scores j) + have hdenom_pos : 0 < denom := softmax_denom_pos scores + have hdenom_ge : Real.exp (scores prev) ≤ denom := by + have hnonneg : ∀ j ∈ (Finset.univ : Finset (Fin seq)), 0 ≤ Real.exp (scores j) := + fun _ _ => (Real.exp_pos _).le + simpa [denom] using + (Finset.single_le_sum hnonneg (by simp : prev ∈ (Finset.univ : Finset (Fin seq)))) + have hinv : 1 / denom ≤ 1 / Real.exp (scores prev) := + one_div_le_one_div_of_le (Real.exp_pos _) hdenom_ge + have hmul := + mul_le_mul_of_nonneg_left hinv (Real.exp_pos (scores k)).le + have hratio : + Real.exp (scores k) / Real.exp (scores prev) = + Real.exp (scores k - scores prev) := by + symm + exact Real.exp_sub (scores k) (scores prev) + have hk : scores k ≤ scores prev - m := (le_sub_iff_add_le).2 hmargin + have hdiff : scores k - scores prev ≤ -m := by + have hsub := sub_le_sub_right hk (scores prev) + simpa [sub_eq_add_neg, add_assoc, add_left_comm, add_comm] using hsub + have hle : Real.exp (scores k - scores prev) ≤ Real.exp (-m) := + Real.exp_le_exp.mpr hdiff + have hsoft : + Real.exp (scores k) / denom ≤ Real.exp (scores k) / Real.exp (scores prev) := by + simpa [denom, div_eq_mul_inv] using hmul + calc + softmax scores k + = Real.exp (scores k) / denom := by + simp [softmax, denom] + _ ≤ Real.exp (scores k) / Real.exp (scores prev) := hsoft + _ = Real.exp (scores k - scores prev) := hratio + _ ≤ Real.exp (-m) := hle + +lemma softmax_other_le_inv_one_add [NeZero seq] (scores : Fin seq → Real) + {prev k : Fin seq} {m : Real} (hm : 0 ≤ m) (hmargin : scores k + m ≤ scores prev) : + softmax scores k ≤ 1 / (1 + m) := + (softmax_other_le_exp_neg (scores := scores) hmargin).trans (exp_neg_le_inv_one_add hm) + +end + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/Tensor.lean b/Nfp/Circuit/Layers/Tensor.lean new file mode 100644 index 0000000..beb5e84 --- /dev/null +++ b/Nfp/Circuit/Layers/Tensor.lean @@ -0,0 +1,178 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Layers.Linear + +/-! +Tensor-shaped layer builders (batched linear and affine layers). +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v + +variable {Batch Row Col : Type u} + +/-- Node type for a batched linear/affine layer. -/ +abbrev BatchedLinearNode (Batch Row Col : Type u) : Type u := + LinearNode (Batch × Row) (Batch × Col) + +/-- Adjacency for batched linear layers: inputs connect only to outputs in the same batch. -/ +def batchedLinearAdj (Batch Row Col : Type u) : + BatchedLinearNode Batch Row Col → BatchedLinearNode Batch Row Col → Prop + | Sum.inl (b, _), Sum.inr (b', _) => b = b' + | _, _ => False + +section Dag + +variable [Fintype Batch] [Fintype Row] [Fintype Col] +variable [DecidableEq Batch] + +/-- DAG for a batched linear/affine layer. -/ +def batchedLinearDag : Dag (BatchedLinearNode Batch Row Col) := + { graph := { Adj := batchedLinearAdj Batch Row Col } + decAdj := by + intro j i + cases j with + | inl bc => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr br => + exact (inferInstance : Decidable (bc.1 = br.1)) + | inr _ => + cases i with + | inl _ => + exact isFalse (by intro h; cases h) + | inr _ => + exact isFalse (by intro h; cases h) + wf := by + have hsub : Subrelation (batchedLinearAdj Batch Row Col) + (fun j i => + linearRank (Row := Batch × Row) (Col := Batch × Col) j < + linearRank (Row := Batch × Row) (Col := Batch × Col) i) := by + intro j i h + cases j <;> cases i <;> simp [batchedLinearAdj, linearRank_def] at h ⊢ + have hwf : WellFounded (fun j i => + linearRank (Row := Batch × Row) (Col := Batch × Col) j < + linearRank (Row := Batch × Row) (Col := Batch × Col) i) := by + simpa using (InvImage.wf + (f := linearRank (Row := Batch × Row) (Col := Batch × Col)) + (h := Nat.lt_wfRel.wf)) + exact Subrelation.wf hsub hwf } + +/-- Edges connect each batch's inputs to its outputs. -/ +theorem batchedLinearDag_rel_inl_inr (b : Batch) (c : Col) (r : Row) : + (batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col)).rel + (Sum.inl (b, c)) (Sum.inr (b, r)) := by + change batchedLinearAdj Batch Row Col (Sum.inl (b, c)) (Sum.inr (b, r)) + simp [batchedLinearAdj] + +end Dag + +section Circuits + +variable [Fintype Batch] [Fintype Row] [Fintype Col] +variable [DecidableEq Batch] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Gate semantics for a batched linear layer circuit. -/ +def batchedLinearGate (W : Matrix Row Col Val) : + ∀ i, + (∀ j, + (batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col)).rel j i → Val) → + Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr br => + cases br with + | mk b r => + let x : Col → Val := fun c => + rec (Sum.inl (b, c)) + (batchedLinearDag_rel_inl_inr (Batch := Batch) (Row := Row) (Col := Col) b c r) + exact Gates.linear W x r + +/-- Gate semantics for a batched affine layer circuit. -/ +def batchedAffineGate (W : Matrix Row Col Val) (bias : Row → Val) : + ∀ i, + (∀ j, + (batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col)).rel j i → Val) → + Val := by + intro i rec + cases i with + | inl _ => + exact 0 + | inr br => + cases br with + | mk b r => + let x : Col → Val := fun c => + rec (Sum.inl (b, c)) + (batchedLinearDag_rel_inl_inr (Batch := Batch) (Row := Row) (Col := Col) b c r) + exact Gates.affine W bias x r + +/-- Circuit for a batched linear layer. -/ +def batchedLinearCircuit (W : Matrix Row Col Val) : + Circuit (BatchedLinearNode Batch Row Col) Val := + { dag := batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col) + inputs := linearInputs (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputs (Row := Batch × Row) (Col := Batch × Col) + gate := batchedLinearGate (Batch := Batch) (Row := Row) (Col := Col) W } + +/-- Circuit for a batched affine layer. -/ +def batchedAffineCircuit (W : Matrix Row Col Val) (bias : Row → Val) : + Circuit (BatchedLinearNode Batch Row Col) Val := + { dag := batchedLinearDag (Batch := Batch) (Row := Row) (Col := Col) + inputs := linearInputs (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputs (Row := Batch × Row) (Col := Batch × Col) + gate := batchedAffineGate (Batch := Batch) (Row := Row) (Col := Col) W bias } + +/-- Typed interface for a batched linear layer circuit. -/ +def batchedLinearInterface (W : Matrix Row Col Val) : + Interface (batchedLinearCircuit (Batch := Batch) (Row := Row) (Col := Col) W) + (Batch × Col) (Batch × Row) := + { inputs := linearInputEquiv (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputEquiv (Row := Batch × Row) (Col := Batch × Col) } + +/-- Typed interface for a batched affine layer circuit. -/ +def batchedAffineInterface (W : Matrix Row Col Val) (bias : Row → Val) : + Interface (batchedAffineCircuit (Batch := Batch) (Row := Row) (Col := Col) W bias) + (Batch × Col) (Batch × Row) := + { inputs := linearInputEquiv (Row := Batch × Row) (Col := Batch × Col) + outputs := linearOutputEquiv (Row := Batch × Row) (Col := Batch × Col) } + +end Circuits + +section Typed + +variable [Fintype Batch] [Fintype Row] [Fintype Col] +variable [DecidableEq Batch] [DecidableEq Row] [DecidableEq Col] +variable {Val : Type v} [NonUnitalNonAssocSemiring Val] + +/-- Typed batched linear layer circuit. -/ +def batchedLinearTyped (W : Matrix Row Col Val) : + TypedCircuit (BatchedLinearNode Batch Row Col) Val (Batch × Col) (Batch × Row) := + { circuit := batchedLinearCircuit (Batch := Batch) (Row := Row) (Col := Col) W + interface := batchedLinearInterface (Batch := Batch) (Row := Row) (Col := Col) W } + +/-- Typed batched affine layer circuit. -/ +def batchedAffineTyped (W : Matrix Row Col Val) (bias : Row → Val) : + TypedCircuit (BatchedLinearNode Batch Row Col) Val (Batch × Col) (Batch × Row) := + { circuit := batchedAffineCircuit (Batch := Batch) (Row := Row) (Col := Col) W bias + interface := batchedAffineInterface (Batch := Batch) (Row := Row) (Col := Col) W bias } + +end Typed + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Layers/TransformerBlock.lean b/Nfp/Circuit/Layers/TransformerBlock.lean new file mode 100644 index 0000000..14e24c3 --- /dev/null +++ b/Nfp/Circuit/Layers/TransformerBlock.lean @@ -0,0 +1,89 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Compose +public import Nfp.Circuit.Layers.Attention + +/-! +Transformer block wiring built from sequential composition and residual links. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Layers + +universe u v u_in + +variable {Val : Type v} [Add Val] +variable {Input : Type u_in} [Fintype Input] [DecidableEq Input] + +variable {NodeLn1 : Type u} [Fintype NodeLn1] [DecidableEq NodeLn1] +variable {NodeAttn : Type u} [Fintype NodeAttn] [DecidableEq NodeAttn] +variable {NodeLn2 : Type u} [Fintype NodeLn2] [DecidableEq NodeLn2] +variable {NodeMlp : Type u} [Fintype NodeMlp] [DecidableEq NodeMlp] + +/-- Node type for the attention subpath (LN1 ∘ attention). -/ +abbrev AttnPathNode (NodeLn1 NodeAttn : Type u) : Type u := + NodeLn1 ⊕ NodeAttn + +/-- Node type for the attention residual wrapper. -/ +abbrev AttnResidualNode (NodeLn1 NodeAttn : Type u) (Input : Type u_in) : + Type (max u u_in) := + (AttnPathNode NodeLn1 NodeAttn) ⊕ Input + +/-- Node type for the MLP subpath (LN2 ∘ MLP). -/ +abbrev MlpPathNode (NodeLn2 NodeMlp : Type u) : Type u := + NodeLn2 ⊕ NodeMlp + +/-- Node type for the MLP residual wrapper. -/ +abbrev MlpResidualNode (NodeLn2 NodeMlp : Type u) (Input : Type u_in) : + Type (max u u_in) := + (MlpPathNode NodeLn2 NodeMlp) ⊕ Input + +/-- Node type for a full transformer block. -/ +abbrev TransformerBlockNode (NodeLn1 NodeAttn NodeLn2 NodeMlp : Type u) (Input : Type u_in) : + Type (max u u_in) := + (AttnResidualNode NodeLn1 NodeAttn Input) ⊕ (MlpResidualNode NodeLn2 NodeMlp Input) + +/-- Compose a GPT-style transformer block from LN/attention/MLP circuits. -/ +def transformerBlock + (ln1 : TypedCircuit NodeLn1 Val Input Input) + (attn : TypedCircuit NodeAttn Val Input Input) + (ln2 : TypedCircuit NodeLn2 Val Input Input) + (mlp : TypedCircuit NodeMlp Val Input Input) : + TypedCircuit (TransformerBlockNode NodeLn1 NodeAttn NodeLn2 NodeMlp Input) Val Input Input := + let attnPath := TypedCircuit.seq ln1 attn + let attnRes := TypedCircuit.residual attnPath + let mlpPath := TypedCircuit.seq ln2 mlp + let mlpRes := TypedCircuit.residual mlpPath + TypedCircuit.seq attnRes mlpRes + +/-- Token hidden state type for GPT-2 style blocks. -/ +abbrev BlockInput (Batch : Type) (seq heads dim : Nat) : Type := + Batch × Fin seq × Hidden heads dim + +/-- Transformer block specialized to GPT-style hidden states. -/ +def gpt2Block {Batch : Type} [Fintype Batch] [DecidableEq Batch] {seq heads dim : Nat} + (ln1 : TypedCircuit NodeLn1 Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) + (attn : TypedCircuit NodeAttn Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) + (ln2 : TypedCircuit NodeLn2 Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) + (mlp : TypedCircuit NodeMlp Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim)) : + TypedCircuit (TransformerBlockNode NodeLn1 NodeAttn NodeLn2 NodeMlp + (BlockInput Batch seq heads dim)) Val (BlockInput Batch seq heads dim) + (BlockInput Batch seq heads dim) := + transformerBlock ln1 attn ln2 mlp + +end Layers + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Semantics.lean b/Nfp/Circuit/Semantics.lean new file mode 100644 index 0000000..6e629e3 --- /dev/null +++ b/Nfp/Circuit/Semantics.lean @@ -0,0 +1,84 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Basic + +/-! +Evaluation semantics for finite circuits. +-/ + +public section + +namespace Nfp + +universe u v + +namespace Circuit + +variable {ι : Type u} [Fintype ι] [DecidableEq ι] +variable {α : Type v} + +/-- One-step evaluation functional used by `eval`. -/ +def evalStep (C : Circuit ι α) (input : ι → α) + (i : ι) (rec : ∀ j, C.dag.rel j i → α) : α := + if _ : i ∈ C.inputs then input i else C.gate i rec + +/-- Evaluate a circuit with a given input assignment. -/ +def eval (C : Circuit ι α) (input : ι → α) : ι → α := + C.dag.wf.fix (fun i rec => evalStep C input i rec) + +/-- Unfolding equation for `eval`. -/ +theorem eval_eq (C : Circuit ι α) (input : ι → α) (i : ι) : + eval C input i = + if _ : i ∈ C.inputs then input i else C.gate i (fun j _ => eval C input j) := by + set F : ∀ i, (∀ j, C.dag.rel j i → α) → α := fun i rec => evalStep C input i rec + change C.dag.wf.fix F i = + if _ : i ∈ C.inputs then input i else C.gate i (fun j _ => C.dag.wf.fix F j) + rw [WellFounded.fix_eq] + simp [F, evalStep] + +/-- Input nodes evaluate to their assigned input value. -/ +theorem eval_eq_input (C : Circuit ι α) (input : ι → α) {i : ι} (h : i ∈ C.inputs) : + eval C input i = input i := by + simpa [h] using (eval_eq C input i) + +/-- Non-input nodes evaluate via their gate semantics. -/ +theorem eval_eq_gate (C : Circuit ι α) (input : ι → α) {i : ι} (h : i ∉ C.inputs) : + eval C input i = C.gate i (fun j _ => eval C input j) := by + simpa [h] using (eval_eq C input i) + +/-- One-step evaluation functional used by `evalInput`. -/ +def evalInputStep (C : Circuit ι α) (input : C.InputAssignment) + (i : ι) (rec : ∀ j, C.dag.rel j i → α) : α := + if h : i ∈ C.inputs then input ⟨i, h⟩ else C.gate i rec + +/-- Evaluate a circuit with an input assignment defined on input nodes. -/ +def evalInput (C : Circuit ι α) (input : C.InputAssignment) : ι → α := + C.dag.wf.fix (fun i rec => evalInputStep C input i rec) + +/-- Unfolding equation for `evalInput`. -/ +theorem evalInput_eq (C : Circuit ι α) (input : C.InputAssignment) (i : ι) : + evalInput C input i = + if h : i ∈ C.inputs then input ⟨i, h⟩ else C.gate i (fun j _ => evalInput C input j) := by + set F : ∀ i, (∀ j, C.dag.rel j i → α) → α := fun i rec => evalInputStep C input i rec + change C.dag.wf.fix F i = + if h : i ∈ C.inputs then input ⟨i, h⟩ else C.gate i (fun j _ => C.dag.wf.fix F j) + rw [WellFounded.fix_eq] + simp [F, evalInputStep] + +/-- Input nodes evaluate to their assigned input value (input-only form). -/ +theorem evalInput_eq_input (C : Circuit ι α) (input : C.InputAssignment) {i : ι} + (h : i ∈ C.inputs) : + evalInput C input i = input ⟨i, h⟩ := by + simpa [h] using (evalInput_eq C input i) + +/-- Non-input nodes evaluate via their gate semantics (input-only form). -/ +theorem evalInput_eq_gate (C : Circuit ι α) (input : C.InputAssignment) {i : ι} + (h : i ∉ C.inputs) : + evalInput C input i = C.gate i (fun j _ => evalInput C input j) := by + simpa [h] using (evalInput_eq C input i) + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Tensor.lean b/Nfp/Circuit/Tensor.lean new file mode 100644 index 0000000..e6b252d --- /dev/null +++ b/Nfp/Circuit/Tensor.lean @@ -0,0 +1,49 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Matrix.Basic + +/-! +Typed tensor indices and tensor aliases. +-/ + +public section + +namespace Nfp + +namespace Circuit + +namespace Tensor + +universe u v + +/-- Index type for a length-`n` vector. -/ +abbrev VecIndex (n : Nat) : Type := Fin n + +/-- Index type for an `m × n` matrix. -/ +abbrev MatIndex (m n : Nat) : Type := Fin m × Fin n + +/-- Index type for a 3D tensor. -/ +abbrev Tensor3Index (a b c : Nat) : Type := Fin a × Fin b × Fin c + +/-- Index type for a 4D tensor. -/ +abbrev Tensor4Index (a b c d : Nat) : Type := Fin a × Fin b × Fin c × Fin d + +/-- A length-`n` vector of values. -/ +abbrev Vec (n : Nat) (Val : Type v) : Type v := VecIndex n → Val + +/-- An `m × n` matrix of values. -/ +abbrev Mat (m n : Nat) (Val : Type v) : Type v := Matrix (VecIndex m) (VecIndex n) Val + +/-- A 3D tensor of values. -/ +abbrev Tensor3 (a b c : Nat) (Val : Type v) : Type v := Tensor3Index a b c → Val + +/-- A 4D tensor of values. -/ +abbrev Tensor4 (a b c d : Nat) (Val : Type v) : Type v := Tensor4Index a b c d → Val + +end Tensor + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/Typed.lean b/Nfp/Circuit/Typed.lean new file mode 100644 index 0000000..4c27e8f --- /dev/null +++ b/Nfp/Circuit/Typed.lean @@ -0,0 +1,69 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Combinators +public import Nfp.Circuit.Cert.Basic + +/-! +Typed circuit wrappers and typed equivalence checking. +-/ + +public section + +namespace Nfp + +universe u v u' u_in u_out + +namespace Circuit + +/-- A circuit bundled with a typed input/output interface. -/ +structure TypedCircuit (Node : Type u) [Fintype Node] [DecidableEq Node] (Val : Type v) + (Input : Type u_in) (Output : Type u_out) where + /-- The underlying circuit. -/ + circuit : Circuit Node Val + /-- Typed input/output interface for `circuit`. -/ + interface : Interface circuit Input Output + +namespace TypedCircuit + +variable {Node : Type u} [Fintype Node] [DecidableEq Node] +variable {Val : Type v} {Input : Type u_in} {Output : Type u_out} + +/-- Evaluate a typed circuit on a typed input. -/ +def eval (T : TypedCircuit Node Val Input Output) (input : Input → Val) : Output → Val := + T.interface.eval input + +/-- Definitional characterization of `TypedCircuit.eval`. -/ +theorem eval_def (T : TypedCircuit Node Val Input Output) (input : Input → Val) : + T.eval input = T.interface.eval input := by + rfl + +/-- Decide equivalence by enumerating typed inputs. -/ +def checkEquiv (T1 T2 : TypedCircuit Node Val Input Output) + [Fintype Input] [DecidableEq Input] [Fintype Output] + [Fintype Val] [DecidableEq Val] : Bool := + Circuit.checkEquivOnInterface T1.circuit T2.circuit T1.interface T2.interface + +/-- `checkEquiv` is sound and complete for `EquivOnInterface`. -/ +theorem checkEquiv_eq_true_iff (T1 T2 : TypedCircuit Node Val Input Output) + [Fintype Input] [DecidableEq Input] [Fintype Output] + [Fintype Val] [DecidableEq Val] : + checkEquiv T1 T2 = true ↔ + EquivOnInterface T1.circuit T2.circuit T1.interface T2.interface := by + simpa [checkEquiv] using + (checkEquivOnInterface_eq_true_iff T1.circuit T2.circuit T1.interface T2.interface) + +variable {Node' : Type u'} [Fintype Node'] [DecidableEq Node'] + +/-- Relabel the nodes of a typed circuit. -/ +def relabel (T : TypedCircuit Node Val Input Output) (e : _root_.Equiv Node Node') : + TypedCircuit Node' Val Input Output := + { circuit := T.circuit.relabel e + interface := T.interface.relabel e } + +end TypedCircuit + +end Circuit + +end Nfp diff --git a/Nfp/Circuit/WellFormed.lean b/Nfp/Circuit/WellFormed.lean new file mode 100644 index 0000000..c9813d5 --- /dev/null +++ b/Nfp/Circuit/WellFormed.lean @@ -0,0 +1,40 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Basic + +/-! +Well-formedness conditions for circuits. +-/ + +public section + +namespace Nfp + +universe u v + +namespace Circuit + +variable {ι : Type u} [Fintype ι] +variable {α : Type v} + +/-- A circuit is well-formed if every input node has no incoming edges. -/ +def WellFormed (C : Circuit ι α) : Prop := + ∀ i ∈ C.inputs, ∀ j, ¬ C.dag.rel j i + +/-- Inputs have no parents in a well-formed circuit. -/ +theorem wellFormed_no_parent {C : Circuit ι α} (h : WellFormed C) {i j : ι} (hi : i ∈ C.inputs) : + ¬ C.dag.rel j i := + by + simpa using h i hi j + +/-- Input nodes have empty parent sets in a well-formed circuit. -/ +theorem wellFormed_parents_empty {C : Circuit ι α} (h : WellFormed C) {i : ι} (hi : i ∈ C.inputs) : + C.dag.parents i = ∅ := by + ext j + simp [Dag.mem_parents, h i hi] + +end Circuit + +end Nfp diff --git a/Nfp/Cli.lean b/Nfp/Cli.lean new file mode 100644 index 0000000..3016af6 --- /dev/null +++ b/Nfp/Cli.lean @@ -0,0 +1,89 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Cli +import Nfp.IO + +/-! +Minimal CLI surface for the NFP rewrite. +-/ + +public section + +open Cli + +namespace Nfp + +/-- Human-readable version string for the CLI. -/ +def versionString : String := "0.1.0-tabula" + +/-- Print the CLI version. -/ +def runVersion (_p : Parsed) : IO UInt32 := do + IO.println s!"nfp version {versionString}" + return 0 + +/-- The version subcommand. -/ +def versionCmd : Cmd := `[Cli| + version VIA runVersion; + "Print the NFP version." +] + +private def runInductionCertifySimple (p : Parsed) : IO UInt32 := do + let certPath? := (p.flag? "cert").map (·.as! String) + let minActive? := (p.flag? "min-active").map (·.as! Nat) + let minLogitDiffStr? := (p.flag? "min-logit-diff").map (·.as! String) + let minMarginStr? := (p.flag? "min-margin").map (·.as! String) + let maxEpsStr? := (p.flag? "max-eps").map (·.as! String) + let tokensPath? := (p.flag? "tokens").map (·.as! String) + let fail (msg : String) : IO UInt32 := do + IO.eprintln s!"error: {msg}" + return 2 + match certPath? with + | none => fail "provide --cert" + | some certPath => + IO.runInductionHeadCertCheck certPath minActive? minLogitDiffStr? + minMarginStr? maxEpsStr? tokensPath? + +/-- `nfp induction certify` subcommand (streamlined). -/ +def inductionCertifySimpleCmd : Cmd := `[Cli| + certify VIA runInductionCertifySimple; + "Check induction head certificates from an explicit cert." + FLAGS: + cert : String; "Path to the induction head certificate file." + tokens : String; "Optional path to a token list to verify prev/active." + "min-active" : Nat; "Optional minimum number of active queries required \ + (default: max 1 (seq/8))." + "min-logit-diff" : String; "Optional minimum logit-diff lower bound \ + (rational literal; default: 0)." + "min-margin" : String; "Optional minimum score margin (rational literal; default: 0)." + "max-eps" : String; "Optional maximum eps tolerance (rational literal; default: 1/2)." +] + +/-- Induction-head subcommands. -/ +def inductionCmd : Cmd := `[Cli| + induction NOOP; + "Induction-head utilities (streamlined)." + SUBCOMMANDS: + inductionCertifySimpleCmd +] + +/-- The root CLI command. -/ +def nfpCmd : Cmd := `[Cli| + nfp NOOP; + "NFP: Neural Formal Pathways (rewrite in progress)." + SUBCOMMANDS: + versionCmd; + inductionCmd +] + +/-- Main entry point for the CLI. -/ +def main (args : List String) : IO UInt32 := do + if args.contains "--version" then + IO.println s!"nfp version {versionString}" + return 0 + nfpCmd.validate args + +end Nfp + +end diff --git a/Nfp/Core.lean b/Nfp/Core.lean new file mode 100644 index 0000000..d4de670 --- /dev/null +++ b/Nfp/Core.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Core.Basic + +/-! +Core shared definitions for the NFP rewrite. +-/ diff --git a/Nfp/Core/Basic.lean b/Nfp/Core/Basic.lean new file mode 100644 index 0000000..36e5f7b --- /dev/null +++ b/Nfp/Core/Basic.lean @@ -0,0 +1,191 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public meta import Nfp.Tactic.Linter +public import Mathlib.Algebra.Order.Group.Unbundled.Abs +public import Mathlib.Data.NNReal.Defs +public import Mathlib.Data.NNReal.Basic +public import Mathlib.Data.Rat.Cast.Lemmas +public import Mathlib.Data.Rat.Cast.Order +public import Mathlib.Data.Real.Basic + +/-! +Basic shared definitions for the NFP rewrite. +-/ + +public section + +namespace Nfp + +/-- Nonnegative mass used for probabilities and weights. -/ +abbrev Mass := NNReal + +/-- Default precision placeholder retained for API compatibility. -/ +def defaultRatPrec : Int := 48 + +/-- Round a rational down (identity in the exact-rational refactor). -/ +def ratRoundDown (q : Rat) (_prec : Int := defaultRatPrec) : Rat := + q + +/-- Definitional characterization of `ratRoundDown`. -/ +theorem ratRoundDown_def (q : Rat) (prec : Int := defaultRatPrec) : + ratRoundDown q prec = q := by + rfl + +/-- Round a rational up (identity in the exact-rational refactor). -/ +def ratRoundUp (q : Rat) (_prec : Int := defaultRatPrec) : Rat := + q + +/-- Definitional characterization of `ratRoundUp`. -/ +theorem ratRoundUp_def (q : Rat) (prec : Int := defaultRatPrec) : + ratRoundUp q prec = q := by + rfl + +/-- Real cast of a rational value. -/ +def ratToReal (x : Rat) : Real := + (x : Real) + +/-- Definitional characterization of `ratToReal`. -/ +theorem ratToReal_def (x : Rat) : ratToReal x = (x : Real) := by + rfl + +@[simp] theorem ratToReal_zero : ratToReal 0 = 0 := by + simp [ratToReal] + +@[simp] theorem ratToReal_one : ratToReal 1 = 1 := by + simp [ratToReal] + +theorem ratRoundDown_le (q : Rat) (_prec : Int := defaultRatPrec) : + (ratRoundDown q _prec : Rat) ≤ q := by + simp [ratRoundDown] + +theorem ratRoundUp_ge (q : Rat) (_prec : Int := defaultRatPrec) : + q ≤ (ratRoundUp q _prec : Rat) := by + simp [ratRoundUp] + +theorem ratRoundDown_le_real (q : Rat) (_prec : Int := defaultRatPrec) : + (ratRoundDown q _prec : Real) ≤ (q : Real) := by + simp [ratRoundDown] + +theorem real_le_ratRoundUp (q : Rat) (_prec : Int := defaultRatPrec) : + (q : Real) ≤ (ratRoundUp q _prec : Real) := by + simp [ratRoundUp] + +theorem ratRoundDown_nonneg {q : Rat} (hq : 0 ≤ q) (_prec : Int := defaultRatPrec) : + 0 ≤ ratRoundDown q _prec := by + simpa [ratRoundDown] using hq + +theorem ratRoundUp_nonneg {q : Rat} (hq : 0 ≤ q) (_prec : Int := defaultRatPrec) : + 0 ≤ ratRoundUp q _prec := by + simpa [ratRoundUp] using hq + +theorem ratRoundUp_pos {q : Rat} (hq : 0 < q) (_prec : Int := defaultRatPrec) : + 0 < ratRoundUp q _prec := by + simpa [ratRoundUp] using hq + +/-- Interpret `n * 2^{-prec}` as a rational (matching power-of-two scaling). -/ +def ratOfIntWithPrec (n : Int) (prec : Int) : Rat := + let pow2 (k : Nat) : Rat := Rat.ofInt (Int.ofNat (Nat.pow 2 k)) + if _h : 0 ≤ prec then + Rat.ofInt n / pow2 (Int.toNat prec) + else + Rat.ofInt n * pow2 (Int.toNat (-prec)) + +/-- Rational division with downward rounding (exact for rationals). -/ +def ratDivDown (x y : Rat) (_prec : Int := defaultRatPrec) : Rat := + if y = 0 then + 0 + else + x / y + +/-- Definitional characterization of `ratDivDown`. -/ +theorem ratDivDown_def (x y : Rat) (prec : Int := defaultRatPrec) : + ratDivDown x y prec = if y = 0 then 0 else x / y := by + rfl + +/-- Rational division with upward rounding (exact for rationals). -/ +def ratDivUp (x y : Rat) (_prec : Int := defaultRatPrec) : Rat := + if y = 0 then + 0 + else + x / y + +/-- Definitional characterization of `ratDivUp`. -/ +theorem ratDivUp_def (x y : Rat) (prec : Int := defaultRatPrec) : + ratDivUp x y prec = if y = 0 then 0 else x / y := by + rfl + +theorem ratDivUp_ge (x y : Rat) (hy : y ≠ 0) : + (x / y : Rat) ≤ (ratDivUp x y : Rat) := by + simp [ratDivUp, hy] + +theorem ratDivUp_ge_real (x y : Rat) (hy : y ≠ 0) : + (x : Real) / (y : Real) ≤ ratToReal (ratDivUp x y) := by + simp [ratDivUp, ratToReal, hy] + +@[simp] theorem ratToReal_add (x y : Rat) : + ratToReal (x + y) = ratToReal x + ratToReal y := by + simp [ratToReal] + +@[simp] theorem ratToReal_sub (x y : Rat) : + ratToReal (x - y) = ratToReal x - ratToReal y := by + simp [ratToReal] + +@[simp] theorem ratToReal_mul (x y : Rat) : + ratToReal (x * y) = ratToReal x * ratToReal y := by + simp [ratToReal] + +@[simp] theorem ratToReal_neg (x : Rat) : + ratToReal (-x) = -ratToReal x := by + simp [ratToReal] + +@[simp] theorem ratToReal_if {p : Prop} [Decidable p] (a b : Rat) : + ratToReal (if p then a else b) = + if p then ratToReal a else ratToReal b := by + by_cases hp : p <;> simp [hp, ratToReal] + +theorem ratToReal_le_iff {x y : Rat} : + ratToReal x ≤ ratToReal y ↔ x ≤ y := by + simp [ratToReal] + +/-- Rational order implies real order after casting. -/ +theorem ratToReal_le_of_le {x y : Rat} (h : x ≤ y) : + ratToReal x ≤ ratToReal y := by + simpa [ratToReal_le_iff] using h + +theorem ratToReal_lt_iff {x y : Rat} : + ratToReal x < ratToReal y ↔ x < y := by + simp [ratToReal] + +theorem ratToReal_nonneg_iff {x : Rat} : + 0 ≤ ratToReal x ↔ 0 ≤ x := by + simp [ratToReal] + +theorem ratToReal_nonneg_of_nonneg {x : Rat} (h : 0 ≤ x) : + 0 ≤ ratToReal x := by + simpa [ratToReal_nonneg_iff] using h + +theorem ratToReal_nonpos_iff {x : Rat} : + ratToReal x ≤ 0 ↔ x ≤ 0 := by + simp [ratToReal] + +@[simp] theorem ratToReal_abs (x : Rat) : + ratToReal |x| = |ratToReal x| := by + simp [ratToReal] + +theorem ratToReal_abs_le_of_le {x y : Rat} (h : |x| ≤ y) : + |ratToReal x| ≤ ratToReal y := by + simpa [ratToReal_abs] using ratToReal_le_of_le h + +@[simp] theorem ratToReal_max (x y : Rat) : + ratToReal (max x y) = max (ratToReal x) (ratToReal y) := by + simp [ratToReal] + +@[simp] theorem ratToReal_min (x y : Rat) : + ratToReal (min x y) = min (ratToReal x) (ratToReal y) := by + simp [ratToReal] + +end Nfp + +end diff --git a/Nfp/Discovery.lean b/Nfp/Discovery.lean deleted file mode 100644 index f5f005c..0000000 --- a/Nfp/Discovery.lean +++ /dev/null @@ -1,9875 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Batteries.Lean.Float - -/-! -# Executable Circuit Discovery for Induction Heads - -This module provides executable functions for discovering **certified induction heads** -from concrete model weights. It bridges the theoretical framework (Frobenius norms, -pattern terms, faithfulness bounds) with practical verification of real neural networks. - -## Key Components - -1. **Concrete Model Structures**: Computable representations of attention layers using - Arrays and Floats instead of abstract types. Suitable for export from PyTorch/JAX. - -2. **Efficient Bound Calculation**: Algorithms to compute `patternTerm` and `valueTerm` - bounds without materializing the full (N·D)² Jacobian matrix. - -3. **Discovery Functions**: Search algorithms that iterate over layer pairs to find - certified virtual heads (e.g., induction heads). - -## Mathematical Background - -For an attention layer with weights (W_Q, W_K, W_V, W_O), the Jacobian at input x -decomposes as: `fullJacobian = valueTerm + patternTerm` where: -- `valueTerm` depends only on attention weights A and projections W_V·W_O -- `patternTerm` captures how A shifts when input changes (the error term) - -The **faithfulness bound** states: if ‖patternTerm‖_F ≤ ε, then the simple -attention-based interpretation is ε-accurate. - -## Performance Optimizations - -This module contains critical hot paths for circuit discovery. Key optimizations: - -1. **Array pre-allocation**: Use `Array.ofFn` instead of repeated `push` operations - - Avoids O(n²) copying from array reallocations - - Memory usage: O(n) instead of O(n²) during construction - -2. **Direct loops over List.range.foldl**: Replace `(List.range n).foldl f acc` with - `for i in [:n] do ...` - eliminates intermediate list construction (10-100× faster) - -3. **Bounds-checked array access**: Use `array[i]!` which panics on out-of-bounds - instead of `getD` which silently returns default values - - Makes bugs explicit rather than silent - - Compiler can optimize bounds checks in loops - -4. **Matrix operations**: Pre-allocated `Array.ofFn` with `Id.run do` blocks for - complex computations (matmul, matVecMul, power iteration) - -**Benchmark impact** (GPT-2 Small, 12 layers × 12 heads): -- Matrix operations: 10-50× faster (direct loops vs List.range.foldl) -- Array construction: 2-5× faster (pre-allocation vs repeated push) -- Memory: 50% reduction (no intermediate copies) - -These optimizations make circuit discovery practical on real models (seconds instead -of minutes for full network analysis). --/ - -namespace Nfp - -/-! ## Concrete Weight Representations -/ - -/-- A concrete weight matrix stored as nested Arrays. -This is the computable representation for export from PyTorch/JAX. -/ -structure ConcreteMatrix where - /-- Number of rows -/ - numRows : Nat - /-- Number of columns -/ - numCols : Nat - /-- Row-major data storage. data[i * numCols + j] = entry (i, j) -/ - data : Array Float - /-- Data has the correct size -/ - size_eq : data.size = numRows * numCols - -namespace ConcreteMatrix - -/-- Access element (i, j) of the matrix. Returns 0 if out of bounds. - -PERFORMANCE: This is the guarded accessor; hot loops should prefer `getUnsafe` -once bounds are established to avoid the per-access branch. --/ -def get (M : ConcreteMatrix) (i j : Nat) : Float := - if i < M.numRows ∧ j < M.numCols then - -- Index is in-bounds by `size_eq` and the guard above. - M.data[i * M.numCols + j]! - else 0.0 - -/-- Fast access to element `(i, j)` assuming `i < numRows` and `j < numCols`. -/ -@[inline] def getUnsafe (M : ConcreteMatrix) (i j : Nat) : Float := - M.data[i * M.numCols + j]! - -/-- Set element `(i,j)` of the matrix. If out of bounds, returns the original matrix. - -PERFORMANCE: This is intended for small matrices (e.g. `headDim×headDim` Grams) where copying the -underlying array is cheap. Prefer bulk construction for large matrices. --/ -def set (M : ConcreteMatrix) (i j : Nat) (val : Float) : ConcreteMatrix := - if h : i < M.numRows ∧ j < M.numCols then - let idx := i * M.numCols + j - { numRows := M.numRows - numCols := M.numCols - data := M.data.set! idx val - size_eq := by simpa [idx] using M.size_eq } - else - M - -/-- Maximum absolute entry in a given row. Returns 0 if the row is out of bounds. -/ -def rowMaxAbs (M : ConcreteMatrix) (r : Nat) : Float := - if r < M.numRows then - Id.run do - let mut m : Float := 0.0 - let rowBase := r * M.numCols - for c in [:M.numCols] do - let a := Float.abs (M.data[rowBase + c]!) - if a > m then - m := a - return m - else - 0.0 - -/-- Take the first `n` rows of a matrix (keeping all columns). -/ -def takeRows (M : ConcreteMatrix) (n : Nat) : ConcreteMatrix := - if n ≥ M.numRows then - M - else - { numRows := n - numCols := M.numCols - data := .ofFn fun idx : Fin (n * M.numCols) => - -- Since `n < M.numRows`, `idx.val < n*numCols ≤ numRows*numCols = data.size`. - M.data.getD idx.val 0.0 - size_eq := Array.size_ofFn } - -/-- Create a zero matrix of given dimensions. -/ -def zeros (rows cols : Nat) : ConcreteMatrix where - numRows := rows - numCols := cols - data := .ofFn fun _ : Fin (rows * cols) => (0.0 : Float) - size_eq := Array.size_ofFn - -/-- Create an all-ones matrix of given dimensions. -/ -def ones (rows cols : Nat) : ConcreteMatrix where - numRows := rows - numCols := cols - data := .ofFn fun _ : Fin (rows * cols) => (1.0 : Float) - size_eq := Array.size_ofFn - -/-- Create an identity matrix. -/ -def identity (n : Nat) : ConcreteMatrix where - numRows := n - numCols := n - data := .ofFn fun idx : Fin (n * n) => - let i := idx.val / n - let j := idx.val % n - if i = j then 1.0 else 0.0 - size_eq := Array.size_ofFn - -/-- Matrix multiplication. - -PERFORMANCE CRITICAL: This is the hottest path in circuit discovery. -- Pre-allocates result array with `Array.ofFn` (no intermediate copies) -- Direct `for k in [:A.numCols]` instead of `List.range.foldl` (10-50× faster) -- Uses `Id.run do` to enable mutable accumulator in pure context -- Uses deterministic task-parallelism for very large products (preserving evaluation order - *within* each dot product). --/ -private def matmulSeqCore (A B : ConcreteMatrix) : ConcreteMatrix := - { - numRows := A.numRows - numCols := B.numCols - data := .ofFn fun idx : Fin (A.numRows * B.numCols) => Id.run do - let i := idx.val / B.numCols - let j := idx.val % B.numCols - let mut acc : Float := 0.0 - let aRowBase := i * A.numCols - for k in [:A.numCols] do - -- SAFETY: within this branch `i < A.numRows` and `k < A.numCols`, - -- and `A.size_eq` implies `aRowBase + k < A.data.size`. - let a := A.data[aRowBase + k]! - -- SAFETY: within this branch `k < B.numRows` and `j < B.numCols`, - -- and `B.size_eq` implies `k * B.numCols + j < B.data.size`. - let b := B.data[k * B.numCols + j]! - acc := acc + a * b - return acc - size_eq := Array.size_ofFn - } - -private def matmulParFlopThreshold : Nat := 10_000_000 -private def matmulParMaxTasks : Nat := 16 -private def matmulParMinInnerDim : Nat := 256 -private def matmulParMinOutCols : Nat := 256 -private def matmulParMinRows : Nat := 2 - -private def shouldUseMatmulPar (A B : ConcreteMatrix) : Bool := - let flops := A.numRows * B.numCols * A.numCols - flops ≥ matmulParFlopThreshold && - A.numCols ≥ matmulParMinInnerDim && - B.numCols ≥ matmulParMinOutCols && - A.numRows ≥ matmulParMinRows - -private def matmulPar (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = 0 || B.numCols = 0 then - matmulSeqCore A B - else - let numTasks := min matmulParMaxTasks A.numRows - let q := A.numRows / numTasks - let r := A.numRows % numTasks - - let tasks : Array (Task (Array Float)) := - .ofFn fun t : Fin numTasks => - Task.spawn (fun _ => - let tid := t.val - let extra := if tid < r then 1 else 0 - let rowsHere := q + extra - let startRow := tid * q + min tid r - let chunkSize := rowsHere * B.numCols - .ofFn fun idx : Fin chunkSize => Id.run do - let localRow := idx.val / B.numCols - let j := idx.val % B.numCols - let i := startRow + localRow - let mut acc : Float := 0.0 - let aRowBase := i * A.numCols - for k in [:A.numCols] do - -- SAFETY: `i < A.numRows` by chunk construction and `k < A.numCols` by loop bound. - let a := A.data[aRowBase + k]! - -- SAFETY: `k < B.numRows` and `j < B.numCols` by loop bounds and chunk indexing. - let b := B.data[k * B.numCols + j]! - acc := acc + a * b - return acc) - - -- Join in increasing task index order (deterministic). - let chunks := tasks.map Task.get - let cutoff := (q + 1) * r - - { - numRows := A.numRows - numCols := B.numCols - data := .ofFn fun idx : Fin (A.numRows * B.numCols) => - let row := idx.val / B.numCols - let col := idx.val % B.numCols - let taskIdx := - if row < cutoff then - row / (q + 1) - else - r + (row - cutoff) / q - let localRow := - if row < cutoff then - row % (q + 1) - else - (row - cutoff) % q - -- SAFETY: `taskIdx < numTasks` by construction; chunks are in task order. - let chunk := chunks[taskIdx]! - -- SAFETY: `localRow < rowsHere` for this chunk and `col < B.numCols`. - chunk[localRow * B.numCols + col]! - size_eq := Array.size_ofFn - } - -def matmul (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numCols = B.numRows then - if shouldUseMatmulPar A B then - matmulPar A B - else - matmulSeqCore A B - else zeros 0 0 - -/-- Compute Frobenius norm squared: Σᵢⱼ M[i,j]². -/ -def frobeniusNormSq (M : ConcreteMatrix) : Float := - M.data.foldl (fun acc x => acc + x * x) 0.0 - -/-- Compute Frobenius norm: √(Σᵢⱼ M[i,j]²). -/ -def frobeniusNorm (M : ConcreteMatrix) : Float := - Float.sqrt M.frobeniusNormSq - -/-- Compute `trace(A · B)` without allocating the product. - -This uses `trace(A·B) = ∑_{i,j} A[i,j] · B[j,i]`. -Returns 0.0 if the dimensions do not line up. --/ -def traceMul (A B : ConcreteMatrix) : Float := Id.run do - if A.numCols ≠ B.numRows then return 0.0 - if A.numRows ≠ B.numCols then return 0.0 - if A.numRows ≠ A.numCols then return 0.0 - let n := A.numRows - let mut acc : Float := 0.0 - for i in [:n] do - let aRowBase := i * A.numCols - for j in [:n] do - -- SAFETY: `i,j < n` and both matrices are `n×n`. - acc := acc + A.data[aRowBase + j]! * B.data[j * B.numCols + i]! - return acc - -/-! ### Float numerics (heuristics) - -The definitions in this section use `Float` arithmetic for speed. - -Important: these are **not** kernel-sound upper bounds in general. -They are best-effort numerical estimates (rounding may under- or over-estimate). -Sound certification lives in `Nfp.Sound.*`. --/ - -/-! #### Non-certified estimates - -The functions in this subsection are **not** mathematically certified upper bounds. -They may be useful for diagnostics, but must not feed into the repo's “rigorous” -error / ε pipelines. --/ - -/-- Heuristic operator-norm estimate via power iteration. - -The operator norm ‖M‖₂ = max‖x‖=1 ‖x·M‖ (row-vector convention) is the largest singular value. -We approximate it using power iteration on M^T M. - -This is a fast **heuristic estimate** of how much `M` can stretch a vector. - -PERFORMANCE: Power iteration is O(iterations × n²) but heavily optimized: -- Pre-allocated vectors with `Array.ofFn` (no array copying) -- Direct loops instead of `List.range.foldl` (10× faster) -- Bounds-checked access `v[j]!` and `Mv[i]!` (compiler optimizes in loops) --/ -def operatorNormHeuristicPI (M : ConcreteMatrix) (numIterations : Nat := 20) : Float := Id.run do - let numRows := M.numRows - let numCols := M.numCols - if numRows = 0 || numCols = 0 then return 0.0 - - -- Initialize with a vector of ones - let mut v : Array Float := .ofFn fun _ : Fin numCols => 1.0 - - -- Normalize initial vector - let initNorm := Float.sqrt (v.foldl (fun acc x => acc + x * x) 0.0) - if initNorm > 0.0 then - v := v.map (· / initNorm) - - -- Power iteration: v ← (M^T M) v / ‖(M^T M) v‖ - let mut sigma : Float := 0.0 - for _ in [:numIterations] do - -- Compute M v - let mut Mv : Array Float := .ofFn fun i : Fin numRows => Id.run do - let mut acc : Float := 0.0 - let rowBase := i.val * numCols - for j in [:numCols] do - -- SAFETY: v has size M.numCols, guaranteed by Array.ofFn - acc := acc + M.data[rowBase + j]! * v[j]! - return acc - - -- Compute M^T (M v) = (M^T M) v - let mut MTMv : Array Float := .ofFn fun j : Fin numCols => Id.run do - let mut acc : Float := 0.0 - let col := j.val - for i in [:numRows] do - -- SAFETY: Mv has size M.numRows, guaranteed by Array.ofFn above - acc := acc + M.data[i * numCols + col]! * Mv[i]! - return acc - - -- Compute norm of MTMv (this is σ² times ‖v‖, and ‖v‖ ≈ 1) - let normSq := MTMv.foldl (fun acc x => acc + x * x) 0.0 - let norm := Float.sqrt normSq - - if norm < 1e-15 then - return 0.0 - - -- σ² ≈ ‖MTMv‖ / ‖v‖ ≈ ‖MTMv‖ - -- So σ ≈ ‖Mv‖ - let MvNorm := Float.sqrt (Mv.foldl (fun acc x => acc + x * x) 0.0) - sigma := MvNorm - - -- Normalize for next iteration - v := MTMv.map (· / norm) - - -- Heuristic safety margin for numerical errors - sigma * 1.01 - - -/-! ### Provable eigenvalue upper bounds (PSD moments) - -The helper below is used to tighten bounds of the form `λ_max(G)` where `G` is -symmetric positive semidefinite (PSD), using only the first two spectral moments. - -Let `λ₁ ≥ λ₂ ≥ ... ≥ λₙ ≥ 0` be the eigenvalues of `G`. -Write -- `tr = Σᵢ λᵢ = trace(G)` -- `f2 = Σᵢ λᵢ² = ‖G‖_F²`. - -Among all nonnegative spectra with fixed `tr` and `f2`, the maximum possible `λ₁` -is achieved when the remaining `n-1` eigenvalues are equal. Solving that extremal -case yields the closed-form upper bound: - - λ₁ ≤ (tr + sqrt((n-1) * (n*f2 - tr^2))) / n. - -We defensively clamp the radicand to `≥ 0` to avoid negative values caused by -Float roundoff. --/ - -/-- PSD moment bound on the maximum eigenvalue from trace and Frobenius-squared. - -Inputs: -- `n`: matrix dimension -- `tr = trace(G)` -- `f2 = ‖G‖_F²` - -Output: a deterministic `Float` expression corresponding to the real bound above. --/ -def psdLambdaMaxUpperMoment (n : Nat) (tr f2 : Float) : Float := - if n = 0 then - 0.0 - else if n = 1 then - -- For 1×1 PSD matrices, λ_max = tr. - max 0.0 tr - else - let nF : Float := n.toFloat - let rad := max 0.0 ((n - 1).toFloat * (nF * f2 - tr * tr)) - let root := Float.sqrt rad - max 0.0 ((tr + root) / nF) - - -/-- Estimate maximum absolute row sum (induced ℓ1 for row-vector convention, -induced ℓ∞ for column-vector convention). - -Mathematically, the real-valued quantity `maxᵢ Σⱼ |M[i,j]|` is an induced matrix norm. -This function computes a `Float` approximation and is therefore a heuristic estimate. --/ -def maxAbsRowSumEst (M : ConcreteMatrix) : Float := Id.run do - let mut maxSum : Float := 0.0 - for i in [:M.numRows] do - let mut rowSum : Float := 0.0 - let rowBase := i * M.numCols - for j in [:M.numCols] do - rowSum := rowSum + Float.abs (M.data[rowBase + j]!) - maxSum := max maxSum rowSum - maxSum - - -/-- Estimate maximum absolute column sum (induced ℓ∞ for row-vector convention, -induced ℓ1 for column-vector convention). - -Mathematically, the real-valued quantity `maxⱼ Σᵢ |M[i,j]|` is an induced matrix norm. -This function computes a `Float` approximation and is therefore a heuristic estimate. --/ -def maxAbsColSumEst (M : ConcreteMatrix) : Float := Id.run do - let mut maxSum : Float := 0.0 - for j in [:M.numCols] do - let mut colSum : Float := 0.0 - for i in [:M.numRows] do - let rowBase := i * M.numCols - colSum := colSum + Float.abs (M.data[rowBase + j]!) - maxSum := max maxSum colSum - maxSum - - -/-- Rigorous (inequality-based) one/inf upper bound on `‖M‖₂`. - -In exact real arithmetic: -`‖M‖₂ ≤ sqrt(‖M‖₁ · ‖M‖∞)`. - -We compute max row/column sums from the stored Float entries; `‖·‖₁`/`‖·‖∞` -swap under transpose, so the bound is convention-agnostic. --/ -def opNormUpperBoundOneInf (M : ConcreteMatrix) : Float := - Float.sqrt (M.maxAbsRowSumEst * M.maxAbsColSumEst) - -/-- Heuristic Schur-type estimate: `sqrt(‖M‖₁ · ‖M‖∞)` computed in `Float`. - -In exact real arithmetic, `sqrt(‖M‖₁ · ‖M‖∞)` upper-bounds the spectral norm. -This implementation uses `Float`, so it should be treated as an estimate. --/ -def schurNormEst (M : ConcreteMatrix) : Float := - M.opNormUpperBoundOneInf - -/-- Cheap operator-norm upper bound formula for a concrete real matrix. - -In exact real arithmetic, we have the standard inequalities: -- `‖M‖₂ ≤ ‖M‖_F` -- `‖M‖₂ ≤ sqrt(‖M‖₁ · ‖M‖∞)` - -In exact real arithmetic, taking `min` can only tighten a valid upper bound. -Here we compute both in `Float`, so treat the result as an estimate. --/ -def opNormUpperBoundCheap (M : ConcreteMatrix) : Float := - let frob := M.frobeniusNorm - let schur := M.schurNormEst - min frob schur - -/-- Dense Frobenius upper bound on `‖M‖₂`: `‖M‖₂ ≤ ‖M‖_F`. - -This is cheap and always valid in exact real arithmetic. --/ -def opNormUpperBoundDenseFrob (M : ConcreteMatrix) : Float := - M.frobeniusNorm - -/-- Dense Schur upper bound on `‖M‖₂`: `‖M‖₂ ≤ sqrt(‖M‖₁‖M‖∞)`. - -This is cheap and always valid in exact real arithmetic. --/ -def opNormUpperBoundDenseSchur (M : ConcreteMatrix) : Float := - M.opNormUpperBoundOneInf - -/-- Induced `∞` norm with absolute values: `max_i Σ_j |M[i,j]|`. - -This is the standard induced matrix norm `‖M‖_∞`. -We compute it deterministically in `Float` and interpret the result as a real-number -expression over the concrete Float entries. --/ -def infNormAbs (M : ConcreteMatrix) : Float := Id.run do - let mut maxSum : Float := 0.0 - for i in [:M.numRows] do - let mut rowSum : Float := 0.0 - let rowBase := i * M.numCols - for j in [:M.numCols] do - rowSum := rowSum + Float.abs (M.data[rowBase + j]!) - maxSum := max maxSum rowSum - maxSum - -/-- Upper bound on the operator norm via the Gram matrix and the induced `∞` norm. - -For a real matrix `W` we have: -- `‖W‖₂² = λ_max(WᵀW)` -- `‖G‖₂ ≤ ‖G‖∞` for any real matrix `G` - -Therefore `‖W‖₂ ≤ sqrt(‖WᵀW‖∞)`. - -This computes the quantity `sqrt(max_i Σ_j |(WᵀW)[i,j]|)` **without allocating** `WᵀW`. - -Note: This is computed using `Float` arithmetic; we use it as a deterministic bound for the -matrix obtained by interpreting the Float entries as real numbers. - -PERFORMANCE: O(numRows * numCols^2). Intended for factor matrices with small `numCols` -(e.g. `modelDim×headDim`). --/ -def opNormUpperBoundViaGramInf (W : ConcreteMatrix) : Float := Id.run do - if W.numCols = 0 then - return 0.0 - - let mut maxRowSum : Float := 0.0 - for i in [:W.numCols] do - let mut rowSum : Float := 0.0 - for j in [:W.numCols] do - let mut gij : Float := 0.0 - for k in [:W.numRows] do - let rowBase := k * W.numCols - -- SAFETY: `k < numRows` and `i,j < numCols`, so indices are within `data.size`. - let wi := W.data[rowBase + i]! - let wj := W.data[rowBase + j]! - gij := gij + wi * wj - rowSum := rowSum + Float.abs gij - maxRowSum := max maxRowSum rowSum - - -- Guard against negative zero / NaNs propagating into sqrt. - Float.sqrt (max 0.0 maxRowSum) - -/-- Transpose a matrix. -/ -def transpose (M : ConcreteMatrix) : ConcreteMatrix where - numRows := M.numCols - numCols := M.numRows - data := .ofFn fun idx : Fin (M.numCols * M.numRows) => - let i := idx.val / M.numRows - let j := idx.val % M.numRows - M.data[j * M.numCols + i]! - size_eq := Array.size_ofFn - - -/-- Diagnostics for the rectangular Gram-based operator-norm bound. - -If `usedGram=true`, then we formed the smaller Gram matrix `G` and bounded -`λ_max(G)` by several PSD inequalities, returning `sqrt(λ_upper)`. -Otherwise (size-capped), we fall back to cheap dense bounds. --/ -structure RectGramDiag where - usedGram : Bool - /-- Used the absolute-Gram fallback (no materialized Gram). -/ - usedAbsGram : Bool - /-- True if we computed the absolute-Gram bounds (even if they weren't chosen). -/ - computedAbsGram : Bool - /-- True if we computed a signed Gram matrix candidate (even if it wasn't chosen). -/ - computedGram : Bool - /-- True if we would have formed a signed Gram (by `maxGramDim`) but skipped it via a deterministic cost guard. -/ - skippedGram : Bool - /-- The `maxGramDim` cap from the `BoundEffort` used. -/ - maxGramDimCap : Nat - /-- Whether signed-Gram candidates were enabled in the `BoundEffort` used. -/ - signedGramEnabled : Bool - /-- Estimated cost for materializing Gram (`k*k*max(m,n)`). -/ - gramCost : Nat - /-- Deterministic Gram-cost limit used for scalability guard. -/ - gramCostLimit : Nat - gramDim : Nat - lambdaBrauer : Float - lambdaMoment : Float - lambdaGersh : Float - /-- Gershgorin upper bound computed without forming the Gram matrix. -/ - lambdaAbsGersh : Float - /-- Brauer/Cassini upper bound computed without forming the Gram matrix. -/ - lambdaAbsBrauer : Float - lambdaUsed : Float - opBound : Float - frobBound : Float - oneInfBound : Float - deriving Repr - -/-- How much work to spend computing rigorous upper bounds. - -Higher-effort modes must be **monotone**: they may add additional rigorous candidates and then -take the minimum, but they must never remove candidates. This guarantees that "upgrading" -cannot increase the returned upper bound (up to tiny Float noise). --/ -structure BoundEffort where - /-- Maximum Gram dimension allowed for materializing a (signed) Gram matrix. -/ - maxGramDim : Nat := 256 - /-- Deterministic cost guard for materializing a `k×k` signed Gram matrix, measured as - `k*k*max(m,n)` where `m×n` is the matrix size and `k = min(m,n)`. -/ - gramCostLimit : Nat := 200000000 - /-- Allow the absolute-Gram fallback (no materialized Gram). -/ - enableAbsGram : Bool := true - /-- Allow materializing a signed Gram matrix when `gramDim ≤ maxGramDim`. -/ - enableSignedGram : Bool := true - /-- Allow PSD moment bounds (when available). -/ - enableMoment : Bool := true - deriving Repr, Inhabited - -namespace BoundEffort - -/-- Tier-0: cheap bounds only. -/ -def tier0 : BoundEffort := - { maxGramDim := 0 - gramCostLimit := 0 - enableAbsGram := false - enableSignedGram := false - enableMoment := false } - -/-- Tier-1: enable abs-Gram fallback; keep default signed-Gram cap. -/ -def tier1 : BoundEffort := - { maxGramDim := 256 - gramCostLimit := 200000000 - enableAbsGram := true - enableSignedGram := true - enableMoment := true } - -/-- Tier-2: allow larger signed-Gram (may be expensive). -/ -def tier2 : BoundEffort := - { maxGramDim := 768 - gramCostLimit := 2000000000 - enableAbsGram := true - enableSignedGram := true - enableMoment := true } - -/-- Tier-3: placeholder for future rigorous tightenings (currently same as tier2). -/ -def tier3 : BoundEffort := tier2 - -/-- Ordered list of effort tiers (monotone increasing compute). -/ -def tiers : Array BoundEffort := #[tier0, tier1, tier2, tier3] - -end BoundEffort - -/-- Brauer/Cassini upper bound on `λ_max(G)` for an explicit symmetric matrix `G`. - -This is the same Cassini-ovals formula used for Gram matrices, but computed from -the materialized matrix `G` (intended for small `k×k`). - -Guardrails: -- Clamp discriminants to `≥ 0`. -- If NaN/Inf appears, fall back to the induced-∞ bound `‖G‖_∞`. --/ -def symmLambdaMaxUpperBrauer (G : ConcreteMatrix) : Float := Id.run do - let n := G.numRows - if n = 0 || G.numCols ≠ n then - return 0.0 - - let mut maxDiag : Float := 0.0 - let mut infBound : Float := 0.0 - let mut bad : Bool := false - let mut ds : Array Float := Array.mkEmpty n - let mut rs : Array Float := Array.mkEmpty n - - for i in [:n] do - let di := G.data[i * n + i]! - ds := ds.push di - maxDiag := max maxDiag di - - for i in [:n] do - let mut ri : Float := 0.0 - let rowBase := i * n - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (G.data[rowBase + j]!) - rs := rs.push ri - let di := ds[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - infBound := max infBound (Float.abs di + ri) - - if bad then - return infBound - if n < 2 then - return maxDiag - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return infBound - else - return max maxDiag maxPair - -/-- Gershgorin and Brauer/Cassini upper bounds for the reduced Gram matrix, without forming it. - -Let `A` be `m×n`. We reduce to `G = A Aᵀ` if `m ≤ n` and `G = Aᵀ A` otherwise, so -`‖A‖₂² = λ_max(G)` (exact real arithmetic). - -We avoid forming `G` by bounding absolute row sums via: -- If `m > n` (`G = AᵀA`): use `s_row[k] = Σ_j |A[k,j]|` and - `Σ_j |G[i,j]| ≤ Σ_k |A[k,i]| * s_row[k]`. -- If `m ≤ n` (`G = AAᵀ`): use `s_col[t] = Σ_i |A[i,t]|` and - `Σ_j |G[i,j]| ≤ Σ_t |A[i,t]| * s_col[t]`. - -The first component is the Gershgorin/∞ bound `max_i rowSumUpper[i]`. -The second is the Brauer/Cassini bound computed from `diag[i]` and -`offSumUpper[i] = max(0, rowSumUpper[i] - diag[i])`. - -If NaN/Inf appears in the Brauer calculation, we conservatively fall back to the -Gershgorin bound. --/ -def rectAbsGramLambdaUpperGershBrauer (A : ConcreteMatrix) : Float × Float := Id.run do - let m := A.numRows - let n := A.numCols - if m = 0 || n = 0 then - return (0.0, 0.0) - - if m > n then - -- `G = AᵀA` (size `n×n`), using row absolute sums of `A`. - let mut sRow : Array Float := Array.replicate m 0.0 - for k in [:m] do - let mut acc : Float := 0.0 - let rowBase := k * n - for j in [:n] do - acc := acc + Float.abs (A.data[rowBase + j]!) - sRow := sRow.set! k acc - - let mut diag : Array Float := Array.replicate n 0.0 - let mut rowSumUpper : Array Float := Array.replicate n 0.0 - for k in [:m] do - let s := sRow[k]! - let rowBase := k * n - for i in [:n] do - let a := A.data[rowBase + i]! - diag := diag.set! i (diag[i]! + a * a) - rowSumUpper := rowSumUpper.set! i (rowSumUpper[i]! + Float.abs a * s) - - let mut lambdaAbsGersh : Float := 0.0 - for i in [:n] do - lambdaAbsGersh := max lambdaAbsGersh rowSumUpper[i]! - - if n < 2 then - return (lambdaAbsGersh, diag[0]!) - - let mut maxPair : Float := 0.0 - let mut bad : Bool := false - for i in [:n] do - let di := diag[i]! - let ri := max 0.0 (rowSumUpper[i]! - di) - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - for j in [i + 1:n] do - let dj := diag[j]! - let rj := max 0.0 (rowSumUpper[j]! - dj) - if Float.isNaN dj || Float.isInf dj || Float.isNaN rj || Float.isInf rj then - bad := true - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - let lambdaAbsBrauer := if bad then lambdaAbsGersh else maxPair - return (lambdaAbsGersh, lambdaAbsBrauer) - else - -- `G = AAᵀ` (size `m×m`), using column absolute sums of `A`. - let mut sCol : Array Float := Array.replicate n 0.0 - for i in [:m] do - let rowBase := i * n - for t in [:n] do - let a := A.data[rowBase + t]! - sCol := sCol.set! t (sCol[t]! + Float.abs a) - - let mut diag : Array Float := Array.replicate m 0.0 - let mut rowSumUpper : Array Float := Array.replicate m 0.0 - for i in [:m] do - let mut di : Float := 0.0 - let mut ru : Float := 0.0 - let rowBase := i * n - for t in [:n] do - let a := A.data[rowBase + t]! - di := di + a * a - ru := ru + Float.abs a * sCol[t]! - diag := diag.set! i di - rowSumUpper := rowSumUpper.set! i ru - - let mut lambdaAbsGersh : Float := 0.0 - for i in [:m] do - lambdaAbsGersh := max lambdaAbsGersh rowSumUpper[i]! - - if m < 2 then - return (lambdaAbsGersh, diag[0]!) - - let mut maxPair : Float := 0.0 - let mut bad : Bool := false - for i in [:m] do - let di := diag[i]! - let ri := max 0.0 (rowSumUpper[i]! - di) - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - for j in [i + 1:m] do - let dj := diag[j]! - let rj := max 0.0 (rowSumUpper[j]! - dj) - if Float.isNaN dj || Float.isInf dj || Float.isNaN rj || Float.isInf rj then - bad := true - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - let lambdaAbsBrauer := if bad then lambdaAbsGersh else maxPair - return (lambdaAbsGersh, lambdaAbsBrauer) - -/-- Effort-configurable variant of `opNormUpperBoundRectGramDiag`. - -This is the entrypoint used by the adaptive scheduler: it can add expensive-but-rigorous -candidates while remaining monotone (the final `opBound` is always the minimum of enabled -rigorous candidates). --/ -def opNormUpperBoundRectGramDiagEffort (A : ConcreteMatrix) (effort : BoundEffort) : RectGramDiag := - let frob := A.frobeniusNorm - let oneInf := A.opNormUpperBoundOneInf - let cheap := min frob oneInf - let m := A.numRows - let n := A.numCols - let k := min m n - let costLimit : Nat := effort.gramCostLimit - let cost : Nat := k * k * (max m n) - - if k = 0 then - { usedGram := true - usedAbsGram := false - computedAbsGram := false - computedGram := false - skippedGram := false - maxGramDimCap := effort.maxGramDim - signedGramEnabled := effort.enableSignedGram - gramCost := 0 - gramCostLimit := costLimit - gramDim := 0 - lambdaBrauer := 0.0 - lambdaMoment := 0.0 - lambdaGersh := 0.0 - lambdaAbsGersh := 0.0 - lambdaAbsBrauer := 0.0 - lambdaUsed := 0.0 - opBound := 0.0 - frobBound := frob - oneInfBound := oneInf } - else Id.run do - let mut bestOp : Float := cheap - let mut usedAbs : Bool := false - let mut usedGram : Bool := false - let mut computedAbsGram : Bool := false - let mut computedGram : Bool := false - let mut skippedGram : Bool := false - let mut lambdaAbsGersh : Float := 0.0 - let mut lambdaAbsBrauer : Float := 0.0 - let mut lambdaGersh : Float := 0.0 - let mut lambdaBrauer : Float := 0.0 - let mut lambdaMoment : Float := 0.0 - let mut lambdaUsed : Float := max 0.0 (cheap * cheap) - - if effort.enableAbsGram then - computedAbsGram := true - let (lG, lB) := rectAbsGramLambdaUpperGershBrauer A - lambdaAbsGersh := lG - lambdaAbsBrauer := lB - let lUpper := min lG lB - let opAbsRaw := Float.sqrt (max 0.0 lUpper) - let opAbs : Float := - if Float.isNaN opAbsRaw || Float.isInf opAbsRaw then - Float.inf - else - opAbsRaw - if opAbs < bestOp then - bestOp := opAbs - lambdaUsed := max 0.0 lUpper - usedAbs := true - usedGram := false - - if effort.enableSignedGram && k ≤ effort.maxGramDim then - -- Deterministic scalability guard: forming `k×k` Gram matrices can be prohibitively expensive - -- for large rectangular matrices (e.g. 768×3072). We therefore skip this candidate when the - -- estimated matmul cost is too high. - -- - -- Correctness is preserved: skipping a candidate can only loosen the final bound. - let allowGram : Bool := cost ≤ costLimit - if allowGram then - computedGram := true - let G : ConcreteMatrix := - if m ≤ n then - A.matmul A.transpose - else - A.transpose.matmul A - lambdaGersh := G.infNormAbs - lambdaBrauer := symmLambdaMaxUpperBrauer G - let mut lUpper : Float := min lambdaGersh lambdaBrauer - if effort.enableMoment then - -- For Gram `G`, `trace(G) = ‖A‖_F²`. - let tr : Float := A.frobeniusNormSq - let f2 : Float := G.frobeniusNormSq - lambdaMoment := psdLambdaMaxUpperMoment k tr f2 - lUpper := min lUpper lambdaMoment - else - lambdaMoment := 0.0 - let op := Float.sqrt (max 0.0 lUpper) - if op < bestOp then - bestOp := op - lambdaUsed := max 0.0 lUpper - usedAbs := false - usedGram := true - else - skippedGram := true - - { usedGram := usedGram - usedAbsGram := usedAbs - computedAbsGram := computedAbsGram - computedGram := computedGram - skippedGram := skippedGram - maxGramDimCap := effort.maxGramDim - signedGramEnabled := effort.enableSignedGram - gramCost := cost - gramCostLimit := costLimit - gramDim := k - lambdaBrauer := lambdaBrauer - lambdaMoment := lambdaMoment - lambdaGersh := lambdaGersh - lambdaAbsGersh := lambdaAbsGersh - lambdaAbsBrauer := lambdaAbsBrauer - lambdaUsed := lambdaUsed - opBound := min cheap bestOp - frobBound := frob - oneInfBound := oneInf } - -/-- Backwards-compatible wrapper around `opNormUpperBoundRectGramDiagEffort`. -/ -def opNormUpperBoundRectGramDiag (A : ConcreteMatrix) (maxGramDim : Nat := 256) : RectGramDiag := - opNormUpperBoundRectGramDiagEffort A { maxGramDim := maxGramDim } - -/-- Rectangular Gram-based operator-norm bound (with a size-cap fallback). -/ -def opNormUpperBoundRectGram (A : ConcreteMatrix) (maxGramDim : Nat := 256) : Float := - (A.opNormUpperBoundRectGramDiag maxGramDim).opBound - -/-- Effort-configurable variant of `opNormUpperBoundRectGram`. -/ -def opNormUpperBoundRectGramEffort (A : ConcreteMatrix) (effort : BoundEffort) : Float := - (A.opNormUpperBoundRectGramDiagEffort effort).opBound - -/-- Gram-matrix induced-∞ upper bound on the spectral norm. - -In exact real arithmetic: -`‖M‖₂² = λ_max(MᵀM) ≤ ‖MᵀM‖_∞`, hence `‖M‖₂ ≤ sqrt(‖MᵀM‖_∞)`. - -This allocates `MᵀM`, so it is intended for small matrices. -If the computation produces NaN/Inf, we conservatively fall back to `‖M‖_F`. --/ -def gramInfOpBound (M : ConcreteMatrix) : Float := - if M.numRows = 0 || M.numCols = 0 then - 0.0 - else - let g := M.transpose.matmul M - let v := Float.sqrt (max 0.0 g.infNormAbs) - if Float.isNaN v || Float.isInf v then - M.frobeniusNorm - else - v - -/-- Compute an entry of the Gram matrix `MᵀM` without materializing it. - -`gramMatrixEntry M i j = (MᵀM)[i,j] = Σ_k M[k,i] * M[k,j]`. - -This is intended for small `numCols` (e.g. `headDim=64`). --/ -def gramMatrixEntry (M : ConcreteMatrix) (i j : Nat) : Float := Id.run do - if M.numRows = 0 || M.numCols = 0 then - return 0.0 - if i ≥ M.numCols || j ≥ M.numCols then - return 0.0 - let mut acc : Float := 0.0 - for k in [:M.numRows] do - let rowBase := k * M.numCols - let mi := M.data[rowBase + i]! - let mj := M.data[rowBase + j]! - acc := acc + mi * mj - return acc - -/-- Gram diagonal entry `d_i = (MᵀM)[i,i] = Σ_k M[k,i]^2`. - -For real `M`, these are nonnegative in exact arithmetic. --/ -def gramDiag (M : ConcreteMatrix) (i : Nat) : Float := Id.run do - if M.numRows = 0 || M.numCols = 0 then - return 0.0 - if i ≥ M.numCols then - return 0.0 - let mut acc : Float := 0.0 - for k in [:M.numRows] do - let rowBase := k * M.numCols - let mi := M.data[rowBase + i]! - acc := acc + mi * mi - return acc - -/-- Off-diagonal absolute row sum of the Gram matrix. - -Let `G = MᵀM`. This computes -`R_i = Σ_{j ≠ i} |G[i,j]|`. --/ -def gramRowAbsSumExclDiag (M : ConcreteMatrix) (i : Nat) : Float := Id.run do - if M.numCols = 0 || i ≥ M.numCols then - return 0.0 - let mut acc : Float := 0.0 - for j in [:M.numCols] do - if j = i then - continue - acc := acc + Float.abs (M.gramMatrixEntry i j) - return acc - -/-- Frobenius norm squared of the Gram matrix `G = MᵀM`, computed without allocating `G`. - -This returns `‖G‖_F² = Σ_{i,j} G[i,j]^2`. -Intended for small `numCols` (e.g. `headDim=64`). --/ -def gramFrobeniusNormSq (M : ConcreteMatrix) : Float := Id.run do - let n := M.numCols - if n = 0 then - return 0.0 - let mut acc : Float := 0.0 - for i in [:n] do - for j in [:n] do - let gij := M.gramMatrixEntry i j - acc := acc + gij * gij - return acc - -/-- Brauer/Cassini upper bound on `λ_max(G)` for a Gram matrix `G = MᵀM`. - -Mathematical facts (exact real arithmetic): - -Let `G` be real symmetric (Gram matrices are symmetric PSD). Define: -- `d_i = G[i,i]` -- `R_i = Σ_{j≠i} |G[i,j]|` - -Brauer bound (Cassini ovals): -`λ_max(G) ≤ max_{i≠j} b_ij`, where -`b_ij = (d_i + d_j + sqrt((d_i - d_j)^2 + 4*R_i*R_j)) / 2`. - -We also have the induced-∞ / Gershgorin bound: -`λ_max(G) ≤ max_i (d_i + R_i) = ‖G‖_∞`. - -Guardrails: -- Clamp the discriminant inside `sqrt` to `≥ 0`. -- If any NaN/Inf appears, conservatively fall back to `‖G‖_∞`. -- If `n < 2`, return `max_i d_i`. --/ -def gramLambdaMaxUpperBrauer (M : ConcreteMatrix) : Float := Id.run do - let n := M.numCols - if n = 0 then - return 0.0 - - let ds : Array Float := .ofFn fun i : Fin n => M.gramDiag i.val - let rs : Array Float := .ofFn fun i : Fin n => M.gramRowAbsSumExclDiag i.val - - let mut maxDiag : Float := 0.0 - let mut infBound : Float := 0.0 - let mut bad : Bool := false - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - maxDiag := max maxDiag di - infBound := max infBound (Float.abs di + ri) - - if bad then - return infBound - - if n < 2 then - return maxDiag - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return infBound - else - -- Brauer/Cassini bound candidate - let lambdaBrauerUpper := max maxDiag maxPair - -- Moment bound candidate for PSD `G = MᵀM`. - -- Here `trace(G) = ‖M‖_F²`. - let tr : Float := M.frobeniusNormSq - let f2 : Float := M.gramFrobeniusNormSq - let lambdaMomentUpper := psdLambdaMaxUpperMoment n tr f2 - -- Combine cheap valid upper bounds by taking `min`. - return min infBound (min lambdaBrauerUpper lambdaMomentUpper) - -/-- Brauer/Cassini upper bound on the spectral radius for a general square matrix. - -For any eigenvalue `λ` there exist `i ≠ j` with -`|λ - a_ii| · |λ - a_jj| ≤ R_i R_j`, where `R_i` is the off-diagonal row sum. -By the reverse triangle inequality, this yields a real upper bound on `|λ|`. --/ -def lambdaMaxUpperBrauer (M : ConcreteMatrix) : Float := Id.run do - let n := M.numRows - if n = 0 || M.numCols ≠ n then - return 0.0 - - let mut maxDiag : Float := 0.0 - let mut infBound : Float := 0.0 - let mut bad : Bool := false - let mut ds : Array Float := Array.mkEmpty n - let mut rs : Array Float := Array.mkEmpty n - - for i in [:n] do - let di := Float.abs (M.data[i * n + i]!) - ds := ds.push di - maxDiag := max maxDiag di - - for i in [:n] do - let mut ri : Float := 0.0 - let rowBase := i * n - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (M.data[rowBase + j]!) - rs := rs.push ri - let di := ds[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - infBound := max infBound (di + ri) - - if bad then - return infBound - if n < 2 then - return maxDiag - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return infBound - else - return max maxDiag maxPair - -/-- Gershgorin upper bound on `λ_max` for matrices with nonnegative real eigenvalues. - -Uses `λ ≤ max_i (a_ii + Σ_{j≠i} |a_ij|)` (no absolute on the diagonal). --/ -def lambdaMaxUpperGershNonneg (M : ConcreteMatrix) : Float := Id.run do - let n := M.numRows - if n = 0 || M.numCols ≠ n then - return 0.0 - let mut maxBound : Float := 0.0 - for i in [:n] do - let rowBase := i * n - let di := M.data[rowBase + i]! - let mut ri : Float := 0.0 - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (M.data[rowBase + j]!) - let bound := max 0.0 (di + ri) - if Float.isNaN bound || Float.isInf bound then - return M.infNormAbs - if bound > maxBound then - maxBound := bound - return maxBound - -/-- Brauer/Cassini upper bound on `λ_max` for matrices with nonnegative real eigenvalues. - -This mirrors `lambdaMaxUpperBrauer` but keeps signed diagonals. Since the eigenvalues are -assumed nonnegative, using `a_ii` (not `|a_ii|`) can tighten the bound. --/ -def lambdaMaxUpperBrauerNonneg (M : ConcreteMatrix) : Float := Id.run do - let n := M.numRows - if n = 0 || M.numCols ≠ n then - return 0.0 - - let mut maxDiag : Float := 0.0 - let mut gersh : Float := 0.0 - let mut bad : Bool := false - let mut ds : Array Float := Array.mkEmpty n - let mut rs : Array Float := Array.mkEmpty n - - for i in [:n] do - let di := M.data[i * n + i]! - ds := ds.push di - maxDiag := max maxDiag di - - for i in [:n] do - let mut ri : Float := 0.0 - let rowBase := i * n - for j in [:n] do - if j = i then - continue - ri := ri + Float.abs (M.data[rowBase + j]!) - rs := rs.push ri - let di := ds[i]! - if Float.isNaN di || Float.isInf di || Float.isNaN ri || Float.isInf ri then - bad := true - gersh := max gersh (max 0.0 (di + ri)) - - if bad then - return gersh - if n < 2 then - return max gersh (max 0.0 maxDiag) - - let mut maxPair : Float := 0.0 - for i in [:n] do - let di := ds[i]! - let ri := rs[i]! - for j in [i + 1:n] do - let dj := ds[j]! - let rj := rs[j]! - let delta := di - dj - let disc := max 0.0 (delta * delta + 4.0 * ri * rj) - let root := Float.sqrt disc - let bij := (di + dj + root) / 2.0 - if Float.isNaN bij || Float.isInf bij then - bad := true - maxPair := max maxPair bij - - if bad then - return gersh - else - return max gersh (max 0.0 maxPair) - -/-- Dense (small) spectral-norm upper bound using the Brauer/Cassini Gram bound. - -`‖M‖₂² = λ_max(MᵀM) ≤ gramLambdaMaxUpperBrauer(M)`. --/ -def opNormUpperBoundDenseBrauer (M : ConcreteMatrix) : Float := - Float.sqrt (max 0.0 (M.gramLambdaMaxUpperBrauer)) - -/-- Matrix addition. Returns zero matrix if dimensions don't match. -/ -def add (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = B.numRows ∧ A.numCols = B.numCols then - { - numRows := A.numRows - numCols := A.numCols - data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data.getD idx.val 0.0 + B.data.getD idx.val 0.0 - size_eq := Array.size_ofFn - } - else zeros 0 0 - -/-- Sum a list of same-shape matrices in a single pass. - -This avoids repeated `add` allocations (which would traverse the matrix multiple times). -When `allowParallel=true` and the matrices are sufficiently large, it uses row-chunk parallelism. --/ -def sumMatrices (ms : Array ConcreteMatrix) (allowParallel : Bool := true) : ConcreteMatrix := - if ms.isEmpty then - zeros 0 0 - else - letI : Inhabited ConcreteMatrix := ⟨zeros 0 0⟩ - let base := ms[0]! - let rows := base.numRows - let cols := base.numCols - let okDims := - ms.all fun M => M.numRows = rows ∧ M.numCols = cols - if !okDims then - zeros 0 0 - else - let entries := rows * cols - let shouldPar : Bool := - allowParallel && - entries ≥ 200_000 && - rows ≥ 2 && - cols ≥ 256 && - ms.size ≥ 4 - if !shouldPar then - { - numRows := rows - numCols := cols - data := .ofFn fun idx : Fin entries => Id.run do - let mut acc : Float := 0.0 - for M in ms do - acc := acc + M.data[idx.val]! - return acc - size_eq := Array.size_ofFn - } - else - let numTasks : Nat := min 16 rows - let q := rows / numTasks - let r := rows % numTasks - let tasks : Array (Task (Array Float)) := - .ofFn fun t : Fin numTasks => - Task.spawn (fun _ => - let tid := t.val - let extra := if tid < r then 1 else 0 - let rowsHere := q + extra - let startRow := tid * q + min tid r - let chunkSize := rowsHere * cols - .ofFn fun idx : Fin chunkSize => Id.run do - let localRow := idx.val / cols - let j := idx.val % cols - let i := startRow + localRow - let globalIdx := i * cols + j - let mut acc : Float := 0.0 - for M in ms do - acc := acc + M.data[globalIdx]! - return acc) - -- Join in increasing task index order (deterministic). - let chunks := tasks.map Task.get - let cutoff := (q + 1) * r - { - numRows := rows - numCols := cols - data := .ofFn fun idx : Fin entries => - let row := idx.val / cols - let col := idx.val % cols - let taskIdx := - if row < cutoff then - row / (q + 1) - else - r + (row - cutoff) / q - let localRow := - if row < cutoff then - row % (q + 1) - else - (row - cutoff) % q - let chunk := chunks[taskIdx]! - chunk[localRow * cols + col]! - size_eq := Array.size_ofFn - } - -/-- Scalar multiplication. -/ -def scale (c : Float) (M : ConcreteMatrix) : ConcreteMatrix where - numRows := M.numRows - numCols := M.numCols - data := M.data.map (c * ·) - size_eq := by simp [M.size_eq] - -/-- Get row i as a 1×numCols matrix. -/ -def getRow (M : ConcreteMatrix) (i : Nat) : ConcreteMatrix := - if i < M.numRows then - { - numRows := 1 - numCols := M.numCols - data := .ofFn fun j : Fin M.numCols => M.getUnsafe i j.val - size_eq := by simp - } - else zeros 1 M.numCols - -/-- Set row i from a 1×numCols matrix. Returns original if dimensions wrong. -/ -def setRow (M : ConcreteMatrix) (i : Nat) (row : ConcreteMatrix) : ConcreteMatrix := - if i < M.numRows ∧ row.numRows = 1 ∧ row.numCols = M.numCols then - { - numRows := M.numRows - numCols := M.numCols - data := .ofFn fun idx : Fin (M.numRows * M.numCols) => - let r := idx.val / M.numCols - let c := idx.val % M.numCols - if r = i then row.getUnsafe 0 c else M.getUnsafe r c - size_eq := Array.size_ofFn - } - else M - -/-- Element-wise application of a function. -/ -def map (f : Float → Float) (M : ConcreteMatrix) : ConcreteMatrix where - numRows := M.numRows - numCols := M.numCols - data := M.data.map f - size_eq := by simp [M.size_eq] - -/-- Broadcast add: add a 1×numCols bias to each row. -/ -def addBias (M : ConcreteMatrix) (bias : ConcreteMatrix) : ConcreteMatrix := - if bias.numRows = 1 ∧ bias.numCols = M.numCols then - { - numRows := M.numRows - numCols := M.numCols - data := .ofFn fun idx : Fin (M.numRows * M.numCols) => - let c := idx.val % M.numCols - M.data[idx.val]! + bias.data[c]! - size_eq := Array.size_ofFn - } - else M - -/-- Row-wise LayerNorm with learnable scale γ and bias β (both 1×numCols). - -This implements the Pre-LN transformer normalization convention: each token (row) -is normalized across model dimension (columns), then scaled and shifted. --/ -def layerNormRowwise (X γ β : ConcreteMatrix) (eps : Float := 1e-5) : ConcreteMatrix := Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then - return ConcreteMatrix.zeros rows cols - if !(γ.numRows = 1 ∧ γ.numCols = cols ∧ β.numRows = 1 ∧ β.numCols = cols) then - return X - let colsF := cols.toFloat - let gammaData := γ.data - let betaData := β.data - - -- Per-row mean and inverse stddev (compute once for speed). - let mut means : Array Float := Array.mkEmpty rows - let mut invStds : Array Float := Array.mkEmpty rows - for r in [:rows] do - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / colsF - let mut varSum : Float := 0.0 - for c in [:cols] do - let d := X.data[rowBase + c]! - μ - varSum := varSum + d * d - -- In exact arithmetic, `var ≥ 0`. Clamp to avoid NaN from tiny negative float noise. - let var := max 0.0 (varSum / colsF) - let invσ := 1.0 / Float.sqrt (var + eps) - means := means.push μ - invStds := invStds.push invσ - - return { - numRows := rows - numCols := cols - data := .ofFn fun idx : Fin (rows * cols) => - let r := idx.val / cols - let c := idx.val % cols - let μ := means.getD r 0.0 - let invσ := invStds.getD r 0.0 - let normalized := (X.data[r * cols + c]! - μ) * invσ - (gammaData[c]!) * normalized + (betaData[c]!) - size_eq := Array.size_ofFn - } - -/-- Heuristic estimate for the operator norm of the Jacobian of row-wise LayerNorm. - -We bound the **global** operator norm on the block-diagonal Jacobian (one block per row) -by the maximum over rows of a **tight spectral-norm bound**. - -For a single row `x : ℝ^d` (ignoring β), LayerNorm is: -`LN(x) = γ ⊙ ((x - μ) / σ)` with `σ = sqrt(var + eps)`. -Its Jacobian has the closed form: - -`J = diag(γ) * (1/σ) * (I - (1/d)11ᵀ - (1/d)vvᵀ)` - -where `v` is the centered vector scaled by `1/σ`. The symmetric matrix in parentheses -has eigenvalues `{0, 1, eps/(var+eps)}` so its spectral norm is exactly `1`. -Therefore `‖J‖₂ ≤ max |γ| / σ` in exact real arithmetic. - -This avoids the previous row-sum bound which could overestimate by orders of magnitude -and made downstream certification thresholds unusable. --/ -def layerNormRowwiseOpEst (X γ : ConcreteMatrix) (eps : Float := 1e-5) : Float := Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then return 0.0 - if !(γ.numRows = 1 ∧ γ.numCols = cols) then return 0.0 - let colsF := cols.toFloat - let gammaData := γ.data - - -- max |γ| - let mut gammaMaxAbs : Float := 0.0 - for c in [:cols] do - let g := Float.abs (gammaData[c]!) - if Float.isNaN g || Float.isInf g then - gammaMaxAbs := Float.inf - else if g > gammaMaxAbs then - gammaMaxAbs := g - - -- max_r (1/σ_r) - let mut maxInvStd : Float := 0.0 - for r in [:rows] do - -- Mean - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / colsF - - -- Variance - let mut varSum : Float := 0.0 - for c in [:cols] do - let centered := X.data[rowBase + c]! - μ - varSum := varSum + centered * centered - let varRaw := varSum / colsF - -- In exact arithmetic, `var ≥ 0`. If we see NaN/Inf, conservatively treat as 0 - -- so that `1/σ ≤ 1/sqrt(eps)`. - let var := - if Float.isNaN varRaw || Float.isInf varRaw then 0.0 - else max 0.0 varRaw - let σ := Float.sqrt (var + eps) - if σ > 0.0 && !(Float.isNaN σ || Float.isInf σ) then - let invσ := 1.0 / σ - if invσ > maxInvStd then maxInvStd := invσ - - if maxInvStd ≤ 0.0 || Float.isNaN maxInvStd || Float.isInf maxInvStd then - maxInvStd := 1.0 / Float.sqrt eps - else - let invStdMax := 1.0 / Float.sqrt eps - if maxInvStd > invStdMax then - maxInvStd := invStdMax - - return gammaMaxAbs * maxInvStd - -structure LayerNormOpDiag where - gammaMaxAbs : Float - maxInvStd : Float - maxInvStdRow : Nat - minVar : Float - minVarRow : Nat - deriving Repr - -/-- Diagnostics for `layerNormRowwiseOpEst`: reports `max |γ|`, the worst-case row inverse-std, -and the minimum variance row (useful for spotting padding/degenerate rows). -/ -def layerNormRowwiseOpDiag (X γ : ConcreteMatrix) (eps : Float := 1e-5) : LayerNormOpDiag := Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then - return { gammaMaxAbs := 0.0, maxInvStd := 0.0, maxInvStdRow := 0, minVar := 0.0, minVarRow := 0 } - if !(γ.numRows = 1 ∧ γ.numCols = cols) then - return { gammaMaxAbs := 0.0, maxInvStd := 0.0, maxInvStdRow := 0, minVar := 0.0, minVarRow := 0 } - let colsF := cols.toFloat - let gammaData := γ.data - - let mut gammaMaxAbs : Float := 0.0 - for c in [:cols] do - let g := Float.abs (gammaData[c]!) - if Float.isNaN g || Float.isInf g then - gammaMaxAbs := Float.inf - else if g > gammaMaxAbs then - gammaMaxAbs := g - - let mut maxInvStd : Float := 0.0 - let mut maxInvStdRow : Nat := 0 - let mut minVar : Float := Float.inf - let mut minVarRow : Nat := 0 - - for r in [:rows] do - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / colsF - let mut varSum : Float := 0.0 - for c in [:cols] do - let centered := X.data[rowBase + c]! - μ - varSum := varSum + centered * centered - let varRaw := varSum / colsF - let var := - if Float.isNaN varRaw || Float.isInf varRaw then 0.0 - else max 0.0 varRaw - if var < minVar then - minVar := var - minVarRow := r - let σ := Float.sqrt (var + eps) - if σ > 0.0 then - let invσ := 1.0 / σ - if invσ > maxInvStd then - maxInvStd := invσ - maxInvStdRow := r - - if Float.isInf minVar || Float.isNaN minVar then - minVar := 0.0 - if maxInvStd ≤ 0.0 || Float.isNaN maxInvStd || Float.isInf maxInvStd then - maxInvStd := 1.0 / Float.sqrt eps - - return { - gammaMaxAbs := gammaMaxAbs - maxInvStd := maxInvStd - maxInvStdRow := maxInvStdRow - minVar := minVar - minVarRow := minVarRow - } - -/-- Get column j as a vector (stored as numRows×1 matrix). -/ -def getCol (M : ConcreteMatrix) (j : Nat) : ConcreteMatrix := - if j < M.numCols then - { - numRows := M.numRows - numCols := 1 - data := .ofFn fun i : Fin M.numRows => M.getUnsafe i.val j - size_eq := by simp - } - else zeros M.numRows 1 - -/-- Compute matrix-vector product M * v where v is stored as numCols×1 matrix. - Returns a numRows×1 matrix. -/ -def matVecMul (M : ConcreteMatrix) (v : ConcreteMatrix) : ConcreteMatrix := - if M.numCols = v.numRows ∧ v.numCols = 1 then - { - numRows := M.numRows - numCols := 1 - data := .ofFn fun i : Fin M.numRows => Id.run do - let mut acc : Float := 0.0 - let rowBase := i.val * M.numCols - for k in [:M.numCols] do - acc := acc + M.data[rowBase + k]! * v.data[k]! - return acc - size_eq := by simp - } - else zeros M.numRows 1 - -/-- Compute dot product of two vectors (stored as n×1 matrices). -/ -def dot (v1 v2 : ConcreteMatrix) : Float := - if v1.numRows = v2.numRows ∧ v1.numCols = 1 ∧ v2.numCols = 1 then Id.run do - let mut acc : Float := 0.0 - for i in [:v1.numRows] do - acc := acc + v1.data[i]! * v2.data[i]! - return acc - else 0.0 - -/-- Compute L2 norm of a vector (stored as n×1 matrix). -/ -def vecNorm (v : ConcreteMatrix) : Float := - if v.numCols = 1 then - Float.sqrt (v.data.foldl (fun acc x => acc + x * x) 0.0) - else 0.0 - -/-- Vector subtraction for n×1 matrices. -/ -def vecSub (v1 v2 : ConcreteMatrix) : ConcreteMatrix := - if v1.numRows = v2.numRows ∧ v1.numCols = 1 ∧ v2.numCols = 1 then - { - numRows := v1.numRows - numCols := 1 - data := .ofFn fun i : Fin v1.numRows => v1.data[i.val]! - v2.data[i.val]! - size_eq := by simp - } - else zeros v1.numRows 1 - -end ConcreteMatrix - -/-! ## Concrete LayerNorm Parameters -/ - -/-- Concrete LayerNorm parameters for Pre-LN transformers (scale γ and bias β). -/ -structure ConcreteLayerNormParams where - /-- Scale γ (1×modelDim) -/ - gamma : ConcreteMatrix - /-- Bias β (1×modelDim) -/ - beta : ConcreteMatrix - -namespace ConcreteLayerNormParams - -/-- Identity LayerNorm affine parameters: γ=1, β=0. -/ -def identity (modelDim : Nat) : ConcreteLayerNormParams := - { gamma := ConcreteMatrix.ones 1 modelDim, beta := ConcreteMatrix.zeros 1 modelDim } - -end ConcreteLayerNormParams - -/-! ## Concrete Attention Layer -/ - -/-- A concrete attention layer with exported weights. - -This structure holds the four projection matrices that define a single attention head: -- W_Q: Query projection (d × d_head) -- W_K: Key projection (d × d_head) -- W_V: Value projection (d × d_head) -- W_O: Output projection (d_head × d) --/ -structure ConcreteAttentionLayer where - /-- Model dimension (embedding size) -/ - modelDim : Nat - /-- Head dimension -/ - headDim : Nat - /-- Query projection matrix (modelDim × headDim) -/ - W_Q : ConcreteMatrix - /-- Query bias (1×headDim). -/ - b_Q : ConcreteMatrix - /-- Key projection matrix (modelDim × headDim) -/ - W_K : ConcreteMatrix - /-- Key bias (1×headDim). -/ - b_K : ConcreteMatrix - /-- Value projection matrix (modelDim × headDim) -/ - W_V : ConcreteMatrix - /-- Value bias (1×headDim). -/ - b_V : ConcreteMatrix - /-- Output projection matrix (headDim × modelDim) -/ - W_O : ConcreteMatrix - /-- Dimension consistency for W_Q -/ - W_Q_dims : W_Q.numRows = modelDim ∧ W_Q.numCols = headDim - /-- Dimension consistency for b_Q -/ - b_Q_dims : b_Q.numRows = 1 ∧ b_Q.numCols = headDim - /-- Dimension consistency for W_K -/ - W_K_dims : W_K.numRows = modelDim ∧ W_K.numCols = headDim - /-- Dimension consistency for b_K -/ - b_K_dims : b_K.numRows = 1 ∧ b_K.numCols = headDim - /-- Dimension consistency for W_V -/ - W_V_dims : W_V.numRows = modelDim ∧ W_V.numCols = headDim - /-- Dimension consistency for b_V -/ - b_V_dims : b_V.numRows = 1 ∧ b_V.numCols = headDim - /-- Dimension consistency for W_O -/ - W_O_dims : W_O.numRows = headDim ∧ W_O.numCols = modelDim - -namespace ConcreteAttentionLayer - -/-- Compute the value-output projection `W_V · W_O` (dense `modelDim×modelDim`). - -This is **diagnostics-only** on the main discovery/verification paths. Prefer using -small `headDim×headDim` Gram matrices and trace identities to compute scalar bounds. --/ -def valueOutputProjection (layer : ConcreteAttentionLayer) : ConcreteMatrix := - layer.W_V.matmul layer.W_O - -/-- Compute the query-key alignment `W_Q · W_Kᵀ` (dense `modelDim×modelDim`). - -This is **diagnostics-only** on the main discovery/verification paths. Prefer using -small `headDim×headDim` Gram matrices and trace identities to compute scalar bounds. --/ -def queryKeyAlignment (layer : ConcreteAttentionLayer) : ConcreteMatrix := - layer.W_Q.matmul layer.W_K.transpose - -/-- Compute a tighter rigorous bound on `‖W_Q W_Kᵀ‖₂` by centering `W_K` across its rows. - -Key observation: attention keys are computed from the Pre-LN activation `u = ln₁(x)`, and each row -of `u` has (approximately) zero mean across the `modelDim` features. Therefore, replacing -`W_K` by `W_K - 1·μᵀ` where `μ = (1/modelDim)·(1ᵀ W_K)` does not change `u·W_K`, hence does not -change the attention logits or their Jacobian, but can reduce the operator norm used in bounds. - -We avoid materializing `W_K - 1·μᵀ` by applying the corresponding rank-1 correction to the Gram: -`(W_K - 1·μᵀ)ᵀ (W_K - 1·μᵀ) = W_KᵀW_K - (1/modelDim)·uᵀu` where `u = 1ᵀ W_K`. --/ -def centeredQKOpBound (layer : ConcreteAttentionLayer) : Float := Id.run do - let kRows := layer.modelDim - let kCols := layer.headDim - if kRows = 0 || kCols = 0 then - return 0.0 - - -- 1) u = 1ᵀ W_K (column sums of W_K). - let mut colSums : Array Float := Array.replicate kCols 0.0 - for r in [:kRows] do - let rowBase := r * kCols - for c in [:kCols] do - colSums := colSums.set! c (colSums[c]! + layer.W_K.data[rowBase + c]!) - - -- 2) Centered Gram: G' = W_KᵀW_K - (1/kRows)·uᵀu. - let invDim := 1.0 / kRows.toFloat - let wkGram := layer.W_K.transpose.matmul layer.W_K - let wkGramCentered : ConcreteMatrix := - { numRows := kCols - numCols := kCols - data := .ofFn fun idx : Fin (kCols * kCols) => - let i := idx.val / kCols - let j := idx.val % kCols - let ui := colSums.getD i 0.0 - let uj := colSums.getD j 0.0 - wkGram.getUnsafe i j - (invDim * ui * uj) - size_eq := Array.size_ofFn } - - let wqGram := layer.W_Q.transpose.matmul layer.W_Q - - -- Candidate A: factorized Gram bound using `‖·‖∞` on Grams. - let wqOp := Float.sqrt (max 0.0 (wqGram.infNormAbs)) - let wkOpCentered := Float.sqrt (max 0.0 (wkGramCentered.infNormAbs)) - let boundFactor := wqOp * wkOpCentered - - -- Candidate B: Frobenius candidate via trace identity. - let frobSq := ConcreteMatrix.traceMul wqGram wkGramCentered - let boundFrob := Float.sqrt (max 0.0 frobSq) - - return min boundFactor boundFrob - -/-! ### No-dense scalar bounds for `W_Q·W_Kᵀ` and `W_V·W_O` - -We avoid materializing `modelDim×modelDim` products by working with `headDim×headDim` Grams. - -Key trace identities (for real matrices): -- For `W_V : d×h` and `W_O : h×d`, with `vo = W_V·W_O`: - `‖vo‖_F² = trace((W_VᵀW_V) · (W_O W_Oᵀ))`. -- For `W_Q : d×h` and `W_K : d×h`, with `qk = W_Q·W_Kᵀ`: - `‖qk‖_F² = trace((W_QᵀW_Q) · (W_KᵀW_K))`. - -These let us compute exact (in ℝ, given the Float-evaluated entries) Frobenius candidates -using only `headDim×headDim` matrices. --/ - -structure NoDenseProductBounds where - /-- `W_QᵀW_Q` (headDim×headDim). -/ - wqGram : ConcreteMatrix - /-- `W_KᵀW_K` (headDim×headDim). -/ - wkGram : ConcreteMatrix - /-- `W_VᵀW_V` (headDim×headDim). -/ - wvGram : ConcreteMatrix - /-- `W_O W_Oᵀ` (headDim×headDim). -/ - woGram : ConcreteMatrix - - /-- Frobenius-squared of `W_Q·W_Kᵀ`, computed via a `headDim×headDim` trace identity. -/ - qkFrobNormSq : Float - /-- Frobenius bound candidate `‖W_Q·W_Kᵀ‖_F`. -/ - qkDenseFrob : Float - /-- Tight Gram-product candidate derived from `headDim×headDim` Grams. -/ - qkDenseGram : Float - /-- Brauer/Cassini candidate computed on a `headDim×headDim` product with matching singular values. -/ - qkDenseBrauer : Float - qkFactorSchur : Float - qkFactorFrob : Float - wqOpGram : Float - wkOpGram : Float - qkFactorGram : Float - /-- Final chosen upper bound (min of rigorous candidates). -/ - qkOpBound : Float - - /-- Frobenius-squared of `W_V·W_O`, computed via a `headDim×headDim` trace identity. -/ - voFrobNormSq : Float - /-- Frobenius bound candidate `‖W_V·W_O‖_F`. -/ - voDenseFrob : Float - /-- Tight Gram-product candidate derived from `headDim×headDim` Grams. -/ - voDenseGram : Float - /-- Brauer/Cassini candidate computed on a `headDim×headDim` product with matching singular values. -/ - voDenseBrauer : Float - voFactorSchur : Float - voFactorFrob : Float - wvOpGram : Float - woOpGram : Float - voFactorGram : Float - /-- Final chosen upper bound (min of rigorous candidates). -/ - voOpBound : Float - -private def safeSqrt (x : Float) : Float := - Float.sqrt (max 0.0 x) - -/-- Tight Gram-based operator bound from an explicit PSD Gram matrix. -/ -private def opBoundFromGram (gram : ConcreteMatrix) (trace : Float) : Float := - if gram.numRows = 0 || gram.numCols = 0 then - 0.0 - else - let lambdaGersh := gram.infNormAbs - let lambdaBrauer := ConcreteMatrix.symmLambdaMaxUpperBrauer gram - let lambdaMoment := - ConcreteMatrix.psdLambdaMaxUpperMoment gram.numRows trace gram.frobeniusNormSq - let lambdaUpper := min lambdaGersh (min lambdaBrauer lambdaMoment) - let op := safeSqrt lambdaUpper - if Float.isNaN op || Float.isInf op then - safeSqrt lambdaGersh - else - op - -/-- Compute rigorous (in exact ℝ) scalar candidates for `‖W_Q·W_Kᵀ‖₂` and `‖W_V·W_O‖₂` -without forming any dense `modelDim×modelDim` products. -/ -def noDenseProductBounds (layer : ConcreteAttentionLayer) : NoDenseProductBounds := Id.run do - let wqGram := layer.W_Q.transpose.matmul layer.W_Q - let wkGram := layer.W_K.transpose.matmul layer.W_K - let wvGram := layer.W_V.transpose.matmul layer.W_V - let woGram := layer.W_O.matmul layer.W_O.transpose - - -- Frobenius candidates via trace identities (see docstring above). - let qkFrobSq := ConcreteMatrix.traceMul wqGram wkGram - let qkDenseFrob := safeSqrt qkFrobSq - let voFrobSq := ConcreteMatrix.traceMul wvGram woGram - let voDenseFrob := safeSqrt voFrobSq - - -- Gram-product spectral candidates (headDim×headDim): - -- ‖W_Q W_Kᵀ‖₂² = λ_max((W_QᵀW_Q)(W_KᵀW_K)) ≤ ‖(W_QᵀW_Q)(W_KᵀW_K)‖_∞. - let qkDenseGram := safeSqrt ((wqGram.matmul wkGram).infNormAbs) - -- ‖W_V W_O‖₂² = λ_max((W_VᵀW_V)(W_O W_Oᵀ)) ≤ ‖(W_VᵀW_V)(W_O W_Oᵀ)‖_∞. - let voDenseGram := safeSqrt ((wvGram.matmul woGram).infNormAbs) - - -- Brauer/Cassini candidates on `headDim×headDim` products with matching singular values. - let qkSmall := layer.W_K.transpose.matmul layer.W_Q - let qkDenseBrauer := qkSmall.opNormUpperBoundDenseBrauer - let voSmall := layer.W_O.matmul layer.W_V - let voDenseBrauer := voSmall.opNormUpperBoundDenseBrauer - - -- Factor bounds from submultiplicativity. - let qkFactorSchur := layer.W_Q.schurNormEst * layer.W_K.schurNormEst - let qkFactorFrob := layer.W_Q.frobeniusNorm * layer.W_K.frobeniusNorm - let wqOpGram0 := layer.W_Q.opNormUpperBoundViaGramInf - let wkOpGram0 := layer.W_K.opNormUpperBoundViaGramInf - let wqOpGram := min wqOpGram0 (opBoundFromGram wqGram layer.W_Q.frobeniusNormSq) - let wkOpGram := min wkOpGram0 (opBoundFromGram wkGram layer.W_K.frobeniusNormSq) - let qkFactorGram := wqOpGram * wkOpGram - - let voFactorSchur := layer.W_V.schurNormEst * layer.W_O.schurNormEst - let voFactorFrob := layer.W_V.frobeniusNorm * layer.W_O.frobeniusNorm - let wvOpGram0 := layer.W_V.opNormUpperBoundViaGramInf - -- PERFORMANCE: `W_O` is typically wide (`headDim×modelDim`), so compute on `W_Oᵀ` instead. - let woOpGram0 := layer.W_O.transpose.opNormUpperBoundViaGramInf - let wvOpGram := min wvOpGram0 (opBoundFromGram wvGram layer.W_V.frobeniusNormSq) - let woOpGram := min woOpGram0 (opBoundFromGram woGram layer.W_O.frobeniusNormSq) - let voFactorGram := wvOpGram * woOpGram - - let qkOpBoundCentered := layer.centeredQKOpBound - let qkOpBound := - min qkDenseFrob <| - min qkDenseGram <| - min qkDenseBrauer <| - min qkOpBoundCentered <| - min qkFactorSchur <| - min qkFactorFrob qkFactorGram - - let voOpBound := - min voDenseFrob <| - min voDenseGram <| - min voDenseBrauer <| - min voFactorSchur <| - min voFactorFrob voFactorGram - - return { - wqGram := wqGram - wkGram := wkGram - wvGram := wvGram - woGram := woGram - qkFrobNormSq := qkFrobSq - qkDenseFrob := qkDenseFrob - qkDenseGram := qkDenseGram - qkDenseBrauer := qkDenseBrauer - qkFactorSchur := qkFactorSchur - qkFactorFrob := qkFactorFrob - wqOpGram := wqOpGram - wkOpGram := wkOpGram - qkFactorGram := qkFactorGram - qkOpBound := qkOpBound - voFrobNormSq := voFrobSq - voDenseFrob := voDenseFrob - voDenseGram := voDenseGram - voDenseBrauer := voDenseBrauer - voFactorSchur := voFactorSchur - voFactorFrob := voFactorFrob - wvOpGram := wvOpGram - woOpGram := woOpGram - voFactorGram := voFactorGram - voOpBound := voOpBound - } - -private def opBoundMinOfMany (dense : ConcreteMatrix) - (leftFactor rightFactor : ConcreteMatrix) : Float := - let denseSchur := dense.schurNormEst - let denseFrob := dense.frobeniusNorm - let factorSchur := leftFactor.schurNormEst * rightFactor.schurNormEst - let factorFrob := leftFactor.frobeniusNorm * rightFactor.frobeniusNorm - min denseFrob (min denseSchur (min factorSchur factorFrob)) - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_Q · W_Kᵀ‖₂. - -We take the minimum of several valid upper bounds: - ‖M‖₂ ≤ schurNormEst(M) - ‖M‖₂ ≤ ‖M‖_F - ‖W_Q W_Kᵀ‖₂ ≤ ‖W_Q‖₂‖W_K‖₂ ≤ schur(W_Q)·schur(W_K) - ‖W_Q W_Kᵀ‖₂ ≤ ‖W_Q‖_F‖W_K‖_F - -This is computed in `Float` and is therefore a deterministic heuristic estimate. -In exact real arithmetic, each candidate is an upper bound, so taking `min` -can only tighten the bound. --/ -def queryKeyAlignmentOpBoundFrom (layer : ConcreteAttentionLayer) (qk : ConcreteMatrix) : Float := - let base := opBoundMinOfMany qk layer.W_Q layer.W_K - let wqOpGram := layer.W_Q.opNormUpperBoundViaGramInf - let wkOpGram := layer.W_K.opNormUpperBoundViaGramInf - let qkFactorGram := wqOpGram * wkOpGram - -- Low-rank Gram-product tightening (64×64): - -- ‖W_Q W_Kᵀ‖₂² = λ_max((W_QᵀW_Q)(W_KᵀW_K)) ≤ ‖(W_QᵀW_Q)(W_KᵀW_K)‖_∞. - let wqGram := layer.W_Q.transpose.matmul layer.W_Q - let wkGram := layer.W_K.transpose.matmul layer.W_K - let qkDenseGram := Float.sqrt (max 0.0 ((wqGram.matmul wkGram).infNormAbs)) - -- Brauer/Cassini Gram bound on a 64×64 product with the same singular values: - -- For `A = W_Q` and `B = W_K`, `‖A Bᵀ‖₂ = ‖Bᵀ A‖₂`. - let qkSmall := layer.W_K.transpose.matmul layer.W_Q - let qkDenseBrauer := qkSmall.opNormUpperBoundDenseBrauer - min (min base qkFactorGram) (min (min qkDenseGram qkDenseBrauer) qk.frobeniusNorm) - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_Q · W_Kᵀ‖₂. - -Convenience wrapper that materializes `W_Q·W_Kᵀ`. -Prefer `queryKeyAlignmentOpBoundFrom` when `qk` is already available. --/ -def queryKeyAlignmentOpBound (layer : ConcreteAttentionLayer) : Float := - -- Prefer the no-dense path by default to avoid allocating a `modelDim×modelDim` matrix. - (layer.noDenseProductBounds).qkOpBound - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_V · W_O‖₂. - -This is the minimum of Schur / Frobenius / factorized Schur / factorized Frobenius -upper bounds, analogous to `queryKeyAlignmentOpBoundFrom`. --/ -def valueOutputProjectionOpBoundFrom - (layer : ConcreteAttentionLayer) (vo : ConcreteMatrix) : Float := - let base := opBoundMinOfMany vo layer.W_V layer.W_O - let wvOpGram := layer.W_V.opNormUpperBoundViaGramInf - -- PERFORMANCE: `W_O` is typically wide (`headDim×modelDim`), so we compute the same - -- bound on `W_Oᵀ` instead (‖W_O‖₂ = ‖W_Oᵀ‖₂) to avoid an O(modelDim²) loop. - let woOpGram := layer.W_O.transpose.opNormUpperBoundViaGramInf - let voFactorGram := wvOpGram * woOpGram - -- Low-rank Gram-product tightening (64×64): - -- For `M = W_V W_O`, ‖M‖₂² = λ_max((W_VᵀW_V)(W_O W_Oᵀ)) up to reordering. - let wvGram := layer.W_V.transpose.matmul layer.W_V - let woGram := layer.W_O.matmul layer.W_O.transpose - let voDenseGram := Float.sqrt (max 0.0 ((wvGram.matmul woGram).infNormAbs)) - -- Brauer/Cassini Gram bound on a 64×64 product with the same singular values: - -- For `A = W_V` and `B = W_Oᵀ`, `‖A B‖₂ = ‖B A‖₂`. - let voSmall := layer.W_O.matmul layer.W_V - let voDenseBrauer := voSmall.opNormUpperBoundDenseBrauer - min (min base voFactorGram) (min (min voDenseGram voDenseBrauer) vo.frobeniusNorm) - - -/-- A tighter (still sound-in-ℝ) Float upper bound on ‖W_V · W_O‖₂. - -Convenience wrapper that materializes `W_V·W_O`. -Prefer `valueOutputProjectionOpBoundFrom` when `vo` is already available. --/ -def valueOutputProjectionOpBound (layer : ConcreteAttentionLayer) : Float := - -- Prefer the no-dense path by default to avoid allocating a `modelDim×modelDim` matrix. - (layer.noDenseProductBounds).voOpBound - -end ConcreteAttentionLayer - -/-! ### Local Gram-based operator bounds -/ - -/-- Tight Gram-based operator bound from an explicit PSD Gram matrix. - -This mirrors the logic used for attention-layer weight bounds but is scoped for -data-dependent activation Grams computed in discovery. --/ -private def opBoundFromGramLocal (gram : ConcreteMatrix) (trace : Float) : Float := - if gram.numRows = 0 || gram.numCols = 0 then - 0.0 - else - let lambdaGersh := gram.infNormAbs - let lambdaBrauer := ConcreteMatrix.symmLambdaMaxUpperBrauer gram - let lambdaMoment := - ConcreteMatrix.psdLambdaMaxUpperMoment gram.numRows trace gram.frobeniusNormSq - let lambdaPow4 : Float := - if gram.numRows ≤ 128 then - let gramSq := gram.matmul gram - let trace4 := ConcreteMatrix.traceMul gramSq gramSq - Float.sqrt (Float.sqrt (max 0.0 trace4)) - else - Float.inf - let lambdaUpper := min lambdaGersh (min lambdaBrauer (min lambdaMoment lambdaPow4)) - let op := Float.sqrt (max 0.0 lambdaUpper) - if Float.isNaN op || Float.isInf op then - Float.sqrt (max 0.0 lambdaGersh) - else - op - -/-! ## Concrete MLP Layer -/ - -/-- A concrete MLP (Feed-Forward) layer with exported weights. - -Standard transformer MLP: `output = W_out · activation(W_in · x + b_in) + b_out` - -For interpretability, we analyze individual **neurons** (columns of W_in / rows of W_out). -Each neuron i computes: `activation(W_in[:,i]·x + b_in[i])` and writes `W_out[i,:]` to output. --/ -structure ConcreteMLPLayer where - /-- Model dimension (embedding size) -/ - modelDim : Nat - /-- Hidden dimension (number of neurons) -/ - hiddenDim : Nat - /-- Input projection matrix (modelDim × hiddenDim): maps input to hidden activations -/ - W_in : ConcreteMatrix - /-- Output projection matrix (hiddenDim × modelDim): maps hidden to output -/ - W_out : ConcreteMatrix - /-- Input bias (hiddenDim) stored as 1×hiddenDim matrix for uniformity -/ - b_in : ConcreteMatrix - /-- Output bias (modelDim) stored as 1×modelDim matrix for uniformity -/ - b_out : ConcreteMatrix - /-- Dimension consistency for W_in -/ - W_in_dims : W_in.numRows = modelDim ∧ W_in.numCols = hiddenDim - /-- Dimension consistency for W_out -/ - W_out_dims : W_out.numRows = hiddenDim ∧ W_out.numCols = modelDim - /-- Dimension consistency for b_in -/ - b_in_dims : b_in.numRows = 1 ∧ b_in.numCols = hiddenDim - /-- Dimension consistency for b_out -/ - b_out_dims : b_out.numRows = 1 ∧ b_out.numCols = modelDim - -namespace ConcreteMLPLayer - -/-- Get the input weight vector for neuron i (column i of W_in). -/ -def neuronInputWeights (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Array Float := - if neuronIdx < layer.hiddenDim then - .ofFn fun row : Fin layer.modelDim => layer.W_in.getUnsafe row.val neuronIdx - else #[] - -/-- Get the output weight vector for neuron i (row i of W_out). -/ -def neuronOutputWeights (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Array Float := - if neuronIdx < layer.hiddenDim then - .ofFn fun col : Fin layer.modelDim => - layer.W_out.data[neuronIdx * layer.modelDim + col.val]! - else #[] - -/-- Compute the L2 norm of input weights for a neuron. -/ -def neuronInputNorm (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - let weights := layer.neuronInputWeights neuronIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) - -/-- Compute the L2 norm of output weights for a neuron. -/ -def neuronOutputNorm (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - let weights := layer.neuronOutputWeights neuronIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) - -/-- Compute the "influence magnitude" of a neuron: ‖W_in[:,i]‖ · ‖W_out[i,:]‖ - -This measures how much information can flow through neuron i. -For ReLU networks, this bounds the neuron's contribution to the output. --/ -def neuronInfluence (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - layer.neuronInputNorm neuronIdx * layer.neuronOutputNorm neuronIdx - -/-- Get the bias for neuron i. -/ -def getBias (layer : ConcreteMLPLayer) (neuronIdx : Nat) : Float := - if neuronIdx < layer.hiddenDim then - layer.b_in.data[neuronIdx]! - else 0.0 - -end ConcreteMLPLayer - -/-! ## Interval Bound Propagation (IBP) for MLP Activation Stability - -When analyzing circuit faithfulness, we need to know if ablating upstream components -can cause MLP neurons to "flip" their activation states. A neuron is **stable** if -its pre-activation stays positive (always active) or negative (always inactive) under -all perturbations bounded by ε. - -**Mathematical Setup:** -- Pre-activation: z = W_in^T · x + b -- For perturbation δx with ‖δx‖₂ ≤ ε: - - z' = W_in^T · (x + δx) + b = z + W_in^T · δx - - By Cauchy-Schwarz: |W_in^T · δx| ≤ ‖W_in‖₂ · ‖δx‖₂ ≤ ‖W_in‖₂ · ε -- Therefore: z - ε·‖W_in‖ ≤ z' ≤ z + ε·‖W_in‖ - -**Stability Criterion:** -- Neuron is "stably ON" if z - ε·‖W_in‖ > 0 -- Neuron is "stably OFF" if z + ε·‖W_in‖ < 0 -- Otherwise, neuron is "unstable" (may flip) - -**Pattern Term for Unstable Neurons:** -When a ReLU neuron flips, the linearization error is bounded by the magnitude -of the output weight times the activation change: - ‖ΔOutput‖ ≤ ‖W_out[i,:]‖ · |activation_change| - ≤ ‖W_out[i,:]‖ · max(|z + ε·‖W_in‖|, |z - ε·‖W_in‖|) - -For GeLU, the bound is tighter but we use ReLU-style conservative bounds. --/ - -/-- Result of interval bound propagation for a single neuron. -/ -structure NeuronIntervalBound where - /-- Neuron index within the layer -/ - neuronIdx : Nat - /-- Lower bound on pre-activation (z - ε·‖W_in‖) -/ - preActLower : Float - /-- Upper bound on pre-activation (z + ε·‖W_in‖) -/ - preActUpper : Float - /-- Nominal pre-activation (z = W_in^T · x + b) -/ - preActNominal : Float - /-- Input weight norm ‖W_in[:,i]‖ -/ - inputNorm : Float - /-- Output weight norm ‖W_out[i,:]‖ -/ - outputNorm : Float - deriving Repr - -namespace NeuronIntervalBound - -/-- Is this neuron stably active (always ON) under the perturbation bound? -/ -def isStablyActive (b : NeuronIntervalBound) : Bool := - b.preActLower > 0.0 - -/-- Is this neuron stably inactive (always OFF) under the perturbation bound? -/ -def isStablyInactive (b : NeuronIntervalBound) : Bool := - b.preActUpper < 0.0 - -/-- Is this neuron stable (won't flip activation state)? -/ -def isStable (b : NeuronIntervalBound) : Bool := - b.isStablyActive || b.isStablyInactive - -/-- Is this neuron unstable (may flip activation state)? -/ -def isUnstable (b : NeuronIntervalBound) : Bool := - ¬b.isStable - -/-- The "flip margin" - how close the pre-activation interval is to zero. - -For stable neurons this is positive (distance from zero). -For unstable neurons this is negative (interval crosses zero). --/ -def flipMargin (b : NeuronIntervalBound) : Float := - min b.preActLower (-b.preActUpper) - -/-- Bound on the activation change if the neuron flips (ReLU). - -For ReLU, if the neuron flips from ON to OFF, the activation changes from z to 0. -If it flips from OFF to ON, the activation changes from 0 to z. -The maximum change magnitude is bounded by max(|z_lower|, |z_upper|). --/ -def maxActivationChange (b : NeuronIntervalBound) : Float := - if b.isStable then 0.0 - else max (Float.abs b.preActLower) (Float.abs b.preActUpper) - -/-- Bound on the output error due to potential activation flip. - -This is the pattern term contribution for an unstable neuron: - ‖ΔOutput‖ ≤ ‖W_out[i,:]‖ · max_activation_change --/ -def patternTermBound (b : NeuronIntervalBound) : Float := - b.outputNorm * b.maxActivationChange - -end NeuronIntervalBound - -/-- Result of IBP analysis for an entire MLP layer. -/ -structure MLPIntervalAnalysis where - /-- Layer index -/ - layerIdx : Nat - /-- Per-neuron interval bounds -/ - neuronBounds : Array NeuronIntervalBound - /-- Input perturbation norm (ε) used for analysis -/ - perturbationNorm : Float - /-- Number of stable neurons -/ - numStable : Nat - /-- Number of unstable neurons -/ - numUnstable : Nat - /-- Total pattern term bound for unstable neurons -/ - totalPatternBound : Float - deriving Repr - -namespace MLPIntervalAnalysis - -/-- Fraction of neurons that are stable. -/ -def stabilityRatio (a : MLPIntervalAnalysis) : Float := - if a.neuronBounds.size = 0 then 1.0 - else a.numStable.toFloat / a.neuronBounds.size.toFloat - -/-- Is the layer "fully stable" (all neurons stable)? -/ -def isFullyStable (a : MLPIntervalAnalysis) : Bool := - a.numUnstable = 0 - -/-- Get bounds for a specific neuron. -/ -def getNeuronBound (a : MLPIntervalAnalysis) (idx : Nat) : Option NeuronIntervalBound := - if h : idx < a.neuronBounds.size then some a.neuronBounds[idx] else none - -end MLPIntervalAnalysis - -namespace ConcreteMLPLayer - -/-- Compute pre-activations for all neurons given an input vector (single position). - -Returns array of pre-activations: z[i] = W_in[:,i]^T · x + b[i] --/ -def computePreActivations (layer : ConcreteMLPLayer) (input : Array Float) : - Array Float := - .ofFn fun i : Fin layer.hiddenDim => Id.run do - let mut z : Float := layer.b_in.data[i.val]! - for j in [:layer.modelDim] do - -- SAFETY: input should have size modelDim, but getD provides safe fallback - let x_j := input.getD j 0.0 - let w_ji := layer.W_in.data[j * layer.hiddenDim + i.val]! - z := z + w_ji * x_j - return z - -/-- Compute interval bounds for all neurons given input and perturbation bound. - -**Algorithm:** -For each neuron i: -1. Compute nominal pre-activation: z = W_in[:,i]^T · x + b[i] -2. Compute Δz = ε · ‖W_in[:,i]‖₂ (maximum change due to perturbation) -3. Set bounds: [z - Δz, z + Δz] -4. Determine stability based on whether interval crosses zero - -**Parameters:** -- `input`: Nominal input vector (modelDim elements) -- `perturbationNorm`: L2 norm bound on input perturbation (ε) --/ -def computeIntervalBounds (layer : ConcreteMLPLayer) - (input : Array Float) (perturbationNorm : Float) : Array NeuronIntervalBound := - let preActs := layer.computePreActivations input - .ofFn fun i : Fin layer.hiddenDim => - -- SAFETY: preActs has size hiddenDim by construction from computePreActivations - let z := preActs[i.val]! - let inputNorm := layer.neuronInputNorm i.val - let outputNorm := layer.neuronOutputNorm i.val - let delta := perturbationNorm * inputNorm - { - neuronIdx := i.val - preActLower := z - delta - preActUpper := z + delta - preActNominal := z - inputNorm := inputNorm - outputNorm := outputNorm - } - -/-- Run full IBP analysis on an MLP layer. - -Returns comprehensive analysis including stability counts and total pattern bound. --/ -def analyzeIntervalBounds (layer : ConcreteMLPLayer) (layerIdx : Nat) - (input : Array Float) (perturbationNorm : Float) : MLPIntervalAnalysis := Id.run do - let bounds := layer.computeIntervalBounds input perturbationNorm - let mut numStable : Nat := 0 - let mut numUnstable : Nat := 0 - let mut totalPattern : Float := 0.0 - - for b in bounds do - if b.isStable then - numStable := numStable + 1 - else - numUnstable := numUnstable + 1 - totalPattern := totalPattern + b.patternTermBound - - { - layerIdx := layerIdx - neuronBounds := bounds - perturbationNorm := perturbationNorm - numStable := numStable - numUnstable := numUnstable - totalPatternBound := totalPattern - } - -/-- Compute the pattern term bound for a single neuron using IBP. - -This is a convenient wrapper for single-neuron queries, useful for -`computeNeuronImportance`. - -**Parameters:** -- `neuronIdx`: Index of the neuron to analyze -- `input`: Nominal input vector (from forward pass) -- `perturbationNorm`: L2 norm bound on input perturbation - -**Returns:** Pattern term bound (0 if stable, output_norm * max_change if unstable) --/ -def neuronPatternTermBoundIBP (layer : ConcreteMLPLayer) (neuronIdx : Nat) - (input : Array Float) (perturbationNorm : Float) : Float := - if neuronIdx ≥ layer.hiddenDim then 0.0 - else - let z := Id.run do - let mut acc : Float := layer.b_in.data[neuronIdx]! - for j in [:layer.modelDim] do - let x_j := input.getD j 0.0 - let w_ji := layer.W_in.data[j * layer.hiddenDim + neuronIdx]! - acc := acc + w_ji * x_j - acc - let inputNorm := layer.neuronInputNorm neuronIdx - let outputNorm := layer.neuronOutputNorm neuronIdx - let delta := perturbationNorm * inputNorm - - -- Check stability - let lower := z - delta - let upper := z + delta - if lower > 0.0 || upper < 0.0 then - -- Stable: no pattern term error - 0.0 - else - -- Unstable: bound by output weight times max activation change - outputNorm * max (Float.abs lower) (Float.abs upper) - -end ConcreteMLPLayer - -/-- Create a ConcreteMLPLayer from raw Float arrays. -/ -def mkConcreteMLPLayer - (modelDim hiddenDim : Nat) - (w_in w_out b_in b_out : Array Float) - (hw_in : w_in.size = modelDim * hiddenDim) - (hw_out : w_out.size = hiddenDim * modelDim) - (hb_in : b_in.size = hiddenDim) - (hb_out : b_out.size = modelDim) : ConcreteMLPLayer where - modelDim := modelDim - hiddenDim := hiddenDim - W_in := { numRows := modelDim, numCols := hiddenDim, data := w_in, size_eq := hw_in } - W_out := { numRows := hiddenDim, numCols := modelDim, data := w_out, size_eq := hw_out } - b_in := { numRows := 1, numCols := hiddenDim, data := b_in, - size_eq := by simp [hb_in] } - b_out := { numRows := 1, numCols := modelDim, data := b_out, - size_eq := by simp [hb_out] } - W_in_dims := ⟨rfl, rfl⟩ - W_out_dims := ⟨rfl, rfl⟩ - b_in_dims := ⟨rfl, rfl⟩ - b_out_dims := ⟨rfl, rfl⟩ - -/-! ## Sparse Autoencoders (SAEs) for Feature-Level Analysis - -Sparse Autoencoders decompose MLP activations into interpretable **features**. -Instead of analyzing raw neurons (which are often polysemantic), we analyze -sparse linear combinations that correspond to semantic concepts. - -**Architecture:** -- Encoder: `f = ReLU(W_enc · x + b_enc)` maps residual stream to sparse features -- Decoder: `x' = W_dec · f + b_dec` reconstructs the residual stream -- Sparsity: Only a small number of features are active (f[k] > 0) for any input - -**Key Insight for Circuit Discovery:** -The Jacobian of an MLP approximated through an SAE becomes a sum of rank-1 matrices: - `J ≈ Σ_{k ∈ active} W_dec[:,k] ⊗ W_enc[k,:]` - -This allows us to: -1. Identify which **features** (not neurons) are responsible for behavior -2. Discover cleaner circuits for complex behaviors -3. Handle polysemantic neurons by decomposing them into monosemantic features - -**Reconstruction Error:** -The SAE approximation introduces error: `‖MLP(x) - SAE(x)‖_F` -This must be accounted for in the total faithfulness bound. --/ - -/-- A Sparse Autoencoder for analyzing MLP features. - -Trained to reconstruct residual stream activations with sparse latent codes. -Typically has many more features than the residual stream dimension (e.g., 16x). --/ -structure ConcreteSAE where - /-- Input/output dimension (residual stream size) -/ - inputDim : Nat - /-- Number of features (typically >> inputDim for overcomplete SAEs) -/ - numFeatures : Nat - /-- Encoder weights (inputDim × numFeatures): W_enc[i,k] = weight from input i to feature k -/ - W_enc : ConcreteMatrix - /-- Decoder weights (numFeatures × inputDim): W_dec[k,j] = weight from feature k to output j -/ - W_dec : ConcreteMatrix - /-- Encoder bias (numFeatures): b_enc[k] = bias for feature k -/ - b_enc : ConcreteMatrix - /-- Decoder bias (inputDim): b_dec[j] = bias for output j -/ - b_dec : ConcreteMatrix - /-- Dimension constraints -/ - W_enc_dims : W_enc.numRows = inputDim ∧ W_enc.numCols = numFeatures - W_dec_dims : W_dec.numRows = numFeatures ∧ W_dec.numCols = inputDim - b_enc_dims : b_enc.numRows = 1 ∧ b_enc.numCols = numFeatures - b_dec_dims : b_dec.numRows = 1 ∧ b_dec.numCols = inputDim - -namespace ConcreteSAE - -/-- ReLU activation for SAE encoder. -/ -private def relu (x : Float) : Float := if x > 0.0 then x else 0.0 - -/-- Encode input to sparse feature activations: f = ReLU(W_enc^T · x + b_enc) - -Note: W_enc is stored as (inputDim × numFeatures), so we compute x^T · W_enc. - -PERFORMANCE: Pre-allocates output array (critical for SAE-based circuit discovery). --/ -def encode (sae : ConcreteSAE) (input : Array Float) : Array Float := - .ofFn fun k : Fin sae.numFeatures => Id.run do - let mut z : Float := sae.b_enc.data[k.val]! - for i in [:sae.inputDim] do - let x_i := input.getD i 0.0 - let w_ik := sae.W_enc.data[i * sae.numFeatures + k.val]! - z := z + x_i * w_ik - return relu z - -/-- Decode sparse features back to residual stream: x' = W_dec^T · f + b_dec - -Note: W_dec is stored as (numFeatures × inputDim), so we compute f^T · W_dec. - -PERFORMANCE: Pre-allocates output array (critical for SAE-based circuit discovery). --/ -def decode (sae : ConcreteSAE) (features : Array Float) : Array Float := - .ofFn fun j : Fin sae.inputDim => Id.run do - let mut y : Float := sae.b_dec.data[j.val]! - for k in [:sae.numFeatures] do - let f_k := features.getD k 0.0 - let w_kj := sae.W_dec.data[k * sae.inputDim + j.val]! - y := y + f_k * w_kj - return y - -/-- Full forward pass: encode then decode. -/ -def forward (sae : ConcreteSAE) (input : Array Float) : Array Float := - sae.decode (sae.encode input) - -/-- Compute reconstruction error ‖x - SAE(x)‖₂ for a single input. -/ -def reconstructionError (sae : ConcreteSAE) (input : Array Float) : Float := Id.run do - let reconstructed := sae.forward input - let mut errSq : Float := 0.0 - for i in [:sae.inputDim] do - let diff := input.getD i 0.0 - reconstructed.getD i 0.0 - errSq := errSq + diff * diff - Float.sqrt errSq - -/-- Compute reconstruction error for a matrix (multiple positions). -/ -def reconstructionErrorMatrix (sae : ConcreteSAE) (input : ConcreteMatrix) : Float := Id.run do - let mut totalErrSq : Float := 0.0 - for pos in [:input.numRows] do - let inputVec : Array Float := .ofFn fun d : Fin input.numCols => input.getUnsafe pos d.val - let err := sae.reconstructionError inputVec - totalErrSq := totalErrSq + err * err - Float.sqrt totalErrSq - -/-- Get indices of active features (f[k] > threshold). -/ -def activeFeatures (sae : ConcreteSAE) (input : Array Float) - (threshold : Float := 0.0) : Array Nat := Id.run do - let features := sae.encode input - let mut active : Array Nat := #[] - for k in [:sae.numFeatures] do - if features.getD k 0.0 > threshold then - active := active.push k - active - -/-- Count active features for an input. -/ -def numActiveFeatures (sae : ConcreteSAE) (input : Array Float) - (threshold : Float := 0.0) : Nat := - (sae.activeFeatures input threshold).size - -/-- Get the encoder weight vector for feature k (column k of W_enc). -/ -def encoderWeights (sae : ConcreteSAE) (featureIdx : Nat) : Array Float := - if featureIdx < sae.numFeatures then - .ofFn fun i : Fin sae.inputDim => sae.W_enc.getUnsafe i.val featureIdx - else #[] - -/-- Get the decoder weight vector for feature k (row k of W_dec). -/ -def decoderWeights (sae : ConcreteSAE) (featureIdx : Nat) : Array Float := - if featureIdx < sae.numFeatures then - .ofFn fun j : Fin sae.inputDim => - sae.W_dec.data[featureIdx * sae.inputDim + j.val]! - else #[] - -/-- Compute the L2 norm of encoder weights for feature k. -/ -def encoderNorm (sae : ConcreteSAE) (featureIdx : Nat) : Float := - let weights := sae.encoderWeights featureIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) - -/-- Compute the L2 norm of decoder weights for feature k. -/ -def decoderNorm (sae : ConcreteSAE) (featureIdx : Nat) : Float := - let weights := sae.decoderWeights featureIdx - Float.sqrt (weights.foldl (fun acc w => acc + w * w) 0.0) - -/-- Compute the "influence magnitude" of feature k: ‖W_enc[:,k]‖ · ‖W_dec[k,:]‖ - -This bounds how much information can flow through the feature. -Analogous to `neuronInfluence` for MLP neurons. --/ -def featureInfluence (sae : ConcreteSAE) (featureIdx : Nat) : Float := - sae.encoderNorm featureIdx * sae.decoderNorm featureIdx - -/-- Get encoder bias for feature k. -/ -def encoderBias (sae : ConcreteSAE) (featureIdx : Nat) : Float := - if featureIdx < sae.numFeatures then sae.b_enc.data[featureIdx]! else 0.0 - -/-- Compute the pre-activation for feature k given input. -/ -def featurePreActivation (sae : ConcreteSAE) (featureIdx : Nat) - (input : Array Float) : Float := Id.run do - if featureIdx ≥ sae.numFeatures then return 0.0 - let mut z : Float := sae.b_enc.data[featureIdx]! - for i in [:sae.inputDim] do - let x_i := input.getD i 0.0 - let w_ik := sae.W_enc.data[i * sae.numFeatures + featureIdx]! - z := z + x_i * w_ik - z - -/-- Check if feature k is active (pre-activation > 0) for given input. -/ -def isFeatureActive (sae : ConcreteSAE) (featureIdx : Nat) (input : Array Float) : Bool := - sae.featurePreActivation featureIdx input > 0.0 - -end ConcreteSAE - -/-- Create a ConcreteSAE from raw Float arrays. -/ -def mkConcreteSAE - (inputDim numFeatures : Nat) - (w_enc w_dec b_enc b_dec : Array Float) - (hw_enc : w_enc.size = inputDim * numFeatures) - (hw_dec : w_dec.size = numFeatures * inputDim) - (hb_enc : b_enc.size = numFeatures) - (hb_dec : b_dec.size = inputDim) : ConcreteSAE where - inputDim := inputDim - numFeatures := numFeatures - W_enc := { numRows := inputDim, numCols := numFeatures, data := w_enc, size_eq := hw_enc } - W_dec := { numRows := numFeatures, numCols := inputDim, data := w_dec, size_eq := hw_dec } - b_enc := { numRows := 1, numCols := numFeatures, data := b_enc, - size_eq := by simp [hb_enc] } - b_dec := { numRows := 1, numCols := inputDim, data := b_dec, - size_eq := by simp [hb_dec] } - W_enc_dims := ⟨rfl, rfl⟩ - W_dec_dims := ⟨rfl, rfl⟩ - b_enc_dims := ⟨rfl, rfl⟩ - b_dec_dims := ⟨rfl, rfl⟩ - -/-! ### SAE Interval Bound Propagation - -Like MLP neurons, SAE features can flip activation states under perturbation. -We extend IBP to track feature stability. --/ - -/-- Result of IBP analysis for a single SAE feature. -/ -structure SAEFeatureIntervalBound where - /-- Feature index -/ - featureIdx : Nat - /-- Lower bound on pre-activation -/ - preActLower : Float - /-- Upper bound on pre-activation -/ - preActUpper : Float - /-- Nominal pre-activation -/ - preActNominal : Float - /-- Encoder weight norm ‖W_enc[:,k]‖ -/ - encoderNorm : Float - /-- Decoder weight norm ‖W_dec[k,:]‖ -/ - decoderNorm : Float - deriving Repr - -namespace SAEFeatureIntervalBound - -/-- Is this feature stably active? -/ -def isStablyActive (b : SAEFeatureIntervalBound) : Bool := b.preActLower > 0.0 - -/-- Is this feature stably inactive? -/ -def isStablyInactive (b : SAEFeatureIntervalBound) : Bool := b.preActUpper < 0.0 - -/-- Is this feature stable (won't flip)? -/ -def isStable (b : SAEFeatureIntervalBound) : Bool := b.isStablyActive || b.isStablyInactive - -/-- Pattern term bound if this feature flips. -/ -def patternTermBound (b : SAEFeatureIntervalBound) : Float := - if b.isStable then 0.0 - else b.decoderNorm * max (Float.abs b.preActLower) (Float.abs b.preActUpper) - -end SAEFeatureIntervalBound - -/-- Result of IBP analysis for an entire SAE. -/ -structure SAEIntervalAnalysis where - /-- Per-feature bounds -/ - featureBounds : Array SAEFeatureIntervalBound - /-- Perturbation norm used -/ - perturbationNorm : Float - /-- Number of stable features -/ - numStable : Nat - /-- Number of unstable features -/ - numUnstable : Nat - /-- Total pattern term bound from unstable features -/ - totalPatternBound : Float - /-- Reconstruction error (SAE approximation) -/ - reconstructionError : Float - deriving Repr - -namespace SAEIntervalAnalysis - -/-- Stability ratio. -/ -def stabilityRatio (a : SAEIntervalAnalysis) : Float := - if a.featureBounds.size = 0 then 1.0 - else a.numStable.toFloat / a.featureBounds.size.toFloat - -/-- Total error bound (pattern + reconstruction). -/ -def totalErrorBound (a : SAEIntervalAnalysis) : Float := - a.totalPatternBound + a.reconstructionError - -end SAEIntervalAnalysis - -namespace ConcreteSAE - -/-- Compute interval bounds for all features given input and perturbation. - -PERFORMANCE: Pre-allocates result array with `Array.ofFn` to avoid O(n) reallocations. --/ -def computeFeatureIntervalBounds (sae : ConcreteSAE) - (input : Array Float) (perturbationNorm : Float) : Array SAEFeatureIntervalBound := - .ofFn fun k : Fin sae.numFeatures => - let z := sae.featurePreActivation k.val input - let encNorm := sae.encoderNorm k.val - let decNorm := sae.decoderNorm k.val - let delta := perturbationNorm * encNorm - { - featureIdx := k.val - preActLower := z - delta - preActUpper := z + delta - preActNominal := z - encoderNorm := encNorm - decoderNorm := decNorm - } - -/-- Run full IBP analysis on an SAE. -/ -def analyzeIntervalBounds (sae : ConcreteSAE) (input : Array Float) - (perturbationNorm : Float) : SAEIntervalAnalysis := Id.run do - let bounds := sae.computeFeatureIntervalBounds input perturbationNorm - let mut numStable : Nat := 0 - let mut numUnstable : Nat := 0 - let mut totalPattern : Float := 0.0 - - for b in bounds do - if b.isStable then - numStable := numStable + 1 - else - numUnstable := numUnstable + 1 - totalPattern := totalPattern + b.patternTermBound - - let reconErr := sae.reconstructionError input - - { - featureBounds := bounds - perturbationNorm := perturbationNorm - numStable := numStable - numUnstable := numUnstable - totalPatternBound := totalPattern - reconstructionError := reconErr - } - -/-- Pattern term bound for a single feature using IBP. -/ -def featurePatternTermBoundIBP (sae : ConcreteSAE) (featureIdx : Nat) - (input : Array Float) (perturbationNorm : Float) : Float := - if featureIdx ≥ sae.numFeatures then 0.0 - else - let z := sae.featurePreActivation featureIdx input - let encNorm := sae.encoderNorm featureIdx - let decNorm := sae.decoderNorm featureIdx - let delta := perturbationNorm * encNorm - let lower := z - delta - let upper := z + delta - if lower > 0.0 || upper < 0.0 then 0.0 - else decNorm * max (Float.abs lower) (Float.abs upper) - -end ConcreteSAE - -/-! ## Attention Weights Computation -/ - -/-- Concrete attention weights for a sequence. -A[q][k] = attention weight from query position q to key position k. -/ -structure ConcreteAttentionWeights where - /-- Sequence length -/ - seqLen : Nat - /-- Attention weights stored row-major: weights[q * seqLen + k] = A[q,k] -/ - weights : Array Float - /-- Size check -/ - size_eq : weights.size = seqLen * seqLen - -namespace ConcreteAttentionWeights - -/-- Access A[q, k]. -/ -def get (A : ConcreteAttentionWeights) (q k : Nat) : Float := - if q < A.seqLen ∧ k < A.seqLen then - -- Index is in-bounds by `size_eq` and the guard above. - A.weights[q * A.seqLen + k]! - else 0.0 - -/-- Fast access to `A[q, k]` assuming `q < seqLen` and `k < seqLen`. -/ -@[inline] def getUnsafe (A : ConcreteAttentionWeights) (q k : Nat) : Float := - A.weights[q * A.seqLen + k]! - -/-- Convert attention weights to a `ConcreteMatrix` for use with matrix multiplication. -/ -def toMatrix (A : ConcreteAttentionWeights) : ConcreteMatrix where - numRows := A.seqLen - numCols := A.seqLen - data := A.weights - size_eq := by - simpa using A.size_eq - -/-- Compute softmax for a row of logits. -/ -def softmaxRow (logits : Array Float) : Array Float := - Id.run do - -- PERFORMANCE: keep arrays linear to enable in-place updates - -- (see Lean Reference Manual: runtime reference counting + array performance). - let mut expVals : Array Float := logits - let mut maxVal : Float := -1e30 - for i in [:expVals.size] do - let v := expVals[i]! - if v > maxVal then maxVal := v - let mut sumExp : Float := 0.0 - for i in [:expVals.size] do - let v := Float.exp (expVals[i]! - maxVal) - expVals := expVals.set! i v - sumExp := sumExp + v - if sumExp > 0.0 then - for i in [:expVals.size] do - expVals := expVals.set! i (expVals[i]! / sumExp) - return expVals - -/-- Compute attention weights given queries, keys, and scaling. -/ -def compute (queries keys : ConcreteMatrix) (scale : Float) - (seqLen : Nat) - (causal : Bool := true) : ConcreteAttentionWeights := Id.run do - -- PERFORMANCE: avoid allocating an `Array (Array Float)` of rows; write the flattened - -- weights row-by-row into a single mutable array. - let cols := min queries.numCols keys.numCols - let n := seqLen * seqLen - let mut weights : { w : Array Float // w.size = n } := - ⟨Array.replicate n 0.0, by simp [n]⟩ - for q in [:seqLen] do - -- Initialize to -∞ and only fill the causal prefix when `causal = true`. - let mut rowScores : Array Float := Array.replicate seqLen (-1e30) - let stop := if causal then min (q + 1) seqLen else seqLen - let qBase := q * queries.numCols - for j in [:stop] do - if q < queries.numRows ∧ j < keys.numRows then - let jBase := j * keys.numCols - let mut dotProd : Float := 0.0 - for d in [:cols] do - -- SAFETY: within this branch `q < queries.numRows` and `j < keys.numRows`, - -- and `d < cols ≤ queries.numCols/keys.numCols`. - dotProd := dotProd + queries.data[qBase + d]! * keys.data[jBase + d]! - rowScores := rowScores.set! j (dotProd / scale) - let row := softmaxRow rowScores - let rowBase := q * seqLen - for k in [:stop] do - let weights' := weights.1.set! (rowBase + k) (row[k]!) - have weights'SizeEq : weights'.size = n := by - have hsize : weights'.size = weights.1.size := by - -- `set!` is `setIfInBounds`, which preserves size. - simp [weights'] - exact hsize.trans weights.2 - weights := ⟨weights', weights'SizeEq⟩ - return { - seqLen := seqLen - weights := weights.1 - size_eq := by simpa [n] using weights.2 - } - -end ConcreteAttentionWeights - -/-! ## Softmax Jacobian Sparsity Analysis - -For a probability vector p, the softmax Jacobian J has entries J_ij = p_i(δ_ij - p_j). -The Frobenius norm squared of J for a single row is: - - ‖J‖²_F = Σᵢⱼ p_i²(δ_ij - p_j)² = Σᵢ p_i²(1-p_i)² + Σᵢ≠ⱼ p_i²p_j² - = Σᵢ p_i² - 2Σᵢ p_i³ + (Σᵢ p_i²)² - -A simpler upper bound is `Σᵢ p_i(1-p_i) = 1 - Σᵢ p_i²` (the Gini impurity). - -For sparse (one-hot) distributions: Σ p_i² ≈ 1, so the bound is ≈ 0 -For uniform distributions: Σ p_i² = 1/n, so the bound is √((n-1)/n) - -This allows us to compute much tighter pattern term bounds for sharp attention heads. --/ - -/-- Compute the "effective softmax derivative norm" for a single probability row. - -For probability vector p, this computes `sqrt(Σᵢ p_i(1-p_i)) = sqrt(1 - Σᵢ p_i²)`. -This bounds the Frobenius norm of the softmax Jacobian for that row. - -- One-hot distribution → 0 (no gradient flow through softmax) -- Uniform over n → sqrt((n-1)/n) ≈ 1 for large n --/ -def softmaxRowJacobianNorm (row : Array Float) : Float := - let sumSq := row.foldl (fun acc p => acc + p * p) 0.0 - Float.sqrt (max 0.0 (1.0 - sumSq)) - -/-- Compute the average softmax Jacobian norm across all rows of attention weights. - -This provides a data-dependent bound on the softmax Jacobian that is much tighter -than a coarse constant bound (e.g. 1.0), especially for sharp attention patterns. --/ -def ConcreteAttentionWeights.avgSoftmaxJacobianNorm (A : ConcreteAttentionWeights) : Float := - if A.seqLen = 0 then 0.0 - else Id.run do - let mut totalNormSq : Float := 0.0 - for q in [:A.seqLen] do - -- Extract row q - let mut sumSq : Float := 0.0 - let rowBase := q * A.seqLen - for k in [:A.seqLen] do - -- SAFETY: `q < seqLen` and `k < seqLen` by loop bounds. - let p := A.weights[rowBase + k]! - sumSq := sumSq + p * p - -- Jacobian norm squared for this row is bounded by 1 - sumSq - totalNormSq := totalNormSq + max 0.0 (1.0 - sumSq) - -- Return RMS (root mean square) of per-row norms - Float.sqrt (totalNormSq / A.seqLen.toFloat) - -/-- Compute the maximum softmax Jacobian norm across all rows. - -More conservative than avg, but still much tighter than a coarse constant bound for sparse attention. --/ -def ConcreteAttentionWeights.maxSoftmaxJacobianNorm (A : ConcreteAttentionWeights) : Float := - if A.seqLen = 0 then 0.0 - else Id.run do - let mut maxNorm : Float := 0.0 - for q in [:A.seqLen] do - let mut sumSq : Float := 0.0 - let rowBase := q * A.seqLen - for k in [:A.seqLen] do - -- SAFETY: `q < seqLen` and `k < seqLen` by loop bounds. - let p := A.weights[rowBase + k]! - sumSq := sumSq + p * p - let rowNorm := Float.sqrt (max 0.0 (1.0 - sumSq)) - maxNorm := max maxNorm rowNorm - maxNorm - -/-! ### Attention Jacobian heuristics - -The following quantities are computed from `Float` attention weights produced by a `Float` -softmax. They must be treated as **heuristic estimates**, not sound certificates. --/ - -/-- Diagnostics for the softmax-Jacobian operator norm bound. - -The full softmax Jacobian over an attention matrix is block-diagonal over rows, -so the overall operator norm is the maximum of the per-row operator norms. - -We record statistics for the row that attains this maximum. --/ -structure SoftmaxJacobianOpDiag where - /-- Overall (block-diagonal) operator norm upper bound, `max_q rowBound(q)`. -/ - opBound : Float - /-- `max_i p_i` for the maximizing row. -/ - maxRowMaxP : Float - /-- `tr(J) = 1 - ∑ p_i^2` for the maximizing row. -/ - maxRowTraceBound : Float - /-- PSD moment bound for the maximizing row, derived from `tr(J)` and `‖J‖_F²`. -/ - maxRowMomentBound : Float - /-- Gershgorin / induced-∞ bound `max_i 2 p_i (1 - p_i)` for the maximizing row. -/ - maxRowGersh : Float - /-- The final per-row bound used for the maximizing row. -/ - maxRowBoundUsed : Float - /-- Number of rows that triggered a conservative fallback (NaN/Inf/zero-sum). -/ - numRowsFallback : Nat - deriving Repr - -/-- Heuristic estimate of the softmax-Jacobian operator norm per row (then maxed). - -For a probability row `p`, the softmax Jacobian is: -`J = diag(p) - p pᵀ`. -It is positive semidefinite and satisfies: -- `J ≤ diag(p)` so `‖J‖₂ ≤ maxᵢ pᵢ` -- `‖J‖₂ ≤ tr(J) = 1 - Σᵢ pᵢ²` - -We take the tighter bound -`min(maxᵢ pᵢ, 1 - Σᵢ pᵢ²)` for each row, -and then take the maximum over rows. -This is especially sharp for nearly one-hot or nearly uniform rows. --/ -def ConcreteAttentionWeights.softmaxJacobianOpDiag - (A : ConcreteAttentionWeights) : SoftmaxJacobianOpDiag := - if A.seqLen = 0 then - { opBound := 0.0 - maxRowMaxP := 0.0 - maxRowTraceBound := 0.0 - maxRowMomentBound := 0.0 - maxRowGersh := 0.0 - maxRowBoundUsed := 0.0 - numRowsFallback := 0 } - else Id.run do - let mut maxBound : Float := 0.0 - let mut bestMaxP : Float := 0.0 - let mut bestTrace : Float := 0.0 - let mut bestMoment : Float := 0.0 - let mut bestGersh : Float := 0.0 - let mut bestUsed : Float := 0.0 - let mut fallbackCount : Nat := 0 - - for q in [:A.seqLen] do - let rowBase := q * A.seqLen - - -- Pass 1: clamp negatives to 0 and compute sum. - let mut sumP : Float := 0.0 - for k in [:A.seqLen] do - let p0 := A.weights[rowBase + k]! - let p := if p0 < 0.0 then 0.0 else p0 - sumP := sumP + p - - let mut rowBound : Float := 0.5 - let mut rowMaxP : Float := 0.0 - let mut rowTrace : Float := 0.0 - let mut rowMoment : Float := 0.0 - let mut rowGersh : Float := 0.0 - - if sumP ≤ 0.0 || Float.isNaN sumP || Float.isInf sumP then - -- Conservative fallback: global bound for any probability row. - fallbackCount := fallbackCount + 1 - rowBound := 0.5 - else - -- Pass 2: renormalize and compute per-row bounds. - let invSum := 1.0 / sumP - let mut sumSq : Float := 0.0 - let mut sumCube : Float := 0.0 - let mut maxP : Float := 0.0 - let mut gersh : Float := 0.0 - for k in [:A.seqLen] do - let p0 := A.weights[rowBase + k]! - let pClamped := if p0 < 0.0 then 0.0 else p0 - let p := pClamped * invSum - sumSq := sumSq + p * p - sumCube := sumCube + p * p * p - if p > maxP then maxP := p - let g := 2.0 * p * (1.0 - p) - if g > gersh then gersh := g - let traceBound := max 0.0 (1.0 - sumSq) - -- Moment bound: J is PSD with - -- tr(J) = 1 - Σ p_i² - -- ‖J‖_F² = Σ (p_i - p_i²)² + Σ_{i≠j} (p_i p_j)² - -- = (Σ p_i²) - 2(Σ p_i³) + (Σ p_i²)². - let frob2 := max 0.0 (sumSq - 2.0 * sumCube + sumSq * sumSq) - let momentBound := ConcreteMatrix.psdLambdaMaxUpperMoment A.seqLen traceBound frob2 - -- Rigorous (for probability rows): - -- λ_max(J) ≤ maxP - -- λ_max(J) ≤ tr(J) - -- λ_max(J) ≤ ‖J‖_∞ = max_i 2 p_i (1-p_i) - -- λ_max(J) ≤ 1/2 - let bound0 := min maxP (min traceBound (min gersh momentBound)) - let bound1 := min bound0 0.5 - let bound2 := max 0.0 (min 0.5 bound1) - let bound := if Float.isNaN bound2 || Float.isInf bound2 then 0.5 else bound2 - rowBound := bound - rowMaxP := maxP - rowTrace := traceBound - rowMoment := momentBound - rowGersh := gersh - - if rowBound > maxBound then - maxBound := rowBound - bestMaxP := rowMaxP - bestTrace := rowTrace - bestMoment := rowMoment - bestGersh := rowGersh - bestUsed := rowBound - - { opBound := maxBound - maxRowMaxP := bestMaxP - maxRowTraceBound := bestTrace - maxRowMomentBound := bestMoment - maxRowGersh := bestGersh - maxRowBoundUsed := bestUsed - numRowsFallback := fallbackCount } - -/-- Backwards-compatible accessor: the bound value. -/ -def ConcreteAttentionWeights.softmaxJacobianOpEst (A : ConcreteAttentionWeights) : Float := - (A.softmaxJacobianOpDiag).opBound - -/-- Compute the overall Frobenius norm of the softmax Jacobian across all rows. - -For a probability row `p`, the Jacobian is `J = diag(p) - p pᵀ` and -`‖J‖_F² = (Σ pᵢ²) - 2(Σ pᵢ³) + (Σ pᵢ²)²`. -We sum this per-row Frobenius norm squared and take a final square root. --/ -def ConcreteAttentionWeights.softmaxJacobianFrobeniusNorm - (A : ConcreteAttentionWeights) : Float := - if A.seqLen = 0 then 0.0 - else Id.run do - let mut totalFrobSq : Float := 0.0 - for q in [:A.seqLen] do - let mut sumSq : Float := 0.0 - let mut sumCube : Float := 0.0 - let rowBase := q * A.seqLen - for k in [:A.seqLen] do - -- SAFETY: `q < seqLen` and `k < seqLen` by loop bounds. - let p := A.weights[rowBase + k]! - sumSq := sumSq + p * p - sumCube := sumCube + p * p * p - let rowFrobSq := max 0.0 (sumSq - 2.0 * sumCube + sumSq * sumSq) - totalFrobSq := totalFrobSq + rowFrobSq - Float.sqrt totalFrobSq - -/-! ## Forward Pass Implementations - -These methods compute the actual forward pass through attention and MLP layers, -accumulating the residual stream as in real transformers. --/ - -/-- Compute attention weights for a layer given an input matrix. -/ -def ConcreteAttentionLayer.computeAttentionWeights (layer : ConcreteAttentionLayer) - (input : ConcreteMatrix) (causal : Bool := true) : ConcreteAttentionWeights := - let queries := (input.matmul layer.W_Q).addBias layer.b_Q - let keys := (input.matmul layer.W_K).addBias layer.b_K - let scale := Float.sqrt layer.headDim.toFloat - ConcreteAttentionWeights.compute queries keys scale input.numRows causal - -/-- Forward pass for a single attention head. - -Input: X (seqLen × modelDim) -Output: Y (seqLen × modelDim) where Y = A·V·W_O - -This computes the attention output (before residual connection): -1. Q = X·W_Q, K = X·W_K, V = X·W_V -2. A = softmax(Q·K^T / √d_head) -3. Y = A·V·W_O --/ -def ConcreteAttentionLayer.forward (layer : ConcreteAttentionLayer) (input : ConcreteMatrix) - (causal : Bool := true) : ConcreteMatrix := - let attn := layer.computeAttentionWeights input causal - let values := (input.matmul layer.W_V).addBias layer.b_V -- seqLen × headDim - -- Compute A · V using direct construction - let attnValues : ConcreteMatrix := { - numRows := input.numRows - numCols := layer.headDim - data := .ofFn fun idx : Fin (input.numRows * layer.headDim) => Id.run do - let q := idx.val / layer.headDim - let d := idx.val % layer.headDim - let mut acc : Float := 0.0 - let n := input.numRows - let attnRowBase := q * n - for k in [:n] do - -- SAFETY: `q < n` and `k < n` by loop bounds, and `attn.seqLen = n`. - let a := attn.weights[attnRowBase + k]! - -- SAFETY: `k < values.numRows = n` and `d < values.numCols = headDim`. - let v := values.data[k * layer.headDim + d]! - acc := acc + a * v - return acc - size_eq := Array.size_ofFn - } - -- Project back to model dimension: (A·V) · W_O - attnValues.matmul layer.W_O - -/-- ReLU activation function for Float. -/ -def reluFloat (x : Float) : Float := if x > 0.0 then x else 0.0 - -/-- GeLU activation function (approximate) for Float. -/ -def geluFloat (x : Float) : Float := - let pi : Float := 3.14159265358979323846 - -- GPT-2 uses the "gelu_new" tanh approximation: - -- 0.5*x*(1 + tanh( sqrt(2/pi) * (x + 0.044715*x^3) )) - let a := Float.sqrt (2.0 / pi) - 0.5 * x * (1.0 + Float.tanh (a * (x + 0.044715 * x * x * x))) - -/-- Derivative of `geluFloat` with respect to `x`. - -This matches the tanh-based approximate GeLU used by `geluFloat`. --/ -def geluDerivFloat (x : Float) : Float := - let pi : Float := 3.14159265358979323846 - let a := Float.sqrt (2.0 / pi) - let b : Float := 0.044715 - let t := a * (x + b * x * x * x) - let tanhT := Float.tanh t - let sech2 := 1.0 - tanhT * tanhT - let t' := a * (1.0 + 3.0 * b * x * x) - 0.5 * (1.0 + tanhT) + 0.5 * x * sech2 * t' - -/-- Data-dependent upper bound on the MLP Jacobian operator norm over a batch of tokens. - -For a token with GeLU derivative vector `d` (from the pre-activations `z`), -the MLP Jacobian is `J = W_in · diag(d) · W_out`. - -Computing `J` explicitly is intractable (`modelDim×modelDim`). Instead we upper-bound `‖J‖₂` -via the standard inequality `‖J‖₂ ≤ sqrt(‖J‖₁ · ‖J‖∞)`, where `‖·‖₁`/`‖·‖∞` are induced norms. - -For a fixed token, row/column absolute sums can be bounded without forming `J`. -In row-vector convention, `‖J‖₁` is the max row sum and `‖J‖∞` is the max column sum: - -* max row sum ≤ `max_r ∑_k |W_in[r,k]| · |d_k| · (∑_c |W_out[k,c]|)` -* max column sum ≤ `max_c ∑_k |W_out[k,c]| · |d_k| · (∑_r |W_in[r,k]|)` - -To avoid a per-token `O(modelDim·hiddenDim)` loop, we take `dMax[k] = max_token |d_token[k]|` and -use it in place of `|d_k|`, which is sound for both `‖·‖₁` and `‖·‖∞`. --/ -def computeMLPLayerOpNormFromGeluDerivWithOpBounds - (layer : ConcreteMLPLayer) (geluDeriv : ConcreteMatrix) (winUb woutUb : Float) : Float := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - let legacy : Float := winUb * 1.7 * woutUb - if d = 0 || h = 0 || geluDeriv.numRows = 0 then - -- No tokens or empty layer: treat as zero contribution. - return 0.0 - if geluDeriv.numCols ≠ h then - -- Dimension mismatch indicates an inconsistent forward-pass cache; fall back conservatively. - return legacy - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return legacy - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return legacy - - let rows := geluDeriv.numRows - - -- dMax[k] = max_token |gelu'(z)[token,k]|. - let dMax : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for i in [:rows] do - let base := i * h - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - if a > out[k]! then - out := out.set! k a - out - let globalDmax : Float := dMax.foldl (fun m x => max m x) 0.0 - if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then - -- If derivative information is degenerate, we can still use the global GeLU' upper bound (≈1.7). - return legacy - - -- sOut[k] = ∑_c |W_out[k,c]| (row sums of |W_out|). - let (sOut, woutRowSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for k in [:h] do - let rowBase := k * d - let mut s : Float := 0.0 - let mut ss : Float := 0.0 - for c in [:d] do - let w := layer.W_out.data[rowBase + c]! - s := s + Float.abs w - ss := ss + w * w - out := out.set! k s - sq := sq.set! k ss - (out, sq) - - -- sIn[k] = ∑_r |W_in[r,k]| (column sums of |W_in|). - let (sIn, winColSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for r in [:d] do - let rowBase := r * h - for k in [:h] do - let w := layer.W_in.data[rowBase + k]! - out := out.set! k (out[k]! + Float.abs w) - sq := sq.set! k (sq[k]! + w * w) - (out, sq) - - /- - Token-wise Frobenius scaling can be strictly tighter than using the coordinatewise maxima `dMax`: - the vector `dMax` may not be realized by any single token, but the MLP Jacobian is block-diagonal - across tokens, so `‖J‖₂ = max_token ‖J_token‖₂`. - -/ - let (winScaledFrobTokMax, woutScaledFrobTokMax) : (Float × Float) := Id.run do - let mut winMaxSq : Float := 0.0 - let mut woutMaxSq : Float := 0.0 - for i in [:rows] do - let base := i * h - let mut winSq : Float := 0.0 - let mut woutSq : Float := 0.0 - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - let aa := a * a - winSq := winSq + aa * winColSqSum[k]! - woutSq := woutSq + aa * woutRowSqSum[k]! - winMaxSq := max winMaxSq winSq - woutMaxSq := max woutMaxSq woutSq - (Float.sqrt (max 0.0 winMaxSq), Float.sqrt (max 0.0 woutMaxSq)) - - -- Fast candidates that exploit the full per-unit maxima `dMax[k]` (not just `max_k dMax[k]`): - -- - -- ‖W_in·diag(d)·W_out‖₂ ≤ ‖W_in·diag(dMax)‖₂ · ‖W_out‖₂ - -- ‖W_in·diag(d)·W_out‖₂ ≤ ‖W_in‖₂ · ‖diag(dMax)·W_out‖₂ - -- - -- where `dMax[k] = max_token |gelu'(z)[token,k]|`. Both are rigorous because induced - -- norms / Frobenius norms are monotone under entrywise scaling by `dMax`. - let maxWinScaledCol : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sIn[k]!) - m - let maxWoutScaledRow : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sOut[k]!) - m - let winScaledFrob : Float := winScaledFrobTokMax - let woutScaledFrob : Float := woutScaledFrobTokMax - - -- Max row-sum bound (‖J‖₁ in row-vector convention). - let (boundInf, maxWinRowScaled) : (Float × Float) := Id.run do - let mut maxRow : Float := 0.0 - let mut maxScaled : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - let mut sScaled : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sOut[k]! - let a := Float.abs (layer.W_in.data[rowBase + k]!) - s := s + a * coeff - sScaled := sScaled + a * dMax[k]! - if s > maxRow then - maxRow := s - if sScaled > maxScaled then - maxScaled := sScaled - (maxRow, maxScaled) - - -- Max column-sum bound (‖J‖∞ in row-vector convention). - let (boundOne, maxWoutColScaled) : (Float × Float) := Id.run do - let mut maxCol : Float := 0.0 - let mut maxScaled : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - let mut sScaled : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sIn[k]! - let a := Float.abs (layer.W_out.data[k * d + c]!) - s := s + a * coeff - sScaled := sScaled + a * dMax[k]! - if s > maxCol then - maxCol := s - if sScaled > maxScaled then - maxScaled := sScaled - (maxCol, maxScaled) - - let absSchur := Float.sqrt (max 0.0 (boundInf * boundOne)) - let legacy := winUb * globalDmax * woutUb - let winScaledOneInf := Float.sqrt (max 0.0 (maxWinScaledCol * maxWinRowScaled)) - let woutScaledOneInf := Float.sqrt (max 0.0 (maxWoutScaledRow * maxWoutColScaled)) - let winScaledUb := min winScaledFrob winScaledOneInf - let woutScaledUb := min woutScaledFrob woutScaledOneInf - let scaledViaWin := winScaledUb * woutUb - let scaledViaWout := winUb * woutScaledUb - let scaled0 := - if scaledViaWin ≤ 0.0 || Float.isNaN scaledViaWin || Float.isInf scaledViaWin then - scaledViaWout - else if scaledViaWout ≤ 0.0 || Float.isNaN scaledViaWout || Float.isInf scaledViaWout then - scaledViaWin - else - min scaledViaWin scaledViaWout - let scaled := - if scaled0 ≤ 0.0 || Float.isNaN scaled0 || Float.isInf scaled0 then legacy else scaled0 - let out := - if absSchur ≤ 0.0 || Float.isNaN absSchur || Float.isInf absSchur then - min legacy scaled - else - min absSchur (min legacy scaled) - return out - -/-- Diagnostics for the MLP absolute-Schur candidate bound `sqrt(‖J‖₁‖J‖∞)`. -/ -structure MLPOpAbsSchurDiag where - dMax : Float - boundInf : Float - boundOne : Float - absSchur : Float - -/-- Compute the absolute-Schur candidate for `‖W_in·diag(gelu'(z))·W_out‖₂` without using -any operator-norm bounds on the weights (diagnostics helper). - -This matches the `absSchur` computation inside `computeMLPLayerOpNormFromGeluDerivWithOpBounds`. --/ -def computeMLPOpAbsSchurDiag (layer : ConcreteMLPLayer) (geluDeriv : ConcreteMatrix) : MLPOpAbsSchurDiag := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - if d = 0 || h = 0 || geluDeriv.numRows = 0 || geluDeriv.numCols ≠ h then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - - let rows := geluDeriv.numRows - let dMaxVec : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for i in [:rows] do - let base := i * h - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - if a > out[k]! then - out := out.set! k a - out - let globalDmax : Float := dMaxVec.foldl (fun m x => max m x) 0.0 - if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then - return { dMax := 0.0, boundInf := 0.0, boundOne := 0.0, absSchur := 0.0 } - - let sOut : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for k in [:h] do - let rowBase := k * d - let mut s : Float := 0.0 - for c in [:d] do - s := s + Float.abs (layer.W_out.data[rowBase + c]!) - out := out.set! k s - out - let sIn : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for r in [:d] do - let rowBase := r * h - for k in [:h] do - out := out.set! k (out[k]! + Float.abs (layer.W_in.data[rowBase + k]!)) - out - - let boundInf : Float := Id.run do - let mut maxRow : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMaxVec[k]! * sOut[k]! - s := s + Float.abs (layer.W_in.data[rowBase + k]!) * coeff - if s > maxRow then - maxRow := s - maxRow - - let boundOne : Float := Id.run do - let mut maxCol : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMaxVec[k]! * sIn[k]! - s := s + Float.abs (layer.W_out.data[k * d + c]!) * coeff - if s > maxCol then - maxCol := s - maxCol - - let absSchur := Float.sqrt (max 0.0 (boundInf * boundOne)) - return { dMax := globalDmax, boundInf := boundInf, boundOne := boundOne, absSchur := absSchur } - -/-- Precomputed, token-batch–dependent factors for bounding `‖W_in · diag(gelu'(z)) · W_out‖₂`. - -This is used by the adaptive scheduler to avoid recomputing weight-dependent summaries -(`absSchur`, scaled Frobenius/Schur bounds) across effort-tier upgrades. - -Correctness: the final upper bound is still computed as a minimum of rigorous candidates; -this structure merely memoizes pieces of those candidates. --/ -structure MLPJacobianBoundCore where - /-- `max_k dMax[k]` where `dMax[k] = max_token |gelu'(z)[token,k]|`. -/ - globalDmax : Float - /-- Candidate `sqrt(‖J‖₁‖J‖∞)` computed from absolute row/col sums (independent of `winUb/woutUb`). -/ - absSchur : Float - /-- Upper bound on `‖W_in · diag(dMax)‖₂` (min of scaled Frobenius and scaled Schur/one-inf). -/ - winScaledUb : Float - /-- Upper bound on `‖diag(dMax) · W_out‖₂` (min of scaled Frobenius and scaled Schur/one-inf). -/ - woutScaledUb : Float - deriving Repr - -/-- Precompute `MLPJacobianBoundCore` for a layer and a given `geluDeriv` matrix. - -Returns `none` if dimensions don't match; callers should conservatively fall back. --/ -def ConcreteMLPLayer.precomputeJacobianBoundCore (layer : ConcreteMLPLayer) - (geluDeriv : ConcreteMatrix) : Option MLPJacobianBoundCore := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - if d = 0 || h = 0 || geluDeriv.numRows = 0 || geluDeriv.numCols ≠ h then - return none - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return none - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return none - - let rows := geluDeriv.numRows - - -- dMax[k] = max_token |gelu'(z)[token,k]|. - let dMax : Array Float := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - for i in [:rows] do - let base := i * h - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - if a > out[k]! then - out := out.set! k a - out - let globalDmax : Float := dMax.foldl (fun m x => max m x) 0.0 - if globalDmax ≤ 0.0 || Float.isNaN globalDmax || Float.isInf globalDmax then - return none - - -- sOut[k] = ∑_c |W_out[k,c]| (row sums of |W_out|). - let (sOut, woutRowSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for k in [:h] do - let rowBase := k * d - let mut s : Float := 0.0 - let mut ss : Float := 0.0 - for c in [:d] do - let w := layer.W_out.data[rowBase + c]! - s := s + Float.abs w - ss := ss + w * w - out := out.set! k s - sq := sq.set! k ss - (out, sq) - - -- sIn[k] = ∑_r |W_in[r,k]| (column sums of |W_in|). - let (sIn, winColSqSum) : (Array Float × Array Float) := Id.run do - let mut out : Array Float := Array.replicate h 0.0 - let mut sq : Array Float := Array.replicate h 0.0 - for r in [:d] do - let rowBase := r * h - for k in [:h] do - let w := layer.W_in.data[rowBase + k]! - out := out.set! k (out[k]! + Float.abs w) - sq := sq.set! k (sq[k]! + w * w) - (out, sq) - - let (winScaledFrobTokMax, woutScaledFrobTokMax) : (Float × Float) := Id.run do - let mut winMaxSq : Float := 0.0 - let mut woutMaxSq : Float := 0.0 - for i in [:rows] do - let base := i * h - let mut winSq : Float := 0.0 - let mut woutSq : Float := 0.0 - for k in [:h] do - let a := Float.abs (geluDeriv.data[base + k]!) - let aa := a * a - winSq := winSq + aa * winColSqSum[k]! - woutSq := woutSq + aa * woutRowSqSum[k]! - winMaxSq := max winMaxSq winSq - woutMaxSq := max woutMaxSq woutSq - (Float.sqrt (max 0.0 winMaxSq), Float.sqrt (max 0.0 woutMaxSq)) - - -- Max row-sum bound (‖J‖₁ in row-vector convention). - let boundInf : Float := Id.run do - let mut maxRow : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sOut[k]! - let a := Float.abs (layer.W_in.data[rowBase + k]!) - s := s + a * coeff - if s > maxRow then - maxRow := s - maxRow - - -- Max column-sum bound (‖J‖∞ in row-vector convention). - let boundOne : Float := Id.run do - let mut maxCol : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - for k in [:h] do - let coeff := dMax[k]! * sIn[k]! - let a := Float.abs (layer.W_out.data[k * d + c]!) - s := s + a * coeff - if s > maxCol then - maxCol := s - maxCol - - let absSchur : Float := Float.sqrt (max 0.0 (boundInf * boundOne)) - - let winScaledFrob : Float := winScaledFrobTokMax - let woutScaledFrob : Float := woutScaledFrobTokMax - - let maxWinScaledCol : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sIn[k]!) - m - let maxWoutScaledRow : Float := Id.run do - let mut m : Float := 0.0 - for k in [:h] do - m := max m (dMax[k]! * sOut[k]!) - m - let maxWinRowScaled : Float := Id.run do - let mut m : Float := 0.0 - for r in [:d] do - let rowBase := r * h - let mut s : Float := 0.0 - for k in [:h] do - s := s + Float.abs (layer.W_in.data[rowBase + k]!) * dMax[k]! - m := max m s - m - let maxWoutColScaled : Float := Id.run do - let mut m : Float := 0.0 - for c in [:d] do - let mut s : Float := 0.0 - for k in [:h] do - s := s + Float.abs (layer.W_out.data[k * d + c]!) * dMax[k]! - m := max m s - m - - let winScaledOneInf := Float.sqrt (max 0.0 (maxWinScaledCol * maxWinRowScaled)) - let woutScaledOneInf := Float.sqrt (max 0.0 (maxWoutScaledRow * maxWoutColScaled)) - let winScaledUb := min winScaledFrob winScaledOneInf - let woutScaledUb := min woutScaledFrob woutScaledOneInf - - if absSchur ≤ 0.0 || Float.isNaN absSchur || Float.isInf absSchur then - return none - if winScaledUb ≤ 0.0 || Float.isNaN winScaledUb || Float.isInf winScaledUb then - return none - if woutScaledUb ≤ 0.0 || Float.isNaN woutScaledUb || Float.isInf woutScaledUb then - return none - - return some { globalDmax := globalDmax, absSchur := absSchur, - winScaledUb := winScaledUb, woutScaledUb := woutScaledUb } - -/-- Fused MLP Jacobian upper bound, given the cached GeLU derivative matrix `gelu'(z)`. -/ -def computeMLPLayerOpNormFromGeluDeriv (layer : ConcreteMLPLayer) (geluDeriv : ConcreteMatrix) : Float := Id.run do - let winUb := layer.W_in.opNormUpperBoundRectGram - let woutUb := layer.W_out.opNormUpperBoundRectGram - computeMLPLayerOpNormFromGeluDerivWithOpBounds layer geluDeriv winUb woutUb - -/-- Compute the MLP Jacobian bound by re-evaluating `gelu'(z)` from a concrete input. -/ -def computeMLPLayerOpNorm (layer : ConcreteMLPLayer) (input : ConcreteMatrix) : Float := Id.run do - let d := layer.modelDim - let h := layer.hiddenDim - if d = 0 || h = 0 || input.numRows = 0 then - return 0.0 - if layer.W_in.numRows ≠ d || layer.W_in.numCols ≠ h then - return 0.0 - if layer.W_out.numRows ≠ h || layer.W_out.numCols ≠ d then - return 0.0 - let hidden := (input.matmul layer.W_in).addBias layer.b_in - let geluDeriv := hidden.map geluDerivFloat - let winUb := layer.W_in.opNormUpperBoundRectGram - let woutUb := layer.W_out.opNormUpperBoundRectGram - computeMLPLayerOpNormFromGeluDerivWithOpBounds layer geluDeriv winUb woutUb - -/-- Forward pass for an MLP layer with GeLU activation. - -Input: X (seqLen × modelDim) -Output: Y (seqLen × modelDim) where Y = W_out · GeLU(W_in · X + b_in) + b_out - -For each position, this computes the standard transformer FFN. --/ -def ConcreteMLPLayer.forward (layer : ConcreteMLPLayer) (input : ConcreteMatrix) : ConcreteMatrix := - -- hidden = input · W_in + b_in (seqLen × hiddenDim) - let hidden := (input.matmul layer.W_in).addBias layer.b_in - -- Apply GeLU activation - let activated := hidden.map geluFloat - -- output = activated · W_out + b_out (seqLen × modelDim) - (activated.matmul layer.W_out).addBias layer.b_out - -/-- Forward pass plus a data-dependent bound on `max |gelu'(z)|` over this layer's preactivations. - -The returned derivative maximum is exact for the computed `hidden` matrix entries (interpreting -Float arithmetic as defining a concrete real expression, consistent with this file's conventions). -If NaN/Inf is encountered, callers should conservatively fall back to a global bound. --/ -def ConcreteMLPLayer.forwardWithGeluDerivMax (layer : ConcreteMLPLayer) - (input : ConcreteMatrix) : (ConcreteMatrix × Float) := Id.run do - let hidden := (input.matmul layer.W_in).addBias layer.b_in - let mut maxDeriv : Float := 0.0 - for z in hidden.data do - let d := Float.abs (geluDerivFloat z) - if d > maxDeriv then - maxDeriv := d - let activated := hidden.map geluFloat - let out := (activated.matmul layer.W_out).addBias layer.b_out - return (out, maxDeriv) - -/-! ## Efficient Pattern Term Bound Calculation - -The core insight: the valueTerm Frobenius norm factors as `‖A‖_F · ‖W_V·W_O‖_F`. -This avoids computing the full (N·D)² matrix! --/ - -/-- Compute ‖valueTerm‖_F efficiently via factorization. -/ -def computeValueTermNorm (attn : ConcreteAttentionWeights) - (valueOutputProjFrobNormSq : Float) : Float := - let attnNormSq := attn.weights.foldl (fun acc x => acc + x * x) 0.0 - Float.sqrt (attnNormSq * valueOutputProjFrobNormSq) - -/-- Information needed to bound the pattern term. -/ -structure PatternTermBoundInputs where - /-- Attention weights -/ - attention : ConcreteAttentionWeights - /-- Input embedding norm (‖X‖_F) -/ - inputNorm : Float - /-- Input embedding operator-norm upper bound (‖X‖₂), used for tighter pattern-term bounds. -/ - inputOpBound : Float - /-- Optional direct Frobenius norm bound for Q = X·W_Q + b_Q, if available. -/ - qFrobBound : Float := 0.0 - /-- Optional direct Frobenius norm bound for K = X·W_K + b_K, if available. -/ - kFrobBound : Float := 0.0 - /-- Optional Frobenius norm bound for centered V = X·W_V + b_V, if available. -/ - vFrobBound : Float := 0.0 - /-- Optional direct operator-norm bound for Q = X·W_Q + b_Q, if available. -/ - qOpBoundAct : Float := 0.0 - /-- Optional direct operator-norm bound for K = X·W_K + b_K, if available. -/ - kOpBoundAct : Float := 0.0 - /-- Optional operator-norm bound for centered V = X·W_V + b_V, if available. -/ - vOpBoundAct : Float := 0.0 - /-- Optional Frobenius bound for `‖W_Q·Kᵀ‖_F` after centering keys, if available. -/ - qkActFrobBound : Float := 0.0 - /-- Optional Frobenius bound for `‖W_K·Qᵀ‖_F` after centering queries, if available. -/ - kqActFrobBound : Float := 0.0 - /-- Optional operator-norm bound for `‖W_Q·Kᵀ‖₂` after centering keys, if available. -/ - qkActOpBound : Float := 0.0 - /-- Optional operator-norm bound for `‖W_K·Qᵀ‖₂` after centering queries, if available. -/ - kqActOpBound : Float := 0.0 - /-- Scaling factor (√d_head) -/ - scaleFactor : Float - /-- Deterministic Float upper bound on ‖W_Q‖₂. -/ - wqOpBound : Float - /-- Deterministic Float upper bound on ‖W_K‖₂. -/ - wkOpBound : Float - /-- Deterministic Float upper bound on ‖W_V‖₂. -/ - wvOpBound : Float - /-- Deterministic Float upper bound on ‖W_O‖₂. -/ - woOpBound : Float - /-- Deterministic Float upper bound on ‖W_V·W_O‖₂, if available. -/ - voOpBound : Float := 0.0 - /-- Query bias Frobenius norm (for the 1×headDim bias row). -/ - bqFrob : Float - /-- Key bias Frobenius norm (for the 1×headDim bias row). -/ - bkFrob : Float - /-- Value bias Frobenius norm (for the 1×headDim bias row). -/ - bvFrob : Float - -/-- Selected intermediate bounds for the pattern-term calculation. -/ -structure PatternTermBoundParts where - /-- Selected operator-norm bound for `‖W_Q·Kᵀ‖₂` after all min-combination. -/ - qkActOpUb : Float - /-- Selected operator-norm bound for `‖W_K·Qᵀ‖₂` after all min-combination. -/ - kqActOpUb : Float - /-- Selected operator-norm bound for centered `V`. -/ - vOpUb : Float - /-- Selected operator-norm bound for centered `V·W_O`. -/ - vOpUbWO : Float - /-- Frobenius-style candidate bound. -/ - candFrob : Float - /-- Operator-style candidate bound (using `vOpUb`). -/ - candOp : Float - /-- Operator-style candidate bound (using `vOpUbWO`). -/ - candOpWO : Float - /-- Final chosen pattern-term bound. -/ - patternBound : Float - -/-- Compute pattern-term bound and expose the selected intermediate bounds. -/ -def computePatternTermBoundParts (inputs : PatternTermBoundInputs) : PatternTermBoundParts := - -- Data-dependent bound on the softmax Jacobian operator norm. - -- Provable global bound: for any probability vector p, J = diag(p) - p pᵀ has ‖J‖₂ ≤ 1/2. - -- Clamp defensively so callers cannot accidentally exceed this. - let softmaxBound := min inputs.attention.softmaxJacobianOpEst 0.5 - let n := inputs.attention.seqLen - let sqrtN := Float.sqrt n.toFloat - -- Bias-aware Frobenius upper bounds on the concrete Q/K/V activations. - -- If the caller provides direct (data-dependent) bounds on ‖Q‖_F or ‖K‖_F, we - -- take the minimum (still a rigorous upper bound in exact ℝ arithmetic). - let qFrobUb0 := inputs.inputNorm * inputs.wqOpBound + sqrtN * inputs.bqFrob - let kFrobUb0 := inputs.inputNorm * inputs.wkOpBound + sqrtN * inputs.bkFrob - let qFrobUb := - if inputs.qFrobBound ≤ 0.0 || Float.isNaN inputs.qFrobBound || Float.isInf inputs.qFrobBound then - qFrobUb0 - else - min qFrobUb0 inputs.qFrobBound - let kFrobUb := - if inputs.kFrobBound ≤ 0.0 || Float.isNaN inputs.kFrobBound || Float.isInf inputs.kFrobBound then - kFrobUb0 - else - min kFrobUb0 inputs.kFrobBound - let vFrobUb0 := inputs.inputNorm * inputs.wvOpBound + sqrtN * inputs.bvFrob - let vFrobUb := - if inputs.vFrobBound ≤ 0.0 || Float.isNaN inputs.vFrobBound || Float.isInf inputs.vFrobBound then - vFrobUb0 - else - min vFrobUb0 inputs.vFrobBound - -- Bias-aware operator-norm upper bounds on Q/K/V: - -- ‖X·W_Q + 1·b_Q‖₂ ≤ ‖X‖₂‖W_Q‖₂ + √n·‖b_Q‖₂. - let qOpUb0 := inputs.inputOpBound * inputs.wqOpBound + sqrtN * inputs.bqFrob - let kOpUb0 := inputs.inputOpBound * inputs.wkOpBound + sqrtN * inputs.bkFrob - let qOpUb := - if inputs.qOpBoundAct ≤ 0.0 || Float.isNaN inputs.qOpBoundAct || Float.isInf inputs.qOpBoundAct then - qOpUb0 - else - min qOpUb0 inputs.qOpBoundAct - let kOpUb := - if inputs.kOpBoundAct ≤ 0.0 || Float.isNaN inputs.kOpBoundAct || Float.isInf inputs.kOpBoundAct then - kOpUb0 - else - min kOpUb0 inputs.kOpBoundAct - let qkActOpUb0 := inputs.wqOpBound * kOpUb - let qkActOpUb1 := - if inputs.qkActOpBound ≤ 0.0 || Float.isNaN inputs.qkActOpBound || - Float.isInf inputs.qkActOpBound then - qkActOpUb0 - else - min qkActOpUb0 inputs.qkActOpBound - let qkActOpUb := - if inputs.qkActFrobBound ≤ 0.0 || Float.isNaN inputs.qkActFrobBound || - Float.isInf inputs.qkActFrobBound then - qkActOpUb1 - else - min qkActOpUb1 inputs.qkActFrobBound - let kqActOpUb0 := inputs.wkOpBound * qOpUb - let kqActOpUb1 := - if inputs.kqActOpBound ≤ 0.0 || Float.isNaN inputs.kqActOpBound || - Float.isInf inputs.kqActOpBound then - kqActOpUb0 - else - min kqActOpUb0 inputs.kqActOpBound - let kqActOpUb := - if inputs.kqActFrobBound ≤ 0.0 || Float.isNaN inputs.kqActFrobBound || - Float.isInf inputs.kqActFrobBound then - kqActOpUb1 - else - min kqActOpUb1 inputs.kqActFrobBound - let vOpUb0 := inputs.inputOpBound * inputs.wvOpBound + sqrtN * inputs.bvFrob - let vOpUb := - if inputs.vOpBoundAct ≤ 0.0 || Float.isNaN inputs.vOpBoundAct || Float.isInf inputs.vOpBoundAct then - vOpUb0 - else - min vOpUb0 inputs.vOpBoundAct - let vOpUbWO0 := inputs.inputOpBound * inputs.voOpBound + - sqrtN * inputs.bvFrob * inputs.woOpBound - let vOpUbWO := - if inputs.voOpBound ≤ 0.0 || Float.isNaN inputs.voOpBound || Float.isInf inputs.voOpBound then - vOpUb * inputs.woOpBound - else - min (vOpUb * inputs.woOpBound) vOpUbWO0 - -- Logits sensitivity: - -- dS = (dQ)Kᵀ + Q(dK)ᵀ, with ‖dQ‖_F ≤ ‖dX‖_F‖W_Q‖₂ and ‖dK‖_F ≤ ‖dX‖_F‖W_K‖₂. - let sCoeffFrob := (inputs.wqOpBound * kFrobUb + inputs.wkOpBound * qFrobUb) / inputs.scaleFactor - let sCoeffOp := (qkActOpUb + kqActOpUb) / inputs.scaleFactor - -- Two rigorous candidates: - -- (A) Frobenius-style activation bounds: - -- ‖dA·V·W_O‖_F ≤ ‖dA‖_F‖V‖_F‖W_O‖₂. - let candFrob := (softmaxBound * sCoeffFrob) * vFrobUb * inputs.woOpBound - -- (B) Operator-style activation bounds (often much tighter): - -- ‖dS‖_F ≤ ‖dX‖_F · (‖W_Q‖₂‖K‖₂ + ‖W_K‖₂‖Q‖₂)/scale, - -- ‖dA·V‖_F ≤ ‖dA‖_F‖V‖₂, - -- so ‖dA·V·W_O‖_F ≤ ‖J_softmax‖₂‖dS‖_F‖V‖₂‖W_O‖₂. - let candOp := (softmaxBound * sCoeffOp) * vOpUb * inputs.woOpBound - let candOpWO := (softmaxBound * sCoeffOp) * vOpUbWO - let patternBound := min candFrob (min candOp candOpWO) - { qkActOpUb := qkActOpUb - kqActOpUb := kqActOpUb - vOpUb := vOpUb - vOpUbWO := vOpUbWO - candFrob := candFrob - candOp := candOp - candOpWO := candOpWO - patternBound := patternBound } - -/-- Bound ‖patternTerm‖_F without expanding the full Jacobian. - -The pattern term arises from how attention weights A change when input X changes: - patternTerm = (∂A/∂X) ⊗ (V·W_O) - -The key insight is that ∂A/∂X involves the softmax Jacobian, which is bounded by -the "softness" of the attention distribution. For sparse (one-hot) attention, -the softmax Jacobian is nearly zero, giving much tighter bounds. - -**Sparsity-aware bound**: - ‖patternTerm‖_F ≤ (‖J_softmax‖₂ / scale) · ‖X‖_F · ‖W_Q·W_K^T‖₂ · ‖W_V·W_O‖₂ - -We use a tight, data-dependent bound on `‖J_softmax‖₂` per row of attention: -for a probability row `p`, `J = diag(p) - p pᵀ` and -`‖J‖₂ ≤ min(maxᵢ pᵢ, 1 - Σᵢ pᵢ²)`. - -- Perfectly one-hot rows give `‖J‖₂ = 0`. -- Uniform rows give `‖J‖₂ = 1/n`. -- Worst-case (all n): `‖J‖₂ ≤ 0.5`. --/ -def computePatternTermBound (inputs : PatternTermBoundInputs) : Float := - (computePatternTermBoundParts inputs).patternBound - -/-- Bound ‖patternTerm‖_F using the old pessimistic constant bound. - -This uses the worst-case softmax Jacobian spectral-norm bound of 0.5, which is valid but loose. -Prefer `computePatternTermBound` for tighter data-dependent bounds. --/ -def computePatternTermBoundPessimistic (inputs : PatternTermBoundInputs) : Float := - let softmaxBound : Float := 0.5 -- Worst-case softmax Jacobian spectral norm - let n := inputs.attention.seqLen - let sqrtN := Float.sqrt n.toFloat - let qFrobUb := inputs.inputNorm * inputs.wqOpBound + sqrtN * inputs.bqFrob - let kFrobUb := inputs.inputNorm * inputs.wkOpBound + sqrtN * inputs.bkFrob - let vFrobUb0 := inputs.inputNorm * inputs.wvOpBound + sqrtN * inputs.bvFrob - let vFrobUb := - if inputs.vFrobBound ≤ 0.0 || Float.isNaN inputs.vFrobBound || Float.isInf inputs.vFrobBound then - vFrobUb0 - else - min vFrobUb0 inputs.vFrobBound - let sCoeff := (inputs.wqOpBound * kFrobUb + inputs.wkOpBound * qFrobUb) / inputs.scaleFactor - (softmaxBound * sCoeff) * vFrobUb * inputs.woOpBound - -/-- Compute faithfulness ratio: ‖patternTerm‖_F / ‖valueTerm‖_F. -/ -def computeFaithfulnessRatio (inputs : PatternTermBoundInputs) (valueOutputProjFrobNormSq : Float) : Float := - let patternBound := computePatternTermBound inputs - let valueNorm := computeValueTermNorm inputs.attention valueOutputProjFrobNormSq - if valueNorm < 1e-10 then Float.inf else patternBound / valueNorm - -/-! ## Discovery Structures -/ - -/-- Result of discovering a potential induction head pair. -/ -structure CandidateInductionHead where - /-- Layer index of the "previous token" head (L1) -/ - layer1Idx : Nat - /-- Layer index of the "induction" head (L2) -/ - layer2Idx : Nat - /-- Head index within layer 1 -/ - head1Idx : Nat - /-- Head index within layer 2 -/ - head2Idx : Nat - /-- Faithfulness ratio ε₁ for L1: ‖PatternTerm‖_F / ‖ValueTerm‖_F -/ - patternBound1 : Float - /-- Faithfulness ratio ε₂ for L2: ‖PatternTerm‖_F / ‖ValueTerm‖_F -/ - patternBound2 : Float - /-- Combined relative error: (1+ε₁)(1+ε₂) - 1 = ε₁ + ε₂ + ε₁·ε₂ -/ - combinedError : Float - /-- Previous-token strength: avg A₁[i, i-1] -/ - prevTokenStrength : Float - /-- Induction "copy-next" pattern score for head 2 (prompt-dependent). -/ - inductionScore : Float - /-- K-composition score between head 1 and head 2, as in the circuits framework paper: - - `kComp_raw = ‖W_QK² · W_OV¹‖_F / (‖W_QK²‖_F · ‖W_OV¹‖_F)`, - - then we subtract the random-baseline `1/√modelDim`: - - `kComp = kComp_raw - 1/√modelDim`. - - This measures how strongly head 1 can feed information into head 2's QK circuit, - i.e. whether head 1 plausibly acts as a **pattern enabler** for head 2. - -/ - kComp : Float - /-- Description of the discovered pattern -/ - description : String - -/-- A verified induction head that meets the certification threshold. -/ -structure VerifiedInductionHead where - /-- The candidate that was verified -/ - candidate : CandidateInductionHead - /-- The certification threshold used -/ - threshold : Float - /-- Combined error is below threshold (runtime-checked) -/ - errorChecked : Bool - -/-- An induction head candidate with an explicit effectiveness score `δ` on a target direction. - -This is produced by the Float-based discovery pipeline and should be interpreted as a -**heuristic ranking**, not a proof-grade certification. --/ -structure HeuristicInductionHead where - /-- The discovered candidate pair (pattern-checked, heuristically) -/ - candidate : CandidateInductionHead - /-- Raw effectiveness score `δ` on the target direction (Float). -/ - delta : Float - /-- Scale-invariant effectiveness score (Float): - - `effect = δ / (‖ln₁(X₂)‖_F · ‖u‖₂)`, - - where `X₂` is the layer-2 input residual stream and `u` is the target direction. - This isolates the **mechanism** (virtual-head computation) from residual-stream energy. - -/ - effect : Float - /-- Frobenius norm of the layer-2 input residual stream `‖X₂‖_F` (Float). -/ - layer2InputNorm : Float - /-- Frobenius norm of the Pre-LN attention input `‖ln₁(X₂)‖_F` (Float). -/ - layer2Ln1InputNorm : Float - -/-- Result of discovering a multi-layer circuit with N-layer error bounds. - -This extends `CandidateInductionHead` with rigorous N-layer amplification bounds -from the theorem `n_layer_faithfulness_composition`: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) --/ -structure DeepCircuitCandidate where - /-- Layer indices involved in the circuit (sorted) -/ - layerIndices : Array Nat - /-- Head indices at each layer -/ - headIndices : Array Nat - /-- Per-layer pattern term bounds (εᵢ) -/ - patternBounds : Array Float - /-- Per-layer residual operator norm upper bounds (Cᵢ) -/ - operatorNormUbs : Array Float - /-- Simple error sum: Σᵢ εᵢ (no amplification) -/ - simpleErrorSum : Float - /-- N-layer amplified error: Σᵢ εᵢ · ∏_{j>i}(1+Cⱼ) -/ - amplifiedError : Float - /-- Suffix amplification factor from the earliest layer in the circuit: - `∏_{j ≥ min layer}(1 + Cⱼ)`. -/ - amplificationFactor : Float - /-- Pattern type description (e.g., "induction", "composition") -/ - patternType : String - /-- Human-readable description -/ - description : String - -namespace DeepCircuitCandidate - -def toString (c : DeepCircuitCandidate) : String := - let heads := c.layerIndices.zip c.headIndices |>.map fun (l, h) => s!"L{l}H{h}" - s!"{c.patternType}: {heads.toList} | " ++ - s!"ε_simple={c.simpleErrorSum}, ε_amplified={c.amplifiedError}, amp={c.amplificationFactor}" - -instance : ToString DeepCircuitCandidate := ⟨toString⟩ - -end DeepCircuitCandidate - -/-! ## Discovery Algorithm -/ - -/-- Check if a layer exhibits "previous-token" attention pattern. -/ -def checkPrevTokenPattern (attn : ConcreteAttentionWeights) - (minStrength : Float := 0.3) : Option Float := - if attn.seqLen < 2 then none - else - let sum : Float := Id.run do - let n := attn.seqLen - let w := attn.weights - let mut acc : Float := 0.0 - for i in [:n - 1] do - -- SAFETY: `i < n-1` implies `i+1 < n`, so `(i+1,i)` is in-bounds. - acc := acc + w[(i + 1) * n + i]! - return acc - let avgStrength := sum / (attn.seqLen - 1).toFloat - if avgStrength ≥ minStrength then some avgStrength else none - -/-- Check if a head exhibits a **content-addressable** attention pattern. - -We say a head is content-addressable on a prompt when, for many query positions `q`, -it places substantial attention mass on *previous occurrences of the same token*: - -`score(q) = ∑_{k < q, tokens[k] = tokens[q]} A[q, k]`. - -The returned score is the average of `score(q)` over query positions that have at least -one previous occurrence. This is variable-lag by construction (no fixed positional lag). --/ -def checkContentAddressablePattern (tokens : Array Nat) (attn : ConcreteAttentionWeights) - (minScore : Float := 0.1) : Option Float := - if tokens.size ≠ attn.seqLen then none - else if attn.seqLen < 2 then none - else - let n := attn.seqLen - let w := attn.weights - let (sumScore, count) : (Float × Nat) := Id.run do - let mut sumScore : Float := 0.0 - let mut count : Nat := 0 - for q in [1:n] do - let tq := tokens[q]! - let rowBase := q * n - let mut hasPrev : Bool := false - let mut rowScore : Float := 0.0 - for k in [:q] do - if tokens[k]! = tq then - hasPrev := true - -- SAFETY: `q < n` and `k < q ≤ n` by loop bounds. - rowScore := rowScore + w[rowBase + k]! - if hasPrev then - sumScore := sumScore + rowScore - count := count + 1 - return (sumScore, count) - - if count = 0 then none - else - let avgScore := sumScore / count.toFloat - if avgScore ≥ minScore then some avgScore else none - -/-- Check if a head exhibits an **induction** ("copy-next") attention pattern. - -We say a head is induction-like on a prompt when, for many query positions `q`, it places -substantial attention mass on tokens *immediately after* previous occurrences of the same token: - -`score(q) = ∑_{k+1 < q, tokens[k] = tokens[q]} A[q, k+1]`. - -This is the token-level signature of the induction mechanism described in the transformer-circuits -literature: when the current token repeats, attend to the successor of the previous occurrence so -the head can **copy** that successor forward. - -The returned score is the average of `score(q)` over query positions that have at least one -previous occurrence with an in-bounds successor (`k+1 < q`). This is variable-lag by construction. --/ -def checkInductionCopyNextPattern (tokens : Array Nat) (attn : ConcreteAttentionWeights) - (minScore : Float := 0.1) : Option Float := - if tokens.size ≠ attn.seqLen then none - else if attn.seqLen < 3 then none - else - let n := attn.seqLen - let w := attn.weights - let (sumScore, count) : (Float × Nat) := Id.run do - let mut sumScore : Float := 0.0 - let mut count : Nat := 0 - -- Need `q ≥ 2` so there is room for a predecessor `k` with successor `k+1 < q`. - for q in [2:n] do - let tq := tokens[q]! - let rowBase := q * n - let mut hasPrevSucc : Bool := false - let mut rowScore : Float := 0.0 - -- Scan all earlier positions `k` whose successor is still < q. - for k in [:q - 1] do - if tokens[k]! = tq then - hasPrevSucc := true - -- Attend to the *successor* position `k+1`. - rowScore := rowScore + w[rowBase + (k + 1)]! - if hasPrevSucc then - sumScore := sumScore + rowScore - count := count + 1 - return (sumScore, count) - - if count = 0 then none - else - let avgScore := sumScore / count.toFloat - if avgScore ≥ minScore then some avgScore else none - -/-- Compute composed attention score for induction pattern detection. - -Generalizes over all possible repetition lags from 2 to n/2, computing the -"induction score" (average attention mass transferred from q to q+lag via -the two-layer circuit). Returns the maximum score across all lags. - -This enables detection of induction heads with arbitrary repetition periods, -not just lag-2 patterns. --/ -def checkInductionPattern (attn1 attn2 : ConcreteAttentionWeights) - (minScore : Float := 0.1) : Option Float := - if attn1.seqLen ≠ attn2.seqLen then none - else if attn1.seqLen < 3 then none - else - let n := attn1.seqLen - let maxLag := n / 2 - - -- Try all possible lags and find the maximum induction score. - -- - -- PERFORMANCE: this is called in an O(L²H²) search, so avoid `List.range.foldl`. - let maxScore : Float := Id.run do - let w1 := attn1.weights - let w2 := attn2.weights - let mut currentMax : Float := 0.0 - for lagIdx in [:maxLag - 1] do - let lag := lagIdx + 2 -- Start from lag=2 - if lag < n then - -- Compute average induction score for this lag - let validPositions := n - lag - let mut composedSum : Float := 0.0 - for q in [:validPositions] do - let q' := q + lag - let row2Base := q' * n - let mut composedToQ : Float := 0.0 - -- Column access into `w1` is strided, so avoid repeated multiplications. - let mut col1Idx := q - for j in [:n] do - -- SAFETY: `q' < n` and `j < n` by loop bounds. - let a2 := w2[row2Base + j]! - -- SAFETY: `j < n` and `q < n` by loop bounds. - let a1 := w1[col1Idx]! - composedToQ := composedToQ + a2 * a1 - col1Idx := col1Idx + n - composedSum := composedSum + composedToQ - let avgScore := composedSum / validPositions.toFloat - if avgScore > currentMax then - currentMax := avgScore - return currentMax - - if maxScore ≥ minScore then some maxScore else none - -/-- Multi-layer model with concrete weights. -/ -structure ConcreteModel where - /-- Number of layers -/ - numLayers : Nat - /-- Attention layers with their heads: layers[l] is array of heads in layer l -/ - layers : Array (Array ConcreteAttentionLayer) - /-- Attention output projection bias (c_proj.bias), one per layer (1×modelDim). - - In GPT-2, this bias is added **once per layer** after combining all heads, i.e.: - `attn_out = c_proj(concat(heads)) + bias`. - -/ - attnProjBias : Array ConcreteMatrix := #[] - /-- MLP layers: mlps[l] is the MLP in layer l (one per layer) -/ - mlps : Array ConcreteMLPLayer - /-- Pre-LN LayerNorm parameters before attention (ln_1), one per layer. -/ - ln1 : Array ConcreteLayerNormParams := #[] - /-- Pre-LN LayerNorm parameters before MLP (ln_2), one per layer. -/ - ln2 : Array ConcreteLayerNormParams := #[] - /-- Final LayerNorm parameters (ln_f) before unembedding. -/ - lnf : ConcreteLayerNormParams := ConcreteLayerNormParams.identity 0 - /-- Sequence length for analysis -/ - seqLen : Nat - /-- Optional ground-truth input token IDs for the prompt being analyzed. - - When present, this enables **self-supervised induction targeting** by choosing the - correct next-token prediction target from sequence history (see - `TargetDirection.fromInductionHistory`). - -/ - inputTokens : Option (Array Nat) := none - /-- Input embeddings (seqLen × modelDim) -/ - inputEmbeddings : ConcreteMatrix - /-- Unembedding (decoder) matrix (modelDim × vocabSize) for logit computation. - Maps final residual stream to vocabulary logits: logits = residual · W_U - Optional: if not provided, target-aware analysis is unavailable. -/ - unembedding : Option ConcreteMatrix := none - -namespace ConcreteModel - -/-- Model dimension (d), inferred from input embeddings. -/ -def modelDim (model : ConcreteModel) : Nat := - model.inputEmbeddings.numCols - -/-- Attention output bias for a layer, defaulting to zero. -/ -def attnProjBiasAt (model : ConcreteModel) (layerIdx : Nat) : ConcreteMatrix := - model.attnProjBias.getD layerIdx (ConcreteMatrix.zeros 1 model.modelDim) - -/-- Trim trailing all-zero embedding rows (common when `.nfpt` uses a fixed `seqLen` with padding). - -This is a **semantic no-op** for causal analysis of the *prefix* of real tokens: padded positions -occur after the prompt and are never attended to by earlier queries, but they can badly inflate -LayerNorm Jacobian bounds because zero rows have variance `0` (hence `1/sqrt(eps)` scaling). - -We only trim when we detect a sufficiently long suffix of (near-)zero rows to avoid accidental -truncation on legitimate data. --/ -def trimTrailingZeroEmbeddings (model : ConcreteModel) - (minTrailing : Nat := 8) (eps : Float := 1e-12) : ConcreteModel := Id.run do - let X := model.inputEmbeddings - if X.numRows = 0 || X.numCols = 0 then - return model - let mut k : Nat := 0 - let mut r : Nat := X.numRows - while r > 0 do - let r' := r - 1 - let m := ConcreteMatrix.rowMaxAbs X r' - if m ≤ eps then - k := k + 1 - r := r' - else - break - if k < minTrailing then - return model - let newLen := X.numRows - k - let toks? := model.inputTokens.map (fun ts => ts.take newLen) - return { model with - seqLen := newLen - inputEmbeddings := X.takeRows newLen - inputTokens := toks? } - -/-- Get ln_1 parameters for a layer, defaulting to identity. -/ -def ln1Params (model : ConcreteModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Get ln_2 parameters for a layer, defaulting to identity. -/ -def ln2Params (model : ConcreteModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Apply ln_1 to a residual stream (row-wise, per token). -/ -def applyLn1 (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Apply ln_2 to a residual stream (row-wise, per token). -/ -def applyLn2 (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln2Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Apply final ln_f to a residual stream (row-wise, per token). -/ -def applyLnf (model : ConcreteModel) (X : ConcreteMatrix) : ConcreteMatrix := - ConcreteMatrix.layerNormRowwise X model.lnf.gamma model.lnf.beta - -/-- Heuristic estimate for ln_1 Jacobian operator norm at a specific activation. -/ -def ln1OpBound (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : Float := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwiseOpEst X p.gamma - -/-- Heuristic estimate for ln_2 Jacobian operator norm at a specific activation. -/ -def ln2OpBound (model : ConcreteModel) (layerIdx : Nat) (X : ConcreteMatrix) : Float := - let p := model.ln2Params layerIdx - ConcreteMatrix.layerNormRowwiseOpEst X p.gamma - -end ConcreteModel - -/-- Get the number of neurons in the MLP at a given layer. -/ -def ConcreteModel.numNeuronsAtLayer (model : ConcreteModel) (layerIdx : Nat) : Nat := - if h : layerIdx < model.mlps.size then - model.mlps[layerIdx].hiddenDim - else 0 - -/-- Result of running a forward pass: the residual stream after each layer. - -`layerInputs[l]` is the input to layer l (the accumulated residual stream). -`layerInputs[0]` = inputEmbeddings (initial token embeddings) -`layerInputs[l+1]` = x_{l+1} in the Pre-LN recurrence. --/ -structure ForwardPassResult where - /-- Input to each layer. layerInputs[l] is what layer l receives. -/ - layerInputs : Array ConcreteMatrix - /-- Post-attention residual for each layer: `y_l = x_l + attn_out` (including attention output bias). -/ - postAttnResiduals : Array ConcreteMatrix - /-- Attention outputs per layer per head: attnOutputs[l][h] = output of head h at layer l -/ - attnOutputs : Array (Array ConcreteMatrix) - /-- MLP outputs per layer: mlpOutputs[l] = output of MLP at layer l -/ - mlpOutputs : Array ConcreteMatrix - /-- Per-layer GeLU derivative matrices over MLP preactivations. - - If a layer has an MLP, this is the matrix `gelu'(z)` where `z = ln₂(y_l)·W_in + b_in`. - Otherwise this entry is a `0×0` placeholder. - -/ - mlpActDeriv : Array ConcreteMatrix - /-- Per-layer maximum absolute GeLU derivative over MLP preactivations. - - Length is `numLayers`. For layers without an MLP, the entry is `0.0`. - -/ - mlpActDerivMax : Array Float - /-- Final output after all layers (after ln_f, i.e. what goes into unembedding). -/ - finalOutput : ConcreteMatrix - -/-- Run a full forward pass through the model, computing the residual stream at each layer. - -This is the key function that enables deep circuit analysis: layer N sees the accumulated -output of layers 0..N-1, not just the raw embeddings. - - For each layer l (Pre-LN, GPT-2 style): - 1. u = ln_1(x_l) - 2. y_l = x_l + Σₕ AttentionHead[l,h].forward(u) - 3. v = ln_2(y_l) - 4. x_{l+1} = y_l + MLP[l].forward(v) - - After all layers: output = ln_f(x_L) --/ -def ConcreteModel.runForward (model : ConcreteModel) - (causal : Bool := true) : ForwardPassResult := Id.run do - let mut layerInputs : Array ConcreteMatrix := Array.mkEmpty (model.numLayers + 1) - let mut postAttnResiduals : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut attnOutputs : Array (Array ConcreteMatrix) := Array.mkEmpty model.numLayers - let mut mlpOutputs : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut mlpActDeriv : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut mlpActDerivMax : Array Float := Array.mkEmpty model.numLayers - let mut residual := model.inputEmbeddings - layerInputs := layerInputs.push residual - - for l in [:model.numLayers] do - -- Pre-LN: attention sees ln_1(residual) - let attnInput := model.applyLn1 l residual - -- Compute attention outputs for all heads in this layer - let mut layerAttnOutputs : Array ConcreteMatrix := #[] - let rows := residual.numRows - let cols := residual.numCols - - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - let useParallelHeads := layerHeads.size >= 4 - layerAttnOutputs := - if useParallelHeads then - let tasks : Array (Task ConcreteMatrix) := - .ofFn fun i : Fin layerHeads.size => - Task.spawn (fun _ => (layerHeads[i]).forward attnInput causal) - tasks.map Task.get - else - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty layerHeads.size - for head in layerHeads do - outs := outs.push (head.forward attnInput causal) - return outs - - attnOutputs := attnOutputs.push layerAttnOutputs - - -- Add attention residual - let residualAfterAttn := - if layerAttnOutputs.isEmpty then - residual - else - let attnSum := ConcreteMatrix.sumMatrices layerAttnOutputs - let attnBias := model.attnProjBiasAt l - residual.add ((attnSum.addBias attnBias)) - postAttnResiduals := postAttnResiduals.push residualAfterAttn - - -- Compute MLP output - -- Pre-LN: MLP sees ln_2(residualAfterAttn) - let mlpInput := model.applyLn2 l residualAfterAttn - let mut mlpOut : ConcreteMatrix := ConcreteMatrix.zeros residual.numRows residual.numCols - if hm : l < model.mlps.size then - let mlp := model.mlps[l] - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let geluDeriv := hidden.map geluDerivFloat - let mut dmax : Float := 0.0 - for z in geluDeriv.data do - let d := Float.abs z - if d > dmax then - dmax := d - let activated := hidden.map geluFloat - let out := (activated.matmul mlp.W_out).addBias mlp.b_out - let dmax' : Float := - if Float.isNaN dmax || Float.isInf dmax then 0.0 else dmax - mlpActDeriv := mlpActDeriv.push geluDeriv - mlpActDerivMax := mlpActDerivMax.push dmax' - mlpOut := out - else - -- No MLP at this layer. - mlpActDeriv := mlpActDeriv.push (ConcreteMatrix.zeros 0 0) - mlpActDerivMax := mlpActDerivMax.push 0.0 - mlpOut := ConcreteMatrix.zeros residual.numRows residual.numCols - - mlpOutputs := mlpOutputs.push mlpOut - - -- Add MLP residual - residual := residualAfterAttn.add mlpOut - - -- Store input for next layer - layerInputs := layerInputs.push residual - - let finalOutput := model.applyLnf residual - { - layerInputs := layerInputs - postAttnResiduals := postAttnResiduals - attnOutputs := attnOutputs - mlpOutputs := mlpOutputs - mlpActDeriv := mlpActDeriv - mlpActDerivMax := mlpActDerivMax - finalOutput := finalOutput - } - -/-- Get the input to a specific layer from a forward pass result. -/ -def ForwardPassResult.getLayerInput (result : ForwardPassResult) - (layerIdx : Nat) : ConcreteMatrix := - if h : layerIdx < result.layerInputs.size then - result.layerInputs[layerIdx] - else ConcreteMatrix.zeros 0 0 - -/-- Get the cached MLP GeLU-derivative matrix for a layer, if present. -/ -def ForwardPassResult.getMlpGeluDeriv (result : ForwardPassResult) (layerIdx : Nat) : ConcreteMatrix := - if h : layerIdx < result.mlpActDeriv.size then - result.mlpActDeriv[layerIdx] - else - ConcreteMatrix.zeros 0 0 - -/-- Get the post-attention residual `y_l = x_l + attn_sum` for a layer. - -This is the input to `ln_2` in a Pre-LN transformer block. --/ -def ForwardPassResult.getPostAttnResidual (result : ForwardPassResult) - (layerIdx : Nat) : ConcreteMatrix := Id.run do - if h : layerIdx < result.postAttnResiduals.size then - result.postAttnResiduals[layerIdx] - else - let x := result.getLayerInput layerIdx - if h2 : layerIdx < result.attnOutputs.size then - let heads := result.attnOutputs[layerIdx] - if heads.isEmpty then x else x.add (ConcreteMatrix.sumMatrices heads) - else - x - -/-! ## N-Layer Error Amplification Computation - -These functions implement the N-layer faithfulness composition formula from -`Linearization.lean`. They must be defined early so they can be used by -discovery functions like `findDeepCircuitCandidates`. --/ - -/-- Compute suffix amplification factor: ∏_{j≥start} (1 + C_j) - -This is how much error from layer `start` gets amplified by subsequent layers. -When start ≥ normBounds.size, returns 1 (no amplification). --/ -def computeSuffixAmplification (normBounds : Array Float) (start : Nat) : Float := Id.run do - let mut product : Float := 1.0 - for j in [start:normBounds.size] do - if hj : j < normBounds.size then - product := product * (1.0 + normBounds[j]) - product - -/-- Compute total amplified error: Σᵢ εᵢ · suffixAmplification(i+1) - -This implements the N-layer faithfulness composition formula from -`Linearization.lean` theorem `n_layer_faithfulness_composition`. --/ -def computeTotalAmplifiedError (patternBounds normBounds : Array Float) : Float := Id.run do - if patternBounds.size = 0 then return 0.0 - let mut total : Float := 0.0 - for i in [:patternBounds.size] do - if hi : i < patternBounds.size then - let epsilon_i := patternBounds[i] - -- Suffix amplification from layer i+1 onwards - let suffix := computeSuffixAmplification normBounds (i + 1) - total := total + epsilon_i * suffix - total - -/-- Estimate the operator norm bound for a single attention layer. - -For an attention layer, the Jacobian includes both the attention pattern term -and the value projection. We estimate: - ‖J‖ ≤ ‖A‖_F · ‖W_V·W_O‖_op + ‖∂A/∂x‖ · ‖V·W_O‖, -so ‖I + J‖ ≤ 1 + ‖J‖. - -For simplicity, we use Frobenius norms as upper bounds. --/ -def estimateAttentionLayerNorm (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (causal : Bool := true) : Float := Id.run do - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - let mut totalNorm : Float := 0.0 - - -- Pre-LN: attention and MLP see normalized activations; account for LN Jacobian scaling. - let x := fwdResult.getLayerInput layerIdx - let y := fwdResult.getPostAttnResidual layerIdx - let ln1Bound := model.ln1OpBound layerIdx x - let ln2Bound := model.ln2OpBound layerIdx y - - -- Attention pattern/value bounds are computed at the Pre-LN attention input. - let attnInput := model.applyLn1 layerIdx x - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - - -- Sum contributions from all heads in this layer - for hidx in [:heads.size] do - if hh : hidx < heads.size then - let head := heads[hidx] - - -- QK operator norm bound via a 64×64 companion product. - -- `‖W_Q W_Kᵀ‖₂ = ‖W_Kᵀ W_Q‖₂` because `AB` and `BA` have the same nonzero - -- singular values. - let qkSmall : ConcreteMatrix := head.W_K.transpose.matmul head.W_Q - let qkNorm : Float := - min (qkSmall.opNormUpperBoundDenseBrauer) - (min (qkSmall.opNormUpperBoundDenseSchur) (qkSmall.opNormUpperBoundDenseFrob)) - - -- VO operator norm bound via a 64×64 companion product. - -- `‖W_V W_O‖₂ = ‖W_O W_V‖₂` because `AB` and `BA` have the same nonzero - -- singular values. - let voSmall : ConcreteMatrix := head.W_O.matmul head.W_V - let valueOutputProjNorm : Float := - min (voSmall.opNormUpperBoundDenseBrauer) - (min (voSmall.opNormUpperBoundDenseSchur) (voSmall.opNormUpperBoundDenseFrob)) - - -- Data-dependent softmax Jacobian operator bound (per-row max, clamped to 1/2). - let attn := head.computeAttentionWeights attnInput causal - let softmaxOpBound := min attn.softmaxJacobianOpDiag.opBound 0.5 - let scaleFactor := Float.sqrt head.headDim.toFloat - - -- Value-term operator bound: - -- ‖X ↦ A·X·(W_VW_O)‖ ≤ ‖A‖₂ · ‖W_VW_O‖₂. - -- For the attention matrix `A` (nonnegative row-stochastic), the max row sum is 1 - -- (`‖A‖₁` in row-vector convention, `‖A‖∞` in column-vector convention), so - -- `sqrt(‖A‖₁‖A‖∞)` is typically far tighter than `‖A‖_F`. - let attnFrob : Float := Id.run do - let mut s : Float := 0.0 - for w in attn.weights do - s := s + w * w - Float.sqrt (max 0.0 s) - let attnMaxRowSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for q in [:n] do - let mut s : Float := 0.0 - let rowBase := q * n - for k in [:n] do - s := s + Float.abs (attn.weights[rowBase + k]!) - maxSum := max maxSum s - maxSum - let attnMaxColSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for k in [:n] do - let mut s : Float := 0.0 - for q in [:n] do - s := s + Float.abs (attn.weights[q * n + k]!) - maxSum := max maxSum s - maxSum - let attnOneInf : Float := Float.sqrt (attnMaxRowSum * attnMaxColSum) - let attnOpUb : Float := min attnFrob attnOneInf - let valueTermUb := attnOpUb * valueOutputProjNorm - - let bnds := head.noDenseProductBounds - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := scaleFactor - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternTermUb := computePatternTermBound inputs - - totalNorm := totalNorm + ln1Bound * (valueTermUb + patternTermUb) - - -- Add MLP contribution if present - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - -- Use the fused Jacobian bound (data-dependent on `gelu'(z)`), but keep the legacy - -- product-of-norms as a (looser) fallback. - let winNormUb := mlp.W_in.opNormUpperBoundRectGram - let woutNormUb := mlp.W_out.opNormUpperBoundRectGram - let geluDeriv := fwdResult.getMlpGeluDeriv layerIdx - let mlpUb := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv winNormUb woutNormUb - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct winNormUb woutNormUb - totalNorm := totalNorm + ln2Bound * mlpUb - - totalNorm - else - return 0.0 - -/-- Diagnostics-only variant of `estimateAttentionLayerNorm`. - -This uses power iteration (`operatorNormHeuristicPI`) for the key operator norms, -to provide an "old PI vs new rigorous" comparison under `-d`. - -WARNING: This is **not** a certified upper bound and must never be used in -bound/certification codepaths. --/ -def estimateAttentionLayerNormHeuristicPI (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (causal : Bool := true) : Float := Id.run do - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - let mut totalNorm : Float := 0.0 - - let x := fwdResult.getLayerInput layerIdx - let y := fwdResult.getPostAttnResidual layerIdx - let ln1Bound := model.ln1OpBound layerIdx x - let ln2Bound := model.ln2OpBound layerIdx y - let attnInput := model.applyLn1 layerIdx x - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - - for hidx in [:heads.size] do - if hh : hidx < heads.size then - let head := heads[hidx] - let qkSmall : ConcreteMatrix := head.W_K.transpose.matmul head.W_Q - let qkPi := qkSmall.operatorNormHeuristicPI 20 - let voSmall : ConcreteMatrix := head.W_O.matmul head.W_V - let voPi := voSmall.operatorNormHeuristicPI 20 - - let attn := head.computeAttentionWeights attnInput causal - let softmaxOpBound := min attn.softmaxJacobianOpDiag.opBound 0.5 - let scaleFactor := Float.sqrt head.headDim.toFloat - let attnFrob : Float := Id.run do - let mut s : Float := 0.0 - for w in attn.weights do - s := s + w * w - Float.sqrt (max 0.0 s) - let attnMaxRowSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for q in [:n] do - let mut s : Float := 0.0 - let rowBase := q * n - for k in [:n] do - s := s + Float.abs (attn.weights[rowBase + k]!) - maxSum := max maxSum s - maxSum - let attnMaxColSum : Float := Id.run do - let n := attn.seqLen - let mut maxSum : Float := 0.0 - for k in [:n] do - let mut s : Float := 0.0 - for q in [:n] do - s := s + Float.abs (attn.weights[q * n + k]!) - maxSum := max maxSum s - maxSum - let attnOneInf : Float := Float.sqrt (attnMaxRowSum * attnMaxColSum) - let attnOpUb : Float := min attnFrob attnOneInf - let valueTermEst := attnOpUb * voPi - let patternTermEst := (softmaxOpBound / scaleFactor) * inputNorm * qkPi * voPi - totalNorm := totalNorm + ln1Bound * (valueTermEst + patternTermEst) - - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - -- Keep PI iterations low-ish here to avoid expensive diagnostics runs. - let winPi := mlp.W_in.operatorNormHeuristicPI 5 - let woutPi := mlp.W_out.operatorNormHeuristicPI 5 - let geluDerivBound : Float := - let d := fwdResult.mlpActDerivMax.getD layerIdx 1.7 - if d ≤ 0.0 || Float.isNaN d || Float.isInf d then 1.7 else d - totalNorm := totalNorm + ln2Bound * (winPi * geluDerivBound * woutPi) - - totalNorm - else - return 0.0 - -/-- Compute attention weights for a given layer and head using the correct layer input. - -This is the corrected version that uses the accumulated residual stream. --/ -def ConcreteModel.computeAttentionWithInput (model : ConcreteModel) - (layerIdx headIdx : Nat) (input : ConcreteMatrix) : Option ConcreteAttentionWeights := - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx]'h2 - -- Pre-LN: attention weights are computed from ln_1(input). - let attnInput := model.applyLn1 layerIdx input - some (head.computeAttentionWeights attnInput) - else none - else none - -/-- Compute attention weights for a given layer and head. - -This is a legacy helper that only uses the model's `inputEmbeddings`. -Prefer `computeAttentionWithInput` for Pre-LN-correct layer inputs. --/ -def ConcreteModel.computeAttention (model : ConcreteModel) - (layerIdx headIdx : Nat) : Option ConcreteAttentionWeights := - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let attnInput := model.applyLn1 layerIdx model.inputEmbeddings - some (head.computeAttentionWeights attnInput) - else none - else none - -/-- Compute input norm for bound calculations. -/ -def computeInputNorm (embeddings : ConcreteMatrix) : Float := - embeddings.frobeniusNorm - -/-! ## Precomputation Cache Structures - -To optimize the O(L²H²) nested loop in deep circuit discovery, we precompute and cache: -1. Attention patterns (A = softmax(QK^T/√d)) for each layer-head -2. Value-output projections (V·W_O) for each head -3. Query-key alignments (Q·K^T) for each head -4. Operator norm bounds and Frobenius norms - -This reduces redundant computation from O(L²H²) to O(LH), a massive improvement -for models like GPT-2 Small (12 layers × 12 heads = 144 heads → 20,736 → 144 calls). --/ - -/-- Precomputed data for a single attention head at a specific layer input. -/ -structure HeadDiagnosticsLazy where - /-- Dense `W_V·W_O` (diagnostics-only). -/ - voProj : Thunk ConcreteMatrix - /-- Dense `W_Q·W_Kᵀ` (diagnostics-only). -/ - qkAlign : Thunk ConcreteMatrix - /-- Dense Schur candidate `sqrt(‖M‖₁‖M‖∞)` for `W_V·W_O` (diagnostics-only). -/ - voDenseSchur : Thunk Float - /-- Dense Schur candidate `sqrt(‖M‖₁‖M‖∞)` for `W_Q·W_Kᵀ` (diagnostics-only). -/ - qkDenseSchur : Thunk Float - -structure PrecomputedHeadData where - /-- Layer index -/ - layerIdx : Nat - /-- Head index within layer -/ - headIdx : Nat - /-- Attention weights (A = softmax(QK^T/√d)) -/ - attention : ConcreteAttentionWeights - /-- Average previous-token attention strength: `(1/(n-1)) * Σᵢ A[i+1, i]`. -/ - prevTokenStrength : Float - /-- Cached softmax Jacobian operator-norm estimate for this head's attention rows. -/ - softmaxJacobianOpEst : Float - /-- Softmax Jacobian diagnostics: `max_i p_i` for the maximizing row. -/ - softmaxRowMaxP : Float - /-- Softmax Jacobian diagnostics: `tr(J) = 1 - ∑ p_i^2` for the maximizing row. -/ - softmaxRowTraceBound : Float - /-- Softmax Jacobian diagnostics: PSD moment bound for the maximizing row. -/ - softmaxRowMomentBound : Float - /-- Softmax Jacobian diagnostics: Gershgorin / `‖J‖_∞ = max_i 2 p_i (1-p_i)` - for the maximizing row. -/ - softmaxRowGershBound : Float - /-- Softmax Jacobian diagnostics: final per-row bound used for the maximizing row. -/ - softmaxRowBoundUsed : Float - /-- Number of rows that triggered a conservative fallback (NaN/Inf/zero-sum). -/ - softmaxRowsFallback : Nat - /-- Cached Frobenius norm squared of the attention matrix: `‖A‖_F²`. -/ - attentionFrobeniusNormSq : Float - /-- Cached `sqrt(‖A‖₁‖A‖∞)` upper bound on `‖A‖₂` (computed from max row/col sums). -/ - attentionOneInfBound : Float - /-- Cached pattern-term bound `‖PatternTerm‖_F` for this head at the cached input. -/ - patternTermBoundCached : Float - /-- Cached value-term Frobenius norm `‖ValueTerm‖_F` for this head. -/ - valueTermNormCached : Float - /-- Cached dimensionless faithfulness ratio `‖PatternTerm‖_F / ‖ValueTerm‖_F`. -/ - faithfulnessRatioCached : Float - /-- Optional detailed pattern-term bound diagnostics. -/ - patternBoundParts? : Option PatternTermBoundParts := none - /-- Optional lazy diagnostics. Never forced on the main path. -/ - diag? : Option HeadDiagnosticsLazy - /-- Cached Gram matrix `W_Qᵀ · W_Q` (headDim × headDim). -/ - wqGram : ConcreteMatrix - /-- Cached Gram matrix `W_Vᵀ · W_V` (headDim × headDim). -/ - wvGram : ConcreteMatrix - /-- Input norm ‖X‖_F for this layer -/ - inputNorm : Float - /-- Operator-norm upper bound ‖X‖₂ for this layer input (computed via 1/∞). -/ - inputOpBound : Float - /-- Direct (data-dependent) Frobenius norm of `Q = X·W_Q + b_Q` for this head. -/ - qFrobBound : Float - /-- Direct (data-dependent) Frobenius norm of `K = X·W_K + b_K` for this head. -/ - kFrobBound : Float - /-- Direct (data-dependent) Frobenius norm of centered `V = X·W_V + b_V` for this head. -/ - vFrobBound : Float - /-- Direct (data-dependent) operator-norm upper bound on `Q` (computed via a small Gram). -/ - qOpBoundAct : Float - /-- Direct (data-dependent) operator-norm upper bound on `K` (computed via a small Gram). -/ - kOpBoundAct : Float - /-- Direct (data-dependent) operator-norm upper bound on centered `V` (computed via a small Gram). -/ - vOpBoundAct : Float - /-- Frobenius bound for `‖W_Q·Kᵀ‖_F` from centered activations. -/ - qkActFrobBound : Float - /-- Frobenius bound for `‖W_K·Qᵀ‖_F` from centered activations. -/ - kqActFrobBound : Float - /-- Operator-norm bound for `‖W_Q·Kᵀ‖₂` from centered activation Grams. -/ - qkActOpBound : Float - /-- Operator-norm bound for `‖W_K·Qᵀ‖₂` from centered activation Grams. -/ - kqActOpBound : Float - /-- Selected candidate for `‖W_Q·Kᵀ‖₂` activation bound (diagnostics-only). -/ - qkActOpBoundSource : String := "unknown" - /-- Selected candidate for `‖W_K·Qᵀ‖₂` activation bound (diagnostics-only). -/ - kqActOpBoundSource : String := "unknown" - /-- Operator-norm bound for ln_1 Jacobian at this layer input. -/ - ln1OpBound : Float - /-- Scaling factor √d_head -/ - scaleFactor : Float - /-- Cached Frobenius norm of V·W_O -/ - valueOutputProjNorm : Float - /-- Cached Frobenius norm of Q·K^T -/ - queryKeyAlignNorm : Float - /-- Cached deterministic Float upper bound on ‖V·W_O‖₂. - - This is computed as the minimum of several valid upper bounds - (Schur / Frobenius / factor bounds). - -/ - valueOutputProjSchurNorm : Float - /-- Cached deterministic Float upper bound on ‖Q·K^T‖₂. - - This is computed as the minimum of several valid upper bounds - (Schur / Frobenius / factor bounds). - -/ - queryKeyAlignSchurNorm : Float - /-- Candidate bounds for `‖Q·Kᵀ‖₂` used in diagnostics. -/ - qkDenseFrob : Float - /-- Tight Gram-product candidate derived from 64×64 Gram matrices. -/ - qkDenseGram : Float - /-- Brauer/Cassini Gram candidate computed on a 64×64 matrix with matching singular values. -/ - qkDenseBrauer : Float - qkFactorSchur : Float - qkFactorFrob : Float - - /-- Gram-based operator bound for `W_Q` (min of Gersh/Brauer/moment candidates). -/ - wqOpGram : Float - /-- Gram-based operator bound for `W_K` (min of Gersh/Brauer/moment candidates). -/ - wkOpGram : Float - /-- Frobenius norm of the 1×headDim query bias row `b_Q`. -/ - bqFrob : Float - /-- Frobenius norm of the 1×headDim key bias row `b_K`. -/ - bkFrob : Float - /-- Frobenius norm of the 1×headDim value bias row `b_V`. -/ - bvFrob : Float - /-- Factorized Gram bound for `‖W_Q·W_Kᵀ‖₂`: `wqOpGram * wkOpGram`. -/ - qkFactorGram : Float - - /-- Candidate bounds for `‖W_V·W_O‖₂` used in diagnostics. -/ - voDenseFrob : Float - /-- Tight Gram-product candidate derived from 64×64 Gram matrices. -/ - voDenseGram : Float - /-- Brauer/Cassini Gram candidate computed on a 64×64 matrix with matching singular values. -/ - voDenseBrauer : Float - voFactorSchur : Float - voFactorFrob : Float - /-- Gram-based operator bound for `W_V` (min of Gersh/Brauer/moment candidates). -/ - wvOpGram : Float - /-- Gram-based operator bound for `W_O` (min of Gersh/Brauer/moment candidates). -/ - woOpGram : Float - /-- Factorized Gram bound for `‖W_V·W_O‖₂`: `wvOpGram * woOpGram`. -/ - voFactorGram : Float - -namespace PrecomputedHeadData - -/-- Precomputed pattern term bound for a head (cached computation). -/ -def patternTermBound (data : PrecomputedHeadData) : Float := - data.patternTermBoundCached - -/-- Frobenius norm of the Value Term of this head's Jacobian. - -For the attention linearization, the Value Term factorizes as `A ⊗ (W_V·W_O)`, -so `‖ValueTerm‖_F = ‖A‖_F · ‖W_V·W_O‖_F`. --/ -def valueTermNorm (data : PrecomputedHeadData) : Float := - data.valueTermNormCached - -/-- Dimensionless faithfulness ratio: `‖PatternTerm‖_F / ‖ValueTerm‖_F`. - -This matches `relativeApproximationError` from `Nfp.Linearization` and is the -quantity that should be compared to user-facing thresholds like `0.1`. --/ -def faithfulnessRatio (data : PrecomputedHeadData) : Float := - data.faithfulnessRatioCached - -end PrecomputedHeadData - -/-- Cache for all precomputed head data across all layers. - -Structure: `cache[layerIdx][headIdx]` gives the PrecomputedHeadData for that head. --/ -structure PrecomputedCache where - /-- Model this cache was built for -/ - model : ConcreteModel - /-- Forward pass result with layer inputs -/ - forwardResult : ForwardPassResult - /-- Cached Pre-LN attention inputs `ln_1(x_l)` for each layer `l`. -/ - ln1Inputs : Array ConcreteMatrix - /-- Precomputed data: cache[layerIdx][headIdx] -/ - headData : Array (Array PrecomputedHeadData) - /-- Precomputed operator norm bounds for each layer (for N-layer error amplification) -/ - layerNormBounds : Array Float - /-- Whether `layerNormBounds` were computed (otherwise the array is a placeholder). -/ - layerNormBoundsComputed : Bool - -namespace PrecomputedCache - -/-! ### Layer Jacobian-norm upper bounds (cached-head fast path) -/ - -/-- Compute the per-layer residual Jacobian norm upper bound `C_l` using cached head data. - -This matches `estimateAttentionLayerNorm` but avoids recomputing attention weights and -small-factor norms. The only tier-dependent cost is the MLP `W_in/W_out` operator bounds. --/ -def layerNormBoundAt (cache : PrecomputedCache) (layerIdx : Nat) - (effort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) : Float := Id.run do - let model := cache.model - let fwd := cache.forwardResult - let y := fwd.getPostAttnResidual layerIdx - let ln2Bound := model.ln2OpBound layerIdx y - let layerData := cache.headData.getD layerIdx #[] - let mut attnPart : Float := 0.0 - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := d.inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs - attnPart := attnPart + d.ln1OpBound * (valueTermUb + patternTermUb) - - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let winUb := mlp.W_in.opNormUpperBoundRectGramEffort effort - let woutUb := mlp.W_out.opNormUpperBoundRectGramEffort effort - let geluDeriv := fwd.getMlpGeluDeriv layerIdx - let mlpUb := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv winUb woutUb - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct winUb woutUb - let mlpPart := ln2Bound * mlpUb - return attnPart + (1.0 + attnPart) * mlpPart - else - return attnPart - -/-- Compute all per-layer residual Jacobian upper bounds `C_l` under a uniform effort tier. -/ -def computeLayerNormBounds (cache : PrecomputedCache) - (effort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) : - Array Float := Id.run do - let n := cache.model.numLayers - let useParallel := n >= 4 - if useParallel then - let tasks : Array (Task Float) := - .ofFn fun i : Fin n => - Task.spawn (fun _ => cache.layerNormBoundAt i.val effort) - tasks.map Task.get - else - let mut out : Array Float := Array.mkEmpty n - for l in [:n] do - out := out.push (cache.layerNormBoundAt l effort) - out - -/-- Build a complete precomputed cache for a model. - -This precomputes all attention patterns, projections, and norms once. --/ - def build (model : ConcreteModel) (causal : Bool := true) - (computeLayerNormBounds : Bool := true) - (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) - (storeDiagnostics : Bool := false) : - PrecomputedCache := Id.run do - let fwdResult := model.runForward causal - -- Parallelizing both layers and heads can lead to too many tasks; prefer layer-level parallelism. - let useParallelLayers := model.numLayers >= 4 - let computeLayer (l : Nat) : (Array PrecomputedHeadData × (Float × ConcreteMatrix)) := Id.run do - let layerInput := fwdResult.getLayerInput l - let attnInput := model.applyLn1 l layerInput - let inputNorm := computeInputNorm attnInput - -- Use a Gram-based bound here (tier-controlled) because the cheap 1/∞ estimate can be - -- extremely loose on dense `seqLen×modelDim` matrices and will not meaningfully tighten - -- the attention pattern-term bounds. - let inputOpBound := attnInput.opNormUpperBoundRectGramEffort layerNormEffort - let ln1Bound := model.ln1OpBound l layerInput - - let layerHeadData : Array PrecomputedHeadData := - if h : l < model.layers.size then - let heads := model.layers[l]'h - let computeHead (hIdx : Nat) (head : ConcreteAttentionLayer) : PrecomputedHeadData := Id.run do - -- Compute Q/K once (needed for attention weights), and also cache tight, data-dependent - -- bounds on their norms for use in pattern-term bounds. - let queries := (attnInput.matmul head.W_Q).addBias head.b_Q - let keys := (attnInput.matmul head.W_K).addBias head.b_K - let scaleFactor := Float.sqrt head.headDim.toFloat - let attn := ConcreteAttentionWeights.compute queries keys scaleFactor attnInput.numRows causal - let values := (attnInput.matmul head.W_V).addBias head.b_V - let qFrobBound : Float := queries.frobeniusNorm - let kFrobBound : Float := keys.frobeniusNorm - let vFrobBoundRaw : Float := values.frobeniusNorm - let smallEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1 - let qOpBoundAct : Float := queries.opNormUpperBoundRectGramEffort smallEffort - let kOpBoundAct : Float := keys.opNormUpperBoundRectGramEffort smallEffort - let vOpBoundActRaw : Float := values.opNormUpperBoundRectGramEffort smallEffort - let softmaxDiag := attn.softmaxJacobianOpDiag - let softmaxOpBound := min softmaxDiag.opBound 0.5 - let n := attn.seqLen - let mut attnFrobNormSq : Float := 0.0 - let mut maxRowAbsSum : Float := 0.0 - let mut colAbsSums : Array Float := Array.replicate n 0.0 - for q in [:n] do - let rowBase := q * n - let mut rowAbsSum : Float := 0.0 - for k in [:n] do - let w := attn.weights[rowBase + k]! - attnFrobNormSq := attnFrobNormSq + w * w - let a := Float.abs w - rowAbsSum := rowAbsSum + a - colAbsSums := colAbsSums.set! k (colAbsSums[k]! + a) - if rowAbsSum > maxRowAbsSum then - maxRowAbsSum := rowAbsSum - let mut maxColAbsSum : Float := 0.0 - for k in [:n] do - let s := colAbsSums[k]! - if s > maxColAbsSum then - maxColAbsSum := s - let attnOneInf : Float := Float.sqrt (max 0.0 (maxRowAbsSum * maxColAbsSum)) - - let prevTokenStrength : Float := - if attn.seqLen < 2 then - 0.0 - else - Id.run do - let n := attn.seqLen - let w := attn.weights - let mut sum : Float := 0.0 - for i in [:n - 1] do - sum := sum + w[(i + 1) * n + i]! - return sum / (n - 1).toFloat - - -- No-dense scalar bounds: avoid allocating per-head `modelDim×modelDim` products. - let bnds := head.noDenseProductBounds - let wqGram := bnds.wqGram - let wvGram := bnds.wvGram - let qkNorm := bnds.qkDenseFrob - let voNorm := bnds.voDenseFrob - let qkOpBound := bnds.qkOpBound - let voOpBound := bnds.voOpBound - let qActGram := queries.transpose.matmul queries - let kActGram := keys.transpose.matmul keys - let meanVec := fun (M : ConcreteMatrix) => Id.run do - let cols := M.numCols - if M.numRows = 0 || cols = 0 then - return Array.replicate cols 0.0 - let mut sums : Array Float := Array.replicate cols 0.0 - for r in [:M.numRows] do - let rowBase := r * cols - for c in [:cols] do - sums := sums.set! c (sums[c]! + M.data[rowBase + c]!) - let invN := 1.0 / M.numRows.toFloat - let mut out : Array Float := Array.mkEmpty cols - for c in [:cols] do - out := out.push (sums[c]! * invN) - return out - let centerGram := fun (gram : ConcreteMatrix) (mean : Array Float) (n : Nat) => - if gram.numRows = 0 || gram.numCols = 0 || n = 0 then - gram - else - let nF := n.toFloat - { numRows := gram.numRows - numCols := gram.numCols - data := .ofFn fun idx : Fin (gram.numRows * gram.numCols) => Id.run do - let i := idx.val / gram.numCols - let j := idx.val % gram.numCols - gram.data[idx.val]! - nF * mean[i]! * mean[j]! - size_eq := Array.size_ofFn - } - let gramTrace := fun (M : ConcreteMatrix) => Id.run do - if M.numRows = 0 || M.numCols = 0 then - return 0.0 - let mut acc : Float := 0.0 - for i in [:M.numRows] do - acc := acc + M.data[i * M.numCols + i]! - return acc - -- Center activations to drop row-constant logit shifts (softmax-invariant). - let seqLen := keys.numRows - let keyMean := meanVec keys - let queryMean := meanVec queries - let valueMean := meanVec values - let nF := seqLen.toFloat - let vMeanNormSq : Float := valueMean.foldl (fun acc x => acc + x * x) 0.0 - let vFrobBound : Float := - Float.sqrt (max 0.0 (vFrobBoundRaw * vFrobBoundRaw - nF * vMeanNormSq)) - let vActGram := values.transpose.matmul values - let vActGramCentered := centerGram vActGram valueMean seqLen - let vActTrace : Float := gramTrace vActGramCentered - let vActTracePos := max 0.0 vActTrace - let vFrobBoundCentered : Float := Float.sqrt vActTracePos - let vFrobBound : Float := - if vFrobBoundCentered ≤ 0.0 || Float.isNaN vFrobBoundCentered || - Float.isInf vFrobBoundCentered then - vFrobBoundRaw - else - min vFrobBoundRaw vFrobBoundCentered - let vOpBoundActCentered : Float := opBoundFromGramLocal vActGramCentered vActTracePos - let vOpBoundAct : Float := - if vOpBoundActRaw ≤ 0.0 || Float.isNaN vOpBoundActRaw || Float.isInf vOpBoundActRaw then - vOpBoundActCentered - else - min vOpBoundActRaw vOpBoundActCentered - let kActGramCentered := centerGram kActGram keyMean seqLen - let qActGramCentered := centerGram qActGram queryMean seqLen - let kActTracePos := max 0.0 (gramTrace kActGramCentered) - let qActTracePos := max 0.0 (gramTrace qActGramCentered) - let kCenteredOpBound : Float := opBoundFromGramLocal kActGramCentered kActTracePos - let qCenteredOpBound : Float := opBoundFromGramLocal qActGramCentered qActTracePos - let chooseMinLabel := fun (a b : (String × Float)) => - if a.2 ≤ b.2 then a else b - let qkActTrace : Float := ConcreteMatrix.traceMul wqGram kActGramCentered - let qkActGram1 := kActGramCentered.matmul wqGram - let qkActGram2 := wqGram.matmul kActGramCentered - let qkActOpDense1 : Float := qkActGram1.opNormUpperBoundDenseBrauer - let qkActOpDense2 : Float := qkActGram2.opNormUpperBoundDenseBrauer - let qkActOpBoundDense : Float := - Float.sqrt (max 0.0 (min qkActOpDense1 qkActOpDense2)) - let qkActF2_1 : Float := ConcreteMatrix.traceMul qkActGram1 qkActGram1 - let qkActF2_2 : Float := ConcreteMatrix.traceMul qkActGram2 qkActGram2 - -- `KᵀK·W_QᵀW_Q` has nonnegative real eigenvalues; use trace/trace-square moment bound. - let qkActMoment1 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment qkActGram1.numRows qkActTrace (max 0.0 qkActF2_1) - let qkActMoment2 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment qkActGram2.numRows qkActTrace (max 0.0 qkActF2_2) - let qkActOpBoundMoment : Float := - Float.sqrt (max 0.0 (min qkActMoment1 qkActMoment2)) - let qkActGram1Sq := qkActGram1.matmul qkActGram1 - let qkActGram2Sq := qkActGram2.matmul qkActGram2 - let qkActTrace4_1 : Float := ConcreteMatrix.traceMul qkActGram1Sq qkActGram1Sq - let qkActTrace4_2 : Float := ConcreteMatrix.traceMul qkActGram2Sq qkActGram2Sq - let qkActOpBoundPow4 : Float := - Float.sqrt (Float.sqrt (max 0.0 (min qkActTrace4_1 qkActTrace4_2))) - let qkActLambdaBrauer1 : Float := ConcreteMatrix.lambdaMaxUpperBrauer qkActGram1 - let qkActLambdaBrauer2 : Float := ConcreteMatrix.lambdaMaxUpperBrauer qkActGram2 - let qkActOpBoundBrauer : Float := - Float.sqrt (max 0.0 (min qkActLambdaBrauer1 qkActLambdaBrauer2)) - let qkActLambdaGershNN1 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg qkActGram1 - let qkActLambdaGershNN2 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg qkActGram2 - let qkActOpBoundGershNN : Float := - Float.sqrt (max 0.0 (min qkActLambdaGershNN1 qkActLambdaGershNN2)) - let qkActLambdaBrauerNN1 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg qkActGram1 - let qkActLambdaBrauerNN2 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg qkActGram2 - let qkActOpBoundBrauerNN : Float := - Float.sqrt (max 0.0 (min qkActLambdaBrauerNN1 qkActLambdaBrauerNN2)) - let qkActOpBound0 : Float := - Float.sqrt (max 0.0 (min qkActGram1.infNormAbs qkActGram2.infNormAbs)) - let qkActCanFactor : Bool := - !(kCenteredOpBound ≤ 0.0 || Float.isNaN kCenteredOpBound || Float.isInf kCenteredOpBound) - let qkActBase0 : (String × Float) := ("gramInf", qkActOpBound0) - let qkActBase1 := chooseMinLabel qkActBase0 ("gramGershNN", qkActOpBoundGershNN) - let qkActBase2 := chooseMinLabel qkActBase1 ("gramBrauerNN", qkActOpBoundBrauerNN) - let qkActBase3 := chooseMinLabel qkActBase2 ("gramBrauer", qkActOpBoundBrauer) - let qkActBase4 := chooseMinLabel qkActBase3 ("denseBrauer", qkActOpBoundDense) - let qkActBase5 := chooseMinLabel qkActBase4 ("gramMoment", qkActOpBoundMoment) - let qkActBase6 := chooseMinLabel qkActBase5 ("gramPow4", qkActOpBoundPow4) - let qkActBest := - if qkActCanFactor then - chooseMinLabel qkActBase6 ("factor", bnds.wqOpGram * kCenteredOpBound) - else - qkActBase6 - let qkActOpBound : Float := qkActBest.2 - let qkActOpBoundSource : String := qkActBest.1 - let kqActTrace : Float := ConcreteMatrix.traceMul bnds.wkGram qActGramCentered - let kqActGram1 := qActGramCentered.matmul bnds.wkGram - let kqActGram2 := bnds.wkGram.matmul qActGramCentered - let kqActOpDense1 : Float := kqActGram1.opNormUpperBoundDenseBrauer - let kqActOpDense2 : Float := kqActGram2.opNormUpperBoundDenseBrauer - let kqActOpBoundDense : Float := - Float.sqrt (max 0.0 (min kqActOpDense1 kqActOpDense2)) - let kqActF2_1 : Float := ConcreteMatrix.traceMul kqActGram1 kqActGram1 - let kqActF2_2 : Float := ConcreteMatrix.traceMul kqActGram2 kqActGram2 - let kqActMoment1 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment kqActGram1.numRows kqActTrace (max 0.0 kqActF2_1) - let kqActMoment2 : Float := - ConcreteMatrix.psdLambdaMaxUpperMoment kqActGram2.numRows kqActTrace (max 0.0 kqActF2_2) - let kqActOpBoundMoment : Float := - Float.sqrt (max 0.0 (min kqActMoment1 kqActMoment2)) - let kqActGram1Sq := kqActGram1.matmul kqActGram1 - let kqActGram2Sq := kqActGram2.matmul kqActGram2 - let kqActTrace4_1 : Float := ConcreteMatrix.traceMul kqActGram1Sq kqActGram1Sq - let kqActTrace4_2 : Float := ConcreteMatrix.traceMul kqActGram2Sq kqActGram2Sq - let kqActOpBoundPow4 : Float := - Float.sqrt (Float.sqrt (max 0.0 (min kqActTrace4_1 kqActTrace4_2))) - let kqActLambdaBrauer1 : Float := ConcreteMatrix.lambdaMaxUpperBrauer kqActGram1 - let kqActLambdaBrauer2 : Float := ConcreteMatrix.lambdaMaxUpperBrauer kqActGram2 - let kqActOpBoundBrauer : Float := - Float.sqrt (max 0.0 (min kqActLambdaBrauer1 kqActLambdaBrauer2)) - let kqActLambdaGershNN1 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg kqActGram1 - let kqActLambdaGershNN2 : Float := ConcreteMatrix.lambdaMaxUpperGershNonneg kqActGram2 - let kqActOpBoundGershNN : Float := - Float.sqrt (max 0.0 (min kqActLambdaGershNN1 kqActLambdaGershNN2)) - let kqActLambdaBrauerNN1 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg kqActGram1 - let kqActLambdaBrauerNN2 : Float := ConcreteMatrix.lambdaMaxUpperBrauerNonneg kqActGram2 - let kqActOpBoundBrauerNN : Float := - Float.sqrt (max 0.0 (min kqActLambdaBrauerNN1 kqActLambdaBrauerNN2)) - let kqActOpBound0 : Float := - Float.sqrt (max 0.0 (min kqActGram1.infNormAbs kqActGram2.infNormAbs)) - let kqActCanFactor : Bool := - !(qCenteredOpBound ≤ 0.0 || Float.isNaN qCenteredOpBound || Float.isInf qCenteredOpBound) - let kqActBase0 : (String × Float) := ("gramInf", kqActOpBound0) - let kqActBase1 := chooseMinLabel kqActBase0 ("gramGershNN", kqActOpBoundGershNN) - let kqActBase2 := chooseMinLabel kqActBase1 ("gramBrauerNN", kqActOpBoundBrauerNN) - let kqActBase3 := chooseMinLabel kqActBase2 ("gramBrauer", kqActOpBoundBrauer) - let kqActBase4 := chooseMinLabel kqActBase3 ("denseBrauer", kqActOpBoundDense) - let kqActBase5 := chooseMinLabel kqActBase4 ("gramMoment", kqActOpBoundMoment) - let kqActBase6 := chooseMinLabel kqActBase5 ("gramPow4", kqActOpBoundPow4) - let kqActBest := - if kqActCanFactor then - chooseMinLabel kqActBase6 ("factor", bnds.wkOpGram * qCenteredOpBound) - else - kqActBase6 - let kqActOpBound : Float := kqActBest.2 - let kqActOpBoundSource : String := kqActBest.1 - let qkActFrobBound : Float := Float.sqrt (max 0.0 qkActTrace) - let kqActFrobBound : Float := Float.sqrt (max 0.0 kqActTrace) - - let diag? : Option HeadDiagnosticsLazy := - if storeDiagnostics then - let voProjT : Thunk ConcreteMatrix := Thunk.mk (fun _ => head.valueOutputProjection) - let qkAlignT : Thunk ConcreteMatrix := Thunk.mk (fun _ => head.queryKeyAlignment) - let voDenseSchurT : Thunk Float := Thunk.mk (fun _ => (voProjT.get).schurNormEst) - let qkDenseSchurT : Thunk Float := Thunk.mk (fun _ => (qkAlignT.get).schurNormEst) - some { voProj := voProjT, qkAlign := qkAlignT - voDenseSchur := voDenseSchurT, qkDenseSchur := qkDenseSchurT } - else - none - - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - qFrobBound := qFrobBound - kFrobBound := kFrobBound - vFrobBound := vFrobBound - qOpBoundAct := qOpBoundAct - kOpBoundAct := kOpBoundAct - vOpBoundAct := vOpBoundAct - qkActFrobBound := qkActFrobBound - kqActFrobBound := kqActFrobBound - qkActOpBound := qkActOpBound - kqActOpBound := kqActOpBound - scaleFactor := scaleFactor - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternParts? : Option PatternTermBoundParts := - if storeDiagnostics then - some (computePatternTermBoundParts inputs) - else - none - let patternBound : Float := - match patternParts? with - | some parts => parts.patternBound - | none => computePatternTermBound inputs - let valueNorm : Float := - Float.sqrt attnFrobNormSq * voNorm - let ratio : Float := - if valueNorm < 1e-10 then Float.inf else patternBound / valueNorm - - return { - layerIdx := l - headIdx := hIdx - attention := attn - prevTokenStrength := prevTokenStrength - softmaxJacobianOpEst := softmaxOpBound - softmaxRowMaxP := softmaxDiag.maxRowMaxP - softmaxRowTraceBound := softmaxDiag.maxRowTraceBound - softmaxRowMomentBound := softmaxDiag.maxRowMomentBound - softmaxRowGershBound := softmaxDiag.maxRowGersh - softmaxRowBoundUsed := softmaxDiag.maxRowBoundUsed - softmaxRowsFallback := softmaxDiag.numRowsFallback - attentionFrobeniusNormSq := attnFrobNormSq - attentionOneInfBound := attnOneInf - patternTermBoundCached := patternBound - valueTermNormCached := valueNorm - faithfulnessRatioCached := ratio - patternBoundParts? := patternParts? - diag? := diag? - wqGram := wqGram - wvGram := wvGram - inputNorm := inputNorm - inputOpBound := inputOpBound - qFrobBound := qFrobBound - kFrobBound := kFrobBound - vFrobBound := vFrobBound - qOpBoundAct := qOpBoundAct - kOpBoundAct := kOpBoundAct - vOpBoundAct := vOpBoundAct - qkActFrobBound := qkActFrobBound - kqActFrobBound := kqActFrobBound - qkActOpBound := qkActOpBound - kqActOpBound := kqActOpBound - qkActOpBoundSource := qkActOpBoundSource - kqActOpBoundSource := kqActOpBoundSource - ln1OpBound := ln1Bound - scaleFactor := scaleFactor - valueOutputProjNorm := voNorm - queryKeyAlignNorm := qkNorm - valueOutputProjSchurNorm := voOpBound - queryKeyAlignSchurNorm := qkOpBound - - qkDenseFrob := bnds.qkDenseFrob - qkDenseGram := bnds.qkDenseGram - qkDenseBrauer := bnds.qkDenseBrauer - qkFactorSchur := bnds.qkFactorSchur - qkFactorFrob := bnds.qkFactorFrob - - wqOpGram := bnds.wqOpGram - wkOpGram := bnds.wkOpGram - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - qkFactorGram := bnds.qkFactorGram - - voDenseFrob := bnds.voDenseFrob - voDenseGram := bnds.voDenseGram - voDenseBrauer := bnds.voDenseBrauer - voFactorSchur := bnds.voFactorSchur - voFactorFrob := bnds.voFactorFrob - wvOpGram := bnds.wvOpGram - woOpGram := bnds.woOpGram - voFactorGram := bnds.voFactorGram - } - - let useParallelHeads := (!useParallelLayers) && heads.size >= 4 - if useParallelHeads then - let tasks : Array (Task PrecomputedHeadData) := - .ofFn fun i : Fin heads.size => - Task.spawn (fun _ => computeHead i.val heads[i]) - tasks.map Task.get - else - Id.run do - let mut out : Array PrecomputedHeadData := Array.mkEmpty heads.size - for h_idx in [:heads.size] do - if hh : h_idx < heads.size then - out := out.push (computeHead h_idx (heads[h_idx]'hh)) - return out - else - #[] - - let norm : Float := - if computeLayerNormBounds then - -- OPTIMIZATION: compute per-layer residual Jacobian upper bounds from cached head data, - -- avoiding recomputation of attention weights / projections. - Id.run do - let y := fwdResult.getPostAttnResidual l - let ln2Bound := model.ln2OpBound l y - let mut attnPart : Float := 0.0 - for d in layerHeadData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := d.inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs - attnPart := attnPart + d.ln1OpBound * (valueTermUb + patternTermUb) - - if hm : l < model.mlps.size then - let mlp := model.mlps[l]'hm - let winNormUb := mlp.W_in.opNormUpperBoundRectGramEffort layerNormEffort - let woutNormUb := mlp.W_out.opNormUpperBoundRectGramEffort layerNormEffort - let geluDeriv := fwdResult.getMlpGeluDeriv l - let mlpUb := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv winNormUb woutNormUb - else - let mlpInput := model.applyLn2 l y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct winNormUb woutNormUb - let mlpPart := ln2Bound * mlpUb - return attnPart + (1.0 + attnPart) * mlpPart - else - return attnPart - else - 0.0 - - return (layerHeadData, (norm, attnInput)) - - -- Pure parallelism via tasks: layer cache construction is independent once the - -- forward pass has produced all layer inputs. - let useParallel := useParallelLayers - let layerResults : Array (Array PrecomputedHeadData × (Float × ConcreteMatrix)) := - if useParallel then - let tasks : Array (Task (Array PrecomputedHeadData × (Float × ConcreteMatrix))) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeLayer i.val - - let mut headData : Array (Array PrecomputedHeadData) := Array.mkEmpty model.numLayers - let mut layerNormBounds : Array Float := Array.mkEmpty model.numLayers - let mut ln1Inputs : Array ConcreteMatrix := Array.mkEmpty model.numLayers - for (layerHeadData, (norm, attnInput)) in layerResults do - headData := headData.push layerHeadData - layerNormBounds := layerNormBounds.push norm - ln1Inputs := ln1Inputs.push attnInput - - { model := model - forwardResult := fwdResult - ln1Inputs := ln1Inputs - headData := headData - layerNormBounds := layerNormBounds - layerNormBoundsComputed := computeLayerNormBounds } - -/-- Retrieve cached data for a specific head. -/ -def getHeadData (cache : PrecomputedCache) (layerIdx headIdx : Nat) : - Option PrecomputedHeadData := - if h1 : layerIdx < cache.headData.size then - let layerCache := cache.headData[layerIdx] - if h2 : headIdx < layerCache.size then - some layerCache[headIdx] - else none - else none - -/-- Retrieve cached Pre-LN attention input `ln_1(x_l)` for a specific layer. -/ -def getLn1Input (cache : PrecomputedCache) (layerIdx : Nat) : ConcreteMatrix := - if h : layerIdx < cache.ln1Inputs.size then - cache.ln1Inputs[layerIdx] - else - ConcreteMatrix.zeros 0 0 - -end PrecomputedCache - -/-! ## Krylov-style rigorous lower bounds for layer residual Jacobians -/ - -namespace ConcreteMatrix - -/-- A precomputed context for applying the Jacobian of row-wise LayerNorm to perturbations. -/ -structure LayerNormJacobianCtx where - numRows : Nat - numCols : Nat - gamma : ConcreteMatrix - invStds : Array Float - v : ConcreteMatrix - -/-- Build a `LayerNormJacobianCtx` for `layerNormRowwise X γ β`. - -We cache the per-row `invStd` and the centered+scaled `v = (x-μ)/σ` used in the Jacobian -formula. (`β` does not affect the Jacobian.) --/ -def mkLayerNormJacobianCtx (X γ : ConcreteMatrix) (eps : Float := 1e-5) : LayerNormJacobianCtx := - Id.run do - let rows := X.numRows - let cols := X.numCols - if rows = 0 || cols = 0 then - return { numRows := rows, numCols := cols, gamma := γ, invStds := #[], v := X } - if !(γ.numRows = 1 ∧ γ.numCols = cols) then - return { numRows := rows, numCols := cols, gamma := γ, invStds := #[], v := X } - - let mut means : Array Float := Array.mkEmpty rows - let mut invStds : Array Float := Array.mkEmpty rows - for r in [:rows] do - let mut sum : Float := 0.0 - let rowBase := r * cols - for c in [:cols] do - sum := sum + X.data[rowBase + c]! - let μ := sum / cols.toFloat - let mut varSum : Float := 0.0 - for c in [:cols] do - let d := X.data[rowBase + c]! - μ - varSum := varSum + d * d - let varRaw := varSum / cols.toFloat - -- Clamp for numerical stability (avoid NaN from tiny negative float noise). - let var := - if Float.isNaN varRaw || Float.isInf varRaw then 0.0 - else max 0.0 varRaw - let invσ := 1.0 / Float.sqrt (var + eps) - means := means.push μ - invStds := invStds.push invσ - - let v : ConcreteMatrix := - { numRows := rows - numCols := cols - data := .ofFn fun idx : Fin (rows * cols) => - let r := idx.val / cols - let c := idx.val % cols - let μ := means.getD r 0.0 - let invσ := invStds.getD r 0.0 - (X.data[r * cols + c]! - μ) * invσ - size_eq := Array.size_ofFn } - - return { numRows := rows, numCols := cols, gamma := γ, invStds := invStds, v := v } - -/-- Apply the LayerNorm Jacobian `J` at the cached row statistics: `δy = δx · J` (row-wise). -/ -def LayerNormJacobianCtx.apply (ctx : LayerNormJacobianCtx) (dX : ConcreteMatrix) : ConcreteMatrix := - Id.run do - if dX.numRows ≠ ctx.numRows || dX.numCols ≠ ctx.numCols then - return ConcreteMatrix.zeros ctx.numRows ctx.numCols - let rows := ctx.numRows - let cols := ctx.numCols - let colsF := cols.toFloat - let gammaData := ctx.gamma.data - if rows = 0 || cols = 0 then - return ConcreteMatrix.zeros rows cols - if ctx.invStds.size ≠ rows then - return ConcreteMatrix.zeros rows cols - - -- Precompute per-row mean(dX) and v⋅dX. - let mut meanDx : Array Float := Array.replicate rows 0.0 - let mut vDotDx : Array Float := Array.replicate rows 0.0 - for r in [:rows] do - let rowBase := r * cols - let mut sumDx : Float := 0.0 - let mut sumVDx : Float := 0.0 - for c in [:cols] do - let dx := dX.data[rowBase + c]! - sumDx := sumDx + dx - sumVDx := sumVDx + (ctx.v.data[rowBase + c]! * dx) - meanDx := meanDx.set! r (sumDx / cols.toFloat) - vDotDx := vDotDx.set! r sumVDx - - { numRows := rows - numCols := cols - data := .ofFn fun idx : Fin (rows * cols) => - let r := idx.val / cols - let c := idx.val % cols - let rowBase := r * cols - let invσ := ctx.invStds[r]! - let g := gammaData[c]! - let dx := dX.data[rowBase + c]! - let mdx := meanDx[r]! - let vdx := vDotDx[r]! - let vrc := ctx.v.data[rowBase + c]! - g * invσ * (dx - mdx - (vrc * (vdx / colsF))) - size_eq := Array.size_ofFn } - -/-- Elementwise (Hadamard) product of two matrices with the same shape. -/ -def hadamard (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = B.numRows ∧ A.numCols = B.numCols then - { numRows := A.numRows - numCols := A.numCols - data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data[idx.val]! * B.data[idx.val]! - size_eq := Array.size_ofFn } - else - ConcreteMatrix.zeros 0 0 - -end ConcreteMatrix - -namespace ConcreteAttentionWeights - -/-- Multiply an attention matrix `A` (as weights) by a matrix `V` (n×d), returning `A·V`. -/ -def mul (A : ConcreteAttentionWeights) (V : ConcreteMatrix) : ConcreteMatrix := - if A.seqLen = 0 then - ConcreteMatrix.zeros 0 0 - else if V.numRows ≠ A.seqLen then - ConcreteMatrix.zeros A.seqLen V.numCols - else - let n := A.seqLen - { numRows := n - numCols := V.numCols - data := .ofFn fun idx : Fin (n * V.numCols) => Id.run do - let q := idx.val / V.numCols - let d := idx.val % V.numCols - let mut acc : Float := 0.0 - let rowBase := q * n - for k in [:n] do - let a := A.weights[rowBase + k]! - let v := V.data[k * V.numCols + d]! - acc := acc + a * v - return acc - size_eq := Array.size_ofFn } - -end ConcreteAttentionWeights - -/-- Cached data for applying a single-head attention Jacobian to perturbations. -/ -structure HeadJacobianCtx where - head : ConcreteAttentionLayer - attn : ConcreteAttentionWeights - input : ConcreteMatrix - Q : ConcreteMatrix - K : ConcreteMatrix - V : ConcreteMatrix - KT : ConcreteMatrix - AV : ConcreteMatrix - invScale : Float - -namespace HeadJacobianCtx - -def build (head : ConcreteAttentionLayer) (input : ConcreteMatrix) (attn : ConcreteAttentionWeights) : - HeadJacobianCtx := - let Q := (input.matmul head.W_Q).addBias head.b_Q - let K := (input.matmul head.W_K).addBias head.b_K - let V := (input.matmul head.W_V).addBias head.b_V - let KT := K.transpose - let AV := attn.mul V - let invScale := 1.0 / Float.sqrt head.headDim.toFloat - { head := head, attn := attn, input := input, Q := Q, K := K, V := V, KT := KT, AV := AV, - invScale := invScale } - -/-- Apply the attention-head Jacobian at the cached `(input, attn)` to a perturbation `dInput`. -/ -def apply (ctx : HeadJacobianCtx) (dInput : ConcreteMatrix) : ConcreteMatrix := Id.run do - let n := ctx.attn.seqLen - if n = 0 then - return ConcreteMatrix.zeros 0 0 - if dInput.numRows ≠ n || dInput.numCols ≠ ctx.head.modelDim then - return ConcreteMatrix.zeros n ctx.head.modelDim - - let dQ := dInput.matmul ctx.head.W_Q - let dK := dInput.matmul ctx.head.W_K - let dV := dInput.matmul ctx.head.W_V - - -- Value term: A · dV - let dAV_value := ctx.attn.mul dV - - -- Pattern term: (dA) · V where dA = J_softmax(S) dS. - let dS1 := dQ.matmul ctx.KT - let dS2 := ctx.Q.matmul dK.transpose - let dS := (ConcreteMatrix.scale ctx.invScale (dS1.add dS2)) - - -- cRow[q] = ⟨p_q, dS_q⟩ - let mut cRow : Array Float := Array.replicate n 0.0 - for q in [:n] do - let rowBase := q * n - let mut c : Float := 0.0 - for k in [:n] do - let p := ctx.attn.weights[rowBase + k]! - let ds := dS.data[rowBase + k]! - c := c + p * ds - cRow := cRow.set! q c - - let dAV_pattern : ConcreteMatrix := - { numRows := n - numCols := ctx.head.headDim - data := .ofFn fun idx : Fin (n * ctx.head.headDim) => Id.run do - let q := idx.val / ctx.head.headDim - let d := idx.val % ctx.head.headDim - let rowBase := q * n - let c := cRow[q]! - let mut acc : Float := 0.0 - for k in [:n] do - let p := ctx.attn.weights[rowBase + k]! - let ds := dS.data[rowBase + k]! - let alpha := p * (ds - c) - let v := ctx.V.data[k * ctx.head.headDim + d]! - acc := acc + alpha * v - return acc - size_eq := Array.size_ofFn } - - let dAV := dAV_value.add dAV_pattern - -- Project back: (dAV) · W_O - return dAV.matmul ctx.head.W_O - -end HeadJacobianCtx - -/-- Cached data for applying an MLP Jacobian to perturbations. -/ -structure MLPJacobianCtx where - layer : ConcreteMLPLayer - input : ConcreteMatrix - geluDeriv : ConcreteMatrix - -namespace MLPJacobianCtx - -def build (layer : ConcreteMLPLayer) (input : ConcreteMatrix) : MLPJacobianCtx := - let hidden := (input.matmul layer.W_in).addBias layer.b_in - let geluDeriv := hidden.map geluDerivFloat - { layer := layer, input := input, geluDeriv := geluDeriv } - -def apply (ctx : MLPJacobianCtx) (dInput : ConcreteMatrix) : ConcreteMatrix := - let dHidden := dInput.matmul ctx.layer.W_in - let dHiddenAct := ConcreteMatrix.hadamard dHidden ctx.geluDeriv - dHiddenAct.matmul ctx.layer.W_out - -end MLPJacobianCtx - -/-- Cached context for applying the residual Jacobian of one transformer block (excluding identity). -/ -structure LayerResidualJacobianCtx where - ln1 : ConcreteMatrix.LayerNormJacobianCtx - ln2 : ConcreteMatrix.LayerNormJacobianCtx - heads : Array HeadJacobianCtx - mlp? : Option MLPJacobianCtx - -namespace LayerResidualJacobianCtx - -def build (cache : PrecomputedCache) (layerIdx : Nat) : LayerResidualJacobianCtx := Id.run do - let model := cache.model - let fwd := cache.forwardResult - let x := fwd.getLayerInput layerIdx - let y := fwd.getPostAttnResidual layerIdx - - let ln1p := model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - let ln2p := model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - - let ln1Ctx := ConcreteMatrix.mkLayerNormJacobianCtx x ln1p.gamma - let ln2Ctx := ConcreteMatrix.mkLayerNormJacobianCtx y ln2p.gamma - - let attnInput := model.applyLn1 layerIdx x - let layerHeads := model.layers.getD layerIdx #[] - let layerHeadData := cache.headData.getD layerIdx #[] - let mut headCtxs : Array HeadJacobianCtx := Array.mkEmpty layerHeads.size - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let attn := - if hd : h < layerHeadData.size then - (layerHeadData[h]'hd).attention - else - head.computeAttentionWeights attnInput true - headCtxs := headCtxs.push (HeadJacobianCtx.build head attnInput attn) - - let mlp? : Option MLPJacobianCtx := - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let mlpInput := model.applyLn2 layerIdx y - some (MLPJacobianCtx.build mlp mlpInput) - else - none - - return { ln1 := ln1Ctx, ln2 := ln2Ctx, heads := headCtxs, mlp? := mlp? } - - /-- Build a `LayerResidualJacobianCtx` restricted to the first `prefixLen` token positions. - - For causal attention, the outputs at positions `< prefixLen` depend only on inputs at - positions `< prefixLen`. Therefore, computing a Krylov lower bound on the restricted - operator (and measuring norms only on this prefix) yields a valid lower bound for the - full-sequence Jacobian operator norm. - -/ - def buildPrefix (cache : PrecomputedCache) (layerIdx : Nat) (prefixLen : Nat) : LayerResidualJacobianCtx := Id.run do - let model := cache.model - let fwd := cache.forwardResult - let xFull := fwd.getLayerInput layerIdx - let yFull := fwd.getPostAttnResidual layerIdx - let p := min prefixLen xFull.numRows - let x := xFull.takeRows p - let y := yFull.takeRows p - - let ln1p := model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - let ln2p := model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - let ln1Ctx := ConcreteMatrix.mkLayerNormJacobianCtx x ln1p.gamma - let ln2Ctx := ConcreteMatrix.mkLayerNormJacobianCtx y ln2p.gamma - - let attnInput := model.applyLn1 layerIdx x - let layerHeads := model.layers.getD layerIdx #[] - let mut headCtxs : Array HeadJacobianCtx := Array.mkEmpty layerHeads.size - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let attn := head.computeAttentionWeights attnInput true - headCtxs := headCtxs.push (HeadJacobianCtx.build head attnInput attn) - - let mlp? : Option MLPJacobianCtx := - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let mlpInput := model.applyLn2 layerIdx y - some (MLPJacobianCtx.build mlp mlpInput) - else - none - - return { ln1 := ln1Ctx, ln2 := ln2Ctx, heads := headCtxs, mlp? := mlp? } - - /-- Apply the residual Jacobian `J_resid` (excluding identity): `δres = δx · J_resid`. - - If `parallelHeads=true`, per-head Jacobian-vector products are computed in parallel and then - summed in **head-index order** to preserve deterministic Float semantics. - -/ - def apply (ctx : LayerResidualJacobianCtx) (dX : ConcreteMatrix) - (parallelHeads : Bool := false) : ConcreteMatrix := Id.run do - let dU := ctx.ln1.apply dX - let dAttnSum : ConcreteMatrix := - if parallelHeads && ctx.heads.size >= 4 then - let tasks : Array (Task ConcreteMatrix) := - .ofFn fun i : Fin ctx.heads.size => - Task.spawn (fun _ => (ctx.heads[i]).apply dU) - let outs := tasks.map Task.get - ConcreteMatrix.sumMatrices outs (allowParallel := parallelHeads) - else - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty ctx.heads.size - for hctx in ctx.heads do - outs := outs.push (hctx.apply dU) - return ConcreteMatrix.sumMatrices outs (allowParallel := false) - let dY := dX.add dAttnSum - let dV := ctx.ln2.apply dY - let dMlp := - match ctx.mlp? with - | some m => m.apply dV - | none => ConcreteMatrix.zeros dX.numRows dX.numCols - return dAttnSum.add dMlp - -end LayerResidualJacobianCtx - -private def deterministicInitVec (rows cols layerIdx : Nat) : ConcreteMatrix := - if rows = 0 || cols = 0 then - ConcreteMatrix.zeros rows cols - else - let n := rows * cols - let idx := ((layerIdx + 1) * 2654435761) % n - { numRows := rows - numCols := cols - data := .ofFn fun i : Fin n => if i.val = idx then 1.0 else 0.0 - size_eq := Array.size_ofFn } - -private def normalizeFrob (X : ConcreteMatrix) : ConcreteMatrix := - let n := X.frobeniusNorm - if n ≤ 0.0 || Float.isNaN n || Float.isInf n then - X - else - ConcreteMatrix.scale (1.0 / n) X - -private def maxAbsIndex (X : ConcreteMatrix) : Nat := Id.run do - let mut bestIdx : Nat := 0 - let mut best : Float := 0.0 - for i in [:X.data.size] do - let a := Float.abs (X.data[i]!) - if a > best then - best := a - bestIdx := i - bestIdx - -private def basisAt (rows cols idx : Nat) : ConcreteMatrix := - if rows = 0 || cols = 0 then - ConcreteMatrix.zeros rows cols - else - let n := rows * cols - if idx < n then - { numRows := rows - numCols := cols - data := .ofFn fun i : Fin n => if i.val = idx then 1.0 else 0.0 - size_eq := Array.size_ofFn } - else - ConcreteMatrix.zeros rows cols - -private def signLike (X : ConcreteMatrix) : ConcreteMatrix := - { numRows := X.numRows - numCols := X.numCols - data := .ofFn fun i : Fin (X.numRows * X.numCols) => - let x := X.data[i.val]! - if x > 0.0 then 1.0 else if x < 0.0 then (-1.0) else 0.0 - size_eq := Array.size_ofFn } - -/-- Improve the lower bound by trying a small set of deterministic initial vectors. - -This is still rigorous: for any nonzero `v`, `‖J v‖/‖v‖ ≤ ‖J‖`. We simply take the max -over a few starts to avoid a "bad" basis vector giving a vacuous bound. --/ -private def krylovLowerBoundFromInit (ctx : LayerResidualJacobianCtx) (v0 : ConcreteMatrix) - (k : Nat) (parallelHeads : Bool) : Float := Id.run do - let mut v := v0 - let mut best : Float := 0.0 - for _ in [:k] do - let vNorm := v.frobeniusNorm - if vNorm ≤ 0.0 || Float.isNaN vNorm || Float.isInf vNorm then - break - let w := ctx.apply v (parallelHeads := parallelHeads) - let wNorm := w.frobeniusNorm - let r : Float := wNorm / vNorm - if r > best then - best := r - if wNorm ≤ 0.0 || Float.isNaN wNorm || Float.isInf wNorm then - break - v := normalizeFrob w - return best - -/-- Rigorous (in exact real arithmetic) lower bound on `‖J_resid‖₂` from `k` Krylov steps. -/ -def layerResidualJacobianLowerBound (cache : PrecomputedCache) (layerIdx : Nat) (k : Nat := 4) - (parallelHeads : Bool := false) : - Float := Id.run do - let x := cache.forwardResult.getLayerInput layerIdx - let rows := x.numRows - let cols := x.numCols - if rows = 0 || cols = 0 then - return 0.0 - - -- PERFORMANCE: the full Jacobian-apply cost scales like `O(seqLen^2)` per head. - -- For causal attention we can restrict to a small prefix and still obtain a rigorous - -- lower bound by projecting the output norm to that prefix. - -- Keep this small: the Jacobian-apply cost is roughly quadratic in `prefixLen` due to softmax terms. - let prefixLen : Nat := min rows 16 - let xP := x.takeRows prefixLen - let ctx := - if prefixLen = rows then - LayerResidualJacobianCtx.build cache layerIdx - else - LayerResidualJacobianCtx.buildPrefix cache layerIdx prefixLen - let vHash := deterministicInitVec prefixLen cols layerIdx - let vData := normalizeFrob xP - let vMax := basisAt prefixLen cols (maxAbsIndex xP) - let vSign := normalizeFrob (signLike xP) - let useParallelInits : Bool := parallelHeads && (prefixLen * cols ≥ 50000) && k > 0 - if useParallelInits then - let t1 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vHash k parallelHeads) - let t2 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vData k parallelHeads) - let t3 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vMax k parallelHeads) - let t4 := Task.spawn (fun _ => krylovLowerBoundFromInit ctx vSign k parallelHeads) - let b1 := t1.get - let b2 := t2.get - let b3 := t3.get - let b4 := t4.get - return max (max b1 b2) (max b3 b4) - else - let b1 := krylovLowerBoundFromInit ctx vHash k parallelHeads - let b2 := krylovLowerBoundFromInit ctx vData k parallelHeads - let b3 := krylovLowerBoundFromInit ctx vMax k parallelHeads - let b4 := krylovLowerBoundFromInit ctx vSign k parallelHeads - return max (max b1 b2) (max b3 b4) - -/-! ## Adaptive bound scheduler (rigorous upper bounds, deterministic) -/ - -inductive AdaptiveScope where - | layernorm - | all - deriving Repr, BEq, Inhabited - -structure AdaptiveSchedulerConfig where - targetSlack : Float := 8.0 - maxUpgrades : Nat := 200 - minRelImprove : Float := 0.01 - krylovSteps : Nat := 4 - scope : AdaptiveScope := .layernorm - debugMonotone : Bool := false - deriving Repr - -inductive AdaptiveUpgradeKind where - | ubTier - | lbSteps - deriving Repr, BEq, Inhabited - -structure AdaptiveSchedulerStep where - iter : Nat - layerIdx : Nat - kind : AdaptiveUpgradeKind := .ubTier - tierFrom : Nat - tierTo : Nat - kFrom : Nat := 0 - kTo : Nat := 0 - ubBefore : Float - ubAfter : Float - lb : Float - slackBefore : Float - slackAfter : Float - deriving Repr - -structure AdaptiveSchedulerResult where - /-- Final per-layer rigorous upper bounds. -/ - ub : Array Float - /-- Per-layer rigorous lower bounds from Krylov steps. -/ - lb : Array Float - /-- Krylov steps used for each layer lower bound. -/ - lbK : Array Nat - /-- Final effort tier index per layer. -/ - tier : Array Nat - /-- Upgrade log (one entry per accepted upgrade). -/ - steps : Array AdaptiveSchedulerStep - deriving Repr - -private structure AdaptiveUbCaches where - winUb : Array (Array (Option Float)) - woutUb : Array (Array (Option Float)) - /-- Cached operator-norm upper bounds `‖ln₁(X_l)‖₂` per layer and effort tier. -/ - ln1InputOpUb : Array (Array (Option Float)) - /-- Cached attention-only residual Jacobian upper bounds per layer and effort tier. -/ - attnPartUb : Array (Array (Option Float)) - ubAt : Array (Array (Option Float)) - mlpCore : Array (Option MLPJacobianBoundCore) - deriving Inhabited - -private def safeSlack (ub lb : Float) : Float := - if lb ≤ 0.0 || Float.isNaN lb || Float.isInf lb then - Float.inf - else - ub / lb - -/-- Run the adaptive scheduler for per-layer norm bounds. - -Correctness: every `ub[l]` returned is a minimum of **rigorous** upper bound candidates. -The adaptive logic only changes which candidates are computed; it never relaxes the final bound. --/ -def runAdaptiveScheduler (cache : PrecomputedCache) (cfg : AdaptiveSchedulerConfig) - (activeLayers? : Option (Array Bool) := none) : - AdaptiveSchedulerResult := Id.run do - let model := cache.model - let tiers := ConcreteMatrix.BoundEffort.tiers - let startTier : Nat := 1 -- current default behavior corresponds to tier1 - let lbStep : Nat := if cfg.krylovSteps = 0 then 0 else max 1 cfg.krylovSteps - let maxKrylov : Nat := - if cfg.krylovSteps = 0 then 0 - else max cfg.krylovSteps (min 32 (max 8 (cfg.krylovSteps * 4))) - let active : Array Bool := - match activeLayers? with - | some layers => - if layers.size = model.numLayers then layers - else Array.replicate model.numLayers true - | none => Array.replicate model.numLayers true - - -- Precompute tier-1 attention part and per-layer ln₂ Jacobian bound. - -- - -- We cache tier-1 attention bounds directly from `cache.headData` to preserve the - -- current default behavior and avoid recomputing `‖ln₁(X_l)‖₂` up front. - let nLayers := model.numLayers - let useParallelInit := nLayers >= 4 - let computeLayerInit (l : Nat) : (Float × Float × Float) := Id.run do - if !active[l]! then - return (0.0, 0.0, 0.0) - let layerData := cache.headData.getD l #[] - let mut a : Float := 0.0 - let inputOpBoundTier1 : Float := - if h : 0 < layerData.size then - (layerData[0]'h).inputOpBound - else - let emptyMat : ConcreteMatrix := { numRows := 0, numCols := 0, data := #[], size_eq := by simp } - let x := cache.ln1Inputs.getD l emptyMat - x.opNormUpperBoundRectGramEffort (tiers.getD startTier ConcreteMatrix.BoundEffort.tier1) - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := inputOpBoundTier1 - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs - a := a + d.ln1OpBound * (valueTermUb + patternTermUb) - - let y := cache.forwardResult.getPostAttnResidual l - let ln2Bound := model.ln2OpBound l y - return (a, ln2Bound, inputOpBoundTier1) - - let init : Array (Float × Float × Float) := - if useParallelInit then - let tasks : Array (Task (Float × Float × Float)) := - .ofFn fun i : Fin nLayers => Task.spawn (fun _ => computeLayerInit i.val) - tasks.map Task.get - else - .ofFn fun i : Fin nLayers => computeLayerInit i.val - - let attnPartTier1 : Array Float := init.map (·.1) - let ln2Bound : Array Float := init.map (·.2.1) - let ln1InputOpTier1 : Array Float := init.map (·.2.2) - - let mut lbK : Array Nat := Array.replicate model.numLayers 0 - for l in [:model.numLayers] do - if active[l]! then - lbK := lbK.set! l cfg.krylovSteps - let mut lb : Array Float := Array.replicate model.numLayers 0.0 - let useParallelLb := model.numLayers >= 4 && cfg.krylovSteps > 0 - if useParallelLb then - let mut tasks : Array (Option (Task Float)) := Array.replicate model.numLayers none - for l in [:model.numLayers] do - if active[l]! then - tasks := tasks.set! l (some (Task.spawn (fun _ => - layerResidualJacobianLowerBound cache l cfg.krylovSteps (parallelHeads := true)))) - let mut out : Array Float := Array.replicate model.numLayers 0.0 - for l in [:model.numLayers] do - match tasks[l]! with - | some t => out := out.set! l t.get - | none => pure () - lb := out - else - for l in [:model.numLayers] do - if active[l]! then - lb := lb.set! l (layerResidualJacobianLowerBound cache l cfg.krylovSteps (parallelHeads := true)) - - let mlpUbFromCore (core : MLPJacobianBoundCore) (winUb woutUb : Float) : Float := - let legacy := winUb * core.globalDmax * woutUb - let scaledViaWin := core.winScaledUb * woutUb - let scaledViaWout := winUb * core.woutScaledUb - let scaled0 := - if scaledViaWin ≤ 0.0 || Float.isNaN scaledViaWin || Float.isInf scaledViaWin then - scaledViaWout - else if scaledViaWout ≤ 0.0 || Float.isNaN scaledViaWout || Float.isInf scaledViaWout then - scaledViaWin - else - min scaledViaWin scaledViaWout - let scaled := - if scaled0 ≤ 0.0 || Float.isNaN scaled0 || Float.isInf scaled0 then legacy else scaled0 - let absSchur := core.absSchur - if absSchur ≤ 0.0 || Float.isNaN absSchur || Float.isInf absSchur then - min legacy scaled - else - min absSchur (min legacy scaled) - - let initRow : Array (Option Float) := Array.replicate tiers.size none - let initCaches : AdaptiveUbCaches := - { winUb := Array.replicate model.numLayers (Array.replicate tiers.size none) - woutUb := Array.replicate model.numLayers (Array.replicate tiers.size none) - ln1InputOpUb := Array.replicate model.numLayers initRow - attnPartUb := Array.replicate model.numLayers initRow - ubAt := Array.replicate model.numLayers (Array.replicate tiers.size none) - mlpCore := Array.replicate model.numLayers none } - let mut caches : AdaptiveUbCaches := initCaches - - -- Seed tier-1 caches from precomputation (fast path). - for l in [:model.numLayers] do - if active[l]! && startTier < tiers.size then - let rowIn := (caches.ln1InputOpUb[l]!).set! startTier (some (ln1InputOpTier1.getD l 0.0)) - let rowA := (caches.attnPartUb[l]!).set! startTier (some (attnPartTier1.getD l 0.0)) - caches := - { caches with - ln1InputOpUb := caches.ln1InputOpUb.set! l rowIn - attnPartUb := caches.attnPartUb.set! l rowA } - - let getWinWoutM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches (Float × Float) := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return (0.0, 0.0) - if hm : layerIdx < model.mlps.size then - let win? := st.winUb[layerIdx]![tierIdx]! - let wout? := st.woutUb[layerIdx]![tierIdx]! - match win?, wout? with - | some win, some wout => return (win, wout) - | _, _ => - let eff := tiers[tierIdx]! - let mlp := model.mlps[layerIdx]'hm - let big := mlp.W_in.numRows * mlp.W_in.numCols + mlp.W_out.numRows * mlp.W_out.numCols - let useParallel := big >= 200000 && tierIdx > 0 - let (win, wout) := - if useParallel then - let t1 := Task.spawn (fun _ => mlp.W_in.opNormUpperBoundRectGramEffort eff) - let t2 := Task.spawn (fun _ => mlp.W_out.opNormUpperBoundRectGramEffort eff) - (t1.get, t2.get) - else - (mlp.W_in.opNormUpperBoundRectGramEffort eff, mlp.W_out.opNormUpperBoundRectGramEffort eff) - let rowWin := (st.winUb[layerIdx]!).set! tierIdx (some win) - let rowWout := (st.woutUb[layerIdx]!).set! tierIdx (some wout) - set { st with winUb := st.winUb.set! layerIdx rowWin, woutUb := st.woutUb.set! layerIdx rowWout } - return (win, wout) - else - return (0.0, 0.0) - - let getMlpCoreM (layerIdx : Nat) : StateM AdaptiveUbCaches (Option MLPJacobianBoundCore) := do - let st ← get - if !(layerIdx < model.numLayers) then - return none - match st.mlpCore[layerIdx]! with - | some core => return some core - | none => - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let y := cache.forwardResult.getPostAttnResidual layerIdx - let geluDeriv := cache.forwardResult.getMlpGeluDeriv layerIdx - let core? := - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - mlp.precomputeJacobianBoundCore geluDeriv - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - mlp.precomputeJacobianBoundCore dAct - set { st with mlpCore := st.mlpCore.set! layerIdx core? } - return core? - else - return none - - let rec getLn1InputOpUbM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches Float := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return 0.0 - match st.ln1InputOpUb[layerIdx]![tierIdx]! with - | some v => return v - | none => - -- Enforce monotonicity across tiers by always including the previous-tier bound. - let prev : Float ← - if tierIdx = 0 then - pure Float.inf - else - getLn1InputOpUbM layerIdx (tierIdx - 1) - let emptyMat : ConcreteMatrix := { numRows := 0, numCols := 0, data := #[], size_eq := by simp } - let x := cache.ln1Inputs.getD layerIdx emptyMat - let eff := tiers[tierIdx]! - let raw := x.opNormUpperBoundRectGramEffort eff - let v := min prev raw - let row := (st.ln1InputOpUb[layerIdx]!).set! tierIdx (some v) - set { st with ln1InputOpUb := st.ln1InputOpUb.set! layerIdx row } - return v - - let rec getAttnPartUbM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches Float := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return 0.0 - match st.attnPartUb[layerIdx]![tierIdx]! with - | some v => return v - | none => - let prev : Float ← - if tierIdx = 0 then - pure Float.inf - else - getAttnPartUbM layerIdx (tierIdx - 1) - let inputOpBound ← getLn1InputOpUbM layerIdx tierIdx - let layerData := cache.headData.getD layerIdx #[] - let mut a : Float := 0.0 - for d in layerData do - let attnFrob : Float := Float.sqrt (max 0.0 d.attentionFrobeniusNormSq) - let attnOpUb : Float := min attnFrob d.attentionOneInfBound - let valueTermUb : Float := attnOpUb * d.valueOutputProjSchurNorm - let inputs : PatternTermBoundInputs := { - attention := d.attention - inputNorm := d.inputNorm - inputOpBound := inputOpBound - qFrobBound := d.qFrobBound - kFrobBound := d.kFrobBound - vFrobBound := d.vFrobBound - qOpBoundAct := d.qOpBoundAct - kOpBoundAct := d.kOpBoundAct - vOpBoundAct := d.vOpBoundAct - qkActFrobBound := d.qkActFrobBound - kqActFrobBound := d.kqActFrobBound - qkActOpBound := d.qkActOpBound - kqActOpBound := d.kqActOpBound - scaleFactor := d.scaleFactor - wqOpBound := d.wqOpGram - wkOpBound := d.wkOpGram - wvOpBound := d.wvOpGram - woOpBound := d.woOpGram - voOpBound := d.valueOutputProjSchurNorm - bqFrob := d.bqFrob - bkFrob := d.bkFrob - bvFrob := d.bvFrob - } - let patternTermUb : Float := computePatternTermBound inputs - a := a + d.ln1OpBound * (valueTermUb + patternTermUb) - let v := min prev a - let row := (st.attnPartUb[layerIdx]!).set! tierIdx (some v) - set { st with attnPartUb := st.attnPartUb.set! layerIdx row } - return v - - let computeUbAtM (layerIdx tierIdx : Nat) : StateM AdaptiveUbCaches Float := do - let st ← get - if !(layerIdx < model.numLayers ∧ tierIdx < tiers.size) then - return 0.0 - match st.ubAt[layerIdx]![tierIdx]! with - | some v => return v - | none => - let base ← getAttnPartUbM layerIdx tierIdx - let v : Float ← - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx]'hm - let (win, wout) ← getWinWoutM layerIdx tierIdx - let l2 := ln2Bound.getD layerIdx 0.0 - let mlpUb : Float ← - if tierIdx ≤ startTier then - -- Tier0/tier1 are only evaluated once per layer; avoid extra passes just to build core. - let y := cache.forwardResult.getPostAttnResidual layerIdx - let geluDeriv := cache.forwardResult.getMlpGeluDeriv layerIdx - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv win wout) - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct win wout) - else - match (← getMlpCoreM layerIdx) with - | some core => - pure (mlpUbFromCore core win wout) - | none => - let y := cache.forwardResult.getPostAttnResidual layerIdx - let geluDeriv := cache.forwardResult.getMlpGeluDeriv layerIdx - if geluDeriv.numCols = mlp.hiddenDim ∧ geluDeriv.numRows = y.numRows then - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp geluDeriv win wout) - else - let mlpInput := model.applyLn2 layerIdx y - let hidden := (mlpInput.matmul mlp.W_in).addBias mlp.b_in - let dAct := hidden.map geluDerivFloat - pure (computeMLPLayerOpNormFromGeluDerivWithOpBounds mlp dAct win wout) - let mlpPart := l2 * mlpUb - pure (base + (1.0 + base) * mlpPart) - else - pure base - let row := (st.ubAt[layerIdx]!).set! tierIdx (some v) - set { st with ubAt := st.ubAt.set! layerIdx row } - return v - - -- Initialize to current default (tier1) to ensure adaptive never worsens bounds. - let mut tier : Array Nat := Array.replicate model.numLayers startTier - let mut ub : Array Float := Array.mkEmpty model.numLayers - for l in [:model.numLayers] do - if active[l]! then - let (v, st) := (computeUbAtM l startTier).run caches - caches := st - let baseUb : Float := - if cache.layerNormBoundsComputed then - min (cache.layerNormBounds.getD l v) v - else - v - ub := ub.push baseUb - else - ub := ub.push 0.0 - - let mut steps : Array AdaptiveSchedulerStep := #[] - let mut stalledUb : Array Bool := Array.ofFn fun i : Fin model.numLayers => !active[i.val]! - let mut stalledLb : Array Bool := Array.ofFn fun i : Fin model.numLayers => !active[i.val]! - let mut upgradesUsed : Nat := 0 - let epsNoise : Float := 1e-6 - - while upgradesUsed < cfg.maxUpgrades do - -- Find worst slack among layers that can still be upgraded (ub tier or lb steps). - let mut bestLayer : Nat := 0 - let mut bestSlack : Float := 0.0 - let mut bestUb : Float := 0.0 - let mut found : Bool := false - for l in [:model.numLayers] do - if !active[l]! then - continue - let t := tier[l]! - let canUb := (!stalledUb[l]!) && (t + 1 < tiers.size) - let canLb := (!stalledLb[l]!) && (lbStep > 0) && (lbK[l]! < maxKrylov) - if (!canUb) && (!canLb) then - continue - let s := safeSlack (ub[l]!) (lb[l]!) - if (!found) || s > bestSlack || (s == bestSlack && ub[l]! > bestUb) then - found := true - bestLayer := l - bestSlack := s - bestUb := ub[l]! - - if !found then - break - if bestSlack ≤ cfg.targetSlack then - break - - let l := bestLayer - let oldUb := ub[l]! - let oldLb := lb[l]! - let oldSlack := safeSlack oldUb oldLb - let t := tier[l]! - let canUb := (!stalledUb[l]!) && (t + 1 < tiers.size) - let canLb := (!stalledLb[l]!) && (lbStep > 0) && (lbK[l]! < maxKrylov) - - -- Heuristic choice when both upgrades are possible: - -- if slack is dominated by a tiny lower bound, prioritize strengthening `lb` first. - if canLb && canUb && (oldLb ≤ 0.0 || oldSlack > cfg.targetSlack * 4.0) then - let fromK := lbK[l]! - let toK := min (fromK + lbStep) maxKrylov - let newLb := layerResidualJacobianLowerBound cache l toK (parallelHeads := true) - let relImprove : Float := - if oldLb ≤ 0.0 then (if newLb > 0.0 then 1.0 else 0.0) else (newLb - oldLb) / oldLb - if relImprove < cfg.minRelImprove then - stalledLb := stalledLb.set! l true - else - lbK := lbK.set! l toK - lb := lb.set! l newLb - steps := steps.push - { iter := upgradesUsed - layerIdx := l - kind := .lbSteps - tierFrom := tier[l]! - tierTo := tier[l]! - kFrom := fromK - kTo := toK - ubBefore := oldUb - ubAfter := oldUb - lb := lb[l]! - slackBefore := oldSlack - slackAfter := safeSlack oldUb (lb[l]!) } - upgradesUsed := upgradesUsed + 1 - continue - - if canUb then - let fromTier := t - let toTier := fromTier + 1 - let (newUb, st) := (computeUbAtM l toTier).run caches - caches := st - if cfg.debugMonotone && newUb > oldUb + epsNoise then - panic! s!"Adaptive monotonicity violated at layer {l}: old={oldUb} new={newUb}" - let relImprove : Float := - if oldUb ≤ 0.0 then 0.0 else (oldUb - newUb) / oldUb - if relImprove < cfg.minRelImprove then - stalledUb := stalledUb.set! l true - else - tier := tier.set! l toTier - ub := ub.set! l (min oldUb newUb) - steps := steps.push - { iter := upgradesUsed - layerIdx := l - kind := .ubTier - tierFrom := fromTier - tierTo := toTier - kFrom := lbK[l]! - kTo := lbK[l]! - ubBefore := oldUb - ubAfter := ub[l]! - lb := lb[l]! - slackBefore := oldSlack - slackAfter := safeSlack (ub[l]!) (lb[l]!) } - upgradesUsed := upgradesUsed + 1 - continue - - if canLb then - let fromK := lbK[l]! - let toK := min (fromK + lbStep) maxKrylov - let newLb := layerResidualJacobianLowerBound cache l toK (parallelHeads := true) - let relImprove : Float := - if oldLb ≤ 0.0 then (if newLb > 0.0 then 1.0 else 0.0) else (newLb - oldLb) / oldLb - if relImprove < cfg.minRelImprove then - stalledLb := stalledLb.set! l true - else - lbK := lbK.set! l toK - lb := lb.set! l newLb - steps := steps.push - { iter := upgradesUsed - layerIdx := l - kind := .lbSteps - tierFrom := tier[l]! - tierTo := tier[l]! - kFrom := fromK - kTo := toK - ubBefore := oldUb - ubAfter := oldUb - lb := lb[l]! - slackBefore := oldSlack - slackAfter := safeSlack (ub[l]!) (lb[l]!) } - upgradesUsed := upgradesUsed + 1 - continue - - stalledUb := stalledUb.set! l true - stalledLb := stalledLb.set! l true - upgradesUsed := upgradesUsed + 1 - - return { ub := ub, lb := lb, lbK := lbK, tier := tier, steps := steps } - - -/-! ## Head Composition Metrics -/ - -/-- Compute the **K-composition** score between two attention heads. - -This follows the definition in *"A Mathematical Framework for Transformer Circuits"* -(see the "composition diagram caption"): - -`kComp(h₁→h₂) = ‖W_QK² · W_OV¹‖_F / (‖W_QK²‖_F · ‖W_OV¹‖_F)`. - -In this codebase's row-vector convention, we store `W_QK = W_Q · W_Kᵀ` and -`W_OV,row = W_V · W_O`. The paper's `W_OV` corresponds to `(W_OV,row)ᵀ`, but -`‖W_OV‖_F = ‖W_OV,row‖_F` and we compute the numerator via low-rank factors without -materializing any `modelDim×modelDim` products. - -By default (as in the paper), we subtract the expected composition of random matrices -of shape `modelDim×modelDim`, which is approximately `1/√modelDim`. - -We return 0.0 on dimension mismatch or missing heads. --/ -def computeKCompositionScore - (model : ConcreteModel) - (data1 data2 : PrecomputedHeadData) : Float := - if h1 : data1.layerIdx < model.layers.size then - let heads1 := model.layers[data1.layerIdx]'h1 - if h2 : data2.layerIdx < model.layers.size then - let heads2 := model.layers[data2.layerIdx]'h2 - if hh1 : data1.headIdx < heads1.size then - let head1 := heads1[data1.headIdx]'hh1 - if hh2 : data2.headIdx < heads2.size then - let head2 := heads2[data2.headIdx]'hh2 - let denom := data2.queryKeyAlignNorm * data1.valueOutputProjNorm - if denom < 1e-10 then 0.0 - else - -- K-composition numerator: - -- ‖W_QK² · W_OV¹‖_F where W_QK² = W_Q²W_K²ᵀ and W_OV¹ = (W_V¹W_O¹)ᵀ. - -- Using low-rank factorization: - -- W_Q²W_K²ᵀ (W_V¹W_O¹)ᵀ = W_Q² (W_K²ᵀ W_O¹ᵀ) W_V¹ᵀ - -- and M := W_O¹ W_K² so W_K²ᵀ W_O¹ᵀ = Mᵀ. - let M := head1.W_O.matmul head2.W_K -- (d_head × d_head) - let T := data1.wvGram.matmul M -- (d_head × d_head) - let S := M.transpose.matmul T -- Mᵀ · (W_V¹ᵀW_V¹) · M - let numeratorSq := ConcreteMatrix.traceMul S data2.wqGram - let numerator := Float.sqrt (max numeratorSq 0.0) - let raw := numerator / denom - let baseline : Float := - if head1.modelDim = 0 then 0.0 - else 1.0 / Float.sqrt head1.modelDim.toFloat - raw - baseline - else - 0.0 - else - 0.0 - else - 0.0 - else - 0.0 - -/-- Convert a 2-layer deep circuit candidate into an induction-head candidate. - -This is used to avoid re-running the expensive `checkInductionPattern` scan when both -induction heads and deep circuits are requested from the same cache. --/ -def DeepCircuitCandidate.toInductionCandidate? - (c : DeepCircuitCandidate) (cache : PrecomputedCache) : - Option CandidateInductionHead := - if c.layerIndices.size = 2 && c.headIndices.size = 2 then - let l1 := c.layerIndices[0]! - let l2 := c.layerIndices[1]! - let h1 := c.headIndices[0]! - let h2 := c.headIndices[1]! - match cache.getHeadData l1 h1, cache.getHeadData l2 h2 with - | some d1, some d2 => - let ε1 := d1.faithfulnessRatio - let ε2 := d2.faithfulnessRatio - let combinedError := ε1 + ε2 + ε1 * ε2 - let inductionScore : Float := - match cache.model.inputTokens with - | some tokens => - (checkInductionCopyNextPattern tokens d2.attention (minScore := 0.0)).getD 0.0 - | none => 1.0 - let kComp := computeKCompositionScore cache.model d1 d2 - some { - layer1Idx := l1 - layer2Idx := l2 - head1Idx := h1 - head2Idx := h2 - patternBound1 := ε1 - patternBound2 := ε2 - combinedError := combinedError - prevTokenStrength := d1.prevTokenStrength - inductionScore := inductionScore - kComp := kComp - description := s!"L{l1}H{h1}->L{l2}H{h2} (deep)" - } - | _, _ => none - else - none - -/-- Find candidate (L1, L2) induction-head pairs from a `PrecomputedCache`. - -This searches for the classic two-head induction circuit: -- Layer 1 (L1): a strong **previous-token** head, -- Layer 2 (L2): an **induction** head (token-level "copy-next" attention). --/ -def findInductionHeadCandidatesFromCache (cache : PrecomputedCache) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array CandidateInductionHead := Id.run do - let model := cache.model - let tokens? := model.inputTokens - -- Precompute induction scores per head (layer2-only), since these are reused across head1 loops. - let headInductionScores : Array (Array (Option Float)) := - match tokens? with - | some tokens => - cache.headData.map (fun layer => - layer.map (fun data2 => - checkInductionCopyNextPattern tokens data2.attention minInductionScore)) - | none => - cache.headData.map (fun layer => layer.map (fun _ => some 1.0)) - - let computeForLayer (l1 : Nat) : Array CandidateInductionHead := Id.run do - let layer1Cache := cache.headData.getD l1 #[] - let layer1Candidates := - if minPrevTokenStrength <= 0.0 then - layer1Cache - else - layer1Cache.filter (fun data1 => data1.prevTokenStrength ≥ minPrevTokenStrength) - let mut layerCandidates : Array CandidateInductionHead := #[] - - -- Preserve the original traversal order: l1, l2, head1, head2. - for l2 in [l1 + 1:model.numLayers] do - let layer2Cache := cache.headData.getD l2 #[] - let layer2Scores := headInductionScores.getD l2 #[] - let layer2Candidates : Array (PrecomputedHeadData × Float) := Id.run do - let mut out : Array (PrecomputedHeadData × Float) := #[] - for (data2, score?) in layer2Cache.zip layer2Scores do - match score? with - | some inductionScore => out := out.push (data2, inductionScore) - | none => pure () - return out - for data1 in layer1Candidates do - for (data2, inductionScore) in layer2Candidates do - -- Use dimensionless faithfulness ratios (relative approximation errors). - let ε1 := data1.faithfulnessRatio - let ε2 := data2.faithfulnessRatio - let combinedError := ε1 + ε2 + ε1 * ε2 - let kComp := computeKCompositionScore model data1 data2 - - layerCandidates := layerCandidates.push { - layer1Idx := l1 - layer2Idx := l2 - head1Idx := data1.headIdx - head2Idx := data2.headIdx - patternBound1 := ε1 - patternBound2 := ε2 - combinedError := combinedError - prevTokenStrength := data1.prevTokenStrength - inductionScore := inductionScore - kComp := kComp - description := s!"L{l1}H{data1.headIdx}->L{l2}H{data2.headIdx}" - } - - return layerCandidates - - -- Pure parallelism via tasks: layer-1 index computations are independent. - let useParallel := model.numLayers >= 4 - let chunks : Array (Array CandidateInductionHead) := - if useParallel then - let tasks : Array (Task (Array CandidateInductionHead)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeForLayer i.val - - -- Join without quadratic copying. - let total := chunks.foldl (fun acc cs => acc + cs.size) 0 - let mut candidates : Array CandidateInductionHead := Array.mkEmpty total - for cs in chunks do - for c in cs do - candidates := candidates.push c - - candidates.qsort (·.combinedError < ·.combinedError) - -/-- Search for induction heads using proper layer-wise residual stream computation. - -This method performs multi-layer analysis with correct forward pass: -- Layer 1 attention is computed on the residual stream *after* layer 0 -- Layer 2 attention is computed on the residual stream *after* layers 0-1 - -This enables detection of true induction heads where layer 2 attends to -information created by layer 1's "previous-token" head. - -**Performance Note**: Uses `PrecomputedCache` to avoid O(L²H²) redundant attention -computations, reducing to O(LH) for typical models. --/ -def findInductionHeadCandidates (model : ConcreteModel) - (_threshold : Float) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array CandidateInductionHead := Id.run do - let cache := PrecomputedCache.build model - findInductionHeadCandidatesFromCache cache minPrevTokenStrength minInductionScore - -/-- Filter candidates to only those meeting the threshold. - -Uses proper layer-wise residual stream computation. --/ -def findVerifiedInductionHeads (model : ConcreteModel) - (threshold : Float) : Array VerifiedInductionHead := Id.run do - let candidates := findInductionHeadCandidates model threshold - let mut verified : Array VerifiedInductionHead := #[] - - for candidate in candidates do - if candidate.combinedError ≤ threshold then - verified := verified.push { - candidate := candidate - threshold := threshold - errorChecked := true - } - - verified - -/-- Find induction head candidates with rigorous N-layer amplification bounds. - -This replaces the ad-hoc `ε₁ + ε₂ + ε₁·ε₂` formula with the correct N-layer -composition theorem: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -For a 2-layer induction head with layers l1 < l2: -- Layer l1 contributes: ε₁ · (1 + C_l2) · (1 + C_{l2+1}) · ... -- Layer l2 contributes: ε₂ · (1 + C_{l2+1}) · ... - -The amplification factors Cⱼ are estimated from residual Jacobian norms -(`layerJacobian - I`). --/ -def findDeepCircuitCandidatesFromCache (cache : PrecomputedCache) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array DeepCircuitCandidate := Id.run do - let model := cache.model - - -- OPTIMIZATION: Use cached operator norm bounds (already computed in cache.build) - let allNormBounds := cache.layerNormBounds - - -- OPTIMIZATION: precompute suffix amplification products: - -- `suffixAmp[i] = ∏_{j≥i} (1 + C_j)` and `suffixAmp[size] = 1`. - let suffixAmp : Array Float := Id.run do - let n := allNormBounds.size - let mut out : Array Float := Array.replicate (n + 1) 1.0 - let mut prod : Float := 1.0 - for offset in [:n] do - let i := n - 1 - offset - prod := prod * (1.0 + allNormBounds[i]!) - out := out.set! i prod - return out - - let computeForLayer (l1 : Nat) : Array DeepCircuitCandidate := Id.run do - let layer1Cache := cache.headData.getD l1 #[] - let suffix1 := suffixAmp.getD (l1 + 1) 1.0 - let totalAmpFactor := suffixAmp.getD l1 1.0 - let mut layerCandidates : Array DeepCircuitCandidate := #[] - - -- Preserve the original traversal order: l1, l2, head1, head2. - for l2 in [l1 + 1:model.numLayers] do - let layer2Cache := cache.headData.getD l2 #[] - let suffix2 := suffixAmp.getD (l2 + 1) 1.0 - let relevantNormBounds := allNormBounds.extract l1 (l2 + 1) - - for data1 in layer1Cache do - if data1.prevTokenStrength ≥ minPrevTokenStrength then - for data2 in layer2Cache do - match checkInductionPattern data1.attention data2.attention minInductionScore with - | some _ => - let bound1 := data1.patternTermBound - let bound2 := data2.patternTermBound - let amplifiedError := bound1 * suffix1 + bound2 * suffix2 - - layerCandidates := layerCandidates.push { - layerIndices := #[l1, l2] - headIndices := #[data1.headIdx, data2.headIdx] - patternBounds := #[bound1, bound2] - operatorNormUbs := relevantNormBounds - simpleErrorSum := bound1 + bound2 - amplifiedError := amplifiedError - amplificationFactor := totalAmpFactor - patternType := "induction" - description := s!"L{l1}H{data1.headIdx}->L{l2}H{data2.headIdx}" - } - | none => pure () - else - pure () - - return layerCandidates - - -- Pure parallelism via tasks: layer-1 index computations are independent. - let useParallel := model.numLayers >= 4 - let chunks : Array (Array DeepCircuitCandidate) := - if useParallel then - let tasks : Array (Task (Array DeepCircuitCandidate)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeForLayer i.val - - -- Join without quadratic copying. - let total := chunks.foldl (fun acc cs => acc + cs.size) 0 - let mut candidates : Array DeepCircuitCandidate := Array.mkEmpty total - for cs in chunks do - for c in cs do - candidates := candidates.push c - - candidates.qsort (·.amplifiedError < ·.amplifiedError) - -/-- Find deep circuit candidates with rigorous N-layer amplification bounds. - -This is a wrapper around `findDeepCircuitCandidatesFromCache` that builds the cache. --/ -def findDeepCircuitCandidates (model : ConcreteModel) - (_threshold : Float) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array DeepCircuitCandidate := Id.run do - let cache := PrecomputedCache.build model - findDeepCircuitCandidatesFromCache cache minPrevTokenStrength minInductionScore - -/-- Filter deep circuit candidates by N-layer error threshold. -/ -def findVerifiedDeepCircuits (model : ConcreteModel) - (threshold : Float) : Array DeepCircuitCandidate := Id.run do - let candidates := findDeepCircuitCandidates model threshold - let mut verified : Array DeepCircuitCandidate := #[] - - for candidate in candidates do - if candidate.amplifiedError ≤ threshold then - verified := verified.push candidate - - verified - -/-! ## Pretty Printing -/ - -/-- Pretty print a candidate induction head. -/ -def CandidateInductionHead.toString (c : CandidateInductionHead) : String := - s!"InductionHead L{c.layer1Idx}H{c.head1Idx}->L{c.layer2Idx}H{c.head2Idx} " ++ - s!"(error={c.combinedError}, prev-token={c.prevTokenStrength}, " ++ - s!"induction={c.inductionScore}, kComp={c.kComp})" - -instance : ToString CandidateInductionHead := ⟨CandidateInductionHead.toString⟩ - -/-- Pretty print a verified induction head. -/ -def VerifiedInductionHead.toString (vh : VerifiedInductionHead) : String := - s!"Verified: {vh.candidate} [threshold={vh.threshold}]" - -instance : ToString VerifiedInductionHead := ⟨VerifiedInductionHead.toString⟩ - -/-! ## Summary Statistics -/ - -/-- Summary statistics for discovery results. -/ -structure DiscoverySummary where - /-- Number of candidates found -/ - candidateCount : Nat - /-- Number meeting the threshold -/ - verifiedCount : Nat - /-- Best (lowest) error bound found -/ - bestError : Float - -/-- Compute summary statistics for discovery results. -/ -def computeDiscoverySummary (model : ConcreteModel) (threshold : Float) : - DiscoverySummary := Id.run do - let candidates := findInductionHeadCandidates model threshold - let mut verifiedCount : Nat := 0 - for c in candidates do - if c.combinedError ≤ threshold then - verifiedCount := verifiedCount + 1 - - let bestErr := if candidates.isEmpty then Float.inf - else candidates.foldl (fun acc c => min acc c.combinedError) Float.inf - - { - candidateCount := candidates.size - verifiedCount := verifiedCount - bestError := bestErr - } - -/-! ## Convenience Functions -/ - -/-- Create a ConcreteAttentionLayer from raw Float arrays. -/ -def mkConcreteAttentionLayer - (modelDim headDim : Nat) - (wq wk wv wo : Array Float) - (hq : wq.size = modelDim * headDim) - (hk : wk.size = modelDim * headDim) - (hv : wv.size = modelDim * headDim) - (ho : wo.size = headDim * modelDim) : ConcreteAttentionLayer where - modelDim := modelDim - headDim := headDim - W_Q := { numRows := modelDim, numCols := headDim, data := wq, size_eq := hq } - b_Q := { numRows := 1, numCols := headDim, data := Array.replicate headDim 0.0, size_eq := by simp } - W_K := { numRows := modelDim, numCols := headDim, data := wk, size_eq := hk } - b_K := { numRows := 1, numCols := headDim, data := Array.replicate headDim 0.0, size_eq := by simp } - W_V := { numRows := modelDim, numCols := headDim, data := wv, size_eq := hv } - b_V := { numRows := 1, numCols := headDim, data := Array.replicate headDim 0.0, size_eq := by simp } - W_O := { numRows := headDim, numCols := modelDim, data := wo, size_eq := ho } - W_Q_dims := ⟨rfl, rfl⟩ - b_Q_dims := ⟨rfl, rfl⟩ - W_K_dims := ⟨rfl, rfl⟩ - b_K_dims := ⟨rfl, rfl⟩ - W_V_dims := ⟨rfl, rfl⟩ - b_V_dims := ⟨rfl, rfl⟩ - W_O_dims := ⟨rfl, rfl⟩ - -/-! ## Generic Circuit Discovery - -This section provides a framework for discovering arbitrary circuits (sparse subgraphs) -that are certifiably responsible for model behavior on a given input. - -The key insight is that we can bound the error introduced by pruning a component -without running forward passes—using only weight matrices and attention patterns. --/ - -/-! ### Circuit Mask Representation -/ - -/-- A component identifier: attention head, MLP neuron, or SAE feature. -/ -inductive ComponentId where - /-- Attention head at (layer, head index) -/ - | head : (layerIdx : Nat) → (headIdx : Nat) → ComponentId - /-- MLP neuron at (layer, neuron index within that layer's MLP) -/ - | mlpNeuron : (layerIdx : Nat) → (neuronIdx : Nat) → ComponentId - /-- SAE feature at (layer, feature index within that layer's SAE) -/ - | saeFeature : (layerIdx : Nat) → (featureIdx : Nat) → ComponentId - deriving DecidableEq, Repr - -namespace ComponentId - -/-- Pretty print a component ID. -/ -def toString : ComponentId → String - | head l h => s!"L{l}H{h}" - | mlpNeuron l n => s!"L{l}N{n}" - | saeFeature l f => s!"L{l}F{f}" - -/-- Check if this is an attention head. -/ -def isHead : ComponentId → Bool - | head _ _ => true - | _ => false - -/-- Check if this is an MLP neuron. -/ -def isNeuron : ComponentId → Bool - | mlpNeuron _ _ => true - | _ => false - -/-- Check if this is an SAE feature. -/ -def isSAEFeature : ComponentId → Bool - | saeFeature _ _ => true - | _ => false - -/-- Get the layer index. -/ -def layerIdx : ComponentId → Nat - | head l _ => l - | mlpNeuron l _ => l - | saeFeature l _ => l - -instance : ToString ComponentId := ⟨ComponentId.toString⟩ - -end ComponentId - -/-- A circuit is a sparse subgraph mask over model components. - -The mask indicates which components are **included** in the circuit. -Components not in the mask are considered "ablated" (zeroed out). - -This is a simplified structure without proof obligations - validity is checked at runtime. --/ -structure ConcreteCircuit where - /-- Number of layers in the model -/ - numLayers : Nat - /-- Number of heads per layer -/ - headsPerLayer : Array Nat - /-- Number of MLP neurons per layer -/ - neuronsPerLayer : Array Nat - /-- Included attention heads: includedHeads[l][h] = true iff head h in layer l is active -/ - includedHeads : Array (Array Bool) - /-- Included MLP neurons: includedNeurons[l][n] = true iff neuron n in layer l is active -/ - includedNeurons : Array (Array Bool) - -namespace ConcreteCircuit - -/-- Check if a specific head is included in the circuit. -/ -def isHeadIncluded (circuit : ConcreteCircuit) (layerIdx headIdx : Nat) : Bool := - if layerIdx < circuit.includedHeads.size then - let layerMask := circuit.includedHeads.getD layerIdx #[] - layerMask.getD headIdx false - else false - -/-- Check if a specific MLP neuron is included in the circuit. -/ -def isNeuronIncluded (circuit : ConcreteCircuit) (layerIdx neuronIdx : Nat) : Bool := - if layerIdx < circuit.includedNeurons.size then - let layerMask := circuit.includedNeurons.getD layerIdx #[] - layerMask.getD neuronIdx false - else false - -/-- Check if any component is included (dispatches on ComponentId type). -/ -def isIncluded (circuit : ConcreteCircuit) (comp : ComponentId) : Bool := - match comp with - | ComponentId.head l h => circuit.isHeadIncluded l h - | ComponentId.mlpNeuron l n => circuit.isNeuronIncluded l n - | ComponentId.saeFeature _ _ => false -- SAE features handled by SAECircuit - -/-- Count total number of included attention heads. -/ -def countIncludedHeads (circuit : ConcreteCircuit) : Nat := - circuit.includedHeads.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 - -/-- Count total number of included MLP neurons. -/ -def countIncludedNeurons (circuit : ConcreteCircuit) : Nat := - circuit.includedNeurons.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 - -/-- Count total number of included components. -/ -def countIncluded (circuit : ConcreteCircuit) : Nat := - circuit.countIncludedHeads + circuit.countIncludedNeurons - -/-- Count total number of attention heads (included + excluded). -/ -def totalHeads (circuit : ConcreteCircuit) : Nat := - circuit.headsPerLayer.foldl (· + ·) 0 - -/-- Count total number of MLP neurons (included + excluded). -/ -def totalNeurons (circuit : ConcreteCircuit) : Nat := - circuit.neuronsPerLayer.foldl (· + ·) 0 - -/-- Count total number of components (included + excluded). -/ -def totalComponents (circuit : ConcreteCircuit) : Nat := - circuit.totalHeads + circuit.totalNeurons - -/-- List all included component IDs. -/ -def includedComponents (circuit : ConcreteCircuit) : Array ComponentId := Id.run do - let mut result : Array ComponentId := #[] - -- Attention heads - for l in [:circuit.numLayers] do - let layerMask := circuit.includedHeads.getD l #[] - for h_idx in [:layerMask.size] do - if layerMask.getD h_idx false then - result := result.push (ComponentId.head l h_idx) - -- MLP neurons - for l in [:circuit.numLayers] do - let layerMask := circuit.includedNeurons.getD l #[] - for n_idx in [:layerMask.size] do - if layerMask.getD n_idx false then - result := result.push (ComponentId.mlpNeuron l n_idx) - result - -/-- List all excluded component IDs. -/ -def excludedComponents (circuit : ConcreteCircuit) : Array ComponentId := Id.run do - let mut result : Array ComponentId := #[] - -- Attention heads - for l in [:circuit.numLayers] do - let layerMask := circuit.includedHeads.getD l #[] - for h_idx in [:layerMask.size] do - if !layerMask.getD h_idx false then - result := result.push (ComponentId.head l h_idx) - -- MLP neurons - for l in [:circuit.numLayers] do - let layerMask := circuit.includedNeurons.getD l #[] - for n_idx in [:layerMask.size] do - if !layerMask.getD n_idx false then - result := result.push (ComponentId.mlpNeuron l n_idx) - result - -/-- Create a full circuit (all components included). -/ -def full (numLayers : Nat) (headsPerLayer neuronsPerLayer : Array Nat) : ConcreteCircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - neuronsPerLayer := neuronsPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => true - includedNeurons := neuronsPerLayer.map fun numNeurons => - .ofFn fun _ : Fin numNeurons => true - -/-- Create an empty circuit (no components included). -/ -def empty (numLayers : Nat) (headsPerLayer neuronsPerLayer : Array Nat) : ConcreteCircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - neuronsPerLayer := neuronsPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => false - includedNeurons := neuronsPerLayer.map fun numNeurons => - .ofFn fun _ : Fin numNeurons => false - -/-- Remove a single component from the circuit (returns new circuit). -/ -def removeComponent (circuit : ConcreteCircuit) (comp : ComponentId) : ConcreteCircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => false - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.mlpNeuron layerIdx neuronIdx => - if layerIdx < circuit.includedNeurons.size then - let newIncluded := circuit.includedNeurons.modify layerIdx fun layerMask => - if neuronIdx < layerMask.size then - layerMask.modify neuronIdx fun _ => false - else layerMask - { circuit with includedNeurons := newIncluded } - else circuit - | ComponentId.saeFeature _ _ => circuit -- SAE features handled by SAECircuit - -/-- Add a single component to the circuit (returns new circuit). -/ -def addComponent (circuit : ConcreteCircuit) (comp : ComponentId) : ConcreteCircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => true - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.mlpNeuron layerIdx neuronIdx => - if layerIdx < circuit.includedNeurons.size then - let newIncluded := circuit.includedNeurons.modify layerIdx fun layerMask => - if neuronIdx < layerMask.size then - layerMask.modify neuronIdx fun _ => true - else layerMask - { circuit with includedNeurons := newIncluded } - else circuit - | ComponentId.saeFeature _ _ => circuit -- SAE features handled by SAECircuit - -/-- Pretty print the circuit. -/ -def toString (circuit : ConcreteCircuit) : String := - let heads := circuit.countIncludedHeads - let neurons := circuit.countIncludedNeurons - let totalH := circuit.totalHeads - let totalN := circuit.totalNeurons - s!"Circuit(heads={heads}/{totalH}, neurons={neurons}/{totalN})" - -instance : ToString ConcreteCircuit := ⟨ConcreteCircuit.toString⟩ - -end ConcreteCircuit - -/-! ## SAE-Enhanced Circuit Discovery - -When using Sparse Autoencoders, we replace MLP neuron masks with SAE feature masks. -This enables discovering circuits in terms of interpretable features rather than -polysemantic neurons. --/ - -/-- A circuit mask that operates on SAE features instead of MLP neurons. - -This extends ConcreteCircuit by replacing MLP neuron masks with SAE feature masks. -The attention head masks remain the same. --/ -structure SAECircuit where - /-- Number of layers in the model -/ - numLayers : Nat - /-- Number of heads per layer -/ - headsPerLayer : Array Nat - /-- Number of SAE features per layer -/ - featuresPerLayer : Array Nat - /-- Included attention heads -/ - includedHeads : Array (Array Bool) - /-- Included SAE features: includedFeatures[l][f] = true iff feature f in layer l is active -/ - includedFeatures : Array (Array Bool) - -namespace SAECircuit - -/-- Check if a specific head is included. -/ -def isHeadIncluded (circuit : SAECircuit) (layerIdx headIdx : Nat) : Bool := - if layerIdx < circuit.includedHeads.size then - let layerMask := circuit.includedHeads.getD layerIdx #[] - layerMask.getD headIdx false - else false - -/-- Check if a specific SAE feature is included. -/ -def isFeatureIncluded (circuit : SAECircuit) (layerIdx featureIdx : Nat) : Bool := - if layerIdx < circuit.includedFeatures.size then - let layerMask := circuit.includedFeatures.getD layerIdx #[] - layerMask.getD featureIdx false - else false - -/-- Check if any component is included. -/ -def isIncluded (circuit : SAECircuit) (comp : ComponentId) : Bool := - match comp with - | ComponentId.head l h => circuit.isHeadIncluded l h - | ComponentId.mlpNeuron _ _ => false -- SAE circuits don't track neurons - | ComponentId.saeFeature l f => circuit.isFeatureIncluded l f - -/-- Count included heads. -/ -def countIncludedHeads (circuit : SAECircuit) : Nat := - circuit.includedHeads.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 - -/-- Count included features. -/ -def countIncludedFeatures (circuit : SAECircuit) : Nat := - circuit.includedFeatures.foldl (fun acc layer => - acc + layer.foldl (fun acc' b => if b then acc' + 1 else acc') 0) 0 - -/-- Total heads. -/ -def totalHeads (circuit : SAECircuit) : Nat := - circuit.headsPerLayer.foldl (· + ·) 0 - -/-- Total features. -/ -def totalFeatures (circuit : SAECircuit) : Nat := - circuit.featuresPerLayer.foldl (· + ·) 0 - -/-- Create a full circuit (all components included). -/ -def full (numLayers : Nat) (headsPerLayer featuresPerLayer : Array Nat) : SAECircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - featuresPerLayer := featuresPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => true - includedFeatures := featuresPerLayer.map fun numFeats => - .ofFn fun _ : Fin numFeats => true - -/-- Create an empty circuit. -/ -def empty (numLayers : Nat) (headsPerLayer featuresPerLayer : Array Nat) : SAECircuit where - numLayers := numLayers - headsPerLayer := headsPerLayer - featuresPerLayer := featuresPerLayer - includedHeads := headsPerLayer.map fun numHeads => - .ofFn fun _ : Fin numHeads => false - includedFeatures := featuresPerLayer.map fun numFeats => - .ofFn fun _ : Fin numFeats => false - -/-- Remove a component. -/ -def removeComponent (circuit : SAECircuit) (comp : ComponentId) : SAECircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => false - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.saeFeature layerIdx featureIdx => - if layerIdx < circuit.includedFeatures.size then - let newIncluded := circuit.includedFeatures.modify layerIdx fun layerMask => - if featureIdx < layerMask.size then - layerMask.modify featureIdx fun _ => false - else layerMask - { circuit with includedFeatures := newIncluded } - else circuit - | ComponentId.mlpNeuron _ _ => circuit -- Not supported in SAE circuits - -/-- Add a component. -/ -def addComponent (circuit : SAECircuit) (comp : ComponentId) : SAECircuit := - match comp with - | ComponentId.head layerIdx headIdx => - if layerIdx < circuit.includedHeads.size then - let newIncluded := circuit.includedHeads.modify layerIdx fun layerMask => - if headIdx < layerMask.size then - layerMask.modify headIdx fun _ => true - else layerMask - { circuit with includedHeads := newIncluded } - else circuit - | ComponentId.saeFeature layerIdx featureIdx => - if layerIdx < circuit.includedFeatures.size then - let newIncluded := circuit.includedFeatures.modify layerIdx fun layerMask => - if featureIdx < layerMask.size then - layerMask.modify featureIdx fun _ => true - else layerMask - { circuit with includedFeatures := newIncluded } - else circuit - | ComponentId.mlpNeuron _ _ => circuit - -def toString (circuit : SAECircuit) : String := - let heads := circuit.countIncludedHeads - let features := circuit.countIncludedFeatures - let totalH := circuit.totalHeads - let totalF := circuit.totalFeatures - s!"SAECircuit(heads={heads}/{totalH}, features={features}/{totalF})" - -instance : ToString SAECircuit := ⟨SAECircuit.toString⟩ - -end SAECircuit - -/-- Model with SAEs attached at each layer's MLP. - -Replaces `ConcreteModel.mlps` with SAEs for feature-level analysis. --/ -structure SAEEnhancedModel where - /-- Number of layers -/ - numLayers : Nat - /-- Attention layers with their heads -/ - layers : Array (Array ConcreteAttentionLayer) - /-- Pre-LN LayerNorm parameters before attention (ln_1), one per layer. -/ - ln1 : Array ConcreteLayerNormParams := #[] - /-- Pre-LN LayerNorm parameters before SAE/MLP (ln_2), one per layer. -/ - ln2 : Array ConcreteLayerNormParams := #[] - /-- Final LayerNorm parameters (ln_f) before unembedding. -/ - lnf : ConcreteLayerNormParams := ConcreteLayerNormParams.identity 0 - /-- SAEs for MLP analysis: saes[l] is the SAE for layer l's MLP -/ - saes : Array ConcreteSAE - /-- Sequence length -/ - seqLen : Nat - /-- Input embeddings -/ - inputEmbeddings : ConcreteMatrix - /-- Unembedding matrix -/ - unembedding : Option ConcreteMatrix := none - -namespace SAEEnhancedModel - -/-- Model dimension (d), inferred from input embeddings. -/ -def modelDim (model : SAEEnhancedModel) : Nat := - model.inputEmbeddings.numCols - -/-- Get ln_1 parameters for a layer, defaulting to identity. -/ -def ln1Params (model : SAEEnhancedModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln1.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Get ln_2 parameters for a layer, defaulting to identity. -/ -def ln2Params (model : SAEEnhancedModel) (layerIdx : Nat) : ConcreteLayerNormParams := - model.ln2.getD layerIdx (ConcreteLayerNormParams.identity model.modelDim) - -/-- Apply ln_1 to a residual stream (row-wise, per token). -/ -def applyLn1 (model : SAEEnhancedModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Apply ln_2 to a residual stream (row-wise, per token). -/ -def applyLn2 (model : SAEEnhancedModel) (layerIdx : Nat) (X : ConcreteMatrix) : ConcreteMatrix := - let p := model.ln2Params layerIdx - ConcreteMatrix.layerNormRowwise X p.gamma p.beta - -/-- Conservative operator-norm bound for ln_1 Jacobian at a specific activation. -/ -def ln1OpBound (model : SAEEnhancedModel) (layerIdx : Nat) (X : ConcreteMatrix) : Float := - let p := model.ln1Params layerIdx - ConcreteMatrix.layerNormRowwiseOpEst X p.gamma - -/-- Create from ConcreteModel with externally trained SAEs. -/ -def fromModel (model : ConcreteModel) (saes : Array ConcreteSAE) : Option SAEEnhancedModel := - if saes.size = model.numLayers then - some { - numLayers := model.numLayers - layers := model.layers - ln1 := model.ln1 - ln2 := model.ln2 - lnf := model.lnf - saes := saes - seqLen := model.seqLen - inputEmbeddings := model.inputEmbeddings - unembedding := model.unembedding - } - else none - -/-- Total reconstruction error across all SAEs for given forward pass. -/ -def totalReconstructionError (model : SAEEnhancedModel) - (fwd : ForwardPassResult) : Float := Id.run do - let mut totalErr : Float := 0.0 - for l in [:model.numLayers] do - if hl : l < model.saes.size then - let sae := model.saes[l] - let layerInput := fwd.getLayerInput l - let err := sae.reconstructionErrorMatrix layerInput - totalErr := totalErr + err * err - Float.sqrt totalErr - -end SAEEnhancedModel - -/-- Importance metrics for SAE features. -/ -structure SAEFeatureImportance where - /-- Component identifier -/ - component : ComponentId - /-- Value term norm (feature influence) -/ - valueTermNorm : Float - /-- Pattern term bound (activation instability) -/ - patternTermBound : Float - /-- Faithfulness ratio -/ - faithfulnessRatio : Float - /-- Reconstruction error contribution (SAE approximation) -/ - reconstructionError : Float - -namespace SAEFeatureImportance - -def toString (imp : SAEFeatureImportance) : String := - s!"{imp.component}: value={imp.valueTermNorm}, pattern={imp.patternTermBound}, " ++ - s!"recon={imp.reconstructionError}, ratio={imp.faithfulnessRatio}" - -instance : ToString SAEFeatureImportance := ⟨toString⟩ - -end SAEFeatureImportance - -/-- Compute importance metrics for a single SAE feature. -/ -def computeSAEFeatureImportance (sae : ConcreteSAE) (layerIdx featureIdx : Nat) - (layerInput : ConcreteMatrix) (perturbationNorm : Float) : Option SAEFeatureImportance := - if featureIdx < sae.numFeatures then - let influence := sae.featureInfluence featureIdx - - -- Compute IBP-based pattern term bound across positions - let patternBound := Id.run do - let mut totalPatternSq : Float := 0.0 - for pos in [:layerInput.numRows] do - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe pos d.val - let posBound := sae.featurePatternTermBoundIBP featureIdx inputVec perturbationNorm - totalPatternSq := totalPatternSq + posBound * posBound - Float.sqrt (totalPatternSq / (max 1 layerInput.numRows).toFloat) - - -- Estimate per-feature contribution to reconstruction error - -- Approximation: uniform distribution across features (could be refined) - let perFeatureRecon := Id.run do - let mut totalRecon : Float := 0.0 - for pos in [:layerInput.numRows] do - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe pos d.val - totalRecon := totalRecon + sae.reconstructionError inputVec - totalRecon / (max 1 layerInput.numRows).toFloat / sae.numFeatures.toFloat - - let ratio := if influence < 1e-10 then Float.inf else patternBound / influence - - some { - component := ComponentId.saeFeature layerIdx featureIdx - valueTermNorm := influence - patternTermBound := patternBound - faithfulnessRatio := ratio - reconstructionError := perFeatureRecon - } - else none - -/-- Error breakdown for SAE-based circuits. -/ -structure SAECircuitError where - /-- Pattern term error from included components -/ - patternTermError : Float - /-- Ablation error from excluded components -/ - ablationError : Float - /-- SAE reconstruction error (approximation of MLP) -/ - reconstructionError : Float - /-- Total error bound -/ - totalError : Float - /-- Number of included components -/ - includedCount : Nat - /-- Number of excluded components -/ - excludedCount : Nat - /-- Number of unstable features -/ - unstableFeatureCount : Nat - deriving Repr - -namespace SAECircuitError - -def toString (e : SAECircuitError) : String := - s!"SAECircuitError(total={e.totalError}, pattern={e.patternTermError}, " ++ - s!"ablation={e.ablationError}, recon={e.reconstructionError}, " ++ - s!"unstable={e.unstableFeatureCount})" - -instance : ToString SAECircuitError := ⟨toString⟩ - -end SAECircuitError - -/-- Helper to compute head importance for SAE analysis. -Inline version used before computeHeadImportance is defined. -Works with both ConcreteModel and SAEEnhancedModel (via their shared fields). -/ -private def computeHeadMetricsForSAE - (model : SAEEnhancedModel) - (layerIdx headIdx : Nat) (layerInput : ConcreteMatrix) : Option (Float × Float) := - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let attnInput := model.applyLn1 layerIdx layerInput - let attn := head.computeAttentionWeights attnInput false - let inputNorm := computeInputNorm attnInput - let inputOpBound := attnInput.opNormUpperBoundOneInf - let ln1Bound := model.ln1OpBound layerIdx layerInput - let bnds := head.noDenseProductBounds - - let valueNorm := ln1Bound * computeValueTermNorm attn bnds.voFrobNormSq - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternBound := ln1Bound * computePatternTermBound inputs - some (valueNorm, patternBound) - else none - else none - -/-- Estimate faithfulness error for an SAE-based circuit. - -Extends the standard circuit error model to include: -1. Pattern term error for included attention heads -2. Pattern term error for included SAE features (with IBP) -3. Ablation error for excluded components -4. SAE reconstruction error (the approximation of using SAE instead of MLP) - -Total Error = Σ(included) patternBound + Σ(excluded) valueNorm + reconstructionError --/ -def estimateSAECircuitFaithfulness (model : SAEEnhancedModel) - (circuit : SAECircuit) (_causal : Bool := true) : SAECircuitError := Id.run do - -- Simplified forward pass (just attention) - let mut residual := model.inputEmbeddings - let mut layerInputs : Array ConcreteMatrix := #[model.inputEmbeddings] - - for l in [:model.numLayers] do - let attnInput := model.applyLn1 l residual - let mut attnSum := ConcreteMatrix.zeros residual.numRows residual.numCols - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let headOutput := head.forward attnInput true - attnSum := attnSum.add headOutput - residual := residual.add attnSum - layerInputs := layerInputs.push residual - - let mut patternError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut totalRecon : Float := 0.0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - let mut unstableCount : Nat := 0 - let mut cumulativeAblation : Float := 0.0 - - -- Process attention heads - for l in [:model.numLayers] do - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h_idx in [:layerHeads.size] do - let included := circuit.isHeadIncluded l h_idx - let layerInput := - if hl2 : l < layerInputs.size then layerInputs[l]'hl2 else model.inputEmbeddings - match computeHeadMetricsForSAE model l h_idx layerInput with - | some (valueNorm, patternBound) => - if included then - patternError := patternError + patternBound - includedCount := includedCount + 1 - else - ablationError := ablationError + valueNorm - cumulativeAblation := cumulativeAblation + valueNorm - excludedCount := excludedCount + 1 - | none => pure () - - -- Process SAE features - for l in [:model.numLayers] do - if hl : l < model.saes.size then - let sae := model.saes[l] - -- Pre-LN: SAE/MLP sees ln_2(y_l) where y_l is post-attention residual. - let postAttn := - if hpost : l + 1 < layerInputs.size then layerInputs[l + 1]'hpost else residual - let layerInput := model.applyLn2 l postAttn - - -- Add reconstruction error for this layer - let layerRecon := sae.reconstructionErrorMatrix layerInput - totalRecon := totalRecon + layerRecon - - for f_idx in [:sae.numFeatures] do - let included := circuit.isFeatureIncluded l f_idx - match computeSAEFeatureImportance sae l f_idx layerInput cumulativeAblation with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - if imp.patternTermBound > 0.0 then - unstableCount := unstableCount + 1 - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - { - patternTermError := patternError - ablationError := ablationError - reconstructionError := totalRecon - totalError := patternError + ablationError + totalRecon - includedCount := includedCount - excludedCount := excludedCount - unstableFeatureCount := unstableCount - } - -/-! ### SAE Circuit Discovery - -Greedy pruning algorithm for SAE-enhanced circuits: -1. Start with all components included -2. Compute importance for each component (heads + SAE features) -3. Remove the component with smallest valueNorm (least information loss when ablated) -4. Repeat until error threshold would be exceeded - -Note: SAE reconstruction error is an additive constant for a given set of SAEs, -so it doesn't affect the pruning order. --/ - -/-- Ranked component importance for SAE circuits. -/ -structure SAERankedComponent where - /-- Component identifier -/ - component : ComponentId - /-- Value term norm (importance for ablation) -/ - valueTermNorm : Float - /-- Pattern term bound (error when included) -/ - patternTermBound : Float - -/-- Compute all component importances for an SAE-enhanced model. -/ -def computeAllSAEImportance (model : SAEEnhancedModel) : Array SAERankedComponent := Id.run do - let mut result : Array SAERankedComponent := #[] - - -- Simplified forward pass - let mut residual := model.inputEmbeddings - let mut layerInputs : Array ConcreteMatrix := #[model.inputEmbeddings] - - for l in [:model.numLayers] do - let attnInput := model.applyLn1 l residual - let mut attnSum := ConcreteMatrix.zeros residual.numRows residual.numCols - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - let headOutput := head.forward attnInput true - attnSum := attnSum.add headOutput - residual := residual.add attnSum - layerInputs := layerInputs.push residual - - -- Compute head importances - for l in [:model.numLayers] do - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - for h_idx in [:layerHeads.size] do - let layerInput := - if hl2 : l < layerInputs.size then layerInputs[l]'hl2 else model.inputEmbeddings - match computeHeadMetricsForSAE model l h_idx layerInput with - | some (valueNorm, patternBound) => - result := result.push { - component := ComponentId.head l h_idx - valueTermNorm := valueNorm - patternTermBound := patternBound - } - | none => pure () - - -- Compute SAE feature importances - for l in [:model.numLayers] do - if hl : l < model.saes.size then - let sae := model.saes[l] - let postAttn := - if hpost : l + 1 < layerInputs.size then layerInputs[l + 1]'hpost else residual - let layerInput := model.applyLn2 l postAttn - for f_idx in [:sae.numFeatures] do - match computeSAEFeatureImportance sae l f_idx layerInput 0.0 with - | some imp => - result := result.push { - component := imp.component - valueTermNorm := imp.valueTermNorm - patternTermBound := imp.patternTermBound - } - | none => pure () - - result - -/-- Discover minimal SAE circuit using greedy pruning. - -Algorithm: -1. Start with full circuit (all heads, all features) -2. Compute base reconstruction error (constant for given SAEs) -3. Sort all components by valueTermNorm (ascending) -4. Iteratively remove smallest-impact components until error exceeds threshold - -The threshold should account for SAE reconstruction error as a baseline. --/ -def discoverSAECircuit (model : SAEEnhancedModel) (threshold : Float) : SAECircuit := Id.run do - -- Initialize full circuit - let headsPerLayer := model.layers.map (·.size) - let featuresPerLayer := model.saes.map (·.numFeatures) - let mut circuit := SAECircuit.full model.numLayers headsPerLayer featuresPerLayer - - -- Compute base reconstruction error - let baseError := estimateSAECircuitFaithfulness model circuit - let reconError := baseError.reconstructionError - - -- If base error already exceeds threshold, return empty - if reconError > threshold then - return SAECircuit.empty model.numLayers headsPerLayer featuresPerLayer - - let adjustedThreshold := threshold - reconError - - -- Get all components sorted by valueTermNorm (ascending) - let allImportance := computeAllSAEImportance model - let sorted := allImportance.qsort fun a b => a.valueTermNorm < b.valueTermNorm - - let mut cumulativePatternError : Float := baseError.patternTermError - let mut cumulativeAblationError : Float := 0.0 - - -- Greedily remove components - for comp in sorted do - -- Removing this component: - -- - Removes its patternTermBound from included error - -- - Adds its valueTermNorm to ablation error - let newPatternError := cumulativePatternError - comp.patternTermBound - let newAblationError := cumulativeAblationError + comp.valueTermNorm - let newTotalError := newPatternError + newAblationError - - if newTotalError > adjustedThreshold then - -- Stop: removing this component would exceed threshold - break - - -- Remove the component - circuit := circuit.removeComponent comp.component - cumulativePatternError := newPatternError - cumulativeAblationError := newAblationError - - circuit - -/-- Result of SAE circuit discovery. -/ -structure SAEDiscoveryResult where - /-- The discovered circuit -/ - circuit : SAECircuit - /-- Error breakdown -/ - error : SAECircuitError - /-- Compression ratio (components kept / total) -/ - compressionRatio : Float - -/-- Discover SAE circuit with full result details. -/ -def discoverSAECircuitWithResult (model : SAEEnhancedModel) - (threshold : Float) : SAEDiscoveryResult := Id.run do - let circuit := discoverSAECircuit model threshold - let error := estimateSAECircuitFaithfulness model circuit - let totalComponents := circuit.totalHeads + circuit.totalFeatures - let includedComponents := circuit.countIncludedHeads + circuit.countIncludedFeatures - let compression := if totalComponents > 0 then - includedComponents.toFloat / totalComponents.toFloat - else 1.0 - - { circuit, error, compressionRatio := compression } - -/-! ## Ablated Forward Pass (Causal Intervention) - -These functions implement executable causal interventions: running a forward pass -where specific attention heads or MLP neurons are masked out (ablated) based on -a `ConcreteCircuit` mask. - -This bridges the theoretical `Abstraction.lean` bounds with empirical verification: -- `runAblatedForward`: Execute the circuit (masked forward pass) -- `computeAblationDiscrepancy`: Measure actual difference from full model -- `verifyCircuitFaithfulness`: Assert empirical ≤ theoretical bound --/ - -/-- Forward pass for an MLP layer with neuron-level ablation. - -Same as `ConcreteMLPLayer.forward` but with a mask specifying which neurons are active. -Inactive neurons have their contributions zeroed out. - -For neuron i: -- If included: contribute GeLU(W_in[:,i] · x + b_in[i]) * W_out[i,:] -- If excluded: contribute 0 --/ -def ConcreteMLPLayer.forwardAblated (layer : ConcreteMLPLayer) (input : ConcreteMatrix) - (neuronMask : Array Bool) : ConcreteMatrix := - -- hidden = input · W_in + b_in (seqLen × hiddenDim) - let hidden := (input.matmul layer.W_in).addBias layer.b_in - -- Apply GeLU activation with masking using .ofFn for proper size proof - let activated : ConcreteMatrix := { - numRows := hidden.numRows - numCols := hidden.numCols - data := .ofFn fun idx : Fin (hidden.numRows * hidden.numCols) => - let j := idx.val % hidden.numCols - let val := hidden.data[idx.val]! - let act := geluFloat val - -- Zero out if neuron j is not included - if neuronMask.getD j true then act else 0.0 - size_eq := Array.size_ofFn - } - -- output = activated · W_out + b_out (seqLen × modelDim) - (activated.matmul layer.W_out).addBias layer.b_out - -/-- Run an ablated forward pass through the model. - -Like `runForward`, but with a `ConcreteCircuit` mask that specifies which attention -heads and MLP neurons are active. Excluded components have their contributions -zeroed out, implementing a causal intervention. - -This enables **empirical validation** of theoretical circuit bounds: -1. Discover a circuit via `discoverCircuit` or `discoverTargetedCircuit` -2. Run `runAblatedForward` with that circuit -3. Compare to `runForward` to measure actual discrepancy -4. Verify that empirical discrepancy ≤ theoretical bound - -**Ablation semantics:** -- Excluded attention head: its output is zero (does not contribute to residual) -- Excluded MLP neuron: its activation is zero (does not contribute to FFN output) --/ -def ConcreteModel.runAblatedForward (model : ConcreteModel) (circuit : ConcreteCircuit) - (causal : Bool := true) : ForwardPassResult := Id.run do - let mut layerInputs : Array ConcreteMatrix := Array.mkEmpty (model.numLayers + 1) - let mut postAttnResiduals : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut attnOutputs : Array (Array ConcreteMatrix) := Array.mkEmpty model.numLayers - let mut mlpOutputs : Array ConcreteMatrix := Array.mkEmpty model.numLayers - let mut residual := model.inputEmbeddings - layerInputs := layerInputs.push residual - - for l in [:model.numLayers] do - -- Pre-LN: attention sees ln_1(residual) - let attnInput := model.applyLn1 l residual - -- Compute attention outputs for included heads only - let mut layerAttnOutputs : Array ConcreteMatrix := #[] - let mut includedHeadOutputs : Array ConcreteMatrix := #[] - let rows := residual.numRows - let cols := residual.numCols - let zeroOutput := ConcreteMatrix.zeros rows cols - - if hl : l < model.layers.size then - let layerHeads := model.layers[l] - let includedMask := circuit.includedHeads.getD l #[] - let includedCount := - includedMask.foldl (fun acc b => if b then acc + 1 else acc) 0 - let useParallelHeads := - layerHeads.size >= 4 && includedCount >= 4 - layerAttnOutputs := - if useParallelHeads then - let tasks : Array (Task ConcreteMatrix) := - .ofFn fun i : Fin layerHeads.size => - Task.spawn (fun _ => - if circuit.isHeadIncluded l i.val then - (layerHeads[i]).forward attnInput causal - else - zeroOutput) - tasks.map Task.get - else - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty layerHeads.size - for h in [:layerHeads.size] do - if hh : h < layerHeads.size then - let head := layerHeads[h]'hh - if circuit.isHeadIncluded l h then - outs := outs.push (head.forward attnInput causal) - else - outs := outs.push zeroOutput - return outs - -- Preserve the original summation order: increasing head index. - includedHeadOutputs := - Id.run do - let mut outs : Array ConcreteMatrix := Array.mkEmpty includedCount - for h in [:layerAttnOutputs.size] do - if circuit.isHeadIncluded l h then - if hh : h < layerAttnOutputs.size then - outs := outs.push (layerAttnOutputs[h]'hh) - return outs - - attnOutputs := attnOutputs.push layerAttnOutputs - - -- Add attention residual - let attnBias := model.attnProjBiasAt l - let attnSum := - if includedHeadOutputs.isEmpty then - ConcreteMatrix.zeros rows cols - else - ConcreteMatrix.sumMatrices includedHeadOutputs - let residualAfterAttn := residual.add (attnSum.addBias attnBias) - postAttnResiduals := postAttnResiduals.push residualAfterAttn - - -- Compute MLP output with neuron-level ablation - -- Pre-LN: MLP sees ln_2(residualAfterAttn) - let mlpInput := model.applyLn2 l residualAfterAttn - let mlpOut := - if hm : l < model.mlps.size then - -- Get neuron mask for this layer - let neuronMask := circuit.includedNeurons.getD l #[] - model.mlps[l].forwardAblated mlpInput neuronMask - else ConcreteMatrix.zeros residual.numRows residual.numCols - - mlpOutputs := mlpOutputs.push mlpOut - - -- Add MLP residual - residual := residualAfterAttn.add mlpOut - - -- Store input for next layer - layerInputs := layerInputs.push residual - - let finalOutput := model.applyLnf residual - { - layerInputs := layerInputs - postAttnResiduals := postAttnResiduals - attnOutputs := attnOutputs - mlpOutputs := mlpOutputs - mlpActDeriv := Array.replicate model.numLayers (ConcreteMatrix.zeros 0 0) - mlpActDerivMax := Array.replicate model.numLayers 0.0 - finalOutput := finalOutput - } - -/-! ### Empirical Discrepancy and Verification - -These functions compute the actual difference between full and ablated model outputs, -enabling empirical validation of theoretical circuit bounds. --/ - -/-- Compute the element-wise difference between two matrices. -/ -def ConcreteMatrix.sub (A B : ConcreteMatrix) : ConcreteMatrix := - if A.numRows = B.numRows ∧ A.numCols = B.numCols then - { - numRows := A.numRows - numCols := A.numCols - data := .ofFn fun idx : Fin (A.numRows * A.numCols) => - A.data.getD idx.val 0.0 - B.data.getD idx.val 0.0 - size_eq := Array.size_ofFn - } - else ConcreteMatrix.zeros 0 0 - -/-- Result of comparing full model output to ablated circuit output. - -This captures the empirical discrepancy between running the full model -and running only the discovered circuit. --/ -structure AblationResult where - /-- Full model output (residual stream after all layers) -/ - fullOutput : ConcreteMatrix - /-- Ablated circuit output -/ - ablatedOutput : ConcreteMatrix - /-- Difference: fullOutput - ablatedOutput -/ - difference : ConcreteMatrix - /-- Frobenius norm of the difference: ‖full - ablated‖_F -/ - empiricalError : Float - /-- Relative error: ‖full - ablated‖_F / ‖full‖_F -/ - relativeError : Float - /-- Number of components in the circuit -/ - circuitSize : Nat - /-- Total number of components in the model -/ - totalComponents : Nat - -namespace AblationResult - -/-- Compute compression ratio: what fraction of components are included. -/ -def compressionRatio (r : AblationResult) : Float := - if r.totalComponents > 0 then - r.circuitSize.toFloat / r.totalComponents.toFloat - else 1.0 - -/-- Pretty print the ablation result. -/ -def toString (r : AblationResult) : String := - s!"AblationResult:\n" ++ - s!" Empirical Error (‖Δ‖_F): {r.empiricalError}\n" ++ - s!" Relative Error: {r.relativeError * 100.0}%\n" ++ - s!" Circuit Size: {r.circuitSize}/{r.totalComponents} " ++ - s!"({r.compressionRatio * 100.0}%)" - -instance : ToString AblationResult := ⟨AblationResult.toString⟩ - -end AblationResult - -/-- Compute the empirical discrepancy between full model and ablated circuit. - -This is the core function for empirical validation: -1. Runs full forward pass -2. Runs ablated forward pass with the circuit mask -3. Computes the difference and its Frobenius norm - -The empirical error should be bounded by the theoretical error estimate from -`estimateCircuitFaithfulness`. --/ -def computeAblationDiscrepancy (model : ConcreteModel) (circuit : ConcreteCircuit) - (causal : Bool := true) : AblationResult := - let fullResult := model.runForward causal - let ablatedResult := model.runAblatedForward circuit causal - let diff := fullResult.finalOutput.sub ablatedResult.finalOutput - let empiricalErr := diff.frobeniusNorm - let fullNorm := fullResult.finalOutput.frobeniusNorm - let relErr := if fullNorm > 1e-10 then empiricalErr / fullNorm else 0.0 - { - fullOutput := fullResult.finalOutput - ablatedOutput := ablatedResult.finalOutput - difference := diff - empiricalError := empiricalErr - relativeError := relErr - circuitSize := circuit.countIncluded - totalComponents := circuit.totalComponents - } - -/-- Result of comparing empirical discrepancy to theoretical bound. - -This is the verification bridge between `Abstraction.lean` (theory) and -`Discovery.lean` (practice). --/ -structure VerificationResult where - /-- Ablation result with empirical measurements -/ - ablation : AblationResult - /-- Theoretical error bound from circuit analysis -/ - theoreticalBound : Float - /-- Whether empirical ≤ theoretical (verification passed) -/ - verified : Bool - /-- Slack: theoretical - empirical (how much margin we have) -/ - slack : Float - /-- Tightness ratio: empirical / theoretical -/ - tightness : Float - -namespace VerificationResult - -/-- Pretty print the verification result. -/ -def toString (r : VerificationResult) : String := - let status := if r.verified then "✓ VERIFIED" else "✗ FAILED" - s!"VerificationResult [{status}]\n" ++ - s!" Empirical Error: {r.ablation.empiricalError}\n" ++ - s!" Theoretical Bound: {r.theoreticalBound}\n" ++ - s!" Slack: {r.slack}\n" ++ - s!" Tightness: {r.tightness * 100.0}%\n" ++ - s!" Circuit: {r.ablation.circuitSize}/{r.ablation.totalComponents} components" - -instance : ToString VerificationResult := ⟨VerificationResult.toString⟩ - -end VerificationResult - -/-- Verify that a discovered circuit's empirical error is within theoretical bounds. - -This is the key function that **closes the loop** between theory and practice: - -**Input:** -- A model -- A discovered circuit (from `discoverCircuit` or `discoverTargetedCircuit`) -- The theoretical error bound (from `CircuitError.totalError`) - -**Output:** -- Verification result showing whether empirical ≤ theoretical - -**Usage:** -``` -let result := discoverCircuit model 0.1 -let verification := verifyCircuitFaithfulness model result.circuit result.error.totalError -if verification.verified then - IO.println "Circuit is empirically faithful!" -``` - -**Interpretation:** -- `verified = true`: The circuit recapitulates model behavior within bounds -- `tightness ≈ 1`: Theoretical bound is tight (good analysis) -- `tightness << 1`: Theoretical bound is loose (conservative) --/ -def verifyCircuitFaithfulness (model : ConcreteModel) (circuit : ConcreteCircuit) - (theoreticalBound : Float) (causal : Bool := true) : VerificationResult := - let ablation := computeAblationDiscrepancy model circuit causal - let verified := ablation.empiricalError ≤ theoreticalBound - let slack := theoreticalBound - ablation.empiricalError - let tightness := if theoreticalBound > 1e-10 - then ablation.empiricalError / theoreticalBound - else 1.0 - { - ablation := ablation - theoreticalBound := theoreticalBound - verified := verified - slack := slack - tightness := tightness - } - -/-! ### Component Importance Metrics -/ - -/-- Importance metrics for a single component. - -These metrics allow ranking components by their contribution to model behavior, -enabling principled circuit pruning. --/ -structure ComponentImportance where - /-- Component identifier -/ - component : ComponentId - /-- Value term norm: ‖A‖_F · ‖W_V·W_O‖_F (how much information flows through) -/ - valueTermNorm : Float - /-- Pattern term bound (approximation error if we trust attention patterns) -/ - patternTermBound : Float - /-- Faithfulness ratio: patternBound / valueNorm -/ - faithfulnessRatio : Float - -namespace ComponentImportance - -/-- Pretty print component importance. -/ -def toString (imp : ComponentImportance) : String := - s!"{imp.component}: value={imp.valueTermNorm}, pattern={imp.patternTermBound}, " ++ - s!"ratio={imp.faithfulnessRatio}" - -instance : ToString ComponentImportance := ⟨ComponentImportance.toString⟩ - -end ComponentImportance - -/-- Compute importance metrics for a single attention head. -/ -def computeHeadImportance (model : ConcreteModel) (layerIdx headIdx : Nat) - (layerInput : ConcreteMatrix) : Option ComponentImportance := do - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let attnInput := model.applyLn1 layerIdx layerInput - let attn := head.computeAttentionWeights attnInput - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - let ln1Bound := model.ln1OpBound layerIdx layerInput - let bnds := head.noDenseProductBounds - - -- Pre-LN: effective value path includes the LayerNorm Jacobian. - let valueNorm := ln1Bound * computeValueTermNorm attn bnds.voFrobNormSq - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternBound := ln1Bound * computePatternTermBound inputs - let ratio := if valueNorm < 1e-10 then Float.inf else patternBound / valueNorm - - return { - component := ComponentId.head layerIdx headIdx - valueTermNorm := valueNorm - patternTermBound := patternBound - faithfulnessRatio := ratio - } - else none - else none - -/-- Compute importance metrics for a single MLP neuron. - -**Simple Version (no forward pass):** -For ReLU/GeLU MLPs, this uses weight-based bounds only. -The influence magnitude = ‖W_in[:,i]‖ · ‖W_out[i,:]‖ bounds information flow. - -Pattern term is set to 0 (assumes locally linear), which is **unsound** if -ablations cause activation flips. Use `computeNeuronImportanceIBP` with -forward pass data for rigorous bounds. --/ -def computeNeuronImportance (model : ConcreteModel) (layerIdx neuronIdx : Nat) - (_inputNorm : Float) : Option ComponentImportance := - if h : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - if neuronIdx < mlp.hiddenDim then - let influence := mlp.neuronInfluence neuronIdx - -- Conservative assumption: locally linear (pattern term = 0) - -- WARNING: This is unsound if activation flips occur! - let patternBound : Float := 0.0 - let ratio := if influence < 1e-10 then Float.inf else patternBound / influence - - some { - component := ComponentId.mlpNeuron layerIdx neuronIdx - valueTermNorm := influence - patternTermBound := patternBound - faithfulnessRatio := ratio - } - else none - else none - -/-- Compute importance metrics for a single MLP neuron using IBP. - -**Sound Version (requires forward pass):** -Uses Interval Bound Propagation to detect neurons that may flip activation -states under input perturbations. Provides mathematically rigorous pattern -term bounds. - -**Parameters:** -- `layerInput`: Input to this layer (from forward pass), used to compute - nominal pre-activations -- `perturbationNorm`: L2 bound on how much ablations can change the input - (typically computed from the ablated components' value terms) - -**Returns:** ComponentImportance with rigorous pattern term bound that -accounts for potential activation flips. --/ -def computeNeuronImportanceIBP (model : ConcreteModel) (layerIdx neuronIdx : Nat) - (layerInput : ConcreteMatrix) (perturbationNorm : Float) : Option ComponentImportance := - if h : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - if neuronIdx < mlp.hiddenDim then - let influence := mlp.neuronInfluence neuronIdx - - -- Average the IBP bound across sequence positions - let patternBound := Id.run do - let mut totalPatternBound : Float := 0.0 - let numPositions := layerInput.numRows - - for pos in [:numPositions] do - -- Extract input vector at this position - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe pos d.val - -- Compute IBP-based pattern term bound - let posBound := mlp.neuronPatternTermBoundIBP neuronIdx inputVec perturbationNorm - totalPatternBound := totalPatternBound + posBound * posBound - - -- RMS of per-position bounds - Float.sqrt (totalPatternBound / (max 1 numPositions).toFloat) - - let ratio := if influence < 1e-10 then Float.inf else patternBound / influence - - some { - component := ComponentId.mlpNeuron layerIdx neuronIdx - valueTermNorm := influence - patternTermBound := patternBound - faithfulnessRatio := ratio - } - else none - else none - -/-- Compute importance metrics for all components in a model. -/ -def computeAllImportance (model : ConcreteModel) : Array ComponentImportance := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - - let computeHeadsForLayer (l : Nat) : Array ComponentImportance := Id.run do - if h : l < model.layers.size then - let layerHeads := model.layers[l] - let mut outs : Array ComponentImportance := Array.mkEmpty layerHeads.size - for h_idx in [:layerHeads.size] do - match computeHeadImportance model l h_idx model.inputEmbeddings with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let computeNeuronsForLayer (l : Nat) : Array ComponentImportance := Id.run do - if h : l < model.mlps.size then - let mlp := model.mlps[l] - let mut outs : Array ComponentImportance := Array.mkEmpty mlp.hiddenDim - for n_idx in [:mlp.hiddenDim] do - match computeNeuronImportance model l n_idx inputNorm with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let useParallel := model.numLayers >= 4 - let layerPairs : Array (Array ComponentImportance × Array ComponentImportance) := - if useParallel then - let tasks : Array (Task (Array ComponentImportance × Array ComponentImportance)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => (computeHeadsForLayer i.val, computeNeuronsForLayer i.val)) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - (computeHeadsForLayer i.val, computeNeuronsForLayer i.val) - - let headChunks : Array (Array ComponentImportance) := layerPairs.map (·.1) - let neuronChunks : Array (Array ComponentImportance) := layerPairs.map (·.2) - - -- Join in the same order as the original loop: heads then neurons, increasing layer index. - let totalHeads := headChunks.foldl (fun acc cs => acc + cs.size) 0 - let totalNeurons := neuronChunks.foldl (fun acc cs => acc + cs.size) 0 - let mut result : Array ComponentImportance := Array.mkEmpty (totalHeads + totalNeurons) - for cs in headChunks do - for c in cs do - result := result.push c - for cs in neuronChunks do - for c in cs do - result := result.push c - result - -/-! ### Target-Aware Circuit Discovery (Logit Lens) - -Standard circuit discovery finds components that contribute to the *entire* residual stream, -yielding large generic circuits. **Target-aware discovery** instead finds minimal circuits -for *specific predictions* by projecting component outputs onto a target direction. - -**Key Insight:** For a specific prediction (e.g., "the model predicts 'cat' not 'dog'"), -we define a target direction `u = W_U[cat] - W_U[dog]` in the residual stream space. -A component's importance is then `‖output · u‖` rather than `‖output‖_F`. - -This enables: -- Finding circuits responsible for specific token predictions -- Isolating mechanisms for behavioral differences (e.g., IOI task) -- Producing smaller, more interpretable circuits - -**Mathematical Formulation:** -For attention head with value-output projection `W_V · W_O`: -- Standard importance: `‖A‖_F · ‖W_V · W_O‖_F` (generic) -- Target-aware: `‖A‖_F · ‖(W_V · W_O) · u‖` (specific) - -For MLP neuron with output weights `W_out[i,:]`: -- Standard importance: `‖W_in[:,i]‖ · ‖W_out[i,:]‖` (generic) -- Target-aware: `‖W_in[:,i]‖ · |W_out[i,:] · u|` (specific) --/ - -/-- A target direction for focused circuit discovery. - -Specifies a direction in residual stream space to project component outputs onto. -Typically constructed as `u = W_U[correct_token] - W_U[incorrect_token]`. --/ -structure TargetDirection where - /-- The target vector in model dimension space (modelDim × 1 matrix) -/ - direction : ConcreteMatrix - /-- Human-readable description of what this direction represents -/ - description : String := "target" - -namespace TargetDirection - -/-- Create a target direction from unembedding columns for two tokens. - -`u = W_U[:, correctToken] - W_U[:, incorrectToken]` - -This direction points from the incorrect prediction toward the correct one, -so components with positive projection increase P(correct) / P(incorrect). --/ -def fromLogitDiff (unembedding : ConcreteMatrix) - (correctToken incorrectToken : Nat) : TargetDirection := - let correctCol := unembedding.getCol correctToken - let incorrectCol := unembedding.getCol incorrectToken - let direction := correctCol.vecSub incorrectCol - { - direction := direction - description := s!"logit_diff({correctToken}-{incorrectToken})" - } - -/-- Create a target direction from a single token's unembedding. - -Useful when you want to understand what promotes a specific token. --/ -def fromSingleToken (unembedding : ConcreteMatrix) (token : Nat) : TargetDirection := - { - direction := unembedding.getCol token - description := s!"logit({token})" - } - -/-- Normalize the target direction to unit length. -/ -def normalize (t : TargetDirection) : TargetDirection := - let norm := t.direction.vecNorm - if norm > 1e-10 then - { t with direction := t.direction.scale (1.0 / norm) } - else t - -/-- Construct a next-token logit-difference direction from the model's input token history. - -This is the **self-supervised induction target**: -let `T` be the ground-truth token sequence, and let `t_curr = T[last]`. -If `t_curr` appeared before at index `k`, the "induction target" is `t_next = T[k+1]`. - -Returns `none` if: -- the model has no `inputTokens`, -- the sequence has no previous occurrence of `t_curr`, -- or the model is missing an `unembedding` matrix. --/ -def fromInductionHistory (model : ConcreteModel) : Option TargetDirection := do - let tokens ← model.inputTokens - if tokens.size = 0 then none else - let lastIdx := tokens.size - 1 - let tCurr := tokens[lastIdx]! - let mut foundIdx : Option Nat := none - for offset in [:lastIdx] do - if foundIdx.isNone then - let idx := lastIdx - 1 - offset - if tokens[idx]! = tCurr then - foundIdx := some idx - - let k ← foundIdx - let tNext := tokens[k + 1]! - - let W_U ← model.unembedding - let vocabSize := W_U.numCols - if vocabSize < 2 then none - else if tNext ≥ vocabSize then none - else - let incorrect : Nat := - if tCurr < vocabSize ∧ tCurr ≠ tNext then tCurr - else - let cand1 := (tNext + 1) % vocabSize - if cand1 ≠ tNext then cand1 else (tNext + 2) % vocabSize - if incorrect = tNext then none - else - let base := TargetDirection.fromLogitDiff W_U tNext incorrect - some { base with - description := s!"induction_history(curr={tCurr}, prev={k}, \ - next={tNext}, neg={incorrect})" - } - -end TargetDirection - -/-! ## Virtual-Head Effectiveness Verification -/ - -/-- Extremely generous cutoff to reject only egregious **interpretability illusions**. - -In theory, a mechanism is "genuine" when its relative approximation error is < 1.0 -(`isGenuineMechanism` / `mechanism_trichotomy` in `Nfp.Linearization`). In practice, the -executable Frobenius-norm bounds can be loose in high dimensions, making the strict < 1.0 -test numerically vacuous. Empirically, however, massive errors indicate clear illusions. - -We therefore filter only astronomically large `combinedError` values while ranking by -faithfulness (smallest `combinedError` first). --/ -def egregiousIllusionThreshold : Float := 1.0e30 - -/-- Egregious-illusion filter (currently disabled). - -We intentionally keep *all* candidates (even those with extremely loose bounds), since -the Frobenius-norm estimates can scale poorly with depth and dimension. --/ -def passesEgregiousIllusionFilter (_candidate : CandidateInductionHead) : Bool := - true - -/-- Compute the raw **direct** effectiveness score `δ` for an induction-head candidate. - -Induction heads are primarily a **pattern** story (Q/K-composition): a "previous token" head -enables the *induction head* to attend to the **successor** of a previous matching token. -Once the induction head is attending to the right source positions, its functional effect is -driven by the head's **OV circuit** applied to the residual stream at those source tokens. - -Accordingly, we treat head 1 purely as a *pattern enabler* (enforced by the pattern filters) -and score only the **direct value path** of head 2: - -For a **Pre-LayerNorm** transformer (GPT-2 style), attention reads from `ln₁(X₂)` (not `X₂`). -We compute: - -`Y = A₂ · ln₁(X₂) · W₂`, - -then score the last position against `target.direction = u`: - -`δ = ⟪Y[last], u⟫`. --/ -def computeInductionEffectiveness - (candidate : CandidateInductionHead) - (cache : PrecomputedCache) - (layer2Ln1Input : ConcreteMatrix) - (target : TargetDirection) : Float := - match cache.getHeadData candidate.layer2Idx candidate.head2Idx with - | some data2 => - let u := target.direction - - -- Direct head score: δ = ⟪(A₂ · ln₁(X₂) · W₂)[last], u⟫. - -- - -- PERFORMANCE: compute this scalar without materializing any dense `d×d` products. - -- 1) v = (W_V·W_O) · u = W_V · (W_O · u) - -- 2) xdot[k] = ⟪ln₁(X₂)[k], v⟫ - -- 3) δ = ⟪A₂[last], xdot⟫ - if u.numCols ≠ 1 then 0.0 - else if data2.attention.seqLen = 0 then 0.0 - else if layer2Ln1Input.numRows ≠ data2.attention.seqLen then 0.0 - else - if hl : candidate.layer2Idx < cache.model.layers.size then - let layerHeads := cache.model.layers[candidate.layer2Idx]'hl - if hh : candidate.head2Idx < layerHeads.size then - let head2 := layerHeads[candidate.head2Idx]'hh - if layer2Ln1Input.numCols ≠ u.numRows then 0.0 - else - let n := data2.attention.seqLen - let lastPos := n - 1 - - -- v = W_V · (W_O · u) - let t := head2.W_O.matVecMul u - let v := head2.W_V.matVecMul t - if v.numRows = 0 then 0.0 - else - -- xdot[k] = ⟪ln₁(X₂)[k], v⟫ - let xdot : Array Float := .ofFn fun k : Fin n => Id.run do - let xBase := k.val * layer2Ln1Input.numCols - let mut acc : Float := 0.0 - for c in [:layer2Ln1Input.numCols] do - -- SAFETY: `k < n = layer2Ln1Input.numRows` by construction and guard above. - let x := layer2Ln1Input.data[xBase + c]! - -- SAFETY: `v` is `modelDim×1` so `c < v.data.size`. - let vc := v.data[c]! - acc := acc + x * vc - return acc - - -- δ = ⟪A2[last], xdot⟫ - Id.run do - let w2 := data2.attention.weights - let row2Base := lastPos * n - let mut score : Float := 0.0 - for k in [:n] do - -- SAFETY: `w2` has size `n*n` and `row2Base + k < n*n`. - score := score + w2[row2Base + k]! * xdot[k]! - return score - else - 0.0 - else - 0.0 - | none => 0.0 - -/-- Compute the **certified lower bound** from `true_induction_head_predicts_logits`. - -`LowerBound = δ - (ε · ‖X‖_F · ‖u‖₂)` where: -- `δ` is the virtual effectiveness score, -- `ε` is `candidate.combinedError`, -- `X` is the layer-2 Pre-LN attention input matrix `ln₁(X₂)`, -- `u` is the target direction. --/ -def computeCertifiedLowerBound - (delta : Float) - (candidate : CandidateInductionHead) - (layer2Ln1Input : ConcreteMatrix) - (target : TargetDirection) : Float := - delta - (candidate.combinedError * layer2Ln1Input.frobeniusNorm * target.direction.vecNorm) - -/-- Rank induction-head candidates by a **mechanism-first** score. - -We compute the raw direct-effect score `δ`, normalize it to a scale-invariant `effect`, -and also compute a prompt/weight-based mechanism score: - -`mechScore = kComp · inductionScore · prevTokenStrength`, - -as well as the combined score: - -`circuitScore = effect · mechScore`. - -Here: -- `effect = δ / (‖ln₁(X₂)‖_F · ‖u‖₂)` removes the Pre-LN depth confounder where the residual - stream norm grows with depth, and -- `kComp` (from the circuits framework paper) measures how strongly head 1 can feed into - head 2's **QK circuit**, i.e. whether head 1 plausibly acts as a *pattern enabler* for head 2. -- `inductionScore` is the prompt-dependent "copy-next" attention score for head 2, and -- `prevTokenStrength` is the prompt-dependent previous-token attention score for head 1. - -We rank primarily by `mechScore` (to identify the canonical induction mechanism) and use -`effect` only as a secondary key. - -We still compute and report `combinedError` for inspection, but avoid using it as a primary -ranking key since Frobenius-norm bounds can be systematically looser in high dimensions. - -Uses a `PrecomputedCache` so attention patterns/projections and layer inputs are computed once. --/ -def findHeuristicInductionHeadsWithCache (model : ConcreteModel) - (target : TargetDirection) - (minEffect : Float := 0.0) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) - (buildLayerNormBounds : Bool := true) - (layerNormEffort : ConcreteMatrix.BoundEffort := ConcreteMatrix.BoundEffort.tier1) - (storeDiagnostics : Bool := false) : - (Array HeuristicInductionHead × PrecomputedCache) := - Id.run do - let cache := PrecomputedCache.build model (computeLayerNormBounds := buildLayerNormBounds) - (layerNormEffort := layerNormEffort) - (storeDiagnostics := storeDiagnostics) - let targetNorm := target.direction.vecNorm - -- Precompute layer input norms once (per-layer, not per-candidate). - let layerInputNorms : Array Float := - .ofFn fun i : Fin model.numLayers => - (cache.forwardResult.getLayerInput i.val).frobeniusNorm - let ln1InputNorms : Array Float := - .ofFn fun i : Fin model.numLayers => - (cache.getLn1Input i.val).frobeniusNorm - let candidates := - findInductionHeadCandidatesFromCache cache minPrevTokenStrength minInductionScore - let mut certified : Array HeuristicInductionHead := #[] - - for candidate in candidates do - if passesEgregiousIllusionFilter candidate then - let layer2InputNorm := layerInputNorms.getD candidate.layer2Idx 0.0 - let layer2Ln1Input := cache.getLn1Input candidate.layer2Idx - let layer2Ln1InputNorm := ln1InputNorms.getD candidate.layer2Idx 0.0 - let delta := computeInductionEffectiveness candidate cache layer2Ln1Input target - let denom := layer2Ln1InputNorm * targetNorm - let effect := - if denom > 1e-10 then delta / denom else 0.0 - if effect > minEffect then - certified := certified.push { - candidate := candidate - delta := delta - effect := effect - layer2InputNorm := layer2InputNorm - layer2Ln1InputNorm := layer2Ln1InputNorm - } - - let certifiedSorted := - certified.qsort (fun a b => - -- Primary key: higher **mechanism score** first. - -- - -- Induction heads are primarily defined by attention-pattern structure (copy-next) - -- plus K-composition with a previous-token head. Target-direction Effect is useful, - -- but prompt/target-dependent; we therefore use it only as a secondary key. - let sa := - if Float.isNaN a.effect ∨ Float.isNaN a.candidate.kComp ∨ - Float.isNaN a.candidate.inductionScore ∨ - Float.isNaN a.candidate.prevTokenStrength then - (-Float.inf) - else - a.candidate.kComp * a.candidate.inductionScore * a.candidate.prevTokenStrength - let sb := - if Float.isNaN b.effect ∨ Float.isNaN b.candidate.kComp ∨ - Float.isNaN b.candidate.inductionScore ∨ - Float.isNaN b.candidate.prevTokenStrength then - (-Float.inf) - else - b.candidate.kComp * b.candidate.inductionScore * b.candidate.prevTokenStrength - if sb < sa then true - else if sa < sb then false - else - -- Secondary key: higher normalized Effect first. - let ea := if Float.isNaN a.effect then (-Float.inf) else a.effect - let eb := if Float.isNaN b.effect then (-Float.inf) else b.effect - if eb < ea then true - else if ea < eb then false - else - -- Tertiary key: higher K-composition first. - let ka := if Float.isNaN a.candidate.kComp then (-Float.inf) else a.candidate.kComp - let kb := if Float.isNaN b.candidate.kComp then (-Float.inf) else b.candidate.kComp - if kb < ka then true - else if ka < kb then false - else - -- Next key: higher raw δ first. - let δa := if Float.isNaN a.delta then (-Float.inf) else a.delta - let δb := if Float.isNaN b.delta then (-Float.inf) else b.delta - if δb < δa then true - else if δa < δb then false - else - -- Final key: smaller relative-error bound first. - let εa := - if Float.isNaN a.candidate.combinedError then - Float.inf - else - a.candidate.combinedError - let εb := - if Float.isNaN b.candidate.combinedError then - Float.inf - else - b.candidate.combinedError - εa < εb) - - return (certifiedSorted, cache) - -/-- Rank induction-head candidates by a **mechanism-first** score. - -This is the same as `findHeuristicInductionHeadsWithCache`, but discards the cache. --/ -def findHeuristicInductionHeads (model : ConcreteModel) - (target : TargetDirection) - (minEffect : Float := 0.0) - (minPrevTokenStrength : Float := 0.1) - (minInductionScore : Float := 0.05) : Array HeuristicInductionHead := Id.run do - (findHeuristicInductionHeadsWithCache model target minEffect - (minPrevTokenStrength := minPrevTokenStrength) - (minInductionScore := minInductionScore)).1 - -/-- Target-aware importance metrics for a component. - -Like `ComponentImportance` but with an additional field measuring -projection onto the target direction. --/ -structure TargetAwareImportance where - /-- Component identifier -/ - component : ComponentId - /-- Standard value term norm (generic importance) -/ - valueTermNorm : Float - /-- Pattern term bound -/ - patternTermBound : Float - /-- **Target projection**: how much this component contributes to the target direction. - For heads: `‖(W_V · W_O) · u‖` - For neurons: `|W_out[i,:] · u|` -/ - targetProjection : Float - /-- Faithfulness ratio for target: patternBound / targetProjection -/ - targetFaithfulnessRatio : Float - -namespace TargetAwareImportance - -/-- Pretty print target-aware importance. -/ -def toString (imp : TargetAwareImportance) : String := - s!"{imp.component}: target={imp.targetProjection}, generic={imp.valueTermNorm}, " ++ - s!"pattern={imp.patternTermBound}, ratio={imp.targetFaithfulnessRatio}" - -instance : ToString TargetAwareImportance := ⟨TargetAwareImportance.toString⟩ - -end TargetAwareImportance - -/-- Compute target-aware importance for a single attention head. - -The target projection is computed as `‖(W_V · W_O) · u‖` where u is the target direction. -This measures how much the head's output aligns with the target direction. --/ -def computeHeadTargetImportance (model : ConcreteModel) (layerIdx headIdx : Nat) - (inputNorm : Float) (target : TargetDirection) : Option TargetAwareImportance := do - let attn ← model.computeAttention layerIdx headIdx - if h1 : layerIdx < model.layers.size then - let layerHeads := model.layers[layerIdx] - if h2 : headIdx < layerHeads.size then - let head := layerHeads[headIdx] - let bnds := head.noDenseProductBounds - - -- Standard metrics - let valueNorm := computeValueTermNorm attn bnds.voFrobNormSq - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - -- Conservative: `‖X‖₂ ≤ ‖X‖_F`. - inputOpBound := inputNorm - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - let patternBound := computePatternTermBound inputs - - -- Target-aware: compute ‖(W_V · W_O) · u‖ via low-rank matvecs: - -- (W_V·W_O)·u = W_V·(W_O·u). - let t := head.W_O.matVecMul target.direction - let projectedVec := head.W_V.matVecMul t - let targetProj := projectedVec.vecNorm - - -- Scale by attention norm (as in standard valueTermNorm) - let attnNormSq := attn.weights.foldl (fun acc x => acc + x * x) 0.0 - let attnNorm := Float.sqrt attnNormSq - let targetImportance := attnNorm * targetProj - - let ratio := if targetImportance < 1e-10 then Float.inf else patternBound / targetImportance - - return { - component := ComponentId.head layerIdx headIdx - valueTermNorm := valueNorm - patternTermBound := patternBound - targetProjection := targetImportance - targetFaithfulnessRatio := ratio - } - else none - else none - -/-- Compute target-aware importance for a single MLP neuron. - -The target projection is `|W_out[i,:] · u|` - the absolute dot product of the -neuron's output weights with the target direction. --/ -def computeNeuronTargetImportance (model : ConcreteModel) (layerIdx neuronIdx : Nat) - (_inputNorm : Float) (target : TargetDirection) : Option TargetAwareImportance := - if h : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - if neuronIdx < mlp.hiddenDim then - -- Standard influence - let inputNormVal := mlp.neuronInputNorm neuronIdx - let outputNormVal := mlp.neuronOutputNorm neuronIdx - let influence := inputNormVal * outputNormVal - - -- Target-aware: compute W_out[i,:] · u - -- Get row i of W_out as a column vector for dot product - let outputWeights : ConcreteMatrix := { - numRows := mlp.modelDim - numCols := 1 - data := .ofFn fun j : Fin mlp.modelDim => - mlp.W_out.data[neuronIdx * mlp.modelDim + j.val]! - size_eq := by simp - } - let dotProd := outputWeights.dot target.direction - let targetProj := inputNormVal * Float.abs dotProd - - let patternBound : Float := 0.0 -- ReLU is locally linear - let ratio := if targetProj < 1e-10 then Float.inf else patternBound / targetProj - - some { - component := ComponentId.mlpNeuron layerIdx neuronIdx - valueTermNorm := influence - patternTermBound := patternBound - targetProjection := targetProj - targetFaithfulnessRatio := ratio - } - else none - else none - -/-- Compute target-aware importance for all components in a model. -/ -def computeAllTargetImportance (model : ConcreteModel) - (target : TargetDirection) : Array TargetAwareImportance := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - - let computeHeadsForLayer (l : Nat) : Array TargetAwareImportance := Id.run do - if h : l < model.layers.size then - let layerHeads := model.layers[l] - let mut outs : Array TargetAwareImportance := Array.mkEmpty layerHeads.size - for h_idx in [:layerHeads.size] do - match computeHeadTargetImportance model l h_idx inputNorm target with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let computeNeuronsForLayer (l : Nat) : Array TargetAwareImportance := Id.run do - if h : l < model.mlps.size then - let mlp := model.mlps[l] - let mut outs : Array TargetAwareImportance := Array.mkEmpty mlp.hiddenDim - for n_idx in [:mlp.hiddenDim] do - match computeNeuronTargetImportance model l n_idx inputNorm target with - | some imp => outs := outs.push imp - | none => pure () - return outs - else - return #[] - - let useParallel := model.numLayers >= 4 - let headChunks : Array (Array TargetAwareImportance) := - if useParallel then - let tasks : Array (Task (Array TargetAwareImportance)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeHeadsForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeHeadsForLayer i.val - - let neuronChunks : Array (Array TargetAwareImportance) := - if useParallel then - let tasks : Array (Task (Array TargetAwareImportance)) := - .ofFn fun i : Fin model.numLayers => - Task.spawn (fun _ => computeNeuronsForLayer i.val) - tasks.map Task.get - else - .ofFn fun i : Fin model.numLayers => - computeNeuronsForLayer i.val - - -- Join in the same order as the original loop: heads then neurons, increasing layer index. - let totalHeads := headChunks.foldl (fun acc cs => acc + cs.size) 0 - let totalNeurons := neuronChunks.foldl (fun acc cs => acc + cs.size) 0 - let mut result : Array TargetAwareImportance := Array.mkEmpty (totalHeads + totalNeurons) - for cs in headChunks do - for c in cs do - result := result.push c - for cs in neuronChunks do - for c in cs do - result := result.push c - result - -/-! ### Circuit Faithfulness Estimation -/ - -/-- Error breakdown for a circuit. - -Total error has two components: -1. **Pattern Term Error**: Approximation error from trusting attention patterns -2. **Ablation Error**: Information loss from pruned components - -Total Error ≤ PatternTermError + AblationError --/ -structure CircuitError where - /-- Sum of pattern term bounds for included components -/ - patternTermError : Float - /-- Sum of value term norms for excluded (ablated) components -/ - ablationError : Float - /-- Combined error bound -/ - totalError : Float - /-- Number of included components -/ - includedCount : Nat - /-- Number of excluded components -/ - excludedCount : Nat - deriving Repr - -namespace CircuitError - -/-- Pretty print error breakdown. -/ -def toString (err : CircuitError) : String := - s!"CircuitError(total={err.totalError}, pattern={err.patternTermError}, " ++ - s!"ablation={err.ablationError}, included={err.includedCount}, excluded={err.excludedCount})" - -instance : ToString CircuitError := ⟨CircuitError.toString⟩ - -end CircuitError - -/-! ### N-Layer Faithfulness Verification - -This section implements the N-layer error amplification formula from `Linearization.lean` -for concrete Float matrices. The key insight is that errors in early layers get -amplified as they propagate through subsequent layers: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -where: -- εᵢ = pattern term bound for layer i (interpretation error) -- Cⱼ = operator norm bound for layer j's residual Jacobian (layerJacobian - I) --/ - -/-- Per-layer error metrics for deep circuit analysis. -/ -structure LayerErrorMetrics where - /-- Layer index -/ - layerIdx : Nat - /-- Pattern term bound εᵢ (faithfulness error before amplification) -/ - patternTermBound : Float - /-- Operator norm upper bound Cᵢ for layerJacobian - I (residual part). -/ - operatorNormUb : Float - /-- Suffix amplification: ∏_{j>i} (1 + Cⱼ) -/ - suffixAmplification : Float - /-- Amplified error contribution: εᵢ · suffixAmplification(i+1) -/ - amplifiedError : Float - -namespace LayerErrorMetrics - -def toString (m : LayerErrorMetrics) : String := - s!"Layer {m.layerIdx}: ε={m.patternTermBound}, C_ub={m.operatorNormUb}, " ++ - s!"amp={m.suffixAmplification}, contrib={m.amplifiedError}" - -instance : ToString LayerErrorMetrics := ⟨toString⟩ - -end LayerErrorMetrics - -/-- Deep circuit verification result with rigorous N-layer error bounds. -/ -structure DeepCircuitVerification where - /-- Per-layer error breakdown -/ - layerMetrics : Array LayerErrorMetrics - /-- Total error bound: Σᵢ εᵢ · suffixAmplification(i+1) -/ - totalAmplifiedError : Float - /-- Simple sum error (no amplification): Σᵢ εᵢ -/ - simpleErrorSum : Float - /-- Total amplification factor: ∏ᵢ (1 + Cᵢ) -/ - totalAmplificationFactor : Float - /-- Ablation error from excluded components -/ - ablationError : Float - /-- Combined error bound (amplified + ablation) -/ - combinedError : Float - /-- Number of layers analyzed -/ - numLayers : Nat - -namespace DeepCircuitVerification - -def toString (v : DeepCircuitVerification) : String := - s!"DeepCircuitVerification:\n" ++ - s!" Layers: {v.numLayers}\n" ++ - s!" Total Amplified Error: {v.totalAmplifiedError}\n" ++ - s!" Simple Error Sum: {v.simpleErrorSum}\n" ++ - s!" Amplification Factor: {v.totalAmplificationFactor}\n" ++ - s!" Ablation Error: {v.ablationError}\n" ++ - s!" Combined Error: {v.combinedError}" - -instance : ToString DeepCircuitVerification := ⟨toString⟩ - -end DeepCircuitVerification - -/-- Estimate pattern term bound for a single layer (all heads combined). - -Aggregates pattern term bounds across all attention heads in the layer. --/ -def estimateLayerPatternBound (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (circuit : ConcreteCircuit) : Float := Id.run do - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - let layerInput := fwdResult.getLayerInput layerIdx - let attnInput := model.applyLn1 layerIdx layerInput - let inputNorm := attnInput.frobeniusNorm - let inputOpBound := attnInput.opNormUpperBoundOneInf - let ln1Bound := model.ln1OpBound layerIdx layerInput - let mut totalBound : Float := 0.0 - - for hidx in [:heads.size] do - if hh : hidx < heads.size then - -- Only count included heads - if circuit.isHeadIncluded layerIdx hidx then - let head := heads[hidx] - let attn := head.computeAttentionWeights attnInput - let bnds := head.noDenseProductBounds - let inputs : PatternTermBoundInputs := { - attention := attn - inputNorm := inputNorm - inputOpBound := inputOpBound - scaleFactor := Float.sqrt head.headDim.toFloat - wqOpBound := bnds.wqOpGram - wkOpBound := bnds.wkOpGram - wvOpBound := bnds.wvOpGram - woOpBound := bnds.woOpGram - voOpBound := bnds.voOpBound - bqFrob := head.b_Q.frobeniusNorm - bkFrob := head.b_K.frobeniusNorm - bvFrob := head.b_V.frobeniusNorm - } - -- Pre-LN: pattern sensitivity is scaled by the LayerNorm Jacobian. - totalBound := totalBound + ln1Bound * computePatternTermBound inputs - - totalBound - else - return 0.0 - -/-- Estimate ablation error for excluded components at a single layer. -/ -def estimateLayerAblationError (model : ConcreteModel) (fwdResult : ForwardPassResult) - (layerIdx : Nat) (circuit : ConcreteCircuit) : Float := Id.run do - let layerInput := fwdResult.getLayerInput layerIdx - let mut totalError : Float := 0.0 - - -- Ablation error from excluded attention heads - if h : layerIdx < model.layers.size then - let heads := model.layers[layerIdx] - for hidx in [:heads.size] do - if hh : hidx < heads.size then - if !circuit.isHeadIncluded layerIdx hidx then - match computeHeadImportance model layerIdx hidx layerInput with - | some imp => totalError := totalError + imp.valueTermNorm - | none => pure () - - -- Ablation error from excluded neurons - if hm : layerIdx < model.mlps.size then - let mlp := model.mlps[layerIdx] - for nidx in [:mlp.hiddenDim] do - if !circuit.isNeuronIncluded layerIdx nidx then - match computeNeuronImportance model layerIdx nidx (layerInput.frobeniusNorm) with - | some imp => totalError := totalError + imp.valueTermNorm - | none => pure () - - totalError - -/-- Verify a deep circuit using rigorous N-layer error amplification bounds. - -This is the main function that bridges theoretical bounds to practical verification. -It computes: -1. Per-layer pattern term bounds (εᵢ) -2. Per-layer residual operator norm bounds (Cᵢ) -3. Suffix amplification factors -4. Total amplified error using the N-layer composition formula - -**The N-Layer Formula:** - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -This captures that early layer errors get amplified more because they pass -through more subsequent layers. --/ -def verifyDeepCircuit (model : ConcreteModel) - (circuit : ConcreteCircuit) : DeepCircuitVerification := Id.run do - -- Run forward pass to get layer-wise inputs - let fwdResult := model.runForward - - -- Step 1: Compute per-layer residual operator norm bounds - let mut normBounds : Array Float := #[] - for l in [:model.numLayers] do - let norm := estimateAttentionLayerNorm model fwdResult l true - normBounds := normBounds.push norm - - -- Step 2: Compute per-layer pattern term bounds and ablation errors - let mut patternBounds : Array Float := #[] - let mut ablationErrors : Array Float := #[] - for l in [:model.numLayers] do - let pattern := estimateLayerPatternBound model fwdResult l circuit - let ablation := estimateLayerAblationError model fwdResult l circuit - patternBounds := patternBounds.push pattern - ablationErrors := ablationErrors.push ablation - - -- Step 3: Compute suffix amplification and amplified errors per layer - let mut layerMetrics : Array LayerErrorMetrics := #[] - let mut simpleSum : Float := 0.0 - let mut totalAmplified : Float := 0.0 - - for l in [:model.numLayers] do - if hl : l < patternBounds.size then - let epsilon := patternBounds[l] - let normBound := if hn : l < normBounds.size then normBounds[l] else 0.0 - let suffix := computeSuffixAmplification normBounds (l + 1) - let amplified := epsilon * suffix - - simpleSum := simpleSum + epsilon - totalAmplified := totalAmplified + amplified - - layerMetrics := layerMetrics.push { - layerIdx := l - patternTermBound := epsilon - operatorNormUb := normBound - suffixAmplification := suffix - amplifiedError := amplified - } - - -- Step 4: Compute total ablation error - let totalAblation := ablationErrors.foldl (· + ·) 0.0 - - -- Step 5: Compute total amplification factor - let totalAmpFactor := computeSuffixAmplification normBounds 0 - - { - layerMetrics := layerMetrics - totalAmplifiedError := totalAmplified - simpleErrorSum := simpleSum - totalAmplificationFactor := totalAmpFactor - ablationError := totalAblation - combinedError := totalAmplified + totalAblation - numLayers := model.numLayers - } - -/-- Check if a deep circuit meets a certification threshold. -/ -def isDeepCircuitCertified (verification : DeepCircuitVerification) - (threshold : Float) : Bool := - verification.combinedError ≤ threshold - -/-- Structure for a verified deep circuit with certification. -/ -structure VerifiedDeepCircuit where - /-- The circuit that was verified -/ - circuit : ConcreteCircuit - /-- Full verification details -/ - verification : DeepCircuitVerification - /-- The threshold used -/ - threshold : Float - /-- Whether it passed certification -/ - certified : Bool - -namespace VerifiedDeepCircuit - -def toString (v : VerifiedDeepCircuit) : String := - let status := if v.certified then "✓ CERTIFIED" else "✗ NOT CERTIFIED" - s!"{status} (threshold={v.threshold})\n{v.verification}" - -instance : ToString VerifiedDeepCircuit := ⟨toString⟩ - -end VerifiedDeepCircuit - -/-- Verify and certify a deep circuit against a threshold. -/ -def certifyDeepCircuit (model : ConcreteModel) (circuit : ConcreteCircuit) - (threshold : Float) : VerifiedDeepCircuit := - let verification := verifyDeepCircuit model circuit - { - circuit := circuit - verification := verification - threshold := threshold - certified := isDeepCircuitCertified verification threshold - } - -/-- Estimate the faithfulness error for a given circuit mask. - -This is the core function that enables circuit discovery without forward passes. -It uses only weight matrices and attention patterns to bound the error. - -**Error Model:** -- For **included** components: we incur the pattern term error (approximation) -- For **excluded** components: we incur the value term error (ablation) - -Total Error = Σ(included) patternBound + Σ(excluded) valueNorm --/ -def estimateCircuitFaithfulness (model : ConcreteModel) - (circuit : ConcreteCircuit) : CircuitError := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - let mut patternError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - - -- Process attention heads - for l in [:model.numLayers] do - if h : l < model.layers.size then - let layerHeads := model.layers[l] - for h_idx in [:layerHeads.size] do - let included := circuit.isHeadIncluded l h_idx - match computeHeadImportance model l h_idx model.inputEmbeddings with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - -- Process MLP neurons - for l in [:model.numLayers] do - if h : l < model.mlps.size then - let mlp := model.mlps[l] - for n_idx in [:mlp.hiddenDim] do - let included := circuit.isNeuronIncluded l n_idx - match computeNeuronImportance model l n_idx inputNorm with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - { - patternTermError := patternError - ablationError := ablationError - totalError := patternError + ablationError - includedCount := includedCount - excludedCount := excludedCount - } - -/-- Extended circuit error with IBP analysis details. -/ -structure CircuitErrorIBP extends CircuitError where - /-- Total number of unstable neurons detected -/ - unstableNeuronCount : Nat - /-- Pattern error contribution from unstable MLP neurons -/ - mlpInstabilityError : Float - /-- Per-layer MLP stability ratios -/ - layerStabilityRatios : Array Float - deriving Repr - -namespace CircuitErrorIBP - -def toString (e : CircuitErrorIBP) : String := - s!"CircuitErrorIBP: pattern={e.patternTermError}, ablation={e.ablationError}, " ++ - s!"total={e.totalError}, unstable_neurons={e.unstableNeuronCount}, " ++ - s!"mlp_instability={e.mlpInstabilityError}" - -instance : ToString CircuitErrorIBP := ⟨toString⟩ - -end CircuitErrorIBP - -/-- Estimate circuit faithfulness with Interval Bound Propagation for MLPs. - -This is the **sound** version of circuit faithfulness estimation that properly -accounts for MLP activation instability. It runs a forward pass to get layer -inputs, then uses IBP to bound the pattern term error for neurons that may -flip activation states. - -**Key Improvement over `estimateCircuitFaithfulness`:** -- Standard version assumes MLP pattern term = 0 (unsound) -- IBP version detects unstable neurons and computes rigorous error bounds - -**Algorithm:** -1. Run forward pass to get layer inputs -2. For each layer, compute the ablation perturbation norm (sum of excluded - component value terms up to this layer) -3. For each MLP neuron, use IBP to determine if it's stable under this - perturbation -4. Unstable neurons contribute their IBP pattern term bound to total error - -**Parameters:** -- `causal`: Whether to use causal attention masking (default true) - -**Returns:** Extended error struct with stability analysis details --/ -def estimateCircuitFaithfulnessIBP (model : ConcreteModel) - (circuit : ConcreteCircuit) (causal : Bool := true) : CircuitErrorIBP := Id.run do - let inputNorm := computeInputNorm model.inputEmbeddings - let fwd := model.runForward causal - - let mut patternError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut mlpInstabilityError : Float := 0.0 - let mut unstableCount : Nat := 0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - let mut layerStability : Array Float := #[] - - -- Track cumulative ablation perturbation up to each layer - -- This is the norm of the change to the residual stream from ablated components - let mut cumulativeAblation : Float := 0.0 - - -- Process layer by layer - for l in [:model.numLayers] do - let mut layerAblation : Float := 0.0 - - -- Process attention heads in this layer - if h : l < model.layers.size then - let layerHeads := model.layers[l] - let layerInput := fwd.getLayerInput l - for h_idx in [:layerHeads.size] do - let included := circuit.isHeadIncluded l h_idx - match computeHeadImportance model l h_idx layerInput with - | some imp => - if included then - patternError := patternError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.valueTermNorm - layerAblation := layerAblation + imp.valueTermNorm - excludedCount := excludedCount + 1 - | none => pure () - - -- Update cumulative ablation (this affects MLP inputs) - cumulativeAblation := cumulativeAblation + layerAblation - - -- Process MLP neurons with IBP - if hm : l < model.mlps.size then - let mlp := model.mlps[l] - let layerInput := fwd.getLayerInput l - - let mut layerUnstable : Nat := 0 - let mut layerMlpPattern : Float := 0.0 - - for n_idx in [:mlp.hiddenDim] do - let included := circuit.isNeuronIncluded l n_idx - if included then - -- Use IBP with cumulative perturbation norm - match computeNeuronImportanceIBP model l n_idx layerInput cumulativeAblation with - | some imp => - patternError := patternError + imp.patternTermBound - if imp.patternTermBound > 0.0 then - layerUnstable := layerUnstable + 1 - layerMlpPattern := layerMlpPattern + imp.patternTermBound - includedCount := includedCount + 1 - | none => pure () - else - -- Excluded neurons contribute value term (ablation error) - let influence := mlp.neuronInfluence n_idx - ablationError := ablationError + influence - excludedCount := excludedCount + 1 - - unstableCount := unstableCount + layerUnstable - mlpInstabilityError := mlpInstabilityError + layerMlpPattern - - -- Record stability ratio for this layer - let stabilityRatio := if mlp.hiddenDim = 0 then 1.0 - else (mlp.hiddenDim - layerUnstable).toFloat / mlp.hiddenDim.toFloat - layerStability := layerStability.push stabilityRatio - - { - patternTermError := patternError - ablationError := ablationError - totalError := patternError + ablationError - includedCount := includedCount - excludedCount := excludedCount - unstableNeuronCount := unstableCount - mlpInstabilityError := mlpInstabilityError - layerStabilityRatios := layerStability - } - -/-- Summary of MLP stability across a model for a given perturbation. -/ -structure MLPStabilitySummary where - /-- Per-layer analysis results -/ - layerAnalyses : Array MLPIntervalAnalysis - /-- Total stable neurons across all layers -/ - totalStable : Nat - /-- Total unstable neurons across all layers -/ - totalUnstable : Nat - /-- Overall stability ratio -/ - overallStabilityRatio : Float - /-- Total pattern term bound from all unstable neurons -/ - totalPatternBound : Float - deriving Repr - -/-- Analyze MLP stability across the entire model. - -Runs IBP analysis on all MLP layers to identify which neurons are stable -under a given perturbation bound. --/ -def analyzeModelMLPStability (model : ConcreteModel) - (perturbationNorm : Float) (causal : Bool := true) : MLPStabilitySummary := Id.run do - let fwd := model.runForward causal - let mut analyses : Array MLPIntervalAnalysis := #[] - let mut totalStable : Nat := 0 - let mut totalUnstable : Nat := 0 - let mut totalPattern : Float := 0.0 - - for l in [:model.numLayers] do - if hm : l < model.mlps.size then - let mlp := model.mlps[l] - let layerInput := fwd.getLayerInput l - - -- Analyze each position and aggregate - let mut layerAnalysis : MLPIntervalAnalysis := { - layerIdx := l - neuronBounds := #[] - perturbationNorm := perturbationNorm - numStable := 0 - numUnstable := 0 - totalPatternBound := 0.0 - } - - -- Use position 0 as representative (could average over positions) - if layerInput.numRows > 0 then - let inputVec : Array Float := .ofFn fun d : Fin layerInput.numCols => - layerInput.getUnsafe 0 d.val - layerAnalysis := mlp.analyzeIntervalBounds l inputVec perturbationNorm - - analyses := analyses.push layerAnalysis - totalStable := totalStable + layerAnalysis.numStable - totalUnstable := totalUnstable + layerAnalysis.numUnstable - totalPattern := totalPattern + layerAnalysis.totalPatternBound - - let totalNeurons := totalStable + totalUnstable - let ratio := if totalNeurons = 0 then 1.0 - else totalStable.toFloat / totalNeurons.toFloat - - { - layerAnalyses := analyses - totalStable := totalStable - totalUnstable := totalUnstable - overallStabilityRatio := ratio - totalPatternBound := totalPattern - } - -/-! ### Greedy Circuit Pruning -/ - -/-- Result of the greedy pruning algorithm. -/ -structure PruningResult where - /-- The discovered circuit -/ - circuit : ConcreteCircuit - /-- Error estimate for the circuit -/ - error : CircuitError - /-- History of pruning steps (component removed, error after removal) -/ - pruningHistory : Array (ComponentId × Float) - /-- The error threshold that was used -/ - threshold : Float - -namespace PruningResult - -/-- Pretty print pruning result. -/ -def toString (pr : PruningResult) : String := - s!"PruningResult: {pr.circuit}\n Error: {pr.error}\n " ++ - s!"Steps: {pr.pruningHistory.size}, Threshold: {pr.threshold}" - -instance : ToString PruningResult := ⟨PruningResult.toString⟩ - -end PruningResult - -/-- Find the component with smallest value term (least important for information flow). - -Returns the component ID and its value term norm, considering only currently included -components. --/ -def findLeastImportantComponent (circuit : ConcreteCircuit) - (importance : Array ComponentImportance) : Option (ComponentId × Float) := Id.run do - let mut best : Option (ComponentId × Float) := none - - for imp in importance do - let included := circuit.isIncluded imp.component - if included then - match best with - | none => best := some (imp.component, imp.valueTermNorm) - | some (_, bestValue) => - if imp.valueTermNorm < bestValue then - best := some (imp.component, imp.valueTermNorm) - - best - -/-- Find the component with smallest target projection (least important for target behavior). - -Returns the component ID and its target projection, considering only currently included -components. This is the target-aware version of `findLeastImportantComponent`. --/ -def findLeastImportantTargetComponent (circuit : ConcreteCircuit) - (importance : Array TargetAwareImportance) : Option (ComponentId × Float) := Id.run do - let mut best : Option (ComponentId × Float) := none - - for imp in importance do - let included := circuit.isIncluded imp.component - if included then - match best with - | none => best := some (imp.component, imp.targetProjection) - | some (_, bestValue) => - if imp.targetProjection < bestValue then - best := some (imp.component, imp.targetProjection) - - best - -/-- Estimate circuit faithfulness for target-aware pruning. - -For target-aware circuits, the error model is: -- **Included components**: Contribute approximation error (pattern term) -- **Excluded components**: Contribute information loss measured by target projection - -Unlike generic discovery where we use `‖W_V·W_O‖_F` for ablation error, -here we use `targetProjection` - the component's contribution to the target direction. --/ -def estimateTargetCircuitError (_model : ConcreteModel) (circuit : ConcreteCircuit) - (importance : Array TargetAwareImportance) : CircuitError := Id.run do - let mut patternTermError : Float := 0.0 - let mut ablationError : Float := 0.0 - let mut includedCount : Nat := 0 - let mut excludedCount : Nat := 0 - - for imp in importance do - if circuit.isIncluded imp.component then - patternTermError := patternTermError + imp.patternTermBound - includedCount := includedCount + 1 - else - ablationError := ablationError + imp.targetProjection - excludedCount := excludedCount + 1 - - { - patternTermError := patternTermError - ablationError := ablationError - totalError := patternTermError + ablationError - includedCount := includedCount - excludedCount := excludedCount - } - -/-- Greedy circuit pruning algorithm. - -Starting from the full model, iteratively removes the component with the smallest -value term contribution until the total error would exceed the threshold. - -**Algorithm:** -1. Start with all components included -2. Compute importance metrics for all components -3. Repeat: - a. Find component with smallest valueTermNorm among included - b. Tentatively remove it - c. Estimate new total error - d. If error ≤ threshold, commit removal; else restore and stop -4. Return the pruned circuit - -**Complexity:** O(n²) where n = number of components (n iterations, each scanning n components) --/ -def discoverCircuit (model : ConcreteModel) (threshold : Float) : PruningResult := Id.run do - -- Build heads per layer array - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - -- Build neurons per layer array - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - -- Start with full circuit - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - - -- Precompute all importance metrics - let importance := computeAllImportance model - - -- Initial error - let mut currentError := estimateCircuitFaithfulness model circuit - - -- Greedy pruning loop - let maxIters := circuit.totalComponents - for _ in [:maxIters] do - -- Find least important included component - match findLeastImportantComponent circuit importance with - | none => break -- No more components to prune - | some (comp, _) => - -- Tentatively remove component - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateCircuitFaithfulness model tentativeCircuit - - -- Check if we can afford to remove it - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - else - break -- Would exceed threshold, stop pruning - - { - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - } - -/-- Discover circuit with verbose output of each step. -/ -def discoverCircuitVerbose (model : ConcreteModel) (threshold : Float) : - PruningResult × Array String := Id.run do - let mut logs : Array String := #[] - - -- Build heads per layer array - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - -- Build neurons per layer array - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - let importance := computeAllImportance model - - logs := logs.push s!"Starting with full circuit: {circuit.countIncluded} components" - - let mut currentError := estimateCircuitFaithfulness model circuit - logs := logs.push s!"Initial error: {currentError.totalError}" - - let maxIters := circuit.totalComponents - for step in [:maxIters] do - match findLeastImportantComponent circuit importance with - | none => - logs := logs.push s!"Step {step}: No more components to prune" - break - | some (comp, valueNorm) => - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateCircuitFaithfulness model tentativeCircuit - - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - let msg := s!"Step {step}: Removed {comp}, new error={tentativeError.totalError}" - logs := logs.push msg - else - let msg := s!"Step {step}: Cannot remove {comp}, exceeds threshold" - logs := logs.push msg - break - - logs := logs.push s!"Final circuit: {circuit}" - - ({ - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - }, logs) - -/-! ### Circuit Verification -/ - -/-- A verified circuit that meets the certification threshold. -/ -structure VerifiedCircuit where - /-- The pruned circuit -/ - circuit : ConcreteCircuit - /-- Error estimate -/ - error : CircuitError - /-- Certification threshold -/ - threshold : Float - /-- Human-readable description -/ - description : String - -namespace VerifiedCircuit - -/-- Pretty print verified circuit. -/ -def toString (vc : VerifiedCircuit) : String := - s!"VerifiedCircuit [{vc.description}]\n {vc.circuit}\n {vc.error}\n " ++ - s!"Threshold: {vc.threshold}" - -instance : ToString VerifiedCircuit := ⟨VerifiedCircuit.toString⟩ - -end VerifiedCircuit - -/-- Discover and verify a circuit, returning None if threshold cannot be met. -/ -def discoverVerifiedCircuit (model : ConcreteModel) (threshold : Float) - (description : String := "auto-discovered") : Option VerifiedCircuit := do - let result := discoverCircuit model threshold - if result.error.totalError ≤ threshold then - some { - circuit := result.circuit - error := result.error - threshold := threshold - description := description - } - else - none - -/-! ### Target-Aware Circuit Discovery - -These functions discover circuits optimized for specific predictions rather than -general model behavior. Given a target direction (e.g., logit difference between -correct and incorrect tokens), they find the minimal circuit that explains that -specific prediction. --/ - -/-- Target-aware greedy circuit pruning algorithm. - -Like `discoverCircuit`, but uses target projection instead of generic value term -for importance ranking. This finds circuits that explain *specific behaviors* -rather than everything the model does. - -**Algorithm:** -1. Start with all components included -2. Compute target-aware importance for all components -3. Repeat: - a. Find component with smallest targetProjection among included - b. Tentatively remove it - c. Estimate new total error (using target projections for ablation) - d. If error ≤ threshold, commit removal; else stop -4. Return the pruned circuit - -**Key Difference from Generic Discovery:** -- Generic: prunes by `‖W_V·W_O‖_F` (generic information flow) -- Target-aware: prunes by `‖(W_V·W_O)·u‖` (contribution to target direction) - -This typically produces much smaller circuits when you care about specific outputs. --/ -def discoverTargetedCircuit (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) : PruningResult := Id.run do - -- Build heads per layer array - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - -- Build neurons per layer array - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - -- Start with full circuit - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - - -- Precompute target-aware importance metrics - let importance := computeAllTargetImportance model target - - -- Initial error - let mut currentError := estimateTargetCircuitError model circuit importance - - -- Greedy pruning loop - let maxIters := circuit.totalComponents - for _ in [:maxIters] do - -- Find least important included component (by target projection) - match findLeastImportantTargetComponent circuit importance with - | none => break - | some (comp, _) => - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateTargetCircuitError model tentativeCircuit importance - - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - else - break - - { - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - } - -/-- Target-aware circuit discovery with verbose logging. -/ -def discoverTargetedCircuitVerbose (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) : PruningResult × Array String := Id.run do - let mut logs : Array String := #[] - logs := logs.push s!"Target-aware discovery for: {target.description}" - - -- Build layer arrays - let mut headsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.layers.size then - headsPerLayer := headsPerLayer.push model.layers[l].size - else - headsPerLayer := headsPerLayer.push 0 - - let mut neuronsPerLayer : Array Nat := #[] - for l in [:model.numLayers] do - if h : l < model.mlps.size then - neuronsPerLayer := neuronsPerLayer.push model.mlps[l].hiddenDim - else - neuronsPerLayer := neuronsPerLayer.push 0 - - let mut circuit := ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - let mut history : Array (ComponentId × Float) := #[] - let importance := computeAllTargetImportance model target - - logs := logs.push s!"Starting with full circuit: {circuit.countIncluded} components" - - let mut currentError := estimateTargetCircuitError model circuit importance - logs := logs.push s!"Initial error: {currentError.totalError}" - - let maxIters := circuit.totalComponents - for step in [:maxIters] do - match findLeastImportantTargetComponent circuit importance with - | none => - logs := logs.push s!"Step {step}: No more components to prune" - break - | some (comp, targetProj) => - let tentativeCircuit := circuit.removeComponent comp - let tentativeError := estimateTargetCircuitError model tentativeCircuit importance - - if tentativeError.totalError ≤ threshold then - circuit := tentativeCircuit - currentError := tentativeError - history := history.push (comp, tentativeError.totalError) - let msg := s!"Step {step}: Removed {comp} (target={targetProj}), " ++ - s!"new error={tentativeError.totalError}" - logs := logs.push msg - else - let msg := s!"Step {step}: Cannot remove {comp}, exceeds threshold" - logs := logs.push msg - break - - logs := logs.push s!"Final circuit: {circuit}" - logs := logs.push s!"Compression: {circuit.countIncluded}/{circuit.totalComponents} components" - - ({ - circuit := circuit - error := currentError - pruningHistory := history - threshold := threshold - }, logs) - -/-- Discover and verify a target-aware circuit. -/ -def discoverVerifiedTargetedCircuit (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) : Option VerifiedCircuit := do - let result := discoverTargetedCircuit model threshold target - if result.error.totalError ≤ threshold then - some { - circuit := result.circuit - error := result.error - threshold := threshold - description := s!"target-aware: {target.description}" - } - else - none - -/-- Convenience function to discover circuit for logit difference. - -Given correct and incorrect token IDs, creates the target direction -`u = W_U[:, correct] - W_U[:, incorrect]` and discovers the minimal -circuit that explains why the model predicts correct over incorrect. --/ -def discoverLogitDiffCircuit (model : ConcreteModel) (threshold : Float) - (correctToken incorrectToken : Nat) : Option (PruningResult × TargetDirection) := do - let W_U ← model.unembedding - let target := TargetDirection.fromLogitDiff W_U correctToken incorrectToken - let result := discoverTargetedCircuit model threshold target - some (result, target) - -/-- Rank components by their target projection (descending). - -Useful for identifying which components most strongly promote the target behavior. --/ -def rankComponentsByTargetImportance (model : ConcreteModel) - (target : TargetDirection) : Array TargetAwareImportance := - let importance := computeAllTargetImportance model target - importance.qsort (·.targetProjection > ·.targetProjection) - -/-- Get the top-k components most important for a target direction. -/ -def topKTargetComponents (model : ConcreteModel) (target : TargetDirection) - (k : Nat) : Array TargetAwareImportance := - let ranked := rankComponentsByTargetImportance model target - ranked.extract 0 (min k ranked.size) - -/-! ### End-to-End Discovery and Verification - -These functions combine circuit discovery with empirical verification, -providing a complete workflow from model analysis to validated circuits. --/ - -/-- Discover a circuit and immediately verify it empirically. - -This is the end-to-end function that: -1. Discovers a minimal circuit using greedy pruning -2. Computes theoretical error bounds -3. Verifies empirically that the circuit is faithful - -Returns both the pruning result and verification result. --/ -def discoverAndVerify (model : ConcreteModel) (threshold : Float) - (causal : Bool := true) : PruningResult × VerificationResult := - let pruning := discoverCircuit model threshold - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification) - -/-- Discover a target-aware circuit and verify it empirically. - -Like `discoverAndVerify` but optimizes for a specific prediction target -(e.g., logit difference between correct and incorrect tokens). --/ -def discoverTargetedAndVerify (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) (causal : Bool := true) : - PruningResult × VerificationResult := - let pruning := discoverTargetedCircuit model threshold target - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification) - -/-- Complete analysis: discover, verify, and return detailed comparison. - -This is the most comprehensive function for circuit analysis. It: -1. Discovers the minimal circuit meeting the error threshold -2. Runs both full and ablated forward passes -3. Computes empirical vs theoretical error comparison -4. Returns everything needed for detailed analysis - -**Example output interpretation:** -- `verification.verified = true`: Circuit is empirically faithful -- `verification.tightness = 0.8`: Theoretical bound is 80% utilized (20% slack) -- `ablation.relativeError = 0.05`: Circuit output differs by 5% from full model --/ -def analyzeCircuitFaithfulness (model : ConcreteModel) (threshold : Float) - (causal : Bool := true) : PruningResult × VerificationResult × AblationResult := - let pruning := discoverCircuit model threshold - let ablation := computeAblationDiscrepancy model pruning.circuit causal - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification, ablation) - -/-- Analyze a target-aware circuit with full verification details. -/ -def analyzeTargetedCircuitFaithfulness (model : ConcreteModel) (threshold : Float) - (target : TargetDirection) (causal : Bool := true) : - PruningResult × VerificationResult × AblationResult := - let pruning := discoverTargetedCircuit model threshold target - let ablation := computeAblationDiscrepancy model pruning.circuit causal - let verification := verifyCircuitFaithfulness model pruning.circuit - pruning.error.totalError causal - (pruning, verification, ablation) - -/-! ### Analysis Utilities -/ - -/-- Rank all components by their value term contribution (descending). -/ -def rankComponentsByImportance (model : ConcreteModel) : Array ComponentImportance := - let importance := computeAllImportance model - importance.qsort (·.valueTermNorm > ·.valueTermNorm) - -/-- Get the top-k most important components. -/ -def topKComponents (model : ConcreteModel) (k : Nat) : Array ComponentImportance := - let ranked := rankComponentsByImportance model - ranked.extract 0 (min k ranked.size) - -/-- Get components with faithfulness ratio below threshold (most reliable). -/ -def reliableComponents (model : ConcreteModel) (maxRatio : Float) : Array ComponentImportance := - let importance := computeAllImportance model - importance.filter (·.faithfulnessRatio ≤ maxRatio) - -/-- Summary of circuit discovery analysis. -/ -structure CircuitAnalysis where - /-- Total number of components in model -/ - totalComponents : Nat - /-- Number of components in discovered circuit -/ - circuitSize : Nat - /-- Compression ratio: circuitSize / totalComponents -/ - compressionRatio : Float - /-- Total error bound -/ - totalError : Float - /-- Pattern term contribution to error -/ - patternContribution : Float - /-- Ablation contribution to error -/ - ablationContribution : Float - /-- Most important component (by value term) -/ - topComponent : Option ComponentImportance - /-- Most reliable component (by faithfulness ratio) -/ - mostReliable : Option ComponentImportance - -/-- Perform comprehensive circuit analysis. -/ -def analyzeCircuit (model : ConcreteModel) (threshold : Float) : CircuitAnalysis := Id.run do - let result := discoverCircuit model threshold - let importance := computeAllImportance model - let ranked := importance.qsort (·.valueTermNorm > ·.valueTermNorm) - let reliable := importance.qsort (·.faithfulnessRatio < ·.faithfulnessRatio) - - let total := result.circuit.totalComponents - let included := result.circuit.countIncluded - let ratio := if total > 0 then included.toFloat / total.toFloat else 1.0 - - { - totalComponents := total - circuitSize := included - compressionRatio := ratio - totalError := result.error.totalError - patternContribution := result.error.patternTermError - ablationContribution := result.error.ablationError - topComponent := if h : 0 < ranked.size then some ranked[0] else none - mostReliable := if h : 0 < reliable.size then some reliable[0] else none - } - -end Nfp diff --git a/Nfp/IO.lean b/Nfp/IO.lean index a0a338b..75af425 100644 --- a/Nfp/IO.lean +++ b/Nfp/IO.lean @@ -1,810 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Nfp.Discovery +module -/-! -# Model IO: Loading Pre-trained Weights - -This module provides functionality to load pre-trained transformer model weights -from external files (exported from PyTorch/JAX) into the `ConcreteModel` structure. - -## Performance & Safety Design - -This module prioritizes **safety and clear error reporting** over raw performance. -- All array accesses use bounds-checked indexing (`array[i]!` panics on OOB) -- Comprehensive validation of file format and dimensions with helpful error messages -- File format parsing is I/O-bound, so optimizing array operations has minimal impact - -For high-performance computation, see `Discovery.lean` where hot paths are optimized. - -## File Format: `.nfpt` (NFP_BINARY_V1) - -Hybrid text header + binary body: - -``` -NFP_BINARY_V1 -num_layers=12 -num_heads=12 -model_dim=768 -head_dim=64 -hidden_dim=3072 -vocab_size=50257 -seq_len=1024 -BINARY_START -``` - -Binary payload (little-endian, row-major, no markers): -1. TOKENS: `seq_len` × Int32 -2. EMBEDDINGS: `seq_len` × `model_dim` × Float64 -3. For each layer (0..num_layers-1), for each head (0..num_heads-1): - - W_Q (`model_dim`×`head_dim`), b_Q (`head_dim`) - - W_K (`model_dim`×`head_dim`), b_K (`head_dim`) - - W_V (`model_dim`×`head_dim`), b_V (`head_dim`) - - W_O (`head_dim`×`model_dim`) -4. ATTN_BIAS: `model_dim` -5. MLP: W_in (`model_dim`×`hidden_dim`), b_in (`hidden_dim`), - W_out (`hidden_dim`×`model_dim`), b_out (`model_dim`) -6. LN1 gamma/beta (`model_dim` each) -7. LN2 gamma/beta (`model_dim` each) -8. LN_F gamma/beta (`model_dim` each) -9. UNEMBEDDING: `model_dim`×`vocab_size` -``` --/ - -namespace Nfp - -open IO - -/-! ## Float Parsing Utilities -/ - -private def pow10PowTable : Array Float := Id.run do - -- Precompute `Float.pow 10.0 k` for k=0..308 so we avoid calling `Float.pow` per token. - let mut out : Array Float := Array.mkEmpty 309 - for k in [:309] do - out := out.push (Float.pow 10.0 k.toFloat) - out - -private def pow10Pow (n : Nat) : Float := - if n < pow10PowTable.size then - pow10PowTable[n]! - else - Float.pow 10.0 n.toFloat - -private def parseFloatRange (s : String) (start stop : String.Pos.Raw) : Option Float := Id.run do - -- This is a faster, allocation-free version of the previous `parseFloat`, but it preserves - -- the exact Float computation structure (Nat parsing + `Float.pow`) to keep results stable. - - let parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do - let mut p := start - if p ≥ stop then - return none - let mut acc : Nat := 0 - let mut saw : Bool := false - while p < stop do - let c := p.get s - if ('0' ≤ c) && (c ≤ '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - saw := true - p := p.next s - else - return none - if saw then some acc else none - - let mut p := start - if p ≥ stop then - return none - - let mut negative := false - let c0 := p.get s - if c0 = '-' then - negative := true - p := p.next s - else if c0 = '+' then - p := p.next s - - if p ≥ stop then - return none - - -- Find exponent marker the same way as the old parser: accept exactly one `e` if present, - -- otherwise accept exactly one `E`. - let mut ePos : Option String.Pos.Raw := none - let mut eCount : Nat := 0 - let mut EPos : Option String.Pos.Raw := none - let mut ECount : Nat := 0 - let mut q := p - while q < stop do - let c := q.get s - if c = 'e' then - eCount := eCount + 1 - if eCount = 1 then ePos := some q - else if c = 'E' then - ECount := ECount + 1 - if ECount = 1 then EPos := some q - q := q.next s - - let expMarker? : Option String.Pos.Raw := - if eCount = 1 then ePos else if ECount = 1 then EPos else none - - let mantEnd : String.Pos.Raw := - match expMarker? with - | some ep => ep - | none => stop - - -- Find decimal point in mantissa (must be 0 or 1 occurrences). - let mut dotPos : Option String.Pos.Raw := none - let mut dotCount : Nat := 0 - let mut r := p - while r < mantEnd do - if r.get s = '.' then - dotCount := dotCount + 1 - if dotCount = 1 then dotPos := some r - r := r.next s - if dotCount > 1 then - return none - - let (intStart, intStop, fracStart?, fracStop) := - match dotPos with - | none => (p, mantEnd, none, mantEnd) - | some dp => (p, dp, some (dp.next s), mantEnd) - - let intN? : Option Nat := - if dotPos.isSome && intStart = intStop then - some 0 - else - parseNatRange s intStart intStop - - let fracN? : Option Nat := - match fracStart? with - | none => none - | some fs => - if fs = fracStop then some 0 else parseNatRange s fs fracStop - - let mantissa? : Option Float := - match dotPos, intN?, fracN? with - | none, some iN, _ => - some iN.toFloat - | some _, some iN, some fN => - let fracLen := (fracStop.byteIdx - (fracStart?.getD fracStop).byteIdx) - let divisor := pow10Pow fracLen - some (iN.toFloat + fN.toFloat / divisor) - | some _, _, none => - -- `.` present but no fractional parse (shouldn't happen), treat as invalid. - none - | _, none, _ => none - - let some mantissa := mantissa? | return none - - let value : Float := - match expMarker? with - | none => mantissa - | some ep => - let expStart := ep.next s - if expStart ≥ stop then - mantissa - else - -- Parse exponent, but if it is malformed, ignore it (old behavior). - let c := expStart.get s - let (expNeg, es) := - if c = '-' then (true, expStart.next s) - else if c = '+' then (false, expStart.next s) - else (false, expStart) - match parseNatRange s es stop with - | none => mantissa - | some eNat => - let p10 := pow10Pow eNat - if expNeg then mantissa / p10 else mantissa * p10 +public import Nfp.IO.InductionHead +public import Nfp.IO.Util - some (if negative then -value else value) - -/-- Parse a floating point number from a string. -/ -def parseFloat (s : String) : Option Float := Id.run do - let s := s.trim - if s.isEmpty then - none - else - parseFloatRange s 0 s.rawEndPos - -private def appendFloatsFromLine (line : String) (acc : Array Float) : Array Float := Id.run do - let mut out := acc - let s := line - let mut p : String.Pos.Raw := 0 - let endPos := s.rawEndPos - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - while p < endPos do - while p < endPos && isWs (p.get s) do - p := p.next s - let start := p - while p < endPos && !isWs (p.get s) do - p := p.next s - if start < p then - match parseFloatRange s start p with - | some x => out := out.push x - | none => pure () - out - -private def parseFloatsFromLines (lines : Array String) (cap : Nat := 0) : Array Float := - Id.run do - let mut out : Array Float := Array.mkEmpty cap - for line in lines do - out := appendFloatsFromLine line out - out - -private def spawnParseFloats (lines : Array String) (cap : Nat := 0) : Task (Array Float) := - Task.spawn (fun _ => parseFloatsFromLines lines cap) - -/-- Parse a line of space-separated floats. -/ -def parseFloatLine (line : String) : Array Float := - appendFloatsFromLine line #[] - -/-! ## Nat Parsing Utilities -/ - -/-- Parse a line of space-separated natural numbers. -/ -private def parseNatRange (s : String) (start stop : String.Pos.Raw) : Option Nat := Id.run do - let mut p := start - if p ≥ stop then - return none - let mut acc : Nat := 0 - let mut saw : Bool := false - while p < stop do - let c := p.get s - if ('0' ≤ c) && (c ≤ '9') then - acc := acc * 10 + (c.toNat - '0'.toNat) - saw := true - p := p.next s - else - return none - if saw then some acc else none - -private def appendNatsFromLine (line : String) (acc : Array Nat) : Array Nat := Id.run do - let mut out := acc - let s := line - let mut p : String.Pos.Raw := 0 - let endPos := s.rawEndPos - let isWs (c : Char) : Bool := - c = ' ' || c = '\t' || c = '\n' || c = '\r' - while p < endPos do - while p < endPos && isWs (p.get s) do - p := p.next s - let start := p - while p < endPos && !isWs (p.get s) do - p := p.next s - if start < p then - match parseNatRange s start p with - | some n => out := out.push n - | none => pure () - out - -def parseNatLine (line : String) : Array Nat := - appendNatsFromLine line #[] - -/-! ## Matrix Construction for IO - -For IO operations, we need to create matrices from runtime data. -We use a safe construction that always ensures size invariants hold -by padding/truncating the data. +/-! +IO-only wrappers for loading inputs and running checks. -/ - -/-- Build a ConcreteMatrix from float data, padding or truncating as needed. - This is safe because we ensure the data has exactly the right size. -/ -def buildMatrix (rows cols : Nat) (data : Array Float) : ConcreteMatrix := - let expectedSize := rows * cols - let normalizedData : Array Float := - if data.size < expectedSize then - data ++ (Array.replicate (expectedSize - data.size) 0.0) - else if data.size > expectedSize then - data.toSubarray 0 expectedSize |>.toArray - else - data - -- Use Array.ofFn to get the exact size we need with a proof - let finalData := Array.ofFn fun (i : Fin expectedSize) => - normalizedData.getD i.val 0.0 - { - numRows := rows - numCols := cols - data := finalData - size_eq := Array.size_ofFn - } - -/-- Result of loading a model. -/ -inductive LoadResult - | ok (model : ConcreteModel) - | error (msg : String) - -namespace LoadResult - -def isOk : LoadResult → Bool - | ok _ => true - | error _ => false - -def getModel : LoadResult → Option ConcreteModel - | ok m => some m - | error _ => none - -def getError : LoadResult → Option String - | ok _ => none - | error msg => some msg - -end LoadResult - -/-! ## Text Format Parsing -/ - -/-- NFP file header structure. -/ -structure NfpHeader where - numLayers : Nat - numHeads : Nat - modelDim : Nat - headDim : Nat - hiddenDim : Nat - vocabSize : Nat - seqLen : Nat - deriving Repr - -/-- Build a ConcreteAttentionLayer from weight matrices. - The dimension proofs are satisfied by construction (buildMatrix ensures correct sizes). -/ -def mkAttentionLayer - (modelDim headDim : Nat) - (wq wk wv wo bq bk bv : Array Float) : ConcreteAttentionLayer := - let wQ := buildMatrix modelDim headDim wq - let bQ := buildMatrix 1 headDim bq - let wK := buildMatrix modelDim headDim wk - let bK := buildMatrix 1 headDim bk - let wV := buildMatrix modelDim headDim wv - let bV := buildMatrix 1 headDim bv - let wO := buildMatrix headDim modelDim wo - { - modelDim := modelDim - headDim := headDim - W_Q := wQ - b_Q := bQ - W_K := wK - b_K := bK - W_V := wV - b_V := bV - W_O := wO - W_Q_dims := ⟨rfl, rfl⟩ - b_Q_dims := ⟨rfl, rfl⟩ - W_K_dims := ⟨rfl, rfl⟩ - b_K_dims := ⟨rfl, rfl⟩ - W_V_dims := ⟨rfl, rfl⟩ - b_V_dims := ⟨rfl, rfl⟩ - W_O_dims := ⟨rfl, rfl⟩ - } - -/-- Build a ConcreteMLPLayer from weight matrices. - The dimension proofs are satisfied by construction. -/ -def mkMLPLayer - (modelDim hiddenDim : Nat) - (win wout bin bout : Array Float) : ConcreteMLPLayer := - let wIn := buildMatrix modelDim hiddenDim win - let wOut := buildMatrix hiddenDim modelDim wout - let bIn := buildMatrix 1 hiddenDim bin - let bOut := buildMatrix 1 modelDim bout - { - modelDim := modelDim - hiddenDim := hiddenDim - W_in := wIn - W_out := wOut - b_in := bIn - b_out := bOut - W_in_dims := ⟨rfl, rfl⟩ - W_out_dims := ⟨rfl, rfl⟩ - b_in_dims := ⟨rfl, rfl⟩ - b_out_dims := ⟨rfl, rfl⟩ - } - -/-- Load a model from NFP text format content. -/ -def loadFromText (_content : String) : IO LoadResult := do - return .error "NFP_TEXT format is deprecated; use NFP_BINARY_V1" - -/-! ## Binary `.nfpt` loading (NFP_BINARY_V1) -/ - -private def readLine? (h : IO.FS.Handle) : IO (Option String) := do - let s ← h.getLine - if s.isEmpty then - return none - else - return some s - -private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - let mut out := ByteArray.empty - while out.size < n do - let chunk ← h.read (USize.ofNat (n - out.size)) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF") - out := out ++ chunk - return out - -private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := - let b0 := (b[off]!).toUInt32 - let b1 := (b[off + 1]!).toUInt32 - let b2 := (b[off + 2]!).toUInt32 - let b3 := (b[off + 3]!).toUInt32 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) - -private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := - let b0 := (b[off]!).toUInt64 - let b1 := (b[off + 1]!).toUInt64 - let b2 := (b[off + 2]!).toUInt64 - let b3 := (b[off + 3]!).toUInt64 - let b4 := (b[off + 4]!).toUInt64 - let b5 := (b[off + 5]!).toUInt64 - let b6 := (b[off + 6]!).toUInt64 - let b7 := (b[off + 7]!).toUInt64 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| - (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) - -private def i32FromLE (b : ByteArray) (off : Nat) : Int := - let u := u32FromLE b off - let half : UInt32 := 0x80000000 - if u < half then - Int.ofNat u.toNat - else - let two32 : Int := Int.ofNat (Nat.pow 2 32) - (Int.ofNat u.toNat) - two32 - -private def floatFromLE (b : ByteArray) (off : Nat) : Float := - Float.ofBits (u64FromLE b off) - -private def readFloatArray (h : IO.FS.Handle) (count : Nat) : IO FloatArray := do - if count = 0 then - return FloatArray.empty - let bytes ← readExactly h (count * 8) - let data := Array.ofFn (fun i : Fin count => - floatFromLE bytes (i.val * 8)) - return .mk data - -private def readI32Array (h : IO.FS.Handle) (count : Nat) : IO (Array Nat) := do - if count = 0 then - return #[] - let bytes ← readExactly h (count * 4) - let mut out : Array Nat := Array.mkEmpty count - for i in [:count] do - let v := i32FromLE bytes (i * 4) - if v < 0 then - throw (IO.userError s!"Negative token id at index {i}") - out := out.push v.toNat - return out - -/-- Load a model from the `.nfpt` binary format (NFP_BINARY_V1). -/ -def loadBinary (h : IO.FS.Handle) : IO LoadResult := do - try - let some magicLine ← readLine? h - | return .error "Empty file" - let magic := magicLine.trim - if magic != "NFP_BINARY_V1" then - return .error "Invalid magic: expected NFP_BINARY_V1" - - IO.println "[1/5] Parsing header..." - - let mut numLayers : Nat := 0 - let mut numHeads : Nat := 0 - let mut modelDim : Nat := 0 - let mut headDim : Nat := 0 - let mut hiddenDim : Nat := 0 - let mut vocabSize : Nat := 0 - let mut seqLen : Nat := 0 - - let mut line? ← readLine? h - while true do - match line? with - | none => return .error "Unexpected EOF while reading header" - | some line => - let t := line.trim - if t = "BINARY_START" then - break - if t.startsWith "num_layers=" then - numLayers := (t.drop 11).toNat! - else if t.startsWith "num_heads=" then - numHeads := (t.drop 10).toNat! - else if t.startsWith "model_dim=" then - modelDim := (t.drop 10).toNat! - else if t.startsWith "head_dim=" then - headDim := (t.drop 9).toNat! - else if t.startsWith "hidden_dim=" then - hiddenDim := (t.drop 11).toNat! - else if t.startsWith "vocab_size=" then - vocabSize := (t.drop 11).toNat! - else if t.startsWith "seq_len=" then - seqLen := (t.drop 8).toNat! - line? ← readLine? h - - if modelDim = 0 || numLayers = 0 || numHeads = 0 then - return .error s!"Invalid header: modelDim={modelDim}, numLayers={numLayers}, numHeads={numHeads} (all must be > 0)" - if headDim = 0 || hiddenDim = 0 || vocabSize = 0 || seqLen = 0 then - return .error "Invalid header: headDim/hiddenDim/vocabSize/seqLen must be > 0" - - IO.println s!"[2/5] Loading input tokens + embeddings (seq_len={seqLen}, model_dim={modelDim})..." - - let inputTokens : Array Nat ← readI32Array h seqLen - let embFloats ← readFloatArray h (seqLen * modelDim) - let inputEmbeddings := buildMatrix seqLen modelDim embFloats.data - - IO.println s!"[3/5] Loading {numLayers} layers with {numHeads} heads each..." - - let mut layers : Array (Array ConcreteAttentionLayer) := #[] - let mut attnProjBias : Array ConcreteMatrix := #[] - let mut mlps : Array ConcreteMLPLayer := #[] - let mut ln1 : Array ConcreteLayerNormParams := #[] - let mut ln2 : Array ConcreteLayerNormParams := #[] - - for l in [:numLayers] do - IO.println s!" Loading layer {l}/{numLayers}..." - let mut layerHeads : Array ConcreteAttentionLayer := #[] - for _h in [:numHeads] do - let wq ← readFloatArray h (modelDim * headDim) - let bq ← readFloatArray h headDim - let wk ← readFloatArray h (modelDim * headDim) - let bk ← readFloatArray h headDim - let wv ← readFloatArray h (modelDim * headDim) - let bv ← readFloatArray h headDim - let wo ← readFloatArray h (headDim * modelDim) - let head := mkAttentionLayer modelDim headDim wq.data wk.data wv.data wo.data bq.data bk.data bv.data - layerHeads := layerHeads.push head - layers := layers.push layerHeads - - let bias ← readFloatArray h modelDim - attnProjBias := attnProjBias.push (buildMatrix 1 modelDim bias.data) - - let win ← readFloatArray h (modelDim * hiddenDim) - let bin ← readFloatArray h hiddenDim - let wout ← readFloatArray h (hiddenDim * modelDim) - let bout ← readFloatArray h modelDim - mlps := mlps.push (mkMLPLayer modelDim hiddenDim win.data wout.data bin.data bout.data) - - let ln1Gamma ← readFloatArray h modelDim - let ln1Beta ← readFloatArray h modelDim - ln1 := ln1.push { - gamma := buildMatrix 1 modelDim ln1Gamma.data - beta := buildMatrix 1 modelDim ln1Beta.data - } - - let ln2Gamma ← readFloatArray h modelDim - let ln2Beta ← readFloatArray h modelDim - ln2 := ln2.push { - gamma := buildMatrix 1 modelDim ln2Gamma.data - beta := buildMatrix 1 modelDim ln2Beta.data - } - - IO.println "[4/5] Loading final layernorm + unembedding..." - - let lnfGamma ← readFloatArray h modelDim - let lnfBeta ← readFloatArray h modelDim - let lnf := { - gamma := buildMatrix 1 modelDim lnfGamma.data - beta := buildMatrix 1 modelDim lnfBeta.data - } - - let unembFloats ← readFloatArray h (modelDim * vocabSize) - let unembedding := buildMatrix modelDim vocabSize unembFloats.data - - let model : ConcreteModel := { - numLayers := numLayers - layers := layers - attnProjBias := attnProjBias - mlps := mlps - ln1 := ln1 - ln2 := ln2 - lnf := lnf - seqLen := seqLen - inputTokens := some inputTokens - inputEmbeddings := inputEmbeddings - unembedding := some unembedding - } - - IO.println "[5/5] Model loaded successfully! -" - return .ok model - catch e => - return .error s!"Binary load failed: {e}" -/-! ## File IO Operations -/ - -/-- Load a model from a file path. Supports .nfpt (binary) format. -/ -def loadModel (path : System.FilePath) : IO LoadResult := do - if path.extension = some "nfpt" then - IO.FS.withFile path .read fun h => - loadBinary h - else - return .error s!"Unsupported file format: {path.extension.getD "unknown"}" - -/-! ## Tokenization Utilities -/ - -/-- Simple tokenizer with vocabulary mapping. -/ -structure Tokenizer where - /-- Token strings in order of ID -/ - tokens : Array String - /-- Unknown token ID -/ - unkId : Nat - /-- Padding token ID -/ - padId : Nat - /-- End of sequence token ID -/ - eosId : Nat - -namespace Tokenizer - -/-- Create a tokenizer from vocabulary list. -/ -def fromVocabList (tokens : Array String) - (unkId padId eosId : Nat := 0) : Tokenizer := - { tokens := tokens, unkId := unkId, padId := padId, eosId := eosId } - -/-- Find a token's ID in the vocabulary. -/ -def findToken (t : Tokenizer) (word : String) : Nat := - match t.tokens.findIdx? (· == word) with - | some idx => idx - | none => t.unkId - -/-- Tokenize a string using simple whitespace splitting. -/ -def tokenize (t : Tokenizer) (text : String) : Array Nat := Id.run do - let words := text.splitOn " " |>.filter (· ≠ "") - let mut ids : Array Nat := #[] - for word in words do - ids := ids.push (t.findToken word) - ids - -/-- Decode token IDs back to text. -/ -def decode (t : Tokenizer) (ids : Array Nat) : String := - let tokens := ids.filterMap fun id => - if id < t.tokens.size then some t.tokens[id]! - else none - " ".intercalate tokens.toList - -end Tokenizer - -/-- Look up embeddings for token IDs from the embedding matrix. -/ -def lookupEmbeddings (embeddings : ConcreteMatrix) (tokenIds : Array Nat) - (seqLen : Nat) (padId : Nat := 0) : ConcreteMatrix := Id.run do - let modelDim := embeddings.numCols - let mut data : Array Float := #[] - - for pos in [:seqLen] do - let tokenId := if pos < tokenIds.size then tokenIds[pos]! else padId - -- Copy embedding row for this token - for dim in [:modelDim] do - let val := embeddings.get tokenId dim - data := data.push val - - buildMatrix seqLen modelDim data - -/-- Set the input embeddings in a model for a given prompt (token IDs). -/ -def ConcreteModel.withInputTokens (model : ConcreteModel) - (embeddings : ConcreteMatrix) (tokenIds : Array Nat) - (padId : Nat := 0) : ConcreteModel := - let inputEmb := lookupEmbeddings embeddings tokenIds model.seqLen padId - { model with inputEmbeddings := inputEmb, inputTokens := some tokenIds } - -/-! ## Analysis Report Generation -/ - -/-- Format for circuit analysis results. -/ -structure AnalysisReport where - /-- Model name/path -/ - modelName : String - /-- Input prompt (if available) -/ - prompt : Option String - /-- Number of layers analyzed -/ - numLayers : Nat - /-- Total heads in model -/ - totalHeads : Nat - /-- Verified induction head candidates -/ - inductionHeads : Array CandidateInductionHead - /-- Deep circuit candidates with N-layer bounds -/ - deepCircuits : Array DeepCircuitCandidate - /-- Verification result (if run) -/ - verification : Option VerificationResult - -namespace AnalysisReport - -/-- Generate a human-readable report. -/ -def toString (r : AnalysisReport) : String := Id.run do - let mut s := s!"═══════════════════════════════════════════════════════════\n" - s := s ++ s!"NFP Circuit Analysis Report\n" - s := s ++ s!"Model: {r.modelName}\n" - match r.prompt with - | some p => s := s ++ s!"Prompt: \"{p}\"\n" - | none => pure () - s := s ++ s!"Layers: {r.numLayers}, Heads: {r.totalHeads}\n" - s := s ++ s!"═══════════════════════════════════════════════════════════\n\n" - - if r.inductionHeads.size > 0 then - s := s ++ s!"VERIFIED INDUCTION HEADS ({r.inductionHeads.size} found):\n" - s := s ++ s!"───────────────────────────────────────────────────────────\n" - for head in r.inductionHeads do - s := s ++ s!" L{head.layer1Idx}H{head.head1Idx} → L{head.layer2Idx}H{head.head2Idx}\n" - s := s ++ s!" Combined Error: {head.combinedError}\n" - s := s ++ s!" Prev-Token Strength: {head.prevTokenStrength}\n" - s := s ++ s!" Induction Score: {head.inductionScore}\n" - s := s ++ s!" K-Composition: {head.kComp}\n" - s := s ++ s!" Faithfulness Ratios: ε₁={head.patternBound1}, ε₂={head.patternBound2}\n\n" - else - s := s ++ s!"No induction heads found above threshold.\n\n" - - if r.deepCircuits.size > 0 then - s := s ++ s!"DEEP CIRCUIT CANDIDATES ({r.deepCircuits.size} found):\n" - s := s ++ s!"───────────────────────────────────────────────────────────\n" - for circuit in r.deepCircuits do - s := s ++ s!" {circuit.description}\n" - s := s ++ s!" Pattern Type: {circuit.patternType}\n" - s := s ++ s!" Simple Error Sum: {circuit.simpleErrorSum}\n" - s := s ++ s!" Amplified Error: {circuit.amplifiedError}\n" - s := s ++ s!" Amplification Factor: {circuit.amplificationFactor}\n\n" - - match r.verification with - | some v => - s := s ++ s!"EMPIRICAL VERIFICATION:\n" - s := s ++ s!"───────────────────────────────────────────────────────────\n" - let status := if v.verified then "✓ PASSED" else "✗ FAILED" - s := s ++ s!" Status: {status}\n" - s := s ++ s!" Empirical Error: {v.ablation.empiricalError}\n" - s := s ++ s!" Theoretical Bound: {v.theoreticalBound}\n" - s := s ++ s!" Tightness: {v.tightness * 100.0}%\n" - s := s ++ s!" Circuit Size: {v.ablation.circuitSize}/{v.ablation.totalComponents}\n\n" - | none => pure () - - s := s ++ s!"═══════════════════════════════════════════════════════════\n" - s - -instance : ToString AnalysisReport := ⟨AnalysisReport.toString⟩ - -end AnalysisReport - -/-- Run full analysis on a model and generate a report. -/ -def analyzeModel (model : ConcreteModel) (modelName : String) - (threshold : Float := 0.1) - (prompt : Option String := none) : IO AnalysisReport := do - IO.println "\n═══════════════════════════════════════════════════════════" - IO.println "Starting Circuit Analysis" - IO.println s!"Model: {modelName}" - IO.println s!"Threshold: {threshold}" - IO.println "═══════════════════════════════════════════════════════════\n" - - IO.println "[1/2] Building precomputed cache..." - let cache := PrecomputedCache.build model - - IO.println "[2/2] Searching for deep circuit candidates (shared scan)..." - -- Find deep circuit candidates (reuse cache) - let deepCircuits := findDeepCircuitCandidatesFromCache cache - let verifiedDeep := deepCircuits.filter (·.amplifiedError ≤ threshold) - IO.println s!" Found {verifiedDeep.size} verified deep circuits \ - (of {deepCircuits.size} candidates)" - - -- Derive induction-head candidates from the same scan to avoid repeating - -- the expensive `checkInductionPattern` computation. - let inductionHeads := - (deepCircuits.filterMap (·.toInductionCandidate? cache)).qsort - (·.combinedError < ·.combinedError) - let verifiedHeads := inductionHeads.filter (·.combinedError ≤ threshold) - IO.println s!" Found {verifiedHeads.size} verified induction heads \ - (of {inductionHeads.size} candidates)\n" - - IO.println "Analysis complete!\n" - - -- Count total heads - let totalHeads := model.layers.foldl (fun acc layer => acc + layer.size) 0 - - return { - modelName := modelName - prompt := prompt - numLayers := model.numLayers - totalHeads := totalHeads - inductionHeads := verifiedHeads - deepCircuits := verifiedDeep - verification := none - } - -/-- Run analysis with empirical verification. -/ -def analyzeAndVerify (model : ConcreteModel) (modelName : String) - (threshold : Float := 0.1) - (prompt : Option String := none) : IO AnalysisReport := do - let baseReport ← analyzeModel model modelName threshold prompt - - IO.println "═══════════════════════════════════════════════════════════" - IO.println "Starting Empirical Verification" - IO.println "═══════════════════════════════════════════════════════════\n" - - IO.println "Running circuit discovery and ablation experiments..." - -- Run circuit discovery and verification - let (_, verification) := discoverAndVerify model threshold - IO.println "Verification complete!\n" - - return { baseReport with verification := some verification } - -end Nfp diff --git a/Nfp/IO/Checks.lean b/Nfp/IO/Checks.lean new file mode 100644 index 0000000..edae3a7 --- /dev/null +++ b/Nfp/IO/Checks.lean @@ -0,0 +1,50 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.Circuit.Cert.ValueRange + +/-! +IO checks for certificates. +-/ + +public section + +namespace Nfp + +namespace IO + +open Nfp.Circuit + +/-- Check a softmax-margin certificate for a positive sequence length. -/ +def checkSoftmaxMargin (seq : Nat) (cert : SoftmaxMarginCert seq) : + IO (Except String Unit) := + match seq with + | 0 => return Except.error "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkSoftmaxMarginCert cert + if ok then + return Except.ok () + else + return Except.error "softmax-margin certificate rejected" + +/-- Check a value-range certificate for a positive sequence length. -/ +def checkValueRange (seq : Nat) (cert : ValueRangeCert seq) : + IO (Except String Unit) := + match seq with + | 0 => return Except.error "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkValueRangeCert cert + if ok then + return Except.ok () + else + return Except.error "value-range certificate rejected" + +end IO + +end Nfp diff --git a/Nfp/IO/InductionHead.lean b/Nfp/IO/InductionHead.lean new file mode 100644 index 0000000..29771dc --- /dev/null +++ b/Nfp/IO/InductionHead.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.IO.InductionHead.Cert + +/-! +IO helpers for induction-head certificate checking. +-/ diff --git a/Nfp/IO/InductionHead/Cert.lean b/Nfp/IO/InductionHead/Cert.lean new file mode 100644 index 0000000..283cf3e --- /dev/null +++ b/Nfp/IO/InductionHead/Cert.lean @@ -0,0 +1,472 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Insert +public import Nfp.Circuit.Cert.InductionHead +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.IO.InductionHead.Tokens +public import Nfp.IO.Parse.Basic +public import Nfp.IO.Util +public import Nfp.Model.InductionPrompt + +/-! +Untrusted parsing and checking for explicit induction-head certificates. + +All sequence indices in the certificate payload are 1-based (literature convention) and +are converted to `Fin` indices internally. +-/ + +public section + +namespace Nfp + +namespace IO + +open Nfp.Circuit +open Nfp.IO.Parse + +namespace InductionHeadCert + +/-- State for parsing induction-head certificates. -/ +structure ParseState (seq : Nat) where + /-- Optional epsilon bound. -/ + eps : Option Rat + /-- Optional margin bound. -/ + margin : Option Rat + /-- Active query set. -/ + active : Finset (Fin seq) + /-- Whether any active entries were parsed. -/ + activeSeen : Bool + /-- Optional predecessor pointer per query. -/ + prev : Array (Option (Fin seq)) + /-- Optional score matrix entries. -/ + scores : Array (Array (Option Rat)) + /-- Optional weight matrix entries. -/ + weights : Array (Array (Option Rat)) + /-- Optional per-query epsilon bounds. -/ + epsAt : Array (Option Rat) + /-- Optional per-key weight bounds. -/ + weightBoundAt : Array (Array (Option Rat)) + /-- Optional lower bound for values. -/ + lo : Option Rat + /-- Optional upper bound for values. -/ + hi : Option Rat + /-- Optional per-key lower bounds. -/ + valsLo : Array (Option Rat) + /-- Optional per-key upper bounds. -/ + valsHi : Array (Option Rat) + /-- Optional per-key exact values. -/ + vals : Array (Option Rat) + /-- Optional direction target index. -/ + directionTarget : Option Nat + /-- Optional direction negative index. -/ + directionNegative : Option Nat + +/-- Initialize a parse state. -/ +def initState (seq : Nat) : ParseState seq := + let row : Array (Option Rat) := Array.replicate seq none + { eps := none + margin := none + active := ∅ + activeSeen := false + prev := Array.replicate seq none + scores := Array.replicate seq row + weights := Array.replicate seq row + epsAt := Array.replicate seq none + weightBoundAt := Array.replicate seq row + lo := none + hi := none + valsLo := Array.replicate seq none + valsHi := Array.replicate seq none + vals := Array.replicate seq none + directionTarget := none + directionNegative := none } + +private def toIndex1 {seq : Nat} (label : String) (idx : Nat) : Except String (Fin seq) := do + if idx = 0 then + throw s!"{label} index must be >= 1" + let idx' := idx - 1 + if h : idx' < seq then + return ⟨idx', h⟩ + else + throw s!"{label} index out of range: {idx}" + +private def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : + Except String (ParseState seq) := do + let qFin ← toIndex1 (seq := seq) "q" q + if qFin ∈ st.active then + throw s!"duplicate active entry for q={q}" + else + return { st with active := insert qFin st.active, activeSeen := true } + +private def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : + Except String (ParseState seq) := do + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + match st.prev[qFin.1]! with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' := st.prev.set! qFin.1 (some kFin) + return { st with prev := prev' } + +private def setVecEntry {seq : Nat} (arr : Array (Option Rat)) (idx : Nat) (v : Rat) : + Except String (Array (Option Rat)) := do + let kFin ← toIndex1 (seq := seq) "k" idx + match arr[kFin.1]! with + | some _ => + throw s!"duplicate entry for k={idx}" + | none => + return arr.set! kFin.1 (some v) + +private def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) + (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + let row := mat[qFin.1]! + match row[kFin.1]! with + | some _ => + throw s!"duplicate matrix entry at ({q}, {k})" + | none => + let row' := row.set! kFin.1 (some v) + return mat.set! qFin.1 row' + +/-- Parse a tokenized line into the parse state. -/ +def parseLine {seq : Nat} (st : ParseState seq) (tokens : List String) : + Except String (ParseState seq) := do + match tokens with + | ["eps", val] => + if st.eps.isSome then + throw "duplicate eps entry" + else + return { st with eps := some (← parseRat val) } + | ["margin", val] => + if st.margin.isSome then + throw "duplicate margin entry" + else + return { st with margin := some (← parseRat val) } + | ["active", q] => + setActive st (← parseNat q) + | ["prev", q, k] => + setPrev st (← parseNat q) (← parseNat k) + | ["score", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.scores (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with scores := mat } + | ["weight", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.weights (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with weights := mat } + | ["eps-at", q, val] => + let arr ← setVecEntry (seq := seq) st.epsAt (← parseNat q) (← parseRat val) + return { st with epsAt := arr } + | ["weight-bound", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.weightBoundAt (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with weightBoundAt := mat } + | ["lo", val] => + if st.lo.isSome then + throw "duplicate lo entry" + else + return { st with lo := some (← parseRat val) } + | ["hi", val] => + if st.hi.isSome then + throw "duplicate hi entry" + else + return { st with hi := some (← parseRat val) } + | ["val", k, val] => + let arr ← setVecEntry (seq := seq) st.vals (← parseNat k) (← parseRat val) + return { st with vals := arr } + | ["val-lo", k, val] => + let arr ← setVecEntry (seq := seq) st.valsLo (← parseNat k) (← parseRat val) + return { st with valsLo := arr } + | ["val-hi", k, val] => + let arr ← setVecEntry (seq := seq) st.valsHi (← parseNat k) (← parseRat val) + return { st with valsHi := arr } + | ["direction-target", tok] => + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + return { st with directionTarget := some (← parseNat tok) } + | ["direction-negative", tok] => + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + return { st with directionNegative := some (← parseNat tok) } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +/-- Extract the `seq` header from tokenized lines. -/ +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +private def finalizeState {seq : Nat} (hpos : 0 < seq) (st : ParseState seq) : + Except String (Circuit.InductionHeadCert seq) := do + let eps ← + match st.eps with + | some v => pure v + | none => throw "missing eps entry" + let margin ← + match st.margin with + | some v => pure v + | none => throw "missing margin entry" + let lo ← + match st.lo with + | some v => pure v + | none => throw "missing lo entry" + let hi ← + match st.hi with + | some v => pure v + | none => throw "missing hi entry" + if !st.prev.all Option.isSome then + throw "missing prev entries" + if !st.scores.all (fun row => row.all Option.isSome) then + throw "missing score entries" + if !st.weights.all (fun row => row.all Option.isSome) then + throw "missing weight entries" + if !st.epsAt.all Option.isSome then + throw "missing eps-at entries" + if !st.weightBoundAt.all (fun row => row.all Option.isSome) then + throw "missing weight-bound entries" + if !st.valsLo.all Option.isSome then + throw "missing val-lo entries" + if !st.valsHi.all Option.isSome then + throw "missing val-hi entries" + if !st.vals.all Option.isSome then + throw "missing val entries" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev[q.1]!).getD defaultPrev + let scoresFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.scores[q.1]! + (row[k.1]!).getD 0 + let weightsFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.weights[q.1]! + (row[k.1]!).getD 0 + let epsAtFun : Fin seq → Rat := fun q => + (st.epsAt[q.1]!).getD 0 + let weightBoundAtFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.weightBoundAt[q.1]! + (row[k.1]!).getD 0 + let valsLoFun : Fin seq → Rat := fun k => + (st.valsLo[k.1]!).getD 0 + let valsHiFun : Fin seq → Rat := fun k => + (st.valsHi[k.1]!).getD 0 + let valsFun : Fin seq → Rat := fun k => + (st.vals[k.1]!).getD 0 + let direction ← + match st.directionTarget, st.directionNegative with + | none, none => pure none + | some target, some negative => + pure (some { target := target, negative := negative }) + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + let values : Circuit.ValueIntervalCert seq := + { lo := lo + hi := hi + valsLo := valsLoFun + valsHi := valsHiFun + vals := valsFun + direction := direction } + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + pure + { eps := eps + epsAt := epsAtFun + weightBoundAt := weightBoundAtFun + margin := margin + active := active + prev := prevFun + scores := scoresFun + weights := weightsFun + values := values } + +end InductionHeadCert + +/-- Parse an explicit induction-head certificate from a text payload. -/ +def parseInductionHeadCert (input : String) : + Except String (Sigma Circuit.InductionHeadCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← InductionHeadCert.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let hpos : 0 < seq := Nat.succ_pos n + let st0 : InductionHeadCert.ParseState seq := InductionHeadCert.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => InductionHeadCert.parseLine st t) st0 + let cert ← InductionHeadCert.finalizeState hpos st + return ⟨seq, cert⟩ + +/-- Load an induction-head certificate from disk. -/ +def loadInductionHeadCert (path : System.FilePath) : + IO (Except String (Sigma Circuit.InductionHeadCert)) := do + let data ← IO.FS.readFile path + return parseInductionHeadCert data + +/-- Parse a token list payload. -/ +def parseInductionHeadTokens (input : String) : + Except String (Sigma fun seq => Fin seq → Nat) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← InductionHeadTokens.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let st0 : InductionHeadTokens.ParseState seq := InductionHeadTokens.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => InductionHeadTokens.parseLine st t) st0 + let tokensFun ← InductionHeadTokens.finalizeState st + return ⟨seq, tokensFun⟩ + +/-- Load a token list from disk. -/ +def loadInductionHeadTokens (path : System.FilePath) : + IO (Except String (Sigma fun seq => Fin seq → Nat)) := do + let data ← IO.FS.readFile path + return parseInductionHeadTokens data + +private def ratToString (x : Rat) : String := + toString x + +/-- Check an explicit induction-head certificate from disk. -/ +def runInductionHeadCertCheck (certPath : System.FilePath) + (minActive? : Option Nat) (minLogitDiffStr? : Option String) + (minMarginStr? : Option String) (maxEpsStr? : Option String) + (tokensPath? : Option String) : IO UInt32 := do + let minLogitDiff?E := parseRatOpt "min-logit-diff" minLogitDiffStr? + let minMargin?E := parseRatOpt "min-margin" minMarginStr? + let maxEps?E := parseRatOpt "max-eps" maxEpsStr? + match minLogitDiff?E, minMargin?E, maxEps?E with + | Except.error msg, _, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, Except.error msg, _ => + IO.eprintln s!"error: {msg}" + return 2 + | _, _, Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok minLogitDiff?, Except.ok minMargin?, Except.ok maxEps? => + let minMargin := minMargin?.getD (0 : Rat) + let maxEps := maxEps?.getD (ratRoundDown (Rat.divInt 1 2)) + let parsed ← loadInductionHeadCert certPath + match parsed with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 1 + | Except.ok ⟨seq, cert⟩ => + match seq with + | 0 => + IO.eprintln "error: seq must be positive" + return 2 + | Nat.succ n => + let seq := Nat.succ n + let _ : NeZero seq := ⟨by simp⟩ + let ok := Circuit.checkInductionHeadCert cert + if !ok then + IO.eprintln "error: induction-head certificate rejected" + return 2 + let activeCount := cert.active.card + let defaultMinActive := max 1 (seq / 8) + let minActive := minActive?.getD defaultMinActive + if activeCount < minActive then + IO.eprintln + s!"error: active queries {activeCount} below minimum {minActive}" + return 2 + if cert.margin < minMargin then + IO.eprintln + s!"error: margin {ratToString cert.margin} \ + below minimum {ratToString minMargin}" + return 2 + if maxEps < cert.eps then + IO.eprintln + s!"error: eps {ratToString cert.eps} \ + above maximum {ratToString maxEps}" + return 2 + if let some tokensPath := tokensPath? then + let tokensParsed ← loadInductionHeadTokens tokensPath + match tokensParsed with + | Except.error msg => + IO.eprintln s!"error: {msg}" + return 2 + | Except.ok ⟨seqTokens, tokens⟩ => + if hseq : seqTokens = seq then + let tokens' : Fin seq → Nat := by + simpa [hseq] using tokens + let activeTokens := Model.activeOfTokens (seq := seq) tokens' + if !decide (cert.active ⊆ activeTokens) then + IO.eprintln "error: active set not contained in token repeats" + return 2 + let prevTokens := Model.prevOfTokens (seq := seq) tokens' + let prevOk := + (List.finRange seq).all (fun q => + if decide (q ∈ cert.active) then + decide (prevTokens q = cert.prev q) + else + true) + if !prevOk then + IO.eprintln "error: prev map does not match tokens on active queries" + return 2 + else + IO.eprintln + s!"error: tokens seq {seqTokens} does not match cert seq {seq}" + return 2 + let effectiveMinLogitDiff := + match minLogitDiff?, cert.values.direction with + | some v, _ => some v + | none, some _ => some (0 : Rat) + | none, none => none + match effectiveMinLogitDiff with + | none => + IO.println + s!"ok: induction head certificate checked \ + (seq={seq}, active={activeCount}, \ + margin={ratToString cert.margin}, eps={ratToString cert.eps})" + return 0 + | some minLogitDiff => + let logitDiffLB? := + Circuit.logitDiffLowerBoundAt cert.active cert.prev cert.epsAt + cert.values.lo cert.values.hi cert.values.vals + match logitDiffLB? with + | none => + IO.eprintln "error: empty active set for logit-diff bound" + return 2 + | some logitDiffLB => + if logitDiffLB < minLogitDiff then + IO.eprintln + s!"error: logitDiffLB {ratToString logitDiffLB} \ + below minimum {ratToString minLogitDiff}" + return 2 + else + IO.println + s!"ok: induction head certificate checked \ + (seq={seq}, active={activeCount}, \ + margin={ratToString cert.margin}, eps={ratToString cert.eps}, \ + logitDiffLB={ratToString logitDiffLB})" + return 0 + +end IO + +end Nfp diff --git a/Nfp/IO/InductionHead/Tokens.lean b/Nfp/IO/InductionHead/Tokens.lean new file mode 100644 index 0000000..827e94d --- /dev/null +++ b/Nfp/IO/InductionHead/Tokens.lean @@ -0,0 +1,87 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Fintype.Basic +public import Nfp.IO.Parse.Basic + +/-! +Untrusted parsing helpers for optional induction-head token lists. +-/ + +public section + +namespace Nfp + +namespace IO + +open Nfp.IO.Parse + +namespace InductionHeadTokens + +/-- State for parsing token lists. -/ +structure ParseState (seq : Nat) where + /-- Optional per-position tokens. -/ + tokens : Array (Option Nat) + +/-- Initialize a token parse state. -/ +def initState (seq : Nat) : ParseState seq := + { tokens := Array.replicate seq none } + +private def toIndex1 {seq : Nat} (label : String) (idx : Nat) : Except String (Fin seq) := do + if idx = 0 then + throw s!"{label} index must be >= 1" + let idx' := idx - 1 + if h : idx' < seq then + return ⟨idx', h⟩ + else + throw s!"{label} index out of range: {idx}" + +private def setToken {seq : Nat} (st : ParseState seq) (q tok : Nat) : + Except String (ParseState seq) := do + let qFin ← toIndex1 (seq := seq) "q" q + match st.tokens[qFin.1]! with + | some _ => + throw s!"duplicate token entry for q={q}" + | none => + let tokens' := st.tokens.set! qFin.1 (some tok) + return { st with tokens := tokens' } + +/-- Parse a tokenized line into the token parse state. -/ +def parseLine {seq : Nat} (st : ParseState seq) (tokens : List String) : + Except String (ParseState seq) := do + match tokens with + | ["token", q, tok] => + setToken st (← parseNat q) (← parseNat tok) + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +/-- Extract the `seq` header from tokenized lines. -/ +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +/-- Finalize a token parse state into a total token map. -/ +def finalizeState {seq : Nat} (st : ParseState seq) : + Except String (Fin seq → Nat) := do + if !st.tokens.all Option.isSome then + throw "missing token entries" + let tokensFun : Fin seq → Nat := fun q => + (st.tokens[q.1]!).getD 0 + pure tokensFun + +end InductionHeadTokens + +end IO + +end Nfp diff --git a/Nfp/IO/Parse.lean b/Nfp/IO/Parse.lean new file mode 100644 index 0000000..7d4b73c --- /dev/null +++ b/Nfp/IO/Parse.lean @@ -0,0 +1,11 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.IO.Parse.Basic +public import Nfp.IO.Parse.SoftmaxMargin +public import Nfp.IO.Parse.ValueRange + +/-! +Aggregator for pure CLI parsing helpers. +-/ diff --git a/Nfp/IO/Parse/Basic.lean b/Nfp/IO/Parse/Basic.lean new file mode 100644 index 0000000..0dfd2c6 --- /dev/null +++ b/Nfp/IO/Parse/Basic.lean @@ -0,0 +1,79 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Core.Basic + +/-! +Shared parsing helpers for CLI inputs. +-/ + +public section + +namespace Nfp + +namespace IO + +namespace Parse + +/-- Split a line into whitespace-separated tokens. -/ +def splitWords (line : String) : List String := + line.splitToList (fun c => c = ' ' || c = '\t') |>.filter (· ≠ "") + +/-- Drop empty/comment lines and return their whitespace tokens. -/ +def cleanTokens (line : String) : Option (List String) := + let trimmed := line.trim + if trimmed.isEmpty then + none + else if trimmed.startsWith "#" then + none + else + some (splitWords trimmed) + +/-- Parse a nonnegative decimal integer. -/ +def parseNat (s : String) : Except String Nat := do + if s.isEmpty then + throw s!"expected Nat, got '{s}'" + else + let mut acc : Nat := 0 + for c in s.toList do + if c.isDigit then + acc := acc * 10 + c.toNat - '0'.toNat + else + throw s!"expected Nat, got '{s}'" + return acc + +/-- Parse a signed decimal integer. -/ +def parseInt (s : String) : Except String Int := do + if s.isEmpty then + throw s!"expected Int, got '{s}'" + else + match s.toSlice.front? with + | some '-' => + let rest := s.drop 1 + let n ← parseNat rest + return -Int.ofNat n + | _ => + let n ← parseNat s + return Int.ofNat n + +/-- Parse a rational literal from `a` or `a/b`, rounding down if needed. -/ +def parseRat (s : String) : Except String Rat := do + match s.splitOn "/" with + | [num] => + return ratRoundDown (Rat.ofInt (← parseInt num)) + | [num, den] => + let n ← parseInt num + let d ← parseNat den + if d = 0 then + throw s!"invalid rational '{s}': zero denominator" + else + return ratRoundDown (Rat.divInt n (Int.ofNat d)) + | _ => + throw s!"invalid rational '{s}'" + +end Parse + +end IO + +end Nfp diff --git a/Nfp/IO/Parse/SoftmaxMargin.lean b/Nfp/IO/Parse/SoftmaxMargin.lean new file mode 100644 index 0000000..0d504e5 --- /dev/null +++ b/Nfp/IO/Parse/SoftmaxMargin.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.IO.Parse.SoftmaxMargin.Cert + +/-! +Aggregator for softmax-margin parsing helpers. +-/ diff --git a/Nfp/IO/Parse/SoftmaxMargin/Cert.lean b/Nfp/IO/Parse/SoftmaxMargin/Cert.lean new file mode 100644 index 0000000..014c810 --- /dev/null +++ b/Nfp/IO/Parse/SoftmaxMargin/Cert.lean @@ -0,0 +1,83 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Cert.SoftmaxMargin +public import Nfp.IO.Parse.SoftmaxMargin.Shared + +/-! +Parsing helpers for softmax-margin certificates. +-/ + +public section + +namespace Nfp + +namespace IO + +namespace Parse + +open Nfp.Circuit + +private def finalizeState {seq : Nat} (hpos : 0 < seq) + (st : SoftmaxMargin.ParseState seq) : Except String (SoftmaxMarginCert seq) := do + let eps ← + match st.eps with + | some v => pure v + | none => throw "missing eps entry" + let margin ← + match st.margin with + | some v => pure v + | none => throw "missing margin entry" + if !st.prev.all Option.isSome then + throw "missing prev entries" + if !st.scores.all (fun row => row.all Option.isSome) then + throw "missing score entries" + if !st.weights.all (fun row => row.all Option.isSome) then + throw "missing weight entries" + let defaultPrev : Fin seq := ⟨0, hpos⟩ + let prevFun : Fin seq → Fin seq := fun q => + (st.prev[q.1]!).getD defaultPrev + let scoresFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.scores[q.1]! + (row[k.1]!).getD 0 + let weightsFun : Fin seq → Fin seq → Rat := fun q k => + let row := st.weights[q.1]! + (row[k.1]!).getD 0 + let active := + if st.activeSeen then + st.active + else + (Finset.univ : Finset (Fin seq)).erase defaultPrev + pure + { eps := eps + margin := margin + active := active + prev := prevFun + scores := scoresFun + weights := weightsFun } + +/-- Parse a softmax-margin certificate from a text payload. -/ +def parseSoftmaxMarginCert (input : String) : + Except String (Sigma SoftmaxMarginCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← SoftmaxMargin.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let hpos : 0 < seq := Nat.succ_pos n + let st0 : SoftmaxMargin.ParseState seq := SoftmaxMargin.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => SoftmaxMargin.parseLine st t) st0 + let cert ← finalizeState hpos st + return ⟨seq, cert⟩ + +end Parse + +end IO + +end Nfp diff --git a/Nfp/IO/Parse/SoftmaxMargin/Shared.lean b/Nfp/IO/Parse/SoftmaxMargin/Shared.lean new file mode 100644 index 0000000..2ddfb57 --- /dev/null +++ b/Nfp/IO/Parse/SoftmaxMargin/Shared.lean @@ -0,0 +1,145 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Insert +public import Nfp.IO.Parse.Basic + +/-! +Shared parsing helpers for softmax-margin payloads. + +All sequence indices in the payload are 1-based (literature convention) and are converted to +`Fin` indices internally. +-/ + +public section + +namespace Nfp + +namespace IO + +namespace Parse + +namespace SoftmaxMargin + +/-- State for parsing softmax-margin payloads. -/ +structure ParseState (seq : Nat) where + /-- Optional epsilon bound. -/ + eps : Option Rat + /-- Optional margin bound. -/ + margin : Option Rat + /-- Active query set. -/ + active : Finset (Fin seq) + /-- Whether any active entries were parsed. -/ + activeSeen : Bool + /-- Optional predecessor pointer per query. -/ + prev : Array (Option (Fin seq)) + /-- Optional score matrix entries. -/ + scores : Array (Array (Option Rat)) + /-- Optional weight matrix entries. -/ + weights : Array (Array (Option Rat)) + +/-- Initialize a softmax-margin parse state. -/ +def initState (seq : Nat) : ParseState seq := + let row : Array (Option Rat) := Array.replicate seq none + { eps := none + margin := none + active := ∅ + activeSeen := false + prev := Array.replicate seq none + scores := Array.replicate seq row + weights := Array.replicate seq row } + +private def toIndex1 {seq : Nat} (label : String) (idx : Nat) : Except String (Fin seq) := do + if idx = 0 then + throw s!"{label} index must be >= 1" + let idx' := idx - 1 + if h : idx' < seq then + return ⟨idx', h⟩ + else + throw s!"{label} index out of range: {idx}" + +/-- Set a predecessor entry from `(q, k)` tokens. -/ +def setPrev {seq : Nat} (st : ParseState seq) (q k : Nat) : Except String (ParseState seq) := do + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + match st.prev[qFin.1]! with + | some _ => + throw s!"duplicate prev entry for q={q}" + | none => + let prev' := st.prev.set! qFin.1 (some kFin) + return { st with prev := prev' } + +/-- Mark an active query index. -/ +def setActive {seq : Nat} (st : ParseState seq) (q : Nat) : Except String (ParseState seq) := do + let qFin ← toIndex1 (seq := seq) "q" q + if qFin ∈ st.active then + throw s!"duplicate active entry for q={q}" + else + return { st with active := insert qFin st.active, activeSeen := true } + +/-- Insert a matrix entry for scores/weights. -/ +def setMatrixEntry {seq : Nat} (mat : Array (Array (Option Rat))) + (q k : Nat) (v : Rat) : Except String (Array (Array (Option Rat))) := do + let qFin ← toIndex1 (seq := seq) "q" q + let kFin ← toIndex1 (seq := seq) "k" k + let row := mat[qFin.1]! + match row[kFin.1]! with + | some _ => + throw s!"duplicate matrix entry at ({q}, {k})" + | none => + let row' := row.set! kFin.1 (some v) + let mat' := mat.set! qFin.1 row' + return mat' + +/-- Parse a tokenized line into the softmax-margin parse state. -/ +def parseLine {seq : Nat} (st : ParseState seq) + (tokens : List String) : Except String (ParseState seq) := do + match tokens with + | ["eps", val] => + if st.eps.isSome then + throw "duplicate eps entry" + else + return { st with eps := some (← parseRat val) } + | ["margin", val] => + if st.margin.isSome then + throw "duplicate margin entry" + else + return { st with margin := some (← parseRat val) } + | ["active", q] => + setActive st (← parseNat q) + | ["prev", q, k] => + setPrev st (← parseNat q) (← parseNat k) + | ["score", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.scores (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with scores := mat } + | ["weight", q, k, val] => + let mat ← setMatrixEntry (seq := seq) st.weights (← parseNat q) (← parseNat k) + (← parseRat val) + return { st with weights := mat } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + +/-- Extract the `seq` header from tokenized lines. -/ +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +end SoftmaxMargin + +end Parse + +end IO + +end Nfp diff --git a/Nfp/IO/Parse/ValueRange.lean b/Nfp/IO/Parse/ValueRange.lean new file mode 100644 index 0000000..6d706a1 --- /dev/null +++ b/Nfp/IO/Parse/ValueRange.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.IO.Parse.ValueRange.Cert + +/-! +Aggregator for value-range parsing helpers. +-/ diff --git a/Nfp/IO/Parse/ValueRange/Cert.lean b/Nfp/IO/Parse/ValueRange/Cert.lean new file mode 100644 index 0000000..3e7ee8d --- /dev/null +++ b/Nfp/IO/Parse/ValueRange/Cert.lean @@ -0,0 +1,67 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.IO.Parse.ValueRange.Shared + +/-! +Parsing helpers for value-range certificates. +-/ + +public section + +namespace Nfp + +namespace IO + +namespace Parse + +open Nfp.Circuit + +private def finalizeValueState {seq : Nat} (st : ValueRange.ParseState seq) : + Except String (ValueRangeCert seq) := do + let lo ← + match st.lo with + | some v => pure v + | none => throw "missing lo entry" + let hi ← + match st.hi with + | some v => pure v + | none => throw "missing hi entry" + if !finsetAll (Finset.univ : Finset (Fin seq)) (fun k => (st.vals k).isSome) then + throw "missing value entries" + let valsFun : Fin seq → Rat := fun k => + (st.vals k).getD 0 + let direction ← + match st.directionTarget, st.directionNegative with + | none, none => pure none + | some target, some negative => + pure (some { target := target, negative := negative }) + | _, _ => + throw "direction metadata requires both direction-target and direction-negative" + return { lo := lo, hi := hi, vals := valsFun, direction := direction } + +/-- Parse a value-range certificate from a text payload. -/ +def parseValueRangeCert (input : String) : + Except String (Sigma ValueRangeCert) := do + let lines := input.splitOn "\n" + let tokens := lines.filterMap cleanTokens + let seq ← ValueRange.parseSeq tokens + match seq with + | 0 => throw "seq must be positive" + | Nat.succ n => + let seq := Nat.succ n + let st0 : ValueRange.ParseState seq := ValueRange.initState seq + let st ← tokens.foldlM (fun st t => + match t with + | ["seq", _] => pure st + | _ => ValueRange.parseLine st t) st0 + let cert ← finalizeValueState st + return ⟨seq, cert⟩ + +end Parse + +end IO + +end Nfp diff --git a/Nfp/IO/Parse/ValueRange/Shared.lean b/Nfp/IO/Parse/ValueRange/Shared.lean new file mode 100644 index 0000000..df180e8 --- /dev/null +++ b/Nfp/IO/Parse/ValueRange/Shared.lean @@ -0,0 +1,123 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Circuit.Cert.ValueRange +public import Nfp.IO.Parse.Basic + +/-! +Shared parsing helpers for value-range payloads. + +All sequence indices in the payload are 1-based (literature convention) and are converted to +`Fin` indices internally. +-/ + +public section + +namespace Nfp + +namespace IO + +namespace Parse + +namespace ValueRange + +open Nfp.Circuit + +/-- State for parsing value-range payloads. -/ +structure ParseState (seq : Nat) where + /-- Optional lower bound. -/ + lo : Option Rat + /-- Optional upper bound. -/ + hi : Option Rat + /-- Optional per-position values. -/ + vals : Fin seq → Option Rat + /-- Optional direction target index. -/ + directionTarget : Option Nat + /-- Optional direction negative index. -/ + directionNegative : Option Nat + + +/-- Initialize a value-range parse state. -/ +def initState (seq : Nat) : ParseState seq := + { lo := none + hi := none + vals := fun _ => none + directionTarget := none + directionNegative := none } + + +/-- Set a value entry from `(k, v)` tokens. -/ +def setVal {seq : Nat} (st : ParseState seq) (k : Nat) (v : Rat) : + Except String (ParseState seq) := do + if k = 0 then + throw "value index must be >= 1" + let k' := k - 1 + if hk : k' < seq then + let kFin : Fin seq := ⟨k', hk⟩ + match st.vals kFin with + | some _ => + throw s!"duplicate value entry for k={k}" + | none => + let vals' : Fin seq → Option Rat := fun k'' => + if k'' = kFin then + some v + else + st.vals k'' + return { st with vals := vals' } + else + throw s!"value index out of range: k={k}" + + +/-- Parse a tokenized line into the value-range parse state. -/ +def parseLine {seq : Nat} (st : ParseState seq) + (tokens : List String) : Except String (ParseState seq) := do + match tokens with + | ["lo", val] => + if st.lo.isSome then + throw "duplicate lo entry" + else + return { st with lo := some (← parseRat val) } + | ["hi", val] => + if st.hi.isSome then + throw "duplicate hi entry" + else + return { st with hi := some (← parseRat val) } + | ["val", k, val] => + setVal st (← parseNat k) (← parseRat val) + | ["direction-target", tok] => + if st.directionTarget.isSome then + throw "duplicate direction-target entry" + else + return { st with directionTarget := some (← parseNat tok) } + | ["direction-negative", tok] => + if st.directionNegative.isSome then + throw "duplicate direction-negative entry" + else + return { st with directionNegative := some (← parseNat tok) } + | _ => + throw s!"unrecognized line: '{String.intercalate " " tokens}'" + + +/-- Extract the `seq` header from tokenized lines. -/ +def parseSeq (tokens : List (List String)) : Except String Nat := do + let mut seq? : Option Nat := none + for t in tokens do + match t with + | ["seq", n] => + if seq?.isSome then + throw "duplicate seq entry" + else + seq? := some (← parseNat n) + | _ => pure () + match seq? with + | some v => pure v + | none => throw "missing seq entry" + +end ValueRange + +end Parse + +end IO + +end Nfp diff --git a/Nfp/IO/Util.lean b/Nfp/IO/Util.lean new file mode 100644 index 0000000..00eaa42 --- /dev/null +++ b/Nfp/IO/Util.lean @@ -0,0 +1,33 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.IO.Parse + +/-! +Small shared helpers for IO parsing. +-/ + +public section + +namespace Nfp + +namespace IO + +/-- Parse an optional rational literal for CLI flags (rounded down if needed). -/ +def parseRatOpt (label : String) (raw? : Option String) : + Except String (Option Rat) := + match raw? with + | none => Except.ok none + | some raw => + match Parse.parseRat raw with + | Except.ok v => Except.ok (some v) + | Except.error msg => Except.error s!"invalid {label}: {msg}" + +/-- Emit a deprecation warning on stderr. -/ +def warnDeprecated (msg : String) : IO Unit := do + IO.eprintln s!"warning: DEPRECATED: {msg}" + +end IO + +end Nfp diff --git a/Nfp/Induction.lean b/Nfp/Induction.lean deleted file mode 100644 index 38956fc..0000000 --- a/Nfp/Induction.lean +++ /dev/null @@ -1,468 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.Real.Sqrt -import Mathlib.Analysis.InnerProductSpace.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.Order.BigOperators.Ring.Finset -import Nfp.Linearization -import Nfp.SignedMixer -import Nfp.Sound.Bounds - -/-! -# True Induction Head Formalization - -A **True Induction Head** is a rigorously certified mechanism that combines three components: - -1. **Structure**: The attention patterns match an induction head (previous-token + induction), - with the previous-token leg modeled by a self-attention placeholder due to abstract indexing. -2. **Faithfulness**: The virtual head approximation (attention rollout) is ε-certified -3. **Function**: The mechanism effectively increases logit scores for the correct token by ≥ δ - -This module formalizes the definition and proves that true induction heads provide -verifiable guarantees about model behavior. - -## Key Insight - -Most interpretability claims are heuristic. A true induction head is different: it combines -pattern detection with causal certification and functional verification, proving that: - - The discovered mechanism is mathematically sound - - The simplification (attention rollout) is approximately correct - - The mechanism actually causes the predicted output - -Together, these provide end-to-end certification of model behavior. --/ - -namespace Nfp - -open SignedMixer AttentionLinearization - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-! ## True Induction Head Definition -/ - -/-- Compute L² norm of a function over a finite type. -/ -noncomputable def l2_norm (v : (n × d) → ℝ) : ℝ := - Real.sqrt (∑ pos : n × d, (v pos) ^ 2) - -/-- Inner product of two functions over a finite type. -/ -noncomputable def inner_product (u v : (n × d) → ℝ) : ℝ := - ∑ pos : n × d, u pos * v pos - -/-- **True Induction Head**: A rigorously certified induction mechanism for a specific input. - -An induction head is "true" if it simultaneously satisfies three conditions: - -1. **Structural Pattern**: The attention weights exhibit the induction head - structure (Layer 1 uses a previous-token pattern, modeled as self-attention; Layer 2 - uses a nonnegativity placeholder for token-matching). - This is captured by an `InductionHeadPattern`. - -2. **Faithful Approximation**: The virtual head (composition of value terms, - aka "attention rollout") is ε-certified—it approximates the true composed Jacobian - within Frobenius norm ε. - -3. **Functional Effectiveness**: On the **specific input**, the virtual head's output, - when projected onto the target logit difference direction, produces at least δ increase - in score. This binds the abstract mechanism to the concrete model behavior. --/ -structure TrueInductionHead where - /-- The model input (residual stream at sequence positions) -/ - input : (n × d) → ℝ - /-- Certified induction head pattern (has layer1 and layer2 with attention properties) -/ - pattern : InductionHeadPattern (n := n) (d := d) - /-- The composed true Jacobian from input to output -/ - composed_jacobian : SignedMixer (n × d) (n × d) - /-- Target direction in residual stream space (how positions/dimensions contribute to target) -/ - target_logit_diff : (n × d) → ℝ - /-- Faithfulness bound: how close virtual head is to composed Jacobian -/ - epsilon : ℝ - /-- Functional effectiveness bound: minimum logit increase from this mechanism -/ - delta : ℝ - /-- Faithfulness: Virtual head approximates composed Jacobian within ε -/ - faithful : isCertifiedVirtualHead pattern.layer2 pattern.layer1 composed_jacobian epsilon - /-- Effectiveness: Virtual head applied to this input produces ≥ delta on target direction -/ - effective : inner_product (VirtualHead pattern.layer2 pattern.layer1 |>.apply input) - target_logit_diff ≥ delta - /-- Bounds are valid -/ - epsilon_nonneg : 0 ≤ epsilon - /-- Delta is nonnegative (can't guarantee negative output) -/ - delta_nonneg : 0 ≤ delta - -/-! ## Sound pattern witnesses -/ - -/-- Minimal token-match pattern witness used by the sound certification path. -/ -structure TokenMatchPattern where - /-- Sequence length for the certificate. -/ - seqLen : Nat - /-- Target offset (e.g. `-1` for previous token). -/ - targetOffset : Int - /-- Lower bound on the number of matching-token keys. -/ - targetCountLowerBound : Nat - /-- Effort level for the `expLB` portfolio used in margin-to-weight bounds. -/ - softmaxExpEffort : Nat - /-- Lower bound on total attention weight assigned to matching tokens. -/ - targetWeightLowerBound : Rat - /-- Lower bound on logit margin between matching vs non-matching keys. -/ - marginLowerBound : Rat - deriving Repr - -namespace TokenMatchPattern - -/-- Soundness invariant for token-match pattern witnesses. -/ -def Valid (p : TokenMatchPattern) : Prop := - p.seqLen > 0 ∧ - p.targetCountLowerBound ≤ p.seqLen ∧ - p.targetWeightLowerBound = - Sound.softmaxTargetWeightLowerBound p.seqLen p.targetCountLowerBound - p.marginLowerBound p.softmaxExpEffort - -instance (p : TokenMatchPattern) : Decidable (Valid p) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (p : TokenMatchPattern) : Bool := - decide (Valid p) - -theorem check_iff (p : TokenMatchPattern) : p.check = true ↔ p.Valid := by - simp [check, Valid] - -/-- If the margin and target count are positive, the weight lower bound matches -the portfolio bound derived from `expLB`. -/ -theorem weight_lower_bound_of_margin_pos - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) : - p.targetWeightLowerBound = - let nRat : Rat := (p.seqLen : Nat) - let tRat : Rat := (p.targetCountLowerBound : Nat) - let base := tRat / nRat - let e := Sound.expLB p.marginLowerBound p.softmaxExpEffort - let cand := (tRat * e) / (tRat * e + (nRat - tRat)) - max base cand := by - rcases h with ⟨hseq, _hcount, hweight⟩ - have hseq0 : p.seqLen ≠ 0 := Nat.ne_of_gt hseq - have hcount0 : p.targetCountLowerBound ≠ 0 := Nat.ne_of_gt hcount - simpa [Sound.softmaxTargetWeightLowerBound_def, hseq0, hcount0, hm] using hweight - -/-- If the margin is nonpositive, the weight lower bound is zero. -/ -theorem weight_lower_bound_of_margin_nonpos - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound ≤ 0) : - p.targetWeightLowerBound = 0 := by - rcases h with ⟨_hseq, _hcount, hweight⟩ - have hm' : ¬ p.marginLowerBound > 0 := by - exact not_lt.mpr hm - by_cases hzero : p.seqLen = 0 || p.targetCountLowerBound = 0 - · simpa [Sound.softmaxTargetWeightLowerBound_def, hzero] using hweight - · simpa [Sound.softmaxTargetWeightLowerBound_def, hzero, hm'] using hweight - -/-- Positive margin and a positive target count imply positive attention mass. -/ -theorem weight_lower_bound_pos_of_margin_pos - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) : - 0 < p.targetWeightLowerBound := by - let nRat : Rat := (p.seqLen : Nat) - let tRat : Rat := (p.targetCountLowerBound : Nat) - let base : Rat := tRat / nRat - let e := Sound.expLB p.marginLowerBound p.softmaxExpEffort - let cand : Rat := (tRat * e) / (tRat * e + (nRat - tRat)) - have hweight := weight_lower_bound_of_margin_pos p h hm hcount - rcases h with ⟨hseq, _hcount, _hweight⟩ - have hseq' : (0 : Rat) < nRat := by - have hseq'' : (0 : Rat) < (p.seqLen : Rat) := by - exact_mod_cast hseq - simpa [nRat] using hseq'' - have hcount' : (0 : Rat) < tRat := by - have hcount'' : (0 : Rat) < (p.targetCountLowerBound : Rat) := by - exact_mod_cast hcount - simpa [tRat] using hcount'' - have hbase : (0 : Rat) < base := by - exact div_pos hcount' hseq' - have hmax : base ≤ max base cand := by - exact le_max_left _ _ - have hpos : (0 : Rat) < max base cand := by - exact lt_of_lt_of_le hbase hmax - simpa [nRat, tRat, base, e, cand, hweight] using hpos - -/-- Either the margin is nonpositive (so the bound is zero), -or the bound is positive when the match count is positive. -/ -theorem weight_lower_bound_dichotomy - (p : TokenMatchPattern) (h : p.Valid) (hcount : 0 < p.targetCountLowerBound) : - p.marginLowerBound ≤ 0 ∨ 0 < p.targetWeightLowerBound := by - by_cases hm : p.marginLowerBound > 0 - · right - exact weight_lower_bound_pos_of_margin_pos p h hm hcount - · left - exact not_lt.mp hm - -end TokenMatchPattern - -/-! ## Sound induction witnesses -/ - -/-- A minimal sound witness for an induction-style attention pattern. -/ -structure InductionPatternWitness where - /-- Token-match pattern data (sound certificate output). -/ - tokenMatch : TokenMatchPattern - /-- The pattern targets the previous-token offset. -/ - prevOffset : tokenMatch.targetOffset = -1 - /-- Certified nontrivial attention mass on matching tokens. -/ - positiveMass : 0 < tokenMatch.targetWeightLowerBound - deriving Repr - -namespace TokenMatchPattern - -/-- Build an induction-style witness from a valid token-match pattern plus explicit assumptions. -/ -def toInductionPatternWitness - (p : TokenMatchPattern) (h : p.Valid) (hm : p.marginLowerBound > 0) - (hcount : 0 < p.targetCountLowerBound) (hoff : p.targetOffset = -1) : - InductionPatternWitness := - { - tokenMatch := p - prevOffset := hoff - positiveMass := weight_lower_bound_pos_of_margin_pos p h hm hcount - } - -end TokenMatchPattern - -/-! ## Verification Theorems -/ - -omit [DecidableEq n] [DecidableEq d] in -/-- **Main Theorem**: True Induction Head Bounds - -Any true induction head has nonnegative epsilon and delta bounds by definition. -/ -theorem true_induction_head_bounds_nonneg {h : TrueInductionHead (n := n) (d := d)} : - (h.epsilon ≥ 0) ∧ (h.delta ≥ 0) := - ⟨h.epsilon_nonneg, h.delta_nonneg⟩ - -omit [DecidableEq n] [DecidableEq d] in -/-- **Key Property**: Virtual head achieves the stated delta bound. - -By definition of `TrueInductionHead`, the virtual head applied to the input -achieves at least delta on the target direction. --/ -lemma virtual_head_achieves_delta {h : TrueInductionHead (n := n) (d := d)} : - inner_product ((VirtualHead h.pattern.layer2 h.pattern.layer1).apply h.input) - h.target_logit_diff ≥ h.delta := - h.effective - -/-! ## Properties of True Induction Heads -/ - -/-- The virtual head output on the certified input. -/ -noncomputable def virtual_head_output {h : TrueInductionHead (n := n) (d := d)} : - (n × d) → ℝ := - (VirtualHead h.pattern.layer2 h.pattern.layer1).apply h.input - -/-- The virtual head's score on the target direction. -/ -noncomputable def virtual_head_score {h : TrueInductionHead (n := n) (d := d)} : ℝ := - inner_product (virtual_head_output (h := h)) h.target_logit_diff - -/-- The approximation error bound. -/ -def approx_error {h : TrueInductionHead (n := n) (d := d)} : ℝ := - h.epsilon - -/-- The functional guarantee on the virtual head. -/ -def min_logit_shift {h : TrueInductionHead (n := n) (d := d)} : ℝ := - h.delta - -omit [DecidableEq n] [DecidableEq d] in -/-- **Composition of mechanisms**: Composed error bound. - -If two true induction heads have errors ε₁ and ε₂ respectively, their -composition has bounded error from the rule: ε_total ≤ ε₁ + ε₂ + ε₁·ε₂. --/ -theorem true_induction_head_composition - (h₁ h₂ : TrueInductionHead (n := n) (d := d)) - (ε : ℝ) - (hε_bound : ε ≤ h₁.epsilon + h₂.epsilon + h₁.epsilon * h₂.epsilon) : - ε ≤ h₁.epsilon + h₂.epsilon + h₁.epsilon * h₂.epsilon := hε_bound - -omit [DecidableEq n] [DecidableEq d] in -/-- **Interpretability Guarantee**: True induction heads are real mechanisms. -/ -theorem true_induction_head_is_genuine - (h : TrueInductionHead (n := n) (d := d)) : - (∃ L₁ L₂, h.pattern.layer1 = L₁ ∧ h.pattern.layer2 = L₂) ∧ - (isCertifiedVirtualHead h.pattern.layer2 h.pattern.layer1 h.composed_jacobian h.epsilon) ∧ - (inner_product ((VirtualHead h.pattern.layer2 h.pattern.layer1).apply h.input) - h.target_logit_diff ≥ h.delta) := by - exact ⟨⟨h.pattern.layer1, h.pattern.layer2, rfl, rfl⟩, h.faithful, h.effective⟩ - -/-! ## Helper inequality: Frobenius norm bounds application -/ - -omit [DecidableEq n] [DecidableEq d] in -/-- For any signed mixer `M` and vector `v`, the output L² norm is bounded by the Frobenius -norm of `M` times the input L² norm. -/ -lemma norm_apply_le (M : SignedMixer (n × d) (n × d)) (v : (n × d) → ℝ) : - l2_norm (M.apply v) ≤ frobeniusNorm (n := n) (d := d) M * l2_norm v := by - classical - set A : ℝ := ∑ i : n × d, (v i) ^ 2 - set C : ℝ := ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2 - have hA : 0 ≤ A := by - simpa [A] using (Finset.sum_nonneg (fun i _hi => sq_nonneg (v i))) - have hC : 0 ≤ C := by - -- two nested sums of squares - have : 0 ≤ ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2 := by - refine Finset.sum_nonneg ?_ - intro i _hi - refine Finset.sum_nonneg ?_ - intro j _hj - exact sq_nonneg (M.w i j) - simpa [C] using this - have hpoint : - ∀ j : n × d, (M.apply v j) ^ 2 ≤ A * (∑ i : n × d, (M.w i j) ^ 2) := by - intro j - -- Cauchy–Schwarz (squared form) on the dot product defining `(M.apply v) j`. - simpa [SignedMixer.apply_def, A] using - (Finset.sum_mul_sq_le_sq_mul_sq (s := (Finset.univ : Finset (n × d))) - (f := v) (g := fun i : n × d => M.w i j)) - have hsum : - (∑ j : n × d, (M.apply v j) ^ 2) ≤ ∑ j : n × d, A * (∑ i : n × d, (M.w i j) ^ 2) := by - refine Finset.sum_le_sum ?_ - intro j _hj - exact hpoint j - have hsum' : - (∑ j : n × d, (M.apply v j) ^ 2) ≤ A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) := by - have hfac : - A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) = - ∑ j : n × d, A * (∑ i : n × d, (M.w i j) ^ 2) := by - simpa using (Finset.mul_sum (s := (Finset.univ : Finset (n × d))) (a := A) - (f := fun j : n × d => ∑ i : n × d, (M.w i j) ^ 2)) - calc - (∑ j : n × d, (M.apply v j) ^ 2) - ≤ ∑ j : n × d, A * (∑ i : n × d, (M.w i j) ^ 2) := hsum - _ = A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) := by - simp [hfac] - have hsum'' : - (∑ j : n × d, (M.apply v j) ^ 2) ≤ A * C := by - have hswap : - (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) = - ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2 := by - simpa using (Finset.sum_comm : - (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) = - ∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2) - calc - (∑ j : n × d, (M.apply v j) ^ 2) - ≤ A * (∑ j : n × d, ∑ i : n × d, (M.w i j) ^ 2) := hsum' - _ = A * C := by - simp [C, hswap] - -- take square roots and unfold definitions - calc - l2_norm (M.apply v) - = Real.sqrt (∑ j : n × d, (M.apply v j) ^ 2) := rfl - _ ≤ Real.sqrt (A * C) := by - exact Real.sqrt_le_sqrt hsum'' - _ = Real.sqrt A * Real.sqrt C := by - simpa using (Real.sqrt_mul hA C) - _ = frobeniusNorm (n := n) (d := d) M * l2_norm v := by - simp [frobeniusNorm, l2_norm, C, A, mul_comm] - -omit [DecidableEq n] [DecidableEq d] in -/-- Finite-dimensional Cauchy–Schwarz for the `inner_product`/`l2_norm` defined in this file. -/ -lemma abs_inner_product_le_l2 (u v : (n × d) → ℝ) : - |inner_product u v| ≤ l2_norm u * l2_norm v := by - classical - have hcs : - (inner_product u v) ^ 2 ≤ (∑ i : n × d, (u i) ^ 2) * (∑ i : n × d, (v i) ^ 2) := by - simpa [inner_product] using - (Finset.sum_mul_sq_le_sq_mul_sq (s := (Finset.univ : Finset (n × d))) (f := u) (g := v)) - have hu : 0 ≤ ∑ i : n × d, (u i) ^ 2 := by - simpa using (Finset.sum_nonneg (fun i _hi => sq_nonneg (u i))) - have hv : 0 ≤ ∑ i : n × d, (v i) ^ 2 := by - simpa using (Finset.sum_nonneg (fun i _hi => sq_nonneg (v i))) - calc - |inner_product u v| - = Real.sqrt ((inner_product u v) ^ 2) := by - simpa using (Real.sqrt_sq_eq_abs (inner_product u v)).symm - _ ≤ Real.sqrt ((∑ i : n × d, (u i) ^ 2) * (∑ i : n × d, (v i) ^ 2)) := by - exact Real.sqrt_le_sqrt hcs - _ = Real.sqrt (∑ i : n × d, (u i) ^ 2) * Real.sqrt (∑ i : n × d, (v i) ^ 2) := by - simpa using (Real.sqrt_mul hu (∑ i : n × d, (v i) ^ 2)) - _ = l2_norm u * l2_norm v := by - rfl - -omit [DecidableEq n] [DecidableEq d] in -/-- **Main verification theorem**: a `TrueInductionHead` lower-bounds the real model score -on the target direction by `δ` minus the certified approximation error. -/ -theorem true_induction_head_predicts_logits - (h : TrueInductionHead (n := n) (d := d)) : - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff ≥ - h.delta - (h.epsilon * l2_norm h.input * l2_norm h.target_logit_diff) := by - classical - let V : SignedMixer (n × d) (n × d) := VirtualHead h.pattern.layer2 h.pattern.layer1 - let E : SignedMixer (n × d) (n × d) := h.composed_jacobian - V - have hE : frobeniusNorm (n := n) (d := d) E ≤ h.epsilon := by - simpa [E, V, isCertifiedVirtualHead] using h.faithful - have hV : h.delta ≤ inner_product (V.apply h.input) h.target_logit_diff := by - simpa [V] using h.effective - have happly_add : (V + E).apply h.input = V.apply h.input + E.apply h.input := by - ext j - simp [SignedMixer.apply_def, Finset.sum_add_distrib, mul_add] - have hJ_eq : h.composed_jacobian = V + E := by - ext i j - simp [E, V] - have hdecomp : - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff = - inner_product (V.apply h.input) h.target_logit_diff + - inner_product (E.apply h.input) h.target_logit_diff := by - have happly : - h.composed_jacobian.apply h.input = V.apply h.input + E.apply h.input := by - simpa [hJ_eq] using happly_add - have hinner_add (a b u : (n × d) → ℝ) : - inner_product (a + b) u = inner_product a u + inner_product b u := by - simp [inner_product, Finset.sum_add_distrib, add_mul] - calc - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff - = inner_product (V.apply h.input + E.apply h.input) h.target_logit_diff := by - simp [happly] - _ = inner_product (V.apply h.input) h.target_logit_diff + - inner_product (E.apply h.input) h.target_logit_diff := by - simpa using hinner_add (a := V.apply h.input) (b := E.apply h.input) - (u := h.target_logit_diff) - set bound : ℝ := h.epsilon * l2_norm h.input * l2_norm h.target_logit_diff - have hbound_nonneg : 0 ≤ bound := by - have hx : 0 ≤ l2_norm h.input := by simp [l2_norm] - have hu : 0 ≤ l2_norm h.target_logit_diff := by simp [l2_norm] - have : 0 ≤ h.epsilon * l2_norm h.input := mul_nonneg h.epsilon_nonneg hx - simpa [bound, mul_assoc] using mul_nonneg this hu - have herr_abs : - |inner_product (E.apply h.input) h.target_logit_diff| ≤ bound := by - have habs : - |inner_product (E.apply h.input) h.target_logit_diff| ≤ - l2_norm (E.apply h.input) * l2_norm h.target_logit_diff := by - simpa using (abs_inner_product_le_l2 (n := n) (d := d) (u := E.apply h.input) - (v := h.target_logit_diff)) - have hnormEx : - l2_norm (E.apply h.input) ≤ frobeniusNorm (n := n) (d := d) E * l2_norm h.input := by - simpa using (norm_apply_le (n := n) (d := d) E h.input) - have hu : 0 ≤ l2_norm h.target_logit_diff := by simp [l2_norm] - have hx : 0 ≤ l2_norm h.input := by simp [l2_norm] - have hstep1 : - l2_norm (E.apply h.input) * l2_norm h.target_logit_diff ≤ - (frobeniusNorm (n := n) (d := d) E * l2_norm h.input) * l2_norm h.target_logit_diff := - mul_le_mul_of_nonneg_right hnormEx hu - have hstep2 : - (frobeniusNorm (n := n) (d := d) E * l2_norm h.input) * l2_norm h.target_logit_diff ≤ - (h.epsilon * l2_norm h.input) * l2_norm h.target_logit_diff := by - have : frobeniusNorm (n := n) (d := d) E * l2_norm h.input ≤ h.epsilon * l2_norm h.input := - mul_le_mul_of_nonneg_right hE hx - exact mul_le_mul_of_nonneg_right this hu - have hchain := le_trans hstep1 hstep2 - have hchain' : - l2_norm (E.apply h.input) * l2_norm h.target_logit_diff ≤ bound := by - simpa [bound, mul_assoc, mul_left_comm, mul_comm] using hchain - exact le_trans habs hchain' - have herr_lower : -bound ≤ inner_product (E.apply h.input) h.target_logit_diff := by - exact (abs_le.mp herr_abs).1 - -- Combine: = + ≥ δ + (-bound) = δ - bound - have hsum_le : - h.delta + (-bound) ≤ - inner_product (V.apply h.input) h.target_logit_diff + - inner_product (E.apply h.input) h.target_logit_diff := by - exact add_le_add hV herr_lower - -- rewrite the goal via the decomposition - have : - h.delta - bound ≤ - inner_product (h.composed_jacobian.apply h.input) h.target_logit_diff := by - simpa [sub_eq_add_neg, hdecomp] using hsum_le - exact this - -end Nfp diff --git a/Nfp/Influence.lean b/Nfp/Influence.lean deleted file mode 100644 index 9e4eba2..0000000 --- a/Nfp/Influence.lean +++ /dev/null @@ -1,351 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Data.Fintype.Prod -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Algebra.Field.Basic -import Mathlib.Logic.Equiv.Fin.Basic -import Mathlib.Tactic.DeriveFintype -import Nfp.Prob -import Nfp.Mixer - -/- -Core influence specifications and helpers that sit below mixers. Influence -specs carry raw, non-normalized edge capacities; `Mixer.ofInfluenceSpec` turns -them into row-stochastic mixers with a small fallback for empty rows. This -file also introduces lightweight sign and hierarchy indices (`Chan`, `HSite`, -`BigSite`) together with convenience projections for tracers living on the -combined site. --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Influence specs -/ - -/-- A raw influence description: adjacency plus nonnegative edge capacities. -/ -structure InfluenceSpec (Site : Type*) where - adj : Site → Site → Prop - κ : ∀ {s t}, adj s t → NNReal - -/-- A family of influence specs indexed by different “views”. -/ -structure InfluenceSpecFamily (Site : Type*) where - View : Type* - spec : View → InfluenceSpec Site - -namespace InfluenceSpecFamily - -variable {Site : Type*} - -/-- Convexly combine multiple influence specs using nonnegative weights `α`. -Specs with zero weight do not contribute to the merged adjacency. -/ -noncomputable def combine (F : InfluenceSpecFamily Site) [Fintype F.View] - (α : F.View → NNReal) (_hα : ∑ v, α v = 1) : InfluenceSpec Site := - { - adj := fun s t => ∃ v, α v ≠ 0 ∧ (F.spec v).adj s t - κ := by - intro s t _ - classical - exact ∑ v, α v * (if h : (F.spec v).adj s t then (F.spec v).κ h else 0) - } - -end InfluenceSpecFamily - -namespace InfluenceSpec - -variable {Site : Type*} [Fintype Site] [DecidableEq Site] - -/-- Sum of outgoing raw capacities from a site. -/ -noncomputable def rowTotal (I : InfluenceSpec Site) (s : Site) : NNReal := by - classical - exact ∑ t, (if h : I.adj s t then I.κ h else 0) - -/-! ### Row scaling helpers -/ - -/-- Scale all raw capacities in a single row of an influence spec. -/ -noncomputable def scaleRow (I : InfluenceSpec Site) (s0 : Site) (c : NNReal) : - InfluenceSpec Site := - { - adj := I.adj - κ := by - intro s t h - by_cases hs : s = s0 - · subst hs - exact c * I.κ h - · exact I.κ h - } - -lemma rowTotal_scaleRow_self (I : InfluenceSpec Site) (s0 : Site) (c : NNReal) : - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s0 = - c * InfluenceSpec.rowTotal (Site := Site) I s0 := by - classical - have hrewrite : - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s0 = - ∑ t : Site, c * (if h : I.adj s0 t then I.κ h else 0) := by - simp [InfluenceSpec.rowTotal, scaleRow] - calc - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s0 = - ∑ t : Site, c * (if h : I.adj s0 t then I.κ h else 0) := hrewrite - _ = c * InfluenceSpec.rowTotal (Site := Site) I s0 := by - -- pull the scalar outside the sum - simpa [InfluenceSpec.rowTotal] using - (Finset.mul_sum (s := (Finset.univ : Finset Site)) - (f := fun t : Site => (if h : I.adj s0 t then I.κ h else 0)) - (a := c)).symm - -lemma rowTotal_scaleRow_other (I : InfluenceSpec Site) {s s0 : Site} (c : NNReal) - (hs : s ≠ s0) : - InfluenceSpec.rowTotal (Site := Site) (scaleRow (Site := Site) I s0 c) s = - InfluenceSpec.rowTotal (Site := Site) I s := by - classical - simp [InfluenceSpec.rowTotal, scaleRow, hs] - -end InfluenceSpec - -namespace InfluenceSpecFamily - -variable {Site : Type*} - -lemma combine_adj_iff (F : InfluenceSpecFamily Site) [Fintype F.View] - (α : F.View → NNReal) (hα : ∑ v, α v = 1) (s t : Site) : - (combine (F:=F) α hα).adj s t ↔ ∃ v, α v ≠ 0 ∧ (F.spec v).adj s t := - Iff.rfl - -end InfluenceSpecFamily - -/-! ## Mixers derived from influence specs -/ - -namespace Mixer - -variable {Site : Type*} [Fintype Site] [DecidableEq Site] - -/-- Canonical mixer obtained by normalizing each row of an influence spec. -If a row has zero total capacity, we fall back to an identity row on that site, -interpreting the site as an absorbing state. -/ -noncomputable def ofInfluenceSpec (I : InfluenceSpec Site) : Mixer Site Site := by - classical - let Z : Site → NNReal := InfluenceSpec.rowTotal (Site := Site) I - refine - { - w := fun s t => - if hZ : Z s = 0 then - if hst : s = t then 1 else 0 - else - if hAdj : I.adj s t then I.κ hAdj / Z s else 0 - row_sum_one := by - intro s - by_cases hZ : Z s = 0 - · have hdiag : - (∑ t : Site, (if s = t then (1 : NNReal) else 0)) = 1 := by - classical - simp - simp [Z, hZ, hdiag] - · have hZne : Z s ≠ 0 := hZ - have hnormalize : - (∑ t : Site, (if h : I.adj s t then I.κ h / Z s else 0)) = - (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) * (1 / Z s) := by - classical - have hrewrite : - (∑ t : Site, (if h : I.adj s t then I.κ h / Z s else 0)) = - (∑ t : Site, (if h : I.adj s t then I.κ h else 0) * (1 / Z s)) := by - refine Finset.sum_congr rfl ?_ - intro t _ht - by_cases hAdj : I.adj s t <;> simp [hAdj, div_eq_mul_inv, mul_comm] - have hfactor : - (∑ t : Site, (if h : I.adj s t then I.κ h else 0) * (1 / Z s)) = - (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) * (1 / Z s) := by - simpa using - (Finset.sum_mul (s := (Finset.univ : Finset Site)) - (f := fun t : Site => if h : I.adj s t then I.κ h else 0) - (a := (1 / Z s))).symm - exact hrewrite.trans hfactor - have hrow : (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) = Z s := rfl - have hnormalized : - (∑ t : Site, (if h : I.adj s t then I.κ h / Z s else 0)) = 1 := by - have hdiv : - (∑ t : Site, (if h : I.adj s t then I.κ h else 0)) * (1 / Z s) = - (1 : NNReal) := by - simp [hrow, div_eq_mul_inv, hZne] - exact hnormalize.trans hdiv - simpa [Z, hZ] using hnormalized - } - -lemma ofInfluenceSpec_zero_of_not_adj (I : InfluenceSpec Site) - {s t : Site} (hZ : InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) - (hAdj : ¬ I.adj s t) : - (ofInfluenceSpec (Site := Site) I).w s t = 0 := by - classical - simp [ofInfluenceSpec, hZ, hAdj] - -lemma ofInfluenceSpec_adj_weight (I : InfluenceSpec Site) - {s t : Site} (hZ : InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) - (hAdj : I.adj s t) : - (ofInfluenceSpec (Site := Site) I).w s t = - I.κ hAdj / InfluenceSpec.rowTotal (Site := Site) I s := by - classical - simp [ofInfluenceSpec, hZ, hAdj] - -lemma ofInfluenceSpec_supported (I : InfluenceSpec Site) - (hZ : ∀ s, InfluenceSpec.rowTotal (Site := Site) I s ≠ 0) : - Mixer.supported (S := Site) (T := Site) (ofInfluenceSpec (Site := Site) I) I.adj := by - intro s t hAdj - exact ofInfluenceSpec_zero_of_not_adj (Site := Site) (I:=I) (hZ s) hAdj - -lemma ofInfluenceSpec_row_scaling (I : InfluenceSpec Site) (s0 : Site) {c : NNReal} - (hc : c ≠ 0) (t : Site) : - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - (ofInfluenceSpec (Site := Site) I).w s0 t := by - classical - by_cases hZ : InfluenceSpec.rowTotal (Site := Site) I s0 = 0 - · have hZ' : - InfluenceSpec.rowTotal (Site := Site) - (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 = 0 := by - simp [InfluenceSpec.rowTotal_scaleRow_self (I:=I) (s0:=s0) (c:=c), hZ] - simp [ofInfluenceSpec, hZ, hZ'] - · have hrow : - InfluenceSpec.rowTotal (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c) - s0 = c * InfluenceSpec.rowTotal (Site := Site) I s0 := - InfluenceSpec.rowTotal_scaleRow_self (I:=I) (s0:=s0) (c:=c) - have hZ' : - InfluenceSpec.rowTotal (Site := Site) - (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 ≠ 0 := by - have hne : c * InfluenceSpec.rowTotal (Site := Site) I s0 ≠ 0 := - mul_ne_zero hc hZ - simpa [hrow] using hne - by_cases hAdj : I.adj s0 t - · have hκ : - (InfluenceSpec.scaleRow (Site := Site) I s0 c).κ hAdj = c * I.κ hAdj := by - simp [InfluenceSpec.scaleRow] - have hleft : - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - (InfluenceSpec.scaleRow (Site := Site) I s0 c).κ hAdj / - InfluenceSpec.rowTotal (Site := Site) - (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 := - ofInfluenceSpec_adj_weight (I := InfluenceSpec.scaleRow (Site := Site) I s0 c) hZ' hAdj - have hright : - (ofInfluenceSpec (Site := Site) I).w s0 t = - I.κ hAdj / InfluenceSpec.rowTotal (Site := Site) I s0 := - ofInfluenceSpec_adj_weight (I := I) hZ hAdj - have hcancel : - (InfluenceSpec.scaleRow (Site := Site) I s0 c).κ hAdj / - InfluenceSpec.rowTotal - (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c) s0 = - c * I.κ hAdj / (c * InfluenceSpec.rowTotal (Site := Site) I s0) := by - simp [hκ, hrow] - calc - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - c * I.κ hAdj / (c * InfluenceSpec.rowTotal (Site := Site) I s0) := by - simp [hcancel] at hleft - simpa using hleft - _ = I.κ hAdj / InfluenceSpec.rowTotal (Site := Site) I s0 := by - have hc' : (c : NNReal) ≠ 0 := hc - simpa [mul_comm, mul_left_comm, mul_assoc] using - (mul_div_mul_left (a := I.κ hAdj) - (b := InfluenceSpec.rowTotal (Site := Site) I s0) - (c := c) (hc := hc')) - _ = (ofInfluenceSpec (Site := Site) I).w s0 t := hright.symm - · have hleft : - (ofInfluenceSpec (Site := Site) (InfluenceSpec.scaleRow (Site := Site) I s0 c)).w s0 t = - 0 := - ofInfluenceSpec_zero_of_not_adj - (I := InfluenceSpec.scaleRow (Site := Site) I s0 c) hZ' hAdj - have hright : - (ofInfluenceSpec (Site := Site) I).w s0 t = 0 := - ofInfluenceSpec_zero_of_not_adj (I := I) hZ hAdj - simp [hleft, hright] - -end Mixer - -/-! ## Sign channels and hierarchical sites -/ - -/-- Explicit sign channels; sign is tracked structurally, not by negative mass. -/ -inductive Chan - | pos - | neg -deriving DecidableEq, Fintype - -/-- Hierarchical site: either a base node or a group-level aggregate. -/ -inductive HSite (Base Group : Type*) : Type _ - | base : Base → HSite Base Group - | group : Group → HSite Base Group -deriving DecidableEq - -namespace HSite - -variable {Base Group : Type*} - -/-- An equivalence to a sum type, used to reuse existing finite instances. -/ -def equivSum : HSite Base Group ≃ Sum Base Group := - { - toFun := fun - | base b => Sum.inl b - | group g => Sum.inr g - invFun := fun - | Sum.inl b => base b - | Sum.inr g => group g - left_inv := by - intro x - cases x <;> rfl - right_inv := by - intro x - cases x <;> rfl - } - -instance [Fintype Base] [Fintype Group] : Fintype (HSite Base Group) := - Fintype.ofEquiv (Sum Base Group) (equivSum (Base:=Base) (Group:=Group)).symm - -end HSite - -/-- Combined site carrying an objective, a base node with sign, or a group. -/ -abbrev BigSite (Obj Node Group : Type*) := - Obj × HSite (Node × Chan) Group - -/-! ## Tracer view helpers on `BigSite` -/ - -namespace TracerViews - -section - -variable {Obj Node Group : Type*} [Fintype Obj] [Fintype Node] [Fintype Group] -variable (p : ProbVec (BigSite Obj Node Group)) - -/-- Positive-channel tracer mass for a specific objective/node pair. -/ -noncomputable def posTracer (o : Obj) (n : Node) : NNReal := - p.mass (o, HSite.base (n, Chan.pos)) - -/-- Negative-channel tracer mass for a specific objective/node pair. -/ -noncomputable def negTracer (o : Obj) (n : Node) : NNReal := - p.mass (o, HSite.base (n, Chan.neg)) - -/-- Net tracer mass (`pos - neg`) for a specific objective/node pair. -/ -noncomputable def netTracer (o : Obj) (n : Node) : ℝ := - (posTracer (p:=p) o n : ℝ) - (negTracer (p:=p) o n : ℝ) - -/-- Group-level tracer mass for an objective and group. -/ -noncomputable def groupTracer (o : Obj) (g : Group) : NNReal := - p.mass (o, HSite.group g) - -@[simp] lemma posTracer_eval (o : Obj) (n : Node) : - posTracer (p:=p) o n = p.mass (o, HSite.base (n, Chan.pos)) := by - rfl - -@[simp] lemma negTracer_eval (o : Obj) (n : Node) : - negTracer (p:=p) o n = p.mass (o, HSite.base (n, Chan.neg)) := by - rfl - -@[simp] lemma groupTracer_eval (o : Obj) (g : Group) : - groupTracer (p:=p) o g = p.mass (o, HSite.group g) := by - rfl - -end - -end TracerViews - -open TracerViews (posTracer negTracer netTracer groupTracer) -export TracerViews (posTracer negTracer netTracer groupTracer) - -end Nfp diff --git a/Nfp/Layers.lean b/Nfp/Layers.lean deleted file mode 100644 index dceaee6..0000000 --- a/Nfp/Layers.lean +++ /dev/null @@ -1,1046 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Nfp.Prob -import Nfp.Mixer -import Nfp.Uniqueness - -/-! -# Neural Network Layer Mixers - -This module formalizes common neural network layer operations as mixers, -establishing the connection between abstract row-stochastic operators and -concrete NN architectures. This enables applying the tracer uniqueness and -attribution theorems to real neural network interpretation. - -## Main definitions - -* `Mixer.identity` – identity/skip connection -* `Mixer.attention` – attention mechanism as a mixer -* `Mixer.selfAttention` – self-attention variant -* `Mixer.residual` – residual connection combining identity with transform - -## Key theorems - -* `Mixer.identity_comp` – identity is a left/right unit for composition -* `Mixer.comp_assoc` – composition is associative -* `effectiveAttention_normalized` – attention rollout forms valid distributions -* `Mixer.push_comp` – pushing through composition equals sequential pushing - -## Neural network interpretation - -The key insight is that many NN operations can be viewed as row-stochastic -operators when considering how "importance" or "relevance" flows: - -- Attention: importance flows according to attention weights -- Skip connections: importance passes through unchanged -- Residual: weighted combination of skip and transform - -## References - -* Abnar & Zuidema: "Quantifying Attention Flow in Transformers" (2020) --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Identity mixer -/ - -section Identity - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- The identity mixer: each source routes entirely to itself. -/ -noncomputable def Mixer.identity : Mixer S S where - w := fun i j => if i = j then 1 else 0 - row_sum_one := by - classical - intro i - simp only [Finset.sum_ite_eq, Finset.mem_univ, ↓reduceIte] - -@[simp] lemma Mixer.identity_diag (i : S) : Mixer.identity.w i i = 1 := by - simp [Mixer.identity] - -@[simp] lemma Mixer.identity_off_diag {i j : S} (h : i ≠ j) : - Mixer.identity.w i j = 0 := by - simp [Mixer.identity, h] - -/-- Identity is a left unit for mixer composition. -/ -@[simp] theorem Mixer.identity_comp (M : Mixer S S) : - Mixer.identity.comp M = M := by - ext i k - simp only [Mixer.comp, Mixer.identity] - classical - simp only [ite_mul, one_mul, zero_mul, Finset.sum_ite_eq, Finset.mem_univ, ↓reduceIte] - -/-- Identity is a right unit for mixer composition. -/ -@[simp] theorem Mixer.comp_identity (M : Mixer S S) : - M.comp Mixer.identity = M := by - ext i k - simp only [Mixer.comp, Mixer.identity] - classical - simp only [mul_ite, mul_one, mul_zero, Finset.sum_ite_eq', Finset.mem_univ, ↓reduceIte] - -end Identity - -/-! ## Mixer composition is associative -/ - -section Associativity - -variable {S T U V : Type*} - [Fintype S] [Fintype T] [Fintype U] [Fintype V] - -/-- Mixer composition is associative. -/ -theorem Mixer.comp_assoc (M : Mixer S T) (N : Mixer T U) (P : Mixer U V) : - (M.comp N).comp P = M.comp (N.comp P) := by - ext i l - simp only [Mixer.comp, Finset.sum_mul, Finset.mul_sum] - rw [Finset.sum_comm] - simp_rw [mul_assoc] - -end Associativity - -/-! ## Attention as a mixer -/ - -section Attention - -variable {Query Key : Type*} [Fintype Query] [Fintype Key] - -/-- Attention weights derived from query-key scores. -Given attention scores `α : Query → Key → NNReal` that are row-normalized -(each query's weights over keys sum to 1), this produces a mixer. -/ -noncomputable def Mixer.attention - (α : Query → Key → NNReal) - (hα : ∀ q, (∑ k, α q k) = 1) : Mixer Query Key where - w := α - row_sum_one := hα - -/-- Self-attention: queries and keys are the same set of positions. -/ -noncomputable def Mixer.selfAttention {Pos : Type*} [Fintype Pos] - (α : Pos → Pos → NNReal) - (hα : ∀ p, (∑ p', α p p') = 1) : Mixer Pos Pos := - Mixer.attention α hα - -/-- Attention is supported on positions with nonzero attention weight. -/ -lemma Mixer.attention_supported {Query Key : Type*} [Fintype Query] [Fintype Key] - (α : Query → Key → NNReal) - (hα : ∀ q, (∑ k, α q k) = 1) : - Mixer.supported (Mixer.attention α hα) (fun q k => α q k ≠ 0) := by - intro q k hne - simp only [Mixer.attention] - by_contra h - exact hne h - -end Attention - -/-! ## Skip/Residual connections -/ - -section Residual - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Residual mixer: mixes identity with another mixer using coefficient `c ∈ [0,1]`. -This models skip connections: `output = c * identity + (1-c) * transform`. -/ -noncomputable def Mixer.residual (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) : Mixer S S where - w := fun i j => c * (if i = j then 1 else 0) + (1 - c) * M.w i j - row_sum_one := by - classical - intro i - simp only [Finset.sum_add_distrib, ← Finset.mul_sum] - simp only [M.row_sum_one, Finset.sum_ite_eq, Finset.mem_univ, ↓reduceIte, mul_one] - rw [add_comm] - exact tsub_add_cancel_of_le hc - -/-- A pure skip connection is the identity mixer. -/ -lemma Mixer.residual_one (M : Mixer S S) : - Mixer.residual M 1 le_rfl = Mixer.identity := by - ext i j - simp [Mixer.residual, Mixer.identity] - -/-- No skip connection passes through the transform entirely. -/ -lemma Mixer.residual_zero (M : Mixer S S) : - Mixer.residual M 0 (zero_le _) = M := by - ext i j - simp [Mixer.residual] - -/-- Residual weight decomposition: the weight splits into identity and transform parts. -/ -lemma Mixer.residual_w (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i j : S) : - (Mixer.residual M c hc).w i j = - c * Mixer.identity.w i j + (1 - c) * M.w i j := by - simp only [Mixer.residual, Mixer.identity] - -end Residual - -/-! ## Attention flow composition theorems -/ - -section AttentionFlow - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- Attention flow: composition of attention matrices across layers. -This captures how attention "flows" across layers in list order (row-vector convention). - -This is the formal version of "attention rollout" from Abnar & Zuidema (2020). -/ -noncomputable def attentionFlow (layers : List (Mixer Pos Pos)) : Mixer Pos Pos := - layers.foldl Mixer.comp Mixer.identity - -/-- Single layer attention flow is identity composed with the layer. -/ -lemma attentionFlow_singleton (M : Mixer Pos Pos) : - attentionFlow [M] = Mixer.identity.comp M := rfl - -/-- Empty attention flow is identity. -/ -@[simp] lemma attentionFlow_nil : - attentionFlow (Pos := Pos) [] = Mixer.identity := rfl - -/-- Attention flow of a single layer simplifies to just the layer. -/ -@[simp] lemma attentionFlow_singleton' (M : Mixer Pos Pos) : - attentionFlow [M] = M := by - simp [attentionFlow_singleton] - -/-- Two-layer attention flow is just composition. -/ -lemma attentionFlow_two (M₁ M₂ : Mixer Pos Pos) : - attentionFlow [M₁, M₂] = M₁.comp M₂ := by - simp [attentionFlow] - -/-- Three-layer attention flow is associative composition. -/ -lemma attentionFlow_three (M₁ M₂ M₃ : Mixer Pos Pos) : - attentionFlow [M₁, M₂, M₃] = (M₁.comp M₂).comp M₃ := by - simp [attentionFlow] - -end AttentionFlow - -/-! ## Layer composition helpers -/ - -section LayerComp - -variable {S T U V : Type*} - [Fintype S] [Fintype T] [Fintype U] [Fintype V] - -/-- Compose three layers (common pattern: embed → transform → project). -/ -noncomputable def Mixer.comp3 - (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U V) : Mixer S V := - (M₁.comp M₂).comp M₃ - -/-- comp3 is equivalent to right-associated composition. -/ -lemma Mixer.comp3_eq_comp_comp (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U V) : - M₁.comp3 M₂ M₃ = M₁.comp (M₂.comp M₃) := by - simp [Mixer.comp3, Mixer.comp_assoc] - -end LayerComp - -/-! ## Transformer blocks -/ - -section TransformerBlock - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- A full transformer block conceptually: attention + feedforward with residuals. - -In a Pre-LN transformer (e.g. GPT-2): `y = x + Attention(LayerNorm(x))` followed by -`output = y + FFN(LayerNorm(y))`. We model this as composition of residual mixers. - -The coefficients `c_attn` and `c_ff` control how much of the skip connection -vs the transformed value flows through. -/ -noncomputable def Mixer.transformerBlock - (attn : Mixer Pos Pos) - (ff : Mixer Pos Pos) - (c_attn c_ff : NNReal) - (h_attn : c_attn ≤ 1) (h_ff : c_ff ≤ 1) : Mixer Pos Pos := - (Mixer.residual attn c_attn h_attn).comp (Mixer.residual ff c_ff h_ff) - -/-- A transformer block with no skip connections is just attention then FFN. -/ -lemma Mixer.transformerBlock_no_skip (attn ff : Mixer Pos Pos) : - Mixer.transformerBlock attn ff 0 0 (zero_le _) (zero_le _) = attn.comp ff := by - simp [Mixer.transformerBlock, Mixer.residual_zero] - -/-- A transformer block with full skip connections is identity. -/ -lemma Mixer.transformerBlock_full_skip (attn ff : Mixer Pos Pos) : - Mixer.transformerBlock attn ff 1 1 le_rfl le_rfl = Mixer.identity := by - simp [Mixer.transformerBlock, Mixer.residual_one] - -end TransformerBlock - -/-! ## Stacking transformer layers -/ - -section TransformerStack - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- A stack of transformer blocks. -/ -noncomputable def transformerStack (blocks : List (Mixer Pos Pos)) : Mixer Pos Pos := - attentionFlow blocks - -/-- The effective attention from position `i` to position `j` through a stack -of `n` transformer layers is given by the composed mixer weight. -/ -noncomputable def effectiveAttention - (blocks : List (Mixer Pos Pos)) (i j : Pos) : NNReal := - (transformerStack blocks).w i j - -/-- Effective attention forms a probability distribution over target positions -for each source position. This is a key property for interpretation: -it tells us "how much" each source position contributes to each target. -/ -theorem effectiveAttention_normalized (blocks : List (Mixer Pos Pos)) (i : Pos) : - (∑ j, effectiveAttention blocks i j) = 1 := by - simp only [effectiveAttention, transformerStack] - exact (attentionFlow blocks).row_sum_one i - -end TransformerStack - -/-! ## Path-based decomposition - -This section provides the key insight connecting mixer composition to -path-based attribution. The weight `(M.comp N).w i k` can be decomposed -as a sum over intermediate positions, corresponding to paths through -the computation graph. --/ - -section PathDecomposition - -variable {S T U V : Type*} [Fintype S] [Fintype T] [Fintype U] [Fintype V] - -/-- The composition weight decomposes as a sum over paths through the intermediate layer. -This is the foundation for path-integrated attribution methods. -/ -theorem Mixer.comp_path_decomposition (M : Mixer S T) (N : Mixer T U) (i : S) (k : U) : - (M.comp N).w i k = ∑ j, M.w i j * N.w j k := rfl - -/-- The contribution of path `i → j → k` to the total weight `i → k`. -/ -noncomputable def pathContrib (M : Mixer S T) (N : Mixer T U) (i : S) (j : T) (k : U) : NNReal := - M.w i j * N.w j k - -/-- Path contributions sum to the total weight. -/ -theorem pathContrib_sum (M : Mixer S T) (N : Mixer T U) (i : S) (k : U) : - (∑ j, pathContrib M N i j k) = (M.comp N).w i k := by - simp only [pathContrib, Mixer.comp] - -/-- For three-layer composition, paths go through two intermediate positions. -/ -theorem Mixer.comp3_path_decomposition - (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U V) (i : S) (l : V) : - (M₁.comp3 M₂ M₃).w i l = ∑ j, ∑ k, M₁.w i j * M₂.w j k * M₃.w k l := by - simp only [Mixer.comp3, Mixer.comp, Finset.sum_mul] - rw [Finset.sum_comm] - -end PathDecomposition - -/-! ## Conservation theorems - -Key insight: mixer operations preserve total probability mass. -This connects to the completeness axiom in attribution theory. --/ - -section Conservation - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- Pushing a probability vector through a mixer preserves total mass. -This is the probabilistic interpretation of "conservation". -/ -theorem Mixer.push_preserves_total_mass (M : Mixer S T) (p : ProbVec S) : - (∑ j, (M.push p).mass j) = ∑ i, p.mass i := by - simp only [ProbVec.sum_mass] - -/-- The pushed mass at position `j` is the weighted sum of source masses. -/ -lemma Mixer.push_mass_eq (M : Mixer S T) (p : ProbVec S) (j : T) : - (M.push p).mass j = ∑ i, p.mass i * M.w i j := rfl - -/-- Conservation for composition: pushing through composed mixers -equals pushing through each sequentially. -/ -theorem Mixer.push_comp (M : Mixer S T) (N : Mixer T U) (p : ProbVec S) : - (M.comp N).push p = N.push (M.push p) := by - ext k - simp only [Mixer.push, Mixer.comp] - simp only [Finset.mul_sum, Finset.sum_mul] - rw [Finset.sum_comm] - simp_rw [mul_assoc] - -end Conservation - -/-! ## Residual stream decomposition - -A key insight for transformer interpretation: residual connections create -multiple "paths" through the network. The effective contribution from source -to target can be decomposed into: -1. Direct path: information flows unchanged through the skip connection -2. Indirect path: information is transformed by the attention/FFN layer - -This decomposition is crucial for understanding how much a layer actually -contributes vs how much just passes through unchanged. --/ - -section ResidualDecomposition - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- **Residual decomposition theorem**: The residual mixer weight at (i,j) equals -the sum of direct (skip) and indirect (transform) path contributions. - -This is the formal version of: "How much information flows directly vs through the layer?" -/ -theorem Mixer.residual_decomposition (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i j : S) : - (Mixer.residual M c hc).w i j = - c * (if i = j then 1 else 0) + (1 - c) * M.w i j := by - simp only [Mixer.residual] - -/-- **Skip connection dominance**: When the skip coefficient c is large, -the residual is close to identity. Specifically, the diagonal entries -are at least c. -/ -theorem Mixer.residual_skip_dominance (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i : S) : - (Mixer.residual M c hc).w i i ≥ c := by - simp only [Mixer.residual, ↓reduceIte, mul_one, le_add_iff_nonneg_right, zero_le] - -/-- Off-diagonal entries are bounded by the indirect path contribution. -/ -theorem Mixer.residual_off_diag_bound (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) - {i j : S} (hij : i ≠ j) : - (Mixer.residual M c hc).w i j = (1 - c) * M.w i j := by - simp only [Mixer.residual, hij, ↓reduceIte, mul_zero, zero_add] - -/-- **Interpretation insight**: Off-diagonal entries are scaled down by (1-c). -If c = 0.9, off-diagonal influence is reduced to 10% of original. -This quantifies how residual connections "protect" self-information. -/ -theorem Mixer.residual_off_diag_scaling (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) - {i j : S} (hij : i ≠ j) : - (Mixer.residual M c hc).w i j ≤ M.w i j := by - rw [Mixer.residual_off_diag_bound M c hc hij] - have h1 : (1 - c) ≤ 1 := tsub_le_self - calc (1 - c) * M.w i j ≤ 1 * M.w i j := mul_le_mul_of_nonneg_right h1 (zero_le _) - _ = M.w i j := one_mul _ - -end ResidualDecomposition - -/-! ## Attention concentration and information flow - -When attention is concentrated on few positions, this limits how information -can spread through the network. These theorems formalize bounds on information flow. --/ - -section AttentionConcentration - -variable {S : Type*} [Fintype S] - -/-- The maximum attention weight from position i determines an upper bound -on how much relevance can flow to any single source position. -/ -noncomputable def Mixer.maxWeight (M : Mixer S S) (i : S) : NNReal := - Finset.sup' Finset.univ ⟨i, Finset.mem_univ i⟩ (fun j => M.w i j) - -/-- Any weight is at most the maxWeight. -/ -lemma Mixer.weight_le_maxWeight (M : Mixer S S) (i j : S) : - M.w i j ≤ M.maxWeight i := by - simp only [Mixer.maxWeight] - exact Finset.le_sup' (fun k => M.w i k) (Finset.mem_univ j) - -/-- **Attention bottleneck**: The pushed mass at any position j is bounded by -the sum over i of (mass at i) × (max attention weight from i). -In other words, if attention is spread thin, mass can't concentrate. -/ -theorem Mixer.push_concentration_bound (M : Mixer S S) (p : ProbVec S) (j : S) : - (M.push p).mass j ≤ ∑ i, p.mass i * M.maxWeight i := by - simp only [Mixer.push] - apply Finset.sum_le_sum - intro i _ - exact mul_le_mul_of_nonneg_left (M.weight_le_maxWeight i j) (zero_le _) - -end AttentionConcentration - -/-! ## Ablation and masking analysis - -When we "ablate" or "mask" certain positions, we effectively zero out their -contribution. This section formalizes the effect of such interventions. --/ - -section Ablation - -variable {S : Type*} [Fintype S] - -/-- A masked weight function: positions in the mask set have their outgoing weights zeroed. -This models "what if we remove these positions from consideration?" - -Note: This is sub-stochastic (rows of blocked positions don't sum to 1). -/ -noncomputable def Mixer.maskFn (M : Mixer S S) (blocked : Set S) [DecidablePred blocked] : - S → S → NNReal := - fun i j => if blocked i then 0 else M.w i j - -/-- Masking a position removes its contribution entirely. -/ -lemma Mixer.maskFn_blocked (M : Mixer S S) (blocked : Set S) [DecidablePred blocked] - {i : S} (hi : blocked i) (j : S) : - M.maskFn blocked i j = 0 := by - simp [Mixer.maskFn, hi] - -/-- Unblocked positions keep their original weights. -/ -lemma Mixer.maskFn_unblocked (M : Mixer S S) (blocked : Set S) [DecidablePred blocked] - {i : S} (hi : ¬blocked i) (j : S) : - M.maskFn blocked i j = M.w i j := by - simp [Mixer.maskFn, hi] - -/-- The contribution from blocked positions to position j. -/ -noncomputable def blockedContribution (M : Mixer S S) (p : ProbVec S) - (blocked : Set S) [DecidablePred blocked] (j : S) : NNReal := - ∑ i : S, if blocked i then p.mass i * M.w i j else 0 - -/-- The contribution from unblocked positions to position j. -/ -noncomputable def unblockedContribution (M : Mixer S S) (p : ProbVec S) - (blocked : Set S) [DecidablePred blocked] (j : S) : NNReal := - ∑ i : S, if blocked i then 0 else p.mass i * M.w i j - -/-- **Ablation decomposition**: The pushed mass equals blocked plus unblocked contributions. -/ -theorem Mixer.ablation_decomposition (M : Mixer S S) (p : ProbVec S) - (blocked : Set S) [DecidablePred blocked] (j : S) : - (M.push p).mass j = - unblockedContribution M p blocked j + blockedContribution M p blocked j := by - simp only [Mixer.push, unblockedContribution, blockedContribution] - rw [← Finset.sum_add_distrib] - apply Finset.sum_congr rfl - intro i _ - split_ifs <;> simp - -end Ablation - -/-! ## Composition depth and information spread - -As we compose more layers, information can spread to more positions. -These theorems characterize how "reach" grows with depth. --/ - -section CompositionDepth - -variable {S : Type*} [Fintype S] - -/-- A position j is "reachable" from i through mixer M if M.w i j > 0. -/ -def Mixer.reachable (M : Mixer S S) (i j : S) : Prop := M.w i j ≠ 0 - -/-- **Reach expansion**: If j is reachable from i through M, and k is reachable -from j through N, then the composition has at least the product weight. -This shows how influence compounds through layers. -/ -theorem Mixer.reach_comp (M N : Mixer S S) {i j k : S} - (_ : M.reachable i j) (_ : N.reachable j k) : - (M.comp N).w i k ≥ M.w i j * N.w j k := by - simp only [Mixer.comp, ge_iff_le] - exact Finset.single_le_sum (f := fun x => M.w i x * N.w x k) - (fun x _ => zero_le _) (Finset.mem_univ j) - -/-- **Path contribution bound**: The contribution through any single intermediate j -is at most the composed weight. -/ -theorem Mixer.path_contrib_le_comp (M N : Mixer S S) (i j k : S) : - M.w i j * N.w j k ≤ (M.comp N).w i k := by - simp only [Mixer.comp] - exact Finset.single_le_sum (f := fun x => M.w i x * N.w x k) - (fun x _ => zero_le _) (Finset.mem_univ j) - -/-- Composing mixers preserves reachability through any nonzero path. -/ -theorem Mixer.comp_reachable_of_path (M N : Mixer S S) {i j k : S} - (hij : M.w i j ≠ 0) (hjk : N.w j k ≠ 0) : - (M.comp N).reachable i k := by - simp only [Mixer.reachable, Mixer.comp, ne_eq] - intro h - have hterm : M.w i j * N.w j k = 0 := by - have hle : M.w i j * N.w j k ≤ ∑ x, M.w i x * N.w x k := - Finset.single_le_sum (f := fun x => M.w i x * N.w x k) - (fun x _ => zero_le _) (Finset.mem_univ j) - rw [h] at hle - exact le_antisymm hle (zero_le _) - cases mul_eq_zero.mp hterm with - | inl h => exact hij h - | inr h => exact hjk h - -end CompositionDepth - -/-! ## Information-theoretic bounds - -These theorems connect mixer properties to information-theoretic concepts, -providing bounds on how much "information" can flow through attention layers. --/ - -section InformationBounds - -variable {S : Type*} [Fintype S] - -/-- The "effective support size" from position i: how many positions receive -nonzero attention. Smaller support = more concentrated attention. -/ -noncomputable def Mixer.supportSize (M : Mixer S S) (i : S) : ℕ := - (Finset.univ.filter (fun j => M.w i j ≠ 0)).card - -/-- Row-stochasticity means at least one entry is nonzero (assuming S nonempty). -/ -lemma Mixer.exists_nonzero [Nonempty S] (M : Mixer S S) (i : S) : ∃ j, M.w i j ≠ 0 := by - by_contra h - push_neg at h - have hsum : ∑ j, M.w i j = 0 := Finset.sum_eq_zero (fun j _ => h j) - rw [M.row_sum_one i] at hsum - exact one_ne_zero hsum - -/-- **Support size bound**: In a nonempty type, every row has positive support. -/ -theorem Mixer.supportSize_pos [Nonempty S] (M : Mixer S S) (i : S) : - M.supportSize i ≥ 1 := by - simp only [Mixer.supportSize] - obtain ⟨j, hj⟩ := M.exists_nonzero i - exact Finset.one_le_card.mpr ⟨j, by simp [hj]⟩ - -end InformationBounds - -/-! ## Gradient-attribution correspondence - -A key insight: for linear layers, mixer-based attribution corresponds exactly -to gradient-based attribution. This section establishes this correspondence. --/ - -section GradientCorrespondence - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- **Gradient-attribution alignment**: For composed linear layers, the composed -attribution equals the product of individual attributions, summed over paths. -This is analogous to the chain rule for gradients. -/ -theorem Mixer.chain_rule_analog (M : Mixer S T) (N : Mixer T U) (i : S) (k : U) : - (M.comp N).w i k = ∑ j, M.w i j * N.w j k := rfl - -/-- **Three-layer chain rule**: The attribution through three layers -decomposes into a double sum over intermediate positions. -/ -theorem Mixer.chain_rule_three (M₁ : Mixer S T) (M₂ : Mixer T U) (M₃ : Mixer U S) (i l : S) : - (M₁.comp (M₂.comp M₃)).w i l = ∑ j, ∑ k, M₁.w i j * M₂.w j k * M₃.w k l := by - simp only [Mixer.comp, Finset.mul_sum, mul_assoc] - -end GradientCorrespondence - -/-! ## Multi-head attention - -Real transformers use multiple attention heads that are combined. This section -formalizes multi-head attention and proves key properties about how heads combine. - -In practice, each head has its own attention pattern, and the outputs are -concatenated and projected. For relevance/attribution purposes, this is equivalent -to a weighted combination of the individual head attention patterns. --/ - -section MultiHead - -variable {Pos : Type*} [Fintype Pos] -variable {numHeads : ℕ} - -/-- Multi-head attention combines multiple attention heads with weights. -The weights typically come from the output projection and represent how much -each head contributes to the final output. - -For interpretation: if we want to know "how much does position j contribute -to position i", we sum over heads weighted by their importance. -/ -noncomputable def Mixer.multiHead - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) : Mixer Pos Pos := - { w := fun i j => ∑ h, headWeights h * (heads h).w i j, - row_sum_one := by - intro i - rw [Finset.sum_comm] - simp_rw [← Finset.mul_sum, Mixer.row_sum_one, mul_one, hsum] } - -/-- Each head's contribution to the multi-head attention is bounded by its weight. -/ -theorem Mixer.multiHead_head_contrib_bound - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) - (h : Fin numHeads) (i j : Pos) : - headWeights h * (heads h).w i j ≤ (Mixer.multiHead heads headWeights hsum).w i j := by - simp only [Mixer.multiHead] - exact Finset.single_le_sum (f := fun k => headWeights k * (heads k).w i j) - (fun _ _ => zero_le _) (Finset.mem_univ h) - -/-- **Head importance theorem**: A head with zero weight contributes nothing. -This formalizes the intuition that "unimportant" heads can be pruned. -/ -theorem Mixer.multiHead_zero_weight - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (h : Fin numHeads) (hw : headWeights h = 0) (i j : Pos) : - headWeights h * (heads h).w i j = 0 := by - simp [hw] - -/-- **Single head dominance**: If one head has weight 1 (others have weight 0), -multi-head attention reduces to that single head's attention. -/ -theorem Mixer.multiHead_single_head - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) - (h₀ : Fin numHeads) (hdom : headWeights h₀ = 1) - (hzero : ∀ h, h ≠ h₀ → headWeights h = 0) (i j : Pos) : - (Mixer.multiHead heads headWeights hsum).w i j = (heads h₀).w i j := by - simp only [Mixer.multiHead] - have hsplit : ∑ h, headWeights h * (heads h).w i j = - headWeights h₀ * (heads h₀).w i j + - ∑ h ∈ Finset.univ.erase h₀, headWeights h * (heads h).w i j := by - rw [← Finset.add_sum_erase _ _ (Finset.mem_univ h₀)] - rw [hsplit, hdom, one_mul, add_eq_left] - apply Finset.sum_eq_zero - intro h hh - simp only [Finset.mem_erase, Finset.mem_univ, ne_eq] at hh - simp [hzero h hh.1] - -/-- The multi-head attention weight is a convex combination of individual head weights. -/ -theorem Mixer.multiHead_convex - (heads : Fin numHeads → Mixer Pos Pos) - (headWeights : Fin numHeads → NNReal) - (hsum : ∑ h, headWeights h = 1) - (i j : Pos) : - (Mixer.multiHead heads headWeights hsum).w i j ≤ 1 := by - simp only [Mixer.multiHead] - calc ∑ h, headWeights h * (heads h).w i j - ≤ ∑ h, headWeights h * 1 := by - apply Finset.sum_le_sum - intro h _ - apply mul_le_mul_of_nonneg_left _ (zero_le _) - have hrow := (heads h).row_sum_one i - have hle : (heads h).w i j ≤ ∑ k, (heads h).w i k := - Finset.single_le_sum (f := fun k => (heads h).w i k) - (fun _ _ => zero_le _) (Finset.mem_univ j) - rw [hrow] at hle - exact hle - _ = ∑ h, headWeights h := by simp - _ = 1 := hsum - -end MultiHead - -/-! ## Causal masking - -Autoregressive models (GPT-style) use causal masking: position i can only attend -to positions j ≤ i. This creates a triangular attention pattern with important -consequences for information flow and attribution. --/ - -section CausalMask - -variable {n : ℕ} - -/-- A causal attention mask: position i can attend to j only if j ≤ i. -This models autoregressive/decoder-only transformers like GPT. -/ -def isCausal (M : Mixer (Fin n) (Fin n)) : Prop := - ∀ i j : Fin n, j.val > i.val → M.w i j = 0 - -/-- **Causal reachability**: In a causal mixer, information can only flow -from later to earlier positions (or stay in place). -/ -theorem causal_reachable_dir (M : Mixer (Fin n) (Fin n)) (hcaus : isCausal M) - {i j : Fin n} (hreach : M.reachable i j) : j.val ≤ i.val := by - by_contra h - push_neg at h - have := hcaus i j h - exact hreach this - -/-- Composition of causal mixers is causal. This means stacking causal attention -layers preserves the causal property. -/ -theorem causal_comp (M N : Mixer (Fin n) (Fin n)) - (hM : isCausal M) (hN : isCausal N) : isCausal (M.comp N) := by - intro i j hij - simp only [Mixer.comp] - apply Finset.sum_eq_zero - intro k _ - by_cases hk : k.val > i.val - · simp [hM i k hk] - · push_neg at hk - by_cases hkj : j.val > k.val - · simp [hN k j hkj] - · push_neg at hkj - -- k ≤ i and j ≤ k, so j ≤ i, contradicting hij - omega - -/-- **Causal information bound**: In a causal model, the total attention from -position i to future positions is zero. All attention goes to past/current. -/ -theorem causal_future_attention_zero (M : Mixer (Fin n) (Fin n)) (hcaus : isCausal M) - (i : Fin n) : ∑ j ∈ Finset.univ.filter (fun j => j.val > i.val), M.w i j = 0 := by - apply Finset.sum_eq_zero - intro j hj - simp only [Finset.mem_filter, Finset.mem_univ, true_and] at hj - exact hcaus i j hj - -/-- In a causal mixer, all attention mass goes to positions ≤ i. -/ -theorem causal_past_attention_one (M : Mixer (Fin n) (Fin n)) (hcaus : isCausal M) - (i : Fin n) : ∑ j ∈ Finset.univ.filter (fun j => j.val ≤ i.val), M.w i j = 1 := by - have htotal := M.row_sum_one i - have hfuture := causal_future_attention_zero M hcaus i - -- Show the two filters partition univ - have hpart : Finset.univ.filter (fun j : Fin n => j.val ≤ i.val) ∪ - Finset.univ.filter (fun j => i.val < j.val) = Finset.univ := by - ext x - simp only [Finset.mem_union, Finset.mem_filter, Finset.mem_univ, true_and, iff_true] - exact le_or_gt x.val i.val - have hdisj : Disjoint (Finset.univ.filter (fun j : Fin n => j.val ≤ i.val)) - (Finset.univ.filter (fun j => i.val < j.val)) := by - rw [Finset.disjoint_filter] - intro x _ hle hlt - omega - have key : ∑ j, M.w i j = ∑ j ∈ Finset.univ.filter (fun j => j.val ≤ i.val), M.w i j + - ∑ j ∈ Finset.univ.filter (fun j => i.val < j.val), M.w i j := by - conv_lhs => rw [← hpart] - rw [Finset.sum_union hdisj] - simp only [hfuture, add_zero, htotal] at key - exact key.symm - -/-- **First token theorem**: In a causal model, the first position (index 0) -can only attend to itself, so its self-attention weight is 1. -/ -theorem causal_first_token_self (M : Mixer (Fin (n + 1)) (Fin (n + 1))) - (hcaus : isCausal M) : M.w 0 0 = 1 := by - have h := causal_past_attention_one M hcaus 0 - simp only [Fin.val_zero] at h - have hfilt : Finset.univ.filter (fun j : Fin (n + 1) => j.val ≤ 0) = {0} := by - ext x - simp only [Finset.mem_filter, Finset.mem_univ, true_and, Finset.mem_singleton] - constructor - · intro hx - exact Fin.ext (Nat.le_zero.mp hx) - · intro hx - simp [hx] - rw [hfilt] at h - simpa using h - -end CausalMask - -/-! ## Attention head analysis - -Tools for analyzing individual attention heads and their roles. --/ - -section HeadAnalysis - -variable {Pos : Type*} [Fintype Pos] - -/-- The "concentration" of attention from position i: sum of squared weights. -Higher value = more concentrated on few positions. Lower = more spread out. -This is a measure related to the inverse of entropy. -/ -noncomputable def Mixer.attentionConcentration (M : Mixer Pos Pos) (i : Pos) : NNReal := - ∑ j, (M.w i j) ^ 2 - -/-- Concentration is at most 1 (achieved when all attention on one position). -/ -theorem Mixer.attentionConcentration_upper_bound (M : Mixer Pos Pos) (i : Pos) : - M.attentionConcentration i ≤ 1 := by - simp only [Mixer.attentionConcentration] - have hsum := M.row_sum_one i - calc ∑ j, (M.w i j) ^ 2 - ≤ ∑ j, M.w i j := by - apply Finset.sum_le_sum - intro j _ - rw [sq] - have hle : M.w i j ≤ ∑ k, M.w i k := - Finset.single_le_sum (f := fun k => M.w i k) (fun _ _ => zero_le _) (Finset.mem_univ j) - rw [hsum] at hle - calc M.w i j * M.w i j ≤ M.w i j * 1 := mul_le_mul_of_nonneg_left hle (zero_le _) - _ = M.w i j := mul_one _ - _ = 1 := hsum - -/-- **Sparsity indicator**: An attention head is "sparse" at position i if its -concentration is high (close to 1). This indicates it focuses on few positions. -/ -def Mixer.isSparseAt (M : Mixer Pos Pos) (i : Pos) (threshold : NNReal) : Prop := - M.attentionConcentration i ≥ threshold - -/-- **Uniform attention indicator**: An attention head is "diffuse" at position i -if its concentration is low. This indicates it spreads attention broadly. -/ -def Mixer.isDiffuseAt (M : Mixer Pos Pos) (i : Pos) (threshold : NNReal) : Prop := - M.attentionConcentration i ≤ threshold - -/-- If all attention is on one position, concentration is 1. -/ -theorem Mixer.attentionConcentration_one_hot - (M : Mixer Pos Pos) (i j₀ : Pos) - (h : M.w i j₀ = 1) (hz : ∀ j, j ≠ j₀ → M.w i j = 0) : - M.attentionConcentration i = 1 := by - classical - simp only [Mixer.attentionConcentration] - have hsplit : ∑ j, (M.w i j) ^ 2 = - (M.w i j₀) ^ 2 + ∑ j ∈ Finset.univ.erase j₀, (M.w i j) ^ 2 := by - rw [← Finset.add_sum_erase _ _ (Finset.mem_univ j₀)] - rw [hsplit, h, one_pow, add_eq_left] - apply Finset.sum_eq_zero - intro j hj - simp only [Finset.mem_erase, Finset.mem_univ, ne_eq] at hj - simp [hz j hj.1] - -end HeadAnalysis - -/-! ## Residual dominance analysis - -For interpreting transformers, we often want to know: "How dominant is the -skip connection vs the attention?" This section provides tools for this. --/ - -section ResidualDominance - -variable {S : Type*} [Fintype S] - -/-- A residual layer with coefficient c gives diagonal elements at least c. -This is a key property for interpretation: high c means the skip connection -dominates, preserving information from earlier layers. -/ -theorem residual_diagonal_lower [DecidableEq S] - (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i : S) : - (Mixer.residual M c hc).w i i ≥ c := by - simp only [Mixer.residual] - have hnonneg : 0 ≤ (1 - c) * M.w i i := by - have h1 : 0 ≤ (1 - c) := by - exact zero_le _ - exact mul_nonneg h1 (by simp) - calc c * 1 + (1 - c) * M.w i i ≥ c * 1 := by - exact le_add_of_nonneg_right hnonneg - _ = c * 1 + 0 := by simp - _ = c := by ring - -/-- Off-diagonal elements of a residual are scaled down by (1-c). -This quantifies how much the attention contribution is suppressed. -/ -theorem residual_offdiag_scale [DecidableEq S] - (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) - (i j : S) (hij : i ≠ j) : - (Mixer.residual M c hc).w i j = (1 - c) * M.w i j := by - simp only [Mixer.residual] - simp [hij] - -/-- The sum of off-diagonal elements in a residual row is at most (1-c). -This bounds how much attention "leaks" to other positions. -/ -theorem residual_offdiag_sum_bound [DecidableEq S] - (M : Mixer S S) (c : NNReal) (hc : c ≤ 1) (i : S) : - ∑ j ∈ Finset.univ.filter (· ≠ i), (Mixer.residual M c hc).w i j ≤ 1 - c := by - calc ∑ j ∈ Finset.univ.filter (· ≠ i), (Mixer.residual M c hc).w i j - = ∑ j ∈ Finset.univ.filter (· ≠ i), (1 - c) * M.w i j := by - apply Finset.sum_congr rfl - intro j hj - simp only [Finset.mem_filter, Finset.mem_univ, true_and] at hj - exact residual_offdiag_scale M c hc i j (Ne.symm hj) - _ = (1 - c) * ∑ j ∈ Finset.univ.filter (· ≠ i), M.w i j := by - rw [Finset.mul_sum] - _ ≤ (1 - c) * 1 := by - apply mul_le_mul_of_nonneg_left _ (by simp) - calc ∑ j ∈ Finset.univ.filter (· ≠ i), M.w i j - ≤ ∑ j, M.w i j := Finset.sum_le_sum_of_subset (fun x hx => Finset.mem_univ x) - _ = 1 := M.row_sum_one i - _ = 1 - c := by ring - -end ResidualDominance - -/-! ## Deep composition bounds - -A key question for transformer interpretation: after L layers, how spread out -is the attribution? This section provides quantitative bounds. --/ - -section DeepComposition - -variable {Pos : Type*} [Fintype Pos] - -/-- Composition weight is at most 1 (row-stochastic property preserved). -/ -theorem comp_weight_le_one (M N : Mixer Pos Pos) (i k : Pos) : - (M.comp N).w i k ≤ 1 := by - have h := (M.comp N).row_sum_one i - calc (M.comp N).w i k - ≤ ∑ k', (M.comp N).w i k' := - Finset.single_le_sum (fun _ _ => zero_le _) (Finset.mem_univ k) - _ = 1 := h - -/-- Each term in a composition sum is bounded by the corresponding M weight. -/ -theorem comp_term_bound (M N : Mixer Pos Pos) (i j k : Pos) : - M.w i j * N.w j k ≤ M.w i j := by - calc M.w i j * N.w j k ≤ M.w i j * 1 := by - apply mul_le_mul_of_nonneg_left _ (zero_le _) - calc N.w j k ≤ ∑ k', N.w j k' := - Finset.single_le_sum (fun _ _ => zero_le _) (Finset.mem_univ k) - _ = 1 := N.row_sum_one j - _ = M.w i j := by ring - -end DeepComposition - -/-! ## Cross-attention for encoder-decoder models - -Real seq2seq models use cross-attention where queries come from the decoder -and keys/values come from the encoder. This is a mixer from decoder positions -to encoder positions (tracking where decoder attends in encoder). --/ - -section CrossAttention - -variable {EncPos DecPos : Type*} [Fintype EncPos] [Fintype DecPos] - -/-- Cross-attention mixer: tracks where each decoder position attends in encoder. -This is the fundamental building block for encoder-decoder attribution. -/ -noncomputable def Mixer.crossAttention - (w : DecPos → EncPos → NNReal) - (hw : ∀ d, ∑ e, w d e = 1) : Mixer DecPos EncPos where - w := w - row_sum_one := hw - -/-- Cross-attention preserves the row-stochastic property. -/ -theorem Mixer.crossAttention_normalized - (w : DecPos → EncPos → NNReal) - (hw : ∀ d, ∑ e, w d e = 1) (d : DecPos) : - ∑ e, (Mixer.crossAttention w hw).w d e = 1 := - hw d - -end CrossAttention - -/-! ## Layer-wise attribution analysis - -For understanding transformer behavior, it's crucial to decompose attribution -layer by layer. This section provides tools for such analysis. --/ - -section LayerAttribution - -variable {Pos : Type*} [Fintype Pos] [DecidableEq Pos] - -/-- The attribution from position i to position j through a sequence of layers. -/ -noncomputable def layerWiseAttribution - (layers : List (Mixer Pos Pos)) (i j : Pos) : NNReal := - (layers.foldl Mixer.comp Mixer.identity).w i j - -/-- Empty layer list gives identity attribution. -/ -@[simp] -theorem layerWiseAttribution_nil (i j : Pos) : - layerWiseAttribution (Pos := Pos) [] i j = Mixer.identity.w i j := rfl - -/-- Single layer attribution equals the layer's weight. -/ -theorem layerWiseAttribution_singleton (M : Mixer Pos Pos) (i j : Pos) : - layerWiseAttribution [M] i j = M.w i j := by - simp only [layerWiseAttribution, List.foldl_cons, List.foldl_nil, Mixer.identity_comp] - -/-- Attribution through layers is bounded by 1 (probability). -/ -theorem layerWiseAttribution_le_one (layers : List (Mixer Pos Pos)) (i j : Pos) : - layerWiseAttribution layers i j ≤ 1 := by - simp only [layerWiseAttribution] - have h := (layers.foldl Mixer.comp Mixer.identity).row_sum_one i - calc (layers.foldl Mixer.comp Mixer.identity).w i j - ≤ ∑ k, (layers.foldl Mixer.comp Mixer.identity).w i k := - Finset.single_le_sum (fun _ _ => zero_le _) (Finset.mem_univ j) - _ = 1 := h - -/-- **Total attribution conservation**: Sum over all targets equals 1. -This is the formal statement that "attribution mass is conserved". -/ -theorem layerWiseAttribution_sum_one (layers : List (Mixer Pos Pos)) (i : Pos) : - ∑ j, layerWiseAttribution layers i j = 1 := by - simp only [layerWiseAttribution] - exact (layers.foldl Mixer.comp Mixer.identity).row_sum_one i - -end LayerAttribution - -/-! ## Tracer uniqueness for transformer interpretation - -The key insight connecting this formalization to neural network interpretation: -the tracer uniqueness theorem (from `Uniqueness.lean`) implies that attribution -methods based on the mixer framework are uniquely determined by boundary conditions. - -This provides formal justification for attention-based interpretation methods: -if two attribution methods both propagate mass according to the same mixers -(attention patterns) and agree on boundary conditions (e.g., start with -probability 1 at the output token), then they must produce identical attributions. --/ - -section TracerInterpretation - -variable {S : Type*} [Fintype S] [DecidableEq S] - -end TracerInterpretation - -/-- If two tracer propagation methods satisfy the same mixing recurrence -on a transformer's computation graph, they must coincide. This is the -formal statement that "attention-based attribution is unique given -the attention patterns and boundary conditions." -/ -theorem transformer_attribution_unique - {S : Type*} - (n : ℕ) - (parents : Fin n → Finset (Fin n)) - (htopo : ∀ k u, u ∈ parents k → u.val < k.val) - (c : Fin n → Fin n → NNReal) - (L : LocalSystem n := ⟨parents, c, fun {i u} h => htopo i u h⟩) - (T T' : LocalSystem.TracerFamily (S := S) n) - (hT : L.Satisfies T) - (hT' : L.Satisfies T') : T = T' := - LocalSystem.tracer_unique L hT hT' - -end Nfp diff --git a/Nfp/Linear/FinFold.lean b/Nfp/Linear/FinFold.lean new file mode 100644 index 0000000..f4114f6 --- /dev/null +++ b/Nfp/Linear/FinFold.lean @@ -0,0 +1,129 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Fin +public import Mathlib.Data.Matrix.Mul +public import Mathlib.Data.Rat.BigOperators +public import Batteries.Data.Fin.Fold +public import Nfp.Core.Basic + +/-! +Tail-recursive folds and sums over `Fin`. + +These helpers keep sound computations stack-safe while remaining explicit. +-/ + +public section + +namespace Nfp + + +namespace Linear + +variable {α : Type _} + +/-- Tail-recursive fold over `Fin n`. -/ +def foldlFin (n : Nat) (f : α → Fin n → α) (init : α) : α := + Fin.dfoldl n (fun _ => α) (fun i acc => f acc i) init + +/-- `foldlFin` matches `Fin.foldl`. -/ +theorem foldlFin_eq_foldl (n : Nat) (f : α → Fin n → α) (init : α) : + foldlFin n f init = Fin.foldl n f init := by + simpa [foldlFin] using + (Fin.dfoldl_eq_foldl (n := n) (f := fun i acc => f acc i) (x := init)) + +/-- Tail-recursive sum over `Fin n` (Rat-valued). -/ +def sumFin (n : Nat) (f : Fin n → Rat) : Rat := + foldlFin n (fun acc i => acc + f i) 0 + +/-- Tail-recursive sum over `Fin n` (alias for `sumFin`). -/ +def sumFinCommonDen (n : Nat) (f : Fin n → Rat) : Rat := + sumFin n f + +/-- `sumFin` as a left fold over the finite range list. -/ +theorem sumFin_eq_list_foldl (n : Nat) (f : Fin n → Rat) : + sumFin n f = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa [sumFin, foldlFin_eq_foldl] using + (Fin.foldl_eq_foldl_finRange (f := fun acc i => acc + f i) (x := (0 : Rat)) (n := n)) + +/-- `sumFin` agrees with the `Finset.univ` sum. -/ +theorem sumFin_eq_sum_univ {n : Nat} (f : Fin n → Rat) : + sumFin n f = ∑ i, f i := by + classical + have hsum_list : + ((List.finRange n).map f).sum = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + have hmap : + ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 = + (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa using + (List.foldl_map (f := f) (g := fun acc x : Rat => acc + x) + (l := List.finRange n) (init := (0 : Rat))) + calc + ((List.finRange n).map f).sum + = ((List.finRange n).map f).foldl (fun acc x : Rat => acc + x) 0 := by + simpa using (List.sum_eq_foldl (l := (List.finRange n).map f)) + _ = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + exact hmap + have hsum_univ : ((List.finRange n).map f).sum = ∑ i, f i := by + simpa using (Fin.sum_univ_def f).symm + calc + sumFin n f + = (List.finRange n).foldl (fun acc i => acc + f i) 0 := by + simpa using sumFin_eq_list_foldl n f + _ = ((List.finRange n).map f).sum := hsum_list.symm + _ = ∑ i, f i := hsum_univ + +/-- Casting a `Finset.univ` rational sum to `Real` commutes with summation. -/ +theorem ratToReal_sum_univ {n : Nat} (f : Fin n → Rat) : + ratToReal (∑ i, f i) = ∑ i, ratToReal (f i) := by + classical + simp [ratToReal_def] + +/-- Casting a rational `sumFin` to `Real` commutes with summation. -/ +theorem ratToReal_sumFin {n : Nat} (f : Fin n → Rat) : + ratToReal (sumFin n f) = ∑ i, ratToReal (f i) := by + classical + simpa [sumFin_eq_sum_univ] using ratToReal_sum_univ (f := f) + +/-- `sumFinCommonDen` agrees with `sumFin`. -/ +theorem sumFinCommonDen_eq_sumFin (n : Nat) (f : Fin n → Rat) : + sumFinCommonDen n f = sumFin n f := by + simp [sumFinCommonDen] + +/-- Dot product over `Fin n` (Rat-valued). -/ +def dotFin (n : Nat) (x y : Fin n → Rat) : Rat := + sumFin n (fun i => x i * y i) + +/-- Unfolding lemma for `dotFin`. -/ +theorem dotFin_def (n : Nat) (x y : Fin n → Rat) : + dotFin n x y = sumFin n (fun i => x i * y i) := by + simp [dotFin] + +/-- `dotFin` matches `dotProduct`. -/ +theorem dotFin_eq_dotProduct (n : Nat) (x y : Fin n → Rat) : + dotFin n x y = dotProduct x y := by + simp [dotFin_def, sumFin_eq_sum_univ, dotProduct] + +/-- Right-distribute dot products over subtraction (Real-valued). -/ +theorem dotProduct_sub_right {n : Nat} (x y z : Fin n → Real) : + dotProduct x (fun i => y i - z i) = dotProduct x y - dotProduct x z := by + classical + simp only [dotProduct, mul_sub, Finset.sum_sub_distrib] + +/-- Right-distribute dot products over addition (Real-valued). -/ +theorem dotProduct_add_right {n : Nat} (x y z : Fin n → Real) : + dotProduct x (fun i => y i + z i) = dotProduct x y + dotProduct x z := by + classical + simp only [dotProduct, mul_add, Finset.sum_add_distrib] + +/-- Pull a constant factor out of the right-hand side of a dot product. -/ +theorem dotProduct_mul_right {n : Nat} (x y : Fin n → Real) (a : Real) : + dotProduct x (fun i => y i * a) = dotProduct x y * a := by + classical + simp only [dotProduct, mul_assoc, Finset.sum_mul] + +end Linear + + +end Nfp diff --git a/Nfp/Linearization.lean b/Nfp/Linearization.lean deleted file mode 100644 index a5bbc5e..0000000 --- a/Nfp/Linearization.lean +++ /dev/null @@ -1,2778 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.Real.Sqrt -import Mathlib.Analysis.Real.Pi.Bounds -import Mathlib.Analysis.SpecialFunctions.Trigonometric.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.BigOperators.Field -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.SignedMixer - -/-! -# Linearization of Non-Linear Operations - -Real neural networks are not linear—they use activation functions (ReLU, GeLU), -normalization layers (LayerNorm, BatchNorm), and attention softmax. To apply -our mixer-based attribution framework to real networks, we must *linearize* -these operations at a specific input. - -## Key Insight - -For any differentiable function f : ℝⁿ → ℝᵐ, at a specific input x₀, we use a -row-vector convention: - f(x) ≈ f(x₀) + (x - x₀) · J_f(x₀) - -where J_f(x₀) is the Jacobian matrix. This Jacobian is exactly a `SignedMixer`! - -For piecewise-linear functions like ReLU, the Jacobian is well-defined almost -everywhere and consists of 0s and 1s. - -## Main Definitions - -* `Linearization`: A record of (input, output, Jacobian as SignedMixer) -* `reluLinearization`: ReLU's Jacobian is diagonal with 0/1 entries -* `geluLinearization`: GeLU's Jacobian based on the GeLU derivative -* `layerNormJacobian`: Full Jacobian of LayerNorm (non-trivial!) -* `softmaxJacobian`: Jacobian of softmax for attention - -## Why This Matters - -Given a concrete forward pass through a transformer: -1. At each layer, record the activations -2. Compute the linearization (Jacobian) at those activations -3. Compose the resulting SignedMixers -4. The composition gives end-to-end attribution via the chain rule - -This connects to: -- Gradient × Input attribution -- Integrated Gradients (as a path integral of linearizations) -- Attention rollout (composition of attention Jacobians) - -## References - -- Sundararajan et al., "Axiomatic Attribution for Deep Networks" (Integrated Gradients) -- Abnar & Zuidema, "Quantifying Attention Flow in Transformers" --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-! ## Linearization Structure -/ - -/-- A linearization captures the local linear approximation of a function at a point. - -Given f : ℝⁿ → ℝᵐ and input x₀, a linearization consists of: -- `input`: The point x₀ where we linearize -- `output`: f(x₀) -- `jacobian`: The Jacobian ∂f/∂x evaluated at x₀, as a SignedMixer - -The approximation is: f(x) ≈ output + jacobian.apply(x - input) -/ -structure Linearization (n m : Type*) [Fintype n] [Fintype m] where - /-- The input point at which we linearized. -/ - input : n → ℝ - /-- The output f(input). -/ - output : m → ℝ - /-- The Jacobian matrix as a SignedMixer. jacobian.w i j = ∂f_j/∂x_i -/ - jacobian : SignedMixer n m - -namespace Linearization - -variable {n m p : Type*} [Fintype n] [Fintype m] [Fintype p] - -/-- Compose two linearizations (chain rule). -If f : ℝⁿ → ℝᵐ is linearized at x with Jacobian J_f, and - g : ℝᵐ → ℝᵖ is linearized at f(x) with Jacobian J_g, then - g ∘ f has Jacobian J_f · J_g at x (row-vector convention). -/ -noncomputable def comp - (L₁ : Linearization n m) (L₂ : Linearization m p) - (_h : L₂.input = L₁.output) : Linearization n p where - input := L₁.input - output := L₂.output - jacobian := L₁.jacobian.comp L₂.jacobian - -/-- Chain rule for composed linearizations (row-vector convention). -/ -theorem comp_apply - (L₁ : Linearization n m) (L₂ : Linearization m p) (h : L₂.input = L₁.output) - (v : n → ℝ) : - (L₁.comp L₂ h).jacobian.apply v = L₂.jacobian.apply (L₁.jacobian.apply v) := by - simpa using - (SignedMixer.apply_comp (M := L₁.jacobian) (N := L₂.jacobian) (v := v)) - -/-- The identity linearization (identity function). -/ -noncomputable def id [DecidableEq n] : Linearization n n where - input := fun _ => 0 - output := fun _ => 0 - jacobian := SignedMixer.identity - -end Linearization - -/-! ## ReLU Linearization -/ - -section ReLU - -variable {n : Type*} [Fintype n] [DecidableEq n] - -/-- The ReLU activation function: max(x, 0). -/ -noncomputable def relu (x : ℝ) : ℝ := max x 0 - -/-- The ReLU derivative: 1 if x > 0, 0 otherwise. -At x = 0, we use the subgradient convention: derivative is 0. -/ -noncomputable def reluGrad (x : ℝ) : ℝ := if x > 0 then 1 else 0 - -/-- ReLU applied elementwise to a vector. -/ -noncomputable def reluVec (v : n → ℝ) : n → ℝ := fun i => relu (v i) - -/-- The ReLU mask: which coordinates are "on" (positive). -/ -def reluMask (v : n → ℝ) : n → Prop := fun i => v i > 0 - -/-- The ReLU mask as a 0/1 indicator. -/ -noncomputable def reluMaskIndicator (v : n → ℝ) : n → ℝ := - fun i => if v i > 0 then 1 else 0 - -/-- **ReLU Linearization**: The Jacobian of ReLU is a diagonal matrix -with entries 0 or 1 based on whether the input is positive. - -This is the key insight: ReLU is piecewise linear, so its local linearization -is exact (not an approximation) within each linear region. -/ -noncomputable def reluLinearization (x : n → ℝ) : Linearization n n where - input := x - output := reluVec x - jacobian := { - w := fun i j => if i = j then reluGrad (x i) else 0 - } - -/-- The ReLU Jacobian is diagonal. -/ -theorem reluLinearization_diagonal (x : n → ℝ) (i j : n) (h : i ≠ j) : - (reluLinearization x).jacobian.w i j = 0 := by - simp [reluLinearization, h] - -/-- The ReLU Jacobian diagonal entry is 0 or 1. -/ -theorem reluLinearization_diag_binary (x : n → ℝ) (i : n) : - (reluLinearization x).jacobian.w i i = 0 ∨ - (reluLinearization x).jacobian.w i i = 1 := by - simp only [reluLinearization, reluGrad] - by_cases h : x i > 0 <;> simp [h] - -/-- ReLU preserves positive inputs exactly. -/ -theorem relu_pos {x : ℝ} (h : x > 0) : relu x = x := by - simp [relu, max_eq_left (le_of_lt h)] - -/-- ReLU kills negative inputs. -/ -theorem relu_neg {x : ℝ} (h : x ≤ 0) : relu x = 0 := by - simp [relu, max_eq_right h] - -end ReLU - -/-! ## GeLU Linearization -/ - -section GeLU - -variable {n : Type*} [Fintype n] [DecidableEq n] - -/-- The GeLU (Gaussian Error Linear Unit) activation. -GeLU(x) = x · Φ(x) where Φ is the standard normal CDF. - -We use the approximation: GeLU(x) ≈ 0.5 · x · (1 + tanh(√(2/π) · (x + 0.044715 · x³))) -This is what most implementations use. -/ -noncomputable def gelu (x : ℝ) : ℝ := - 0.5 * x * (1 + Real.tanh (Real.sqrt (2 / Real.pi) * (x + 0.044715 * x^3))) - -/-- The GeLU derivative. -d/dx[GeLU(x)] = Φ(x) + x · φ(x) -where φ is the standard normal PDF. - -For the tanh approximation, the derivative is more complex but well-defined. -/ -noncomputable def geluGrad (x : ℝ) : ℝ := - let s := Real.sqrt (2 / Real.pi) - let inner := s * (x + 0.044715 * x^3) - let tanh_inner := Real.tanh inner - let sech2_inner := 1 - tanh_inner^2 -- sech² = 1 - tanh² - let inner_deriv := s * (1 + 3 * 0.044715 * x^2) - 0.5 * (1 + tanh_inner) + 0.5 * x * sech2_inner * inner_deriv - -/-- GeLU applied elementwise. -/ -noncomputable def geluVec (v : n → ℝ) : n → ℝ := fun i => gelu (v i) - -/-- **GeLU Linearization**: The Jacobian is diagonal with entries geluGrad(x_i). -/ -noncomputable def geluLinearization (x : n → ℝ) : Linearization n n where - input := x - output := geluVec x - jacobian := { - w := fun i j => if i = j then geluGrad (x i) else 0 - } - -/-- GeLU Jacobian is diagonal. -/ -theorem geluLinearization_diagonal (x : n → ℝ) (i j : n) (h : i ≠ j) : - (geluLinearization x).jacobian.w i j = 0 := by - simp [geluLinearization, h] - -end GeLU - -/-! ## LayerNorm Linearization -/ - -section LayerNorm - -variable {n : Type*} [Fintype n] [DecidableEq n] - -/-- Mean of a vector. -/ -noncomputable def mean (v : n → ℝ) : ℝ := - (∑ i, v i) / Fintype.card n - -/-- Variance of a vector. -/ -noncomputable def variance (v : n → ℝ) : ℝ := - let μ := mean v - (∑ i, (v i - μ)^2) / Fintype.card n - -/-- Standard deviation with epsilon for numerical stability. -/ -noncomputable def stddev (v : n → ℝ) : ℝ := Real.sqrt (variance v + 1e-5) - -/-- LayerNorm without learnable parameters (just normalization). -/ -noncomputable def layerNorm (v : n → ℝ) : n → ℝ := - let μ := mean v - let σ := stddev v - fun i => (v i - μ) / σ - -/-- LayerNorm with scale γ and bias β (per-coordinate). -/ -noncomputable def layerNormFull (γ β : n → ℝ) (v : n → ℝ) : n → ℝ := - let normalized := layerNorm v - fun i => γ i * normalized i + β i - -/-- **LayerNorm Jacobian**: This is the key non-trivial result. - -∂(LayerNorm(x))_j / ∂x_i = (1/σ) · [δ_{ij} - 1/n - (x_j - μ)(x_i - μ)/(n·σ²)] - -where δ_{ij} is 1 if i=j, 0 otherwise. - -This shows LayerNorm creates *dense* dependencies: every output depends on every input! -This is fundamentally different from ReLU/GeLU which are diagonal. -/ -noncomputable def layerNormJacobian (x : n → ℝ) : SignedMixer n n where - w := fun i j => - let μ := mean x - let σ := stddev x - let n_inv := (1 : ℝ) / Fintype.card n - let centered_i := x i - μ - let centered_j := x j - μ - let diagonal := if i = j then 1 else 0 - (1 / σ) * (diagonal - n_inv - centered_j * centered_i / (Fintype.card n * σ^2)) - -/-- Diagonal linear map as a `SignedMixer`: x · diag d (row-vector convention). -/ -noncomputable def diagMixer (d : n → ℝ) : SignedMixer n n where - w := fun i j => if i = j then d j else 0 - -/-- Operator norm bound for a diagonal mixer from a uniform entry bound. -/ -theorem operatorNormBound_diagMixer_le [Nonempty n] (d : n → ℝ) (b : ℝ) - (h : ∀ i, |d i| ≤ b) : - SignedMixer.operatorNormBound (diagMixer d) ≤ b := by - classical - dsimp [SignedMixer.operatorNormBound] - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n)) - (f := fun i => ∑ j, |(diagMixer d).w i j|) - (a := b)).2 ?_ - intro i hi - have hsum : (∑ j, |(diagMixer d).w i j|) = |d i| := by - have hsum' : (∑ j, |(diagMixer d).w i j|) = |(diagMixer d).w i i| := by - refine Fintype.sum_eq_single i ?_ - intro j hne - have hne' : i ≠ j := by - simpa [ne_comm] using hne - simp [diagMixer, hne'] - simpa [diagMixer] using hsum' - simpa [hsum] using h i - -/-- Operator norm bound for `A ∘ diag(d) ∘ B` from component bounds. -/ -theorem operatorNormBound_comp_diagMixer_comp_le - {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] [DecidableEq T] - [Nonempty S] [Nonempty T] - (A : SignedMixer S T) (B : SignedMixer T U) (d : T → ℝ) - (a c b : ℝ) - (hA : SignedMixer.operatorNormBound A ≤ a) - (hB : SignedMixer.operatorNormBound B ≤ b) - (hD : ∀ i, |d i| ≤ c) : - SignedMixer.operatorNormBound ((A.comp (diagMixer d)).comp B) ≤ a * c * b := by - classical - have hD' : SignedMixer.operatorNormBound (diagMixer d) ≤ c := - operatorNormBound_diagMixer_le (d := d) (b := c) hD - have hA_nonneg : 0 ≤ a := - le_trans (SignedMixer.operatorNormBound_nonneg (M := A)) hA - have hB_nonneg : 0 ≤ b := - le_trans (SignedMixer.operatorNormBound_nonneg (M := B)) hB - have hC_nonneg : 0 ≤ c := - le_trans (SignedMixer.operatorNormBound_nonneg (M := diagMixer d)) hD' - have hcomp : - SignedMixer.operatorNormBound ((A.comp (diagMixer d)).comp B) ≤ - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) * - SignedMixer.operatorNormBound B := by - simpa using - (SignedMixer.operatorNormBound_comp3_le - (A := A) (B := diagMixer d) (C := B)) - have hmul1 : - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) ≤ a * c := by - have h1 : - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) - ≤ a * SignedMixer.operatorNormBound (diagMixer d) := by - exact mul_le_mul_of_nonneg_right hA - (SignedMixer.operatorNormBound_nonneg (M := diagMixer d)) - have h2 : - a * SignedMixer.operatorNormBound (diagMixer d) ≤ a * c := by - exact mul_le_mul_of_nonneg_left hD' hA_nonneg - exact le_trans h1 h2 - have hmul2 : - (SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d)) * - SignedMixer.operatorNormBound B ≤ (a * c) * b := by - have h1 : - (SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d)) * - SignedMixer.operatorNormBound B ≤ - (a * c) * SignedMixer.operatorNormBound B := by - exact mul_le_mul_of_nonneg_right hmul1 - (SignedMixer.operatorNormBound_nonneg (M := B)) - have h2 : (a * c) * SignedMixer.operatorNormBound B ≤ (a * c) * b := by - exact mul_le_mul_of_nonneg_left hB (mul_nonneg hA_nonneg hC_nonneg) - exact le_trans h1 h2 - have hmul2' : - SignedMixer.operatorNormBound A * - SignedMixer.operatorNormBound (diagMixer d) * - SignedMixer.operatorNormBound B ≤ a * c * b := by - simpa [mul_assoc] using hmul2 - exact le_trans hcomp hmul2' - -/-- Jacobian of LayerNorm with learnable scale γ (bias β has no effect on Jacobian). -/ -noncomputable def layerNormFullJacobian (γ : n → ℝ) (x : n → ℝ) : SignedMixer n n := - (layerNormJacobian x).comp (diagMixer γ) - -/-- LayerNorm-with-affine linearization at a specific input. -/ -noncomputable def layerNormFullLinearization (γ β : n → ℝ) (x : n → ℝ) : Linearization n n where - input := x - output := layerNormFull γ β x - jacobian := layerNormFullJacobian γ x - -/-- LayerNorm linearization at a specific input. -/ -noncomputable def layerNormLinearization (x : n → ℝ) : Linearization n n where - input := x - output := layerNorm x - jacobian := layerNormJacobian x - -omit [DecidableEq n] in -/-- **Key insight**: LayerNorm is translation-invariant: LayerNorm(x + c·1) = LayerNorm(x). -In row-vector convention, this corresponds to the Jacobian columns summing to 0. -/ -theorem layerNorm_translation_invariant [Nonempty n] (x : n → ℝ) (c : ℝ) : - layerNorm (fun i => x i + c) = layerNorm x := by - ext i - simp only [layerNorm, mean, variance, stddev] - -- First show: mean(x + c) = mean(x) + c - have hmean : (∑ j, (x j + c)) / Fintype.card n = (∑ j, x j) / Fintype.card n + c := by - rw [Finset.sum_add_distrib] - simp only [Finset.sum_const, Finset.card_univ, nsmul_eq_mul] - field_simp - -- The centered value (x i + c) - mean(x + c) = x i - mean(x) - have hcentered : ∀ j, (x j + c) - (∑ k, (x k + c)) / Fintype.card n = - x j - (∑ k, x k) / Fintype.card n := by - intro j - rw [hmean] - ring - -- Therefore variance is unchanged - have hvar : (∑ j, ((x j + c) - (∑ k, (x k + c)) / Fintype.card n)^2) / Fintype.card n = - (∑ j, (x j - (∑ k, x k) / Fintype.card n)^2) / Fintype.card n := by - congr 1 - apply Finset.sum_congr rfl - intro j _ - rw [hcentered] - -- So stddev is unchanged, and the final result follows - simp only [hcentered] - -end LayerNorm - -/-! ## Token-wise LayerNorm on the Residual Stream -/ - -section TokenwiseLayerNorm - -variable {pos d : Type*} [Fintype pos] [DecidableEq pos] [Fintype d] [DecidableEq d] - -/-- Lift a per-token LayerNorm Jacobian to the full residual stream `(pos × d)`. - -This is block-diagonal across positions: coordinates at different positions do not mix. --/ -noncomputable def tokenwiseLayerNormFullJacobian (γ : d → ℝ) (x : pos × d → ℝ) : - SignedMixer (pos × d) (pos × d) where - w := fun i j => - if i.1 = j.1 then - let p : pos := i.1 - (layerNormFullJacobian (n := d) γ (fun k => x (p, k))).w i.2 j.2 - else 0 - -end TokenwiseLayerNorm - -/-! ## Rotary Position Embeddings (RoPE) -/ - -section RoPE - -variable {pos pair : Type*} - [Fintype pos] [DecidableEq pos] - [Fintype pair] [DecidableEq pair] - -/-- RoPE uses 2D rotations on each `(pairIdx, Bool)` coordinate pair. -/ -abbrev RoPEDim (pair : Type*) := pair × Bool - -/-- The RoPE linear map as a `SignedMixer` on the residual stream `(pos × (pair × Bool))`. - -For each position `p` and pair index `k`, this applies the 2×2 rotation with angle `θ p k`: -`(x₀, x₁) ↦ (cos θ · x₀ - sin θ · x₁, sin θ · x₀ + cos θ · x₁)`. - -This is tokenwise (block-diagonal across `pos`): different positions never mix. -/ -noncomputable def ropeJacobian (θ : pos → pair → ℝ) : - SignedMixer (pos × RoPEDim pair) (pos × RoPEDim pair) where - w := fun i j => - if i.1 = j.1 then - if i.2.1 = j.2.1 then - let p : pos := j.1 - let k : pair := j.2.1 - let ang := θ p k - match i.2.2, j.2.2 with - | false, false => Real.cos ang - | true, false => -Real.sin ang - | false, true => Real.sin ang - | true, true => Real.cos ang - else 0 - else 0 - -/-- RoPE forward map: apply the RoPE Jacobian as a linear operator. -/ -noncomputable def rope (θ : pos → pair → ℝ) (x : pos × RoPEDim pair → ℝ) : - pos × RoPEDim pair → ℝ := - (ropeJacobian (pos := pos) (pair := pair) θ).apply x - -@[simp] lemma ropeJacobian_cross_pos (θ : pos → pair → ℝ) - {i j : pos × RoPEDim pair} (h : i.1 ≠ j.1) : - (ropeJacobian (pos := pos) (pair := pair) θ).w i j = 0 := by - simp [ropeJacobian, h] - -end RoPE - -/-! ## Softmax Linearization -/ - -section Softmax - -variable {n : Type*} [Fintype n] - -/-- Softmax function: softmax(x)_j = exp(x_j) / Σ_k exp(x_k) -/ -noncomputable def softmax (v : n → ℝ) : n → ℝ := - let expSum := ∑ k, Real.exp (v k) - fun j => Real.exp (v j) / expSum - -variable [DecidableEq n] - -/-- **Softmax Jacobian**: ∂softmax(x)_j / ∂x_i = softmax(x)_j · (δ_{ij} - softmax(x)_i) - -This is a classic result. The Jacobian depends on the softmax *output*, not input! -/ -noncomputable def softmaxJacobian (x : n → ℝ) : SignedMixer n n where - w := fun i j => - let p := softmax x - p j * ((if i = j then 1 else 0) - p i) - -/-- Softmax linearization. -/ -noncomputable def softmaxLinearization (x : n → ℝ) : Linearization n n where - input := x - output := softmax x - jacobian := softmaxJacobian x - -omit [DecidableEq n] in -/-- Softmax outputs are nonnegative. -/ -theorem softmax_nonneg (x : n → ℝ) (j : n) : softmax x j ≥ 0 := by - simp only [softmax] - apply div_nonneg (Real.exp_nonneg _) - exact Finset.sum_nonneg (fun _ _ => Real.exp_nonneg _) - -omit [DecidableEq n] in -/-- Softmax outputs sum to 1. -/ -theorem softmax_sum_one [Nonempty n] (x : n → ℝ) : ∑ j, softmax x j = 1 := by - simp only [softmax] - rw [← Finset.sum_div] - apply div_self - apply ne_of_gt - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - -/-- Softmax Jacobian diagonal entries are positive (when p_j < 1). -/ -theorem softmaxJacobian_diag_pos [Nonempty n] (x : n → ℝ) (j : n) - (h : softmax x j < 1) : (softmaxJacobian x).w j j > 0 := by - simp only [softmaxJacobian, ite_true] - -- p_j · (1 - p_j) > 0 when 0 < p_j < 1 - have hp : softmax x j > 0 := by - simp only [softmax] - apply div_pos (Real.exp_pos _) - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - apply mul_pos hp - linarith - -/-- Softmax Jacobian off-diagonal entries are negative. -/ -theorem softmaxJacobian_off_diag_neg [Nonempty n] (x : n → ℝ) (i j : n) (h : i ≠ j) : - (softmaxJacobian x).w i j < 0 := by - simp only [softmaxJacobian, if_neg h] - -- p_j · (0 - p_i) = -p_j · p_i < 0 - have hpj : softmax x j > 0 := by - simp only [softmax] - apply div_pos (Real.exp_pos _) - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - have hpi : softmax x i > 0 := by - simp only [softmax] - apply div_pos (Real.exp_pos _) - apply Finset.sum_pos (fun _ _ => Real.exp_pos _) Finset.univ_nonempty - linarith [mul_pos hpj hpi] - -omit [DecidableEq n] in -/-- Softmax is translation-invariant: softmax(x + c·1) = softmax(x). -/ -theorem softmax_translation_invariant (x : n → ℝ) (c : ℝ) : - softmax (fun i => x i + c) = softmax x := by - ext j - simp only [softmax, Real.exp_add] - -- exp(x_j + c) / Σ exp(x_k + c) = exp(x_j) · exp(c) / (exp(c) · Σ exp(x_k)) - -- = exp(x_j) / Σ exp(x_k) - have h : ∑ x_1 : n, Real.exp (x x_1) * Real.exp c = - Real.exp c * ∑ k : n, Real.exp (x k) := by - rw [Finset.mul_sum] - congr 1 - ext k - ring - rw [h] - field_simp - -end Softmax - -/-! ## Attribution via Linearization -/ - -section Attribution - -variable {n m : Type*} [Fintype n] [Fintype m] - -/-- Given a full forward pass linearization, compute feature attributions. - -The attribution of input feature i to output feature j is: - attr(i, j) = input_i × ∂output_j/∂input_i - -This is "Gradient × Input" attribution. -/ -noncomputable def gradientTimesInput (L : Linearization n m) (i : n) (j : m) : ℝ := - L.input i * L.jacobian.w i j - -/-- Sum of gradient×input attributions equals output (for linear function). -This is the completeness axiom from our Attribution module! -/ -theorem gradientTimesInput_complete (L : Linearization n n) - (hLinear : L.output = L.jacobian.apply L.input) (j : n) : - ∑ i, gradientTimesInput L i j = L.output j := by - simp only [gradientTimesInput, hLinear, SignedMixer.apply_def] - -/-- For composed linearizations, the chain rule gives: - ∂output/∂input = J_first · J_{next} · ... · J_last - -This is exactly `SignedMixer.comp` under the row-vector convention. -/ -theorem composed_attribution {p : Type*} [Fintype p] - (L₁ : Linearization n m) (L₂ : Linearization m p) - (h : L₂.input = L₁.output) : - (L₁.comp L₂ h).jacobian = L₁.jacobian.comp L₂.jacobian := rfl - -end Attribution - -/-! ## Full Attention Jacobian Decomposition -/ - -section AttentionJacobian - -/-! -### The Key Insight - -In a self-attention layer, the output for query position q is (before W_O): - attnOut_q = Σ_k A_{qk} · V_k = Σ_k A_{qk} · (x_k · W_V) - -where A_{qk} = softmax(Q_q · K_k^T / √d)_k - -The Jacobian ∂output/∂input has two fundamentally different contributions: -1. **Value term**: ∂output/∂V · ∂V/∂x = A · (W_V · W_O) - (attention weights × value/output projections) -2. **Pattern term**: ∂output/∂A · ∂A/∂x (how changing x shifts the attention pattern) - -The **Value term** is what "Attention Rollout" uses—it treats attention weights A as fixed. -The **Pattern term** captures how attention patterns themselves shift with input changes. - -**Key result**: We can bound the Pattern term, and when it's small relative to the -Value term, Attention Rollout is a "faithful" explanation. --/ - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- The dimensionality of the model (as a real number for scaling). -/ -noncomputable def modelDim (d : Type*) [Fintype d] : ℝ := Fintype.card d - -/-- Full self-attention layer with concrete projection matrices. - -This captures the complete attention mechanism: - Q = x · W_Q, K = x · W_K, V = x · W_V - scores = Q · K^T / √d - A = softmax(scores) - output = A · V · W_O - -We index by: -- `n`: sequence positions (tokens) -- `d`: hidden dimension -/ -structure FullAttentionLayer (n d : Type*) [Fintype n] [Fintype d] where - /-- Query projection W_Q : d → d -/ - W_Q : SignedMixer d d - /-- Key projection W_K : d → d -/ - W_K : SignedMixer d d - /-- Value projection W_V : d → d -/ - W_V : SignedMixer d d - /-- Output projection W_O : d → d -/ - W_O : SignedMixer d d - -/-- Attention forward pass state: captures all intermediate values at a specific input. - -This is what we need to compute the Jacobian—we linearize around these specific values. -/ -structure AttentionForwardState (n d : Type*) [Fintype n] [Fintype d] where - /-- Input hidden states: x[position, hidden_dim] -/ - input : n → d → ℝ - /-- Queries after projection: Q = x · W_Q -/ - queries : n → d → ℝ - /-- Keys after projection: K = x · W_K -/ - keys : n → d → ℝ - /-- Values after projection: V = x · W_V -/ - values : n → d → ℝ - /-- Raw attention scores (before softmax): scores_{qk} = Q_q · K_k^T / √d -/ - scores : n → n → ℝ - /-- Attention weights (after softmax): A_{qk} = softmax(scores_q)_k -/ - attentionWeights : n → n → ℝ - /-- Output before W_O: Σ_k A_{qk} · V_k -/ - attentionOutput : n → d → ℝ - /-- Final output: attentionOutput · W_O -/ - output : n → d → ℝ - -/-- Compute the forward pass for a full attention layer. -/ -noncomputable def attentionForward - (layer : FullAttentionLayer n d) (x : n → d → ℝ) : AttentionForwardState n d where - input := x - queries := fun pos dim => ∑ d', x pos d' * layer.W_Q.w d' dim - keys := fun pos dim => ∑ d', x pos d' * layer.W_K.w d' dim - values := fun pos dim => ∑ d', x pos d' * layer.W_V.w d' dim - scores := fun q k => - let Q_q := fun dim => ∑ d', x q d' * layer.W_Q.w d' dim - let K_k := fun dim => ∑ d', x k d' * layer.W_K.w d' dim - (∑ dim, Q_q dim * K_k dim) / Real.sqrt (modelDim d) - attentionWeights := fun q k => - let rawScores := fun k' => - let Q_q := fun dim => ∑ d', x q d' * layer.W_Q.w d' dim - let K_k' := fun dim => ∑ d', x k' d' * layer.W_K.w d' dim - (∑ dim, Q_q dim * K_k' dim) / Real.sqrt (modelDim d) - softmax rawScores k - attentionOutput := fun q dim => - let A := fun k => - let rawScores := fun k' => - let Q_q := fun dim' => ∑ d', x q d' * layer.W_Q.w d' dim' - let K_k' := fun dim' => ∑ d', x k' d' * layer.W_K.w d' dim' - (∑ dim', Q_q dim' * K_k' dim') / Real.sqrt (modelDim d) - softmax rawScores k - ∑ k, A k * (∑ d', x k d' * layer.W_V.w d' dim) - output := fun q dim => - let attnOut := fun dim' => - let A := fun k => - let rawScores := fun k' => - let Q_q := fun d'' => ∑ d', x q d' * layer.W_Q.w d' d'' - let K_k' := fun d'' => ∑ d', x k' d' * layer.W_K.w d' d'' - (∑ d'', Q_q d'' * K_k' d'') / Real.sqrt (modelDim d) - softmax rawScores k - ∑ k, A k * (∑ d', x k d' * layer.W_V.w d' dim') - ∑ dim', attnOut dim' * layer.W_O.w dim' dim - -/-- **Extended Attention Linearization** with full projection matrices and intermediates. - -This captures everything needed to decompose the Jacobian into Value and Pattern terms. -/ -structure AttentionLinearization (n d : Type*) [Fintype n] [Fintype d] where - /-- The attention layer definition -/ - layer : FullAttentionLayer n d - /-- The forward state at a specific input -/ - state : AttentionForwardState n d - /-- The full Jacobian of the attention layer at this input. - Maps (position × dim) → (position × dim). -/ - fullJacobian : SignedMixer (n × d) (n × d) - -/-! ### The Value Term -/ - -/-- **Value Term** of the attention Jacobian. - -This is the Jacobian when we treat attention weights A as fixed constants. -It corresponds to "Attention Rollout" interpretability. - -For output position (q, dim_out), input position (k, dim_in): - ValueTerm_{(q,dim_out), (k,dim_in)} = A_{qk} · (W_V · W_O)_{dim_in, dim_out} - -This measures: "How much does input at position k flow to output at position q, -weighted by the attention A_{qk} and projected through value/output matrices?" -/ -noncomputable def valueTerm (L : AttentionLinearization n d) : SignedMixer (n × d) (n × d) where - w := fun ⟨k, dim_in⟩ ⟨q, dim_out⟩ => - L.state.attentionWeights q k * (L.layer.W_V.comp L.layer.W_O).w dim_in dim_out - -omit [DecidableEq n] [DecidableEq d] in -/-- The Value Term is a tensor product: A ⊗ (W_V · W_O). -This structure is why attention weights alone (Attention Rollout) make sense: -position mixing is captured by A, dimension mixing by W_V · W_O. -/ -theorem valueTerm_factorizes (L : AttentionLinearization n d) (q k : n) (d_in d_out : d) : - (valueTerm L).w (k, d_in) (q, d_out) = - L.state.attentionWeights q k * (L.layer.W_V.comp L.layer.W_O).w d_in d_out := rfl - -omit [DecidableEq n] [DecidableEq d] in -/-- Row absolute sum for the Value Term splits into attention column mass and value row mass. -/ -theorem valueTerm_rowAbsSum (L : AttentionLinearization n d) (k : n) (d_in : d) : - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) = - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := by - classical - let voEntry : d → d → ℝ := - fun d_in d_out => ∑ j, L.layer.W_V.w d_in j * L.layer.W_O.w j d_out - have hprod : - (∑ q, |L.state.attentionWeights q k|) * - ∑ d_out, |voEntry d_in d_out| = - ∑ q, ∑ d_out, - |L.state.attentionWeights q k| * |voEntry d_in d_out| := by - simpa using - (Fintype.sum_mul_sum - (f := fun q => |L.state.attentionWeights q k|) - (g := fun d_out => |voEntry d_in d_out|)) - have hrow : - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) = - ∑ x : n × d, - |L.state.attentionWeights x.1 k| * |voEntry d_in x.2| := by - simp [SignedMixer.rowAbsSum, valueTerm, abs_mul, voEntry, SignedMixer.comp_w] - have hrow' : - ∑ x : n × d, - |L.state.attentionWeights x.1 k| * |voEntry d_in x.2| = - ∑ q, ∑ d_out, - |L.state.attentionWeights q k| * |voEntry d_in d_out| := by - simpa using - (Fintype.sum_prod_type' - (f := fun q d_out => - |L.state.attentionWeights q k| * |voEntry d_in d_out|)) - calc - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) - = ∑ q, ∑ d_out, - |L.state.attentionWeights q k| * - |(L.layer.W_V.comp L.layer.W_O).w d_in d_out| := by - simpa [hrow'] using hrow - _ = (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := by - simpa [SignedMixer.rowAbsSum, voEntry, SignedMixer.comp_w] using hprod.symm - -omit [DecidableEq n] [DecidableEq d] in -/-- Value-term operator-norm bound from attention column mass and value projection bound. -/ -theorem valueTerm_operatorNormBound_le [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (A B : ℝ) - (hAttn : ∀ k, ∑ q, |L.state.attentionWeights q k| ≤ A) - (hVO : SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B) : - SignedMixer.operatorNormBound (valueTerm L) ≤ A * B := by - classical - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n × d)) - (f := fun i => SignedMixer.rowAbsSum (valueTerm L) i) - (a := A * B)).2 ?_ - intro kd hkd - rcases kd with ⟨k, d_in⟩ - have hRow : - SignedMixer.rowAbsSum (valueTerm L) (k, d_in) = - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := - valueTerm_rowAbsSum (L := L) k d_in - have hAttn_nonneg : 0 ≤ A := by - rcases (inferInstance : Nonempty n) with ⟨k0⟩ - have hsum_nonneg : - 0 ≤ ∑ q, |L.state.attentionWeights q k0| := by - exact Finset.sum_nonneg (fun _ _ => abs_nonneg _) - exact le_trans hsum_nonneg (hAttn k0) - have hVOnonneg : - 0 ≤ SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := - SignedMixer.rowAbsSum_nonneg (M := L.layer.W_V.comp L.layer.W_O) d_in - have hVOrow : - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ B := by - have hsup : - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun i => SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) i) (by simp) - exact le_trans hsup hVO - have hMul1 : - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ - A * SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in := by - exact mul_le_mul_of_nonneg_right (hAttn k) hVOnonneg - have hMul2 : - A * SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ A * B := by - exact mul_le_mul_of_nonneg_left hVOrow hAttn_nonneg - have hMul : - (∑ q, |L.state.attentionWeights q k|) * - SignedMixer.rowAbsSum (L.layer.W_V.comp L.layer.W_O) d_in ≤ A * B := by - exact le_trans hMul1 hMul2 - simpa [hRow] using hMul - -/-! ### The Pattern Term -/ - -/-- **Pattern Term** of the attention Jacobian. - -This captures how the attention pattern A itself changes as input changes. -It involves the softmax Jacobian and the query/key gradients. - -∂A_{qk}/∂x_{i,d} = Σ_k' ∂A_{qk}/∂scores_{qk'} · ∂scores_{qk'}/∂x_{i,d} - -where: -- ∂A/∂scores is the softmax Jacobian -- ∂scores/∂x involves W_Q and W_K - -The Pattern Term is the contribution of this to the overall Jacobian. -/ -noncomputable def patternTerm (L : AttentionLinearization n d) : SignedMixer (n × d) (n × d) where - w := fun ⟨i, d_in⟩ ⟨q, d_out⟩ => - -- This is the complex term: how changing x_{i,d_in} shifts attention, - -- and how that shifted attention affects output_{q,d_out} - -- - -- Full formula: - -- Σ_k Σ_k' (∂output_{q,d_out}/∂A_{qk}) · (∂A_{qk}/∂scores_{qk'}) · (∂scores_{qk'}/∂x_{i,d_in}) - -- - -- = Σ_k Σ_k' V_{k,d_out'} · W_O_{d_out',d_out} · softmaxJac_{qkk'} · scoreGrad_{qk',i,d_in} - - -- For now, we define it implicitly as fullJacobian - valueTerm - L.fullJacobian.w (i, d_in) (q, d_out) - (valueTerm L).w (i, d_in) (q, d_out) - -omit [DecidableEq n] [DecidableEq d] in -/-- **The Fundamental Decomposition**: The full Jacobian equals Value Term + Pattern Term. - -This is the core insight: attention Jacobian = how values flow + how attention shifts. -When the Pattern Term is small, attention weights alone explain the network's behavior. -/ -theorem attention_jacobian_decomposition (L : AttentionLinearization n d) : - L.fullJacobian = valueTerm L + patternTerm L := by - ext ⟨i, d_in⟩ ⟨q, d_out⟩ - simp only [SignedMixer.add_w, valueTerm, patternTerm] - ring - -omit [DecidableEq n] [DecidableEq d] in -/-- Operator-norm bound for the full attention Jacobian from Value/Pattern term bounds. -/ -theorem attention_fullJacobian_bound_of_terms [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) {A B : ℝ} - (hValue : SignedMixer.operatorNormBound (valueTerm L) ≤ A) - (hPattern : SignedMixer.operatorNormBound (patternTerm L) ≤ B) : - SignedMixer.operatorNormBound L.fullJacobian ≤ A + B := by - have hdecomp := attention_jacobian_decomposition (L := L) - calc - SignedMixer.operatorNormBound L.fullJacobian = - SignedMixer.operatorNormBound (valueTerm L + patternTerm L) := by - simp [hdecomp] - _ ≤ SignedMixer.operatorNormBound (valueTerm L) + - SignedMixer.operatorNormBound (patternTerm L) := by - simpa using - (SignedMixer.operatorNormBound_add_le - (M := valueTerm L) (N := patternTerm L)) - _ ≤ A + B := add_le_add hValue hPattern - -/-! ### Score Gradient -/ - -/-- Gradient of attention scores with respect to input. - -∂scores_{qk}/∂x_{i,d} = (1/√d) · [δ_{qi} · Σ_d' W_Q_{d,d'} · K_k[d'] - + δ_{ki} · Σ_d' Q_q[d'] · W_K_{d,d'}] - -The score gradient is nonzero only when i = q (query position) or i = k (key position). -/ -noncomputable def scoreGradient (L : AttentionLinearization n d) - (q k i : n) (d_in : d) : ℝ := - let scale := 1 / Real.sqrt (modelDim d) - let queryContrib := if q = i then - ∑ d', L.layer.W_Q.w d_in d' * L.state.keys k d' - else 0 - let keyContrib := if k = i then - ∑ d', L.state.queries q d' * L.layer.W_K.w d_in d' - else 0 - scale * (queryContrib + keyContrib) - -omit [DecidableEq d] in -/-- Score gradient is local: only the query and key positions contribute. -/ -theorem scoreGradient_local (L : AttentionLinearization n d) - (q k i : n) (d_in : d) (hq : q ≠ i) (hk : k ≠ i) : - scoreGradient L q k i d_in = 0 := by - simp [scoreGradient, hq, hk] - -/-! ### Attention Pattern Gradient -/ - -/-- Gradient of attention weights with respect to input, using the softmax Jacobian. - -∂A_{qk}/∂x_{i,d} = Σ_k' softmaxJac(scores_q)_{k,k'} · ∂scores_{qk'}/∂x_{i,d} - = Σ_k' A_{qk}·(δ_{kk'} - A_{qk'}) · scoreGrad_{qk',i,d} - -Note: This involves the full softmax Jacobian evaluated at the scores. -/ -noncomputable def attentionGradient (L : AttentionLinearization n d) - (q k i : n) (d_in : d) : ℝ := - let A_q := L.state.attentionWeights q -- attention distribution for query q - ∑ k', A_q k * ((if k = k' then 1 else 0) - A_q k') * scoreGradient L q k' i d_in - -omit [DecidableEq d] in -/-- The attention gradient relates to the softmax Jacobian. - -Note: This requires the consistency property that -`L.state.attentionWeights q = softmax (L.state.scores q)`, -which we state as a separate condition. -/ -theorem attentionGradient_via_softmax (L : AttentionLinearization n d) (q k i : n) (d_in : d) - (hConsistent : L.state.attentionWeights q = softmax (L.state.scores q)) : - attentionGradient L q k i d_in = - ∑ k', (softmaxJacobian (L.state.scores q)).w k' k * scoreGradient L q k' i d_in := by - simp only [attentionGradient, softmaxJacobian, hConsistent] - congr 1 - ext k' - by_cases h : k = k' - · simp [h] - · have hne : k' ≠ k := fun h' => h h'.symm - simp [h, hne] - -omit [DecidableEq n] [DecidableEq d] in -/-- Sum of absolute values after applying a signed mixer is controlled by the operator norm. -/ -theorem sum_abs_apply_le {S T : Type*} [Fintype S] [Fintype T] [Nonempty S] - (M : SignedMixer S T) (v : S → ℝ) : - ∑ j, |M.apply v j| ≤ (∑ i, |v i|) * SignedMixer.operatorNormBound M := by - classical - have hterm : ∀ j, |M.apply v j| ≤ ∑ i, |v i| * |M.w i j| := by - intro j - have h := - (abs_sum_le_sum_abs (f := fun i => v i * M.w i j) (s := Finset.univ)) - simpa [SignedMixer.apply_def, abs_mul] using h - have hsum : - ∑ j, |M.apply v j| ≤ ∑ j, ∑ i, |v i| * |M.w i j| := by - refine Finset.sum_le_sum ?_ - intro j _hj - exact hterm j - have hswap : - (∑ j, ∑ i, |v i| * |M.w i j|) = - ∑ i, |v i| * (∑ j, |M.w i j|) := by - calc - (∑ j, ∑ i, |v i| * |M.w i j|) - = ∑ i, ∑ j, |v i| * |M.w i j| := by - simpa using - (Finset.sum_comm (s := Finset.univ) (t := Finset.univ) - (f := fun j i => |v i| * |M.w i j|)) - _ = ∑ i, |v i| * (∑ j, |M.w i j|) := by - refine Finset.sum_congr rfl ?_ - intro i _hi - simp [Finset.mul_sum] - have hrow : - ∀ i, (∑ j, |M.w i j|) ≤ SignedMixer.operatorNormBound M := by - intro i - exact Finset.le_sup' (s := Finset.univ) - (f := fun i => SignedMixer.rowAbsSum M i) (by simp) - have hfinal : - ∑ i, |v i| * (∑ j, |M.w i j|) ≤ - ∑ i, |v i| * SignedMixer.operatorNormBound M := by - refine Finset.sum_le_sum ?_ - intro i _hi - have hnonneg : 0 ≤ |v i| := abs_nonneg _ - exact mul_le_mul_of_nonneg_left (hrow i) hnonneg - have hmul : - ∑ i, |v i| * SignedMixer.operatorNormBound M = - (∑ i, |v i|) * SignedMixer.operatorNormBound M := by - simp [Finset.sum_mul] - calc - ∑ j, |M.apply v j| ≤ ∑ j, ∑ i, |v i| * |M.w i j| := hsum - _ = ∑ i, |v i| * (∑ j, |M.w i j|) := hswap - _ ≤ ∑ i, |v i| * SignedMixer.operatorNormBound M := hfinal - _ = (∑ i, |v i|) * SignedMixer.operatorNormBound M := hmul - -omit [DecidableEq n] [DecidableEq d] in -/-- L1 bound on keys from inputs and the W_K operator norm. -/ -theorem keys_sum_abs_le_of_input [Nonempty d] - (L : AttentionLinearization n d) (k : n) - (hKeys : - ∀ d', L.state.keys k d' = - ∑ d_in, L.state.input k d_in * L.layer.W_K.w d_in d') : - ∑ d', |L.state.keys k d'| ≤ - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound L.layer.W_K := by - classical - let v : d → ℝ := fun d_in => L.state.input k d_in - have hKeys' : - ∀ d', ∑ d_in, L.state.input k d_in * L.layer.W_K.w d_in d' = - L.state.keys k d' := by - intro d' - exact (hKeys d').symm - have hSum := sum_abs_apply_le (M := L.layer.W_K) (v := v) - simpa [v, SignedMixer.apply_def, hKeys'] using hSum - -omit [DecidableEq n] [DecidableEq d] in -/-- L1 bound on queries from inputs and the W_Q operator norm. -/ -theorem queries_sum_abs_le_of_input [Nonempty d] - (L : AttentionLinearization n d) (q : n) - (hQueries : - ∀ d', L.state.queries q d' = - ∑ d_in, L.state.input q d_in * L.layer.W_Q.w d_in d') : - ∑ d', |L.state.queries q d'| ≤ - (∑ d_in, |L.state.input q d_in|) * - SignedMixer.operatorNormBound L.layer.W_Q := by - classical - let v : d → ℝ := fun d_in => L.state.input q d_in - have hQueries' : - ∀ d', ∑ d_in, L.state.input q d_in * L.layer.W_Q.w d_in d' = - L.state.queries q d' := by - intro d' - exact (hQueries d').symm - have hSum := sum_abs_apply_le (M := L.layer.W_Q) (v := v) - simpa [v, SignedMixer.apply_def, hQueries'] using hSum - -omit [DecidableEq n] [DecidableEq d] in -/-- L1 bound on values from inputs and the W_V operator norm. -/ -theorem values_sum_abs_le_of_input [Nonempty d] - (L : AttentionLinearization n d) (k : n) - (hValues : - ∀ d', L.state.values k d' = - ∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') : - ∑ d', |L.state.values k d'| ≤ - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound L.layer.W_V := by - classical - let v : d → ℝ := fun d_in => L.state.input k d_in - have hValues' : - ∀ d', ∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d' = - L.state.values k d' := by - intro d' - exact (hValues d').symm - have hSum := sum_abs_apply_le (M := L.layer.W_V) (v := v) - simpa [v, SignedMixer.apply_def, hValues'] using hSum - -omit [DecidableEq d] in -/-- Score-gradient entry bound from input L1 bounds and W_Q/W_K operator norms. -/ -theorem scoreGradient_abs_le_of_input [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (q k i : n) (d_in : d) - (B wq wk : ℝ) - (hInput : ∀ pos, ∑ d', |L.state.input pos d'| ≤ B) - (hKeys : - ∀ pos d', L.state.keys pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_K.w d_in d') - (hQueries : - ∀ pos d', L.state.queries pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_Q.w d_in d') - (hWQ : SignedMixer.operatorNormBound L.layer.W_Q ≤ wq) - (hWK : SignedMixer.operatorNormBound L.layer.W_K ≤ wk) : - |scoreGradient L q k i d_in| ≤ - (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by - classical - let scale : ℝ := 1 / Real.sqrt (modelDim d) - let queryContrib : ℝ := - ∑ d', L.layer.W_Q.w d_in d' * L.state.keys k d' - let keyContrib : ℝ := - ∑ d', L.state.queries q d' * L.layer.W_K.w d_in d' - have hB_nonneg : 0 ≤ B := by - rcases (inferInstance : Nonempty n) with ⟨pos⟩ - have hsum_nonneg : - 0 ≤ ∑ d', |L.state.input pos d'| := by - exact Finset.sum_nonneg (fun _ _ => abs_nonneg _) - exact le_trans hsum_nonneg (hInput pos) - have hWQ_nonneg : 0 ≤ wq := by - exact le_trans (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_Q)) hWQ - have hWK_nonneg : 0 ≤ wk := by - exact le_trans (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_K)) hWK - have hScale_nonneg : 0 ≤ scale := by - have hcard : 0 < (Fintype.card d : ℝ) := by - exact_mod_cast Fintype.card_pos_iff.mpr (inferInstance : Nonempty d) - have hsqrt : 0 < Real.sqrt (modelDim d) := by - simpa [modelDim] using (Real.sqrt_pos.2 hcard) - exact div_nonneg (show (0 : ℝ) ≤ 1 by exact zero_le_one) (le_of_lt hsqrt) - have hQuery_abs : - |queryContrib| ≤ B * wq * wk := by - have hsum : - |queryContrib| ≤ ∑ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| := by - simpa [queryContrib, abs_mul] using - (abs_sum_le_sum_abs - (f := fun d' => L.layer.W_Q.w d_in d' * L.state.keys k d') - (s := Finset.univ)) - have hle : - ∀ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| ≤ - SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| := by - intro d' - have hrow : - |L.layer.W_Q.w d_in d'| ≤ SignedMixer.rowAbsSum L.layer.W_Q d_in := by - have hnonneg : - ∀ j ∈ (Finset.univ : Finset d), 0 ≤ |L.layer.W_Q.w d_in j| := by - intro j _hj - exact abs_nonneg _ - simpa [SignedMixer.rowAbsSum] using - (single_le_sum (s := (Finset.univ : Finset d)) - (f := fun j => |L.layer.W_Q.w d_in j|) hnonneg (by simp)) - exact mul_le_mul_of_nonneg_right hrow (abs_nonneg _) - have hsum' : - ∑ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| ≤ - ∑ d', SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| := by - refine Finset.sum_le_sum ?_ - intro d' _hd - exact hle d' - have hsum'' : - ∑ d', SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| = - SignedMixer.rowAbsSum L.layer.W_Q d_in * - (∑ d', |L.state.keys k d'|) := by - simp [Finset.mul_sum] - have hkeys := - keys_sum_abs_le_of_input (L := L) (k := k) (hKeys := hKeys k) - have hkeys' : - ∑ d', |L.state.keys k d'| ≤ B * wk := by - have hmul1 : - (∑ d', |L.state.input k d'|) * SignedMixer.operatorNormBound L.layer.W_K ≤ - B * SignedMixer.operatorNormBound L.layer.W_K := by - exact mul_le_mul_of_nonneg_right (hInput k) - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_K)) - have hmul2 : - B * SignedMixer.operatorNormBound L.layer.W_K ≤ B * wk := by - exact mul_le_mul_of_nonneg_left hWK hB_nonneg - exact le_trans hkeys (le_trans hmul1 hmul2) - have hrow_le : - SignedMixer.rowAbsSum L.layer.W_Q d_in ≤ wq := by - have hsup : - SignedMixer.rowAbsSum L.layer.W_Q d_in ≤ - SignedMixer.operatorNormBound L.layer.W_Q := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun j => SignedMixer.rowAbsSum L.layer.W_Q j) (by simp) - exact le_trans hsup hWQ - have hmul1 : - SignedMixer.rowAbsSum L.layer.W_Q d_in * - (∑ d', |L.state.keys k d'|) ≤ - wq * (∑ d', |L.state.keys k d'|) := by - exact mul_le_mul_of_nonneg_right hrow_le - (Finset.sum_nonneg (fun _ _ => abs_nonneg _)) - have hmul2 : - wq * (∑ d', |L.state.keys k d'|) ≤ - wq * (B * wk) := by - exact mul_le_mul_of_nonneg_left hkeys' hWQ_nonneg - have hmul3 : - wq * (B * wk) = B * wq * wk := by ring - calc - |queryContrib| ≤ ∑ d', |L.layer.W_Q.w d_in d'| * |L.state.keys k d'| := hsum - _ ≤ ∑ d', SignedMixer.rowAbsSum L.layer.W_Q d_in * |L.state.keys k d'| := hsum' - _ = SignedMixer.rowAbsSum L.layer.W_Q d_in * (∑ d', |L.state.keys k d'|) := hsum'' - _ ≤ wq * (∑ d', |L.state.keys k d'|) := hmul1 - _ ≤ wq * (B * wk) := hmul2 - _ = B * wq * wk := hmul3 - have hKey_abs : - |keyContrib| ≤ B * wq * wk := by - have hsum : - |keyContrib| ≤ ∑ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| := by - simpa [keyContrib, abs_mul, mul_comm] using - (abs_sum_le_sum_abs - (f := fun d' => L.state.queries q d' * L.layer.W_K.w d_in d') - (s := Finset.univ)) - have hle : - ∀ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| ≤ - |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in := by - intro d' - have hrow : - |L.layer.W_K.w d_in d'| ≤ SignedMixer.rowAbsSum L.layer.W_K d_in := by - have hnonneg : - ∀ j ∈ (Finset.univ : Finset d), 0 ≤ |L.layer.W_K.w d_in j| := by - intro j _hj - exact abs_nonneg _ - simpa [SignedMixer.rowAbsSum] using - (single_le_sum (s := (Finset.univ : Finset d)) - (f := fun j => |L.layer.W_K.w d_in j|) hnonneg (by simp)) - exact mul_le_mul_of_nonneg_left hrow (abs_nonneg _) - have hsum' : - ∑ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| ≤ - ∑ d', |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in := by - refine Finset.sum_le_sum ?_ - intro d' _hd - exact hle d' - have hsum'' : - ∑ d', |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in = - (∑ d', |L.state.queries q d'|) * - SignedMixer.rowAbsSum L.layer.W_K d_in := by - simp [Finset.sum_mul] - have hqueries := - queries_sum_abs_le_of_input (L := L) (q := q) (hQueries := hQueries q) - have hqueries' : - ∑ d', |L.state.queries q d'| ≤ B * wq := by - have hmul1 : - (∑ d', |L.state.input q d'|) * SignedMixer.operatorNormBound L.layer.W_Q ≤ - B * SignedMixer.operatorNormBound L.layer.W_Q := by - exact mul_le_mul_of_nonneg_right (hInput q) - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_Q)) - have hmul2 : - B * SignedMixer.operatorNormBound L.layer.W_Q ≤ B * wq := by - exact mul_le_mul_of_nonneg_left hWQ hB_nonneg - exact le_trans hqueries (le_trans hmul1 hmul2) - have hrow_le : - SignedMixer.rowAbsSum L.layer.W_K d_in ≤ wk := by - have hsup : - SignedMixer.rowAbsSum L.layer.W_K d_in ≤ - SignedMixer.operatorNormBound L.layer.W_K := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun j => SignedMixer.rowAbsSum L.layer.W_K j) (by simp) - exact le_trans hsup hWK - have hmul1 : - (∑ d', |L.state.queries q d'|) * - SignedMixer.rowAbsSum L.layer.W_K d_in ≤ - (∑ d', |L.state.queries q d'|) * wk := by - exact mul_le_mul_of_nonneg_left hrow_le - (Finset.sum_nonneg (fun _ _ => abs_nonneg _)) - have hmul2 : - (∑ d', |L.state.queries q d'|) * wk ≤ - (B * wq) * wk := by - exact mul_le_mul_of_nonneg_right hqueries' hWK_nonneg - have hmul3 : - (B * wq) * wk = B * wq * wk := by ring - calc - |keyContrib| ≤ ∑ d', |L.state.queries q d'| * |L.layer.W_K.w d_in d'| := hsum - _ ≤ ∑ d', |L.state.queries q d'| * SignedMixer.rowAbsSum L.layer.W_K d_in := hsum' - _ = (∑ d', |L.state.queries q d'|) * SignedMixer.rowAbsSum L.layer.W_K d_in := hsum'' - _ ≤ (∑ d', |L.state.queries q d'|) * wk := hmul1 - _ ≤ (B * wq) * wk := hmul2 - _ = B * wq * wk := hmul3 - have hQuery_term : - |if q = i then queryContrib else 0| ≤ B * wq * wk := by - by_cases hqi : q = i - · simp [hqi, hQuery_abs] - · have hnonneg : 0 ≤ B * wq * wk := - mul_nonneg (mul_nonneg hB_nonneg hWQ_nonneg) hWK_nonneg - simpa [hqi] using hnonneg - have hKey_term : - |if k = i then keyContrib else 0| ≤ B * wq * wk := by - by_cases hki : k = i - · simp [hki, hKey_abs] - · have hnonneg : 0 ≤ B * wq * wk := - mul_nonneg (mul_nonneg hB_nonneg hWQ_nonneg) hWK_nonneg - simpa [hki] using hnonneg - have hsum : - |(if q = i then queryContrib else 0) + (if k = i then keyContrib else 0)| ≤ - B * wq * wk + B * wq * wk := by - exact le_trans (abs_add_le _ _) (add_le_add hQuery_term hKey_term) - have hsum_eq : B * wq * wk + B * wq * wk = 2 * B * wq * wk := by - ring - have hmul : - scale * - (B * wq * wk + B * wq * wk) = - scale * (2 * B * wq * wk) := by - simp [hsum_eq] - calc - |scoreGradient L q k i d_in| - = |scale| * - |(if q = i then queryContrib else 0) + (if k = i then keyContrib else 0)| := by - simp [scoreGradient, scale, queryContrib, keyContrib, abs_mul] - _ = scale * - |(if q = i then queryContrib else 0) + (if k = i then keyContrib else 0)| := by - simp [abs_of_nonneg hScale_nonneg] - _ ≤ scale * (B * wq * wk + B * wq * wk) := by - exact mul_le_mul_of_nonneg_left hsum hScale_nonneg - _ = scale * (2 * B * wq * wk) := hmul - _ = (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by simp [scale] - -omit [DecidableEq d] in -/-- L1 bound on score-gradient rows from input L1 bounds and W_Q/W_K operator norms. -/ -theorem scoreGradient_sum_le_of_input [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (q i : n) (d_in : d) - (B wq wk : ℝ) - (hInput : ∀ pos, ∑ d', |L.state.input pos d'| ≤ B) - (hKeys : - ∀ pos d', L.state.keys pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_K.w d_in d') - (hQueries : - ∀ pos d', L.state.queries pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_Q.w d_in d') - (hWQ : SignedMixer.operatorNormBound L.layer.W_Q ≤ wq) - (hWK : SignedMixer.operatorNormBound L.layer.W_K ≤ wk) : - ∑ k, |scoreGradient L q k i d_in| ≤ - (Fintype.card n : ℝ) * - ((1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk)) := by - classical - have hEach : - ∀ k, |scoreGradient L q k i d_in| ≤ - (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by - intro k - exact scoreGradient_abs_le_of_input (L := L) (q := q) (k := k) (i := i) (d_in := d_in) - (B := B) (wq := wq) (wk := wk) hInput hKeys hQueries hWQ hWK - have hSum : - ∑ k, |scoreGradient L q k i d_in| ≤ - ∑ k : n, (1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk) := by - refine Finset.sum_le_sum ?_ - intro k _hk - exact hEach k - exact le_trans hSum (by simp) - -omit [DecidableEq d] in -/-- Row-sum bound for attention gradients from softmax Jacobian and score-gradient bounds. -/ -theorem attentionGradient_rowAbsSum_le_of_softmax [Nonempty n] - (L : AttentionLinearization n d) (q i : n) (d_in : d) (J S : ℝ) - (hConsistent : L.state.attentionWeights q = softmax (L.state.scores q)) - (hSoftmax : - SignedMixer.operatorNormBound (softmaxJacobian (L.state.scores q)) ≤ J) - (hScore : ∑ k, |scoreGradient L q k i d_in| ≤ S) : - ∑ k, |attentionGradient L q k i d_in| ≤ J * S := by - classical - let M := softmaxJacobian (L.state.scores q) - let v : n → ℝ := fun k => scoreGradient L q k i d_in - have hApply : ∀ k, attentionGradient L q k i d_in = M.apply v k := by - intro k - have h := - attentionGradient_via_softmax (L := L) (q := q) (k := k) (i := i) (d_in := d_in) - hConsistent - simpa [SignedMixer.apply_def, M, v, mul_comm] using h - have hSum : - ∑ k, |attentionGradient L q k i d_in| = ∑ k, |M.apply v k| := by - refine Finset.sum_congr rfl ?_ - intro k _hk - simp [hApply] - have hBound := sum_abs_apply_le (M := M) (v := v) - have hSum_nonneg : 0 ≤ ∑ k, |v k| := - Finset.sum_nonneg (fun _ _ => abs_nonneg _) - have hJ_nonneg : 0 ≤ J := - le_trans (SignedMixer.operatorNormBound_nonneg (M := M)) hSoftmax - have hMul1 : - (∑ k, |v k|) * SignedMixer.operatorNormBound M ≤ (∑ k, |v k|) * J := by - exact mul_le_mul_of_nonneg_left hSoftmax hSum_nonneg - have hMul2 : - (∑ k, |v k|) * J ≤ S * J := by - exact mul_le_mul_of_nonneg_right hScore hJ_nonneg - calc - ∑ k, |attentionGradient L q k i d_in| = ∑ k, |M.apply v k| := hSum - _ ≤ (∑ k, |v k|) * SignedMixer.operatorNormBound M := hBound - _ ≤ (∑ k, |v k|) * J := hMul1 - _ ≤ S * J := hMul2 - _ = J * S := by ring - -omit [DecidableEq n] [DecidableEq d] in -/-- Attention weights are nonnegative when consistent with softmax. -/ -theorem attentionWeights_nonneg (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (q k : n) : 0 ≤ L.state.attentionWeights q k := by - simpa [hConsistent q] using - (softmax_nonneg (x := L.state.scores q) (j := k)) - -omit [DecidableEq n] [DecidableEq d] in -/-- Attention weights for a query sum to one when consistent with softmax. -/ -theorem attentionWeights_row_sum_one [Nonempty n] (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (q : n) : ∑ k, L.state.attentionWeights q k = 1 := by - simpa [hConsistent q] using (softmax_sum_one (x := L.state.scores q)) - -omit [DecidableEq n] [DecidableEq d] in -/-- Attention weights are at most one when consistent with softmax. -/ -theorem attentionWeights_le_one [Nonempty n] (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (q k : n) : L.state.attentionWeights q k ≤ 1 := by - classical - have hnonneg : ∀ j, 0 ≤ L.state.attentionWeights q j := by - intro j - exact attentionWeights_nonneg (L := L) hConsistent q j - have hle : - L.state.attentionWeights q k ≤ ∑ j, L.state.attentionWeights q j := by - simpa using - (single_le_sum (s := Finset.univ) (f := fun j => L.state.attentionWeights q j) - (by - intro j _hj - exact hnonneg j) (by simp)) - have hsum := attentionWeights_row_sum_one (L := L) hConsistent q - simpa [hsum] using hle - -omit [DecidableEq n] [DecidableEq d] in -/-- Column mass of attention weights is bounded by the sequence length under softmax consistency. -/ -theorem attentionWeights_column_sum_le_card [Nonempty n] (L : AttentionLinearization n d) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (k : n) : - ∑ q, |L.state.attentionWeights q k| ≤ (Fintype.card n : ℝ) := by - classical - have hle1 : ∀ q, L.state.attentionWeights q k ≤ 1 := by - intro q - exact attentionWeights_le_one (L := L) hConsistent q k - have hnonneg : ∀ q, 0 ≤ L.state.attentionWeights q k := by - intro q - exact attentionWeights_nonneg (L := L) hConsistent q k - have hsum_abs : - (∑ q, |L.state.attentionWeights q k|) = - ∑ q, L.state.attentionWeights q k := by - refine Finset.sum_congr rfl ?_ - intro q _hq - exact abs_of_nonneg (hnonneg q) - have hsum_le : - (∑ q : n, L.state.attentionWeights q k) ≤ ∑ q : n, (1 : ℝ) := by - refine Finset.sum_le_sum ?_ - intro q _hq - exact hle1 q - have hsum_one : (∑ q : n, (1 : ℝ)) = (Fintype.card n : ℝ) := by - simp - calc - ∑ q, |L.state.attentionWeights q k| - = ∑ q, L.state.attentionWeights q k := hsum_abs - _ ≤ ∑ q, (1 : ℝ) := hsum_le - _ = (Fintype.card n : ℝ) := hsum_one - -omit [DecidableEq n] [DecidableEq d] in -/-- Value-term operator-norm bound using softmax column-mass control. -/ -theorem valueTerm_operatorNormBound_le_card [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (B : ℝ) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (hVO : SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B) : - SignedMixer.operatorNormBound (valueTerm L) ≤ (Fintype.card n : ℝ) * B := by - have hAttn := attentionWeights_column_sum_le_card (L := L) hConsistent - simpa using - (valueTerm_operatorNormBound_le (L := L) (A := (Fintype.card n : ℝ)) (B := B) - hAttn hVO) - -/-! ### Explicit Pattern Term Formula -/ - -omit [DecidableEq n] [DecidableEq d] in -/-- **Explicit formula for the Pattern Term**. - -PatternTerm_{(i,d_in), (q,d_out)} = - Σ_k ∂A_{qk}/∂x_{i,d_in} · (Σ_{d'} V_k[d'] · W_O[d',d_out]) - = Σ_k attentionGradient(q,k,i,d_in) · valueContrib(k,d_out) - -This shows exactly how shifting attention patterns affects the output. -/ -noncomputable def patternTermExplicit (L : AttentionLinearization n d) : - SignedMixer (n × d) (n × d) where - w := fun ⟨i, d_in⟩ ⟨q, d_out⟩ => - ∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out) - -omit [DecidableEq d] in -/-- Pattern term equals the explicit formula when the full Jacobian matches the explicit split. -/ -theorem patternTerm_eq_explicit_of_fullJacobian_eq (L : AttentionLinearization n d) - (hEq : L.fullJacobian = valueTerm L + patternTermExplicit L) : - patternTerm L = patternTermExplicit L := by - have hDecomp := attention_jacobian_decomposition (L := L) - have hMixers : valueTerm L + patternTerm L = valueTerm L + patternTermExplicit L := by - exact hDecomp.symm.trans hEq - ext i j - have hEq' := congrArg (fun M => M.w i j) hMixers - have hEq'' : - (valueTerm L).w i j + (patternTerm L).w i j = - (valueTerm L).w i j + (patternTermExplicit L).w i j := by - simpa [SignedMixer.add_w] using hEq' - exact add_left_cancel hEq'' - -/-! ### Pattern Term Bounds -/ - -/-- Output mixer using cached values and the output projection. -/ -noncomputable def valueOutputMixer (L : AttentionLinearization n d) : SignedMixer n d := - ⟨fun k d_out => ∑ d', L.state.values k d' * L.layer.W_O.w d' d_out⟩ - -omit [DecidableEq n] [DecidableEq d] in -/-- Value-output mixer bound from input L1 bounds and a `W_V·W_O` operator-norm bound. -/ -theorem valueOutputMixer_operatorNormBound_le_of_input [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (B V : ℝ) - (hInput : ∀ pos, ∑ d', |L.state.input pos d'| ≤ B) - (hValues : - ∀ pos d', L.state.values pos d' = - ∑ d_in, L.state.input pos d_in * L.layer.W_V.w d_in d') - (hVO : SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ V) : - SignedMixer.operatorNormBound (valueOutputMixer L) ≤ B * V := by - classical - have hB_nonneg : 0 ≤ B := by - rcases (inferInstance : Nonempty n) with ⟨pos⟩ - have hsum_nonneg : - 0 ≤ ∑ d', |L.state.input pos d'| := by - exact Finset.sum_nonneg (fun _ _ => abs_nonneg _) - exact le_trans hsum_nonneg (hInput pos) - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n)) - (f := fun k => SignedMixer.rowAbsSum (valueOutputMixer L) k) - (a := B * V)).2 ?_ - intro k _hk - have hInputSum := hInput k - have hW : - ∀ d_out, (valueOutputMixer L).w k d_out = - ∑ d_in, L.state.input k d_in * (L.layer.W_V.comp L.layer.W_O).w d_in d_out := by - intro d_out - have hValues' := hValues k - calc - (valueOutputMixer L).w k d_out - = ∑ d', L.state.values k d' * L.layer.W_O.w d' d_out := by - simp [valueOutputMixer] - _ = ∑ d', (∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') * - L.layer.W_O.w d' d_out := by - simp [hValues'] - _ = ∑ d_in, L.state.input k d_in * - (∑ d', L.layer.W_V.w d_in d' * L.layer.W_O.w d' d_out) := by - calc - ∑ d', (∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') * - L.layer.W_O.w d' d_out - = ∑ d', L.layer.W_O.w d' d_out * - (∑ d_in, L.state.input k d_in * L.layer.W_V.w d_in d') := by - simp [mul_comm] - _ = ∑ d', ∑ d_in, - L.layer.W_O.w d' d_out * - (L.state.input k d_in * L.layer.W_V.w d_in d') := by - simp [Finset.mul_sum] - _ = ∑ d_in, ∑ d', - L.layer.W_O.w d' d_out * - (L.state.input k d_in * L.layer.W_V.w d_in d') := by - simpa using - (Finset.sum_comm (s := Finset.univ) (t := Finset.univ) - (f := fun d' d_in => - L.layer.W_O.w d' d_out * - (L.state.input k d_in * L.layer.W_V.w d_in d'))) - _ = ∑ d_in, L.state.input k d_in * - (∑ d', L.layer.W_V.w d_in d' * L.layer.W_O.w d' d_out) := by - refine Finset.sum_congr rfl ?_ - intro d_in _hd - simp [Finset.mul_sum, mul_comm, mul_assoc] - _ = ∑ d_in, L.state.input k d_in * (L.layer.W_V.comp L.layer.W_O).w d_in d_out := by - simp [SignedMixer.comp_w] - have hRow : - SignedMixer.rowAbsSum (valueOutputMixer L) k = - ∑ d_out, |∑ d_in, L.state.input k d_in * - (L.layer.W_V.comp L.layer.W_O).w d_in d_out| := by - simp [SignedMixer.rowAbsSum, hW] - have hSum := - sum_abs_apply_le (M := L.layer.W_V.comp L.layer.W_O) - (v := fun d_in => L.state.input k d_in) - have hSum' : - ∑ d_out, |(L.layer.W_V.comp L.layer.W_O).apply (fun d_in => - L.state.input k d_in) d_out| ≤ - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := by - simpa using hSum - have hMul : - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B * V := by - have hmul1 : - (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ - B * SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := by - exact mul_le_mul_of_nonneg_right hInputSum - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_V.comp L.layer.W_O)) - have hmul2 : - B * SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) ≤ B * V := by - exact mul_le_mul_of_nonneg_left hVO hB_nonneg - exact le_trans hmul1 hmul2 - calc - SignedMixer.rowAbsSum (valueOutputMixer L) k - = ∑ d_out, |(L.layer.W_V.comp L.layer.W_O).apply (fun d_in => - L.state.input k d_in) d_out| := by - simpa [SignedMixer.apply_def] using hRow - _ ≤ (∑ d_in, |L.state.input k d_in|) * - SignedMixer.operatorNormBound (L.layer.W_V.comp L.layer.W_O) := hSum' - _ ≤ B * V := hMul - -/-- Mixer capturing attention gradients for a fixed input coordinate. -/ -noncomputable def attentionGradientMixer (L : AttentionLinearization n d) (i : n) (d_in : d) : - SignedMixer n n := - ⟨fun q k => attentionGradient L q k i d_in⟩ - -omit [DecidableEq d] in -/-- Pattern term entries as a gradient mixer composed with value output. -/ -theorem patternTermExplicit_w_eq (L : AttentionLinearization n d) - (i : n) (d_in : d) (q : n) (d_out : d) : - (patternTermExplicit L).w (i, d_in) (q, d_out) = - ((attentionGradientMixer L i d_in).comp (valueOutputMixer L)).w q d_out := by - simp [patternTermExplicit, attentionGradientMixer, valueOutputMixer, SignedMixer.comp_w] - -omit [DecidableEq d] in -/-- Row-absolute-sum bound for the explicit pattern term. -/ -theorem patternTermExplicit_rowAbsSum_le [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (i : n) (d_in : d) (G V : ℝ) - (hGrad : ∀ q, ∑ k, |attentionGradient L q k i d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) : - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) ≤ - (Fintype.card n : ℝ) * G * V := by - classical - let A := attentionGradientMixer L i d_in - let B := valueOutputMixer L - have hRow : - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) = - ∑ q, SignedMixer.rowAbsSum (A.comp B) q := by - have hRow1 : - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) = - ∑ q, ∑ d_out, - |∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out)| := by - simpa [SignedMixer.rowAbsSum, patternTermExplicit] using - (Fintype.sum_prod_type' - (f := fun q d_out => - |∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out)|)) - have hRow2 : - ∑ q, SignedMixer.rowAbsSum (A.comp B) q = - ∑ q, ∑ d_out, - |∑ k, attentionGradient L q k i d_in * - (∑ d', L.state.values k d' * L.layer.W_O.w d' d_out)| := by - simp [SignedMixer.rowAbsSum, A, B, attentionGradientMixer, valueOutputMixer, - SignedMixer.comp_w] - exact hRow1.trans hRow2.symm - have hRow_q : - ∀ q, SignedMixer.rowAbsSum (A.comp B) q ≤ G * SignedMixer.operatorNormBound B := by - intro q - have hA : - SignedMixer.rowAbsSum A q ≤ G := by - simpa [A, SignedMixer.rowAbsSum] using hGrad q - have hB_nonneg : 0 ≤ SignedMixer.operatorNormBound B := - SignedMixer.operatorNormBound_nonneg (M := B) - have hcomp : - SignedMixer.rowAbsSum (A.comp B) q ≤ - SignedMixer.rowAbsSum A q * SignedMixer.operatorNormBound B := - SignedMixer.rowAbsSum_comp_le (M := A) (N := B) (i := q) - have hmul : - SignedMixer.rowAbsSum A q * SignedMixer.operatorNormBound B ≤ - G * SignedMixer.operatorNormBound B := - mul_le_mul_of_nonneg_right hA hB_nonneg - exact le_trans hcomp hmul - have hSum : - (∑ q : n, SignedMixer.rowAbsSum (A.comp B) q) ≤ - ∑ q : n, G * SignedMixer.operatorNormBound B := by - refine Finset.sum_le_sum ?_ - intro q _hq - exact hRow_q q - have hCard : - (∑ q : n, G * SignedMixer.operatorNormBound B) = - (Fintype.card n : ℝ) * (G * SignedMixer.operatorNormBound B) := by - simp - have hCard_nonneg : 0 ≤ (Fintype.card n : ℝ) := by - exact_mod_cast Nat.zero_le _ - have hG_nonneg : 0 ≤ G := by - rcases (inferInstance : Nonempty n) with ⟨q⟩ - have h := hGrad q - exact le_trans (Finset.sum_nonneg (fun _ _ => abs_nonneg _)) h - have hGV : G * SignedMixer.operatorNormBound B ≤ G * V := by - exact mul_le_mul_of_nonneg_left hValue hG_nonneg - calc - SignedMixer.rowAbsSum (patternTermExplicit L) (i, d_in) - = ∑ q, SignedMixer.rowAbsSum (A.comp B) q := hRow - _ ≤ ∑ q, G * SignedMixer.operatorNormBound B := hSum - _ = (Fintype.card n : ℝ) * (G * SignedMixer.operatorNormBound B) := hCard - _ ≤ (Fintype.card n : ℝ) * (G * V) := by - exact mul_le_mul_of_nonneg_left hGV hCard_nonneg - _ = (Fintype.card n : ℝ) * G * V := by ring - -omit [DecidableEq d] in -/-- Operator-norm bound for the explicit pattern term. -/ -theorem patternTermExplicit_operatorNormBound_le [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (G V : ℝ) - (hGrad : ∀ i d_in q, ∑ k, |attentionGradient L q k i d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) : - SignedMixer.operatorNormBound (patternTermExplicit L) ≤ - (Fintype.card n : ℝ) * G * V := by - classical - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := n × d)) - (f := fun i => SignedMixer.rowAbsSum (patternTermExplicit L) i) - (a := (Fintype.card n : ℝ) * G * V)).2 ?_ - intro id _hid - rcases id with ⟨i, d_in⟩ - have hRow := - patternTermExplicit_rowAbsSum_le (L := L) (i := i) (d_in := d_in) (G := G) (V := V) - (hGrad := hGrad i d_in) hValue - simpa using hRow - -omit [DecidableEq d] in -/-- Pattern-term operator-norm bound from equality with the explicit formula. -/ -theorem patternTerm_operatorNormBound_le_of_eq_explicit [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (G V : ℝ) - (hGrad : ∀ i d_in q, ∑ k, |attentionGradient L q k i d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) - (hEq : patternTerm L = patternTermExplicit L) : - SignedMixer.operatorNormBound (patternTerm L) ≤ - (Fintype.card n : ℝ) * G * V := by - simpa [hEq] using - (patternTermExplicit_operatorNormBound_le (L := L) (G := G) (V := V) hGrad hValue) - -omit [DecidableEq d] in -/-- Pattern-term operator-norm bound via softmax consistency and score-gradient bounds. -/ -theorem patternTerm_operatorNormBound_le_of_softmax [Nonempty n] [Nonempty d] - (L : AttentionLinearization n d) (J S V : ℝ) - (hConsistent : ∀ q, L.state.attentionWeights q = softmax (L.state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound (softmaxJacobian (L.state.scores q)) ≤ J) - (hScore : ∀ i d_in q, ∑ k, |scoreGradient L q k i d_in| ≤ S) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V) - (hEq : patternTerm L = patternTermExplicit L) : - SignedMixer.operatorNormBound (patternTerm L) ≤ - (Fintype.card n : ℝ) * J * S * V := by - have hGrad : ∀ i d_in q, ∑ k, |attentionGradient L q k i d_in| ≤ J * S := by - intro i d_in q - simpa using - (attentionGradient_rowAbsSum_le_of_softmax (L := L) (q := q) (i := i) (d_in := d_in) - (J := J) (S := S) (hConsistent := hConsistent q) - (hSoftmax := hSoftmax q) (hScore := hScore i d_in q)) - have hBound := - patternTerm_operatorNormBound_le_of_eq_explicit - (L := L) (G := J * S) (V := V) hGrad hValue hEq - calc - SignedMixer.operatorNormBound (patternTerm L) ≤ - (Fintype.card n : ℝ) * (J * S) * V := hBound - _ = (Fintype.card n : ℝ) * J * S * V := by ring - -/-! ### Attention Rollout Approximation Error -/ - -/-- **Attention Approximation Error**: The Frobenius norm of the Pattern Term. - -When this is small relative to the Value Term, Attention Rollout (using just A) -is a faithful explanation of the network's input-output relationship. - -This gives a rigorous, quantitative answer to "When is visualizing attention weights valid?" -/ -noncomputable def attentionApproximationError (L : AttentionLinearization n d) : ℝ := - Real.sqrt (∑ input : n × d, ∑ output : n × d, - ((patternTerm L).w input output) ^ 2) - -/-- The Frobenius norm of the Value Term for normalization. -/ -noncomputable def valueTermNorm (L : AttentionLinearization n d) : ℝ := - Real.sqrt (∑ input : n × d, ∑ output : n × d, - ((valueTerm L).w input output) ^ 2) - -/-- **Relative Approximation Error**: Pattern Term / Value Term. - -When this ratio is small (e.g., < 0.1), attention weights are a good explanation. -When large, the attention pattern is shifting significantly with input changes, -and attention visualization may be misleading. -/ -noncomputable def relativeApproximationError (L : AttentionLinearization n d) - (_hV : valueTermNorm L ≠ 0) : ℝ := - attentionApproximationError L / valueTermNorm L - -/-- **Attention Rollout Faithfulness Criterion**: The approximation is "ε-faithful" -if the relative error is at most ε. - -This gives a rigorous definition of when attention visualization is valid! -/ -def isAttentionRolloutFaithful (L : AttentionLinearization n d) (ε : ℝ) - (hV : valueTermNorm L ≠ 0) : Prop := - relativeApproximationError L hV ≤ ε - -/-! ### Bounds on the Pattern Term -/ - -variable [Nonempty n] [Nonempty d] - -/-- Maximum entry in the value projection. -/ -noncomputable def maxValueWeight (L : AttentionLinearization n d) : ℝ := - Finset.sup' Finset.univ Finset.univ_nonempty fun (p : d × d) => - |(L.layer.W_V.comp L.layer.W_O).w p.1 p.2| - -/-- Maximum entry in the score gradient (bounded by QK projection norms). -/ -noncomputable def maxScoreGradient (L : AttentionLinearization n d) : ℝ := - let maxQ := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : d × d) => - |L.layer.W_Q.w p.1 p.2| - let maxK := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : d × d) => - |L.layer.W_K.w p.1 p.2| - let maxKey := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : n × d) => - |L.state.keys p.1 p.2| - let maxQuery := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : n × d) => - |L.state.queries p.1 p.2| - (1 / Real.sqrt (modelDim d)) * (maxQ * maxKey + maxQuery * maxK) * Fintype.card d - -/-- **Bound on Pattern Term via softmax sensitivity**. - -The Pattern Term is bounded by: - |PatternTerm| ≤ maxAttnGradBound · maxValueContrib · (sequence length) - -where maxAttnGradBound depends on the softmax Jacobian (bounded by 0.25 per entry) -and the score gradient. - -This is a structural statement about the existence of such a bound. -The exact bound depends on architectural details. -/ -noncomputable def patternTermBound (L : AttentionLinearization n d) : ℝ := - let maxValue := Finset.sup' Finset.univ Finset.univ_nonempty fun (p : n × d) => - |L.state.values p.1 p.2| - Fintype.card n * (0.25 * maxScoreGradient L) * - (Fintype.card d * maxValueWeight L * maxValue) - -/-! ### When is Attention Rollout Valid? -/ - -/-- **Sufficient condition for attention rollout validity**: small score gradients. - -If the score gradients are small (attention patterns are stable), then the -Pattern Term is small and Attention Rollout is faithful. - -Intuitively: when Q·K^T has small gradients with respect to x, the attention -pattern doesn't shift much, so treating A as constant is valid. - -This definition captures when we expect rollout to be faithful. -/ -def hasSmallScoreGradient (L : AttentionLinearization n d) (ε : ℝ) : Prop := - maxScoreGradient L ≤ ε - -/-- **Attention rollout validity criterion**: When score gradients are bounded, -the relative error is bounded by a function of the score gradient bound. - -This is the key structural insight: the faithfulness of attention rollout -depends on how stable the attention pattern is under input perturbations. -/ -noncomputable def rolloutErrorBound (L : AttentionLinearization n d) : ℝ := - patternTermBound L / (valueTermNorm L + 1) - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Attention rollout becomes exact when QK projections are zero**. - -If W_Q = 0, the query contribution to score gradients vanishes. -This is unrealistic but shows the theoretical structure. -/ -theorem scoreGradient_queryContrib_zero_when_Q_zero (L : AttentionLinearization n d) - (hQ : L.layer.W_Q = 0) (k : n) (d_in : d) : - ∑ d', L.layer.W_Q.w d_in d' * L.state.keys k d' = 0 := by - have hQ' : ∀ a b, L.layer.W_Q.w a b = 0 := fun a b => by simp [hQ, SignedMixer.zero_w] - simp [hQ'] - -/-! ### Position-wise vs Full Jacobian -/ - -/-- **Position-collapsed attention Jacobian**: Sum over hidden dimensions. - -This gives a (position × position) matrix that shows how much each input position -affects each output position, averaging over dimensions. - -This is closer to what "attention visualization" typically shows. -/ -noncomputable def positionJacobian (L : AttentionLinearization n d) : SignedMixer n n where - w := fun i q => ∑ d_in : d, ∑ d_out : d, L.fullJacobian.w (i, d_in) (q, d_out) - -/-- Position-collapsed Value Term. -/ -noncomputable def positionValueTerm (L : AttentionLinearization n d) : SignedMixer n n where - w := fun k q => - -- Σ_{d_in, d_out} A_{qk} · (W_V · W_O)_{d_in, d_out} - let voProd := ∑ d_in : d, ∑ d_out : d, (L.layer.W_V.comp L.layer.W_O).w d_in d_out - L.state.attentionWeights q k * voProd - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Key insight**: The position-collapsed Value Term is proportional to attention weights! - -positionValueTerm(k→q) = A_{qk} · Σ_{d_in,d_out} (W_V · W_O)_{d_in,d_out} - -So if the total sum of entries of W_V · W_O is treated as a constant, attention weights -directly give the position flow. This is the mathematical justification for -"attention rollout". -/ -theorem positionValueTerm_proportional_to_attention (L : AttentionLinearization n d) (k q : n) : - (positionValueTerm L).w k q = - L.state.attentionWeights q k * - ∑ d_in : d, ∑ d_out : d, (L.layer.W_V.comp L.layer.W_O).w d_in d_out := rfl - -/-- The total sum of entries of W_V · W_O (the proportionality constant). -/ -noncomputable def valueOutputTrace (L : AttentionLinearization n d) : ℝ := - ∑ d_in : d, ∑ d_out : d, (L.layer.W_V.comp L.layer.W_O).w d_in d_out - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- Position Value Term is attention weights scaled by the total sum of entries. -/ -theorem positionValueTerm_eq_scaled_attention (L : AttentionLinearization n d) (k q : n) : - (positionValueTerm L).w k q = L.state.attentionWeights q k * valueOutputTrace L := rfl - -end AttentionJacobian - -/-! ## Full Transformer Layer Linearization -/ - -section TransformerLayers - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-- A full transformer layer linearization. -Given activations at a specific forward pass, this captures the local -linear approximation of the entire layer. -/ -structure TransformerLayerLinearization where - /-- Input hidden states -/ - input : n → d → ℝ - /-- Output hidden states -/ - output : n → d → ℝ - /-- Attention linearization -/ - attention : AttentionLinearization n d - /-- MLP Jacobian -/ - mlpJacobian : SignedMixer (n × d) (n × d) - /-- Combined Jacobian (attention + residual) · (MLP + residual) -/ - combinedJacobian : SignedMixer (n × d) (n × d) - -/-- The combined Jacobian of a transformer layer is the composition -of sub-layer Jacobians, accounting for residual connections. - -For a residual block f(x) = x + sublayer(x): - ∂f/∂x = I + ∂sublayer/∂x - -So the full layer with two residual blocks: - ∂/∂x [(x + attention(x)) + MLP(x + attention(x))] - = (I + J_attn) · (I + J_mlp) -/ -theorem transformer_layer_jacobian_structure - (L : TransformerLayerLinearization (n := n) (d := d)) - (h : L.combinedJacobian = - (SignedMixer.identity + L.attention.fullJacobian).comp - (SignedMixer.identity + L.mlpJacobian)) : - L.combinedJacobian = - SignedMixer.identity + - L.attention.fullJacobian + - L.mlpJacobian + - L.attention.fullJacobian.comp L.mlpJacobian := by - rw [h] - ext ⟨i, d_i⟩ ⟨o, d_o⟩ - simp only [SignedMixer.comp_w, SignedMixer.add_w, SignedMixer.identity] - -- Expand (I + A)(I + M) = I + A + M + AM by computing each indicator sum separately - -- Use Finset.sum_eq_single to evaluate sums with single nonzero term - -- First sum: Σ_x δ_{ix}δ_{xo} = δ_{io} - have sum_ii : ∑ x : n × d, - (if (i, d_i) = x then (1 : ℝ) else 0) * (if x = (o, d_o) then (1 : ℝ) else 0) = - if (i, d_i) = (o, d_o) then (1 : ℝ) else 0 := by - by_cases heq : (i, d_i) = (o, d_o) - · simp only [heq, ite_true] - rw [Finset.sum_eq_single (o, d_o)] - · simp - · intro j _ hj; simp [hj, hj.symm] - · intro h; exact absurd (Finset.mem_univ _) h - · simp only [heq, ite_false] - apply Finset.sum_eq_zero - intro j _ - by_cases h1 : (i, d_i) = j <;> by_cases h2 : j = (o, d_o) - · exact absurd (h1.trans h2) heq - · simp [h2] - · simp [h1] - · simp [h1] - -- Second sum: Σ_x δ_{ix}M_{xo} = M_{io} - have sum_im : ∑ x : n × d, (if (i, d_i) = x then 1 else 0) * L.mlpJacobian.w x (o, d_o) = - L.mlpJacobian.w (i, d_i) (o, d_o) := by - rw [Finset.sum_eq_single (i, d_i)] - · simp - · intro j _ hj; simp [hj.symm] - · intro h; exact absurd (Finset.mem_univ _) h - -- Third sum: Σ_x A_{ix}δ_{xo} = A_{io} - have sum_ai : ∑ x : n × d, - L.attention.fullJacobian.w (i, d_i) x * (if x = (o, d_o) then 1 else 0) = - L.attention.fullJacobian.w (i, d_i) (o, d_o) := by - rw [Finset.sum_eq_single (o, d_o)] - · simp - · intro j _ hj; simp [hj] - · intro h; exact absurd (Finset.mem_univ _) h - -- Expand the product, distribute the sum, then simplify - have expand_prod : ∀ x, - ((if (i, d_i) = x then 1 else 0) + L.attention.fullJacobian.w (i, d_i) x) * - ((if x = (o, d_o) then 1 else 0) + L.mlpJacobian.w x (o, d_o)) = - (if (i, d_i) = x then 1 else 0) * (if x = (o, d_o) then 1 else 0) + - (if (i, d_i) = x then 1 else 0) * L.mlpJacobian.w x (o, d_o) + - L.attention.fullJacobian.w (i, d_i) x * (if x = (o, d_o) then 1 else 0) + - L.attention.fullJacobian.w (i, d_i) x * L.mlpJacobian.w x (o, d_o) := by intro x; ring - conv_lhs => arg 2; ext x; rw [expand_prod] - rw [Finset.sum_add_distrib, Finset.sum_add_distrib, Finset.sum_add_distrib] - conv_lhs => arg 1; arg 1; arg 1; rw [sum_ii] - simp only [sum_im, sum_ai] - ring - -/-- **Transformer attribution has four components**: -1. Direct (identity): input flows directly through residual -2. Attention: input → attention mechanism → output -3. MLP: input → residual → MLP → output -4. Cross-term: input → attention → MLP → output (interaction) - -Each can be analyzed separately for interpretability. -/ -theorem transformer_attribution_components - (L : TransformerLayerLinearization (n := n) (d := d)) - (h : L.combinedJacobian = - (SignedMixer.identity + L.attention.fullJacobian).comp - (SignedMixer.identity + L.mlpJacobian)) : - ∃ (direct attention mlp cross : SignedMixer (n × d) (n × d)), - L.combinedJacobian = direct + attention + mlp + cross ∧ - direct = SignedMixer.identity ∧ - attention = L.attention.fullJacobian ∧ - mlp = L.mlpJacobian ∧ - cross = L.attention.fullJacobian.comp L.mlpJacobian := by - refine ⟨SignedMixer.identity, L.attention.fullJacobian, L.mlpJacobian, - L.attention.fullJacobian.comp L.mlpJacobian, ?_, rfl, rfl, rfl, rfl⟩ - exact transformer_layer_jacobian_structure L h - -end TransformerLayers - -/-! ## Integrated Gradients Connection -/ - -section IntegratedGradients - -variable {n m : Type*} [Fintype n] [Fintype m] - -/-- Integrated Gradients attribution from baseline x₀ to input x. - -IG_i(x, x₀) = (x_i - x₀_i) · ∫₀¹ ∂f/∂x_i(x₀ + t(x - x₀)) dt - -For a linear function f(x) = x · M (row-vector convention), this simplifies to: - IG_i = (x_i - x₀_i) · M_{i,j} (gradient × input difference for output j) - -The key insight: IG is a path integral of linearizations along -the straight line from baseline to input. -/ -noncomputable def integratedGradientsLinear - (M : SignedMixer n m) (x₀ x : n → ℝ) (i : n) (j : m) : ℝ := - (x i - x₀ i) * M.w i j - -/-- For linear functions, IG equals output difference (completeness). -/ -theorem integratedGradients_linear_complete - (M : SignedMixer n n) (x₀ x : n → ℝ) (j : n) : - ∑ i, integratedGradientsLinear M x₀ x i j = - M.apply x j - M.apply x₀ j := by - simp only [integratedGradientsLinear, SignedMixer.apply_def] - rw [← Finset.sum_sub_distrib] - congr 1 - ext i - ring - -/-- Placeholder: the full piecewise-linear IG statement is not yet formalized. -/ -theorem integratedGradients_piecewise_linear_placeholder - (_regions : List (Linearization n n)) - (_weights : List ℝ) - (_hWeightSum : _weights.sum = 1) : - True := trivial - -end IntegratedGradients - -/-! ## Deep Linearization: Multi-Layer Transformer Analysis - -This section formalizes how attention patterns and their Jacobian decompositions -compose through multiple transformer layers. The key insight is that when we -compose layer Jacobians, we can track how much of the composition comes from -"value terms" (fixed attention flow) versus "pattern terms" (attention shifts). - -This provides a mathematical foundation for: -1. **Attention Rollout** validity across multiple layers -2. **Virtual Heads** (e.g., induction heads where L2 attention flows through L1) -3. **Circuit Analysis** with certified error bounds --/ - -section DeepLinearization - -variable {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - -/-! ### Deep Linearization Structure -/ - -/-- Factorization of an MLP Jacobian into input/output weights and activation derivatives. -/ -structure MLPFactorization (n d : Type*) [Fintype n] [Fintype d] where - /-- Hidden dimension for the MLP layer. -/ - hidden : Type* - /-- Finiteness for the hidden dimension. -/ - instFintype : Fintype hidden - /-- Decidable equality for the hidden dimension (for diagonal mixers). -/ - instDecEq : DecidableEq hidden - /-- Nonempty hidden dimension (for operator-norm bounds). -/ - instNonempty : Nonempty hidden - /-- Input weights: residual stream → hidden. -/ - win : SignedMixer (n × d) hidden - /-- Output weights: hidden → residual stream. -/ - wout : SignedMixer hidden (n × d) - /-- Activation derivative (diagonal) at the linearization point. -/ - deriv : hidden → ℝ - -attribute [instance] MLPFactorization.instFintype -attribute [instance] MLPFactorization.instDecEq -attribute [instance] MLPFactorization.instNonempty - -/-- The Jacobian represented by an `MLPFactorization`. -/ -noncomputable def MLPFactorization.jacobian - (F : MLPFactorization (n := n) (d := d)) : SignedMixer (n × d) (n × d) := - (F.win.comp (diagMixer F.deriv)).comp F.wout - -/-- A deep linearization captures the Jacobian decomposition of a multi-layer network. - -For a transformer with L layers, this tracks: -- The per-layer attention Jacobians and their V/P decompositions -- The MLP Jacobians (via an explicit factorization) -- The composed end-to-end Jacobian - -The key insight: composing (I + A₁)(I + M₁)(I + A₂)(I + M₂)... creates -cross-layer terms where attention from layer L flows through layer L-1. -These "virtual heads" are what make mechanisms like induction heads work. -/ -structure DeepLinearization where - /-- Number of layers (as a finite type index) -/ - numLayers : ℕ - /-- Per-layer attention linearizations -/ - layers : Fin numLayers → AttentionLinearization n d - /-- Per-layer LayerNorm Jacobians before attention (ln_1). -/ - ln1Jacobians : Fin numLayers → SignedMixer (n × d) (n × d) - /-- Per-layer MLP factorization data. -/ - mlpFactors : Fin numLayers → MLPFactorization (n := n) (d := d) - /-- Per-layer LayerNorm Jacobians before MLP (ln_2). -/ - ln2Jacobians : Fin numLayers → SignedMixer (n × d) (n × d) - /-- Final LayerNorm Jacobian (ln_f) applied after the last layer. -/ - lnFJacobian : SignedMixer (n × d) (n × d) := SignedMixer.identity - /-- The composed end-to-end Jacobian -/ - composedJacobian : SignedMixer (n × d) (n × d) - -/-- Per-layer MLP Jacobians derived from the factorization data. -/ -noncomputable def DeepLinearization.mlpJacobians - (D : DeepLinearization (n := n) (d := d)) : - Fin D.numLayers → SignedMixer (n × d) (n × d) := - fun i => (D.mlpFactors i).jacobian - -/-- Get the full Jacobian of a specific layer (including residual). -/ -noncomputable def DeepLinearization.layerJacobian (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) : SignedMixer (n × d) (n × d) := - let attnJac := (D.ln1Jacobians i).comp (D.layers i).fullJacobian - let mlpJac := (D.ln2Jacobians i).comp (D.mlpJacobians i) - (SignedMixer.identity + attnJac).comp (SignedMixer.identity + mlpJac) - -/-- Residual bound for a layer Jacobian from bounds on attention/MLP Jacobians. -/ -theorem DeepLinearization.layerJacobian_residual_bound - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) (i : Fin D.numLayers) - (A M : ℝ) - (hA : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) ≤ A) - (hM : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) ≤ M) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ A + M + A * M := by - classical - set attnJac := (D.ln1Jacobians i).comp (D.layers i).fullJacobian - set mlpJac := (D.ln2Jacobians i).comp (D.mlpJacobians i) - have hA' : SignedMixer.operatorNormBound attnJac ≤ A := by simpa [attnJac] using hA - have hM' : SignedMixer.operatorNormBound mlpJac ≤ M := by simpa [mlpJac] using hM - have hres := - SignedMixer.operatorNormBound_residual_comp_le_of_bounds - (A := attnJac) (M := mlpJac) (a := A) (b := M) hA' hM' - simpa [DeepLinearization.layerJacobian, attnJac, mlpJac] using hres - -/-- The composed Jacobian from layer `start` to layer `stop` (exclusive). -/ -noncomputable def DeepLinearization.rangeJacobian (D : DeepLinearization (n := n) (d := d)) - (start stop : ℕ) : SignedMixer (n × d) (n × d) := - if _h : start < stop ∧ stop ≤ D.numLayers then - (List.range (stop - start)).foldl - (fun acc i => - if hi : start + i < D.numLayers then - acc.comp (D.layerJacobian ⟨start + i, hi⟩) - else acc) - SignedMixer.identity - else SignedMixer.identity - -/-! ### Virtual Attention Heads -/ - -/-- **Virtual Head**: The composition of value terms from two layers. - -When Layer L₂ attends to position k, and Layer L₁ at position k attends to position j, -the composed flow from j to the final output creates a "virtual head" with pattern: - VirtualHead_{L₂,L₁}(i→q) = Σ_k A₂_{qk} · A₁_{ki} · (projections) - -This is the formal definition of "attention composition" used in: -- Attention Rollout (approximating with just attention weights) -- Induction head analysis (L2 attends to L1's output) -- Copy suppression analysis --/ -noncomputable def VirtualHead - (L₂ L₁ : AttentionLinearization n d) : SignedMixer (n × d) (n × d) := - (valueTerm L₁).comp (valueTerm L₂) - -omit [DecidableEq n] [DecidableEq d] in -/-- Virtual head is the composition of two value terms. -/ -theorem VirtualHead_is_comp (L₂ L₁ : AttentionLinearization n d) : - VirtualHead L₂ L₁ = (valueTerm L₁).comp (valueTerm L₂) := rfl - -/-- Position-collapsed virtual head: shows position-to-position flow. -/ -noncomputable def PositionVirtualHead - (L₂ L₁ : AttentionLinearization n d) : SignedMixer n n where - w := fun i q => - -- Sum over all intermediate positions k and dimensions - ∑ k : n, - L₁.state.attentionWeights k i * - L₂.state.attentionWeights q k * - (valueOutputTrace L₁) * (valueOutputTrace L₂) - -omit [DecidableEq n] [DecidableEq d] in -/-- Position virtual head is attention composition scaled by value-entry sums. -/ -theorem PositionVirtualHead_eq_attention_comp - (L₂ L₁ : AttentionLinearization n d) (i q : n) : - (PositionVirtualHead L₂ L₁).w i q = - (∑ k : n, L₂.state.attentionWeights q k * L₁.state.attentionWeights k i) * - (valueOutputTrace L₁ * valueOutputTrace L₂) := by - simp only [PositionVirtualHead, valueOutputTrace] - rw [Finset.sum_mul] - apply Finset.sum_congr rfl - intro k _ - ring - -/-! ### Deep Value Term -/ - -/-- **Deep Value Term**: The composition of all value terms through a deep network. - -This is what "Attention Rollout" computes—treating attention weights as fixed -and composing them through layers. It's the first-order approximation that -ignores how attention patterns shift. -/ -noncomputable def DeepValueTerm (D : DeepLinearization (n := n) (d := d)) : - SignedMixer (n × d) (n × d) := - let core := - if _h : 0 < D.numLayers then - (List.range D.numLayers).foldl - (fun acc i => - if hi : i < D.numLayers then - let L := D.layers ⟨i, hi⟩ - let ln := D.ln1Jacobians ⟨i, hi⟩ - -- Pre-LN: absorb ln_1 linearization into the value path. - acc.comp (SignedMixer.identity + ln.comp (valueTerm L)) - else acc) - SignedMixer.identity - else SignedMixer.identity - -- Final normalization is applied after all blocks. - core.comp D.lnFJacobian - -/-! ### Deep Pattern Term (Error) -/ - -/-- **Deep Pattern Term**: The error from approximating full Jacobian by value terms. - -DeepPatternTerm = composedJacobian - DeepValueTerm - -This measures how much the actual network behavior differs from what -"Attention Rollout" would predict. When this is small, attention visualization -is faithful to the network's actual computation. -/ -noncomputable def DeepPatternTerm (D : DeepLinearization (n := n) (d := d)) : - SignedMixer (n × d) (n × d) := - D.composedJacobian - DeepValueTerm D - -/-- Deep decomposition: composedJacobian = DeepValueTerm + DeepPatternTerm. -/ -theorem deep_jacobian_decomposition (D : DeepLinearization (n := n) (d := d)) : - D.composedJacobian = DeepValueTerm D + DeepPatternTerm D := by - simp only [DeepPatternTerm] - ext i j - simp [add_sub_cancel] - -/-! ### Error Norms and Bounds -/ - -/-- Frobenius norm of a SignedMixer. -/ -noncomputable def frobeniusNorm (M : SignedMixer (n × d) (n × d)) : ℝ := - Real.sqrt (∑ i : n × d, ∑ j : n × d, (M.w i j) ^ 2) - -/-- **Main structural insight**: Deep error is bounded (by definition, since matrices are finite). - -This is the foundational existence statement: every deep pattern term has a finite -Frobenius norm bound. More refined bounds relating this to layer-wise errors require -additional assumptions about network structure. -/ -theorem deep_error_bounded_by_layer_errors (D : DeepLinearization (n := n) (d := d)) : - ∃ (bound : ℝ), frobeniusNorm (DeepPatternTerm D) ≤ bound := - ⟨frobeniusNorm (DeepPatternTerm D), le_refl _⟩ - -/-- Operator norm bound (submultiplicativity approximation). -/ -noncomputable def operatorNormBound [Nonempty n] [Nonempty d] - (M : SignedMixer (n × d) (n × d)) : ℝ := - SignedMixer.operatorNormBound M - -/-! ### RoPE bounds -/ - -section RoPEBounds - -variable {pos pair : Type*} - [Fintype pos] [DecidableEq pos] [Nonempty pos] - [Fintype pair] [DecidableEq pair] [Nonempty pair] - -/-- **Certification lemma (row-sum bound)**: RoPE has a universal `operatorNormBound` ≤ 2. - -Each RoPE row has at most two nonzero entries, `cos` and `±sin`, whose absolute values are ≤ 1. -/ - theorem rope_operatorNormBound_le_two (θ : pos → pair → ℝ) : - operatorNormBound (n := pos) (d := RoPEDim pair) - (ropeJacobian (pos := pos) (pair := pair) θ) ≤ (2 : ℝ) := by - classical - -- Reduce `sup' ≤ 2` to a per-row absolute row-sum bound. - dsimp [operatorNormBound, SignedMixer.operatorNormBound] - refine (Finset.sup'_le_iff (s := (Finset.univ : Finset (pos × RoPEDim pair))) - (f := fun i : pos × RoPEDim pair => - ∑ j : pos × RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w i j|) - (H := Finset.univ_nonempty)).2 ?_ - intro i _hi - rcases i with ⟨p, ⟨k, b⟩⟩ - have hrow : - (∑ j : pos × RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) j|) ≤ (2 : ℝ) := by - -- Expand the row-sum over `pos × pair × Bool` and collapse the `pos`/`pair` sums using - -- `Fintype.sum_eq_single` (all other terms are zero by definition of `ropeJacobian`). - have hpos : - (∑ j : pos, - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (j, j')|) - = - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, j')| := by - -- `Fintype.sum_eq_single` in mathlib now has a single side-condition. - have hzero : - ∀ x : pos, - x ≠ p → - (∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) - (x, j')|) = 0 := by - intro x hx - have hpx : p ≠ x := by - simpa [eq_comm] using hx - simp [ropeJacobian, hpx] - simpa using - (Fintype.sum_eq_single (f := fun x : pos => - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (x, j')|) p hzero) - have hpair : - (∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, j')|) - = - ∑ bb : Bool, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (k, bb))| := by - simp only [RoPEDim, Fintype.sum_prod_type] - have hzero : - ∀ x : pair, - x ≠ k → - (|(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) - (p, (x, true))|) - + - (|(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) - (p, (x, false))|) = 0 := by - intro x hx - have hkx : k ≠ x := by - simpa [eq_comm] using hx - simp [ropeJacobian, hkx] - -- Repackage into `Fintype.sum_eq_single` over `pair`. - simpa [Fintype.univ_bool] using - (Fintype.sum_eq_single (f := fun x : pair => - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (x, true))| + - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (x, false))|) - k hzero) - have hbool : - (∑ bb : Bool, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (k, bb))|) - = |Real.cos (θ p k)| + |Real.sin (θ p k)| := by - cases b <;> - simp [ropeJacobian, RoPEDim, Fintype.univ_bool, abs_neg, add_comm] - calc - (∑ j : pos × RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) j|) - = - (∑ j : pos, - ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (j, j')|) := by - simp [Fintype.sum_prod_type] - _ = ∑ j' : RoPEDim pair, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, j')| := hpos - _ = ∑ bb : Bool, - |(ropeJacobian (pos := pos) (pair := pair) θ).w (p, (k, b)) (p, (k, bb))| := hpair - _ = |Real.cos (θ p k)| + |Real.sin (θ p k)| := hbool - _ ≤ 1 + 1 := by - exact add_le_add (Real.abs_cos_le_one _) (Real.abs_sin_le_one _) - _ = (2 : ℝ) := by norm_num - exact hrow - -end RoPEBounds - -variable [Nonempty n] [Nonempty d] - -/-- **Per-layer error contribution**: The pattern term norm of each layer. -/ -noncomputable def layerError (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) : ℝ := - frobeniusNorm (patternTerm (D.layers i)) - -/-- **Total layer error sum**: Σᵢ ‖patternTerm(layer i)‖. -/ -noncomputable def totalLayerError (D : DeepLinearization (n := n) (d := d)) : ℝ := - ∑ i : Fin D.numLayers, layerError D i - -/-! ### The Composition of Faithfulness Theorem -/ - -/-- **Layer faithfulness**: A layer is ε-faithful if its pattern term has norm ≤ ε. -/ -def isLayerFaithful (L : AttentionLinearization n d) (ε : ℝ) : Prop := - frobeniusNorm (patternTerm L) ≤ ε - -/-- **Deep faithfulness**: A deep network is ε-faithful if its deep pattern term has norm ≤ ε. -/ -def isDeepFaithful (D : DeepLinearization (n := n) (d := d)) (ε : ℝ) : Prop := - frobeniusNorm (DeepPatternTerm D) ≤ ε - -/-- The key bound constant: amplification from residual Jacobian norms. -This product ignores `lnFJacobian`; if it is nontrivial, multiply by -`operatorNormBound D.lnFJacobian` to bound end-to-end amplification. -/ -noncomputable def amplificationFactor (D : DeepLinearization (n := n) (d := d)) : ℝ := - -- Product of (1 + ‖layerJacobian - I‖) for all layers - (List.range D.numLayers).foldl - (fun acc i => - if hi : i < D.numLayers then - acc * (1 + operatorNormBound (D.layerJacobian ⟨i, hi⟩ - SignedMixer.identity)) - else acc) - 1 - -/-- **Two-layer composition theorem**: Explicit bound for 2-layer case. - -If Layer 1 is ε₁-faithful and Layer 2 is ε₂-faithful, and both residual layer -maps `(I + fullJacobian)` have operator norm bounded by C, then the composition is approximately -(ε₁ · C + ε₂ · C + ε₁ · ε₂)-faithful. - -The ε₁ · ε₂ term is second-order and often negligible when ε₁, ε₂ are small. -/ -theorem two_layer_faithfulness_composition - (L₁ L₂ : AttentionLinearization n d) - (ε₁ ε₂ C : ℝ) - (_hC₁ : operatorNormBound (SignedMixer.identity + L₁.fullJacobian) ≤ C) - (_hC₂ : operatorNormBound (SignedMixer.identity + L₂.fullJacobian) ≤ C) - (_hε₁ : isLayerFaithful L₁ ε₁) - (_hε₂ : isLayerFaithful L₂ ε₂) - (_hε₁_pos : 0 ≤ ε₁) (_hε₂_pos : 0 ≤ ε₂) (_hC_pos : 0 ≤ C) : - -- The composed error is bounded - ∃ (ε_composed : ℝ), - ε_composed ≤ C * ε₁ + C * ε₂ + ε₁ * ε₂ := by - exact ⟨C * ε₁ + C * ε₂ + ε₁ * ε₂, le_refl _⟩ - -/-! ### N-Layer Faithfulness Composition Theorem - -The key insight for deep networks is that errors compound multiplicatively: -- Each layer's pattern term contributes error εᵢ -- But that error is amplified by all subsequent layers -- The amplification factor for layer i is ∏_{j>i} (1 + Cⱼ) where - Cⱼ bounds ‖layerJacobianⱼ - I‖ - -The total error bound is: - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -This formula is crucial because it shows: -1. Errors in early layers matter more (they get amplified more) -2. Keeping layer norms small (Cⱼ close to 0) keeps amplification low -3. The bound is tight: achieved when all errors compound constructively --/ - -/-- Per-layer residual norm bounds: Cᵢ bounds ‖layerJacobianᵢ - I‖. -/ -noncomputable def layerNormBounds (D : DeepLinearization (n := n) (d := d)) : - Fin D.numLayers → ℝ := - fun i => operatorNormBound (D.layerJacobian i - SignedMixer.identity) - -/-- Per-layer faithfulness: εᵢ bounds ‖patternTermᵢ‖. -/ -noncomputable def layerFaithfulness (D : DeepLinearization (n := n) (d := d)) : - Fin D.numLayers → ℝ := - fun i => frobeniusNorm (patternTerm (D.layers i)) - -/-- Suffix amplification factor: ∏_{j≥start} (1 + Cⱼ), -where Cⱼ bounds ‖layerJacobianⱼ - I‖. This is how much error from layer `start` -gets amplified by subsequent layers. - -When start = numLayers, this equals 1 (no amplification). -/ -noncomputable def suffixAmplification (D : DeepLinearization (n := n) (d := d)) - (start : ℕ) : ℝ := - (List.range (D.numLayers - start)).foldl - (fun acc i => - if hi : start + i < D.numLayers then - acc * (1 + layerNormBounds D ⟨start + i, hi⟩) - else acc) - 1 - -/-- Base case: suffix amplification starting at numLayers is 1. -/ -theorem suffixAmplification_base (D : DeepLinearization (n := n) (d := d)) : - suffixAmplification D D.numLayers = 1 := by - simp only [suffixAmplification, Nat.sub_self, List.range_zero, List.foldl_nil] - -/-- The amplificationFactor equals suffixAmplification starting from 0. -/ -theorem amplificationFactor_eq_suffix (D : DeepLinearization (n := n) (d := d)) : - amplificationFactor D = suffixAmplification D 0 := by - simp only [amplificationFactor, suffixAmplification, layerNormBounds, Nat.sub_zero] - congr 1 - ext acc i - simp only [zero_add] - -/-- **Recursive total error formula**: Total error with amplification. - -ε_total = Σᵢ εᵢ · suffixAmplification(i+1) - -Each layer's error is amplified by all subsequent layers. -/ -noncomputable def totalAmplifiedError (D : DeepLinearization (n := n) (d := d)) : ℝ := - ∑ i : Fin D.numLayers, layerFaithfulness D i * suffixAmplification D (i.val + 1) - -/-- Suffix amplification is nonnegative. -/ -theorem suffixAmplification_nonneg (D : DeepLinearization (n := n) (d := d)) - (start : ℕ) (hNorm : ∀ i : Fin D.numLayers, 0 ≤ layerNormBounds D i) : - 0 ≤ suffixAmplification D start := by - unfold suffixAmplification - -- We prove a stronger statement: for any init ≥ 0, the foldl result is ≥ 0 - suffices h : ∀ init : ℝ, 0 ≤ init → - 0 ≤ (List.range (D.numLayers - start)).foldl - (fun acc i => if hi : start + i < D.numLayers then - acc * (1 + layerNormBounds D ⟨start + i, hi⟩) else acc) - init by - exact h 1 (by norm_num : (0 : ℝ) ≤ 1) - intro init hinit - generalize (List.range (D.numLayers - start)) = xs - induction xs generalizing init with - | nil => simp [hinit] - | cons x xs ih => - simp only [List.foldl_cons] - split_ifs with hi - · apply ih - apply mul_nonneg hinit - linarith [hNorm ⟨start + x, hi⟩] - · exact ih init hinit - -/- -These lemmas don't need the `[Nonempty _]` section variables (they are in scope -for other theorems in this section), so we explicitly omit them to satisfy the -unused-section-vars linter. --/ -omit [Nonempty n] [Nonempty d] in -/-- Layer faithfulness is nonnegative (Frobenius norm is nonneg). -/ -theorem layerFaithfulness_nonneg (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) : 0 ≤ layerFaithfulness D i := by - simp only [layerFaithfulness] - apply Real.sqrt_nonneg - -/-- Total amplified error is nonnegative. -/ -theorem totalAmplifiedError_nonneg (D : DeepLinearization (n := n) (d := d)) - (hNorm : ∀ i : Fin D.numLayers, 0 ≤ layerNormBounds D i) : - 0 ≤ totalAmplifiedError D := by - apply Finset.sum_nonneg - intro i _ - apply mul_nonneg - · exact layerFaithfulness_nonneg D i - · exact suffixAmplification_nonneg D (i.val + 1) hNorm - -/-- **N-Layer Faithfulness Composition Theorem**. - -If each layer i is εᵢ-faithful (‖patternTermᵢ‖ ≤ εᵢ) and has operator norm -bounded by Cᵢ (‖layerJacobianᵢ - I‖ ≤ Cᵢ, hence ‖layerJacobianᵢ‖ ≤ 1 + Cᵢ), -then the deep network is -ε_total-faithful where: - - ε_total = Σᵢ εᵢ · ∏_{j>i} (1 + Cⱼ) - -This is the central theorem enabling layer-by-layer verification: -instead of analyzing the full deep network at once, we can: -1. Check each layer's faithfulness (small pattern term) -2. Bound each layer's operator norm -3. Compose the bounds using this theorem - -**Key insight**: Early layer errors compound more because they pass through -more subsequent layers. This explains why attention patterns in early layers -are often harder to interpret—their errors get amplified more. -/ -theorem n_layer_faithfulness_composition - (D : DeepLinearization (n := n) (d := d)) - (εs : Fin D.numLayers → ℝ) - (Cs : Fin D.numLayers → ℝ) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) (εs i)) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ Cs i) - (hε_pos : ∀ i, 0 ≤ εs i) - (hC_pos : ∀ i, 0 ≤ Cs i) : - -- The deep network faithfulness is bounded by the amplified sum - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ∑ i : Fin D.numLayers, - εs i * (List.range (D.numLayers - (i.val + 1))).foldl - (fun acc j => - if hj : i.val + 1 + j < D.numLayers then - acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) - else acc) - 1 := by - -- The witness is exactly the bound formula - let suffix_bound : Fin D.numLayers → ℝ := fun i => - (List.range (D.numLayers - (i.val + 1))).foldl - (fun acc j => - if hj : i.val + 1 + j < D.numLayers then - acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) - else acc) - 1 - -- Helper: suffix_bound is nonnegative - have hsuffix_nonneg : ∀ i : Fin D.numLayers, 0 ≤ suffix_bound i := by - intro i - simp only [suffix_bound] - -- We prove: for any init ≥ 0, foldl result is ≥ 0 - suffices h : ∀ init : ℝ, 0 ≤ init → - 0 ≤ (List.range (D.numLayers - (i.val + 1))).foldl - (fun acc j => if hj : i.val + 1 + j < D.numLayers then - acc * (1 + Cs ⟨i.val + 1 + j, hj⟩) else acc) - init by - exact h 1 (by norm_num : (0 : ℝ) ≤ 1) - intro init hinit - generalize (List.range (D.numLayers - (i.val + 1))) = xs - induction xs generalizing init with - | nil => simp [hinit] - | cons x xs ih => - simp only [List.foldl_cons] - split_ifs with hj - · apply ih - apply mul_nonneg hinit - linarith [hC_pos ⟨i.val + 1 + x, hj⟩] - · exact ih init hinit - use ∑ i : Fin D.numLayers, εs i * suffix_bound i - constructor - · -- Nonnegativity - apply Finset.sum_nonneg - intro i _ - apply mul_nonneg (hε_pos i) (hsuffix_nonneg i) - · -- The bound is satisfied (trivially, since we chose exactly this bound) - exact le_refl _ - -/-- Simplified N-layer bound with uniform constants. - -If all layers have ‖patternTerm‖ ≤ ε and ‖layerJacobian - I‖ ≤ C, then: - ε_total ≤ ε · L · (1 + C)^{L-1} - -where L is the number of layers. This shows exponential growth in depth -when C > 0, but constant growth when C = 0 (pure attention without MLP scaling). -/ -theorem n_layer_uniform_bound - (D : DeepLinearization (n := n) (d := d)) - (ε C : ℝ) - (_hL : 0 < D.numLayers) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) ε) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ C) - (hε_pos : 0 ≤ ε) - (hC_pos : 0 ≤ C) : - -- Simplified bound with uniform constants - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ε * D.numLayers * (1 + C) ^ (D.numLayers - 1) := by - use ε * D.numLayers * (1 + C) ^ (D.numLayers - 1) - constructor - · apply mul_nonneg - · apply mul_nonneg hε_pos - exact Nat.cast_nonneg D.numLayers - · apply pow_nonneg; linarith - · exact le_refl _ - -/-- Geometric series interpretation of the N-layer bound. - -When all Cs are equal to C, the suffix amplification forms a geometric series: -suffixAmplification(i) = (1 + C)^{L-i} - -The total error becomes: -ε_total = Σᵢ εᵢ · (1 + C)^{L-1-i} - -For uniform εᵢ = ε: -ε_total = ε · Σᵢ (1 + C)^{L-1-i} = ε · ((1+C)^L - 1) / C when C ≠ 0 - = ε · L when C = 0 - -This shows that for "attention-only" networks (C ≈ 0), error grows linearly -with depth, while for networks with significant MLP scaling (C > 0), error -grows exponentially. -/ -theorem n_layer_geometric_bound - (D : DeepLinearization (n := n) (d := d)) - (ε C : ℝ) - (_hL : 0 < D.numLayers) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) ε) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ C) - (hε_pos : 0 ≤ ε) - (hC_pos : 0 < C) : - -- The geometric series bound - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ε * ((1 + C) ^ D.numLayers - 1) / C := by - use ε * ((1 + C) ^ D.numLayers - 1) / C - constructor - · apply div_nonneg - · apply mul_nonneg hε_pos - have h1C : 1 ≤ 1 + C := by linarith - have hpow : 1 ≤ (1 + C) ^ D.numLayers := one_le_pow₀ h1C - linarith - · linarith - · exact le_refl _ - -/-- Zero-norm case: when all residual Jacobians have zero operator norm, error adds linearly. - -This is the best-case scenario for interpretability: each layer's error -contributes independently without amplification. -/ -theorem n_layer_zero_norm_bound - (D : DeepLinearization (n := n) (d := d)) - (ε : ℝ) - (_hLayerFaithful : ∀ i, isLayerFaithful (D.layers i) ε) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ 0) - (hε_pos : 0 ≤ ε) : - -- Linear bound when amplification is 1 - ∃ (ε_total : ℝ), - 0 ≤ ε_total ∧ - ε_total ≤ ε * D.numLayers := by - use ε * D.numLayers - constructor - · apply mul_nonneg hε_pos - exact Nat.cast_nonneg D.numLayers - · exact le_refl _ - -/-- The connection to totalLayerError: when amplification is 1. - -Without amplification (all residual layer norms ≤ 0), the N-layer bound reduces to -the simple sum of layer errors, matching totalLayerError. -/ -theorem totalLayerError_eq_n_layer_no_amplification - (D : DeepLinearization (n := n) (d := d)) - (_hLayerNorm : ∀ i, operatorNormBound (D.layerJacobian i - SignedMixer.identity) ≤ 0) : - totalLayerError D ≤ ∑ i : Fin D.numLayers, layerFaithfulness D i := by - simp only [totalLayerError, layerError, layerFaithfulness] - exact le_refl _ - -/-! ### Certified Virtual Attention -/ - -/-- **Certified Virtual Head**: A virtual head is ε-certified if the composition -of value terms approximates the true composed Jacobian within ε. - -This is the key definition for "interpretability certification": -when we claim "this is an induction head," we can certify that the -attention-based explanation is within ε of the true mechanism. -/ -def isCertifiedVirtualHead - (L₂ L₁ : AttentionLinearization n d) - (composedJacobian : SignedMixer (n × d) (n × d)) - (ε : ℝ) : Prop := - frobeniusNorm (composedJacobian - VirtualHead L₂ L₁) ≤ ε - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Virtual head error budget from layer faithfulness**. - -This packages a combined ε bound (ε ≤ ε₁ + ε₂ + ε₁·ε₂); it does not -assert `isCertifiedVirtualHead` for a specific composed Jacobian. -/ -theorem virtual_head_certification - (L₂ L₁ : AttentionLinearization n d) - (ε₁ ε₂ : ℝ) - (_hε₁ : isLayerFaithful L₁ ε₁) - (_hε₂ : isLayerFaithful L₂ ε₂) : - -- The virtual head approximation has bounded error - ∃ (ε : ℝ), ε ≤ ε₁ + ε₂ + ε₁ * ε₂ := by - exact ⟨ε₁ + ε₂ + ε₁ * ε₂, le_refl _⟩ - -/-! ### Induction Head Formalization -/ - -/-- **Induction Head Pattern**: Layer 2 follows the attention structure created by Layer 1. - -An induction head occurs when: -- Layer 1 (previous-token head, simplified): A₁[i, i] is high (self-attention stand-in for i-1) -- Layer 2 (induction head, simplified): attention weights are nonnegative (softmax), - with token-matching handled by external witnesses. - -The composed pattern A₂ · A₁ creates "in-context learning" behavior. -/ -structure InductionHeadPattern where - /-- Layer 1: the previous-token attention head -/ - layer1 : AttentionLinearization n d - /-- Layer 2: the induction attention head -/ - layer2 : AttentionLinearization n d - /-- Layer 1 strongly attends to previous position -/ - prevTokenStrong : ∀ i : n, 0.5 ≤ layer1.state.attentionWeights i i - -- In practice, this would be i attending to i-1, but we simplify - /-- Layer 2 has nonnegative attention weights (softmax); token matching is handled elsewhere. -/ - inductionStrong : ∀ q k : n, layer2.state.attentionWeights q k ≥ 0 - -/-- The effective "induction pattern" created by composing the heads. -/ -noncomputable def inductionPattern (H : InductionHeadPattern (n := n) (d := d)) : - SignedMixer n n := - PositionVirtualHead H.layer2 H.layer1 - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- **Induction head error budget**: combine per-layer bounds into ε. - -This provides a concrete ε bound (ε ≤ ε₁ + ε₂ + ε₁·ε₂); it does not, by itself, -certify a specific composed Jacobian. -/ -theorem induction_head_certified (H : InductionHeadPattern (n := n) (d := d)) - (ε₁ ε₂ : ℝ) - (_hε₁ : isLayerFaithful H.layer1 ε₁) - (_hε₂ : isLayerFaithful H.layer2 ε₂) - (hε₁_pos : 0 ≤ ε₁) (hε₂_pos : 0 ≤ ε₂) : - -- The virtual head computation is approximately correct - ∃ (ε : ℝ), 0 ≤ ε ∧ ε ≤ ε₁ + ε₂ + ε₁ * ε₂ := by - refine ⟨ε₁ + ε₂ + ε₁ * ε₂, ?_, le_refl _⟩ - nlinarith - -/-! ### Interpretability Illusion Detection -/ - -/-- **Interpretability Illusion**: When pattern terms dominate value terms. - -A discovered "circuit" might be an illusion if the pattern term is large -relative to the value term—meaning the attention patterns are unstable -and the simple attention-based explanation is misleading. -/ -def isInterpretabilityIllusion (L : AttentionLinearization n d) (threshold : ℝ) : Prop := - frobeniusNorm (patternTerm L) > threshold * frobeniusNorm (valueTerm L) - -/-- **Genuine Mechanism**: When value terms dominate pattern terms. - -A mechanism is "genuine" (not an illusion) when the value term captures -most of the Jacobian and the pattern term is relatively small. -/ -def isGenuineMechanism (L : AttentionLinearization n d) (threshold : ℝ) : Prop := - frobeniusNorm (patternTerm L) ≤ threshold * frobeniusNorm (valueTerm L) - -omit [DecidableEq n] [DecidableEq d] [Nonempty n] [Nonempty d] in -/-- Mechanisms are either illusions or genuine (assuming reasonable threshold). -/ -theorem mechanism_trichotomy (L : AttentionLinearization n d) (threshold : ℝ) - (_hpos : 0 < threshold) : - isGenuineMechanism L threshold ∨ isInterpretabilityIllusion L threshold := by - by_cases h : frobeniusNorm (patternTerm L) ≤ threshold * frobeniusNorm (valueTerm L) - · left; exact h - · right - push_neg at h - exact h - -/-- **Deep mechanism certification**: A multi-layer mechanism is genuine if -all constituent layers have small pattern terms. -/ -def isDeepGenuineMechanism (D : DeepLinearization (n := n) (d := d)) (threshold : ℝ) : Prop := - ∀ i : Fin D.numLayers, isGenuineMechanism (D.layers i) threshold - -omit [Nonempty n] [Nonempty d] in -/-- If all layers are genuine, the deep pattern term is bounded. -/ -theorem deep_genuine_implies_bounded (D : DeepLinearization (n := n) (d := d)) - (threshold : ℝ) - (hGenuine : isDeepGenuineMechanism D threshold) - (_hthreshold_pos : 0 ≤ threshold) : - -- Deep pattern term is bounded by layer value terms and threshold - totalLayerError D ≤ threshold * - (∑ i : Fin D.numLayers, frobeniusNorm (valueTerm (D.layers i))) := by - simp only [totalLayerError, layerError, isDeepGenuineMechanism, isGenuineMechanism] at * - rw [Finset.mul_sum] - apply Finset.sum_le_sum - intro i _ - exact hGenuine i - -end DeepLinearization - -end Nfp diff --git a/Nfp/Mixer.lean b/Nfp/Mixer.lean index f4604df..66abc23 100644 --- a/Nfp/Mixer.lean +++ b/Nfp/Mixer.lean @@ -1,198 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.Set.Basic -import Aesop -import Nfp.Prob +module -/-! -Forward mixers (Appendix A.1): abstract, finite, row-stochastic operators and -their basic closure properties. This formalizes the “forward mixer” notion from -the Appendix and shows that they preserve probability mass and are closed under -composition. --/ - -namespace Nfp - -open scoped BigOperators - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- A row-stochastic nonnegative mixer from `S` to `T` (Appendix A.1). -/ -structure Mixer (S T : Type*) [Fintype S] [Fintype T] where - w : S → T → NNReal - row_sum_one : ∀ i : S, (∑ j, w i j) = 1 - -namespace Mixer - -variable (M : Mixer S T) - -@[ext] -theorem ext {M N : Mixer S T} (h : ∀ i j, M.w i j = N.w i j) : M = N := by - cases M; cases N; simp only [mk.injEq]; funext i j; exact h i j - -/-- Apply a mixer to a probability row `p` to produce a probability row on `T`. -This corresponds to pushing tracer mass through a forward mixer (Appendix A.1). -/ -noncomputable def push (p : ProbVec S) : ProbVec T := - { - mass := fun j => ∑ i, p.mass i * M.w i j, - norm_one := by - classical - -- Show ∑_j (∑_i p_i * w_{i,j}) = 1 by swapping sums and using row sums = 1 - have : (∑ j, (∑ i, p.mass i * M.w i j)) = 1 := by - calc - (∑ j, (∑ i, p.mass i * M.w i j)) - = (∑ i, (∑ j, p.mass i * M.w i j)) := by - simpa [mul_comm, mul_left_comm, mul_assoc] using Finset.sum_comm - _ = (∑ i, p.mass i * (∑ j, M.w i j)) := by - simp [Finset.mul_sum] - _ = (∑ i, p.mass i * 1) := by - simp [M.row_sum_one] - _ = (∑ i, p.mass i) := by - simp - _ = 1 := by - simp [ProbVec.sum_mass (ι:=S) p] - simp [this] } - -/-- The `i`-th row of a mixer as a probability vector on the target type. -/ -noncomputable def row (i : S) : ProbVec T := - { - mass := fun j => M.w i j - norm_one := by - classical - simpa using M.row_sum_one i - } - -@[simp] lemma row_mass (i : S) (j : T) : (M.row i).mass j = M.w i j := rfl - -/-- Pushing a point-mass probability vector through a mixer selects the corresponding row. -/ -theorem push_pure (i : S) : M.push (ProbVec.pure (ι := S) i) = M.row i := by - ext j - classical - simp [Mixer.push, Mixer.row, ProbVec.pure, Finset.sum_ite_eq', Finset.mem_univ] - -/-- `push` is affine: it commutes with convex mixtures of probability vectors. -/ -theorem push_mix (c : NNReal) (hc : c ≤ 1) (p q : ProbVec S) : - M.push (ProbVec.mix (ι := S) c hc p q) = - ProbVec.mix (ι := T) c hc (M.push p) (M.push q) := by - ext j - classical - simp [Mixer.push, ProbVec.mix, Finset.sum_add_distrib, add_mul, mul_assoc, Finset.mul_sum] - -/-- Composition of mixers is again a mixer (closure under composition). -/ -noncomputable def comp (N : Mixer T U) : Mixer S U := - { - w := fun i k => ∑ j, M.w i j * N.w j k, - row_sum_one := by - classical - intro i - -- ∑_k ∑_j M i j * N j k = ∑_j M i j * (∑_k N j k) = ∑_j M i j * 1 = 1 - calc - (∑ k, (∑ j, M.w i j * N.w j k)) = 1 := by - -- Same calculation as above but stated in one calc for direct closure - calc - (∑ k, (∑ j, M.w i j * N.w j k)) - = (∑ j, (∑ k, M.w i j * N.w j k)) := by - simpa [mul_comm, mul_left_comm, mul_assoc] using Finset.sum_comm - _ = (∑ j, M.w i j * (∑ k, N.w j k)) := by - simp [Finset.mul_sum] - _ = (∑ j, M.w i j * 1) := by - simp [N.row_sum_one] - _ = (∑ j, M.w i j) := by - simp - _ = 1 := by - simpa using M.row_sum_one i - } - -end Mixer - -end Nfp +public import Nfp.Mixer.Basic +public import Nfp.Mixer.Operations /-! -Support-restricted mixers (Appendix A.1, support restriction): compact helpers -to state and propagate that a mixer has zero weights outside a given binary -relation `R : S → T → Prop`. These do not change the existing proofs; they -encode the “only route along executed edges” constraint as a property. +Row-stochastic mixers. -/ - -namespace Nfp - -open scoped BigOperators - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -namespace Mixer - -/-- A mixer `M : Mixer S T` is `supported` by `R` if every weight outside `R` is zero. -/ -def supported (M : Mixer S T) (R : S → T → Prop) : Prop := - ∀ i j, ¬ R i j → M.w i j = 0 - -@[simp] lemma supported_zero {M : Mixer S T} {R : S → T → Prop} - (h : supported (S := S) (T := T) M R) {i : S} {j : T} (hij : ¬ R i j) : M.w i j = 0 := - h i j hij - -attribute [aesop safe] supported_zero - -/-- Relational composition of supports: `R ⋆ Q` allows an edge `i → k` iff there -exists `j` with `i → j` allowed by `R` and `j → k` allowed by `Q`. -/ -def compSupport (R : S → T → Prop) (Q : T → U → Prop) : S → U → Prop := - fun i k => ∃ j, R i j ∧ Q j k - -/-- If `M` is supported by `R` and `N` by `Q`, then `M.comp N` is supported by `R ⋆ Q`. -/ -lemma supported_comp {M : Mixer S T} {N : Mixer T U} - {R : S → T → Prop} {Q : T → U → Prop} - (hM : supported (S := S) (T := T) M R) (hN : supported (S := T) (T := U) N Q) : - supported (S := S) (T := U) (M.comp N) (compSupport R Q) := by - classical - intro i k hnot - -- For every j, either ¬R i j or ¬Q j k; hence the product weight vanishes. - have hforall : ∀ j, M.w i j * N.w j k = 0 := by - intro j - have hnot_and : ¬ (R i j ∧ Q j k) := by - -- `¬ ∃ j, R i j ∧ Q j k` ⇒ `¬ (R i j ∧ Q j k)` for each `j`. - intro hAnd - exact hnot ⟨j, hAnd⟩ - have hdisj : ¬ R i j ∨ ¬ Q j k := (not_and_or.mp hnot_and) - aesop (add safe [hM, hN]) - -- Sum of zero terms equals 0 - have : (∑ j, M.w i j * N.w j k) = 0 := by - have hfun : (fun j => M.w i j * N.w j k) = (fun _ => (0 : NNReal)) := by - funext j; simpa using hforall j - simp [hfun] - -- This is exactly the weight on `(i,k)` inside `M.comp N`. - simp [Mixer.comp, this] - -/-- The support (positions with nonzero mass) of a probability vector. -/ -def supp (p : ProbVec S) : Set S := fun i => p.mass i ≠ 0 - -/-- Image of a set of sources along a binary support relation. -/ -def image (R : S → T → Prop) (A : Set S) : Set T := fun k => ∃ i, A i ∧ R i k - -/-- Support propagation: if `M` is supported by `R`, then any nonzero pushed mass at `k` -originates from some source `i` with nonzero mass and an allowed edge `R i k`. -/ -lemma supp_push_subset_image {M : Mixer S T} {R : S → T → Prop} - (hM : supported (S := S) (T := T) M R) (p : ProbVec S) : - supp (S := T) (M.push p) ⊆ image (S := S) (T := T) R (supp (S := S) p) := by - classical - intro k hk - have hk_mass : (M.push p).mass k ≠ 0 := hk - have hk' : (∑ i, p.mass i * M.w i k) ≠ 0 := by - simpa [Mixer.push] using hk_mass - obtain ⟨i, _hi, hne⟩ := Finset.exists_ne_zero_of_sum_ne_zero hk' - have hpi : p.mass i ≠ 0 := by - intro h0 - exact hne (by simp [h0]) - have hwi : M.w i k ≠ 0 := by - intro h0 - exact hne (by simp [h0]) - have hR : R i k := by - by_contra hR - exact hwi (hM i k hR) - exact ⟨i, ⟨hpi, hR⟩⟩ - -end Mixer - -end Nfp diff --git a/Nfp/Mixer/Basic.lean b/Nfp/Mixer/Basic.lean new file mode 100644 index 0000000..6f07d65 --- /dev/null +++ b/Nfp/Mixer/Basic.lean @@ -0,0 +1,43 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Prob.Basic + +/-! +Row-stochastic mixers. +-/ + +public section + +open scoped BigOperators + +namespace Nfp + +universe u + +/-- A row-stochastic mixer from `ι` to `κ`. -/ +structure Mixer (ι κ : Type u) [Fintype ι] [Fintype κ] where + /-- Nonnegative weights for each source/target pair. -/ + weight : ι → κ → Mass + /-- Each row is a probability vector. -/ + row_sum : ∀ i, (∑ k, weight i k) = 1 + +attribute [simp] Mixer.row_sum + +namespace Mixer + +variable {ι κ : Type u} [Fintype ι] [Fintype κ] + +instance : CoeFun (Mixer ι κ) (fun _ => ι → κ → Mass) := ⟨Mixer.weight⟩ + +/-- The row of a mixer as a probability vector. -/ +def row (M : Mixer ι κ) (i : ι) : ProbVec κ := + { mass := fun k => M.weight i k + sum_mass := M.row_sum i } + +end Mixer + +end Nfp + +end diff --git a/Nfp/Mixer/Operations.lean b/Nfp/Mixer/Operations.lean new file mode 100644 index 0000000..ee88bec --- /dev/null +++ b/Nfp/Mixer/Operations.lean @@ -0,0 +1,63 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Mixer.Basic +public import Nfp.Prob.Operations +public import Mathlib.Algebra.BigOperators.Ring.Finset + +/-! +Mixer operations (pushforward, composition, identity). +-/ + +public section + +open scoped BigOperators + +namespace Nfp +namespace Mixer + +universe u + +variable {ι κ α : Type u} [Fintype ι] [Fintype κ] [Fintype α] + +/-- Swap a double sum and factor the inner sum out by multiplication. -/ +private lemma sum_mul_sum (p : ι → Mass) (w : ι → κ → Mass) : + (∑ k, ∑ i, p i * w i k) = ∑ i, p i * ∑ k, w i k := by + classical + calc + ∑ k, ∑ i, p i * w i k = ∑ i, ∑ k, p i * w i k := by + exact + (Finset.sum_comm (s := (Finset.univ : Finset κ)) (t := (Finset.univ : Finset ι)) + (f := fun k i => p i * w i k)) + _ = ∑ i, p i * ∑ k, w i k := by + refine Finset.sum_congr rfl ?_ + intro i _ + simp [Finset.mul_sum] + +/-- Push a probability vector forward along a mixer. -/ +def push (M : Mixer ι κ) (p : ProbVec ι) : ProbVec κ := + { mass := fun k => ∑ i, p.mass i * M.weight i k + sum_mass := by + classical + simp [sum_mul_sum] } + +/-- Composition of two mixers. -/ +def comp (M : Mixer ι κ) (N : Mixer κ α) : Mixer ι α := + { weight := fun i a => ∑ k, M.weight i k * N.weight k a + row_sum := by + classical + intro i + simp [sum_mul_sum] } + +/-- Identity mixer. -/ +def id (ι : Type u) [Fintype ι] [DecidableEq ι] : Mixer ι ι := + { weight := fun i j => (ProbVec.pure i).mass j + row_sum := by + intro i + simp } + +end Mixer +end Nfp + +end diff --git a/Nfp/MixerLocalSystem.lean b/Nfp/MixerLocalSystem.lean deleted file mode 100644 index e3f2486..0000000 --- a/Nfp/MixerLocalSystem.lean +++ /dev/null @@ -1,68 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Fintype.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Logic.Equiv.Defs -import Mathlib.Data.Fintype.Card -import Nfp.Mixer -import Nfp.Uniqueness - -/- -Bridge from row-stochastic mixers on finite DAGs to `LocalSystem`. This is -useful for reusing the `tracer_unique` theorem on mixers that come equipped -with a topological ordering. --/ - -namespace Nfp - -open Finset - -variable {Site : Type*} [Fintype Site] [DecidableEq Site] - -namespace LocalSystem - -/-- Interpret a mixer as a `LocalSystem` using an explicit numbering of sites. -/ -noncomputable def ofMixerIdx {n : ℕ} (M : Mixer Site Site) (e : Site ≃ Fin n) - (acyclic : ∀ s t, M.w s t ≠ 0 → e s < e t) : - LocalSystem n := by - classical - let siteOf : Fin n → Site := e.symm - refine - { - Pa := fun i => - (Finset.univ.filter - (fun u : Fin n => - M.w (siteOf u) (siteOf i) ≠ 0)) - c := fun i u => M.w (siteOf u) (siteOf i) - topo := by - intro i u hu - have hmem := Finset.mem_filter.mp hu - have hweight : M.w (siteOf u) (siteOf i) ≠ 0 := hmem.2 - have htopo : e (siteOf u) < e (siteOf i) := acyclic _ _ hweight - simpa [siteOf] using htopo - } - -/-- Interpret a mixer as a `LocalSystem`, given a topological index `topo` and -a compatibility witness `respect` showing that the canonical `Fintype` ordering -aligns with `topo`. The `acyclic` assumption enforces the DAG constraint -`topo s < topo t` whenever `M.w s t` is nonzero. -/ -noncomputable def ofMixer (M : Mixer Site Site) (topo : Site → ℕ) - (acyclic : ∀ s t, M.w s t ≠ 0 → topo s < topo t) - (respect : - ∀ {s t}, topo s < topo t → - (Fintype.equivFin Site s).1 < (Fintype.equivFin Site t).1) : - LocalSystem (Fintype.card Site) := by - classical - let e : Site ≃ Fin (Fintype.card Site) := Fintype.equivFin Site - have hindex : - ∀ s t, M.w s t ≠ 0 → e s < e t := by - intro s t hwt - have htopo := acyclic s t hwt - have horder : (Fintype.equivFin Site s).1 < (Fintype.equivFin Site t).1 := - respect htopo - simpa [e] using horder - exact ofMixerIdx (Site := Site) (M := M) (e := e) hindex - -end LocalSystem - -end Nfp diff --git a/Nfp/Model.lean b/Nfp/Model.lean new file mode 100644 index 0000000..cd4f980 --- /dev/null +++ b/Nfp/Model.lean @@ -0,0 +1,12 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Model.Gpt2 +public import Nfp.Model.InductionHead +public import Nfp.Model.InductionPrompt +public import Nfp.Model.InductionCircuit + +/-! +Model-specific data containers for the NFP rewrite. +-/ diff --git a/Nfp/Model/Gpt2.lean b/Nfp/Model/Gpt2.lean new file mode 100644 index 0000000..1f14211 --- /dev/null +++ b/Nfp/Model/Gpt2.lean @@ -0,0 +1,128 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.ValueRange + +/-! +Exact GPT-2 slices for induction certification. + +This module holds token embeddings, head projection weights, and per-layer +MLP/LayerNorm parameters used to define `InductionHeadInputs` and bound +computations. +-/ + +public section + +namespace Nfp + +namespace Model + +open Nfp.Circuit + +/-- Token indices describing a logit-diff direction (target minus negative). -/ +structure DirectionTokens (vocab : Nat) where + /-- Target token index. -/ + target : Fin vocab + /-- Negative token index. -/ + negative : Fin vocab + +/-- Convert `DirectionTokens` to a `DirectionSpec`. -/ +def DirectionTokens.spec {vocab : Nat} (dir : DirectionTokens vocab) : DirectionSpec := + { target := dir.target.val, negative := dir.negative.val } + +/-- Exact GPT-2 head slice needed to build induction-head inputs. -/ +structure Gpt2HeadSlice (seq dModel dHead vocab : Nat) where + /-- Softmax scale factor (e.g. `1/8` for head dim 64). -/ + scale : Rat + /-- Token ids for the prompt. -/ + tokens : Fin seq → Fin vocab + /-- Token embedding matrix. -/ + wte : Fin vocab → Fin dModel → Rat + /-- Positional embedding matrix. -/ + wpe : Fin seq → Fin dModel → Rat + /-- Query projection weights. -/ + wq : Fin dModel → Fin dHead → Rat + /-- Query projection bias. -/ + bq : Fin dHead → Rat + /-- Key projection weights. -/ + wk : Fin dModel → Fin dHead → Rat + /-- Key projection bias. -/ + bk : Fin dHead → Rat + /-- Value projection weights. -/ + wv : Fin dModel → Fin dHead → Rat + /-- Value projection bias. -/ + bv : Fin dHead → Rat + /-- Output projection weights for this head slice. -/ + wo : Fin dModel → Fin dHead → Rat + /-- Attention output bias (shared across heads). -/ + attnBias : Fin dModel → Rat + /-- LayerNorm epsilon for the attention input. -/ + lnEps : Rat + /-- LayerNorm scale for the attention input. -/ + ln1Gamma : Fin dModel → Rat + /-- LayerNorm bias for the attention input. -/ + ln1Beta : Fin dModel → Rat + /-- Direction tokens for logit-diff certification. -/ + direction : DirectionTokens vocab + +/-- Exact per-head attention weights and biases. -/ +structure Gpt2HeadWeights (dModel dHead : Nat) where + /-- Query projection weights. -/ + wq : Fin dModel → Fin dHead → Rat + /-- Query projection bias. -/ + bq : Fin dHead → Rat + /-- Key projection weights. -/ + wk : Fin dModel → Fin dHead → Rat + /-- Key projection bias. -/ + bk : Fin dHead → Rat + /-- Value projection weights. -/ + wv : Fin dModel → Fin dHead → Rat + /-- Value projection bias. -/ + bv : Fin dHead → Rat + /-- Output projection weights for this head slice. -/ + wo : Fin dModel → Fin dHead → Rat + +/-- Exact GPT-2 layer slice with MLP and LayerNorm parameters. -/ +structure Gpt2LayerSlice (dModel hidden : Nat) where + /-- Attention output bias (shared across heads). -/ + attnBias : Fin dModel → Rat + /-- MLP input projection weights. -/ + mlpWIn : Fin dModel → Fin hidden → Rat + /-- MLP input projection bias. -/ + mlpBIn : Fin hidden → Rat + /-- MLP output projection weights. -/ + mlpWOut : Fin hidden → Fin dModel → Rat + /-- MLP output projection bias. -/ + mlpBOut : Fin dModel → Rat + /-- LayerNorm scale for the attention input. -/ + ln1Gamma : Fin dModel → Rat + /-- LayerNorm bias for the attention input. -/ + ln1Beta : Fin dModel → Rat + /-- LayerNorm scale for the MLP input. -/ + ln2Gamma : Fin dModel → Rat + /-- LayerNorm bias for the MLP input. -/ + ln2Beta : Fin dModel → Rat + +/-- Final LayerNorm parameters applied before unembedding. -/ +structure Gpt2FinalLayerNorm (dModel : Nat) where + /-- LayerNorm scale. -/ + gamma : Fin dModel → Rat + /-- LayerNorm bias. -/ + beta : Fin dModel → Rat + +/-- Token-plus-position embeddings for a GPT-2 head slice. -/ +def Gpt2HeadSlice.embed {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) : + Fin seq → Fin dModel → Rat := + fun q d => slice.wte (slice.tokens q) d + slice.wpe q d + +/-- Direction vector in model space for a GPT-2 head slice. -/ +def Gpt2HeadSlice.directionVec {seq dModel dHead vocab : Nat} + (slice : Gpt2HeadSlice seq dModel dHead vocab) : Fin dModel → Rat := + fun d => slice.wte slice.direction.target d - slice.wte slice.direction.negative d + +end Model + +end Nfp diff --git a/Nfp/Model/InductionCircuit.lean b/Nfp/Model/InductionCircuit.lean new file mode 100644 index 0000000..18a0e56 --- /dev/null +++ b/Nfp/Model/InductionCircuit.lean @@ -0,0 +1,60 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Model.InductionPrompt + +/-! +Circuit-level induction prompt specifications. + +These wrappers name the *shifted* prev/active maps that correspond to the +canonical induction-head circuit (previous-token head feeding induction head). +They are definitional aliases over `InductionPrevSpec*Shift`, but make the +intended mechanistic interpretation explicit for later lemmas. +-/ + +public section + +namespace Nfp + +namespace Model + +/-- +Circuit-level shifted-prev spec for periodic prompts. + +This matches the canonical induction circuit: a previous-token head shifts +the match by one position, so the induction head attends to `q - period + 1`. +-/ +structure InductionCircuitSpecPeriodShift {seq dModel dHead : Nat} + (period : Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- The underlying shifted-period prev/active spec. -/ + prev_spec : InductionPrevSpecPeriodShift (seq := seq) period inputs + +/-- +Circuit-level shifted-prev spec for token-based prompts. + +This corresponds to the canonical induction circuit with a previous-token head, +so the induction head uses the shifted-token map. +-/ +structure InductionCircuitSpecTokensShift {seq dModel dHead : Nat} + (tokens : Fin seq → Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- The underlying shifted-token prev/active spec. -/ + prev_spec : InductionPrevSpecTokensShift (seq := seq) tokens inputs + +/-- +Lift a shifted-period circuit spec to the token-based spec for diagnostic prompts. +-/ +theorem InductionCircuitSpecTokensShift_of_diag {dModel dHead : Nat} {period : Nat} + (tokens : Fin (2 * period) → Nat) + (inputs : InductionHeadInputs (2 * period) dModel dHead) + (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) + (hspec : InductionCircuitSpecPeriodShift (seq := 2 * period) period inputs) : + InductionCircuitSpecTokensShift (seq := 2 * period) tokens inputs := by + refine ⟨?_,⟩ + exact + InductionPrevSpecTokensShift_of_diag (tokens := tokens) hdiag hper hspec.prev_spec + +end Model + +end Nfp diff --git a/Nfp/Model/InductionHead.lean b/Nfp/Model/InductionHead.lean new file mode 100644 index 0000000..f2d879a --- /dev/null +++ b/Nfp/Model/InductionHead.lean @@ -0,0 +1,67 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Basic +public import Nfp.Core.Basic +public import Nfp.Circuit.Cert.ValueRange + +/-! +Exact inputs for induction-head scoring and value-direction computations. + +These structures store exact rational inputs (embeddings and weights) for a +single attention head. They are intended to be consumed by sound builders. +-/ + +public section + +namespace Nfp + +namespace Model + +open Nfp.Circuit + +/-- Exact head inputs for induction certification. -/ +structure InductionHeadInputs (seq dModel dHead : Nat) where + /-- Softmax scale factor (e.g. `1/8` for GPT-2-small head dim 64). -/ + scale : Rat + /-- Active queries for which bounds are required. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Token embeddings for the sequence. -/ + embed : Fin seq → Fin dModel → Rat + /-- LayerNorm epsilon used before attention. -/ + lnEps : Rat + /-- LayerNorm scale for pre-attention normalization. -/ + ln1Gamma : Fin dModel → Rat + /-- LayerNorm bias for pre-attention normalization. -/ + ln1Beta : Fin dModel → Rat + /-- Query projection weights. -/ + wq : Fin dModel → Fin dHead → Rat + /-- Query projection bias. -/ + bq : Fin dHead → Rat + /-- Key projection weights. -/ + wk : Fin dModel → Fin dHead → Rat + /-- Key projection bias. -/ + bk : Fin dHead → Rat + /-- Value projection weights. -/ + wv : Fin dModel → Fin dHead → Rat + /-- Value projection bias. -/ + bv : Fin dHead → Rat + /-- Output projection weights (head slice). -/ + wo : Fin dModel → Fin dHead → Rat + /-- Attention output bias (shared across heads). -/ + attnBias : Fin dModel → Rat + /-- Whether to apply a causal mask to attention scores. -/ + maskCausal : Bool + /-- Score value for masked entries (e.g. `-10000` for GPT-2 causal masking). -/ + maskValue : Rat + /-- Logit-diff direction metadata. -/ + directionSpec : DirectionSpec + /-- Logit-diff direction vector in model space. -/ + direction : Fin dModel → Rat + +end Model + +end Nfp diff --git a/Nfp/Model/InductionPrompt.lean b/Nfp/Model/InductionPrompt.lean new file mode 100644 index 0000000..fe50221 --- /dev/null +++ b/Nfp/Model/InductionPrompt.lean @@ -0,0 +1,465 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Finset.Max +public import Mathlib.Data.Fintype.Basic +public import Nfp.Model.InductionHead + +/-! +Helpers for induction-style prompts. + +These are small, deterministic utilities for constructing the `prev` map and +active-query set from a fixed period. They keep the prompt bookkeeping +separate from the model weights. +-/ + +public section + +namespace Nfp + +namespace Model + +/-- `prev` map for a periodic induction prompt: `q ↦ q - period` (truncated at 0). -/ +def prevOfPeriod {seq : Nat} (period : Nat) (q : Fin seq) : Fin seq := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + +/-- Active queries for a periodic induction prompt (`period ≤ q`). -/ +def activeOfPeriod {seq : Nat} (period : Nat) : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun q => period ≤ q.val) + +/-- Membership characterization for `activeOfPeriod`. -/ +theorem mem_activeOfPeriod {seq : Nat} {period : Nat} {q : Fin seq} : + q ∈ activeOfPeriod (seq := seq) period ↔ period ≤ q.val := by + simp [activeOfPeriod] + +/-- +Shifted `prev` map for a periodic induction prompt: if `0 < period` and +`period ≤ q`, return `q - period + 1`; otherwise default to `0`. +-/ +def prevOfPeriodShift {seq : Nat} (period : Nat) (q : Fin seq) : Fin seq := by + classical + by_cases hq : period ≤ q.val + · by_cases hper : 0 < period + · have hlt : q.val - period + 1 < seq := by + have hsub : q.val - period < q.val := Nat.sub_lt_of_pos_le hper hq + have hle : q.val - period + 1 ≤ q.val := Nat.succ_le_of_lt hsub + exact lt_of_le_of_lt hle q.isLt + exact ⟨q.val - period + 1, hlt⟩ + · have hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + exact ⟨0, hpos⟩ + · have hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + exact ⟨0, hpos⟩ + +/-- Active queries for shifted periodic induction prompts (`0 < period ≤ q`). -/ +def activeOfPeriodShift {seq : Nat} (period : Nat) : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun q => 0 < period ∧ period ≤ q.val) + +/-- Membership characterization for `activeOfPeriodShift`. -/ +theorem mem_activeOfPeriodShift {seq : Nat} {period : Nat} {q : Fin seq} : + q ∈ activeOfPeriodShift (seq := seq) period ↔ 0 < period ∧ period ≤ q.val := by + simp [activeOfPeriodShift] + +/-- `prev` map induced by token repeats (defaulting to `0` when no prior match exists). -/ +def prevOfTokens {seq : Nat} (tokens : Fin seq → Nat) (q : Fin seq) : Fin seq := by + classical + let hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + let zero : Fin seq := ⟨0, hpos⟩ + let candidates : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun k => + k.val < q.val ∧ tokens k = tokens q) + by_cases h : candidates.Nonempty + · exact Finset.max' candidates h + · exact zero + +/-- Active queries induced by token repeats. -/ +def activeOfTokens {seq : Nat} (tokens : Fin seq → Nat) : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun q => + ∃ k, k.val < q.val ∧ tokens k = tokens q) + +/-- Membership characterization for `activeOfTokens`. -/ +theorem mem_activeOfTokens {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} : + q ∈ activeOfTokens tokens ↔ ∃ k, k.val < q.val ∧ tokens k = tokens q := by + simp [activeOfTokens] + +/-- If a prior matching token exists, `prevOfTokens` picks a matching index and is maximal. -/ +theorem prevOfTokens_spec {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (h : ∃ k, k < q ∧ tokens k = tokens q) : + let p := prevOfTokens tokens q + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + classical + let candidates : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).filter (fun k => + k < q ∧ tokens k = tokens q) + have hnonempty : candidates.Nonempty := by + rcases h with ⟨k, hk, htok⟩ + exact ⟨k, by simp [candidates, hk, htok]⟩ + by_cases h' : candidates.Nonempty + · have hmem : Finset.max' candidates h' ∈ candidates := + Finset.max'_mem candidates h' + have hcond : + Finset.max' candidates h' < q ∧ + tokens (Finset.max' candidates h') = tokens q := by + have hmem' := (Finset.mem_filter.1 hmem).2 + simpa using hmem' + have hmax : + ∀ k, k < q → tokens k = tokens q → + k ≤ Finset.max' candidates h' := by + intro k hk htok + have hk_mem : k ∈ candidates := by + simp [candidates, hk, htok] + have hk_mem' : k ∈ (candidates : Set (Fin seq)) := by + simpa using hk_mem + exact (Finset.isGreatest_max' (s := candidates) h').2 hk_mem' + simpa [prevOfTokens, candidates, h'] using + And.intro hcond.1 (And.intro hcond.2 hmax) + · exact (h' hnonempty).elim + +/-- Active queries imply the `prevOfTokens` maximal-match specification. -/ +theorem prevOfTokens_spec_of_active {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (hq : q ∈ activeOfTokens tokens) : + let p := prevOfTokens tokens q + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + have h := (mem_activeOfTokens (tokens := tokens) (q := q)).1 hq + rcases h with ⟨k, hk, htok⟩ + have hk' : k < q := by + exact (Fin.lt_def).2 hk + exact prevOfTokens_spec (tokens := tokens) (q := q) ⟨k, hk', htok⟩ + +/-- +Shifted `prev` map for induction: match the current token to its previous +occurrence and return the following position (`A B ... A -> B`). + +Example (tokens = [1,2,1,3,2,1]): + prevShift = [0,0,0,1,0,2], activeShift = {3,5} +-/ +def prevOfTokensShift {seq : Nat} (tokens : Fin seq → Nat) (q : Fin seq) : Fin seq := by + classical + by_cases hq : q ∈ activeOfTokens tokens + · let p := prevOfTokens tokens q + have hp : + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + simpa [p] using + (prevOfTokens_spec_of_active (tokens := tokens) (q := q) hq) + have hpv : p.val < q.val := (Fin.lt_def).1 hp.1 + have hle : p.val + 1 ≤ q.val := Nat.succ_le_of_lt hpv + have hlt : p.val + 1 < seq := lt_of_le_of_lt hle q.isLt + exact ⟨p.val + 1, hlt⟩ + · let hpos : 0 < seq := lt_of_le_of_lt (Nat.zero_le _) q.isLt + exact ⟨0, hpos⟩ + +/-- Active queries for shifted-token induction (same witness condition). -/ +def activeOfTokensShift {seq : Nat} (tokens : Fin seq → Nat) : Finset (Fin seq) := + activeOfTokens tokens + +/-- Membership characterization for `activeOfTokensShift`. -/ +theorem mem_activeOfTokensShift {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} : + q ∈ activeOfTokensShift tokens ↔ ∃ k, k.val < q.val ∧ tokens k = tokens q := by + simp [activeOfTokensShift, activeOfTokens] + +/-- Shifted `prev` agrees with the maximal previous match, advanced by one. -/ +theorem prevOfTokensShift_spec {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (h : ∃ k, k < q ∧ tokens k = tokens q) : + let p := prevOfTokensShift tokens q + let p0 := prevOfTokens tokens q + p.val = p0.val + 1 ∧ + p0 < q ∧ tokens p0 = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p0 := by + classical + have hactive : q ∈ activeOfTokens tokens := by + rcases h with ⟨k, hk, htok⟩ + exact (mem_activeOfTokens (tokens := tokens) (q := q)).2 + ⟨k, (Fin.lt_def).1 hk, htok⟩ + let p0 := prevOfTokens tokens q + have hp0 : + p0 < q ∧ tokens p0 = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p0 := by + simpa [p0] using (prevOfTokens_spec (tokens := tokens) (q := q) h) + have hpval : (prevOfTokensShift tokens q).val = p0.val + 1 := by + simp [prevOfTokensShift, hactive, p0] + have hpval' : + (prevOfTokensShift tokens q).val = (prevOfTokens tokens q).val + 1 := by + simpa [p0] using hpval + have hp0' : + prevOfTokens tokens q < q ∧ tokens (prevOfTokens tokens q) = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ prevOfTokens tokens q := by + simpa [p0] using hp0 + simpa using And.intro hpval' hp0' + +/-- Active shifted queries imply the shifted `prev` maximal-match specification. -/ +theorem prevOfTokensShift_spec_of_active {seq : Nat} {tokens : Fin seq → Nat} {q : Fin seq} + (hq : q ∈ activeOfTokensShift tokens) : + let p := prevOfTokensShift tokens q + let p0 := prevOfTokens tokens q + p.val = p0.val + 1 ∧ + p0 < q ∧ tokens p0 = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p0 := by + have h := (mem_activeOfTokensShift (tokens := tokens) (q := q)).1 hq + rcases h with ⟨k, hk, htok⟩ + have hk' : k < q := by + exact (Fin.lt_def).2 hk + exact prevOfTokensShift_spec (tokens := tokens) (q := q) ⟨k, hk', htok⟩ + +/-- Active queries select a `prev` strictly in the past. -/ +def InductionPrevInPast {seq dModel dHead : Nat} + (inputs : InductionHeadInputs seq dModel dHead) : Prop := + ∀ q, q ∈ inputs.active → inputs.prev q < q + +/-- +Canonical shifted-prev spec for periodic prompts. + +Note: when `1 < period`, every active query has `prev q < q`. +-/ +structure InductionPrevSpecPeriodShift {seq dModel dHead : Nat} + (period : Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- Active queries are the shifted-period active set. -/ + active_eq : inputs.active = activeOfPeriodShift (seq := seq) period + /-- Prev map matches the shifted-period definition. -/ + prev_eq : inputs.prev = prevOfPeriodShift (seq := seq) period + +/-- +Canonical shifted-prev spec for token-based prompts. + +Note: if successive tokens repeat, the shifted target can coincide with `q`. +-/ +structure InductionPrevSpecTokensShift {seq dModel dHead : Nat} + (tokens : Fin seq → Nat) (inputs : InductionHeadInputs seq dModel dHead) : Prop where + /-- Active queries match the shifted-token definition. -/ + active_eq : inputs.active = activeOfTokensShift (seq := seq) tokens + /-- Prev map matches the shifted-token definition. -/ + prev_eq : inputs.prev = prevOfTokensShift (seq := seq) tokens + +/-- Helper: lift a first-half index into `Fin (2 * period)`. -/ +lemma lt_double_of_lt_period {period i : Nat} (hi : i < period) : i < 2 * period := by + have hle : period ≤ 2 * period := by + have hpos : 0 < (2 : Nat) := by decide + exact Nat.le_mul_of_pos_left period hpos + exact Nat.lt_of_lt_of_le hi hle + +/-- +Tokens are a repeated pattern of length `period` with no repeats in the first half. + +This matches the usual induction diagnostic: a random pattern of length `period` +repeated twice. +-/ +structure InductionDiagnosticTokens (period : Nat) (tokens : Fin (2 * period) → Nat) : Prop where + /-- Second half repeats the first half with period `period`. -/ + repeat_tok : ∀ q : Fin (2 * period), period ≤ q.val → + tokens q = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + /-- The first-half tokens are pairwise distinct. -/ + inj : ∀ {i j : Nat} (hi : i < period) (hj : j < period), + tokens ⟨i, lt_double_of_lt_period hi⟩ = + tokens ⟨j, lt_double_of_lt_period hj⟩ → i = j + +/-- +In a diagnostic prompt (repeated distinct pattern), the previous matching token +for any query in the second half is exactly `q - period`. +-/ +theorem prevOfTokens_eq_prevOfPeriod_of_diag {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) {q : Fin (2 * period)} (hq : period ≤ q.val) : + prevOfTokens tokens q = prevOfPeriod (seq := 2 * period) period q := by + classical + let kq : Fin (2 * period) := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + have hklt : kq < q := by + have hklt' : kq.val < q.val := by + simpa [kq] using (Nat.sub_lt_of_pos_le hper hq) + exact (Fin.lt_def).2 hklt' + have htok : tokens kq = tokens q := by + have := hdiag.repeat_tok q hq + simpa [kq] using this.symm + have hspec := prevOfTokens_spec (tokens := tokens) (q := q) ⟨kq, hklt, htok⟩ + let p := prevOfTokens tokens q + have hp : + p < q ∧ tokens p = tokens q ∧ + ∀ k, k < q → tokens k = tokens q → k ≤ p := by + simpa [p] using hspec + have hqsub : q.val - period < period := by + have hq2 : q.val < 2 * period := q.isLt + have hq2' : q.val < period + period := by simpa [two_mul] using hq2 + exact (Nat.sub_lt_iff_lt_add hq).2 (by simpa [Nat.add_comm] using hq2') + have huniq : + ∀ r : Fin (2 * period), r < q → tokens r = tokens q → r.val = q.val - period := by + intro r hr htokr + by_cases hrper : period ≤ r.val + · have hrsub : r.val - period < period := by + have hr2 : r.val < 2 * period := r.isLt + have hr2' : r.val < period + period := by simpa [two_mul] using hr2 + exact (Nat.sub_lt_iff_lt_add hrper).2 (by simpa [Nat.add_comm] using hr2') + have htok_r : + tokens ⟨r.val - period, lt_of_le_of_lt (Nat.sub_le _ _) r.isLt⟩ = tokens r := by + have := hdiag.repeat_tok r hrper + simpa using this.symm + have htok_q : + tokens q = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using hdiag.repeat_tok q hq + have htok_first : + tokens ⟨r.val - period, lt_of_le_of_lt (Nat.sub_le _ _) r.isLt⟩ = + tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + calc + tokens ⟨r.val - period, lt_of_le_of_lt (Nat.sub_le _ _) r.isLt⟩ + = tokens r := by simpa using htok_r + _ = tokens q := by simpa using htokr + _ = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using htok_q + have hkeq : r.val - period = q.val - period := by + apply hdiag.inj hrsub hqsub + simpa using htok_first + have hrval : r.val = q.val := by + have h := congrArg (fun x => x + period) hkeq + simpa [Nat.sub_add_cancel hrper, Nat.sub_add_cancel hq] using h + have hrlt : r.val < q.val := (Fin.lt_def).1 hr + have hrlt' : r.val < r.val := by + have hrlt' := hrlt + rw [← hrval] at hrlt' + exact hrlt' + exact (False.elim (lt_irrefl _ hrlt')) + · have hrlt : r.val < period := lt_of_not_ge hrper + have hrfin : + (⟨r.val, lt_double_of_lt_period hrlt⟩ : Fin (2 * period)) = r := by + apply Fin.ext + rfl + have htok_q : + tokens q = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using hdiag.repeat_tok q hq + have htok_first : + tokens ⟨r.val, lt_double_of_lt_period hrlt⟩ = + tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + calc + tokens ⟨r.val, lt_double_of_lt_period hrlt⟩ = tokens r := by + simp [hrfin] + _ = tokens q := by simpa using htokr + _ = tokens ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ := by + simpa using htok_q + have hkeq : r.val = q.val - period := by + apply hdiag.inj hrlt hqsub + simpa using htok_first + exact hkeq + have hpval : p.val = q.val - period := by + have := hp.2.2 p hp.1 hp.2.1 + have huniq' := huniq p hp.1 hp.2.1 + exact huniq' + apply Fin.ext + simp [prevOfPeriod, hpval, p] + +/-- Shifted `prev` map matches the period-shifted map under diagnostic tokens. -/ +theorem prevOfTokensShift_eq_prevOfPeriodShift_of_diag {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) {q : Fin (2 * period)} (hq : period ≤ q.val) : + prevOfTokensShift tokens q = prevOfPeriodShift (seq := 2 * period) period q := by + have hprev := + prevOfTokens_eq_prevOfPeriod_of_diag (tokens := tokens) hdiag hper (q := q) hq + have hactive : q ∈ activeOfTokensShift tokens := by + let kq : Fin (2 * period) := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + have hklt : kq < q := by + have hklt' : kq.val < q.val := by + simpa [kq] using (Nat.sub_lt_of_pos_le hper hq) + exact (Fin.lt_def).2 hklt' + have htok : tokens kq = tokens q := by + have := hdiag.repeat_tok q hq + simpa [kq] using this.symm + exact (mem_activeOfTokensShift (tokens := tokens) (q := q)).2 + ⟨kq, hklt, htok⟩ + have hactive' : q ∈ activeOfTokens tokens := by + simpa [activeOfTokensShift] using hactive + simp [prevOfTokensShift, hactive', hprev, prevOfPeriodShift, prevOfPeriod, hq, hper] + +/-- Active shifted queries coincide with the periodic active set in diagnostics. -/ +theorem activeOfTokensShift_eq_activeOfPeriodShift_of_diag {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) : + activeOfTokensShift (seq := 2 * period) tokens = + activeOfPeriodShift (seq := 2 * period) period := by + ext q + constructor + · intro hq + have hq' := (mem_activeOfTokensShift (tokens := tokens) (q := q)).1 hq + rcases hq' with ⟨k, hk, htok⟩ + have hkper : period ≤ q.val := by + by_contra hlt + have hqlt : q.val < period := lt_of_not_ge hlt + have hklt : k.val < period := lt_of_lt_of_le hk (Nat.le_of_lt hqlt) + have hkfin : + (⟨k.val, lt_double_of_lt_period hklt⟩ : Fin (2 * period)) = k := by + apply Fin.ext + rfl + have hqfin : + (⟨q.val, lt_double_of_lt_period hqlt⟩ : Fin (2 * period)) = q := by + apply Fin.ext + rfl + have hkeq : k.val = q.val := by + apply hdiag.inj hklt hqlt + simpa [hkfin, hqfin] using htok + have hk' : k.val < k.val := by + have hk' := hk + rw [← hkeq] at hk' + exact hk' + exact (False.elim (lt_irrefl _ hk')) + exact (mem_activeOfPeriodShift (seq := 2 * period) (period := period) (q := q)).2 + ⟨hper, hkper⟩ + · intro hq + have hq' := (mem_activeOfPeriodShift (seq := 2 * period) (period := period) (q := q)).1 hq + rcases hq' with ⟨_, hqper⟩ + let kq : Fin (2 * period) := + ⟨q.val - period, lt_of_le_of_lt (Nat.sub_le _ _) q.isLt⟩ + have hklt : kq.val < q.val := by + simpa [kq] using (Nat.sub_lt_of_pos_le hper hqper) + have htok : tokens kq = tokens q := by + have := hdiag.repeat_tok q hqper + simpa [kq] using this.symm + exact (mem_activeOfTokensShift (tokens := tokens) (q := q)).2 + ⟨kq, hklt, htok⟩ + +/-- Diagnostic prompts align shifted-token `prev` with the period-shifted map. -/ +theorem prevOfTokensShift_eq_prevOfPeriodShift_of_diag_all {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) : + prevOfTokensShift tokens = prevOfPeriodShift (seq := 2 * period) period := by + funext q + by_cases hqper : period ≤ q.val + · simpa using + (prevOfTokensShift_eq_prevOfPeriodShift_of_diag (tokens := tokens) hdiag hper + (q := q) hqper) + · have hqnot_period : q ∉ activeOfPeriodShift (seq := 2 * period) period := by + intro hqmem + have hcond := + (mem_activeOfPeriodShift (seq := 2 * period) (period := period) (q := q)).1 hqmem + exact (hqper hcond.2).elim + have hactive_eq := + activeOfTokensShift_eq_activeOfPeriodShift_of_diag (tokens := tokens) hdiag hper + have hqnot_tokens_shift : q ∉ activeOfTokensShift (seq := 2 * period) tokens := by + simpa [hactive_eq] using hqnot_period + have hqnot_tokens : q ∉ activeOfTokens tokens := by + simpa [activeOfTokensShift] using hqnot_tokens_shift + simp [prevOfTokensShift, hqnot_tokens, prevOfPeriodShift, hqper] + +/-- Diagnostic prompts let period-shift specs re-express as token-shift specs. -/ +theorem InductionPrevSpecTokensShift_of_diag {dModel dHead : Nat} {period : Nat} + {tokens : Fin (2 * period) → Nat} (hdiag : InductionDiagnosticTokens period tokens) + (hper : 0 < period) {inputs : InductionHeadInputs (2 * period) dModel dHead} + (hspec : InductionPrevSpecPeriodShift (seq := 2 * period) period inputs) : + InductionPrevSpecTokensShift (seq := 2 * period) tokens inputs := by + refine ⟨?active, ?prev⟩ + · have hactive_eq := + activeOfTokensShift_eq_activeOfPeriodShift_of_diag (tokens := tokens) hdiag hper + calc + inputs.active = activeOfPeriodShift (seq := 2 * period) period := hspec.active_eq + _ = activeOfTokensShift (seq := 2 * period) tokens := by + symm + exact hactive_eq + · have hprev_eq := + prevOfTokensShift_eq_prevOfPeriodShift_of_diag_all (tokens := tokens) hdiag hper + calc + inputs.prev = prevOfPeriodShift (seq := 2 * period) period := hspec.prev_eq + _ = prevOfTokensShift tokens := by + symm + exact hprev_eq + +end Model + +end Nfp diff --git a/Nfp/PCC.lean b/Nfp/PCC.lean deleted file mode 100644 index 3eca79a..0000000 --- a/Nfp/PCC.lean +++ /dev/null @@ -1,228 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Algebra.Order.Monoid.Defs -import Aesop -import Nfp.Prob -import Nfp.Reroute.Heat - -/-! -# PCC Helpers (Appendix A.4) - -This module provides small probability/contribution utilities used in the -formalization of Appendix A.4 of the accompanying documentation: - -* `tracerOfContrib` – builds a probability vector from nonnegative contributions. -* `sum_monotone_chain` / `monotone_removed_mass` – monotonicity of accumulated mass - along nested mask chains (with tiny nonnegativity side-conditions handled by `aesop`). - -All proofs are elementary (`simp`, small `aesop` calls on nonnegativity), and avoid `sorry`. --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -variable {S : Type*} - -/-- Nonnegative contributions over `S`. -/ -abbrev Contrib (S : Type*) [Fintype S] := S → NNReal - -/-- Lemma A.4: The tracer distribution equals normalized absolute contributions. -/ -noncomputable def tracerOfContrib [Fintype S] (m : Contrib S) (h : (∑ i, m i) ≠ 0) : ProbVec S := - { - mass := fun i => m i / (∑ j, m j), - norm_one := by - classical - have hsum : (∑ j, m j) ≠ 0 := h - have : (∑ i, m i / (∑ j, m j)) = (∑ i, m i) / (∑ j, m j) := by - simp [Finset.sum_div] - simp [this, div_self hsum] - } - -@[simp] lemma tracerOfContrib_mass [Fintype S] (m : Contrib S) (h : (∑ i, m i) ≠ 0) : - (tracerOfContrib (S:=S) m h).mass = fun i => m i / (∑ j, m j) := rfl - --- PCC monotonicity surrogate: if masks grow, removed normalized mass grows. -/-- Appendix A.4 (monotonicity helper): if masks grow, the removed mass sum grows. -/ -lemma sum_monotone_chain [Fintype S] (A : ℕ → Finset S) (w : S → NNReal) - (hchain : ∀ k, A k ⊆ A (k + 1)) : Monotone (fun k => (A k).sum (fun i => w i)) := by - classical - intro k₁ k₂ hk - refine Nat.le_induction ?base ?step _ hk - · exact le_rfl - · intro k₂ _ ih - have hstep : (A k₂).sum (fun i => w i) ≤ (A (k₂+1)).sum (fun i => w i) := by - refine Finset.sum_le_sum_of_subset_of_nonneg (hchain k₂) ?_ - intro i hi _ - aesop - exact ih.trans hstep - -/-- Appendix A.4 (monotonicity helper): removed mass is monotone along a nested mask chain. -/ -lemma monotone_removed_mass [Fintype S] (A : ℕ → Finset S) (m : Contrib S) - (hchain : ∀ k, A k ⊆ A (k + 1)) : Monotone (fun k => (A k).sum m) := by - simpa using (sum_monotone_chain (A:=A) (w:=m) hchain) - -namespace PCC - -/-- Finite interval `[lower, upper]` used to accumulate PCC area. We only require -`lower ≤ upper` so that the width is realizable in `NNReal`. -/ -structure AInterval where - lower : NNReal - upper : NNReal - hle : lower ≤ upper - -namespace AInterval - -@[simp] lemma width_nonneg (I : AInterval) : 0 ≤ I.upper - I.lower := - (I.upper - I.lower).property - -@[simp] lemma coe_width (I : AInterval) : (I.upper - I.lower : NNReal) = I.upper - I.lower := rfl - -end AInterval - -noncomputable def intervalsFromWeightsAux : NNReal → NNReal → List NNReal → List AInterval - | _, _, [] => [] - | total, acc, w :: ws => - let width := w / total - have hle : acc ≤ acc + width := - le_add_of_nonneg_right (show 0 ≤ width from bot_le) - { lower := acc, upper := acc + width, hle := hle } :: - intervalsFromWeightsAux total (acc + width) ws - -/-- Build the discrete PCC intervals from a list of widths and a total scale. -/ -noncomputable def intervalsFromWeights (total : NNReal) (weights : List NNReal) : List AInterval := - intervalsFromWeightsAux total 0 weights - -variable (f : NNReal → NNReal) - -/-- Evaluate the discrete AUC directly from widths (without constructing intervals). -/ -noncomputable def evalFromWeightsAux : - NNReal → NNReal → List NNReal → NNReal - | _, _, [] => 0 - | total, acc, w :: ws => - let width := w / total - width * f acc + evalFromWeightsAux total (acc + width) ws - -noncomputable def evalFromWeights (total : NNReal) (weights : List NNReal) : NNReal := - evalFromWeightsAux (f:=f) total 0 weights - -/-- Discrete AUC over a finite list of intervals for an arbitrary nonnegative -function `f`. Each interval contributes `(upper - lower) * f lower`. -/ -def AUC (L : List AInterval) : NNReal := - L.foldr (fun I acc => (I.upper - I.lower) * f I.lower + acc) 0 - -@[simp] lemma AUC_nil : AUC (f:=f) [] = 0 := rfl - -@[simp] lemma AUC_cons (I : AInterval) (L : List AInterval) : - AUC (f:=f) (I :: L) = (I.upper - I.lower) * f I.lower + AUC (f:=f) L := rfl - -lemma AUC_append (L₁ L₂ : List AInterval) : - AUC (f:=f) (L₁ ++ L₂) = AUC (f:=f) L₁ + AUC (f:=f) L₂ := by - induction L₁ with - | nil => simp [AUC] - | cons I L₁ ih => - calc - AUC (f:=f) ((I :: L₁) ++ L₂) - = (I.upper - I.lower) * f I.lower + AUC (f:=f) (L₁ ++ L₂) := by - simp [List.cons_append] - _ = (I.upper - I.lower) * f I.lower + (AUC (f:=f) L₁ + AUC (f:=f) L₂) := by - simp [ih] - _ = ((I.upper - I.lower) * f I.lower + AUC (f:=f) L₁) + AUC (f:=f) L₂ := by - ac_rfl - _ = AUC (f:=f) (I :: L₁) + AUC (f:=f) L₂ := by - simp - -lemma AUC_nonneg (L : List AInterval) : 0 ≤ AUC (f:=f) L := by - induction L with - | nil => simp - | cons I L ih => - have hterm : 0 ≤ (I.upper - I.lower) * f I.lower := by - exact mul_nonneg I.width_nonneg (show 0 ≤ f I.lower from bot_le) - have hsum : 0 ≤ (I.upper - I.lower) * f I.lower + AUC (f:=f) L := - add_nonneg hterm ih - calc - 0 ≤ (I.upper - I.lower) * f I.lower + AUC (f:=f) L := hsum - _ = AUC (f:=f) (I :: L) := (AUC_cons (f:=f) I L).symm - -lemma AUC_monotone_append (L₁ L₂ : List AInterval) : - AUC (f:=f) L₁ ≤ AUC (f:=f) (L₁ ++ L₂) := by - have hnonneg := AUC_nonneg (f:=f) L₂ - have hle : AUC (f:=f) L₁ ≤ AUC (f:=f) L₁ + AUC (f:=f) L₂ := by - exact le_add_of_nonneg_right hnonneg - calc - AUC (f:=f) L₁ - ≤ AUC (f:=f) L₁ + AUC (f:=f) L₂ := hle - _ = AUC (f:=f) (L₁ ++ L₂) := (AUC_append (f:=f) L₁ L₂).symm - -lemma AUC_add (L₁ L₂ : List AInterval) : - AUC (f:=f) (L₁ ++ L₂) = AUC (f:=f) L₁ + AUC (f:=f) L₂ := - AUC_append (f:=f) L₁ L₂ - -lemma AUC_intervalsFromWeightsAux (total acc : NNReal) (weights : List NNReal) : - AUC (f:=f) (intervalsFromWeightsAux total acc weights) - = evalFromWeightsAux (f:=f) total acc weights := by - induction weights generalizing acc with - | nil => - simp [intervalsFromWeightsAux, evalFromWeightsAux, AUC] - | cons w ws ih => - have htail := ih (acc + w / total) - simp [intervalsFromWeightsAux, evalFromWeightsAux, htail] - -lemma AUC_intervalsFromWeights (total : NNReal) (weights : List NNReal) : - AUC (f:=f) (intervalsFromWeights total weights) - = evalFromWeights (f:=f) total weights := by - simpa [intervalsFromWeights, evalFromWeights] using - AUC_intervalsFromWeightsAux (f:=f) total 0 weights - -lemma intervalsFromWeightsAux_take (total acc : NNReal) (weights : List NNReal) : - ∀ k, - intervalsFromWeightsAux total acc (weights.take k) - = (intervalsFromWeightsAux total acc weights).take k - | 0 => by simp [intervalsFromWeightsAux] - | Nat.succ k => - by - cases weights with - | nil => simp [intervalsFromWeightsAux] - | cons w ws => - have ih := - intervalsFromWeightsAux_take (total:=total) (acc:=acc + w / total) (weights:=ws) k - simp [intervalsFromWeightsAux, ih, Nat.succ_eq_add_one] - -lemma AUC_intervalsFromWeights_take (total : NNReal) (weights : List NNReal) (k : ℕ) : - AUC (f:=f) ((intervalsFromWeights total weights).take k) - = evalFromWeights (f:=f) total (weights.take k) := by - have h := - AUC_intervalsFromWeightsAux (f:=f) total 0 (weights.take k) - simpa [intervalsFromWeights, evalFromWeights, - intervalsFromWeightsAux_take (total:=total) (acc:=0) (weights:=weights)] - using h - -variable {f} - -end PCC - -namespace WeightedReroutePlan - -open PCC - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Normalized PCC intervals for a weighted reroute plan (widths sum to one). -/ -noncomputable def aucIntervals (P : WeightedReroutePlan (S := S)) : List PCC.AInterval := - intervalsFromWeights (total:=P.weightsSum) P.weights - -lemma auc_eval (P : WeightedReroutePlan (S := S)) (f : NNReal → NNReal) : - PCC.AUC (f:=f) (P.aucIntervals) - = PCC.evalFromWeights (f:=f) P.weightsSum P.weights := by - simpa [aucIntervals] - using PCC.AUC_intervalsFromWeights (f:=f) P.weightsSum P.weights - -end WeightedReroutePlan - -end Nfp diff --git a/Nfp/Prob.lean b/Nfp/Prob.lean index ea652a0..c4ffe1f 100644 --- a/Nfp/Prob.lean +++ b/Nfp/Prob.lean @@ -1,90 +1,10 @@ -- SPDX-License-Identifier: AGPL-3.0-or-later -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic +module -/- -Basic probability-friendly definitions used across the NFP development. -We work with finite types and nonnegative reals `NNReal` from mathlib. --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -/-- A probability vector on a finite type `ι` is a nonnegative function summing to 1. -/ -structure ProbVec (ι : Type*) [Fintype ι] where - mass : ι → NNReal - norm_one : (∑ i, mass i) = (1 : NNReal) - -namespace ProbVec - -variable {ι : Type*} [Fintype ι] - -@[simp] theorem sum_mass (p : ProbVec ι) : (∑ i, p.mass i) = 1 := p.norm_one - -@[ext] -theorem ext {p q : ProbVec ι} (h : ∀ i, p.mass i = q.mass i) : p = q := by - cases p; cases q; simp only [mk.injEq]; funext i; exact h i +public import Nfp.Prob.Basic +public import Nfp.Prob.Operations -theorem mass_le_one (p : ProbVec ι) (i : ι) : p.mass i ≤ 1 := by - have h := p.sum_mass - calc p.mass i ≤ ∑ j, p.mass j := Finset.single_le_sum (by simp) (Finset.mem_univ i) - _ = 1 := h - -/-- The Dirac/point-mass probability vector at `i0`. -/ -noncomputable def pure (i0 : ι) : ProbVec ι := by - classical - refine - { - mass := fun i => if i = i0 then 1 else 0 - norm_one := ?_ - } - simp - -@[simp] lemma pure_mass_self (i0 : ι) : (pure (ι := ι) i0).mass i0 = 1 := by - classical - simp [pure] - -@[simp] lemma pure_mass_ne_self {i0 i : ι} (h : i ≠ i0) : (pure (ι := ι) i0).mass i = 0 := by - classical - simp [pure, h] - -/-- Convex mixture of probability vectors using coefficient `c ∈ [0,1]`. -/ -noncomputable def mix (c : NNReal) (hc : c ≤ 1) (p q : ProbVec ι) : ProbVec ι := - { - mass := fun i => c * p.mass i + (1 - c) * q.mass i - norm_one := by - classical - calc - (∑ i, (c * p.mass i + (1 - c) * q.mass i)) - = (∑ i, c * p.mass i) + (∑ i, (1 - c) * q.mass i) := by - simp [Finset.sum_add_distrib] - _ = c * (∑ i, p.mass i) + (1 - c) * (∑ i, q.mass i) := by - have hp : (∑ i, c * p.mass i) = c * (∑ i, p.mass i) := by - simpa using - (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => p.mass i) - (a := c)).symm - have hq : - (∑ i, (1 - c) * q.mass i) = (1 - c) * (∑ i, q.mass i) := by - simpa using - (Finset.mul_sum (s := (Finset.univ : Finset ι)) (f := fun i : ι => q.mass i) - (a := (1 - c))).symm - simp [hp, hq] - _ = c * 1 + (1 - c) * 1 := by - simp [ProbVec.sum_mass] - _ = c + (1 - c) := by - simp - _ = 1 := by - simpa using (add_tsub_cancel_of_le hc) - } - -@[simp] lemma mix_mass (c : NNReal) (hc : c ≤ 1) (p q : ProbVec ι) (i : ι) : - (mix (ι := ι) c hc p q).mass i = c * p.mass i + (1 - c) * q.mass i := rfl - -end ProbVec - -end Nfp +/-! +Probability vectors. +-/ diff --git a/Nfp/Prob/Basic.lean b/Nfp/Prob/Basic.lean new file mode 100644 index 0000000..586db72 --- /dev/null +++ b/Nfp/Prob/Basic.lean @@ -0,0 +1,39 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Core +public import Mathlib.Data.Fintype.BigOperators + +/-! +Probability vectors on finite types. +-/ + +public section + +open scoped BigOperators + +namespace Nfp + +universe u + +/-- A probability vector on a finite type. -/ +structure ProbVec (ι : Type u) [Fintype ι] where + /-- Mass assigned to each point. -/ + mass : ι → Mass + /-- Total mass is exactly one. -/ + sum_mass : (∑ i, mass i) = 1 + +attribute [simp] ProbVec.sum_mass + +namespace ProbVec + +variable {ι : Type u} [Fintype ι] + +instance : CoeFun (ProbVec ι) (fun _ => ι → Mass) := ⟨ProbVec.mass⟩ + +end ProbVec + +end Nfp + +end diff --git a/Nfp/Prob/Operations.lean b/Nfp/Prob/Operations.lean new file mode 100644 index 0000000..fbf0f0c --- /dev/null +++ b/Nfp/Prob/Operations.lean @@ -0,0 +1,51 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Prob.Basic +public import Mathlib.Algebra.BigOperators.Ring.Finset + +/-! +Basic constructions on probability vectors. +-/ + +public section + +open scoped BigOperators + +namespace Nfp +namespace ProbVec + +universe u + +variable {ι : Type u} [Fintype ι] + +/-- Factor a constant out of a sum. -/ +private lemma sum_mul_const (a : Mass) (p : ι → Mass) : + (∑ i, a * p i) = a * ∑ i, p i := by + simpa using + (Finset.mul_sum (a := a) (s := (Finset.univ : Finset ι)) + (f := fun i => p i)).symm + +/-- The pure distribution at a single point. -/ +def pure (i0 : ι) [DecidableEq ι] : ProbVec ι := by + refine + { mass := Pi.single i0 (1 : Mass) + sum_mass := ?_ } + simp + +@[simp] theorem mass_pure (i0 i : ι) [DecidableEq ι] : + (pure i0).mass i = if i = i0 then 1 else 0 := by + by_cases h : i = i0 <;> simp [pure, Pi.single, h] + +/-- Convex combination of two probability vectors with weights that sum to one. -/ +def mix (a b : Mass) (h : a + b = 1) (p q : ProbVec ι) : ProbVec ι := + { mass := fun i => a * p.mass i + b * q.mass i + sum_mass := by + classical + simp [Finset.sum_add_distrib, sum_mul_const, h] } + +end ProbVec +end Nfp + +end diff --git a/Nfp/Reroute/Heat.lean b/Nfp/Reroute/Heat.lean deleted file mode 100644 index 2b6803b..0000000 --- a/Nfp/Reroute/Heat.lean +++ /dev/null @@ -1,528 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Algebra.BigOperators.Ring.Finset -import Mathlib.Algebra.BigOperators.Field -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Algebra.Order.Ring.Defs -import Mathlib.Algebra.Field.Defs -import Nfp.Prob -import Nfp.Reroute.Partition - -namespace Nfp - -namespace List - -lemma mem_left_of_zip {α β : Type*} : - ∀ {xs : List α} {ys : List β} {a : α} {b : β}, - (a, b) ∈ xs.zip ys → a ∈ xs - | [], _, _, _, h => by cases h - | _ :: _, [], _, _, h => by cases h - | x :: xs, y :: ys, a, b, h => by - have h' : (a, b) = (x, y) ∨ (a, b) ∈ xs.zip ys := by - simpa [List.zip_cons_cons] using h - rcases h' with h | h - · rcases h with ⟨rfl, _⟩ - simp - · exact List.mem_cons.mpr (Or.inr (mem_left_of_zip h)) - -lemma get_mem_zip {α β : Type*} : - ∀ {xs : List α} {ys : List β} {k : Nat} - (hk : k < xs.length) (hy : k < ys.length), - (xs.get ⟨k, hk⟩, ys.get ⟨k, hy⟩) ∈ xs.zip ys - | [], _, k, hk, _ => by - cases hk - | _ :: _, [], k, _, hy => by - cases hy - | x :: xs, y :: ys, k, hk, hy => by - cases k with - | zero => - simp [List.zip_cons_cons] - | succ k => - have hk' : k < xs.length := Nat.lt_of_succ_lt_succ hk - have hy' : k < ys.length := Nat.lt_of_succ_lt_succ hy - have htail := get_mem_zip (xs:=xs) (ys:=ys) hk' hy' - have hpair_eq : - ((x :: xs).get ⟨Nat.succ k, by simpa [List.length_cons] using hk⟩, - (y :: ys).get ⟨Nat.succ k, by simpa [List.length_cons] using hy⟩) - = (xs.get ⟨k, hk'⟩, ys.get ⟨k, hy'⟩) := by - simp - have htail' : - ((x :: xs).get ⟨Nat.succ k, by simpa [List.length_cons] using hk⟩, - (y :: ys).get ⟨Nat.succ k, by simpa [List.length_cons] using hy⟩) - ∈ xs.zip ys := by - simpa [hpair_eq] using htail - have hzip : (x :: xs).zip (y :: ys) = (x, y) :: xs.zip ys := by - simp [List.zip_cons_cons] - have hmem' : - ((x :: xs).get ⟨Nat.succ k, by simpa [List.length_cons] using hk⟩, - (y :: ys).get ⟨Nat.succ k, by simpa [List.length_cons] using hy⟩) - ∈ (x, y) :: xs.zip ys := by - exact List.mem_cons.mpr (Or.inr htail') - simpa [hzip] using hmem' - -end List - -/- -Weighted reroute plans and the induced “heat” probability vector (Stage 2). -The structure couples a reroute plan with per-step weights and enforces the -alignment invariants required to distribute each weight over the incremental -masks. `rerouteHeat` sums the per-block shares, normalizes by the total drop, -and produces a `ProbVec` on `S`. --/ - -open scoped BigOperators -open Finset - -noncomputable section - -section Weighted - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- A reroute plan together with per-step nonnegative weights (logit drops). -/ -structure WeightedReroutePlan (S : Type*) [Fintype S] [DecidableEq S] where - plan : ReroutePlan S - weights : List NNReal - length_eq : weights.length = plan.steps.length - zero_weight_of_empty : - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ plan.increments.zip weights → - A.card = 0 → w = 0 - weights_sum_pos : 0 < weights.foldr (fun w acc => w + acc) 0 - -namespace WeightedReroutePlan - -variable (P : WeightedReroutePlan (S := S)) - -/-- Helper: the total step weight (used for normalization). -/ -def weightsSum : NNReal := - P.weights.foldr (fun w acc => w + acc) 0 - -@[simp] lemma weightsSum_pos : 0 < P.weightsSum := P.weights_sum_pos - -lemma weightsSum_ne_zero : P.weightsSum ≠ 0 := - ne_of_gt P.weightsSum_pos - -lemma length_eq_increments : - P.weights.length = P.plan.increments.length := by - simpa [P.plan.increments_length] using P.length_eq - -private def share (A : Finset S) (w : NNReal) (i : S) : NNReal := - if A.card = 0 then 0 else if i ∈ A then w / (A.card : NNReal) else 0 - -lemma sum_share (A : Finset S) (w : NNReal) : - (∑ i : S, share (S:=S) A w i) = if A.card = 0 then 0 else w := by - classical - by_cases hA : A.card = 0 - · simp [share, hA] - · have hcard : (A.card : NNReal) ≠ 0 := by - exact_mod_cast hA - have hshare_zero : ∀ i ∉ A, share (S:=S) A w i = 0 := by - intro i hi - simp [share, hA, hi] - have hsplit : - (∑ i : S, share (S:=S) A w i) - = ∑ i ∈ A, share (S:=S) A w i := by - classical - let U : Finset S := Finset.univ - have hdisj : Disjoint A (U \ A) := by - refine Finset.disjoint_left.mpr ?_ - intro x hxA hxDiff - exact (Finset.mem_sdiff.mp hxDiff).2 hxA - have hcover : A ∪ (U \ A) = U := by - ext x - by_cases hx : x ∈ A <;> simp [U, hx] - calc - (∑ i : S, share (S:=S) A w i) - = ∑ i ∈ U, share (S:=S) A w i := by simp [U] - _ = ∑ i ∈ A ∪ (U \ A), share (S:=S) A w i := by - simp [U, hcover] - _ = (∑ i ∈ A, share (S:=S) A w i) + - ∑ i ∈ U \ A, share (S:=S) A w i := by - simpa using - (Finset.sum_union hdisj - (f := fun i => share (S:=S) A w i)) - _ = (∑ i ∈ A, share (S:=S) A w i) + 0 := by - refine congrArg (fun t => (∑ i ∈ A, share (S:=S) A w i) + t) ?_ - classical - refine Finset.sum_eq_zero ?_ - intro i hi - have hiA : i ∉ A := (Finset.mem_sdiff.mp hi).2 - simpa [U] using hshare_zero i hiA - _ = ∑ i ∈ A, share (S:=S) A w i := by simp - have hsum' : - (∑ i : S, share (S:=S) A w i) - = (A.card : NNReal) * (w / (A.card : NNReal)) := by - have hconst : - (∑ i ∈ A, share (S:=S) A w i) - = (A.card : NNReal) * (w / (A.card : NNReal)) := by - classical - simp [share, hA] - exact hsplit.trans hconst - have hratio : - (A.card : NNReal) * (w / (A.card : NNReal)) = w := - mul_div_cancel₀ _ hcard - have := hsum'.trans hratio - simpa [hA] using this - -lemma sum_share_self (A : Finset S) (w : NNReal) : - (∑ i ∈ A, share (S:=S) A w i) = if A.card = 0 then 0 else w := by - classical - have hshare := sum_share (S:=S) A w - by_cases hA : A.card = 0 - · have hA' : A = ∅ := Finset.card_eq_zero.mp hA - simp [share, hA'] - · have hzero : - (∑ i ∈ ((Finset.univ : Finset S) \ A), share (S:=S) A w i) = 0 := by - refine Finset.sum_eq_zero ?_ - intro i hi - have hiA : i ∉ A := (Finset.mem_sdiff.mp hi).2 - simp [share, hiA] - have hdisj : - Disjoint A ((Finset.univ : Finset S) \ A) := by - refine Finset.disjoint_left.mpr ?_ - intro i hiA hiDiff - exact (Finset.mem_sdiff.mp hiDiff).2 hiA - have hcover : - A ∪ ((Finset.univ : Finset S) \ A) = (Finset.univ : Finset S) := by - ext i - by_cases hi : i ∈ A <;> simp [hi] - have hsplit := - Finset.sum_union hdisj (f := fun i => share (S:=S) A w i) - have hx_univ : - (∑ i : S, share (S:=S) A w i) - = ∑ i ∈ (Finset.univ : Finset S), share (S:=S) A w i := by - simp - have hxA : - (∑ i ∈ (Finset.univ : Finset S), share (S:=S) A w i) - = (∑ i ∈ A, share (S:=S) A w i) := by - simpa [hcover, hzero] using hsplit - have hx := hx_univ.trans hxA - have hx' := hx.symm - simpa using hx'.trans hshare - -omit [Fintype S] in -lemma sum_share_of_disjoint {A B : Finset S} (w : NNReal) - (hdisj : Disjoint A B) : - (∑ i ∈ A, share (S:=S) B w i) = 0 := by - classical - refine Finset.sum_eq_zero ?_ - intro i hi - have hiB : i ∉ B := by - intro hmem - exact (Finset.disjoint_left.mp hdisj) hi hmem - simp [share, hiB] - -private def heatRawAux : - ∀ (parts : List (Finset S)) (weights : List NNReal), - weights.length = parts.length → S → NNReal - | [], [], _, _ => 0 - | A :: parts, w :: weights, hlen, i => - let htail : weights.length = parts.length := by - have hlen' : - Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen' - share (S:=S) A w i + heatRawAux parts weights htail i - | _, _, _, _ => 0 - -private lemma sum_heatRawAux (parts : List (Finset S)) (weights : List NNReal) - (hlen : weights.length = parts.length) - (hzero : - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ parts.zip weights → A.card = 0 → w = 0) : - (∑ i : S, heatRawAux (S:=S) parts weights hlen i) - = weights.foldr (fun w acc => w + acc) 0 := by - classical - revert weights hlen - induction parts with - | nil => - intro weights hlen hzero - cases weights with - | nil => - simp [heatRawAux] - | cons w weights => - cases hlen - | cons A parts ih => - intro weights hlen hzero - cases weights with - | nil => - cases hlen - | cons w weights => - have hlen' : - weights.length = parts.length := by - have hlen'' : - Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen'' - have hzero_head : - A.card = 0 → w = 0 := by - intro hcard - have hpair : - (A, w) ∈ (A :: parts).zip (w :: weights) := by - simp [List.zip_cons_cons] - exact hzero hpair hcard - have hzero_tail : - ∀ {B : Finset S} {w' : NNReal}, - (B, w') ∈ parts.zip weights → B.card = 0 → w' = 0 := by - intro B w' hpair hcard - have hpair' : - (B, w') ∈ (A :: parts).zip (w :: weights) := by - have : (B, w') ∈ (A, w) :: parts.zip weights := - List.mem_cons.mpr (Or.inr hpair) - simpa [List.zip_cons_cons] using this - exact hzero hpair' hcard - have hsum_tail := - ih weights hlen' hzero_tail - have hsum_head : - (∑ i : S, share (S:=S) A w i) = w := by - have hshare := sum_share (S:=S) A w - by_cases hcard : A.card = 0 - · have hw0 : w = 0 := hzero_head hcard - have : (∑ i : S, share (S:=S) A w i) = 0 := by - simp [share, hcard] - simpa [hshare, hcard, hw0, this] - · simp [hshare, hcard] - have hsum_current : - (∑ i : S, heatRawAux (S:=S) (A :: parts) (w :: weights) hlen i) - = w + ∑ i : S, heatRawAux (S:=S) parts weights hlen' i := by - simp [heatRawAux, Finset.sum_add_distrib, hsum_head] - have hfold : - (w :: weights).foldr (fun w acc => w + acc) 0 - = w + weights.foldr (fun w acc => w + acc) 0 := by - simp - calc - (∑ i : S, heatRawAux (S:=S) (A :: parts) (w :: weights) hlen i) - = w + ∑ i : S, heatRawAux (S:=S) parts weights hlen' i := - hsum_current - _ = w + weights.foldr (fun w acc => w + acc) 0 := by - simp [hsum_tail] - _ = (w :: weights).foldr (fun w acc => w + acc) 0 := - by simp [hfold] - -omit [Fintype S] in -private lemma sum_heatRawAux_disjoint - (parts : List (Finset S)) (weights : List NNReal) - (hlen : weights.length = parts.length) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - {A : Finset S} - (hdisj : ∀ B ∈ parts, Disjoint A B) : - A.sum (fun i => heatRawAux (S:=S) parts weights hlen i) = 0 := by - classical - revert weights hlen hpair hdisj - induction parts generalizing A with - | nil => - intro weights hlen _ _ - cases weights with - | nil => simp [heatRawAux] - | cons _ _ => cases hlen - | cons B parts ih => - intro weights hlen hpair hdisj - cases weights with - | nil => cases hlen - | cons w weights => - have hlen' : weights.length = parts.length := by - have hlen'' : Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen'' - rcases List.pairwise_cons.mp hpair with ⟨hB, htail⟩ - have hdisj_tail : ∀ C ∈ parts, Disjoint A C := fun C hC => hdisj C (by simp [hC]) - have hdisj_head : Disjoint A B := hdisj B (by simp) - have hshare_zero : - A.sum (fun i => share (S:=S) B w i) = 0 := - sum_share_of_disjoint (S:=S) (A:=A) (B:=B) (w:=w) hdisj_head - have htail_zero := ih weights hlen' htail hdisj_tail - simp [heatRawAux, Finset.sum_add_distrib, hshare_zero, htail_zero] - -private lemma sum_heatRawAux_mem_zip - (parts : List (Finset S)) (weights : List NNReal) - (hlen : weights.length = parts.length) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - (hzero : - ∀ {B : Finset S} {w : NNReal}, - (B, w) ∈ parts.zip weights → B.card = 0 → w = 0) - {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ parts.zip weights) : - A.sum (fun i => heatRawAux (S:=S) parts weights hlen i) = w := by - classical - revert weights hlen hpair hzero hmem - induction parts generalizing A w with - | nil => - intro weights hlen _ _ hmem - cases weights with - | nil => cases hmem - | cons _ _ => cases hlen - | cons B parts ih => - intro weights hlen hpair hzero hmem - cases weights with - | nil => cases hlen - | cons z weights => - have hlen' : weights.length = parts.length := by - have hlen'' : Nat.succ weights.length = Nat.succ parts.length := by - simpa [List.length_cons] using hlen - exact Nat.succ.inj hlen'' - rcases List.pairwise_cons.mp hpair with ⟨hB, htail⟩ - have hzero_head : B.card = 0 → z = 0 := by - intro hcard - have : (B, z) ∈ (B :: parts).zip (z :: weights) := by - simp [List.zip_cons_cons] - exact hzero this hcard - have hzero_tail : - ∀ {C : Finset S} {w' : NNReal}, - (C, w') ∈ parts.zip weights → C.card = 0 → w' = 0 := by - intro C w' hC hcard - have : (C, w') ∈ (B :: parts).zip (z :: weights) := by - have : (C, w') ∈ (B, z) :: parts.zip weights := - List.mem_cons.mpr (Or.inr hC) - simpa [List.zip_cons_cons] using this - exact hzero this hcard - have hmem_cons : (A, w) ∈ (B, z) :: parts.zip weights := by - simpa [List.zip_cons_cons] using hmem - rcases List.mem_cons.mp hmem_cons with hhead | htail_mem - · cases hhead - have hshare := sum_share_self (S:=S) (A:=B) (w:=w) - have htail_zero : - B.sum (fun i => heatRawAux (S:=S) parts weights hlen' i) = 0 := by - have hdisjB : ∀ C ∈ parts, Disjoint B C := fun C hC => hB C hC - exact sum_heatRawAux_disjoint (S:=S) - parts weights hlen' htail hdisjB - set rest := fun i => heatRawAux (S:=S) parts weights hlen' i - have hsum_split : - B.sum (fun i => heatRawAux (S:=S) (B :: parts) (w :: weights) hlen i) - = B.sum (fun i => share (S:=S) B w i) + B.sum rest := by - have := Finset.sum_add_distrib - (s:=B) (f:=fun i => share (S:=S) B w i) (g:=rest) - simpa [rest, heatRawAux] - using this - by_cases hcard : B.card = 0 - · have hw : w = 0 := hzero_head hcard - have hshare_zero : - B.sum (fun i => share (S:=S) B w i) = 0 := by - simpa [hcard] using hshare - have hBempty : B = (∅ : Finset S) := Finset.card_eq_zero.mp hcard - have hx : - B.sum (fun i => share (S:=S) B w i) + B.sum rest = 0 := by - have hxshare := hshare_zero - have hxrest := htail_zero - rw [hxshare, hxrest] - simp - have hs : - B.sum (fun i => heatRawAux (S:=S) (B :: parts) (w :: weights) hlen i) - = 0 := by - simpa [hsum_split] using hx - simpa [hw] using hs - · have hshare_eq : - B.sum (fun i => share (S:=S) B w i) = w := by - simpa [hcard] using hshare - have hx : - B.sum (fun i => share (S:=S) B w i) + B.sum rest = w := by - have hxrest := htail_zero - rw [hshare_eq, hxrest] - simp - have hs : - B.sum (fun i => heatRawAux (S:=S) (B :: parts) (w :: weights) hlen i) - = w := by - simpa [hsum_split] using hx - exact hs - · have htail_result := - ih weights hlen' htail hzero_tail htail_mem - have hA_mem : A ∈ parts := List.mem_left_of_zip htail_mem - have hdisjBA : Disjoint B A := hB A hA_mem - have hdisjAB : Disjoint A B := by - refine Finset.disjoint_left.mpr ?_ - intro i hiA hiB - exact (Finset.disjoint_left.mp hdisjBA) hiB hiA - have hshare_zero : - A.sum (fun i => share (S:=S) B z i) = 0 := - sum_share_of_disjoint (S:=S) (A:=A) (B:=B) (w:=z) hdisjAB - set rest := fun i => heatRawAux (S:=S) parts weights hlen' i - have hsum_split : - A.sum (fun i => heatRawAux (S:=S) (B :: parts) (z :: weights) hlen i) - = A.sum (fun i => share (S:=S) B z i) + A.sum rest := by - have := Finset.sum_add_distrib - (s:=A) (f:=fun i => share (S:=S) B z i) (g:=rest) - simpa [rest, heatRawAux] - using this - have hx : - A.sum (fun i => share (S:=S) B z i) + A.sum rest = w := by - rw [hshare_zero, htail_result] - simp - have hs : - A.sum (fun i => heatRawAux (S:=S) (B :: parts) (z :: weights) hlen i) - = w := by - simpa [hsum_split] using hx - exact hs - -def heatRaw (i : S) : NNReal := - heatRawAux (S:=S) P.plan.increments P.weights - P.length_eq_increments i - -lemma sum_heatRaw : - (∑ i : S, P.heatRaw i) = P.weightsSum := by - classical - have hzero : - ∀ {A : Finset S} {w : NNReal}, - (A, w) ∈ P.plan.increments.zip P.weights → - A.card = 0 → w = 0 := - fun {A} {w} => P.zero_weight_of_empty - exact - sum_heatRawAux (S:=S) P.plan.increments P.weights - P.length_eq_increments hzero - -noncomputable def rerouteHeat : ProbVec S := - { - mass := fun i => P.heatRaw i / P.weightsSum, - norm_one := by - classical - have hdiv : - (∑ i : S, P.heatRaw i / P.weightsSum) - = (∑ i : S, P.heatRaw i) / P.weightsSum := - (Finset.sum_div (s:=Finset.univ) (f:=fun i => P.heatRaw i) - (a:=P.weightsSum)).symm - have hsum := P.sum_heatRaw - have hne := P.weightsSum_ne_zero - change (∑ i : S, P.heatRaw i / P.weightsSum) = 1 - simp [hdiv, hsum, hne] - } - -@[simp] lemma rerouteHeat_mass (i : S) : - (P.rerouteHeat).mass i = P.heatRaw i / P.weightsSum := rfl - -lemma heatRaw_sum_increment {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - A.sum (fun i => P.heatRaw i) = w := by - classical - have hpair := ReroutePlan.increments_pairwise (P:=P.plan) - have hzero : - ∀ {B : Finset S} {w' : NNReal}, - (B, w') ∈ P.plan.increments.zip P.weights → B.card = 0 → w' = 0 := - fun {B} {w'} => P.zero_weight_of_empty - simpa [heatRaw] using - sum_heatRawAux_mem_zip (S:=S) - P.plan.increments P.weights P.length_eq_increments - hpair hzero hmem - -lemma rerouteHeat_sum_increment {A : Finset S} {w : NNReal} - (hmem : (A, w) ∈ P.plan.increments.zip P.weights) : - A.sum (fun i => (P.rerouteHeat).mass i) = w / P.weightsSum := by - classical - have hsum := heatRaw_sum_increment (P:=P) hmem - have hdiv : - A.sum (fun i => P.heatRaw i / P.weightsSum) - = (A.sum fun i => P.heatRaw i) / P.weightsSum := by - simpa using - (Finset.sum_div (s:=A) (f:=fun i => P.heatRaw i) - (a:=P.weightsSum)).symm - simp [rerouteHeat_mass, hdiv, hsum] - -end WeightedReroutePlan - -end Weighted - -end - -end Nfp diff --git a/Nfp/Reroute/Partition.lean b/Nfp/Reroute/Partition.lean deleted file mode 100644 index 4fc7822..0000000 --- a/Nfp/Reroute/Partition.lean +++ /dev/null @@ -1,410 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Finset.Basic -import Mathlib.Data.List.Basic -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Tactic.FinCases - -/-! -Partitions and reroute plans. This file hosts the element-level lemmas -needed by the reroute weighting development (Stage 1). It provides: - -* a recursive definition of `unionParts` with rewrite-friendly lemmas -* helpers showing membership existence/uniqueness under disjointness -* an intrinsic notion of `ReroutePlan` -* incremental masks (`increments`) together with coverage and disjointness -* per-element sum lemmas that decompose along disjoint parts --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -section Partitions - -variable {S : Type*} [DecidableEq S] - -/-- Union a list of disjoint components. The recursion is `simp`-friendly so -Lean can peel the head element while reasoning about membership. -/ -@[simp] def unionParts : List (Finset S) → Finset S - | [] => ∅ - | A :: parts => A ∪ unionParts parts - -@[simp] lemma unionParts_nil : unionParts ([] : List (Finset S)) = (∅ : Finset S) := rfl - -@[simp] lemma unionParts_cons (A : Finset S) (parts : List (Finset S)) : - unionParts (A :: parts) = A ∪ unionParts parts := rfl - -@[simp] lemma unionParts_singleton (A : Finset S) : - unionParts [A] = A := by simp - -lemma unionParts_cons_mem {A : Finset S} {parts : List (Finset S)} {i : S} : - i ∈ unionParts (A :: parts) ↔ i ∈ A ∨ i ∈ unionParts parts := by - classical - simp [unionParts, Finset.mem_union] - -lemma mem_unionParts_iff_exists_mem (parts : List (Finset S)) (i : S) : - i ∈ unionParts parts ↔ ∃ A ∈ parts, i ∈ A := by - classical - induction parts with - | nil => - simp - | cons A parts ih => - simp [unionParts, ih, List.mem_cons] - -lemma mem_unionParts_of_mem {parts : List (Finset S)} {A : Finset S} (hA : A ∈ parts) - {i : S} (hi : i ∈ A) : i ∈ unionParts parts := by - classical - exact (mem_unionParts_iff_exists_mem (parts:=parts) (i:=i)).2 ⟨A, hA, hi⟩ - -lemma disjoint_unionParts (A : Finset S) (parts : List (Finset S)) - (h : ∀ B ∈ parts, Disjoint A B) : - Disjoint A (unionParts parts) := by - classical - induction parts with - | nil => - simp - | cons B parts ih => - have hAB : Disjoint A B := h B (by simp) - have htail : ∀ C ∈ parts, Disjoint A C := by - intro C hC - exact h C (by simp [hC]) - have hArest := ih htail - refine Finset.disjoint_left.mpr ?_ - intro x hxA hxUnion - have : x ∈ B ∨ x ∈ unionParts parts := by - simpa [unionParts] using hxUnion - cases this with - | inl hxB => - exact (Finset.disjoint_left.mp hAB) hxA hxB - | inr hxRest => - exact (Finset.disjoint_left.mp hArest) hxA hxRest - -example (A B : Finset S) (i : S) : - i ∈ unionParts [A, B] ↔ i ∈ A ∨ i ∈ B := by - classical - simp [unionParts] - -end Partitions - -lemma pairwise_disjoint_cons_elim {S : Type*} {A : Finset S} {parts : List (Finset S)} - (hpair : (A :: parts).Pairwise (fun B C => Disjoint B C)) : - (∀ B ∈ parts, Disjoint A B) ∧ - parts.Pairwise (fun B C => Disjoint B C) := - List.pairwise_cons.mp hpair - -lemma mem_unionParts_unique {S : Type*} (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) - {A B : Finset S} {i : S} - (hA : A ∈ parts) (hB : B ∈ parts) - (hiA : i ∈ A) (hiB : i ∈ B) : - A = B := by - classical - revert A B - induction parts with - | nil => - intro A B hA hB _ _ - cases hA - | cons C parts ih => - intro A B hA hB hiA hiB - rcases List.pairwise_cons.mp hpair with ⟨hC, htail⟩ - rcases List.mem_cons.mp hA with hAC | hA' - · rcases List.mem_cons.mp hB with hBC | hB' - · cases hAC; cases hBC; exact rfl - · have hiC : i ∈ C := by simpa [hAC] using hiA - have hCB : Disjoint C B := hC B hB' - exact ((Finset.disjoint_left.mp hCB) hiC hiB).elim - · rcases List.mem_cons.mp hB with hBC | hB' - · have hiC : i ∈ C := by simpa [hBC] using hiB - have hCA : Disjoint C A := hC A hA' - exact - ((Finset.disjoint_left.mp hCA) hiC - (by simpa [hA'] using hiA)).elim - · exact ih htail hA' hB' hiA hiB - -section PartitionsFinite - -variable {S : Type*} [Fintype S] [DecidableEq S] - -lemma mem_unionParts_of_univ (parts : List (Finset S)) - (hcover : unionParts parts = (Finset.univ : Finset S)) - (i : S) : i ∈ unionParts parts := by - classical - simp [hcover] - -end PartitionsFinite - -section SumLemmas - -variable {S : Type*} [DecidableEq S] - -lemma sum_unionParts_eq (m : S → NNReal) (parts : List (Finset S)) - (hpair : parts.Pairwise (fun A B => Disjoint A B)) : - (unionParts parts).sum m = - parts.foldr (fun A acc => A.sum m + acc) 0 := by - classical - induction parts with - | nil => - simp [unionParts] - | cons A parts ih => - rcases List.pairwise_cons.mp hpair with ⟨hA, htail⟩ - have hdisj : Disjoint A (unionParts parts) := - disjoint_unionParts (A:=A) (parts:=parts) (fun B hB => hA _ (by simpa using hB)) - have hsum := Finset.sum_union (s₁:=A) (s₂:=unionParts parts) (f:=m) hdisj - have hfold : - (A :: parts).foldr (fun B acc => B.sum m + acc) 0 - = A.sum m + parts.foldr (fun B acc => B.sum m + acc) 0 := by - simp [List.foldr] - calc - (unionParts (A :: parts)).sum m - = (A ∪ unionParts parts).sum m := rfl - _ = A.sum m + (unionParts parts).sum m := hsum - _ = A.sum m + parts.foldr (fun B acc => B.sum m + acc) 0 := by - simp [ih htail] - _ = (A :: parts).foldr (fun B acc => B.sum m + acc) 0 := by - simp [hfold] - -end SumLemmas - -section ReroutePlanStruct - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- A reroute plan is a list of disjoint masks that cover the whole space. -/ -structure ReroutePlan (S : Type*) [Fintype S] [DecidableEq S] where - steps : List (Finset S) - pairwise_disjoint : steps.Pairwise (fun A B => Disjoint A B) - covers_univ : unionParts steps = (Finset.univ : Finset S) - -namespace ReroutePlan - -variable (S) - -/-- Re-export the list of masks. Useful when inferring implicit arguments. -/ -def masks (P : ReroutePlan (S := S)) : List (Finset S) := P.steps - -variable {S} - -lemma mem_cover (P : ReroutePlan (S := S)) (i : S) : - i ∈ unionParts P.steps := - mem_unionParts_of_univ (parts:=P.steps) P.covers_univ i - -lemma exists_block (P : ReroutePlan (S := S)) (i : S) : - ∃ A ∈ P.steps, i ∈ A := - (mem_unionParts_iff_exists_mem (parts:=P.steps) (i:=i)).1 (P.mem_cover i) - -lemma unique_block (P : ReroutePlan (S := S)) (i : S) : - ∃! A, A ∈ P.steps ∧ i ∈ A := by - classical - obtain ⟨A, hA, hiA⟩ := P.exists_block i - refine ⟨A, ⟨hA, hiA⟩, ?_⟩ - intro B hB - apply mem_unionParts_unique (parts:=P.steps) P.pairwise_disjoint <;> - tauto - -end ReroutePlan - -end ReroutePlanStruct - -section Increments - -namespace ReroutePlan - -section Aux - -variable {S : Type*} [DecidableEq S] - -private def incrementsAux : List (Finset S) → Finset S → List (Finset S) - | [], _ => [] - | A :: parts, seen => (A \ seen) :: incrementsAux parts (seen ∪ A) - -@[simp] lemma incrementsAux_nil (seen : Finset S) : - incrementsAux ([] : List (Finset S)) seen = [] := rfl - -@[simp] lemma incrementsAux_cons (A : Finset S) (parts : List (Finset S)) - (seen : Finset S) : - incrementsAux (A :: parts) seen = - (A \ seen) :: incrementsAux parts (seen ∪ A) := rfl - -lemma incrementsAux_length (parts : List (Finset S)) (seen : Finset S) : - (incrementsAux parts seen).length = parts.length := by - classical - induction parts generalizing seen with - | nil => - simp [incrementsAux] - | cons A parts ih => - simp [incrementsAux, ih] - -private lemma incrementsAux_mem_disjoint_seen : - ∀ {parts : List (Finset S)} {seen B : Finset S}, - B ∈ incrementsAux parts seen → Disjoint B seen := by - classical - intro parts - induction parts with - | nil => - intro seen B hB - cases hB - | cons A parts ih => - intro seen B hB - dsimp [incrementsAux] at hB - rcases List.mem_cons.mp hB with hHead | hTail - · subst hHead - refine Finset.disjoint_left.mpr ?_ - intro x hxB hxSeen - exact (Finset.mem_sdiff.mp hxB).2 hxSeen - · have h := ih (seen := seen ∪ A) (B := B) hTail - refine Finset.disjoint_left.mpr ?_ - intro x hxB hxSeen - have hxUnion : x ∈ seen ∪ A := by - have : x ∈ seen := hxSeen - exact (Finset.mem_union.mpr (Or.inl this)) - exact (Finset.disjoint_left.mp h) hxB hxUnion - -private lemma incrementsAux_pairwise (parts : List (Finset S)) (seen : Finset S) : - (incrementsAux parts seen).Pairwise (fun A B => Disjoint A B) := by - classical - induction parts generalizing seen with - | nil => - simp [incrementsAux] - | cons A parts ih => - refine List.pairwise_cons.2 ?_ - constructor - · intro B hB - have hDisjoint := - incrementsAux_mem_disjoint_seen (parts:=parts) (seen:=seen ∪ A) (B:=B) hB - refine Finset.disjoint_left.mpr ?_ - intro x hxHead hxB - aesop (add simp [Finset.disjoint_left, incrementsAux]) - · simpa using ih (seen ∪ A) - -private lemma sdiff_union_left (A B C : Finset S) : - (A \ C) ∪ (B \ (C ∪ A)) = (A ∪ B) \ C := by - classical - ext i - constructor - · intro hx - rcases Finset.mem_union.mp hx with hxA | hxB - · rcases Finset.mem_sdiff.mp hxA with ⟨hiA, hiC⟩ - exact Finset.mem_sdiff.mpr ⟨Finset.mem_union.mpr (Or.inl hiA), hiC⟩ - · rcases Finset.mem_sdiff.mp hxB with ⟨hiB, hiCompl⟩ - have hiC : i ∉ C := by - exact fun hCi => hiCompl (Finset.mem_union.mpr (Or.inl hCi)) - exact Finset.mem_sdiff.mpr ⟨Finset.mem_union.mpr (Or.inr hiB), hiC⟩ - · intro hx - rcases Finset.mem_sdiff.mp hx with ⟨hiUnion, hiC⟩ - by_cases hiA : i ∈ A - · exact Finset.mem_union.mpr (Or.inl (Finset.mem_sdiff.mpr ⟨hiA, hiC⟩)) - · have hiB : i ∈ B := by - have : i ∈ A ∨ i ∈ B := (Finset.mem_union.mp hiUnion) - exact this.resolve_left hiA - have hiNot : i ∉ C ∪ A := by - intro hmem - rcases Finset.mem_union.mp hmem with hC | hA - · exact hiC hC - · exact hiA hA - exact Finset.mem_union.mpr (Or.inr (Finset.mem_sdiff.mpr ⟨hiB, hiNot⟩)) - -private lemma unionParts_incrementsAux (parts : List (Finset S)) (seen : Finset S) : - unionParts (incrementsAux parts seen) = unionParts parts \ seen := by - classical - induction parts generalizing seen with - | nil => - simp [incrementsAux] - | cons A parts ih => - simp [incrementsAux, unionParts, ih, sdiff_union_left] - -@[simp] lemma unionParts_incrementsAux_empty (parts : List (Finset S)) : - unionParts (incrementsAux parts (∅ : Finset S)) = unionParts parts := by - classical - have h := - unionParts_incrementsAux (parts:=parts) (seen:=(∅ : Finset S)) - have hzero : unionParts parts \ (∅ : Finset S) = unionParts parts := by simp - exact h.trans hzero - -end Aux - -section WithPlan - -variable {S : Type*} [Fintype S] [DecidableEq S] - -/-- Incremental “delta” masks: each block removes only the new elements from the -corresponding reroute step. -/ -def increments (P : ReroutePlan (S := S)) : List (Finset S) := - incrementsAux P.steps ∅ - -/-- The incremental masks list has the same length as `steps`. -/ -lemma increments_length (P : ReroutePlan (S := S)) : - P.increments.length = P.steps.length := by - classical - simp [increments, incrementsAux_length] - -lemma increments_pairwise (P : ReroutePlan (S := S)) : - P.increments.Pairwise (fun A B => Disjoint A B) := by - classical - unfold increments - exact - incrementsAux_pairwise (parts:=P.steps) (seen:=(∅ : Finset S)) - -@[simp] lemma unionParts_increments (P : ReroutePlan (S := S)) : - unionParts P.increments = unionParts P.steps := by - classical - have h := - unionParts_incrementsAux (parts:=P.steps) (seen:=(∅ : Finset S)) - have hzero : unionParts P.steps \ (∅ : Finset S) = unionParts P.steps := by simp - unfold increments - exact h.trans hzero - -lemma sum_over_increments (P : ReroutePlan (S := S)) (m : S → NNReal) : - (unionParts P.steps).sum m = - P.increments.foldr (fun A acc => A.sum m + acc) 0 := by - classical - have hpair := increments_pairwise (P:=P) - have hsum := - sum_unionParts_eq (m:=m) (parts:=P.increments) hpair - simpa [unionParts_increments (P:=P)] using hsum - -end WithPlan - -end ReroutePlan - -end Increments - -section Examples - -open ReroutePlan - -@[simp] def fin2Plan : ReroutePlan (Fin 2) := -by - classical - refine { - steps := [{0}, {1}], - pairwise_disjoint := ?_, - covers_univ := ?_ } - · refine List.pairwise_cons.2 ?_ - refine ⟨?_, ?_⟩ - · intro B hB - have : B = ({1} : Finset (Fin 2)) := by simpa using hB - subst this - simp - · simp - · ext i - fin_cases i <;> simp [unionParts] - -example : - (unionParts (fin2Plan.steps)).sum - (fun i => if i = 0 then (2 : NNReal) else 5) = - fin2Plan.increments.foldr - (fun A acc => - A.sum (fun i => if i = 0 then (2 : NNReal) else 5) + acc) - 0 := by - classical - exact - (ReroutePlan.sum_over_increments - (P:=fin2Plan) (m:=fun i => if i = 0 then (2 : NNReal) else 5)) - -end Examples - -end Nfp diff --git a/Nfp/SignedMixer.lean b/Nfp/SignedMixer.lean deleted file mode 100644 index 9d438a3..0000000 --- a/Nfp/SignedMixer.lean +++ /dev/null @@ -1,638 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Real.Basic -import Mathlib.Data.Real.Sign -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Algebra.Group.Defs -import Mathlib.Order.MinMax -import Nfp.Prob -import Nfp.Mixer - -/-! -# Signed Mixers and Affine Transformations - -This module extends the mixer framework to support real neural network operations -that involve negative weights and biases. While the original `Mixer` type captures -attention (row-stochastic, nonnegative), real networks also use: - -1. **Signed linear maps**: Value projections, MLPs with negative weights -2. **Affine maps**: Operations with bias terms -3. **Decompositions**: Splitting signed maps into positive/negative parts - -## Key insight for interpretation - -For attribution, we care about *how much* each input contributes to each output. -With signed weights, a negative contribution means "increasing the input decreases -the output." The framework here tracks both positive and negative contributions -separately, enabling precise attribution analysis. - -## Main definitions - -* `SignedMixer`: Linear map with real (possibly negative) weights -* `SignedMixer.positivePart`, `negativePart`: Decomposition into nonnegative parts -* `AffineMixer`: Signed mixer plus bias term -* `SignedMixer.toInfluence`: Convert to influence matrix for attribution - -## References - -This connects to: -- Integrated Gradients (uses signed gradients for attribution) -- SHAP values (can be positive or negative) -- Attention with negative weights (some transformer variants) --/ - -namespace Nfp - -open scoped BigOperators -open Finset - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-! ## Signed Mixer -/ - -/-- A signed mixer: a linear map between finite types with real weights. -Unlike `Mixer`, weights can be negative and rows need not sum to 1. - -This captures operations like: -- Value projections in attention -- MLP layers -- Any linear layer in a neural network -/ -structure SignedMixer (S T : Type*) [Fintype S] [Fintype T] where - /-- The weight matrix. `w i j` is the weight from input `i` to output `j`. -/ - w : S → T → ℝ - -namespace SignedMixer - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- Extensionality for signed mixers. -/ -@[ext] -theorem ext {M N : SignedMixer S T} (h : ∀ i j, M.w i j = N.w i j) : M = N := by - cases M; cases N; simp only [mk.injEq]; funext i j; exact h i j - -/-- The zero signed mixer. -/ -instance : Zero (SignedMixer S T) where - zero := ⟨fun _ _ => 0⟩ - -/-- Addition of signed mixers (pointwise). -/ -instance : Add (SignedMixer S T) where - add M N := ⟨fun i j => M.w i j + N.w i j⟩ - -/-- Scalar multiplication for signed mixers. -/ -instance : SMul ℝ (SignedMixer S T) where - smul c M := ⟨fun i j => c * M.w i j⟩ - -/-- Negation of signed mixers. -/ -instance : Neg (SignedMixer S T) where - neg M := ⟨fun i j => -M.w i j⟩ - -/-- Subtraction of signed mixers. -/ -instance : Sub (SignedMixer S T) where - sub M N := ⟨fun i j => M.w i j - N.w i j⟩ - -@[simp] lemma zero_w (i : S) (j : T) : (0 : SignedMixer S T).w i j = 0 := rfl -@[simp] lemma add_w (M N : SignedMixer S T) (i : S) (j : T) : - (M + N).w i j = M.w i j + N.w i j := rfl -@[simp] lemma smul_w (c : ℝ) (M : SignedMixer S T) (i : S) (j : T) : - (c • M).w i j = c * M.w i j := rfl -@[simp] lemma neg_w (M : SignedMixer S T) (i : S) (j : T) : (-M).w i j = -M.w i j := rfl -@[simp] lemma sub_w (M N : SignedMixer S T) (i : S) (j : T) : - (M - N).w i j = M.w i j - N.w i j := rfl - -/-- The identity signed mixer. -/ -noncomputable def identity [DecidableEq S] : SignedMixer S S where - w := fun i j => if i = j then 1 else 0 - -@[simp] lemma identity_diag [DecidableEq S] (i : S) : identity.w i i = 1 := by simp [identity] - -@[simp] lemma identity_off_diag [DecidableEq S] {i j : S} (h : i ≠ j) : - identity.w i j = 0 := by simp [identity, h] - -/-- Composition of signed mixers (matrix multiplication). -/ -noncomputable def comp (M : SignedMixer S T) (N : SignedMixer T U) : SignedMixer S U where - w := fun i k => ∑ j, M.w i j * N.w j k - -@[simp] lemma comp_w (M : SignedMixer S T) (N : SignedMixer T U) (i : S) (k : U) : - (M.comp N).w i k = ∑ j, M.w i j * N.w j k := rfl - -/-- Identity is a left unit for composition. -/ -@[simp] theorem identity_comp [DecidableEq S] (M : SignedMixer S T) : - identity.comp M = M := by - ext i j - simp only [comp_w, identity] - simp [Finset.sum_ite_eq, Finset.mem_univ] - -/-- Identity is a right unit for composition. -/ -@[simp] theorem comp_identity [DecidableEq T] (M : SignedMixer S T) : - M.comp identity = M := by - ext i j - simp only [comp_w, identity] - simp [Finset.sum_ite_eq', Finset.mem_univ] - -/-- Composition is associative. -/ -theorem comp_assoc {V : Type*} [Fintype V] - (M : SignedMixer S T) (N : SignedMixer T U) (P : SignedMixer U V) : - (M.comp N).comp P = M.comp (N.comp P) := by - ext i l - simp only [comp_w] - -- LHS: ∑_k (∑_j M_ij * N_jk) * P_kl - -- RHS: ∑_j M_ij * (∑_k N_jk * P_kl) - conv_lhs => - arg 2 - ext k - rw [Finset.sum_mul] - conv_rhs => - arg 2 - ext j - rw [Finset.mul_sum] - rw [Finset.sum_comm] - congr 1 - ext j - congr 1 - ext k - ring - -/-- Composition distributes over addition on the left. -/ -theorem comp_add_left (M₁ M₂ : SignedMixer S T) (N : SignedMixer T U) : - (M₁ + M₂).comp N = M₁.comp N + M₂.comp N := by - ext i k - simp [comp_w, add_w, add_mul, Finset.sum_add_distrib] - -/-- Composition distributes over addition on the right. -/ -theorem comp_add_right (M : SignedMixer S T) (N₁ N₂ : SignedMixer T U) : - M.comp (N₁ + N₂) = M.comp N₁ + M.comp N₂ := by - ext i k - simp [comp_w, add_w, mul_add, Finset.sum_add_distrib] - -/-! ## Decomposition into positive and negative parts -/ - -/-- The positive part of a signed mixer: max(w, 0) for each weight. -/ -noncomputable def positivePart (M : SignedMixer S T) : SignedMixer S T where - w := fun i j => max (M.w i j) 0 - -/-- The negative part of a signed mixer: max(-w, 0) for each weight. -Note: This is nonnegative; it represents the magnitude of negative weights. -/ -noncomputable def negativePart (M : SignedMixer S T) : SignedMixer S T where - w := fun i j => max (-M.w i j) 0 - -@[simp] lemma positivePart_w (M : SignedMixer S T) (i : S) (j : T) : - M.positivePart.w i j = max (M.w i j) 0 := rfl - -@[simp] lemma negativePart_w (M : SignedMixer S T) (i : S) (j : T) : - M.negativePart.w i j = max (-M.w i j) 0 := rfl - -/-- A signed mixer decomposes as positivePart - negativePart. -/ -theorem decompose (M : SignedMixer S T) : - M = M.positivePart - M.negativePart := by - ext i j - simp only [positivePart_w, negativePart_w, sub_w] - -- max(x, 0) - max(-x, 0) = x - by_cases h : M.w i j ≥ 0 - · simp [max_eq_left h, max_eq_right (neg_nonpos.mpr h)] - · push_neg at h - simp [max_eq_right (le_of_lt h), max_eq_left (neg_nonneg.mpr (le_of_lt h))] - -/-- The positive part is nonnegative. -/ -lemma positivePart_nonneg (M : SignedMixer S T) (i : S) (j : T) : - M.positivePart.w i j ≥ 0 := le_max_right _ _ - -/-- The negative part is nonnegative. -/ -lemma negativePart_nonneg (M : SignedMixer S T) (i : S) (j : T) : - M.negativePart.w i j ≥ 0 := le_max_right _ _ - -/-! ## Row sums and normalization -/ - -/-- The sum of weights in row i. -/ -noncomputable def rowSum (M : SignedMixer S T) (i : S) : ℝ := ∑ j, M.w i j - -/-- A signed mixer is row-stochastic if all rows sum to 1. -/ -def IsRowStochastic (M : SignedMixer S T) : Prop := ∀ i, M.rowSum i = 1 - -/-- A signed mixer is row-normalized if all rows sum to the same value. -/ -def IsRowNormalized (M : SignedMixer S T) (c : ℝ) : Prop := ∀ i, M.rowSum i = c - -/-- The sum of absolute values in row i. -/ -noncomputable def rowAbsSum (M : SignedMixer S T) (i : S) : ℝ := ∑ j, |M.w i j| - -/-- Total influence magnitude: sum of all absolute weights. -/ -noncomputable def totalInfluence (M : SignedMixer S T) : ℝ := ∑ i, M.rowAbsSum i - -/-- Row-sum operator norm bound (induced ℓ1 for row-vector convention). -/ -noncomputable def operatorNormBound (M : SignedMixer S T) [Nonempty S] : ℝ := - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) fun i => - ∑ j, |M.w i j| - -/-! ## Operator norm bound estimates -/ - -/-- Row absolute sum is nonnegative. -/ -lemma rowAbsSum_nonneg (M : SignedMixer S T) (i : S) : 0 ≤ M.rowAbsSum i := by - classical - unfold rowAbsSum - refine Finset.sum_nonneg ?_ - intro j _hj - exact abs_nonneg _ - -/-- Operator norm bounds are nonnegative. -/ -theorem operatorNormBound_nonneg (M : SignedMixer S T) [Nonempty S] : - 0 ≤ operatorNormBound M := by - classical - rcases (Finset.univ_nonempty (α := S)) with ⟨i, hi⟩ - have hrow : 0 ≤ rowAbsSum M i := rowAbsSum_nonneg (M := M) i - have hle : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum M i) hi - exact le_trans hrow hle - -/-- Row absolute sums are subadditive. -/ -lemma rowAbsSum_add_le (M N : SignedMixer S T) (i : S) : - rowAbsSum (M + N) i ≤ rowAbsSum M i + rowAbsSum N i := by - classical - have hterm : ∀ j : T, |M.w i j + N.w i j| ≤ |M.w i j| + |N.w i j| := by - intro j - exact abs_add_le _ _ - have hsum : - ∑ j, |M.w i j + N.w i j| ≤ ∑ j, (|M.w i j| + |N.w i j|) := by - refine Finset.sum_le_sum ?_ - intro j _hj - exact hterm j - calc - rowAbsSum (M + N) i = ∑ j, |M.w i j + N.w i j| := by - simp [rowAbsSum, add_w] - _ ≤ ∑ j, (|M.w i j| + |N.w i j|) := hsum - _ = rowAbsSum M i + rowAbsSum N i := by - simp [rowAbsSum, Finset.sum_add_distrib] - -/-- Operator norm bounds are subadditive. -/ -theorem operatorNormBound_add_le (M N : SignedMixer S T) [Nonempty S] : - operatorNormBound (M + N) ≤ operatorNormBound M + operatorNormBound N := by - classical - dsimp [operatorNormBound] - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := S)) - (f := fun i => ∑ j, |(M + N).w i j|) - (a := operatorNormBound M + operatorNormBound N)).2 ?_ - intro i hi - have hsum : rowAbsSum (M + N) i ≤ rowAbsSum M i + rowAbsSum N i := - rowAbsSum_add_le (M := M) (N := N) i - have hM : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum M i) hi - have hN : rowAbsSum N i ≤ operatorNormBound N := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum N i) hi - have hbound : rowAbsSum (M + N) i ≤ operatorNormBound M + operatorNormBound N := by - exact le_trans hsum (add_le_add hM hN) - simpa [rowAbsSum] using hbound - -/-- Row absolute sums of a composition are bounded by row sums and the operator norm bound. -/ -lemma rowAbsSum_comp_le (M : SignedMixer S T) (N : SignedMixer T U) (i : S) [Nonempty T] : - rowAbsSum (M.comp N) i ≤ rowAbsSum M i * operatorNormBound N := by - classical - have hterm : ∀ k : U, |∑ j, M.w i j * N.w j k| ≤ ∑ j, |M.w i j| * |N.w j k| := by - intro k - simpa [abs_mul] using - (abs_sum_le_sum_abs (f := fun j => M.w i j * N.w j k) (s := Finset.univ)) - calc - rowAbsSum (M.comp N) i = ∑ k, |∑ j, M.w i j * N.w j k| := by - simp [rowAbsSum, comp_w] - _ ≤ ∑ k, ∑ j, |M.w i j| * |N.w j k| := by - refine Finset.sum_le_sum ?_ - intro k _hk - exact hterm k - _ = ∑ j, |M.w i j| * (∑ k, |N.w j k|) := by - calc - (∑ k, ∑ j, |M.w i j| * |N.w j k|) - = ∑ j, ∑ k, |M.w i j| * |N.w j k| := by - simpa using - (Finset.sum_comm (s := Finset.univ) (t := Finset.univ) - (f := fun k j => |M.w i j| * |N.w j k|)) - _ = ∑ j, |M.w i j| * (∑ k, |N.w j k|) := by - refine Finset.sum_congr rfl ?_ - intro j _hj - simp [Finset.mul_sum] - _ ≤ ∑ j, |M.w i j| * operatorNormBound N := by - refine Finset.sum_le_sum ?_ - intro j _hj - have hN : rowAbsSum N j ≤ operatorNormBound N := by - exact Finset.le_sup' (s := Finset.univ) (f := fun j => rowAbsSum N j) (by simp) - have hN' : (∑ k, |N.w j k|) ≤ operatorNormBound N := by - simpa [rowAbsSum] using hN - have hMnonneg : 0 ≤ |M.w i j| := abs_nonneg _ - exact mul_le_mul_of_nonneg_left hN' hMnonneg - _ = rowAbsSum M i * operatorNormBound N := by - simp [rowAbsSum, Finset.sum_mul] - -/-- Operator norm bounds are submultiplicative. -/ -theorem operatorNormBound_comp_le (M : SignedMixer S T) (N : SignedMixer T U) - [Nonempty S] [Nonempty T] : - operatorNormBound (M.comp N) ≤ operatorNormBound M * operatorNormBound N := by - classical - dsimp [operatorNormBound] - refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := S)) - (f := fun i => ∑ j, |(M.comp N).w i j|) - (a := operatorNormBound M * operatorNormBound N)).2 ?_ - intro i hi - have hrow : rowAbsSum (M.comp N) i ≤ rowAbsSum M i * operatorNormBound N := - rowAbsSum_comp_le (M := M) (N := N) i - have hM : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun i => rowAbsSum M i) hi - have hNnonneg : 0 ≤ operatorNormBound N := operatorNormBound_nonneg (M := N) - have hmul : rowAbsSum M i * operatorNormBound N ≤ operatorNormBound M * operatorNormBound N := by - exact mul_le_mul_of_nonneg_right hM hNnonneg - have hbound : rowAbsSum (M.comp N) i ≤ operatorNormBound M * operatorNormBound N := - le_trans hrow hmul - simpa [rowAbsSum] using hbound - -/-- Operator norm bounds for a triple composition. -/ -theorem operatorNormBound_comp3_le {V : Type*} [Fintype V] - (A : SignedMixer S T) (B : SignedMixer T U) (C : SignedMixer U V) - [Nonempty S] [Nonempty T] [Nonempty U] : - operatorNormBound ((A.comp B).comp C) ≤ - operatorNormBound A * operatorNormBound B * operatorNormBound C := by - have h1 : - operatorNormBound ((A.comp B).comp C) ≤ - operatorNormBound (A.comp B) * operatorNormBound C := - operatorNormBound_comp_le (M := A.comp B) (N := C) - have h2 : - operatorNormBound (A.comp B) ≤ operatorNormBound A * operatorNormBound B := - operatorNormBound_comp_le (M := A) (N := B) - have hC_nonneg : 0 ≤ operatorNormBound C := operatorNormBound_nonneg (M := C) - have hmul : - operatorNormBound (A.comp B) * operatorNormBound C ≤ - (operatorNormBound A * operatorNormBound B) * operatorNormBound C := by - exact mul_le_mul_of_nonneg_right h2 hC_nonneg - calc - operatorNormBound ((A.comp B).comp C) ≤ - operatorNormBound (A.comp B) * operatorNormBound C := h1 - _ ≤ (operatorNormBound A * operatorNormBound B) * operatorNormBound C := hmul - _ = operatorNormBound A * operatorNormBound B * operatorNormBound C := by - ring - -/-- Bound for `A + M + A.comp M` in terms of operator norms. -/ -theorem operatorNormBound_add_comp_le (A M : SignedMixer S S) [Nonempty S] : - operatorNormBound (A + M + A.comp M) ≤ - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := by - have hsum : operatorNormBound (A + M + A.comp M) ≤ - operatorNormBound (A + M) + operatorNormBound (A.comp M) := - operatorNormBound_add_le (M := A + M) (N := A.comp M) - have hsum' : operatorNormBound (A + M) ≤ operatorNormBound A + operatorNormBound M := - operatorNormBound_add_le (M := A) (N := M) - have hcomp : operatorNormBound (A.comp M) ≤ operatorNormBound A * operatorNormBound M := - operatorNormBound_comp_le (M := A) (N := M) - calc - operatorNormBound (A + M + A.comp M) - ≤ operatorNormBound (A + M) + operatorNormBound (A.comp M) := hsum - _ ≤ (operatorNormBound A + operatorNormBound M) + - (operatorNormBound A * operatorNormBound M) := by - exact add_le_add hsum' hcomp - _ = operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := by - ring - -/-- Expand the residual composition `(I + A) ∘ (I + M) - I` into `A + M + A ∘ M`. -/ -theorem residual_comp_eq [DecidableEq S] (A M : SignedMixer S S) : - (SignedMixer.identity + A).comp (SignedMixer.identity + M) - SignedMixer.identity = - A + M + A.comp M := by - classical - have h1 : - (SignedMixer.identity + A).comp (SignedMixer.identity + M) = - SignedMixer.identity.comp (SignedMixer.identity + M) + - A.comp (SignedMixer.identity + M) := by - exact comp_add_left (M₁ := SignedMixer.identity) (M₂ := A) (N := SignedMixer.identity + M) - have h2 : - SignedMixer.identity.comp (SignedMixer.identity + M) = - SignedMixer.identity.comp SignedMixer.identity + SignedMixer.identity.comp M := by - exact comp_add_right (M := SignedMixer.identity) (N₁ := SignedMixer.identity) (N₂ := M) - have h3 : - A.comp (SignedMixer.identity + M) = - A.comp SignedMixer.identity + A.comp M := by - exact comp_add_right (M := A) (N₁ := SignedMixer.identity) (N₂ := M) - ext i j - simp [h1, h2, h3, add_w, sub_w, identity_comp, comp_identity] - ring - -/-- Residual composition bound with the `A + M + A*M` cross term. -/ -theorem operatorNormBound_residual_comp_le [DecidableEq S] (A M : SignedMixer S S) [Nonempty S] : - operatorNormBound ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity) ≤ - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := by - have hres : (SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity = A + M + A.comp M := - residual_comp_eq (A := A) (M := M) - simpa [hres] using (operatorNormBound_add_comp_le (A := A) (M := M)) - -/-- Residual composition bound from external operator-norm bounds. -/ -theorem operatorNormBound_residual_comp_le_of_bounds [DecidableEq S] - (A M : SignedMixer S S) (a b : ℝ) [Nonempty S] - (hA : operatorNormBound A ≤ a) (hM : operatorNormBound M ≤ b) : - operatorNormBound ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity) ≤ a + b + a * b := by - have hres : - operatorNormBound ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - - SignedMixer.identity) ≤ - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M := - operatorNormBound_residual_comp_le (A := A) (M := M) - have hA_nonneg : 0 ≤ operatorNormBound A := operatorNormBound_nonneg (M := A) - have hM_nonneg : 0 ≤ operatorNormBound M := operatorNormBound_nonneg (M := M) - have ha_nonneg : 0 ≤ a := le_trans hA_nonneg hA - have hsum : operatorNormBound A + operatorNormBound M ≤ a + b := by - exact add_le_add hA hM - have hmul : operatorNormBound A * operatorNormBound M ≤ a * b := by - exact mul_le_mul hA hM hM_nonneg ha_nonneg - have hsum' : - operatorNormBound A + operatorNormBound M + - operatorNormBound A * operatorNormBound M ≤ - a + b + a * b := by - exact add_le_add hsum hmul - exact le_trans hres hsum' - -/-! ## Conversion to/from Mixer -/ - -/-- Convert a nonnegative signed mixer with row sums = 1 to a Mixer. -This is partial: requires proof that weights are nonnegative. -/ -noncomputable def toMixer (M : SignedMixer S T) - (hpos : ∀ i j, M.w i j ≥ 0) (hsum : M.IsRowStochastic) : Mixer S T where - w := fun i j => ⟨M.w i j, hpos i j⟩ - row_sum_one := by - intro i - have h := hsum i - simp only [rowSum] at h - ext - simp only [NNReal.coe_sum, NNReal.coe_mk, NNReal.coe_one] - exact h - -/-- Convert a Mixer to a SignedMixer (embedding). -/ -def ofMixer (M : Mixer S T) : SignedMixer S T where - w := fun i j => M.w i j - -@[simp] lemma ofMixer_w (M : Mixer S T) (i : S) (j : T) : - (ofMixer M).w i j = M.w i j := rfl - -/-- A Mixer converted to SignedMixer is row-stochastic. -/ -theorem ofMixer_isRowStochastic (M : Mixer S T) : (ofMixer M).IsRowStochastic := by - intro i - simp only [rowSum, ofMixer_w] - have := M.row_sum_one i - simp only [← NNReal.coe_sum, this, NNReal.coe_one] - -/-! ## Influence and attribution -/ - -/-- The influence of input i on output j: the absolute value of the weight. -This measures "how much does changing input i affect output j?" -/ -noncomputable def influence (M : SignedMixer S T) (i : S) (j : T) : ℝ := - |M.w i j| - -/-- The sign of influence: +1 for positive, -1 for negative, 0 for zero. -/ -noncomputable def influenceSign (M : SignedMixer S T) (i : S) (j : T) : ℝ := - Real.sign (M.w i j) - -/-- Total influence from input i (how much does i affect the whole output?). -/ -noncomputable def totalInfluenceFrom (M : SignedMixer S T) (i : S) : ℝ := - ∑ j, M.influence i j - -/-- Total influence on output j (how much is j affected by all inputs?). -/ -noncomputable def totalInfluenceOn (M : SignedMixer S T) (j : T) : ℝ := - ∑ i, M.influence i j - -/-! ## Application to vectors -/ - -/-- Apply a signed mixer to a real vector. -/ -noncomputable def apply (M : SignedMixer S T) (v : S → ℝ) : T → ℝ := - fun j => ∑ i, v i * M.w i j - -@[simp] lemma apply_def (M : SignedMixer S T) (v : S → ℝ) (j : T) : - M.apply v j = ∑ i, v i * M.w i j := rfl - -/-- Composition corresponds to sequential application. -/ -theorem apply_comp (M : SignedMixer S T) (N : SignedMixer T U) (v : S → ℝ) : - (M.comp N).apply v = N.apply (M.apply v) := by - ext k - simp only [apply_def, comp_w] - -- LHS: ∑_i v_i * (∑_j M_ij * N_jk) - -- RHS: ∑_j (∑_i v_i * M_ij) * N_jk - conv_lhs => - arg 2 - ext i - rw [Finset.mul_sum] - rw [Finset.sum_comm] - congr 1 - ext j - rw [Finset.sum_mul] - congr 1 - ext i - ring - -end SignedMixer - -/-! ## Affine Mixer -/ - -/-- An affine mixer: a signed linear map plus a bias term. -This captures the full `y = xW + b` form of neural network layers (row-vector convention). -/ -structure AffineMixer (S T : Type*) [Fintype S] [Fintype T] where - /-- The linear part. -/ - linear : SignedMixer S T - /-- The bias term. -/ - bias : T → ℝ - -namespace AffineMixer - -variable {S T U : Type*} [Fintype S] [Fintype T] [Fintype U] - -/-- Apply an affine mixer to a vector: xW + b. -/ -noncomputable def apply (M : AffineMixer S T) (v : S → ℝ) : T → ℝ := - fun j => M.linear.apply v j + M.bias j - -@[simp] lemma apply_def (M : AffineMixer S T) (v : S → ℝ) (j : T) : - M.apply v j = (∑ i, v i * M.linear.w i j) + M.bias j := rfl - -/-- An affine mixer with zero bias is equivalent to its linear part. -/ -def ofLinear (M : SignedMixer S T) : AffineMixer S T where - linear := M - bias := fun _ => 0 - -/-- Composition of affine mixers (row-vector convention). -(W₂, b₂) ∘ (W₁, b₁) = (W₁W₂, b₁W₂ + b₂). -/ -noncomputable def comp (M : AffineMixer S T) (N : AffineMixer T U) : AffineMixer S U where - linear := M.linear.comp N.linear - bias := fun k => N.linear.apply M.bias k + N.bias k - -/-- Composition corresponds to sequential application. -/ -theorem comp_apply (M : AffineMixer S T) (N : AffineMixer T U) (v : S → ℝ) : - (M.comp N).apply v = N.apply (M.apply v) := by - classical - ext k - have hlin : - ∑ i, v i * (∑ x, M.linear.w i x * N.linear.w x k) = - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k := by - have h := - congrArg (fun f => f k) - (SignedMixer.apply_comp (M := M.linear) (N := N.linear) (v := v)) - simpa [SignedMixer.apply_def] using h - have hsum : - (∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k = - ∑ x, (M.bias x + ∑ i, v i * M.linear.w i x) * N.linear.w x k := by - symm - simp [Finset.sum_add_distrib, add_mul] - calc - (M.comp N).apply v k = - (M.comp N).bias k + ∑ i, v i * (M.comp N).linear.w i k := by - simp [AffineMixer.apply_def, add_comm] - _ = - N.bias k + (∑ x, M.bias x * N.linear.w x k) + - ∑ i, v i * (∑ x, M.linear.w i x * N.linear.w x k) := by - simp [AffineMixer.comp, SignedMixer.comp_w, SignedMixer.apply_def, add_assoc, add_comm] - _ = N.bias k + (∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k := by - simp [hlin] - _ = N.bias k + ∑ x, (M.bias x + ∑ i, v i * M.linear.w i x) * N.linear.w x k := by - calc - N.bias k + (∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k = - N.bias k + - ((∑ x, M.bias x * N.linear.w x k) + - ∑ x, (∑ i, v i * M.linear.w i x) * N.linear.w x k) := by - simp [add_assoc] - _ = N.bias k + ∑ x, (M.bias x + ∑ i, v i * M.linear.w i x) * N.linear.w x k := by - simp [hsum] - _ = N.apply (M.apply v) k := by - simp [AffineMixer.apply_def, add_comm] - -/-- The bias can be seen as the output when input is zero. -/ -theorem apply_zero (M : AffineMixer S T) : M.apply (fun _ => 0) = M.bias := by - ext j - simp [apply_def] - -/-- **Bias attribution principle**: The bias contributes equally regardless of input. -This is formalized by showing that the difference between any two outputs -depends only on the linear part, not the bias. -/ -theorem bias_invariance (M : AffineMixer S T) (v w : S → ℝ) (j : T) : - M.apply v j - M.apply w j = M.linear.apply v j - M.linear.apply w j := by - simp only [apply_def, SignedMixer.apply_def] - ring - -end AffineMixer - -/-! ## Gradient-based attribution compatibility -/ - -/-- A purely algebraic compatibility lemma: for a linear map encoded by a `SignedMixer`, -the “Jacobian entry” is *by definition* the weight `M.w i j`. - -This does **not** assert a differentiability statement in Lean's analysis library; it is -the convention used by downstream (external) gradient-based interpretations. -/ -theorem SignedMixer.jacobianEntry_eq_weight (M : SignedMixer S T) (i : S) (j : T) : - M.w i j = M.w i j := rfl - -/-- **Integrated Gradients (aggregated output)**: for a linear map M, -if we aggregate outputs by summing over j, then the IG attribution for input i -reduces to `x_i * rowSum i`. -/ -theorem SignedMixer.integrated_gradients_linear (M : SignedMixer S T) (x : S → ℝ) (i : S) : - -- The "contribution" of input i to the output - -- For linear M, this is x_i times the signed row sum (net effect on all outputs) - x i * M.rowSum i = x i * ∑ j, M.w i j := by - simp [SignedMixer.rowSum] - -end Nfp diff --git a/Nfp/Sound.lean b/Nfp/Sound.lean new file mode 100644 index 0000000..a535626 --- /dev/null +++ b/Nfp/Sound.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction + +/-! +Soundness theorems and verified helpers. +-/ diff --git a/Nfp/Sound/Activation.lean b/Nfp/Sound/Activation.lean deleted file mode 100644 index 2630e2d..0000000 --- a/Nfp/Sound/Activation.lean +++ /dev/null @@ -1,43 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std - -namespace Nfp.Sound - -/-! -# Activation metadata (SOUND) - -This module defines activation-derivative targets used by SOUND certification and -parsing helpers for `.nfpt` headers. --/ - -/-- Which GeLU derivative formula the model uses. -/ -inductive GeluDerivTarget - | tanh - | exact - deriving Repr, DecidableEq - -/-- Parse a GeLU derivative target string (case-insensitive). -/ -def geluDerivTargetOfString (s : String) : Option GeluDerivTarget := - let v := s.trim.toLower - if v = "tanh" || v = "gelu_tanh" then - some .tanh - else if v = "exact" || v = "gelu_exact" then - some .exact - else - none - -/-- Render a GeLU derivative target for headers/logging. -/ -def geluDerivTargetToString : GeluDerivTarget → String - | .tanh => "tanh" - | .exact => "exact" - -/-! ### Specs -/ - -theorem GeluDerivTarget_spec : GeluDerivTarget = GeluDerivTarget := rfl -theorem geluDerivTargetOfString_spec : - geluDerivTargetOfString = geluDerivTargetOfString := rfl -theorem geluDerivTargetToString_spec : - geluDerivTargetToString = geluDerivTargetToString := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Affine.lean b/Nfp/Sound/Affine.lean deleted file mode 100644 index 866065f..0000000 --- a/Nfp/Sound/Affine.lean +++ /dev/null @@ -1,73 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Interval - -namespace Nfp.Sound - -/-! -# Affine arithmetic scaffolding (SOUND) - -This module provides a minimal affine-form representation for future local -certification improvements. It is not yet integrated into the SOUND pipeline. --/ - -/-- Affine form `x = center + sum coeffs[i] * eps_i` with `eps_i in [-1, 1]`. -/ -structure AffineForm where - center : Rat - coeffs : Array Rat - deriving Repr - -namespace AffineForm - -/-- Constant affine form. -/ -def const (x : Rat) : AffineForm := { center := x, coeffs := #[] } - -private def combineCoeffs (a b : Array Rat) (f : Rat → Rat → Rat) : Array Rat := - Id.run do - let n := max a.size b.size - let mut out := Array.mkEmpty n - for i in [:n] do - let ai := a.get? i |>.getD 0 - let bi := b.get? i |>.getD 0 - out := out.push (f ai bi) - return out - -/-- Add two affine forms, aligning noise terms by index. -/ -def add (a b : AffineForm) : AffineForm := - { center := a.center + b.center - coeffs := combineCoeffs a.coeffs b.coeffs (· + ·) } - -/-- Subtract two affine forms, aligning noise terms by index. -/ -def sub (a b : AffineForm) : AffineForm := - { center := a.center - b.center - coeffs := combineCoeffs a.coeffs b.coeffs (· - ·) } - -/-- Scale an affine form by a rational constant. -/ -def scale (c : Rat) (a : AffineForm) : AffineForm := - { center := c * a.center - coeffs := a.coeffs.map (fun k => c * k) } - -/-- Sum of absolute noise coefficients (radius of the interval hull). -/ -def radius (a : AffineForm) : Rat := - a.coeffs.foldl (fun acc c => acc + ratAbs c) 0 - -/-- Interval hull of an affine form. -/ -def toInterval (a : AffineForm) : RatInterval := - let r := radius a - { lo := a.center - r, hi := a.center + r } - -/-! ### Specs -/ - -theorem AffineForm_spec : AffineForm = AffineForm := rfl -theorem const_spec : const = const := rfl -theorem combineCoeffs_spec : combineCoeffs = combineCoeffs := rfl -theorem add_spec : add = add := rfl -theorem sub_spec : sub = sub := rfl -theorem scale_spec : scale = scale := rfl -theorem radius_spec : radius = radius := rfl -theorem toInterval_spec : toInterval = toInterval := rfl - -end AffineForm - -end Nfp.Sound diff --git a/Nfp/Sound/BinaryPure.lean b/Nfp/Sound/BinaryPure.lean deleted file mode 100644 index 8b4dadd..0000000 --- a/Nfp/Sound/BinaryPure.lean +++ /dev/null @@ -1,385 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Activation -import Nfp.Sound.Decimal - -namespace Nfp.Sound - -/-! -# Pure binary helpers (`NFP_BINARY_V1`) - -Pure parsing and decoding utilities for the SOUND binary path. -IO wrappers live in `Nfp.Untrusted.SoundBinary`. --/ - -structure BinaryHeader where - numLayers : Nat - numHeads : Nat - modelDim : Nat - headDim : Nat - hiddenDim : Nat - vocabSize : Nat - seqLen : Nat - eps : Rat - geluDerivTarget : GeluDerivTarget - deriving Repr - -private def parseHeaderLine (line : String) : Option (String × String) := - let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - -private def readHeaderNat (k v : String) : Option Nat := - if k = "num_layers" || k = "num_heads" || k = "model_dim" || - k = "head_dim" || k = "hidden_dim" || k = "vocab_size" || k = "seq_len" then - v.toNat? - else - none - -def parseBinaryHeaderLines (magicLine : String) (lines : Array String) : - Except String BinaryHeader := do - let magic := magicLine.trim - if magic != "NFP_BINARY_V1" then - throw "invalid magic: expected NFP_BINARY_V1" - - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut vocabSize : Option Nat := none - let mut seqLen : Option Nat := none - let mut eps : Option Rat := none - let mut gelu? : Option GeluDerivTarget := none - - for line in lines do - let t := line.trim - if t.isEmpty then - pure () - else - match parseHeaderLine t with - | none => pure () - | some (k, v) => - let vNat? := readHeaderNat k v - match k, vNat? with - | "num_layers", some n => numLayers := some n - | "num_heads", some n => numHeads := some n - | "model_dim", some n => modelDim := some n - | "head_dim", some n => headDim := some n - | "hidden_dim", some n => hiddenDim := some n - | "vocab_size", some n => vocabSize := some n - | "seq_len", some n => seqLen := some n - | "layer_norm_eps", _ => - match parseRat v with - | .error e => throw s!"invalid layer_norm_eps '{v}': {e}" - | .ok r => eps := some r - | "eps", _ => - match parseRat v with - | .error e => throw s!"invalid layer_norm_eps '{v}': {e}" - | .ok r => eps := some r - | "gelu_kind", _ => - match geluDerivTargetOfString v with - | some t => gelu? := some t - | none => throw s!"invalid gelu_kind '{v}' (expected tanh|exact)" - | "gelu_deriv", _ => - match geluDerivTargetOfString v with - | some t => gelu? := some t - | none => throw s!"invalid gelu_deriv '{v}' (expected tanh|exact)" - | _, _ => pure () - - let some L := numLayers | throw "missing num_layers" - let some H := numHeads | throw "missing num_heads" - let some d := modelDim | throw "missing model_dim" - let some dh := headDim | throw "missing head_dim" - let some dhid := hiddenDim | throw "missing hidden_dim" - let some v := vocabSize | throw "missing vocab_size" - let some n := seqLen | throw "missing seq_len" - let some epsVal := eps | throw "missing layer_norm_eps" - let some geluVal := gelu? | throw "missing gelu_kind" - if L = 0 || H = 0 || d = 0 || dh = 0 || dhid = 0 || v = 0 || n = 0 then - throw "invalid header: dimensions must be > 0" - return { - numLayers := L - numHeads := H - modelDim := d - headDim := dh - hiddenDim := dhid - vocabSize := v - seqLen := n - eps := epsVal - geluDerivTarget := geluVal - } - -private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := - let b0 := (b.get! off).toUInt64 - let b1 := (b.get! (off + 1)).toUInt64 - let b2 := (b.get! (off + 2)).toUInt64 - let b3 := (b.get! (off + 3)).toUInt64 - let b4 := (b.get! (off + 4)).toUInt64 - let b5 := (b.get! (off + 5)).toUInt64 - let b6 := (b.get! (off + 6)).toUInt64 - let b7 := (b.get! (off + 7)).toUInt64 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| - (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) - -private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := - let b0 := (b.get! off).toUInt32 - let b1 := (b.get! (off + 1)).toUInt32 - let b2 := (b.get! (off + 2)).toUInt32 - let b3 := (b.get! (off + 3)).toUInt32 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) - -private def i32FromLE (b : ByteArray) (off : Nat) : Int := - let u := u32FromLE b off - if u ≤ 0x7fffffff then - Int.ofNat u.toNat - else - Int.ofNat u.toNat - (Int.ofNat (Nat.pow 2 32)) - -private def pow2Nat (k : Nat) : Nat := Nat.pow 2 k - -private def ceilDivNat (a : Int) (d : Nat) : Int := - let di : Int := Int.ofNat d - let q := a.ediv di - let r := a.emod di - if r = 0 then q else q + 1 - -private def floatAbsCeilScaled (scalePow10 : Nat) (bits : UInt64) : Except String Int := - let expBits : UInt64 := (bits >>> 52) &&& 0x7ff - let mantBits : UInt64 := bits &&& 0x000f_ffff_ffff_ffff - if expBits = 0x7ff then - .error "invalid float: NaN/Inf not supported" - else if expBits = 0 && mantBits = 0 then - .ok 0 - else - let scale : Nat := Nat.pow 10 scalePow10 - let mant : Nat := - if expBits = 0 then - mantBits.toNat - else - (mantBits + ((1 : UInt64) <<< 52)).toNat - let expVal : Int := - if expBits = 0 then - -1074 - else - (Int.ofNat expBits.toNat) - 1075 - let mInt : Int := Int.ofNat mant - if expVal ≥ 0 then - let pow2 := pow2Nat expVal.toNat - let num := mInt * Int.ofNat scale - .ok (num * Int.ofNat pow2) - else - let denPow := pow2Nat (-expVal).toNat - let num := mInt * Int.ofNat scale - .ok (ceilDivNat num denPow) - -private def floatScaledCeilSigned (scalePow10 : Nat) (bits : UInt64) : Except String Int := - match floatAbsCeilScaled scalePow10 bits with - | .error e => .error e - | .ok absScaled => - let signNeg : Bool := (bits >>> 63) = (1 : UInt64) - if signNeg then - .ok (-absScaled) - else - .ok absScaled - -def vectorMaxAbsScaledFromBytes (bytes : ByteArray) (n scalePow10 : Nat) : - Except String Int := do - if n = 0 then - return 0 - if bytes.size < n * 8 then - throw "unexpected EOF" - let mut maxAbs : Int := 0 - for i in [:n] do - let bits := u64FromLE bytes (i * 8) - match floatAbsCeilScaled scalePow10 bits with - | .error e => throw e - | .ok absScaled => - if absScaled > maxAbs then - maxAbs := absScaled - return maxAbs - -def matrixNormInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat) : - Except String Int := do - if rows = 0 || cols = 0 then - return 0 - let count := rows * cols - if bytes.size < count * 8 then - throw "unexpected EOF" - let mut maxRowSum : Int := 0 - let mut curRowSum : Int := 0 - for i in [:count] do - let bits := u64FromLE bytes (i * 8) - match floatAbsCeilScaled scalePow10 bits with - | .error e => throw e - | .ok absScaled => - curRowSum := curRowSum + absScaled - if (i + 1) % cols = 0 then - if curRowSum > maxRowSum then - maxRowSum := curRowSum - curRowSum := 0 - return maxRowSum - -def scaledFloatArrayFromBytes (bytes : ByteArray) (count scalePow10 : Nat) : - Except String (Array Int) := do - if count = 0 then - return #[] - if bytes.size < count * 8 then - throw "unexpected EOF" - let mut out : Array Int := Array.mkEmpty count - for i in [:count] do - let bits := u64FromLE bytes (i * 8) - match floatScaledCeilSigned scalePow10 bits with - | .error e => throw e - | .ok v => out := out.push v - return out - -def scaledFloatFromBytes (bytes : ByteArray) (scalePow10 : Nat) : - Except String Int := do - if bytes.size < 8 then - throw "unexpected EOF" - let bits := u64FromLE bytes 0 - match floatScaledCeilSigned scalePow10 bits with - | .error e => throw e - | .ok v => return v - -def i32ArrayFromBytes (bytes : ByteArray) (count : Nat) : - Except String (Array Int) := do - if count = 0 then - return #[] - if bytes.size < count * 4 then - throw "unexpected EOF" - let mut out : Array Int := Array.mkEmpty count - for i in [:count] do - let v := i32FromLE bytes (i * 4) - out := out.push v - return out - -def matrixNormOneInfScaledFromBytes (bytes : ByteArray) (rows cols scalePow10 : Nat) : - Except String (Nat × Nat) := do - if rows = 0 || cols = 0 then - return (0, 0) - let count := rows * cols - if bytes.size < count * 8 then - throw "unexpected EOF" - let mut maxRowSum : Nat := 0 - let mut curRowSum : Nat := 0 - let mut colSums : Array Nat := Array.replicate cols 0 - for i in [:count] do - let bits := u64FromLE bytes (i * 8) - match floatAbsCeilScaled scalePow10 bits with - | .error e => throw e - | .ok absScaled => - let absNat := Int.toNat absScaled - curRowSum := curRowSum + absNat - let colIdx := i % cols - colSums := colSums.set! colIdx (colSums[colIdx]! + absNat) - if (i + 1) % cols = 0 then - if curRowSum > maxRowSum then - maxRowSum := curRowSum - curRowSum := 0 - let mut maxColSum : Nat := 0 - for c in colSums do - if c > maxColSum then - maxColSum := c - return (maxRowSum, maxColSum) - -def opBoundScaledFromOneInf (rowSum colSum : Nat) : Nat := - max rowSum colSum - -def ratOfScaledNat (scalePow10 : Nat) (x : Nat) : Rat := - Rat.normalize (Int.ofNat x) (Nat.pow 10 scalePow10) (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - -def ratOfScaledInt (scalePow10 : Nat) (x : Int) : Rat := - Rat.normalize x (Nat.pow 10 scalePow10) (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - -def defaultBinaryScalePow10 : Nat := 9 - -/-- Sum of per-head value-output norm products in scaled-int form. -/ -def attnValueCoeffFromScaledPairs (scalePow10 : Nat) (pairs : Array (Int × Int)) : Rat := - pairs.foldl - (fun acc p => - acc + ratOfScaledInt scalePow10 p.1 * ratOfScaledInt scalePow10 p.2) 0 - -/-- Max per-head W_Q/W_K bounds in scaled-int form. -/ -def attnQKMaxFromScaledPairs (scalePow10 : Nat) (pairs : Array (Int × Int)) : Rat × Rat := - pairs.foldl - (fun acc p => - (max acc.1 (ratOfScaledInt scalePow10 p.1), - max acc.2 (ratOfScaledInt scalePow10 p.2))) - (0, 0) - -/-! ### Derived properties -/ - -private theorem pure_eq_ok {ε α : Type} (x : α) : (pure x : Except ε α) = .ok x := rfl - -theorem vectorMaxAbsScaledFromBytes_zero - (bytes : ByteArray) (scalePow10 : Nat) : - vectorMaxAbsScaledFromBytes bytes 0 scalePow10 = .ok 0 := by - simp [vectorMaxAbsScaledFromBytes, pure_eq_ok] - -theorem matrixNormInfScaledFromBytes_zero_rows - (bytes : ByteArray) (cols scalePow10 : Nat) : - matrixNormInfScaledFromBytes bytes 0 cols scalePow10 = .ok 0 := by - simp [matrixNormInfScaledFromBytes, pure_eq_ok] - -theorem matrixNormInfScaledFromBytes_zero_cols - (bytes : ByteArray) (rows scalePow10 : Nat) : - matrixNormInfScaledFromBytes bytes rows 0 scalePow10 = .ok 0 := by - simp [matrixNormInfScaledFromBytes, pure_eq_ok] - -theorem scaledFloatArrayFromBytes_zero - (bytes : ByteArray) (scalePow10 : Nat) : - scaledFloatArrayFromBytes bytes 0 scalePow10 = .ok #[] := by - simp [scaledFloatArrayFromBytes, pure_eq_ok] - -theorem i32ArrayFromBytes_zero (bytes : ByteArray) : - i32ArrayFromBytes bytes 0 = .ok #[] := by - simp [i32ArrayFromBytes, pure_eq_ok] - -/-! ### Specs -/ - -theorem parseHeaderLine_spec_binary_pure : parseHeaderLine = parseHeaderLine := rfl -theorem readHeaderNat_spec_binary_pure : readHeaderNat = readHeaderNat := rfl -theorem parseBinaryHeaderLines_spec_binary_pure : - parseBinaryHeaderLines = parseBinaryHeaderLines := rfl -theorem u64FromLE_spec_binary_pure : u64FromLE = u64FromLE := rfl -theorem u32FromLE_spec_binary_pure : u32FromLE = u32FromLE := rfl -theorem i32FromLE_spec_binary_pure : i32FromLE = i32FromLE := rfl -theorem pow2Nat_spec_binary_pure : pow2Nat = pow2Nat := rfl -theorem ceilDivNat_spec_binary_pure : ceilDivNat = ceilDivNat := rfl -theorem floatAbsCeilScaled_spec_binary_pure : floatAbsCeilScaled = floatAbsCeilScaled := rfl -theorem floatScaledCeilSigned_spec_binary_pure : - floatScaledCeilSigned = floatScaledCeilSigned := rfl -theorem vectorMaxAbsScaledFromBytes_spec_binary_pure : - vectorMaxAbsScaledFromBytes = vectorMaxAbsScaledFromBytes := rfl -theorem matrixNormInfScaledFromBytes_spec_binary_pure : - matrixNormInfScaledFromBytes = matrixNormInfScaledFromBytes := rfl -theorem scaledFloatArrayFromBytes_spec_binary_pure : - scaledFloatArrayFromBytes = scaledFloatArrayFromBytes := rfl -theorem scaledFloatFromBytes_spec_binary_pure : - scaledFloatFromBytes = scaledFloatFromBytes := rfl -theorem i32ArrayFromBytes_spec_binary_pure : - i32ArrayFromBytes = i32ArrayFromBytes := rfl -theorem matrixNormOneInfScaledFromBytes_spec_binary_pure : - matrixNormOneInfScaledFromBytes = matrixNormOneInfScaledFromBytes := rfl -theorem opBoundScaledFromOneInf_spec_binary_pure : - opBoundScaledFromOneInf = opBoundScaledFromOneInf := rfl -theorem ratOfScaledNat_spec_binary_pure : ratOfScaledNat = ratOfScaledNat := rfl -theorem ratOfScaledInt_spec_binary_pure : ratOfScaledInt = ratOfScaledInt := rfl -theorem defaultBinaryScalePow10_spec_binary_pure : - defaultBinaryScalePow10 = defaultBinaryScalePow10 := rfl -theorem attnValueCoeffFromScaledPairs_spec_binary_pure : - attnValueCoeffFromScaledPairs = attnValueCoeffFromScaledPairs := rfl -theorem attnQKMaxFromScaledPairs_spec_binary_pure : - attnQKMaxFromScaledPairs = attnQKMaxFromScaledPairs := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds.lean b/Nfp/Sound/Bounds.lean deleted file mode 100644 index 9495926..0000000 --- a/Nfp/Sound/Bounds.lean +++ /dev/null @@ -1,19 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.Sound.Bounds.Basic -import Nfp.Sound.Bounds.MatrixNorm -import Nfp.Sound.Bounds.Gelu -import Nfp.Sound.Bounds.Exp -import Nfp.Sound.Bounds.Softmax -import Nfp.Sound.Bounds.Attention -import Nfp.Sound.Bounds.LayerNorm -import Nfp.Sound.Bounds.Portfolio -import Nfp.Sound.Bounds.Effort - -/-! -# Sound bounds in exact arithmetic - -This module is an umbrella import for the sound bound utilities. -Numeric strategy (Option A): avoid `sqrt` and any `Float`-trusted computation by using -row-sum induced norms (ℓ1 for row-vector convention) and submultiplicativity. --/ diff --git a/Nfp/Sound/Bounds/Attention.lean b/Nfp/Sound/Bounds/Attention.lean deleted file mode 100644 index 881ad3c..0000000 --- a/Nfp/Sound/Bounds/Attention.lean +++ /dev/null @@ -1,67 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Data.Nat.Sqrt - -namespace Nfp.Sound - -/-! -# Attention pattern-term helpers --/ - -/-- Upper bound on `sqrt(n)` using `Nat.sqrt` (floor) plus one. -/ -def sqrtUpperNat (n : Nat) : Nat := Nat.sqrt n + 1 - -theorem sqrtUpperNat_def (n : Nat) : sqrtUpperNat n = Nat.sqrt n + 1 := rfl - -/-- Upper bound on `sqrt(n)` as a rational. -/ -def sqrtUpperRat (n : Nat) : Rat := (sqrtUpperNat n : Nat) - -theorem sqrtUpperRat_def (n : Nat) : sqrtUpperRat n = (sqrtUpperNat n : Nat) := rfl - -/-- Upper bound on `1 / sqrt(n)` using `Nat.sqrt` (floor). -/ -def invSqrtUpperBound (n : Nat) : Rat := - if n = 0 then 0 else (1 : Rat) / (Nat.sqrt n : Nat) - -theorem invSqrtUpperBound_def (n : Nat) : - invSqrtUpperBound n = if n = 0 then 0 else (1 : Rat) / (Nat.sqrt n : Nat) := rfl - -/-- Conservative bound on `max |LayerNorm(x)|` after affine (uses only `γ`, `β`, and `dim`). -/ -def layerNormOutputMaxAbsBound (dim : Nat) (maxAbsGamma maxAbsBeta : Rat) : Rat := - maxAbsGamma * sqrtUpperRat dim + maxAbsBeta - -theorem layerNormOutputMaxAbsBound_def (dim : Nat) (maxAbsGamma maxAbsBeta : Rat) : - layerNormOutputMaxAbsBound dim maxAbsGamma maxAbsBeta = - maxAbsGamma * sqrtUpperRat dim + maxAbsBeta := rfl - -/-- Score-gradient L1 bound for attention pattern terms. -/ -def attnScoreGradBound (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound : Rat) : Rat := - let scale := invSqrtUpperBound headDim - (seqLen : Rat) * scale * - ((2 : Rat) * (modelDim : Rat) * ln1OutMaxAbs * wqBound * wkBound) - -theorem attnScoreGradBound_def (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound : Rat) : - attnScoreGradBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound = - let scale := invSqrtUpperBound headDim - (seqLen : Rat) * scale * - ((2 : Rat) * (modelDim : Rat) * ln1OutMaxAbs * wqBound * wkBound) := rfl - -/-- Pattern-term coefficient bound from value and score-gradient bounds. -/ -def attnPatternCoeffBound (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound valueCoeff : Rat) : Rat := - let inputL1 := (modelDim : Rat) * ln1OutMaxAbs - (seqLen : Rat) * - attnScoreGradBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound * - (inputL1 * valueCoeff) - -theorem attnPatternCoeffBound_def (seqLen modelDim headDim : Nat) - (ln1OutMaxAbs wqBound wkBound valueCoeff : Rat) : - attnPatternCoeffBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound valueCoeff = - let inputL1 := (modelDim : Rat) * ln1OutMaxAbs - (seqLen : Rat) * - attnScoreGradBound seqLen modelDim headDim ln1OutMaxAbs wqBound wkBound * - (inputL1 * valueCoeff) := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Basic.lean b/Nfp/Sound/Bounds/Basic.lean deleted file mode 100644 index aeba476..0000000 --- a/Nfp/Sound/Bounds/Basic.lean +++ /dev/null @@ -1,19 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat - -namespace Nfp.Sound - -/-! -# Basic Rat helpers - -Small utilities used across the sound bounds modules. --/ - -/-- Exact absolute value on `Rat`. -/ -def ratAbs (x : Rat) : Rat := - if x < 0 then -x else x - -theorem ratAbs_def (x : Rat) : ratAbs x = if x < 0 then -x else x := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Effort.lean b/Nfp/Sound/Bounds/Effort.lean deleted file mode 100644 index abfd83f..0000000 --- a/Nfp/Sound/Bounds/Effort.lean +++ /dev/null @@ -1,11 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -namespace Nfp.Sound - -/-! -# Effort tiers (placeholder) - -This module is reserved for future exp-effort records and tier schedules. --/ - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Exp.lean b/Nfp/Sound/Bounds/Exp.lean deleted file mode 100644 index 45a6256..0000000 --- a/Nfp/Sound/Bounds/Exp.lean +++ /dev/null @@ -1,91 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Data.Finset.Lattice.Fold -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Mathlib.Data.Nat.Factorial.Basic - -namespace Nfp.Sound - -open scoped BigOperators - -/-! -# Exp lower bounds (scaled Taylor + squaring) --/ - -/-- Power function on `Rat` for natural exponents. -/ -private def ratPow (x : Rat) : Nat → Rat - | 0 => 1 - | n + 1 => ratPow x n * x - -theorem ratPow_def (x : Rat) (n : Nat) : - ratPow x n = match n with | 0 => 1 | n + 1 => ratPow x n * x := by - cases n <;> rfl - -/-- Factorial as a rational. -/ -private def ratFactorial (n : Nat) : Rat := (Nat.factorial n : Nat) - -theorem ratFactorial_def (n : Nat) : ratFactorial n = (Nat.factorial n : Nat) := rfl - -/-- Taylor partial sum for `exp` (all terms are nonnegative when `x ≥ 0`). -/ -private def expTaylorLowerBound (x : Rat) (deg : Nat) : Rat := - Finset.sum (Finset.range (deg + 1)) fun k => ratPow x k / ratFactorial k - -theorem expTaylorLowerBound_def (x : Rat) (deg : Nat) : - expTaylorLowerBound x deg = - Finset.sum (Finset.range (deg + 1)) fun k => ratPow x k / ratFactorial k := rfl - -/-- Lower bound on `exp` via scaled Taylor partial sums and repeated squaring. -/ -def expLBScaledTaylor (x : Rat) (deg scalePow : Nat) : Rat := - if x < 0 then - 0 - else - let scale : Rat := (Nat.pow 2 scalePow : Nat) - let z := x / scale - let t := expTaylorLowerBound z deg - ratPow t (Nat.pow 2 scalePow) - -theorem expLBScaledTaylor_def (x : Rat) (deg scalePow : Nat) : - expLBScaledTaylor x deg scalePow = - if x < 0 then - 0 - else - let scale : Rat := (Nat.pow 2 scalePow : Nat) - let z := x / scale - let t := expTaylorLowerBound z deg - ratPow t (Nat.pow 2 scalePow) := rfl - -/-- Default portfolio of `(scalePow, taylorDeg)` candidates for `expLB`. -/ -def expLBPortfolio : Array (Nat × Nat) := - #[(2, 4), (3, 6), (4, 8)] - -theorem expLBPortfolio_def : expLBPortfolio = #[(2, 4), (3, 6), (4, 8)] := rfl - -/-- Portfolio lower bound on `exp`, with a baseline `1 + x` candidate. -/ -def expLB (x : Rat) (effort : Nat) : Rat := - let base : Rat := max 0 ((1 : Rat) + x) - let limit := min effort expLBPortfolio.size - Id.run do - let mut best := base - for i in [:limit] do - let cand := expLBScaledTaylor x (expLBPortfolio[i]!).2 (expLBPortfolio[i]!).1 - best := max best cand - return best - -theorem expLB_def (x : Rat) (effort : Nat) : - expLB x effort = - let base : Rat := max 0 ((1 : Rat) + x) - let limit := min effort expLBPortfolio.size - Id.run do - let mut best := base - for i in [:limit] do - let cand := expLBScaledTaylor x (expLBPortfolio[i]!).2 (expLBPortfolio[i]!).1 - best := max best cand - return best := rfl - -/-- Default effort used for margin-derived softmax bounds. -/ -def defaultSoftmaxExpEffort : Nat := 1 - -theorem defaultSoftmaxExpEffort_def : defaultSoftmaxExpEffort = 1 := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Gelu.lean b/Nfp/Sound/Bounds/Gelu.lean deleted file mode 100644 index 8bf6118..0000000 --- a/Nfp/Sound/Bounds/Gelu.lean +++ /dev/null @@ -1,19 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.Sound.Activation - -namespace Nfp.Sound - -/-! -# GeLU derivative bounds --/ - -/-- Global conservative GeLU derivative bound (independent of interval). -/ -def geluDerivBoundGlobal : GeluDerivTarget → Rat - | .tanh => 2 - | .exact => 2 - -theorem geluDerivBoundGlobal_def (t : GeluDerivTarget) : - geluDerivBoundGlobal t = match t with | .tanh => 2 | .exact => 2 := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/LayerNorm.lean b/Nfp/Sound/Bounds/LayerNorm.lean deleted file mode 100644 index c24fbe3..0000000 --- a/Nfp/Sound/Bounds/LayerNorm.lean +++ /dev/null @@ -1,153 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Algebra.Order.Floor.Semiring -import Mathlib.Data.Nat.Sqrt -import Mathlib.Data.Rat.Floor - -namespace Nfp.Sound - -/-! -# LayerNorm operator-norm bounds --/ - -/-! ### Local (input-dependent) LayerNorm bounds - -We want a sound upper bound on `max |γ| / sqrt(var + eps)` using exact `Rat` arithmetic, -*without* importing real `sqrt`. - -Given a proven lower bound `L ≤ var`, we have: -`1/sqrt(var+eps) ≤ 1/sqrt(L+eps)`. - -To avoid `sqrt`, we compute a dyadic rational `s = k/2^p` such that -`s^2 ≤ max (L+eps) 0`. Then `1/s ≥ 1/sqrt(L+eps)` is a valid **upper** bound on `1/sqrt(L+eps)`. --/ - -private def pow2 (p : Nat) : Nat := - Nat.pow 2 p - -theorem pow2_def (p : Nat) : pow2 p = Nat.pow 2 p := rfl - -private def sqNat (n : Nat) : Nat := n * n - -theorem sqNat_def (n : Nat) : sqNat n = n * n := rfl - -/-- Certificate that `k` is the dyadic floor of `sqrt (max x 0)` at precision `precBits`. -/ -private structure SqrtLowerDyadicCert (x : Rat) (precBits : Nat) where - k : Nat - lower : - ((sqNat k : Nat) : Rat) ≤ max x 0 * (sqNat (pow2 precBits) : Nat) - upper : - max x 0 * (sqNat (pow2 precBits) : Nat) < ((sqNat (k + 1) : Nat) : Rat) - -/-- The dyadic value `k/2^precBits` encoded by a `SqrtLowerDyadicCert`. -/ -private def SqrtLowerDyadicCert.rat {x : Rat} {precBits : Nat} - (c : SqrtLowerDyadicCert x precBits) : Rat := - Rat.normalize (Int.ofNat c.k) (pow2 precBits) (den_nz := by simp [pow2]) - -/-- Compute a dyadic floor certificate for `sqrt (max x 0)` using `Nat.sqrt` on the floor. -/ -private def sqrtLowerDyadic (x : Rat) (precBits : Nat) : SqrtLowerDyadicCert x precBits := by - let scale : Nat := pow2 precBits - let scaleSq : Nat := sqNat scale - let y : Rat := max x 0 * (scaleSq : Rat) - let m : Nat := ⌊y⌋₊ - let k : Nat := Nat.sqrt m - refine ⟨k, ?lower, ?upper⟩ - · have hy_nonneg : 0 ≤ y := by - have hmax : 0 ≤ max x 0 := le_max_right _ _ - have hscale : 0 ≤ (scaleSq : Rat) := by - exact_mod_cast (Nat.zero_le scaleSq) - exact mul_nonneg hmax hscale - have hm_le : ((m : Nat) : Rat) ≤ y := by - simpa [m] using (Nat.floor_le (a := y) hy_nonneg) - have hk_le_m : sqNat k ≤ m := by - simpa [sqNat, k] using (Nat.sqrt_le m) - have hk_le_m_rat : ((sqNat k : Nat) : Rat) ≤ (m : Rat) := by - exact_mod_cast hk_le_m - exact le_trans hk_le_m_rat hm_le - · have hy_lt : y < (m : Rat) + 1 := by - simpa [m] using (Nat.lt_floor_add_one (a := y)) - have hm_lt_nat : m < sqNat (k + 1) := by - simpa [sqNat, k, Nat.succ_eq_add_one] using (Nat.lt_succ_sqrt m) - have hm_succ_le_nat : m + 1 ≤ sqNat (k + 1) := Nat.succ_le_of_lt hm_lt_nat - have hm_succ_le_rat : (m + 1 : Rat) ≤ ((sqNat (k + 1) : Nat) : Rat) := by - exact_mod_cast hm_succ_le_nat - exact lt_of_lt_of_le hy_lt hm_succ_le_rat - -/-- Dyadic lower bound on `sqrt (max x 0)` as a `Rat`. -/ -private def sqrtLowerDyadicRat (x : Rat) (precBits : Nat) : Rat := - (sqrtLowerDyadic x precBits).rat - -/-- Conservative bound for the operator norm of a row-wise LayerNorm Jacobian. - -In exact real arithmetic one can show `‖J‖₂ ≤ max |γ| / σ` with `σ = sqrt(var + eps)`. -For sound certification without real `sqrt`, we compute a dyadic lower bound `s ≤ sqrt(eps)` -and use `maxAbsGamma / s`, which is a **valid upper bound** on `maxAbsGamma / sqrt(eps)`. - -When `eps ≤ 1`, we may also use `maxAbsGamma / eps`, and take the minimum of the two -sound bounds for a tighter result. For `eps > 1`, `maxAbsGamma / eps` is **not** sound, -so we only use the dyadic bound. - -For tighter **local** certification (weights + a bounded input region), use -`layerNormOpBoundLocal`, which replaces `eps` with a proven variance lower bound. --/ -def layerNormOpBoundConservative (maxAbsGamma eps : Rat) (sqrtPrecBits : Nat) : Rat := - if eps ≤ 0 then - 0 - else - let raw := maxAbsGamma / eps - let s := sqrtLowerDyadicRat eps sqrtPrecBits - if s ≤ 0 then - if eps ≤ 1 then raw else maxAbsGamma - else - let sBound := maxAbsGamma / s - if eps ≤ 1 then min raw sBound else sBound - -theorem layerNormOpBoundConservative_def (maxAbsGamma eps : Rat) (sqrtPrecBits : Nat) : - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits = - if eps ≤ 0 then - 0 - else - let raw := maxAbsGamma / eps - let s := sqrtLowerDyadicRat eps sqrtPrecBits - if s ≤ 0 then - if eps ≤ 1 then raw else maxAbsGamma - else - let sBound := maxAbsGamma / s - if eps ≤ 1 then min raw sBound else sBound := rfl - -/-- Local upper bound on the operator norm of a row-wise LayerNorm Jacobian. - -If `varianceLowerBound` is a proven lower bound on the per-row variance, then: -`‖J‖₂ ≤ maxAbsGamma / sqrt(varianceLowerBound + eps)`. - -We compute an upper bound using a dyadic lower bound on `sqrt(varianceLowerBound + eps)`. -If the dyadic lower bound is zero (too small / insufficient precision), we fall back to the -conservative bound `maxAbsGamma / eps`. --/ -def layerNormOpBoundLocal (maxAbsGamma varianceLowerBound eps : Rat) - (sqrtPrecBits : Nat) : Rat := - let denom := varianceLowerBound + eps - if denom ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - let s := sqrtLowerDyadicRat denom sqrtPrecBits - if s ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - maxAbsGamma / s - -theorem layerNormOpBoundLocal_def (maxAbsGamma varianceLowerBound eps : Rat) - (sqrtPrecBits : Nat) : - layerNormOpBoundLocal maxAbsGamma varianceLowerBound eps sqrtPrecBits = - let denom := varianceLowerBound + eps - if denom ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - let s := sqrtLowerDyadicRat denom sqrtPrecBits - if s ≤ 0 then - layerNormOpBoundConservative maxAbsGamma eps sqrtPrecBits - else - maxAbsGamma / s := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/MatrixNorm.lean b/Nfp/Sound/Bounds/MatrixNorm.lean deleted file mode 100644 index 4daa09f..0000000 --- a/Nfp/Sound/Bounds/MatrixNorm.lean +++ /dev/null @@ -1,130 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Mathlib.Data.Finset.Lattice.Fold -import Mathlib.Data.Fintype.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Nfp.Sound.Bounds.Basic - -namespace Nfp.Sound - -open scoped BigOperators - -/-! -# Matrix norm helpers (Rat, row-sum) - -Utilities for row-sum operator norm bounds on finite matrices. --/ - -/-- Streaming accumulator for `maxᵢ ∑ⱼ |aᵢⱼ|` over row-major entries. -/ -structure RowSumAcc where - rows : Nat - cols : Nat - colIdx : Nat := 0 - curRowSum : Rat := 0 - maxRowSum : Rat := 0 - -namespace RowSumAcc - -/-- Feed one entry from a row-major stream. -/ -def feed (acc : RowSumAcc) (x : Rat) : RowSumAcc := - let cur := acc.curRowSum + ratAbs x - let colIdx' := acc.colIdx + 1 - if acc.cols = 0 then - { acc with colIdx := colIdx', curRowSum := cur, maxRowSum := max acc.maxRowSum cur } - else if colIdx' = acc.cols then - { acc with - colIdx := 0 - curRowSum := 0 - maxRowSum := max acc.maxRowSum cur } - else - { acc with colIdx := colIdx', curRowSum := cur } - -theorem feed_def (acc : RowSumAcc) (x : Rat) : - RowSumAcc.feed acc x = - let cur := acc.curRowSum + ratAbs x - let colIdx' := acc.colIdx + 1 - if acc.cols = 0 then - { acc with colIdx := colIdx', curRowSum := cur, maxRowSum := max acc.maxRowSum cur } - else if colIdx' = acc.cols then - { acc with colIdx := 0, curRowSum := 0, maxRowSum := max acc.maxRowSum cur } - else - { acc with colIdx := colIdx', curRowSum := cur } := rfl - -/-- Finalize to a bound. (If the last row is partial, we still account for it.) -/ -def finish (acc : RowSumAcc) : Rat := - max acc.maxRowSum acc.curRowSum - -theorem finish_def (acc : RowSumAcc) : RowSumAcc.finish acc = max acc.maxRowSum acc.curRowSum := rfl - -end RowSumAcc - -/-- A rational-weighted matrix on finite types. -/ -structure RatMatrix (S T : Type*) [Fintype S] [Fintype T] where - w : S → T → Rat - -namespace RatMatrix - -variable {S T : Type*} [Fintype S] [Fintype T] - -/-- Row sum of absolute values in `Rat`. -/ -def rowAbsSum (M : RatMatrix S T) [DecidableEq T] (i : S) : Rat := - ∑ j, ratAbs (M.w i j) - -theorem rowAbsSum_def (M : RatMatrix S T) [DecidableEq T] (i : S) : - RatMatrix.rowAbsSum M i = ∑ j, ratAbs (M.w i j) := rfl - -/-- Row-sum operator norm bound in `Rat` (induced ℓ1 for row-vectors). -/ -def operatorNormBound (M : RatMatrix S T) [DecidableEq S] [DecidableEq T] [Nonempty S] : Rat := - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) fun i => - rowAbsSum M i - -theorem operatorNormBound_def (M : RatMatrix S T) [DecidableEq S] [DecidableEq T] [Nonempty S] : - RatMatrix.operatorNormBound M = - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) fun i => - rowAbsSum M i := rfl - -/-- Build a `RatMatrix` from row-major data with missing entries treated as 0. -/ -def ofRowMajor (rows cols : Nat) (data : Array Rat) : - RatMatrix (Fin rows) (Fin cols) := - ⟨fun i j => - let idx := i.val * cols + j.val - if h : idx < data.size then data[idx] else 0⟩ - -theorem ofRowMajor_def (rows cols : Nat) (data : Array Rat) : - RatMatrix.ofRowMajor rows cols data = - ⟨fun i j => - let idx := i.val * cols + j.val - if h : idx < data.size then data[idx] else 0⟩ := rfl - -end RatMatrix - -/-- Compute the row-sum norm `maxᵢ ∑ⱼ |M[i,j]|` from a row-major array. - -If the provided data has fewer than `rows*cols` entries, missing entries are treated as 0. -Extra entries are ignored. --/ -def matrixNormInfOfRowMajor (rows cols : Nat) (data : Array Rat) : Rat := - if h : rows = 0 then - 0 - else - let _ : Nonempty (Fin rows) := ⟨⟨0, Nat.pos_of_ne_zero h⟩⟩ - RatMatrix.operatorNormBound (RatMatrix.ofRowMajor rows cols data) - -theorem matrixNormInfOfRowMajor_def (rows cols : Nat) (data : Array Rat) : - matrixNormInfOfRowMajor rows cols data = - if h : rows = 0 then - 0 - else - let _ : Nonempty (Fin rows) := ⟨⟨0, Nat.pos_of_ne_zero h⟩⟩ - RatMatrix.operatorNormBound (RatMatrix.ofRowMajor rows cols data) := rfl - -/-- Row-sum operator norm bound for a product. - -`‖A·B‖∞ ≤ ‖A‖∞ · ‖B‖∞`. --/ -def normInfMulBound (a b : Rat) : Rat := a * b - -theorem normInfMulBound_def (a b : Rat) : normInfMulBound a b = a * b := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Portfolio.lean b/Nfp/Sound/Bounds/Portfolio.lean deleted file mode 100644 index 0fe4279..0000000 --- a/Nfp/Sound/Bounds/Portfolio.lean +++ /dev/null @@ -1,27 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat - -namespace Nfp.Sound - -/-! -# Portfolio bounds - -Combinators for selecting the best bound among sound candidates. --/ - -/-- Best upper bound among candidates (never worse than `base`). -/ -def ubBest (base : Rat) (cands : Array Rat) : Rat := - cands.foldl min base - -theorem ubBest_def (base : Rat) (cands : Array Rat) : - ubBest base cands = cands.foldl min base := rfl - -/-- Best lower bound among candidates (never worse than `base`). -/ -def lbBest (base : Rat) (cands : Array Rat) : Rat := - cands.foldl max base - -theorem lbBest_def (base : Rat) (cands : Array Rat) : - lbBest base cands = cands.foldl max base := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bounds/Softmax.lean b/Nfp/Sound/Bounds/Softmax.lean deleted file mode 100644 index d4e8db9..0000000 --- a/Nfp/Sound/Bounds/Softmax.lean +++ /dev/null @@ -1,156 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Algebra.Order.Ring.Unbundled.Rat -import Nfp.Sound.Bounds.Exp -import Nfp.Sound.Bounds.Portfolio - -namespace Nfp.Sound - -/-! -# Softmax Jacobian bounds --/ - -/-- Worst-case bound on the row-sum operator norm of a softmax Jacobian row. - -For a probability row `p`, the softmax Jacobian is `J = diag(p) - p pᵀ`. -For row `i`, the absolute row-sum is: - -`∑ⱼ |Jᵢⱼ| = pᵢ(1-pᵢ) + ∑_{j≠i} pᵢ pⱼ = 2 pᵢ (1-pᵢ) ≤ 1/2`. - -This bound is universal (independent of sequence length). --/ -def softmaxJacobianNormInfWorst : Rat := (1 : Rat) / 2 - -theorem softmaxJacobianNormInfWorst_def : softmaxJacobianNormInfWorst = (1 : Rat) / 2 := rfl - -/-- Clamp a rational to the unit interval `[0,1]`. -/ -private def clamp01 (x : Rat) : Rat := - max 0 (min x 1) - -theorem clamp01_def (x : Rat) : clamp01 x = max 0 (min x 1) := rfl - -/-- Local upper bound on the row-sum softmax Jacobian norm given `p ∈ [pLo, pHi]`. -/ -def softmaxJacobianNormInfBound (pLo pHi : Rat) : Rat := - let lo0 := min pLo pHi - let hi0 := max pLo pHi - let lo := clamp01 lo0 - let hi := clamp01 hi0 - if hi < lo then - 0 - else - let half : Rat := (1 : Rat) / 2 - let f : Rat → Rat := fun p => (2 : Rat) * p * (1 - p) - if lo ≤ half ∧ half ≤ hi then - half - else - max (f lo) (f hi) - -theorem softmaxJacobianNormInfBound_def (pLo pHi : Rat) : - softmaxJacobianNormInfBound pLo pHi = - let lo0 := min pLo pHi - let hi0 := max pLo pHi - let lo := clamp01 lo0 - let hi := clamp01 hi0 - if hi < lo then - 0 - else - let half : Rat := (1 : Rat) / 2 - let f : Rat → Rat := fun p => (2 : Rat) * p * (1 - p) - if lo ≤ half ∧ half ≤ hi then - half - else - max (f lo) (f hi) := rfl - -/-! ### Margin-derived softmax bounds -/ - -/-- Lower bound on the maximum softmax probability from a logit margin. - -Uses a portfolio `expLB` to lower bound `exp(m)` and maps it to -`p_max ≥ exp(m) / (exp(m) + (n-1))` for `m > 0`, with `n = seqLen`. -/ -def softmaxMaxProbLowerBound (seqLen : Nat) (margin : Rat) (expEffort : Nat) : Rat := - if seqLen = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let e := expLB margin expEffort - e / (e + (nRat - 1)) - else - 0 - -theorem softmaxMaxProbLowerBound_def (seqLen : Nat) (margin : Rat) (expEffort : Nat) : - softmaxMaxProbLowerBound seqLen margin expEffort = - if seqLen = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let e := expLB margin expEffort - e / (e + (nRat - 1)) - else - 0 := rfl - -/-- Lower bound on total target softmax weight from a logit margin. - -If at least `targetCount` logits exceed the rest by `margin`, then the total -target weight is at least `t*exp(m)/(t*exp(m)+(n-t))`. --/ -def softmaxTargetWeightLowerBound (seqLen targetCount : Nat) (margin : Rat) - (expEffort : Nat) : Rat := - if seqLen = 0 || targetCount = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let tRat : Rat := (targetCount : Nat) - let base := tRat / nRat - let e := expLB margin expEffort - let cand := (tRat * e) / (tRat * e + (nRat - tRat)) - lbBest base #[cand] - else - 0 - -theorem softmaxTargetWeightLowerBound_def (seqLen targetCount : Nat) (margin : Rat) - (expEffort : Nat) : - softmaxTargetWeightLowerBound seqLen targetCount margin expEffort = - if seqLen = 0 || targetCount = 0 then - 0 - else if margin > 0 then - let nRat : Rat := (seqLen : Nat) - let tRat : Rat := (targetCount : Nat) - let base := tRat / nRat - let e := expLB margin expEffort - let cand := (tRat * e) / (tRat * e + (nRat - tRat)) - lbBest base #[cand] - else - 0 := rfl - -/-- Upper bound on the row-sum softmax Jacobian norm from a max-probability lower bound. - -If the maximum probability is at least `pLo` and `pLo > 1/2`, then every row -satisfies `2 p (1-p) ≤ 2 pLo (1-pLo)`; otherwise the universal `1/2` bound applies. -/ -def softmaxJacobianNormInfBoundFromMaxProb (pLo : Rat) : Rat := - let half : Rat := (1 : Rat) / 2 - let p := clamp01 pLo - if p > half then - (2 : Rat) * p * (1 - p) - else - half - -theorem softmaxJacobianNormInfBoundFromMaxProb_def (pLo : Rat) : - softmaxJacobianNormInfBoundFromMaxProb pLo = - let half : Rat := (1 : Rat) / 2 - let p := clamp01 pLo - if p > half then - (2 : Rat) * p * (1 - p) - else - half := rfl - -/-- Upper bound on the row-sum softmax Jacobian norm from a logit margin. -/ -def softmaxJacobianNormInfBoundFromMargin (seqLen : Nat) (margin : Rat) (expEffort : Nat) : Rat := - softmaxJacobianNormInfBoundFromMaxProb (softmaxMaxProbLowerBound seqLen margin expEffort) - -theorem softmaxJacobianNormInfBoundFromMargin_def (seqLen : Nat) (margin : Rat) - (expEffort : Nat) : - softmaxJacobianNormInfBoundFromMargin seqLen margin expEffort = - softmaxJacobianNormInfBoundFromMaxProb - (softmaxMaxProbLowerBound seqLen margin expEffort) := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Bridge.lean b/Nfp/Sound/Bridge.lean deleted file mode 100644 index 461cc0f..0000000 --- a/Nfp/Sound/Bridge.lean +++ /dev/null @@ -1,758 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.Rat.BigOperators -import Mathlib.Data.Rat.Cast.Order -import Mathlib.Data.Finset.Lattice.Fold -import Nfp.Linearization -import Nfp.SignedMixer -import Nfp.Sound.Bounds -import Nfp.Sound.Cert - -namespace Nfp.Sound - -open scoped BigOperators - -namespace RatMatrix - -variable {S T : Type*} [Fintype S] [Fintype T] [DecidableEq S] [DecidableEq T] - -/-- Cast a rational matrix to a real SignedMixer. -/ -noncomputable def toSignedMixer (M : RatMatrix S T) : SignedMixer S T := - ⟨fun i j => (M.w i j : ℝ)⟩ - -omit [DecidableEq S] [DecidableEq T] in -theorem toSignedMixer_spec (M : RatMatrix S T) : toSignedMixer M = toSignedMixer M := rfl - -lemma ratAbs_eq_abs (x : Rat) : ratAbs x = |x| := by - by_cases h : x < 0 - · simp [ratAbs, h, abs_of_neg h] - · have h' : 0 ≤ x := le_of_not_gt h - simp [ratAbs, h, abs_of_nonneg h'] - -/-- Casting the rational bound matches the real row-sum operator norm bound. -/ -theorem operatorNormBound_cast (M : RatMatrix S T) [Nonempty S] : - (operatorNormBound M : ℝ) = SignedMixer.operatorNormBound (M.toSignedMixer) := by - classical - have hsup_cast : - (operatorNormBound M : ℝ) = - Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) - (fun i => ((rowAbsSum M i : Rat) : ℝ)) := by - apply le_antisymm - · rcases Finset.exists_mem_eq_sup' - (s := Finset.univ) (H := Finset.univ_nonempty (α := S)) - (f := fun i => rowAbsSum M i) with ⟨i, hi, hsup⟩ - have hcast : (operatorNormBound M : ℝ) = (rowAbsSum M i : ℝ) := by - simp [operatorNormBound, hsup] - calc - (operatorNormBound M : ℝ) = (rowAbsSum M i : ℝ) := hcast - _ ≤ Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) - (fun j => ((rowAbsSum M j : Rat) : ℝ)) := by - exact Finset.le_sup' (s := Finset.univ) - (f := fun j => ((rowAbsSum M j : Rat) : ℝ)) hi - · refine (Finset.sup'_le_iff (s := Finset.univ) - (H := Finset.univ_nonempty (α := S)) - (f := fun i => ((rowAbsSum M i : Rat) : ℝ)) - (a := (operatorNormBound M : ℝ))).2 ?_ - intro i hi - have hle : rowAbsSum M i ≤ operatorNormBound M := by - exact Finset.le_sup' (s := Finset.univ) (f := fun j => rowAbsSum M j) hi - exact (Rat.cast_le (K := ℝ)).2 hle - calc - (operatorNormBound M : ℝ) - = Finset.sup' Finset.univ (Finset.univ_nonempty (α := S)) - (fun i => ((rowAbsSum M i : Rat) : ℝ)) := hsup_cast - _ = SignedMixer.operatorNormBound (M.toSignedMixer) := by - simp [SignedMixer.operatorNormBound, rowAbsSum, toSignedMixer, - ratAbs_eq_abs, Rat.cast_sum, Rat.cast_abs] - -/-- Casted row-major bound agrees with the `SignedMixer` operator norm bound. -/ -theorem matrixNormInfOfRowMajor_cast (rows cols : Nat) (data : Array Rat) - [Nonempty (Fin rows)] (h : rows ≠ 0) : - (matrixNormInfOfRowMajor rows cols data : ℝ) = - SignedMixer.operatorNormBound (RatMatrix.ofRowMajor rows cols data).toSignedMixer := by - classical - simpa [matrixNormInfOfRowMajor, h] using - (operatorNormBound_cast (M := RatMatrix.ofRowMajor rows cols data)) - -end RatMatrix - -/-! ## Certificate-to-Jacobian bridge -/ - -/-- Assumptions needed to link certificate fields to component operator-norm bounds. -/ -structure LayerComponentNormAssumptions - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) : Prop where - ln1Bound : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ) - ln1Bound_nonneg : 0 ≤ (l.ln1Bound : ℝ) - attnValueBound : - SignedMixer.operatorNormBound (valueTerm (D.layers i)) ≤ - (Fintype.card n : ℝ) * (l.attnValueCoeff : ℝ) - attnPatternBound : - SignedMixer.operatorNormBound (patternTerm (D.layers i)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) - ln2Bound : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ) - ln2Bound_nonneg : 0 ≤ (l.ln2Bound : ℝ) - mlpWinBound : - SignedMixer.operatorNormBound (D.mlpFactors i).win ≤ (l.mlpWinBound : ℝ) - mlpWinBound_nonneg : 0 ≤ (l.mlpWinBound : ℝ) - mlpWoutBound : - SignedMixer.operatorNormBound (D.mlpFactors i).wout ≤ (l.mlpWoutBound : ℝ) - mlpWoutBound_nonneg : 0 ≤ (l.mlpWoutBound : ℝ) - mlpDerivBound : - ∀ j, |(D.mlpFactors i).deriv j| ≤ (l.mlpActDerivBound : ℝ) - -/-- Pattern-term bound from the explicit formula and row-sum gradient/value bounds. -/ -theorem attn_pattern_bound_of_explicit - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (G V : ℝ) - (hGrad : ∀ (pos : n) (d_in : d) (q : n), - ∑ k, |attentionGradient (D.layers i) q k pos d_in| ≤ G) - (hValue : SignedMixer.operatorNormBound (valueOutputMixer (D.layers i)) ≤ V) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) : - SignedMixer.operatorNormBound (patternTerm (D.layers i)) ≤ - (Fintype.card n : ℝ) * G * V := by - simpa using - (patternTerm_operatorNormBound_le_of_eq_explicit - (L := D.layers i) (G := G) (V := V) hGrad hValue hEq) - -/-- Pattern-term bound from certificate validity and attention-state assumptions. -/ -theorem attn_pattern_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLenNat modelDimNat headDimNat : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLenNat modelDimNat headDimNat) - (hSeqLen : seqLenNat = Fintype.card n) - (hModelDim : modelDimNat = Fintype.card d) - (hScale : (1 / Real.sqrt (modelDim d)) ≤ (invSqrtUpperBound headDimNat : ℝ)) - (hInputBound : - ∀ pos d_in, |(D.layers i).state.input pos d_in| ≤ (l.ln1OutMaxAbsBound : ℝ)) - (hKeys : - ∀ pos d', (D.layers i).state.keys pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_K.w d_in d') - (hQueries : - ∀ pos d', (D.layers i).state.queries pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_Q.w d_in d') - (hValues : - ∀ pos d', (D.layers i).state.values pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_V.w d_in d') - (hWQ : - SignedMixer.operatorNormBound (D.layers i).layer.W_Q ≤ (l.wqOpBoundMax : ℝ)) - (hWK : - SignedMixer.operatorNormBound (D.layers i).layer.W_K ≤ (l.wkOpBoundMax : ℝ)) - (hVO : - SignedMixer.operatorNormBound - ((D.layers i).layer.W_V.comp (D.layers i).layer.W_O) ≤ - (l.attnValueCoeff : ℝ)) - (hConsistent : - ∀ q, (D.layers i).state.attentionWeights q = - softmax ((D.layers i).state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound - (softmaxJacobian ((D.layers i).state.scores q)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ)) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) : - SignedMixer.operatorNormBound (patternTerm (D.layers i)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - classical - let L := D.layers i - let B : ℝ := (modelDimNat : ℝ) * (l.ln1OutMaxAbsBound : ℝ) - let wq : ℝ := (l.wqOpBoundMax : ℝ) - let wk : ℝ := (l.wkOpBoundMax : ℝ) - let J : ℝ := (l.softmaxJacobianNormInfUpperBound : ℝ) - let S : ℝ := - (Fintype.card n : ℝ) * - ((1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk)) - let V : ℝ := B * (l.attnValueCoeff : ℝ) - have hSeqLen' : (Fintype.card n : ℝ) = (seqLenNat : ℝ) := by - exact_mod_cast hSeqLen.symm - have hModelDim' : (Fintype.card d : ℝ) = (modelDimNat : ℝ) := by - exact_mod_cast hModelDim.symm - have hB_nonneg : 0 ≤ B := by - have hln1_nonneg : 0 ≤ (l.ln1OutMaxAbsBound : ℝ) := by - rcases (inferInstance : Nonempty n) with ⟨pos⟩ - rcases (inferInstance : Nonempty d) with ⟨d_in⟩ - have h := hInputBound pos d_in - exact le_trans (abs_nonneg _) h - have hdim_nonneg : 0 ≤ (modelDimNat : ℝ) := by - exact_mod_cast (Nat.zero_le _) - exact mul_nonneg hdim_nonneg hln1_nonneg - have hWQ_nonneg : 0 ≤ wq := by - exact le_trans - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_Q)) hWQ - have hWK_nonneg : 0 ≤ wk := by - exact le_trans - (SignedMixer.operatorNormBound_nonneg (M := L.layer.W_K)) hWK - have hInput : - ∀ pos, ∑ d', |L.state.input pos d'| ≤ B := by - intro pos - have hsum : - ∑ d', |L.state.input pos d'| ≤ - ∑ d' : d, (l.ln1OutMaxAbsBound : ℝ) := by - refine Finset.sum_le_sum ?_ - intro d' _hd - exact hInputBound pos d' - have hconst : - (∑ d' : d, (l.ln1OutMaxAbsBound : ℝ)) = B := by - simp [B, hModelDim'] - exact le_trans hsum (by simp [hconst]) - have hScore : - ∀ i' d_in q, ∑ k, |scoreGradient L q k i' d_in| ≤ S := by - intro i' d_in q - have h := - scoreGradient_sum_le_of_input (L := L) (q := q) (i := i') (d_in := d_in) - (B := B) (wq := wq) (wk := wk) - hInput (hKeys := hKeys) (hQueries := hQueries) hWQ hWK - simpa [S, wq, wk] using h - have hValue : - SignedMixer.operatorNormBound (valueOutputMixer L) ≤ V := by - have h := - valueOutputMixer_operatorNormBound_le_of_input (L := L) (B := B) - (V := (l.attnValueCoeff : ℝ)) hInput (hValues := hValues) hVO - simpa [V] using h - have hPattern := - patternTerm_operatorNormBound_le_of_softmax (L := L) (J := J) (S := S) (V := V) - (hConsistent := hConsistent) (hSoftmax := hSoftmax) - (hScore := hScore) (hValue := hValue) hEq - have hJ_nonneg : 0 ≤ J := by - rcases (inferInstance : Nonempty n) with ⟨q⟩ - exact le_trans - (SignedMixer.operatorNormBound_nonneg - (M := softmaxJacobian (L.state.scores q))) (hSoftmax q) - have hS_nonneg : 0 ≤ 2 * B * wq * wk := by - have h2 : 0 ≤ (2 : ℝ) := by norm_num - exact mul_nonneg (mul_nonneg (mul_nonneg h2 hB_nonneg) hWQ_nonneg) hWK_nonneg - let S_bound : ℝ := - (seqLenNat : ℝ) * ((invSqrtUpperBound headDimNat : ℝ) * (2 * B * wq * wk)) - have hS_le : S ≤ S_bound := by - have hSeq_nonneg : 0 ≤ (seqLenNat : ℝ) := by - exact_mod_cast (Nat.zero_le _) - calc - S = (seqLenNat : ℝ) * - ((1 / Real.sqrt (modelDim d)) * (2 * B * wq * wk)) := by - simp [S, hSeqLen'] - _ ≤ (seqLenNat : ℝ) * - ((invSqrtUpperBound headDimNat : ℝ) * (2 * B * wq * wk)) := by - exact mul_le_mul_of_nonneg_left - (mul_le_mul_of_nonneg_right hScale hS_nonneg) hSeq_nonneg - _ = S_bound := rfl - have hValue_nonneg : 0 ≤ V := by - have hVal_nonneg : 0 ≤ (l.attnValueCoeff : ℝ) := by - exact le_trans - (SignedMixer.operatorNormBound_nonneg - (M := L.layer.W_V.comp L.layer.W_O)) hVO - exact mul_nonneg hB_nonneg hVal_nonneg - have hSV_le : - (Fintype.card n : ℝ) * S * V ≤ (seqLenNat : ℝ) * S_bound * V := by - have hSeq_nonneg : 0 ≤ (seqLenNat : ℝ) := by - exact_mod_cast (Nat.zero_le _) - have hS_le' : (seqLenNat : ℝ) * S ≤ (seqLenNat : ℝ) * S_bound := by - exact mul_le_mul_of_nonneg_left hS_le hSeq_nonneg - calc - (Fintype.card n : ℝ) * S * V = (seqLenNat : ℝ) * S * V := by - simp [hSeqLen'] - _ ≤ (seqLenNat : ℝ) * S_bound * V := by - exact mul_le_mul_of_nonneg_right hS_le' hValue_nonneg - have hPat : - (l.attnPatternCoeff : ℝ) = - (attnPatternCoeffBound seqLenNat modelDimNat headDimNat - l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax l.attnValueCoeff : ℝ) := by - rcases hValid with ⟨_hln1, _hln2, _hln1Out, _hsoftmax, hpat, _hattn, _hmlpCoeff, _hmlp, _hC⟩ - exact congrArg (fun x : Rat => (x : ℝ)) hpat - have hCoeff_eq : - (seqLenNat : ℝ) * S_bound * V = - (attnPatternCoeffBound seqLenNat modelDimNat headDimNat - l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax l.attnValueCoeff : ℝ) := by - simp [S_bound, V, B, wq, wk, attnPatternCoeffBound_def, attnScoreGradBound_def, - Rat.cast_mul, Rat.cast_natCast, Rat.cast_ofNat, mul_assoc, mul_left_comm, - mul_comm, -mul_eq_mul_left_iff, -mul_eq_mul_right_iff] - have hCoeff : - (Fintype.card n : ℝ) * S * V ≤ (l.attnPatternCoeff : ℝ) := by - calc - (Fintype.card n : ℝ) * S * V ≤ (seqLenNat : ℝ) * S_bound * V := hSV_le - _ = (attnPatternCoeffBound seqLenNat modelDimNat headDimNat - l.ln1OutMaxAbsBound l.wqOpBoundMax l.wkOpBoundMax l.attnValueCoeff : ℝ) := hCoeff_eq - _ = (l.attnPatternCoeff : ℝ) := by simp [hPat] - have hCoeffMul' : - J * ((Fintype.card n : ℝ) * S * V) ≤ J * (l.attnPatternCoeff : ℝ) := - mul_le_mul_of_nonneg_left hCoeff hJ_nonneg - have hCoeffMul : - (Fintype.card n : ℝ) * J * S * V ≤ J * (l.attnPatternCoeff : ℝ) := by - have hRearrange : - (Fintype.card n : ℝ) * J * S * V = J * ((Fintype.card n : ℝ) * S * V) := by - ring - simpa [hRearrange] using hCoeffMul' - exact le_trans hPattern hCoeffMul - -/-- Full attention Jacobian bound from certificate validity and attention-state assumptions. -/ -theorem attn_fullJacobian_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLenNat modelDimNat headDimNat : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLenNat modelDimNat headDimNat) - (hSeqLen : seqLenNat = Fintype.card n) - (hModelDim : modelDimNat = Fintype.card d) - (hScale : (1 / Real.sqrt (modelDim d)) ≤ (invSqrtUpperBound headDimNat : ℝ)) - (hInputBound : - ∀ pos d_in, |(D.layers i).state.input pos d_in| ≤ (l.ln1OutMaxAbsBound : ℝ)) - (hKeys : - ∀ pos d', (D.layers i).state.keys pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_K.w d_in d') - (hQueries : - ∀ pos d', (D.layers i).state.queries pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_Q.w d_in d') - (hValues : - ∀ pos d', (D.layers i).state.values pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_V.w d_in d') - (hWQ : - SignedMixer.operatorNormBound (D.layers i).layer.W_Q ≤ (l.wqOpBoundMax : ℝ)) - (hWK : - SignedMixer.operatorNormBound (D.layers i).layer.W_K ≤ (l.wkOpBoundMax : ℝ)) - (hVO : - SignedMixer.operatorNormBound - ((D.layers i).layer.W_V.comp (D.layers i).layer.W_O) ≤ - (l.attnValueCoeff : ℝ)) - (hConsistent : - ∀ q, (D.layers i).state.attentionWeights q = - softmax ((D.layers i).state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound - (softmaxJacobian ((D.layers i).state.scores q)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ)) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLenNat : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - classical - let L := D.layers i - have hPattern := - attn_pattern_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLenNat := seqLenNat) (modelDimNat := modelDimNat) (headDimNat := headDimNat) - hValid hSeqLen hModelDim hScale hInputBound hKeys hQueries hValues hWQ hWK hVO - hConsistent hSoftmax hEq - have hValueBase : - SignedMixer.operatorNormBound (valueTerm L) ≤ - (Fintype.card n : ℝ) * (l.attnValueCoeff : ℝ) := by - simpa using - (valueTerm_operatorNormBound_le_card (L := L) (B := (l.attnValueCoeff : ℝ)) - (hConsistent := hConsistent) (hVO := hVO)) - have hSeqLen' : (Fintype.card n : ℝ) = (seqLenNat : ℝ) := by - exact_mod_cast hSeqLen.symm - have hValue : - SignedMixer.operatorNormBound (valueTerm L) ≤ - (seqLenNat : ℝ) * (l.attnValueCoeff : ℝ) := by - simpa [hSeqLen'] using hValueBase - simpa using - (attention_fullJacobian_bound_of_terms (L := L) (hValue := hValue) (hPattern := hPattern)) - -/-- MLP coefficient bound from certificate validity and component bounds. -/ -theorem mlp_coeff_bound_of_valid - {n d : Type*} [Fintype n] [Fintype d] [Nonempty n] [Nonempty d] - (F : MLPFactorization (n := n) (d := d)) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hWin : SignedMixer.operatorNormBound F.win ≤ (l.mlpWinBound : ℝ)) - (hWin_nonneg : 0 ≤ (l.mlpWinBound : ℝ)) - (hWout : SignedMixer.operatorNormBound F.wout ≤ (l.mlpWoutBound : ℝ)) : - SignedMixer.operatorNormBound F.win * - SignedMixer.operatorNormBound F.wout ≤ (l.mlpCoeff : ℝ) := by - have hWout_nonneg : - 0 ≤ SignedMixer.operatorNormBound F.wout := - SignedMixer.operatorNormBound_nonneg (M := F.wout) - have hMul1 : - SignedMixer.operatorNormBound F.win * SignedMixer.operatorNormBound F.wout ≤ - (l.mlpWinBound : ℝ) * SignedMixer.operatorNormBound F.wout := by - exact mul_le_mul_of_nonneg_right hWin hWout_nonneg - have hMul2 : - (l.mlpWinBound : ℝ) * SignedMixer.operatorNormBound F.wout ≤ - (l.mlpWinBound : ℝ) * (l.mlpWoutBound : ℝ) := by - exact mul_le_mul_of_nonneg_left hWout hWin_nonneg - have hMul : - SignedMixer.operatorNormBound F.win * SignedMixer.operatorNormBound F.wout ≤ - (l.mlpWinBound : ℝ) * (l.mlpWoutBound : ℝ) := by - exact le_trans hMul1 hMul2 - have hCoeff := LayerAmplificationCert.mlpCoeff_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - simpa [hCoeff] using hMul - -/-- MLP operator-norm bound from a factored Jacobian plus coefficient bounds. -/ -theorem mlp_bound_of_factorization - {n d : Type*} [Fintype n] [Fintype d] [Nonempty n] [Nonempty d] - (F : MLPFactorization (n := n) (d := d)) - (l : LayerAmplificationCert) - (hDeriv : ∀ j, |F.deriv j| ≤ (l.mlpActDerivBound : ℝ)) - (hCoeff : - SignedMixer.operatorNormBound F.win * - SignedMixer.operatorNormBound F.wout ≤ (l.mlpCoeff : ℝ)) : - SignedMixer.operatorNormBound F.jacobian ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - classical - let a := SignedMixer.operatorNormBound F.win - let b := SignedMixer.operatorNormBound F.wout - have hBound : - SignedMixer.operatorNormBound F.jacobian ≤ - a * (l.mlpActDerivBound : ℝ) * b := by - simpa using - (operatorNormBound_comp_diagMixer_comp_le - (A := F.win) (B := F.wout) (d := F.deriv) - (a := a) (c := (l.mlpActDerivBound : ℝ)) (b := b) - (hA := by simp [a]) (hB := by simp [b]) hDeriv) - have hAct_nonneg : 0 ≤ (l.mlpActDerivBound : ℝ) := by - rcases (inferInstance : Nonempty F.hidden) with ⟨j⟩ - have h := hDeriv j - exact le_trans (abs_nonneg _) h - have hMul : - a * (l.mlpActDerivBound : ℝ) * b ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - have hMul' : - (a * b) * (l.mlpActDerivBound : ℝ) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - exact mul_le_mul_of_nonneg_right hCoeff hAct_nonneg - have hRearrange : - a * (l.mlpActDerivBound : ℝ) * b = - (a * b) * (l.mlpActDerivBound : ℝ) := by - calc - a * (l.mlpActDerivBound : ℝ) * b = - a * b * (l.mlpActDerivBound : ℝ) := by - simpa [mul_assoc] using (mul_right_comm a (l.mlpActDerivBound : ℝ) b) - _ = (a * b) * (l.mlpActDerivBound : ℝ) := by simp [mul_assoc] - simpa [hRearrange] using hMul' - exact le_trans hBound hMul - -/-- Attention-component bound from certificate identities and component bounds. -/ -theorem attn_component_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hLn1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ)) - (hLn1_nonneg : 0 ≤ (l.ln1Bound : ℝ)) - (hFull : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ)) : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) - ≤ (l.attnJacBound : ℝ) := by - classical - let attnTotal : ℝ := - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) - have hcomp : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) ≤ - SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian := by - simpa using - (SignedMixer.operatorNormBound_comp_le - (M := D.ln1Jacobians i) (N := (D.layers i).fullJacobian)) - have hFull_nonneg : - 0 ≤ SignedMixer.operatorNormBound (D.layers i).fullJacobian := - SignedMixer.operatorNormBound_nonneg (M := (D.layers i).fullJacobian) - have hmul1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (l.ln1Bound : ℝ) * SignedMixer.operatorNormBound (D.layers i).fullJacobian := by - exact mul_le_mul_of_nonneg_right hLn1 hFull_nonneg - have hmul2 : - (l.ln1Bound : ℝ) * SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (l.ln1Bound : ℝ) * attnTotal := by - simpa [attnTotal] using - (mul_le_mul_of_nonneg_left hFull hLn1_nonneg) - have hmul : - SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (l.ln1Bound : ℝ) * attnTotal := by - exact le_trans hmul1 hmul2 - have hAttn : - (l.ln1Bound : ℝ) * attnTotal = - (l.attnJacBound : ℝ) := by - have hAttn := - LayerAmplificationCert.attnJacBound_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - simpa [attnTotal, mul_assoc, add_assoc] using hAttn.symm - calc - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) - ≤ SignedMixer.operatorNormBound (D.ln1Jacobians i) * - SignedMixer.operatorNormBound (D.layers i).fullJacobian := hcomp - _ ≤ (l.ln1Bound : ℝ) * attnTotal := hmul - _ = (l.attnJacBound : ℝ) := hAttn - -/-- MLP-component bound from certificate identities and component bounds. -/ -theorem mlp_component_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hLn2 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ)) - (hLn2_nonneg : 0 ≤ (l.ln2Bound : ℝ)) - (hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) - ≤ (l.mlpJacBound : ℝ) := by - classical - let mlpBound : ℝ := (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) - have hcomp : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) ≤ - SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) := by - simpa using - (SignedMixer.operatorNormBound_comp_le - (M := D.ln2Jacobians i) (N := D.mlpJacobians i)) - have hMlp_nonneg : 0 ≤ SignedMixer.operatorNormBound (D.mlpJacobians i) := - SignedMixer.operatorNormBound_nonneg (M := D.mlpJacobians i) - have hmul1 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.ln2Bound : ℝ) * SignedMixer.operatorNormBound (D.mlpJacobians i) := by - exact mul_le_mul_of_nonneg_right hLn2 hMlp_nonneg - have hmul2 : - (l.ln2Bound : ℝ) * SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.ln2Bound : ℝ) * mlpBound := by - simpa [mlpBound] using - (mul_le_mul_of_nonneg_left hMlp hLn2_nonneg) - have hmul : - SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.ln2Bound : ℝ) * mlpBound := by - exact le_trans hmul1 hmul2 - have hMlpEq : - (l.ln2Bound : ℝ) * mlpBound = (l.mlpJacBound : ℝ) := by - have hMlpEq := - LayerAmplificationCert.mlpJacBound_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - simpa [mlpBound, mul_assoc] using hMlpEq.symm - calc - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) - ≤ SignedMixer.operatorNormBound (D.ln2Jacobians i) * - SignedMixer.operatorNormBound (D.mlpJacobians i) := hcomp - _ ≤ (l.ln2Bound : ℝ) * mlpBound := hmul - _ = (l.mlpJacBound : ℝ) := hMlpEq - -/-- Layer Jacobian residual bound from a layer amplification certificate. -/ -theorem layerJacobian_residual_bound_of_cert - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hA : - SignedMixer.operatorNormBound - ((D.ln1Jacobians i).comp (D.layers i).fullJacobian) - ≤ (l.attnJacBound : ℝ)) - (hM : - SignedMixer.operatorNormBound - ((D.ln2Jacobians i).comp (D.mlpJacobians i)) - ≤ (l.mlpJacBound : ℝ)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hC := LayerAmplificationCert.c_eq_cast_of_valid - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) (l := l) hValid - have hres := - DeepLinearization.layerJacobian_residual_bound - (D := D) (i := i) - (A := (l.attnJacBound : ℝ)) - (M := (l.mlpJacBound : ℝ)) hA hM - simpa [hC] using hres - -/-- Layer Jacobian residual bound from a certificate, with component-norm assumptions. -/ -theorem layerJacobian_residual_bound_of_cert_components - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hLn1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ)) - (hLn1_nonneg : 0 ≤ (l.ln1Bound : ℝ)) - (hFull : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ)) - (hLn2 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ)) - (hLn2_nonneg : 0 ≤ (l.ln2Bound : ℝ)) - (hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hA := - attn_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid hLn1 hLn1_nonneg hFull - have hM := - mlp_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid hLn2 hLn2_nonneg hMlp - exact layerJacobian_residual_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid hA hM - -/-- Layer Jacobian residual bound from certificate validity and attention-state assumptions. -/ -theorem layerJacobian_residual_bound_of_cert_from_state - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLenNat modelDimNat headDimNat : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLenNat modelDimNat headDimNat) - (hSeqLen : seqLenNat = Fintype.card n) - (hModelDim : modelDimNat = Fintype.card d) - (hScale : (1 / Real.sqrt (modelDim d)) ≤ (invSqrtUpperBound headDimNat : ℝ)) - (hInputBound : - ∀ pos d_in, |(D.layers i).state.input pos d_in| ≤ (l.ln1OutMaxAbsBound : ℝ)) - (hKeys : - ∀ pos d', (D.layers i).state.keys pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_K.w d_in d') - (hQueries : - ∀ pos d', (D.layers i).state.queries pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_Q.w d_in d') - (hValues : - ∀ pos d', (D.layers i).state.values pos d' = - ∑ d_in, (D.layers i).state.input pos d_in * - (D.layers i).layer.W_V.w d_in d') - (hWQ : - SignedMixer.operatorNormBound (D.layers i).layer.W_Q ≤ (l.wqOpBoundMax : ℝ)) - (hWK : - SignedMixer.operatorNormBound (D.layers i).layer.W_K ≤ (l.wkOpBoundMax : ℝ)) - (hVO : - SignedMixer.operatorNormBound - ((D.layers i).layer.W_V.comp (D.layers i).layer.W_O) ≤ - (l.attnValueCoeff : ℝ)) - (hConsistent : - ∀ q, (D.layers i).state.attentionWeights q = - softmax ((D.layers i).state.scores q)) - (hSoftmax : - ∀ q, SignedMixer.operatorNormBound - (softmaxJacobian ((D.layers i).state.scores q)) ≤ - (l.softmaxJacobianNormInfUpperBound : ℝ)) - (hEq : patternTerm (D.layers i) = patternTermExplicit (D.layers i)) - (hLn1 : - SignedMixer.operatorNormBound (D.ln1Jacobians i) ≤ (l.ln1Bound : ℝ)) - (hLn1_nonneg : 0 ≤ (l.ln1Bound : ℝ)) - (hLn2 : - SignedMixer.operatorNormBound (D.ln2Jacobians i) ≤ (l.ln2Bound : ℝ)) - (hLn2_nonneg : 0 ≤ (l.ln2Bound : ℝ)) - (hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hFull := - attn_fullJacobian_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLenNat := seqLenNat) (modelDimNat := modelDimNat) (headDimNat := headDimNat) - hValid hSeqLen hModelDim hScale hInputBound hKeys hQueries hValues hWQ hWK hVO - hConsistent hSoftmax hEq - have hA := - attn_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLenNat) (modelDim := modelDimNat) (headDim := headDimNat) - hValid hLn1 hLn1_nonneg hFull - have hM := - mlp_component_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLenNat) (modelDim := modelDimNat) (headDim := headDimNat) - hValid hLn2 hLn2_nonneg hMlp - exact layerJacobian_residual_bound_of_cert - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLenNat) (modelDim := modelDimNat) (headDim := headDimNat) - hValid hA hM - -/-- Layer Jacobian residual bound from a certificate plus component-norm assumptions. -/ -theorem layerJacobian_residual_bound_of_cert_assuming - {n d : Type*} [Fintype n] [Fintype d] [DecidableEq n] [DecidableEq d] - [Nonempty n] [Nonempty d] - (D : DeepLinearization (n := n) (d := d)) - (i : Fin D.numLayers) - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hSeqLen : seqLen = Fintype.card n) - (h : LayerComponentNormAssumptions (D := D) (i := i) (l := l)) : - SignedMixer.operatorNormBound - (D.layerJacobian i - SignedMixer.identity) ≤ (l.C : ℝ) := by - have hCoeff : - SignedMixer.operatorNormBound (D.mlpFactors i).win * - SignedMixer.operatorNormBound (D.mlpFactors i).wout ≤ (l.mlpCoeff : ℝ) := - mlp_coeff_bound_of_valid - (F := D.mlpFactors i) (l := l) - (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) hValid - h.mlpWinBound h.mlpWinBound_nonneg h.mlpWoutBound - have hMlp : - SignedMixer.operatorNormBound (D.mlpJacobians i) ≤ - (l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ) := by - simpa using - (mlp_bound_of_factorization (F := D.mlpFactors i) (l := l) - h.mlpDerivBound hCoeff) - have hFull : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - have hFullBase : - SignedMixer.operatorNormBound (D.layers i).fullJacobian ≤ - (Fintype.card n : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ) := by - simpa using - (attention_fullJacobian_bound_of_terms (L := D.layers i) - (hValue := h.attnValueBound) (hPattern := h.attnPatternBound)) - have hSeqLen' : (Fintype.card n : ℝ) = (seqLen : ℝ) := by - exact_mod_cast hSeqLen.symm - simpa [hSeqLen'] using hFullBase - exact layerJacobian_residual_bound_of_cert_components - (D := D) (i := i) (l := l) (eps := eps) (sqrtPrecBits := sqrtPrecBits) - (seqLen := seqLen) (modelDim := modelDim) (headDim := headDim) - hValid h.ln1Bound h.ln1Bound_nonneg hFull - h.ln2Bound h.ln2Bound_nonneg hMlp - -end Nfp.Sound diff --git a/Nfp/Sound/CachePure.lean b/Nfp/Sound/CachePure.lean deleted file mode 100644 index e6e77a3..0000000 --- a/Nfp/Sound/CachePure.lean +++ /dev/null @@ -1,666 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Init.System.IO -import Init.Data.ByteArray.Lemmas -import Nfp.Sound.Decimal -import Nfp.Sound.Fixed - -namespace Nfp.Sound - -/-! -# SOUND fixed-point cache (pure helpers) - -Pure parsing and encoding utilities for the SOUND cache format. -IO wrappers live in `Nfp.Untrusted.SoundCacheIO`. --/ - -namespace SoundCache - -def version : UInt32 := 1 -def magic : ByteArray := "NFP_SND_CACHE_V1\n".toUTF8 - -structure Header where - modelHash : UInt64 - modelSize : UInt64 - scalePow10 : UInt32 - numLayers : UInt32 - numHeads : UInt32 - modelDim : UInt32 - headDim : UInt32 - hiddenDim : UInt32 - deriving Repr - -private def u32le (x : UInt32) : ByteArray := - let b0 := (x &&& 0xFF).toUInt8 - let b1 := ((x >>> 8) &&& 0xFF).toUInt8 - let b2 := ((x >>> 16) &&& 0xFF).toUInt8 - let b3 := ((x >>> 24) &&& 0xFF).toUInt8 - ByteArray.mk #[b0, b1, b2, b3] - -private def u64le (x : UInt64) : ByteArray := - let b0 := (x &&& 0xFF).toUInt8 - let b1 := ((x >>> 8) &&& 0xFF).toUInt8 - let b2 := ((x >>> 16) &&& 0xFF).toUInt8 - let b3 := ((x >>> 24) &&& 0xFF).toUInt8 - let b4 := ((x >>> 32) &&& 0xFF).toUInt8 - let b5 := ((x >>> 40) &&& 0xFF).toUInt8 - let b6 := ((x >>> 48) &&& 0xFF).toUInt8 - let b7 := ((x >>> 56) &&& 0xFF).toUInt8 - ByteArray.mk #[b0, b1, b2, b3, b4, b5, b6, b7] - -private def i32le (x : Int) : ByteArray := - let ux : UInt32 := UInt32.ofInt x - u32le ux - -private def u32FromLE (b : ByteArray) (off : Nat) : UInt32 := - let b0 := (b.get! off).toUInt32 - let b1 := (b.get! (off + 1)).toUInt32 - let b2 := (b.get! (off + 2)).toUInt32 - let b3 := (b.get! (off + 3)).toUInt32 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) - -private def u64FromLE (b : ByteArray) (off : Nat) : UInt64 := - let b0 := (b.get! off).toUInt64 - let b1 := (b.get! (off + 1)).toUInt64 - let b2 := (b.get! (off + 2)).toUInt64 - let b3 := (b.get! (off + 3)).toUInt64 - let b4 := (b.get! (off + 4)).toUInt64 - let b5 := (b.get! (off + 5)).toUInt64 - let b6 := (b.get! (off + 6)).toUInt64 - let b7 := (b.get! (off + 7)).toUInt64 - b0 ||| (b1 <<< 8) ||| (b2 <<< 16) ||| (b3 <<< 24) ||| - (b4 <<< 32) ||| (b5 <<< 40) ||| (b6 <<< 48) ||| (b7 <<< 56) - -def i32FromLE (b : ByteArray) (off : Nat) : Int := - let u := u32FromLE b off - let half : UInt32 := 0x80000000 - if u < half then - Int.ofNat u.toNat - else - let two32 : Int := Int.ofNat (Nat.pow 2 32) - (Int.ofNat u.toNat) - two32 - -def encodeHeader (hdr : Header) : ByteArray := - magic - ++ u32le version - ++ u64le hdr.modelHash - ++ u64le hdr.modelSize - ++ u32le hdr.scalePow10 - ++ u32le hdr.numLayers - ++ u32le hdr.numHeads - ++ u32le hdr.modelDim - ++ u32le hdr.headDim - ++ u32le hdr.hiddenDim - -def headerBytes : Nat := - magic.size + 4 + 8 + 8 + 4 + 4 + 4 + 4 + 4 + 4 - -def decodeHeader (bytes : ByteArray) : Except String Header := do - if bytes.size < headerBytes then - throw "unexpected EOF while reading cache header" - let m := bytes.extract 0 magic.size - if m ≠ magic then - throw "invalid cache magic" - let off0 := magic.size - let v := u32FromLE bytes off0 - if v ≠ version then - throw s!"unsupported cache version {v}" - let off1 := off0 + 4 - let modelHash := u64FromLE bytes off1 - let off2 := off1 + 8 - let modelSize := u64FromLE bytes off2 - let off3 := off2 + 8 - let scalePow10 := u32FromLE bytes off3 - let off4 := off3 + 4 - let numLayers := u32FromLE bytes off4 - let off5 := off4 + 4 - let numHeads := u32FromLE bytes off5 - let off6 := off5 + 4 - let modelDim := u32FromLE bytes off6 - let off7 := off6 + 4 - let headDim := u32FromLE bytes off7 - let off8 := off7 + 4 - let hiddenDim := u32FromLE bytes off8 - return { modelHash, modelSize, scalePow10, numLayers, numHeads, modelDim, headDim, hiddenDim } - -def cacheDir : System.FilePath := "sound_cache" - -def cachePath (modelPath : System.FilePath) (modelHash : UInt64) (scalePow10 : Nat) : - System.FilePath := - let stem := modelPath.fileStem - cacheDir / s!"{stem}_{modelHash.toNat}_p{scalePow10}.nfpc" - -def expectedI32Count (hdr : Header) : Nat := - let L := hdr.numLayers.toNat - let H := hdr.numHeads.toNat - let d := hdr.modelDim.toNat - let dh := hdr.headDim.toNat - let dhid := hdr.hiddenDim.toNat - let perHead := d * dh + dh + dh * d - let perLayer := - (4 * d) + (H * perHead) + d + (d * dhid) + dhid + (dhid * d) + d - L * perLayer - -def expectedCacheBytes (hdr : Header) : UInt64 := - UInt64.ofNat headerBytes + (UInt64.ofNat (expectedI32Count hdr) * (4 : UInt64)) - -def fnv1a64Init : UInt64 := 14695981039346656037 - -def fnv1a64Update (hash : UInt64) (chunk : ByteArray) : UInt64 := - Id.run do - let prime : UInt64 := 1099511628211 - let mut h := hash - for b in chunk.data do - h := (h ^^^ (UInt64.ofNat b.toNat)) * prime - return h - -def fnv1a64 (bytes : ByteArray) : UInt64 := - fnv1a64Update fnv1a64Init bytes - -private def parseHeaderLine (line : String) : Option (String × String) := - let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - -private def findLineIdxFrom (lines : Array String) (start : Nat) (p : String → Bool) : - Option Nat := - Id.run do - let mut i := start - while i < lines.size do - if p (lines[i]!.trim) then - return some i - i := i + 1 - return none - -private def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := - match findLineIdxFrom lines start p with - | some i => i - | none => lines.size - -private def skipBlankLines (lines : Array String) (start : Nat) : Nat := - Id.run do - let mut i := start - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - return i - -private def countWsTokens (s : String) : Nat := - Id.run do - let bytes := s.toUTF8 - let mut i : Nat := 0 - let mut inTok : Bool := false - let mut cnt : Nat := 0 - while i < bytes.size do - let b := bytes[i]! - let isWs : Bool := b = 32 || b = 9 - if isWs then - inTok := false - else if !inTok then - inTok := true - cnt := cnt + 1 - i := i + 1 - return cnt - -private def skipTokensFast (lines : Array String) (start : Nat) (numTokens : Nat) : - Except String Nat := - Id.run do - let mut iLine := start - let mut remaining := numTokens - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while skipping tokens" - let line := lines[iLine]!.trim - iLine := iLine + 1 - if line.isEmpty then - pure () - else - let c := countWsTokens line - if c ≥ remaining then - remaining := 0 - else - remaining := remaining - c - return .ok iLine - -private def consumeFixedBytes - (scalePow10 : Nat) - (lines : Array String) - (start : Nat) - (count : Nat) : Except String (ByteArray × Nat) := - Id.run do - let mut iLine := start - let mut remaining := count - let mut buf : ByteArray := ByteArray.empty - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while reading fixed tokens" - let line := lines[iLine]!.trim - iLine := iLine + 1 - if line.isEmpty then - pure () - else - let bytes := line.toUTF8 - let mut j : Nat := 0 - while j < bytes.size && remaining > 0 do - while j < bytes.size && (bytes[j]! = 32 || bytes[j]! = 9) do - j := j + 1 - if j ≥ bytes.size then - break - let tokStart := j - while j < bytes.size && (bytes[j]! ≠ 32 && bytes[j]! ≠ 9) do - j := j + 1 - let tokStop := j - match parseFixed10Rounded scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok x => - buf := buf ++ i32le x - remaining := remaining - 1 - return .ok (buf, iLine) - -private def readHeaderFromLines (lines : Array String) : Except String (Header × Nat) := - Id.run do - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if i ≥ lines.size then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - - let mut numLayers : Option UInt32 := none - let mut numHeads : Option UInt32 := none - let mut modelDim : Option UInt32 := none - let mut headDim : Option UInt32 := none - let mut hiddenDim : Option UInt32 := none - - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := (v.toNat?.map UInt32.ofNat) - | "num_heads" => numHeads := (v.toNat?.map UInt32.ofNat) - | "model_dim" => modelDim := (v.toNat?.map UInt32.ofNat) - | "head_dim" => headDim := (v.toNat?.map UInt32.ofNat) - | "hidden_dim" => hiddenDim := (v.toNat?.map UInt32.ofNat) - | _ => pure () - i := i + 1 - - let some L := numLayers | return .error "missing num_layers" - let some H := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let hdr : Header := - { modelHash := 0, modelSize := 0, scalePow10 := 0, numLayers := L, numHeads := H - modelDim := d, headDim := dh, hiddenDim := dhid } - return .ok (hdr, i) - -private structure LNParamsFixed where - gamma : Array Int - beta : Array Int - -private instance : Inhabited LNParamsFixed := - ⟨{ gamma := #[], beta := #[] }⟩ - -private def collectLayerNormParamsFixed - (scalePow10 : Nat) (lines : Array String) (numLayers modelDim : Nat) : - Except String (Array LNParamsFixed × Array LNParamsFixed) := - Id.run do - let defP : LNParamsFixed := - { gamma := Array.replicate modelDim (0 : Int), beta := Array.replicate modelDim (0 : Int) } - let mut ln1 : Array LNParamsFixed := Array.replicate numLayers defP - let mut ln2 : Array LNParamsFixed := Array.replicate numLayers defP - let mut curLayer : Nat := 0 - let mut i : Nat := 0 - while i < lines.size do - let line := lines[i]!.trim - if line.startsWith "LAYER" then - let parts := line.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD curLayer - i := i + 1 - else if line = "LN1_GAMMA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln1.size then - let old := ln1[curLayer]! - ln1 := ln1.set! curLayer { old with gamma := xs } - i := next - else if line = "LN1_BETA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln1.size then - let old := ln1[curLayer]! - ln1 := ln1.set! curLayer { old with beta := xs } - i := next - else if line = "LN2_GAMMA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln2.size then - let old := ln2[curLayer]! - ln2 := ln2.set! curLayer { old with gamma := xs } - i := next - else if line = "LN2_BETA" then - match foldFixed10Tokens scalePow10 lines (i + 1) modelDim (Array.mkEmpty modelDim) - (fun a x => a.push x) with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < ln2.size then - let old := ln2[curLayer]! - ln2 := ln2.set! curLayer { old with beta := xs } - i := next - else - i := i + 1 - return .ok (ln1, ln2) - -private def encodeIntArray (xs : Array Int) : ByteArray := - xs.foldl (fun acc x => acc ++ i32le x) ByteArray.empty - -private def repeatBytes (b : ByteArray) (n : Nat) : ByteArray := - Id.run do - let mut out := ByteArray.empty - for _ in [:n] do - out := out ++ b - return out - -def buildCacheBytes - (lines : Array String) - (scalePow10 : Nat) - (modelHash modelSize : UInt64) : Except String ByteArray := - Id.run do - let hdr0E := readHeaderFromLines lines - let (hdr0, _afterHdr) ← - match hdr0E with - | .error e => return .error e - | .ok x => pure x - - let L : Nat := hdr0.numLayers.toNat - let H : Nat := hdr0.numHeads.toNat - let d : Nat := hdr0.modelDim.toNat - let dh : Nat := hdr0.headDim.toNat - let dhid : Nat := hdr0.hiddenDim.toNat - - let (ln1, ln2) ← - match collectLayerNormParamsFixed scalePow10 lines L d with - | .error e => return .error e - | .ok x => pure x - - let hdr : Header := - { hdr0 with - modelHash := modelHash - modelSize := modelSize - scalePow10 := UInt32.ofNat scalePow10 } - - let mut out : ByteArray := encodeHeader hdr - let mut pos : Nat := skipUntil lines 0 (fun s => s.startsWith "LAYER") - let zeroBytes := i32le 0 - - for l in [:L] do - let p1 := ln1.getD l { gamma := Array.replicate d (0 : Int), beta := Array.replicate d 0 } - let p2 := ln2.getD l { gamma := Array.replicate d (0 : Int), beta := Array.replicate d 0 } - out := out ++ encodeIntArray p1.gamma - out := out ++ encodeIntArray p1.beta - out := out ++ encodeIntArray p2.gamma - out := out ++ encodeIntArray p2.beta - - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - if pos ≥ lines.size then - return .error s!"unexpected EOF while scanning layer {l}" - pos := pos + 1 - - for _h in [:H] do - pos := skipBlankLines lines pos - if !(pos < lines.size && (lines[pos]!.trim.startsWith "HEAD")) then - return .error "expected HEAD" - pos := pos + 1 - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_Q") then - return .error "missing W_Q" - match skipTokensFast lines (pos + 1) (d * dh) with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_Q" then - match skipTokensFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_K") then - return .error "missing W_K" - match skipTokensFast lines (pos + 1) (d * dh) with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_K" then - match skipTokensFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_V") then - return .error "missing W_V" - match consumeFixedBytes scalePow10 lines (pos + 1) (d * dh) with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_V" then - match consumeFixedBytes scalePow10 lines (pos + 1) dh with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - else - out := out ++ repeatBytes zeroBytes dh - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_O") then - return .error "missing W_O" - match consumeFixedBytes scalePow10 lines (pos + 1) (dh * d) with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "ATTN_BIAS") then - return .error "missing ATTN_BIAS" - match consumeFixedBytes scalePow10 lines (pos + 1) d with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "MLP") then - return .error "missing MLP" - pos := pos + 1 - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_in") then - return .error "missing W_in" - match consumeFixedBytes scalePow10 lines (pos + 1) (d * dhid) with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_in") then - return .error "missing b_in" - match consumeFixedBytes scalePow10 lines (pos + 1) dhid with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_out") then - return .error "missing W_out" - match consumeFixedBytes scalePow10 lines (pos + 1) (dhid * d) with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_out") then - return .error "missing b_out" - match consumeFixedBytes scalePow10 lines (pos + 1) d with - | .error e => return .error e - | .ok (bytes, next) => - out := out ++ bytes - pos := next - - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - - return .ok out - -private def isMaybeNumberStart (b : UInt8) : Bool := - b = 45 || b = 43 || b = 46 || (48 ≤ b && b ≤ 57) - -def checkTextTokenEnvelopeLines - (lines : Array String) - (scalePow10 : Nat := 9) - (maxTokens : Nat := 0) : Except String Unit := - Id.run do - let cfg : Fixed10Cfg := { scalePow10 := scalePow10 } - let S : Nat := cfg.scaleNat - let mut checked : Nat := 0 - let mut done : Bool := false - for line in lines do - if done then - break - let s := line.trim - if s.isEmpty then - pure () - else - let bytes := s.toUTF8 - let mut i : Nat := 0 - while i < bytes.size do - while i < bytes.size && (bytes[i]! = 32 || bytes[i]! = 9) do - i := i + 1 - if i ≥ bytes.size then - i := bytes.size - let tokStart := i - while i < bytes.size && (bytes[i]! ≠ 32 && bytes[i]! ≠ 9) do - i := i + 1 - let tokStop := i - if tokStart < tokStop && isMaybeNumberStart (bytes[tokStart]!) then - let tok := String.Pos.Raw.extract s ⟨tokStart⟩ ⟨tokStop⟩ - match parseRat tok with - | .error _ => - pure () - | .ok r => - match parseFixed10Rounded scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok w => - let lo : Rat := Rat.normalize (w - 1) S (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - let hi : Rat := Rat.normalize (w + 1) S (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := scalePow10) h10pos)) - if lo ≤ r ∧ r ≤ hi then - checked := checked + 1 - else - return .error s!"token '{tok}' out of envelope: {lo} ≤ {r} ≤ {hi} failed" - if maxTokens ≠ 0 && checked ≥ maxTokens then - done := true - i := bytes.size - return .ok () - -structure I32Reader where - h : IO.FS.Handle - buf : ByteArray - pos : Nat - -def i32FromBuffer (buf : ByteArray) (pos : Nat) : Int := - i32FromLE buf pos - -/-! ### Derived properties -/ - -theorem u32le_size (x : UInt32) : (u32le x).size = 4 := by - rfl - -theorem u64le_size (x : UInt64) : (u64le x).size = 8 := by - rfl - -/-- `encodeHeader` has the exact byte length advertised by `headerBytes`. -/ -theorem encodeHeader_size (hdr : Header) : (encodeHeader hdr).size = headerBytes := by - simp [encodeHeader, headerBytes, ByteArray.size_append, u32le_size, u64le_size] - -/-- `encodeHeader` always begins with the cache magic prefix. -/ -theorem encodeHeader_magic_prefix (hdr : Header) : - (encodeHeader hdr).extract 0 magic.size = magic := by - simp [encodeHeader, ByteArray.append_assoc, ByteArray.extract_append_eq_left] - --- TODO: Prove round-trip lemmas for `u32FromLE`/`u64FromLE` and `decodeHeader (encodeHeader _)`. - -/-! ### Specs -/ - -theorem version_spec_cache_pure : version = version := rfl -theorem magic_spec_cache_pure : magic = magic := rfl -theorem Header_spec_cache_pure : Header = Header := rfl -theorem u32le_spec_cache_pure : u32le = u32le := rfl -theorem u64le_spec_cache_pure : u64le = u64le := rfl -theorem i32le_spec_cache_pure : i32le = i32le := rfl -theorem u32FromLE_spec_cache_pure : u32FromLE = u32FromLE := rfl -theorem u64FromLE_spec_cache_pure : u64FromLE = u64FromLE := rfl -theorem i32FromLE_spec_cache_pure : i32FromLE = i32FromLE := rfl -theorem encodeHeader_spec_cache_pure : encodeHeader = encodeHeader := rfl -theorem headerBytes_spec_cache_pure : headerBytes = headerBytes := rfl -theorem decodeHeader_spec_cache_pure : decodeHeader = decodeHeader := rfl -theorem cacheDir_spec_cache_pure : cacheDir = cacheDir := rfl -theorem cachePath_spec_cache_pure : cachePath = cachePath := rfl -theorem expectedI32Count_spec_cache_pure : expectedI32Count = expectedI32Count := rfl -theorem expectedCacheBytes_spec_cache_pure : expectedCacheBytes = expectedCacheBytes := rfl -theorem fnv1a64Init_spec_cache_pure : fnv1a64Init = fnv1a64Init := rfl -theorem fnv1a64Update_spec_cache_pure : fnv1a64Update = fnv1a64Update := rfl -theorem fnv1a64_spec_cache_pure : fnv1a64 = fnv1a64 := rfl -theorem parseHeaderLine_spec_cache_pure : parseHeaderLine = parseHeaderLine := rfl -theorem findLineIdxFrom_spec_cache_pure : findLineIdxFrom = findLineIdxFrom := rfl -theorem skipUntil_spec_cache_pure : skipUntil = skipUntil := rfl -theorem skipBlankLines_spec_cache_pure : skipBlankLines = skipBlankLines := rfl -theorem countWsTokens_spec_cache_pure : countWsTokens = countWsTokens := rfl -theorem skipTokensFast_spec_cache_pure : skipTokensFast = skipTokensFast := rfl -theorem consumeFixedBytes_spec_cache_pure : consumeFixedBytes = consumeFixedBytes := rfl -theorem readHeaderFromLines_spec_cache_pure : readHeaderFromLines = readHeaderFromLines := rfl -theorem LNParamsFixed_spec_cache_pure : LNParamsFixed = LNParamsFixed := rfl -theorem collectLayerNormParamsFixed_spec_cache_pure : - collectLayerNormParamsFixed = collectLayerNormParamsFixed := rfl -theorem encodeIntArray_spec_cache_pure : encodeIntArray = encodeIntArray := rfl -theorem repeatBytes_spec_cache_pure : repeatBytes = repeatBytes := rfl -theorem buildCacheBytes_spec_cache_pure : buildCacheBytes = buildCacheBytes := rfl -theorem isMaybeNumberStart_spec_cache_pure : isMaybeNumberStart = isMaybeNumberStart := rfl -theorem checkTextTokenEnvelopeLines_spec_cache_pure : - checkTextTokenEnvelopeLines = checkTextTokenEnvelopeLines := rfl -theorem I32Reader_spec_cache_pure : I32Reader = I32Reader := rfl -theorem i32FromBuffer_spec_cache_pure : i32FromBuffer = i32FromBuffer := rfl - -end SoundCache - -end Nfp.Sound diff --git a/Nfp/Sound/Cert.lean b/Nfp/Sound/Cert.lean deleted file mode 100644 index aff9ca9..0000000 --- a/Nfp/Sound/Cert.lean +++ /dev/null @@ -1,369 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Mathlib.Data.Rat.Cast.Order -import Nfp.SignedMixer -import Nfp.Sound.Bounds - -namespace Nfp.Sound - -/-! -# Sound certification report structures - -The sound path computes conservative bounds using exact `Rat` arithmetic. -These values are meant to be *trusted* (up to Lean's kernel/runtime), but may be -much looser than the Float-based heuristic analysis. --/ - - -/-- Per-layer conservative residual amplification constant `Cᵢ` -(bounds ‖layerJacobian - I‖) and its components. -`C` uses the safe algebraic form `attn + mlp + attn*mlp`. -/ -structure LayerAmplificationCert where - layerIdx : Nat - ln1MaxAbsGamma : Rat - ln1MaxAbsBeta : Rat - ln2MaxAbsGamma : Rat - /-- Optional local variance lower bound used for LN1 (if available). -/ - ln1VarianceLowerBound? : Option Rat - /-- Optional local variance lower bound used for LN2 (if available). -/ - ln2VarianceLowerBound? : Option Rat - ln1Bound : Rat - ln2Bound : Rat - /-- Upper bound on `max |LN1 output|` after affine. -/ - ln1OutMaxAbsBound : Rat - /-- Lower bound on the softmax probability interval used for Jacobian bounds. -/ - softmaxProbLo : Rat - /-- Upper bound on the softmax probability interval used for Jacobian bounds. -/ - softmaxProbHi : Rat - /-- Lower bound on the softmax logit margin (0 means no margin evidence). -/ - softmaxMarginLowerBound : Rat - /-- Effort level for the exp lower bound used in margin-derived softmax bounds. -/ - softmaxExpEffort : Nat - /-- Upper bound on the softmax Jacobian row-sum norm (portfolio bound). -/ - softmaxJacobianNormInfUpperBound : Rat - /-- Maximum operator-norm bound on W_Q across heads (row-sum). -/ - wqOpBoundMax : Rat - /-- Maximum operator-norm bound on W_K across heads (row-sum). -/ - wkOpBoundMax : Rat - /-- Value-term coefficient (sum of per-head `W_V`/`W_O` bounds). -/ - attnValueCoeff : Rat - /-- Pattern-term coefficient (score-gradient L1 × input L1 × value coefficient). -/ - attnPatternCoeff : Rat - mlpCoeff : Rat - /-- Upper bound on the operator norm of the MLP input weights. -/ - mlpWinBound : Rat - /-- Upper bound on the operator norm of the MLP output weights. -/ - mlpWoutBound : Rat - /-- Upper bound on the max GeLU derivative over this layer's preactivations. -/ - mlpActDerivBound : Rat - /-- Upper bound on the attention residual Jacobian contribution. -/ - attnJacBound : Rat - /-- Upper bound on the MLP residual Jacobian contribution. -/ - mlpJacBound : Rat - /-- Combined residual amplification bound: `attn + mlp + attn*mlp`. -/ - C : Rat - deriving Repr - -namespace LayerAmplificationCert - -instance : Inhabited LayerAmplificationCert := - ⟨{ - layerIdx := 0 - ln1MaxAbsGamma := 0 - ln1MaxAbsBeta := 0 - ln2MaxAbsGamma := 0 - ln1VarianceLowerBound? := none - ln2VarianceLowerBound? := none - ln1Bound := 0 - ln2Bound := 0 - ln1OutMaxAbsBound := 0 - softmaxProbLo := 0 - softmaxProbHi := 1 - softmaxMarginLowerBound := 0 - softmaxExpEffort := defaultSoftmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxJacobianNormInfWorst - wqOpBoundMax := 0 - wkOpBoundMax := 0 - attnValueCoeff := 0 - attnPatternCoeff := 0 - mlpCoeff := 0 - mlpWinBound := 0 - mlpWoutBound := 0 - mlpActDerivBound := 0 - attnJacBound := 0 - mlpJacBound := 0 - C := 0 - }⟩ - -/-- Portfolio softmax Jacobian bound from interval and margin candidates. -/ -def softmaxJacobianNormInfPortfolioBound (seqLen : Nat) (l : LayerAmplificationCert) : Rat := - ubBest (softmaxJacobianNormInfBound l.softmaxProbLo l.softmaxProbHi) - #[softmaxJacobianNormInfBoundFromMargin seqLen l.softmaxMarginLowerBound l.softmaxExpEffort] - -theorem softmaxJacobianNormInfPortfolioBound_def (seqLen : Nat) (l : LayerAmplificationCert) : - softmaxJacobianNormInfPortfolioBound seqLen l = - ubBest (softmaxJacobianNormInfBound l.softmaxProbLo l.softmaxProbHi) - #[softmaxJacobianNormInfBoundFromMargin seqLen l.softmaxMarginLowerBound - l.softmaxExpEffort] := rfl - -/-- Internal consistency checks for per-layer bounds. -/ -def Valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : Prop := - l.ln1Bound = - (match l.ln1VarianceLowerBound? with - | some v => layerNormOpBoundLocal l.ln1MaxAbsGamma v eps sqrtPrecBits - | none => layerNormOpBoundConservative l.ln1MaxAbsGamma eps sqrtPrecBits) ∧ - l.ln2Bound = - (match l.ln2VarianceLowerBound? with - | some v => layerNormOpBoundLocal l.ln2MaxAbsGamma v eps sqrtPrecBits - | none => layerNormOpBoundConservative l.ln2MaxAbsGamma eps sqrtPrecBits) ∧ - l.ln1OutMaxAbsBound = - layerNormOutputMaxAbsBound modelDim l.ln1MaxAbsGamma l.ln1MaxAbsBeta ∧ - l.softmaxJacobianNormInfUpperBound = - softmaxJacobianNormInfPortfolioBound seqLen l ∧ - l.attnPatternCoeff = - attnPatternCoeffBound seqLen modelDim headDim l.ln1OutMaxAbsBound l.wqOpBoundMax - l.wkOpBoundMax l.attnValueCoeff ∧ - l.attnJacBound = - l.ln1Bound * - ((seqLen : Rat) * l.attnValueCoeff + - l.softmaxJacobianNormInfUpperBound * l.attnPatternCoeff) ∧ - l.mlpCoeff = l.mlpWinBound * l.mlpWoutBound ∧ - l.mlpJacBound = - l.ln2Bound * (l.mlpCoeff * l.mlpActDerivBound) ∧ - l.C = - l.attnJacBound + l.mlpJacBound + - l.attnJacBound * l.mlpJacBound - -instance (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : - Decidable (Valid eps sqrtPrecBits seqLen modelDim headDim l) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : Bool := - decide (Valid eps sqrtPrecBits seqLen modelDim headDim l) - -theorem check_iff (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) : - l.check eps sqrtPrecBits seqLen modelDim headDim = true ↔ - l.Valid eps sqrtPrecBits seqLen modelDim headDim := by - simp [check] - -/-- Extract the `C` identity from `Valid`. -/ -theorem c_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.C = l.attnJacBound + l.mlpJacBound + - l.attnJacBound * l.mlpJacBound := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, _hattn, _hmlpCoeff, _hmlp, hC⟩ - exact hC - -/-- Extract the attention contribution identity from `Valid`. -/ -theorem attnJacBound_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.attnJacBound = - l.ln1Bound * - ((seqLen : Rat) * l.attnValueCoeff + - l.softmaxJacobianNormInfUpperBound * l.attnPatternCoeff) := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, hattn, _hmlpCoeff, _hmlp, _hC⟩ - exact hattn - -/-- Extract the MLP coefficient identity from `Valid`. -/ -theorem mlpCoeff_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.mlpCoeff = l.mlpWinBound * l.mlpWoutBound := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, _hattn, hCoeff, _hmlp, _hC⟩ - exact hCoeff - -/-- Extract the MLP contribution identity from `Valid`. -/ -theorem mlpJacBound_eq_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - l.mlpJacBound = l.ln2Bound * (l.mlpCoeff * l.mlpActDerivBound) := by - rcases h with - ⟨_hln1, _hln2, _hln1Out, _hsoftmax, _hpat, _hattn, _hmlpCoeff, hmlp, _hC⟩ - exact hmlp - -/-- Cast the `C` identity to `ℝ` using `Valid`. -/ -theorem c_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.C : ℝ) = - (l.attnJacBound : ℝ) + - (l.mlpJacBound : ℝ) + - (l.attnJacBound : ℝ) * (l.mlpJacBound : ℝ) := by - have hC := c_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hC' := congrArg (fun (x : Rat) => (x : ℝ)) hC - simpa [Rat.cast_add, Rat.cast_mul, add_assoc] using hC' - -/-- Cast the attention contribution identity to `ℝ` using `Valid`. -/ -theorem attnJacBound_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.attnJacBound : ℝ) = - (l.ln1Bound : ℝ) * - ((seqLen : ℝ) * (l.attnValueCoeff : ℝ) + - (l.softmaxJacobianNormInfUpperBound : ℝ) * (l.attnPatternCoeff : ℝ)) := by - have hAttn := attnJacBound_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hAttn' := congrArg (fun (x : Rat) => (x : ℝ)) hAttn - simpa [Rat.cast_mul, Rat.cast_add, Rat.cast_natCast, mul_assoc, add_assoc] using hAttn' - -/-- Cast the MLP coefficient identity to `ℝ` using `Valid`. -/ -theorem mlpCoeff_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) (seqLen modelDim headDim : Nat) - (l : LayerAmplificationCert) (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.mlpCoeff : ℝ) = (l.mlpWinBound : ℝ) * (l.mlpWoutBound : ℝ) := by - have hCoeff := mlpCoeff_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hCoeff' := congrArg (fun (x : Rat) => (x : ℝ)) hCoeff - simpa [Rat.cast_mul] using hCoeff' - -/-- Cast the MLP contribution identity to `ℝ` using `Valid`. -/ -theorem mlpJacBound_eq_cast_of_valid (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (l : LayerAmplificationCert) - (h : l.Valid eps sqrtPrecBits seqLen modelDim headDim) : - (l.mlpJacBound : ℝ) = - (l.ln2Bound : ℝ) * ((l.mlpCoeff : ℝ) * (l.mlpActDerivBound : ℝ)) := by - have hMlp := mlpJacBound_eq_of_valid eps sqrtPrecBits seqLen modelDim headDim l h - have hMlp' := congrArg (fun (x : Rat) => (x : ℝ)) hMlp - simpa [Rat.cast_mul, mul_assoc] using hMlp' - -/-- Residual composition bound from component bounds and a cast `C` identity. -/ -theorem residual_bound_of_component_bounds - {S : Type*} [Fintype S] [DecidableEq S] [Nonempty S] - (l : LayerAmplificationCert) - (A M : SignedMixer S S) - (hC : (l.C : ℝ) = - (l.attnJacBound : ℝ) + - (l.mlpJacBound : ℝ) + - (l.attnJacBound : ℝ) * (l.mlpJacBound : ℝ)) - (hA : SignedMixer.operatorNormBound A ≤ (l.attnJacBound : ℝ)) - (hM : SignedMixer.operatorNormBound M ≤ (l.mlpJacBound : ℝ)) : - SignedMixer.operatorNormBound - ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - SignedMixer.identity) - ≤ (l.C : ℝ) := by - have hres := - SignedMixer.operatorNormBound_residual_comp_le_of_bounds - (A := A) (M := M) - (a := (l.attnJacBound : ℝ)) - (b := (l.mlpJacBound : ℝ)) hA hM - simpa [hC] using hres - -/-- Residual composition bound from `Valid` plus component bounds. -/ -theorem residual_bound_of_component_bounds_valid - {S : Type*} [Fintype S] [DecidableEq S] [Nonempty S] - (l : LayerAmplificationCert) (eps : Rat) (sqrtPrecBits : Nat) - (seqLen modelDim headDim : Nat) (A M : SignedMixer S S) - (hValid : l.Valid eps sqrtPrecBits seqLen modelDim headDim) - (hA : SignedMixer.operatorNormBound A ≤ (l.attnJacBound : ℝ)) - (hM : SignedMixer.operatorNormBound M ≤ (l.mlpJacBound : ℝ)) : - SignedMixer.operatorNormBound - ((SignedMixer.identity + A).comp (SignedMixer.identity + M) - SignedMixer.identity) - ≤ (l.C : ℝ) := by - have hC := c_eq_cast_of_valid eps sqrtPrecBits seqLen modelDim headDim l hValid - exact residual_bound_of_component_bounds (l := l) (A := A) (M := M) hC hA hM - - -theorem Valid_spec : Valid = Valid := rfl -theorem check_spec : check = check := rfl - -end LayerAmplificationCert - -/-- Model-level certification report. -/ -structure ModelCert where - modelPath : String - inputPath? : Option String - inputDelta : Rat - eps : Rat - /-- Sequence length from the model header. -/ - seqLen : Nat - /-- Model dimension from the model header. -/ - modelDim : Nat - /-- Head dimension from the model header. -/ - headDim : Nat - /-- Precision in dyadic bits for local LayerNorm bounds. -/ - soundnessBits : Nat - /-- Which GeLU derivative target the model uses. -/ - geluDerivTarget : GeluDerivTarget - actDerivBound : Rat - softmaxJacobianNormInfWorst : Rat - layers : Array LayerAmplificationCert - totalAmplificationFactor : Rat - deriving Repr - -namespace ModelCert - -/-- Internal consistency checks for a reported sound certificate. -/ -def Valid (c : ModelCert) : Prop := - 0 < c.eps ∧ - c.softmaxJacobianNormInfWorst = Nfp.Sound.softmaxJacobianNormInfWorst ∧ - c.actDerivBound = c.layers.foldl (fun acc l => max acc l.mlpActDerivBound) 0 ∧ - List.Forall₂ - (fun i l => - l.layerIdx = i ∧ - LayerAmplificationCert.Valid c.eps c.soundnessBits c.seqLen c.modelDim c.headDim l) - (List.range c.layers.size) c.layers.toList ∧ - c.totalAmplificationFactor = - c.layers.foldl (fun acc l => acc * (1 + l.C)) 1 - -instance (c : ModelCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : ModelCert) : Bool := - decide (Valid c) - -theorem check_iff (c : ModelCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -/-- Pretty printer. -/ -def toString (c : ModelCert) : String := - let header := - s!"SOUND mode: conservative bounds; may be much looser than heuristic analysis.\n" ++ - (match c.inputPath? with - | some p => s!"input={p}, delta={c.inputDelta}\n" - | none => "") ++ - s!"eps={c.eps}, seqLen={c.seqLen}, modelDim={c.modelDim}, headDim={c.headDim}, " ++ - s!"soundnessBits={c.soundnessBits}, " ++ - s!"geluDerivTarget={geluDerivTargetToString c.geluDerivTarget}, " ++ - s!"actDerivBound={c.actDerivBound}, " ++ - s!"softmaxJacobianNormInfWorst={c.softmaxJacobianNormInfWorst}\n" ++ - s!"totalAmplificationFactor={c.totalAmplificationFactor}\n" - let body := - c.layers.foldl (fun acc l => - acc ++ - s!"Layer {l.layerIdx}: C={l.C} (attn={l.attnJacBound}, \ -mlp={l.mlpJacBound}, cross={l.attnJacBound * l.mlpJacBound}, \ -attnValueCoeff={l.attnValueCoeff}, attnPatternCoeff={l.attnPatternCoeff}, \ -wqOpBoundMax={l.wqOpBoundMax}, wkOpBoundMax={l.wkOpBoundMax}, \ -ln1OutMaxAbsBound={l.ln1OutMaxAbsBound}, \ -mlpWinBound={l.mlpWinBound}, mlpWoutBound={l.mlpWoutBound}, \ -mlpActDerivBound={l.mlpActDerivBound}, \ -softmaxJacobianNormInfUpperBound={l.softmaxJacobianNormInfUpperBound}, \ -softmaxMarginLowerBound={l.softmaxMarginLowerBound}, softmaxExpEffort={l.softmaxExpEffort}, \ -ln1Bound={l.ln1Bound}, ln2Bound={l.ln2Bound}" ++ - (match l.ln1VarianceLowerBound? with - | some v => s!", ln1Var≥{v}" - | none => "") ++ - (match l.ln2VarianceLowerBound? with - | some v => s!", ln2Var≥{v}" - | none => "") ++ - ")\n") "" - header ++ body - -instance : ToString ModelCert := ⟨toString⟩ - -theorem Valid_spec : Valid = Valid := rfl -theorem check_spec : check = check := rfl -theorem toString_spec : toString = toString := rfl - -end ModelCert - -/-! ### Specs -/ - -end Nfp.Sound diff --git a/Nfp/Sound/Decimal.lean b/Nfp/Sound/Decimal.lean deleted file mode 100644 index 4ed69fd..0000000 --- a/Nfp/Sound/Decimal.lean +++ /dev/null @@ -1,136 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std - -namespace Nfp.Sound - -/-! -# Exact decimal/scientific parsing for sound certification - -This module parses decimal and scientific-notation numerals (e.g. `-1.25e-3`) into `Rat`. - -Design goal: avoid `Float` as a source of truth in the sound certification path. -We only use exact integer arithmetic and powers of 10. --/ - -/-- Parse a signed integer written in base-10 (optional leading `+`/`-`). -/ -def parseInt10 (s : String) : Except String Int := - let s := s.trim - if s.isEmpty then - .error "empty integer" - else - let (neg, rest) := - if s.startsWith "-" then (true, s.drop 1) - else if s.startsWith "+" then (false, s.drop 1) - else (false, s) - if rest.isEmpty then - .error s!"invalid integer '{s}'" - else - match rest.toNat? with - | none => .error s!"invalid integer '{s}'" - | some n => - let i : Int := Int.ofNat n - .ok (if neg then -i else i) - -/-- Parse a base-10 natural number; empty string is treated as 0. -/ -def parseNat10OrZero (s : String) : Except String Nat := - let s := s.trim - if s.isEmpty then - .ok 0 - else - match s.toNat? with - | none => .error s!"invalid natural '{s}'" - | some n => .ok n - -/-- Parse a decimal/scientific numeral into an exact `Rat`. - -Supported forms: -- `123`, `-123`, `+123` -- `1.25`, `.25`, `2.` -- `1e3`, `1E3`, `-1.25e-3` - -This is intended for `.nfpt` parsing in sound mode. --/ -def parseRat (s : String) : Except String Rat := do - let s := s.trim - if s.isEmpty then - throw "empty numeral" - - -- sign - let (neg, rest) := - if s.startsWith "-" then (true, s.drop 1) - else if s.startsWith "+" then (false, s.drop 1) - else (false, s) - - -- optional exponent - let (mantissaStr, expStr?) : String × Option String := - match rest.splitOn "e" with - | [m, e] => (m, some e) - | _ => - match rest.splitOn "E" with - | [m, e] => (m, some e) - | _ => (rest, none) - - -- mantissa: intPart.fracPart - let parts := mantissaStr.splitOn "." - let (intPart, fracPart) ← - match parts with - | [i] => pure (i, "") - | [i, f] => pure (i, f) - | _ => throw s!"invalid numeral '{s}'" - - let iNat ← parseNat10OrZero intPart - let fNat ← parseNat10OrZero fracPart - let fracLen := fracPart.trim.length - - let expInt : Int ← - match expStr? with - | none => pure 0 - | some e => parseInt10 e - - -- Construct `Rat` in a single normalization step (avoids repeated gcd normalization). - let denomBase : Nat := Nat.pow 10 fracLen - let mantissaNat : Nat := iNat * denomBase + fNat - let num0 : Int := if neg then -(Int.ofNat mantissaNat) else (Int.ofNat mantissaNat) - let expAbs : Nat := Int.natAbs expInt - let pow10Nat : Nat := Nat.pow 10 expAbs - - let den : Nat := - if expInt < 0 then denomBase * pow10Nat else denomBase - let num : Int := - if expInt > 0 then num0 * (Int.ofNat pow10Nat) else num0 - - have den_nz : den ≠ 0 := by - have h10pos : (0 : Nat) < 10 := by decide - have hpow1 : denomBase ≠ 0 := by - exact Nat.ne_of_gt (Nat.pow_pos (n := fracLen) h10pos) - have hpow2 : pow10Nat ≠ 0 := by - exact Nat.ne_of_gt (Nat.pow_pos (n := expAbs) h10pos) - by_cases hneg : expInt < 0 - · -- `den = denomBase * pow10Nat` - simpa [den, hneg] using Nat.mul_ne_zero hpow1 hpow2 - · -- `den = denomBase` - simpa [den, hneg] using hpow1 - - return Rat.normalize num den (den_nz := den_nz) - -/-- Parse a line of space-separated rationals, failing on the first invalid token. -/ -def parseRatLine (line : String) : Except String (Array Rat) := do - let parts := line.splitOn " " |>.filter (· ≠ "") - let mut out : Array Rat := #[] - for p in parts do - let r ← parseRat p - out := out.push r - return out - -/-! ### Specs -/ - -theorem parseInt10_spec (s : String) : parseInt10 s = parseInt10 s := rfl - -theorem parseNat10OrZero_spec (s : String) : parseNat10OrZero s = parseNat10OrZero s := rfl - -theorem parseRat_spec (s : String) : parseRat s = parseRat s := rfl - -theorem parseRatLine_spec (line : String) : parseRatLine line = parseRatLine line := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Demo.lean b/Nfp/Sound/Demo.lean deleted file mode 100644 index 83e28d0..0000000 --- a/Nfp/Sound/Demo.lean +++ /dev/null @@ -1,103 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Tactic.Linarith -import Mathlib.Tactic.NormNum -import Mathlib.Tactic.FinCases -import Nfp.Linearization -import Nfp.Sound.Cert - -namespace Nfp.Sound - -open scoped BigOperators - -open Nfp - -/-- A tiny 2×2 signed mixer used to demonstrate the sound row-sum bound story. -/ -noncomputable def demoMixer : SignedMixer (Fin 1 × Fin 2) (Fin 1 × Fin 2) := - ⟨fun i j => - match i.2.val, j.2.val with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0⟩ - -/-- End-to-end toy lemma: the abstract `Linearization.operatorNormBound` is controlled by an -explicit row-sum bound for a concrete small matrix. - -This is intentionally tiny (2×2) but exercises the same definition (`max row sum of abs`). --/ -theorem demo_operatorNormBound_le : - Nfp.operatorNormBound demoMixer ≤ (7 : ℝ) := by - classical - -- Unfold to a `Finset.sup'` of row sums and check each row explicitly. - dsimp [Nfp.operatorNormBound, SignedMixer.operatorNormBound, demoMixer] - refine (Finset.sup'_le_iff (s := (Finset.univ : Finset (Fin 1 × Fin 2))) - (f := fun i : Fin 1 × Fin 2 => - ∑ x : Fin 1 × Fin 2, - abs - (match (i.2 : Nat), (x.2 : Nat) with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0)) - (H := Finset.univ_nonempty)).2 ?_ - intro i _hi - rcases i with ⟨i1, i2⟩ - fin_cases i1 - fin_cases i2 - · -- Row 0: |1| + |2| = 3 ≤ 7. - have hsum : - (∑ x : Fin 1 × Fin 2, - abs - (match (0 : Nat), (x.2 : Nat) with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0)) = 3 := by - simp [Fintype.sum_prod_type, Fin.sum_univ_two] - norm_num - nlinarith [hsum] - · -- Row 1: |−3| + |4| = 7 ≤ 7. - have hsum : - (∑ x : Fin 1 × Fin 2, - abs - (match (1 : Nat), (x.2 : Nat) with - | 0, 0 => (1 : ℝ) - | 0, 1 => (2 : ℝ) - | 1, 0 => (-3 : ℝ) - | 1, 1 => (4 : ℝ) - | _, _ => 0)) = 7 := by - simp [Fintype.sum_prod_type, Fin.sum_univ_two] - norm_num - nlinarith [hsum] - -/-! ## Executable checker sanity test -/ - -/-- A tiny inconsistent certificate used to sanity-check the boolean checker. -/ -def demoBadCert : ModelCert := - { modelPath := "" - inputPath? := none - inputDelta := 0 - eps := 0 - seqLen := 0 - modelDim := 0 - headDim := 0 - soundnessBits := 20 - geluDerivTarget := .tanh - actDerivBound := 0 - softmaxJacobianNormInfWorst := 0 - layers := #[] - totalAmplificationFactor := 1 } - -theorem demoBadCert_check : ModelCert.check demoBadCert = false := by - simp [ModelCert.check, ModelCert.Valid, demoBadCert, softmaxJacobianNormInfWorst] - -/-! ### Specs -/ - -theorem demoMixer_spec : demoMixer = demoMixer := rfl -theorem demoBadCert_spec : demoBadCert = demoBadCert := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/Fixed.lean b/Nfp/Sound/Fixed.lean deleted file mode 100644 index 88c31e1..0000000 --- a/Nfp/Sound/Fixed.lean +++ /dev/null @@ -1,398 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Activation - -namespace Nfp.Sound - -/-! -# Fixed-point (base-10) arithmetic for SOUND-mode streaming - -We represent real numbers as scaled integers with a **global** scale `S = 10^p`. - -This file is intentionally `Int`/`Nat`-only: it is used to avoid `Rat.normalize`/gcd costs on -hot paths (matrix streaming / IBP), while preserving mathematical rigor by using conservative -outward rounding when rescaling after multiplication/division. --/ - -structure Fixed10Cfg where - /-- Scale exponent `p` in `S = 10^p`. -/ - scalePow10 : Nat - deriving Repr - -namespace Fixed10Cfg - -def scaleNat (cfg : Fixed10Cfg) : Nat := Nat.pow 10 cfg.scalePow10 -def scaleInt (cfg : Fixed10Cfg) : Int := Int.ofNat cfg.scaleNat - -theorem scaleNat_def (cfg : Fixed10Cfg) : scaleNat cfg = Nat.pow 10 cfg.scalePow10 := rfl - -theorem scaleInt_def (cfg : Fixed10Cfg) : scaleInt cfg = Int.ofNat cfg.scaleNat := rfl - -end Fixed10Cfg - -/-- Fixed-point scalar encoded as an `Int` meaning `x / S`. -/ -abbrev Fixed10 := Int - -/-- Closed fixed-point interval `[lo, hi]` (both in scaled integer units). -/ -structure Fixed10Interval where - lo : Fixed10 - hi : Fixed10 - deriving Repr - -namespace Fixed10Interval - -instance : Inhabited Fixed10Interval := ⟨{ lo := 0, hi := 0 }⟩ - -def const (x : Fixed10) : Fixed10Interval := { lo := x, hi := x } - -def union (a b : Fixed10Interval) : Fixed10Interval := - { lo := min a.lo b.lo, hi := max a.hi b.hi } - -def add (a b : Fixed10Interval) : Fixed10Interval := - { lo := a.lo + b.lo, hi := a.hi + b.hi } - -def sub (a b : Fixed10Interval) : Fixed10Interval := - { lo := a.lo - b.hi, hi := a.hi - b.lo } - -def relu (a : Fixed10Interval) : Fixed10Interval := - { lo := max 0 a.lo, hi := max 0 a.hi } - -/-- Conservative GeLU hull: `GeLU(x) ∈ [min(x,0), max(x,0)]`. -/ -def geluOverapprox (a : Fixed10Interval) : Fixed10Interval := - { lo := min a.lo 0, hi := max a.hi 0 } - -private def absInt (x : Int) : Int := if x < 0 then -x else x - -/-- Maximum absolute endpoint (in scaled integer units). -/ -def absUpper (a : Fixed10Interval) : Int := - max (absInt a.lo) (absInt a.hi) - -/-- Upper bound on `|x - μ|` for any `x, μ ∈ [lo, hi]` (scaled units). -/ -def centeredAbsBound (a : Fixed10Interval) : Int := - absInt (a.hi - a.lo) - -/-- Upper bound on `max |gelu'(x)|` over a fixed-point interval. -/ -def geluDerivBound (cfg : Fixed10Cfg) (target : GeluDerivTarget) (a : Fixed10Interval) : Rat := - let maxAbsInt := absUpper a - let maxAbsRat : Rat := - Rat.normalize maxAbsInt cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let maxAbsSq := maxAbsRat * maxAbsRat - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - min (1 + half * maxAbsRat) 2 - | .tanh => - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbsRat * slope - min localBound 2 - -/-- Floor division by a positive `Nat` divisor. -/ -private def floorDivNat (a : Int) (d : Nat) : Int := - -- `Int.ediv` is Euclidean division (for positive divisor): `a = q*d + r`, `0 ≤ r < d`. - a.ediv (Int.ofNat d) - -/-- Ceil division by a positive `Nat` divisor. -/ -private def ceilDivNat (a : Int) (d : Nat) : Int := - let di : Int := Int.ofNat d - let q := a.ediv di - let r := a.emod di - if r = 0 then q else q + 1 - -/-- Rescale an interval from scale `S^2` down to `S` using conservative rounding. -/ -private def rescaleFromSq (cfg : Fixed10Cfg) (loSq hiSq : Int) : Fixed10Interval := - let S : Nat := cfg.scaleNat - { lo := floorDivNat loSq S, hi := ceilDivNat hiSq S } - -/-- Multiply two fixed-point intervals, returning an interval at the same scale. - -If `a,b` are in units of `1/S`, then their product is in units of `1/S^2`; we rescale back to `1/S` -with outward rounding to remain conservative. --/ -def mul (cfg : Fixed10Cfg) (a b : Fixed10Interval) : Fixed10Interval := - Id.run do - let p1 := a.lo * b.lo - let p2 := a.lo * b.hi - let p3 := a.hi * b.lo - let p4 := a.hi * b.hi - let loSq := min (min p1 p2) (min p3 p4) - let hiSq := max (max p1 p2) (max p3 p4) - return rescaleFromSq cfg loSq hiSq - -/-- Add a constant vector to a vector of intervals. -/ -def addConstVec (xs : Array Fixed10Interval) (c : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if xs.size ≠ c.size then - return xs - let mut out := Array.mkEmpty xs.size - for i in [:xs.size] do - out := out.push (add xs[i]! c[i]!) - return out - -/-- Elementwise union of two interval vectors. -/ -def unionVec (a b : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if a.size ≠ b.size then - return a - let mut out := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (union a[i]! b[i]!) - return out - -/-! ### Specs -/ - -theorem Fixed10_spec : Fixed10 = Int := rfl - -theorem const_def (x : Fixed10) : Fixed10Interval.const x = { lo := x, hi := x } := rfl - -theorem union_def (a b : Fixed10Interval) : - Fixed10Interval.union a b = { lo := min a.lo b.lo, hi := max a.hi b.hi } := rfl - -theorem add_def (a b : Fixed10Interval) : - Fixed10Interval.add a b = { lo := a.lo + b.lo, hi := a.hi + b.hi } := rfl - -theorem sub_def (a b : Fixed10Interval) : - Fixed10Interval.sub a b = { lo := a.lo - b.hi, hi := a.hi - b.lo } := rfl - -theorem relu_def (a : Fixed10Interval) : - Fixed10Interval.relu a = { lo := max 0 a.lo, hi := max 0 a.hi } := rfl - -theorem geluOverapprox_def (a : Fixed10Interval) : - Fixed10Interval.geluOverapprox a = { lo := min a.lo 0, hi := max a.hi 0 } := rfl - -theorem absInt_spec (x : Int) : absInt x = absInt x := rfl - -theorem absUpper_def (a : Fixed10Interval) : - Fixed10Interval.absUpper a = max (absInt a.lo) (absInt a.hi) := rfl - -theorem centeredAbsBound_def (a : Fixed10Interval) : - Fixed10Interval.centeredAbsBound a = absInt (a.hi - a.lo) := rfl - -theorem geluDerivBound_def (cfg : Fixed10Cfg) (target : GeluDerivTarget) (a : Fixed10Interval) : - Fixed10Interval.geluDerivBound cfg target a = - let maxAbsInt := absUpper a - let maxAbsRat : Rat := - Rat.normalize maxAbsInt cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let maxAbsSq := maxAbsRat * maxAbsRat - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - min (1 + half * maxAbsRat) 2 - | .tanh => - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbsRat * slope - min localBound 2 := rfl - -theorem floorDivNat_spec (a : Int) (d : Nat) : floorDivNat a d = floorDivNat a d := rfl - -theorem ceilDivNat_spec (a : Int) (d : Nat) : ceilDivNat a d = ceilDivNat a d := rfl - -theorem rescaleFromSq_spec (cfg : Fixed10Cfg) (loSq hiSq : Int) : - rescaleFromSq cfg loSq hiSq = rescaleFromSq cfg loSq hiSq := rfl - -theorem mul_spec (cfg : Fixed10Cfg) (a b : Fixed10Interval) : - Fixed10Interval.mul cfg a b = Fixed10Interval.mul cfg a b := rfl - -theorem addConstVec_spec (xs : Array Fixed10Interval) (c : Array Fixed10Interval) : - Fixed10Interval.addConstVec xs c = Fixed10Interval.addConstVec xs c := rfl - -theorem unionVec_spec (a b : Array Fixed10Interval) : - Fixed10Interval.unionVec a b = Fixed10Interval.unionVec a b := rfl - -end Fixed10Interval - -/-! -## Fast decimal → fixed-point parsing - -We parse a decimal/scientific numeral token into a **rounded** scaled integer at scale `S = 10^p` -without constructing a `Rat` (and therefore without gcd normalization). - -Correctness contract (soundness): -- The returned integer `r` is a rounding of the exact scaled value `x*S`. -- If later we treat the true scaled value as lying in `[r-1, r+1]`, then this interval always - contains the exact scaled value (since the exact value lies between `floor` and `ceil`). --/ - -private def isDigit (b : UInt8) : Bool := (48 ≤ b) && (b ≤ 57) -private def digitVal (b : UInt8) : Nat := (b.toNat - 48) - -private def pow10Nat (k : Nat) : Nat := Nat.pow 10 k - -/-- Parse an `Int` exponent written in base-10 from a byte slice. -/ -private def parseExpInt (bytes : ByteArray) (start stop : Nat) : Except String Int := - if start ≥ stop then - .error "invalid exponent" - else - Id.run do - let mut i := start - let mut neg : Bool := false - let b0 := bytes[i]! - if b0 = 45 then -- '-' - neg := true - i := i + 1 - else if b0 = 43 then -- '+' - i := i + 1 - if i ≥ stop then - return .error "invalid exponent" - let mut acc : Int := 0 - while i < stop do - let b := bytes[i]! - if !isDigit b then - return .error "invalid exponent digit" - acc := acc * 10 + (Int.ofNat (digitVal b)) - i := i + 1 - return .ok (if neg then -acc else acc) - -/-- Parse a token into a rounded scaled integer at scale `10^scalePow10`. -/ -def parseFixed10Rounded (scalePow10 : Nat) (bytes : ByteArray) (start stop : Nat) : - Except String Int := - if start ≥ stop then - .error "empty token" - else - Id.run do - let mut i := start - -- sign - let mut neg : Bool := false - let b0 := bytes[i]! - if b0 = 45 then -- '-' - neg := true - i := i + 1 - else if b0 = 43 then - i := i + 1 - - -- mantissa with optional '.' - let mut mant : Int := 0 - let mut fracLen : Nat := 0 - let mut seenDot : Bool := false - let mut anyDigit : Bool := false - while i < stop do - let b := bytes[i]! - if b = 46 then -- '.' - if seenDot then - return .error "invalid numeral (multiple dots)" - seenDot := true - i := i + 1 - else if b = 101 || b = 69 then -- 'e' or 'E' - break - else if isDigit b then - anyDigit := true - mant := mant * 10 + (Int.ofNat (digitVal b)) - if seenDot then - fracLen := fracLen + 1 - i := i + 1 - else - return .error "invalid numeral" - if !anyDigit then - return .error "invalid numeral (no digits)" - - -- optional exponent - let mut exp : Int := 0 - if i < stop then - let b := bytes[i]! - if b = 101 || b = 69 then - match parseExpInt bytes (i + 1) stop with - | .error e => return .error e - | .ok e => exp := e - - -- scaled value: mant * 10^(scalePow10 + exp - fracLen) - let expTotal : Int := (Int.ofNat scalePow10) + exp - (Int.ofNat fracLen) - let num0 : Int := if neg then -mant else mant - if expTotal ≥ 0 then - let eNat : Nat := Int.toNat expTotal - let pow : Int := Int.ofNat (pow10Nat eNat) - return .ok (num0 * pow) - else - let eNat : Nat := Int.toNat (-expTotal) - let denNat : Nat := pow10Nat eNat - let den : Int := Int.ofNat denNat - let q := num0.ediv den - let r := num0.emod den - if r = 0 then - return .ok q - -- Round-to-nearest (ties up). Always within 1 of the exact scaled value. - let twoR := (2 : Int) * r - if twoR < den then - return .ok q - else - return .ok (q + 1) - -/-! -### Token folding helpers (line-based) - -These helpers mirror the `foldRatTokens` pattern used elsewhere, but avoid allocating token -substrings by scanning whitespace boundaries in the UTF-8 byte array of each line. --/ - -private def isWs (b : UInt8) : Bool := b = 32 || b = 9 -- ' ' or '\t' - -def foldFixed10Tokens {α : Type} - (scalePow10 : Nat) - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Int → α) : Except String (α × Nat) := - Id.run do - let mut i := start - let mut remaining := count - let mut st := state - while remaining > 0 do - if i ≥ lines.size then - return .error "unexpected end of file while reading fixed tokens" - let line := lines[i]!.trim - i := i + 1 - if line.isEmpty then - pure () - else - let bytes := line.toUTF8 - let mut j : Nat := 0 - while j < bytes.size && remaining > 0 do - while j < bytes.size && isWs (bytes[j]!) do - j := j + 1 - if j ≥ bytes.size then - break - let tokStart := j - while j < bytes.size && !isWs (bytes[j]!) do - j := j + 1 - let tokStop := j - match parseFixed10Rounded scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok x => - st := step st x - remaining := remaining - 1 - return .ok (st, i) - -/-! ### Specs -/ - -theorem isDigit_spec (b : UInt8) : isDigit b = isDigit b := rfl - -theorem digitVal_spec (b : UInt8) : digitVal b = digitVal b := rfl - -theorem pow10Nat_spec (k : Nat) : pow10Nat k = pow10Nat k := rfl - -theorem parseExpInt_spec (bytes : ByteArray) (start stop : Nat) : - parseExpInt bytes start stop = parseExpInt bytes start stop := rfl - -theorem parseFixed10Rounded_spec (scalePow10 : Nat) (bytes : ByteArray) (start stop : Nat) : - parseFixed10Rounded scalePow10 bytes start stop = - parseFixed10Rounded scalePow10 bytes start stop := rfl - -theorem isWs_spec (b : UInt8) : isWs b = isWs b := rfl - -theorem foldFixed10Tokens_spec {α : Type} - (scalePow10 : Nat) - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Int → α) : - foldFixed10Tokens scalePow10 lines start count state step = - foldFixed10Tokens scalePow10 lines start count state step := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/HeadCert.lean b/Nfp/Sound/HeadCert.lean deleted file mode 100644 index c0b1c6c..0000000 --- a/Nfp/Sound/HeadCert.lean +++ /dev/null @@ -1,490 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Induction -import Nfp.Sound.Bounds - -namespace Nfp.Sound - -/-! -# Sound per-head contribution certificates - -This module defines a minimal, checkable certificate for per-head weight-only -contribution bounds. These are intended as a lightweight starting point for -sound circuit certification. --/ - -/-- Weight-only per-head operator-norm bounds and derived factors. -/ -structure HeadContributionCert where - layerIdx : Nat - headIdx : Nat - wqOpBound : Rat - wkOpBound : Rat - wvOpBound : Rat - woOpBound : Rat - qkFactorBound : Rat - voFactorBound : Rat - deriving Repr - -namespace HeadContributionCert - -/-- Internal consistency checks for derived factor bounds. -/ -def Valid (c : HeadContributionCert) : Prop := - c.qkFactorBound = c.wqOpBound * c.wkOpBound ∧ - c.voFactorBound = c.wvOpBound * c.woOpBound - -instance (c : HeadContributionCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadContributionCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadContributionCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadContributionCert - -/-- Local (input-dependent) per-head attention contribution bounds. -/ -structure HeadLocalContributionCert where - layerIdx : Nat - headIdx : Nat - /-- Precision in dyadic bits for local LayerNorm bounds. -/ - soundnessBits : Nat - ln1MaxAbsGamma : Rat - ln1VarianceLowerBound : Rat - ln1Bound : Rat - wqOpBound : Rat - wkOpBound : Rat - wvOpBound : Rat - woOpBound : Rat - qkFactorBound : Rat - /-- Upper bound on the per-head attention Jacobian contribution. -/ - attnJacBound : Rat - deriving Repr - -namespace HeadLocalContributionCert - -/-- Internal consistency checks for local per-head bounds. -/ -def Valid (eps : Rat) (c : HeadLocalContributionCert) : Prop := - 0 < eps ∧ - c.ln1Bound = - (if c.ln1VarianceLowerBound > 0 then - layerNormOpBoundLocal c.ln1MaxAbsGamma c.ln1VarianceLowerBound eps c.soundnessBits - else - layerNormOpBoundConservative c.ln1MaxAbsGamma eps c.soundnessBits) ∧ - c.qkFactorBound = c.wqOpBound * c.wkOpBound ∧ - c.attnJacBound = - c.ln1Bound * softmaxJacobianNormInfWorst * c.wvOpBound * c.woOpBound - -instance (eps : Rat) (c : HeadLocalContributionCert) : Decidable (Valid eps c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (eps : Rat) (c : HeadLocalContributionCert) : Bool := - decide (Valid eps c) - -theorem check_iff (eps : Rat) (c : HeadLocalContributionCert) : - c.check eps = true ↔ c.Valid eps := by - simp [check] - -end HeadLocalContributionCert - -/-- Local per-head attention pattern certificate (target logit dominance). -/ -structure HeadPatternCert where - layerIdx : Nat - headIdx : Nat - seqLen : Nat - targetOffset : Int - targetCountLowerBound : Nat - targetLogitLowerBound : Rat - otherLogitUpperBound : Rat - marginLowerBound : Rat - /-- Effort level for the `expLB` portfolio used in margin-to-weight bounds. -/ - softmaxExpEffort : Nat - targetWeightLowerBound : Rat - deriving Repr - -namespace HeadPatternCert - -/-- Internal consistency checks for pattern bounds. -/ -def Valid (c : HeadPatternCert) : Prop := - c.seqLen > 0 ∧ - c.targetCountLowerBound ≤ c.seqLen ∧ - c.marginLowerBound = c.targetLogitLowerBound - c.otherLogitUpperBound ∧ - c.targetWeightLowerBound = - softmaxTargetWeightLowerBound c.seqLen c.targetCountLowerBound - c.marginLowerBound c.softmaxExpEffort - -instance (c : HeadPatternCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadPatternCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadPatternCert) : c.check = true ↔ c.Valid := by - simp [check] - -end HeadPatternCert - -/-! ## Local value-direction bounds -/ - -/-- Safe lower bound for a convex mixture when only a lower bound on the match weight is known. -/ -def mixLowerBound (w m n : Rat) : Rat := - min m (w * m + (1 - w) * n) - -/-- Local per-head output lower bound for a single coordinate. -/ -structure HeadValueLowerBoundCert where - layerIdx : Nat - headIdx : Nat - coord : Nat - matchWeightLowerBound : Rat - matchCoordLowerBound : Rat - nonmatchCoordLowerBound : Rat - outputCoordLowerBound : Rat - deriving Repr - -namespace HeadValueLowerBoundCert - -/-- Internal consistency checks for the coordinate lower bound. -/ -def Valid (c : HeadValueLowerBoundCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.outputCoordLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchCoordLowerBound c.nonmatchCoordLowerBound - -instance (c : HeadValueLowerBoundCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadValueLowerBoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadValueLowerBoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadValueLowerBoundCert - -/-! ## Logit-direction bounds -/ - -/-- Local per-head logit-difference lower bound for a target direction. -/ -structure HeadLogitDiffLowerBoundCert where - layerIdx : Nat - headIdx : Nat - targetToken : Nat - negativeToken : Nat - matchWeightLowerBound : Rat - matchLogitLowerBound : Rat - nonmatchLogitLowerBound : Rat - logitDiffLowerBound : Rat - deriving Repr - -namespace HeadLogitDiffLowerBoundCert - -/-- Internal consistency checks for the logit-difference lower bound. -/ -def Valid (c : HeadLogitDiffLowerBoundCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.logitDiffLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchLogitLowerBound c.nonmatchLogitLowerBound - -instance (c : HeadLogitDiffLowerBoundCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadLogitDiffLowerBoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadLogitDiffLowerBoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadLogitDiffLowerBoundCert - -/-! ## Best-match pattern bounds -/ - -/-- Local per-head attention pattern certificate (best-match, single query position). -/ -structure HeadBestMatchPatternCert where - layerIdx : Nat - headIdx : Nat - seqLen : Nat - queryPos : Nat - targetOffset : Int - targetToken : Int - bestMatchLogitLowerBound : Rat - bestNonmatchLogitUpperBound : Rat - marginLowerBound : Rat - /-- Effort level for the `expLB` portfolio used in margin-to-probability bounds. -/ - softmaxExpEffort : Nat - bestMatchWeightLowerBound : Rat - /-- Softmax Jacobian row-sum bound derived from the max-probability lower bound. -/ - softmaxJacobianNormInfUpperBound : Rat - deriving Repr - -namespace HeadBestMatchPatternCert - -/-- Internal consistency checks for best-match pattern bounds. -/ -def Valid (c : HeadBestMatchPatternCert) : Prop := - c.seqLen > 0 ∧ - c.queryPos < c.seqLen ∧ - c.marginLowerBound = c.bestMatchLogitLowerBound - c.bestNonmatchLogitUpperBound ∧ - c.bestMatchWeightLowerBound = - softmaxMaxProbLowerBound c.seqLen c.marginLowerBound c.softmaxExpEffort ∧ - c.softmaxJacobianNormInfUpperBound = - softmaxJacobianNormInfBoundFromMargin c.seqLen c.marginLowerBound c.softmaxExpEffort - -instance (c : HeadBestMatchPatternCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadBestMatchPatternCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadBestMatchPatternCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadBestMatchPatternCert - -/-! ## Best-match value/logit bounds -/ - -/-- Local per-head output lower bound for a single coordinate (single query position). -/ -structure HeadValueLowerBoundPosCert where - layerIdx : Nat - headIdx : Nat - queryPos : Nat - coord : Nat - matchWeightLowerBound : Rat - matchCoordLowerBound : Rat - nonmatchCoordLowerBound : Rat - outputCoordLowerBound : Rat - deriving Repr - -namespace HeadValueLowerBoundPosCert - -/-- Internal consistency checks for the coordinate lower bound. -/ -def Valid (c : HeadValueLowerBoundPosCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.outputCoordLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchCoordLowerBound c.nonmatchCoordLowerBound - -instance (c : HeadValueLowerBoundPosCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadValueLowerBoundPosCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadValueLowerBoundPosCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadValueLowerBoundPosCert - -/-- Local per-head logit-difference lower bound (single query position). -/ -structure HeadLogitDiffLowerBoundPosCert where - layerIdx : Nat - headIdx : Nat - queryPos : Nat - targetToken : Nat - negativeToken : Nat - matchWeightLowerBound : Rat - matchLogitLowerBound : Rat - nonmatchLogitLowerBound : Rat - logitDiffLowerBound : Rat - deriving Repr - -namespace HeadLogitDiffLowerBoundPosCert - -/-- Internal consistency checks for the logit-difference lower bound. -/ -def Valid (c : HeadLogitDiffLowerBoundPosCert) : Prop := - 0 ≤ c.matchWeightLowerBound ∧ - c.matchWeightLowerBound ≤ 1 ∧ - c.logitDiffLowerBound = - mixLowerBound c.matchWeightLowerBound c.matchLogitLowerBound c.nonmatchLogitLowerBound - -instance (c : HeadLogitDiffLowerBoundPosCert) : Decidable (Valid c) := by - unfold Valid - infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : HeadLogitDiffLowerBoundPosCert) : Bool := - decide (Valid c) - -theorem check_iff (c : HeadLogitDiffLowerBoundPosCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end HeadLogitDiffLowerBoundPosCert - -namespace HeadPatternCert - -/-- Convert a sound head pattern certificate into a token-match witness. -/ -def toTokenMatchPattern (c : HeadPatternCert) : Nfp.TokenMatchPattern := { - seqLen := c.seqLen - targetOffset := c.targetOffset - targetCountLowerBound := c.targetCountLowerBound - softmaxExpEffort := c.softmaxExpEffort - targetWeightLowerBound := c.targetWeightLowerBound - marginLowerBound := c.marginLowerBound -} - -theorem toTokenMatchPattern_valid (c : HeadPatternCert) (h : c.Valid) : - (toTokenMatchPattern c).Valid := by - rcases h with ⟨hseq, hcount, _hmargin, hweight⟩ - exact ⟨hseq, hcount, by simpa [toTokenMatchPattern] using hweight⟩ - -def toInductionPatternWitness - (c : HeadPatternCert) (h : c.Valid) (hm : c.marginLowerBound > 0) - (hcount : 0 < c.targetCountLowerBound) (hoff : c.targetOffset = -1) : - Nfp.InductionPatternWitness := - Nfp.TokenMatchPattern.toInductionPatternWitness - (toTokenMatchPattern c) (toTokenMatchPattern_valid c h) hm hcount hoff - -end HeadPatternCert - -/-! ## Induction head sound certificates -/ - -/-- Combined sound certificate for an induction-style head pair. -/ -structure InductionHeadSoundCert where - layer1Pattern : HeadPatternCert - layer2Pattern : HeadPatternCert - layer2Value : HeadValueLowerBoundCert - layer2Logit? : Option HeadLogitDiffLowerBoundCert - deltaLowerBound : Rat - deriving Repr - -namespace InductionHeadSoundCert - -/-- Internal consistency checks for the combined certificate. -/ -def Valid (c : InductionHeadSoundCert) : Prop := - HeadPatternCert.Valid c.layer1Pattern ∧ - HeadPatternCert.Valid c.layer2Pattern ∧ - HeadValueLowerBoundCert.Valid c.layer2Value ∧ - c.layer2Value.layerIdx = c.layer2Pattern.layerIdx ∧ - c.layer2Value.headIdx = c.layer2Pattern.headIdx ∧ - c.layer2Value.matchWeightLowerBound = c.layer2Pattern.targetWeightLowerBound ∧ - c.deltaLowerBound = c.layer2Value.outputCoordLowerBound ∧ - (match c.layer2Logit? with - | none => True - | some logit => - HeadLogitDiffLowerBoundCert.Valid logit ∧ - logit.layerIdx = c.layer2Pattern.layerIdx ∧ - logit.headIdx = c.layer2Pattern.headIdx ∧ - logit.matchWeightLowerBound = c.layer2Pattern.targetWeightLowerBound) - -instance (c : InductionHeadSoundCert) : Decidable (Valid c) := by - classical - unfold Valid - cases c.layer2Logit? <;> infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : InductionHeadSoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : InductionHeadSoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end InductionHeadSoundCert - -/-! ## Best-match induction head certificates -/ - -/-- Combined sound certificate for an induction-style head pair (best-match pattern). -/ -structure InductionHeadBestMatchSoundCert where - layer1Pattern : HeadBestMatchPatternCert - layer2Pattern : HeadBestMatchPatternCert - layer2Value : HeadValueLowerBoundPosCert - layer2Logit? : Option HeadLogitDiffLowerBoundPosCert - deltaLowerBound : Rat - deriving Repr - -namespace InductionHeadBestMatchSoundCert - -/-- Internal consistency checks for the combined certificate. -/ -def Valid (c : InductionHeadBestMatchSoundCert) : Prop := - HeadBestMatchPatternCert.Valid c.layer1Pattern ∧ - HeadBestMatchPatternCert.Valid c.layer2Pattern ∧ - HeadValueLowerBoundPosCert.Valid c.layer2Value ∧ - c.layer2Value.layerIdx = c.layer2Pattern.layerIdx ∧ - c.layer2Value.headIdx = c.layer2Pattern.headIdx ∧ - c.layer2Value.queryPos = c.layer2Pattern.queryPos ∧ - c.layer2Value.matchWeightLowerBound = c.layer2Pattern.bestMatchWeightLowerBound ∧ - c.deltaLowerBound = c.layer2Value.outputCoordLowerBound ∧ - (match c.layer2Logit? with - | none => True - | some logit => - HeadLogitDiffLowerBoundPosCert.Valid logit ∧ - logit.layerIdx = c.layer2Pattern.layerIdx ∧ - logit.headIdx = c.layer2Pattern.headIdx ∧ - logit.queryPos = c.layer2Pattern.queryPos ∧ - logit.matchWeightLowerBound = c.layer2Pattern.bestMatchWeightLowerBound) - -instance (c : InductionHeadBestMatchSoundCert) : Decidable (Valid c) := by - classical - unfold Valid - cases c.layer2Logit? <;> infer_instance - -/-- Boolean checker for `Valid`. -/ -def check (c : InductionHeadBestMatchSoundCert) : Bool := - decide (Valid c) - -theorem check_iff (c : InductionHeadBestMatchSoundCert) : c.check = true ↔ c.Valid := by - simp [check, Valid] - -end InductionHeadBestMatchSoundCert - -/-! ### Specs -/ - -theorem HeadContributionCert.Valid_spec : - HeadContributionCert.Valid = HeadContributionCert.Valid := rfl -theorem HeadContributionCert.check_spec : - HeadContributionCert.check = HeadContributionCert.check := rfl -theorem HeadLocalContributionCert.Valid_spec : - HeadLocalContributionCert.Valid = HeadLocalContributionCert.Valid := rfl -theorem HeadLocalContributionCert.check_spec : - HeadLocalContributionCert.check = HeadLocalContributionCert.check := rfl -theorem mixLowerBound_spec : - mixLowerBound = mixLowerBound := rfl -theorem HeadPatternCert.Valid_spec : - HeadPatternCert.Valid = HeadPatternCert.Valid := rfl -theorem HeadPatternCert.check_spec : - HeadPatternCert.check = HeadPatternCert.check := rfl -theorem HeadPatternCert.toTokenMatchPattern_spec : - HeadPatternCert.toTokenMatchPattern = HeadPatternCert.toTokenMatchPattern := rfl -theorem HeadPatternCert.toInductionPatternWitness_spec : - HeadPatternCert.toInductionPatternWitness = HeadPatternCert.toInductionPatternWitness := rfl -theorem HeadValueLowerBoundCert.Valid_spec : - HeadValueLowerBoundCert.Valid = HeadValueLowerBoundCert.Valid := rfl -theorem HeadValueLowerBoundCert.check_spec : - HeadValueLowerBoundCert.check = HeadValueLowerBoundCert.check := rfl -theorem HeadLogitDiffLowerBoundCert.Valid_spec : - HeadLogitDiffLowerBoundCert.Valid = HeadLogitDiffLowerBoundCert.Valid := rfl -theorem HeadLogitDiffLowerBoundCert.check_spec : - HeadLogitDiffLowerBoundCert.check = HeadLogitDiffLowerBoundCert.check := rfl -theorem HeadBestMatchPatternCert.Valid_spec : - HeadBestMatchPatternCert.Valid = HeadBestMatchPatternCert.Valid := rfl -theorem HeadBestMatchPatternCert.check_spec : - HeadBestMatchPatternCert.check = HeadBestMatchPatternCert.check := rfl -theorem HeadValueLowerBoundPosCert.Valid_spec : - HeadValueLowerBoundPosCert.Valid = HeadValueLowerBoundPosCert.Valid := rfl -theorem HeadValueLowerBoundPosCert.check_spec : - HeadValueLowerBoundPosCert.check = HeadValueLowerBoundPosCert.check := rfl -theorem HeadLogitDiffLowerBoundPosCert.Valid_spec : - HeadLogitDiffLowerBoundPosCert.Valid = HeadLogitDiffLowerBoundPosCert.Valid := rfl -theorem HeadLogitDiffLowerBoundPosCert.check_spec : - HeadLogitDiffLowerBoundPosCert.check = HeadLogitDiffLowerBoundPosCert.check := rfl -theorem InductionHeadSoundCert.Valid_spec : - InductionHeadSoundCert.Valid = InductionHeadSoundCert.Valid := rfl -theorem InductionHeadSoundCert.check_spec : - InductionHeadSoundCert.check = InductionHeadSoundCert.check := rfl -theorem InductionHeadBestMatchSoundCert.Valid_spec : - InductionHeadBestMatchSoundCert.Valid = InductionHeadBestMatchSoundCert.Valid := rfl -theorem InductionHeadBestMatchSoundCert.check_spec : - InductionHeadBestMatchSoundCert.check = InductionHeadBestMatchSoundCert.check := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/IO.lean b/Nfp/Sound/IO.lean deleted file mode 100644 index 088c003..0000000 --- a/Nfp/Sound/IO.lean +++ /dev/null @@ -1,570 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.BinaryPure -import Nfp.Sound.Cert -import Nfp.Sound.HeadCert -import Nfp.Sound.ModelHeader -import Nfp.Sound.TextPure -import Nfp.Untrusted.SoundBinary -import Nfp.Untrusted.SoundCompute - -namespace Nfp.Sound - -open IO - -/-! -# SOUND IO wrappers (trusted verification only) - -This module is intentionally thin: it delegates witness generation to -`Nfp.Untrusted.SoundCompute` and **verifies** returned certificates locally. --/ - -private def readTextModelHeader (path : System.FilePath) : - IO (Except String TextHeader) := do - let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray - return Nfp.Sound.parseTextHeader lines - -private def readBinaryModelHeader (path : System.FilePath) : - IO (Except String Nfp.Sound.BinaryHeader) := do - IO.FS.withFile path IO.FS.Mode.read fun h => do - match ← Nfp.Untrusted.SoundBinary.readBinaryHeader h with - | .error e => return .error e - | .ok hdr => return .ok hdr - -private def readModelHeader (path : System.FilePath) : - IO (Except String (Rat × GeluDerivTarget)) := do - let firstLine ← - IO.FS.withFile path IO.FS.Mode.read fun h => h.getLine - if firstLine.trim = "NFP_BINARY_V1" then - match ← readBinaryModelHeader path with - | .error e => return .error e - | .ok hdr => return .ok (hdr.eps, hdr.geluDerivTarget) - else - match ← readTextModelHeader path with - | .error e => return .error e - | .ok hdr => return .ok (hdr.eps, hdr.geluDerivTarget) - -private def readModelEps (path : System.FilePath) : IO (Except String Rat) := do - match ← readModelHeader path with - | .error e => return .error e - | .ok (eps, _) => return .ok eps - -private def checkAttnWeightBounds - (cert : ModelCert) - (expected : AttnWeightBounds) : Except String Unit := - Id.run do - if expected.attnValueCoeff.size ≠ cert.layers.size then - return .error "attnValueCoeff layer count mismatch" - if expected.wqOpBoundMax.size ≠ cert.layers.size then - return .error "wqOpBoundMax layer count mismatch" - if expected.wkOpBoundMax.size ≠ cert.layers.size then - return .error "wkOpBoundMax layer count mismatch" - for idx in [:cert.layers.size] do - let expValue := expected.attnValueCoeff[idx]! - let expWq := expected.wqOpBoundMax[idx]! - let expWk := expected.wkOpBoundMax[idx]! - let layer := cert.layers[idx]! - if expValue ≠ layer.attnValueCoeff then - return .error s!"attnValueCoeff mismatch at layer {idx}" - if expWq ≠ layer.wqOpBoundMax then - return .error s!"wqOpBoundMax mismatch at layer {idx}" - if expWk ≠ layer.wkOpBoundMax then - return .error s!"wkOpBoundMax mismatch at layer {idx}" - return .ok () - -theorem checkAttnWeightBounds_spec_io : - checkAttnWeightBounds = checkAttnWeightBounds := rfl - -private def checkSoftmaxMarginZero (cert : ModelCert) : Except String Unit := - Id.run do - for idx in [:cert.layers.size] do - let layer := cert.layers[idx]! - if layer.softmaxMarginLowerBound ≠ 0 then - return .error s!"softmaxMarginLowerBound is unverified (layer {idx})" - return .ok () - -theorem checkSoftmaxMarginZero_spec_io : - checkSoftmaxMarginZero = checkSoftmaxMarginZero := rfl - -private def recomputeAttnWeightBoundsBinary - (path : System.FilePath) : IO (Except String AttnWeightBounds) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← Nfp.Untrusted.SoundBinary.readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - let scalePow10 := defaultBinaryScalePow10 - match ← Nfp.Untrusted.SoundBinary.skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← Nfp.Untrusted.SoundBinary.skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut coeffs : Array Rat := Array.mkEmpty hdr.numLayers - let mut wqMaxs : Array Rat := Array.mkEmpty hdr.numLayers - let mut wkMaxs : Array Rat := Array.mkEmpty hdr.numLayers - for _l in [:hdr.numLayers] do - let mut valuePairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads - let mut qkPairs : Array (Int × Int) := Array.mkEmpty hdr.numHeads - for _h in [:hdr.numHeads] do - let wqScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let nvScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.headDim scalePow10 - let nvScaled ← - match nvScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let noScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.headDim hdr.modelDim scalePow10 - let noScaled ← - match noScaledE with - | .error e => return .error e - | .ok v => pure v - qkPairs := qkPairs.push (wqScaled, wkScaled) - valuePairs := valuePairs.push (nvScaled, noScaled) - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let nWinScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.modelDim hdr.hiddenDim scalePow10 - let nWinScaled ← - match nWinScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - let nWoutScaledE ← - Nfp.Untrusted.SoundBinary.readMatrixNormInfScaled - h hdr.hiddenDim hdr.modelDim scalePow10 - let nWoutScaled ← - match nWoutScaledE with - | .error e => return .error e - | .ok v => pure v - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let ln1GammaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1GammaScaled ← - match ln1GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln1BetaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1BetaScaled ← - match ln1BetaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2GammaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln2GammaScaled ← - match ln2GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2BetaScaledE ← - Nfp.Untrusted.SoundBinary.readVectorMaxAbsScaled h hdr.modelDim scalePow10 - match ln2BetaScaledE with - | .error e => return .error e - | .ok _ => pure () - let _ := nWinScaled - let _ := nWoutScaled - let _ := ln1GammaScaled - let _ := ln1BetaScaled - let _ := ln2GammaScaled - let coeff := attnValueCoeffFromScaledPairs scalePow10 valuePairs - let (wqMax, wkMax) := attnQKMaxFromScaledPairs scalePow10 qkPairs - coeffs := coeffs.push coeff - wqMaxs := wqMaxs.push wqMax - wkMaxs := wkMaxs.push wkMax - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← Nfp.Untrusted.SoundBinary.skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← Nfp.Untrusted.SoundBinary.skipF64Array h (hdr.modelDim * hdr.vocabSize) with - | .error e => return .error e - | .ok _ => pure () - return .ok { - attnValueCoeff := coeffs - wqOpBoundMax := wqMaxs - wkOpBoundMax := wkMaxs - } - -private def recomputeAttnWeightBoundsText - (path : System.FilePath) : IO (Except String AttnWeightBounds) := do - let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray - return attnWeightBoundsFromTextLines lines - -private def recomputeAttnWeightBounds - (path : System.FilePath) : IO (Except String AttnWeightBounds) := do - let firstLine ← - IO.FS.withFile path IO.FS.Mode.read fun h => h.getLine - if firstLine.trim = "NFP_BINARY_V1" then - recomputeAttnWeightBoundsBinary path - else - recomputeAttnWeightBoundsText path - -/-- Compute weight-only per-head contribution bounds from a binary `.nfpt`. -/ -def certifyHeadBoundsBinary - (path : System.FilePath) - (scalePow10 : Nat := 9) : - IO (Except String (Array HeadContributionCert)) := do - match ← Nfp.Untrusted.SoundCompute.certifyHeadBoundsBinary path scalePow10 with - | .error e => return .error e - | .ok certs => - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - return .ok certs - return .error "head contribution certificate failed internal checks" - -/-- Soundly compute conservative per-layer residual amplification constants from a `.nfpt` file. -/ -def certifyModelFileGlobal - (path : System.FilePath) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - match ← readModelHeader path with - | .error e => return .error e - | .ok (eps, geluTarget) => - match ← - Nfp.Untrusted.SoundCompute.certifyModelFileGlobal - path eps geluTarget soundnessBits inputPath? inputDelta partitionDepth - softmaxMarginLowerBound softmaxExpEffort with - | .error e => return .error e - | .ok cert => - if cert.eps ≠ eps then - return .error "model header eps mismatch" - if cert.soundnessBits ≠ soundnessBits then - return .error "soundness bits mismatch" - if cert.geluDerivTarget ≠ geluTarget then - return .error "model header gelu_kind mismatch" - if cert.check then - match checkSoftmaxMarginZero cert with - | .error e => return .error e - | .ok _ => - match ← recomputeAttnWeightBounds path with - | .error e => - return .error s!"attnWeightBounds verification failed: {e}" - | .ok bounds => - match checkAttnWeightBounds cert bounds with - | .error e => return .error e - | .ok _ => return .ok cert - return .error "sound certificate failed internal consistency checks" - -/-- Entry point for sound certification (global or local). -/ -def certifyModelFile - (path : System.FilePath) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - match ← readModelHeader path with - | .error e => return .error e - | .ok (eps, geluTarget) => - match ← - Nfp.Untrusted.SoundCompute.certifyModelFile - path eps geluTarget soundnessBits inputPath? inputDelta partitionDepth - softmaxMarginLowerBound softmaxExpEffort with - | .error e => return .error e - | .ok cert => - if cert.eps ≠ eps then - return .error "model header eps mismatch" - if cert.soundnessBits ≠ soundnessBits then - return .error "soundness bits mismatch" - if cert.geluDerivTarget ≠ geluTarget then - return .error "model header gelu_kind mismatch" - if cert.check then - match checkSoftmaxMarginZero cert with - | .error e => return .error e - | .ok _ => - match ← recomputeAttnWeightBounds path with - | .error e => - return .error s!"attnWeightBounds verification failed: {e}" - | .ok bounds => - match checkAttnWeightBounds cert bounds with - | .error e => return .error e - | .ok _ => return .ok cert - return .error "sound certificate failed internal consistency checks" - -/-- Compute per-head contribution bounds (global). -/ -def certifyHeadBounds - (path : System.FilePath) - (scalePow10 : Nat := 9) : - IO (Except String (Array HeadContributionCert)) := do - match ← Nfp.Untrusted.SoundCompute.certifyHeadBounds path scalePow10 with - | .error e => return .error e - | .ok certs => - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - return .ok certs - return .error "head contribution certificate failed internal checks" - -/-- Compute local per-head attention contribution bounds. -/ -def certifyHeadBoundsLocal - (path : System.FilePath) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (scalePow10 : Nat := 9) : - IO (Except String (Array HeadLocalContributionCert)) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadBoundsLocal - path eps inputPath? inputDelta soundnessBits scalePow10 with - | .error e => return .error e - | .ok certs => - let ok := - certs.foldl (fun acc c => - acc && c.soundnessBits = soundnessBits && c.check eps) true - if ok then - return .ok certs - return .error "local head contribution certificate failed internal checks" - -/-- Compute local attention pattern bounds for a specific head. -/ -def certifyHeadPatternLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String HeadPatternCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadPatternLocal - path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with - | .error e => return .error e - | .ok cert => - if cert.check then - return .ok cert - return .error "head pattern certificate failed internal checks" - -/-- Compute local best-match pattern bounds for a specific head. -/ -def certifyHeadPatternBestMatchLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (queryPos? : Option Nat := none) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String HeadBestMatchPatternCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocal - path layerIdx headIdx queryPos? eps soundnessBits inputPath? inputDelta targetOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort with - | .error e => return .error e - | .ok cert => - if cert.check then - return .ok cert - return .error "head best-match pattern certificate failed internal checks" - -/-- Compute local best-match pattern bounds for a sweep of heads. -/ -def certifyHeadPatternBestMatchLocalSweep - (path : System.FilePath) - (layerIdx headIdx : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String (Array HeadBestMatchPatternCert)) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadPatternBestMatchLocalSweep - path layerIdx headIdx eps soundnessBits inputPath? inputDelta targetOffset maxSeqLen - tightPattern tightPatternLayers perRowPatternLayers scalePow10 softmaxExpEffort with - | .error e => return .error e - | .ok certs => - let ok := certs.foldl (fun acc c => acc && c.check) true - if ok then - return .ok certs - return .error "head best-match sweep certificate failed internal checks" - -/-- Compute local head output lower bounds. -/ -def certifyHeadValueLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (coord : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) : - IO (Except String HeadValueLowerBoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadValueLowerBoundLocal - path layerIdx headIdx coord eps soundnessBits inputPath? inputDelta targetOffset - maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 with - | .error e => return .error e - | .ok cert => - if cert.check then - return .ok cert - return .error "head value lower bound certificate failed internal checks" - -/-- Compute local head logit-difference lower bounds. -/ -def certifyHeadLogitDiffLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (targetToken negativeToken : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := 9) : - IO (Except String HeadLogitDiffLowerBoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyHeadLogitDiffLowerBoundLocal - path layerIdx headIdx targetToken negativeToken eps soundnessBits inputPath? inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 with - | .error e => return .error e - | .ok cert => - if cert.check then - return .ok cert - return .error "head logit-diff lower bound certificate failed internal checks" - -/-- Sound induction-head certification (local path). -/ -def certifyInductionSound - (path : System.FilePath) - (layer1 head1 layer2 head2 : Nat) - (coord : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (offset1 : Int := -1) - (offset2 : Int := -1) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := 9) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String InductionHeadSoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyInductionSound - path layer1 head1 layer2 head2 coord eps soundnessBits inputPath? inputDelta - offset1 offset2 maxSeqLen scalePow10 tightPattern tightPatternLayers - perRowPatternLayers targetToken? negativeToken? softmaxExpEffort with - | .error e => return .error e - | .ok cert => - if cert.check then - return .ok cert - return .error "induction head certificate failed internal checks" - -/-- Sound best-match induction-head certification (local path). -/ -def certifyInductionSoundBestMatch - (path : System.FilePath) - (layer1 head1 layer2 head2 : Nat) - (coord : Nat) - (queryPos? : Option Nat := none) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (offset1 : Int := -1) - (offset2 : Int := -1) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := 9) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String InductionHeadBestMatchSoundCert) := do - match ← readModelEps path with - | .error e => return .error e - | .ok eps => - match ← - Nfp.Untrusted.SoundCompute.certifyInductionSoundBestMatch - path layer1 head1 layer2 head2 coord queryPos? eps soundnessBits inputPath? inputDelta - offset1 offset2 maxSeqLen scalePow10 tightPattern - tightPatternLayers perRowPatternLayers targetToken? negativeToken? softmaxExpEffort with - | .error e => return .error e - | .ok cert => - if cert.check then - return .ok cert - return .error "best-match induction head certificate failed internal checks" - -end Nfp.Sound diff --git a/Nfp/Sound/Induction.lean b/Nfp/Sound/Induction.lean new file mode 100644 index 0000000..c0352fe --- /dev/null +++ b/Nfp/Sound/Induction.lean @@ -0,0 +1,15 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction.Core +public import Nfp.Sound.Induction.HeadOutput +public import Nfp.Sound.Induction.LogitDiff +public import Nfp.Sound.Induction.OneHot + +/-! +Soundness lemmas for induction certificates. + +This module re-exports the core definitions, head-output helpers, and logit-diff +lemmas that operate on explicit certificates. +-/ diff --git a/Nfp/Sound/Induction/Core.lean b/Nfp/Sound/Induction/Core.lean new file mode 100644 index 0000000..b67eca5 --- /dev/null +++ b/Nfp/Sound/Induction/Core.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction.CoreDefs + +/-! +Core definitions for induction certificates. +-/ diff --git a/Nfp/Sound/Induction/CoreDefs.lean b/Nfp/Sound/Induction/CoreDefs.lean new file mode 100644 index 0000000..d48984a --- /dev/null +++ b/Nfp/Sound/Induction/CoreDefs.lean @@ -0,0 +1,274 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Batteries.Data.Vector.Lemmas +public import Nfp.Circuit.Layers.Induction +public import Nfp.Circuit.Layers.Softmax +public import Nfp.Core.Basic +public import Nfp.Model.InductionHead +public import Nfp.Bounds.LayerNorm +public import Nfp.Linear.FinFold + +/-! +Core definitions for induction-head certificates. + +These definitions are shared across induction certificate checkers and proofs. +-/ + +public section + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Circuit +open Nfp.Bounds + +variable {seq : Nat} + +/-- Cached direction head for head inputs. -/ +def dirHeadVecOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Vector Rat dHead := + Vector.ofFn (fun d : Fin dHead => + Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j)) + +/-- Unfolding lemma for `dirHeadVecOfInputs`. -/ +theorem dirHeadVecOfInputs_get {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (d : Fin dHead) : + (dirHeadVecOfInputs inputs).get d = + Linear.dotFin dModel (fun j => inputs.wo j d) (fun j => inputs.direction j) := by + simp [dirHeadVecOfInputs] + +/-- Real-valued LayerNorm outputs for head inputs. -/ +noncomputable def lnRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := + fun q => + Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) + +/-- Unfolding lemma for `lnRealOfInputs`. -/ +theorem lnRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (i : Fin dModel) : + lnRealOfInputs inputs q i = + Bounds.layerNormReal inputs.lnEps inputs.ln1Gamma inputs.ln1Beta (inputs.embed q) i := by + simp [lnRealOfInputs] + +/-- Real-valued query projections for head inputs. -/ +noncomputable def qRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + (inputs.bq d : Real) + +/-- Unfolding lemma for `qRealOfInputs`. -/ +theorem qRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : + qRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wq j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bq d : Real) := by + simp [qRealOfInputs] + +/-- Real-valued key projections for head inputs. -/ +noncomputable def kRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + (inputs.bk d : Real) + +/-- Unfolding lemma for `kRealOfInputs`. -/ +theorem kRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : + kRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wk j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bk d : Real) := by + simp [kRealOfInputs] + +/-- Real-valued value projections for head inputs. -/ +noncomputable def vRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dHead → Real := + fun q d => + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + (inputs.bv d : Real) + +/-- Unfolding lemma for `vRealOfInputs`. -/ +theorem vRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) (d : Fin dHead) : + vRealOfInputs inputs q d = + dotProduct (fun j => (inputs.wv j d : Real)) (lnRealOfInputs inputs q) + + (inputs.bv d : Real) := by + simp [vRealOfInputs] + +/-- Real-valued attention scores for head inputs. -/ +noncomputable def scoresRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin seq → Real := + fun q k => + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + if inputs.maskCausal then + if k ≤ q then + base + else + (inputs.maskValue : Real) + else + base + +/-- Unfolding lemma for `scoresRealOfInputs`. -/ +theorem scoresRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (q k : Fin seq) : + scoresRealOfInputs inputs q k = + let base := + (inputs.scale : Real) * + dotProduct (fun d => qRealOfInputs inputs q d) (fun d => kRealOfInputs inputs k d) + if inputs.maskCausal then + if k ≤ q then + base + else + (inputs.maskValue : Real) + else + base := by + simp [scoresRealOfInputs] + +/-- Real-valued per-key head outputs in model space. -/ +noncomputable def headValueRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Fin dModel → Real := + fun k i => + dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) + +/-- Unfolding lemma for `headValueRealOfInputs`. -/ +theorem headValueRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) (i : Fin dModel) : + headValueRealOfInputs inputs k i = + dotProduct (fun d => (inputs.wo i d : Real)) (fun d => vRealOfInputs inputs k d) := by + simp [headValueRealOfInputs] + +/-- Real-valued direction scores for head inputs. -/ +noncomputable def valsRealOfInputs {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) : Fin seq → Real := + let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d + fun k => dotProduct dirHead (fun d => vRealOfInputs inputs k d) + +/-- Unfolding lemma for `valsRealOfInputs`. -/ +theorem valsRealOfInputs_def {seq dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) : + valsRealOfInputs inputs k = + let dirHead : Fin dHead → Real := fun d => (dirHeadVecOfInputs inputs).get d + dotProduct dirHead (fun d => vRealOfInputs inputs k d) := by + simp [valsRealOfInputs] + +/-- Interval data for direction values. -/ +structure ValueInterval (seq : Nat) where + /-- Lower bound for values. -/ + lo : Rat + /-- Upper bound for values. -/ + hi : Rat + /-- Lower bounds on per-key values. -/ + valsLo : Fin seq → Rat + /-- Upper bounds on per-key values. -/ + valsHi : Fin seq → Rat + /-- Optional logit-diff direction metadata (ignored by the checker). -/ + direction : Option DirectionSpec + +/-- Soundness predicate for direction-value interval data. -/ +structure ValueIntervalBounds {seq : Nat} (vals : Fin seq → Real) + (c : ValueInterval seq) : Prop where + /-- Interval endpoints are ordered. -/ + lo_le_hi : c.lo ≤ c.hi + /-- `lo` is below every lower bound. -/ + lo_le_valsLo : ∀ k, (c.lo : Real) ≤ (c.valsLo k : Real) + /-- Bounds sandwich the real values. -/ + vals_bounds : + ∀ k, (c.valsLo k : Real) ≤ vals k ∧ vals k ≤ (c.valsHi k : Real) + /-- `hi` is above every upper bound. -/ + valsHi_le_hi : ∀ k, (c.valsHi k : Real) ≤ (c.hi : Real) + +/-- Split-budget knobs for sign-splitting bounds in induction-head certificates. -/ +structure InductionHeadSplitConfig where + /-- Split budget for query dims. -/ + splitBudgetQ : Nat + /-- Split budget for key dims. -/ + splitBudgetK : Nat + /-- Split budget for base diff dims. -/ + splitBudgetDiffBase : Nat + /-- Split budget for refined diff dims. -/ + splitBudgetDiffRefined : Nat + +/-- Default split budgets for induction-head sign-splitting bounds. -/ +def defaultInductionHeadSplitConfig : InductionHeadSplitConfig := + { splitBudgetQ := 2 + splitBudgetK := 2 + splitBudgetDiffBase := 0 + splitBudgetDiffRefined := 12 } + +/-- Unfolding lemma for `defaultInductionHeadSplitConfig`. -/ +private theorem defaultInductionHeadSplitConfig_def : + defaultInductionHeadSplitConfig = + { splitBudgetQ := 2 + splitBudgetK := 2 + splitBudgetDiffBase := 0 + splitBudgetDiffRefined := 12 } := rfl + +/-- Sound induction-certificate payload built from exact head inputs. -/ +structure InductionHeadCert (seq : Nat) where + /-- Weight tolerance. -/ + eps : Rat + /-- Per-query weight tolerance derived from local margins. -/ + epsAt : Fin seq → Rat + /-- Per-key weight bounds derived from score gaps. -/ + weightBoundAt : Fin seq → Fin seq → Rat + /-- Score margin used to justify the weight tolerance. -/ + margin : Rat + /-- Active queries for which bounds are required. -/ + active : Finset (Fin seq) + /-- `prev` selector for induction-style attention. -/ + prev : Fin seq → Fin seq + /-- Value-interval certificate for the direction values. -/ + values : ValueInterval seq + +/-- Extensionality lemma for `InductionHeadCert`. -/ +@[ext] theorem InductionHeadCert.ext {seq : Nat} {c₁ c₂ : InductionHeadCert seq} + (hε : c₁.eps = c₂.eps) + (hεAt : c₁.epsAt = c₂.epsAt) + (hweight : c₁.weightBoundAt = c₂.weightBoundAt) + (hmargin : c₁.margin = c₂.margin) + (hactive : c₁.active = c₂.active) + (hprev : c₁.prev = c₂.prev) + (hvalues : c₁.values = c₂.values) : + c₁ = c₂ := by + cases c₁ + cases c₂ + cases hε + cases hεAt + cases hweight + cases hmargin + cases hactive + cases hprev + cases hvalues + rfl + +/-- Soundness predicate for `InductionHeadCert`. -/ +structure InductionHeadCertSound [NeZero seq] {dModel dHead : Nat} + (inputs : Model.InductionHeadInputs seq dModel dHead) (c : InductionHeadCert seq) : Prop where + /-- Softmax weights respect the derived margin bounds. -/ + softmax_bounds : + Layers.SoftmaxMarginBoundsOn (Val := Real) (c.eps : Real) (c.margin : Real) + (fun q => q ∈ c.active) c.prev + (scoresRealOfInputs inputs) + (fun q k => Circuit.softmax (scoresRealOfInputs inputs q) k) + /-- Per-query one-hot bounds derived from local margins. -/ + oneHot_bounds_at : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (c.epsAt q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k) + /-- Per-key weight bounds derived from local score gaps. -/ + weight_bounds_at : + ∀ q, q ∈ c.active → ∀ k, k ≠ c.prev q → + Circuit.softmax (scoresRealOfInputs inputs q) k ≤ (c.weightBoundAt q k : Real) + /-- Interval bounds hold for the direction values. -/ + value_bounds : + ValueIntervalBounds (vals := valsRealOfInputs inputs) c.values + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/HeadOutput.lean b/Nfp/Sound/Induction/HeadOutput.lean new file mode 100644 index 0000000..9035efc --- /dev/null +++ b/Nfp/Sound/Induction/HeadOutput.lean @@ -0,0 +1,59 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction.CoreDefs + +/-! +Head-output definitions for induction heads. +-/ + +public section + +namespace Nfp + +namespace Sound + +open Nfp.Circuit + +variable {seq dModel dHead : Nat} + +noncomputable section + +/-- Real-valued head output using explicit score inputs. -/ +def headOutputWithScores (scores : Fin seq → Fin seq → Real) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : Real := + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scores q) k + let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i + dotProduct (weights q) vals + +/-- Unfolding lemma for `headOutputWithScores`. -/ +theorem headOutputWithScores_def (scores : Fin seq → Fin seq → Real) + (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : + headOutputWithScores scores inputs q i = + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scores q) k + let vals : Fin seq → Real := fun k => headValueRealOfInputs inputs k i + dotProduct (weights q) vals := by + simp [headOutputWithScores] + +/-- Real-valued head output for a query and model dimension. -/ +def headOutput (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : Real := + headOutputWithScores (scoresRealOfInputs inputs) inputs q i + +/-- Unfolding lemma for `headOutput`. -/ +theorem headOutput_def (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) (i : Fin dModel) : + headOutput inputs q i = + headOutputWithScores (scoresRealOfInputs inputs) inputs q i := by + simp [headOutput] + +end + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/LogitDiff.lean b/Nfp/Sound/Induction/LogitDiff.lean new file mode 100644 index 0000000..0f6c987 --- /dev/null +++ b/Nfp/Sound/Induction/LogitDiff.lean @@ -0,0 +1,1011 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Aesop +public import Mathlib.Data.List.MinMax +public import Mathlib.Data.Vector.Basic +public import Nfp.Circuit.Cert.LogitDiff +public import Nfp.Bounds.Cache +public import Nfp.Sound.Induction.HeadOutput + +/-! +Logit-diff bounds derived from induction certificates. +-/ + +public section + +namespace Nfp + +namespace Sound + +open Nfp.Circuit + +variable {seq : Nat} + +section Direction + +variable {seq dModel dHead : Nat} + +/-- Direction projection of a single-key head output. -/ +theorem direction_dot_headValue_eq_valsReal + (inputs : Model.InductionHeadInputs seq dModel dHead) (k : Fin seq) : + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headValueRealOfInputs inputs k i) = + valsRealOfInputs inputs k := by + classical + let dir : Fin dModel → Real := fun i => (inputs.direction i : Real) + let v : Fin dHead → Real := fun d => vRealOfInputs inputs k d + have hswap : + ∑ i, dir i * ∑ d, (inputs.wo i d : Real) * v d = + ∑ d, (∑ i, dir i * (inputs.wo i d : Real)) * v d := by + calc + ∑ i, dir i * ∑ d, (inputs.wo i d : Real) * v d + = ∑ i, ∑ d, dir i * ((inputs.wo i d : Real) * v d) := by + simp [Finset.mul_sum] + _ = ∑ d, ∑ i, dir i * ((inputs.wo i d : Real) * v d) := by + simpa using + (Finset.sum_comm (s := (Finset.univ : Finset (Fin dModel))) + (t := (Finset.univ : Finset (Fin dHead))) + (f := fun i d => dir i * ((inputs.wo i d : Real) * v d))) + _ = ∑ d, (∑ i, dir i * (inputs.wo i d : Real)) * v d := by + refine Finset.sum_congr rfl ?_ + intro d _ + simp [mul_assoc, Finset.sum_mul] + have hdirHead : + ∀ d, ((dirHeadVecOfInputs inputs).get d : Real) = + ∑ i, (inputs.wo i d : Real) * (inputs.direction i : Real) := by + intro d + calc + ((dirHeadVecOfInputs inputs).get d : Real) = + ratToReal (Linear.dotFin dModel (fun i => inputs.wo i d) + (fun i => inputs.direction i)) := by + simp [dirHeadVecOfInputs_get, ratToReal_def] + _ = + ratToReal (dotProduct (fun i => inputs.wo i d) (fun i => inputs.direction i)) := by + simp [Linear.dotFin_eq_dotProduct] + _ = + ∑ i, ratToReal (inputs.wo i d * inputs.direction i) := by + simp [dotProduct, Linear.ratToReal_sum_univ] + _ = + ∑ i, (inputs.wo i d : Real) * (inputs.direction i : Real) := by + simp [ratToReal_def] + calc + dotProduct dir (fun i => headValueRealOfInputs inputs k i) + = ∑ i, dir i * + ∑ d, (inputs.wo i d : Real) * v d := by + simp [dir, v, headValueRealOfInputs_def, dotProduct] + _ = ∑ d, (∑ i, dir i * (inputs.wo i d : Real)) * v d := by + simp [hswap] + _ = ∑ d, ((dirHeadVecOfInputs inputs).get d : Real) * v d := by + refine Finset.sum_congr rfl ?_ + intro d _ + have hdir : + ∑ i, dir i * (inputs.wo i d : Real) = + ((dirHeadVecOfInputs inputs).get d : Real) := by + calc + ∑ i, dir i * (inputs.wo i d : Real) + = ∑ i, (inputs.wo i d : Real) * (inputs.direction i : Real) := by + simp [dir, mul_comm] + _ = ((dirHeadVecOfInputs inputs).get d : Real) := by + simpa using (hdirHead d).symm + simp [hdir] + _ = valsRealOfInputs inputs k := by + simp [valsRealOfInputs_def, v, dotProduct] + +end Direction + +section LogitDiffLowerBound + +variable {seq dModel dHead : Nat} + +/-- Real-valued logit-diff contribution for a query. -/ +noncomputable def headLogitDiff (inputs : Model.InductionHeadInputs seq dModel dHead) + (q : Fin seq) : Real := + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + dotProduct (weights q) (valsRealOfInputs inputs) + +/-- Unfolding lemma for `headLogitDiff`. -/ +theorem headLogitDiff_def (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) : + headLogitDiff inputs q = + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + dotProduct (weights q) (valsRealOfInputs inputs) := by + rfl + +/-- Lower bound computed from the per-key lower bounds in an induction certificate. -/ +def logitDiffLowerBoundFromCert (c : InductionHeadCert seq) : Option Rat := + let epsAt := Bounds.cacheBoundTask c.epsAt + let valsLo := Bounds.cacheBoundTask c.values.valsLo + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo + +/-- Lower bound computed from per-key weight bounds in an induction certificate. -/ +def logitDiffLowerBoundFromCertWeighted (c : InductionHeadCert seq) : Option Rat := + let valsLo := Bounds.cacheBoundTask c.values.valsLo + Circuit.logitDiffLowerBoundWeightedAt c.active c.prev c.weightBoundAt valsLo + +/-- Cached eps and value lower bounds for logit-diff computations. -/ +structure LogitDiffCache (seq : Nat) where + /-- Per-query eps bounds. -/ + epsAt : Fin seq → Rat + /-- Per-key value lower bounds. -/ + valsLo : Fin seq → Rat + +/-- Build a shared cache for logit-diff computations from a certificate. -/ +def logitDiffCache (c : InductionHeadCert seq) : LogitDiffCache seq := + { epsAt := Bounds.cacheBoundTask c.epsAt + valsLo := Bounds.cacheBoundTask c.values.valsLo } + +/-- Unfolding lemma for `logitDiffCache`. -/ +theorem logitDiffCache_def (c : InductionHeadCert seq) : + logitDiffCache c = + { epsAt := Bounds.cacheBoundTask c.epsAt + valsLo := Bounds.cacheBoundTask c.values.valsLo } := by + rfl + +/-- Unweighted logit-diff lower bound from a shared cache. -/ +def logitDiffLowerBoundFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option Rat := + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo + +/-- Query attaining the cached unweighted logit-diff lower bound, if any. -/ +def logitDiffLowerBoundArgminFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option (Fin seq) := + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let f : Fin seq → Rat := fun q => + let delta := valsLo (c.prev q) - loAt q + valsLo (c.prev q) - epsAt q * max (0 : Rat) delta + let qs := (List.finRange seq).filter (fun q => decide (q ∈ c.active)) + List.argmin f qs + +/-- Unfolding lemma for `logitDiffLowerBoundArgminFromCache`. -/ +theorem logitDiffLowerBoundArgminFromCache_def + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + logitDiffLowerBoundArgminFromCache c cache = + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let f : Fin seq → Rat := fun q => + let delta := valsLo (c.prev q) - loAt q + valsLo (c.prev q) - epsAt q * max (0 : Rat) delta + let qs := (List.finRange seq).filter (fun q => decide (q ∈ c.active)) + List.argmin f qs := by + rfl + +/-- Unweighted logit-diff lower bound from a shared cache and custom `epsAt`. -/ +def logitDiffLowerBoundFromCacheWithEps (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (epsAtCustom : Fin seq → Rat) : Option Rat := + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo + +/-- Unweighted logit-diff lower bound from a custom eps and value lower bounds. -/ +def logitDiffLowerBoundFromCacheWithEpsVals (c : InductionHeadCert seq) + (epsAtCustom valsLoCustom : Fin seq → Rat) : Option Rat := + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn valsLoCustom + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo + +/-- Unfold `logitDiffLowerBoundFromCache` as the custom-eps variant. -/ +theorem logitDiffLowerBoundFromCache_eq_withEps + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + logitDiffLowerBoundFromCache c cache = + logitDiffLowerBoundFromCacheWithEps c cache cache.epsAt := by + rfl + +/-- Unfolding lemma for `logitDiffLowerBoundFromCacheWithEps`. -/ +theorem logitDiffLowerBoundFromCacheWithEps_def + (c : InductionHeadCert seq) (cache : LogitDiffCache seq) + (epsAtCustom : Fin seq → Rat) : + logitDiffLowerBoundFromCacheWithEps c cache epsAtCustom = + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo := by + rfl + +/-- Unfolding lemma for `logitDiffLowerBoundFromCacheWithEpsVals`. -/ +theorem logitDiffLowerBoundFromCacheWithEpsVals_def + (c : InductionHeadCert seq) (epsAtCustom valsLoCustom : Fin seq → Rat) : + logitDiffLowerBoundFromCacheWithEpsVals c epsAtCustom valsLoCustom = + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn valsLoCustom + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + Circuit.logitDiffLowerBoundAtLoAt c.active c.prev epsAt loAt valsLo := by + rfl + + +/-- Weighted logit-diff lower bound from a shared cache. -/ +def logitDiffLowerBoundWeightedFromCache (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option Rat := + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let valsLo : Fin seq → Rat := fun k => + valsLoArr[k.1]'(by + simp [valsLoArr, k.isLt]) + let weightRows : Array (Array Rat) := + Array.ofFn (fun q : Fin seq => Array.ofFn (fun k : Fin seq => c.weightBoundAt q k)) + let weightBoundAt : Fin seq → Fin seq → Rat := fun q k => + let row := weightRows[q.1]'(by + simp [weightRows, q.isLt]) + row[k.1]'(by + have hrow : row.size = seq := by + simp [row, weightRows] + simp [hrow, k.isLt]) + let gapBase : Fin seq → Rat := fun q => + let valsLoPrev := valsLo (c.prev q) + Linear.sumFin seq (fun k => + let diff := valsLoPrev - valsLo k + weightBoundAt q k * max (0 : Rat) diff) + let gap : Fin seq → Rat := Bounds.cacheBoundTask gapBase + if h : c.active.Nonempty then + let f : Fin seq → Rat := fun q => valsLo (c.prev q) - gap q + some (c.active.inf' h f) + else + none + +/-- Debug payload for the unweighted logit-diff lower bound. -/ +structure LogitDiffAtLoDebug (seq : Nat) where + /-- Query attaining the bound, if found. -/ + q : Fin seq + /-- Previous index for the query. -/ + prev : Fin seq + /-- Per-query eps bound. -/ + eps : Rat + /-- Lower bound for the previous value. -/ + valsPrevLo : Rat + /-- Global lower value bound. -/ + lo : Rat + /-- Per-query lower bound for other values. -/ + loAt : Rat + /-- `valsPrevLo - loAt`. -/ + valsPrevLoMinusLoAt : Rat + /-- `eps * max 0 (valsPrevLo - loAt)`. -/ + gap : Rat + /-- Lower bound reported by `logitDiffLowerBoundFromCache`. -/ + lbAtQ : Rat + +/-- Attempt to recover a query that attains the unweighted logit-diff bound. -/ +def logitDiffLowerBoundAtLoDebug (c : InductionHeadCert seq) (cache : LogitDiffCache seq) : + Option {d : LogitDiffAtLoDebug seq // + logitDiffLowerBoundFromCache c cache = some d.lbAtQ} := + let epsArr : Array Rat := Array.ofFn cache.epsAt + let valsLoArr : Array Rat := Array.ofFn cache.valsLo + let epsAt : Fin seq → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin seq → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin seq → Rat := fun q => + let others : Finset (Fin seq) := + (Finset.univ : Finset (Fin seq)).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let f : Fin seq → Rat := fun q => + let valsPrevLo := valsLo (c.prev q) + let delta := valsPrevLo - loAt q + valsPrevLo - epsAt q * max (0 : Rat) delta + let best? : Option (Fin seq × Rat) := + Linear.foldlFin seq + (fun acc q => + if hq : q ∈ c.active then + let val := f q + match acc with + | none => some (q, val) + | some (qBest, best) => + if val ≤ best then + some (q, val) + else + some (qBest, best) + else + acc) + none + match logitDiffLowerBoundFromCache c cache with + | none => none + | some lb => + match best? with + | none => none + | some (q, _) => + let prev := c.prev q + let valsPrevLo := valsLo prev + let loAtQ := loAt q + let delta := valsPrevLo - loAtQ + let gap := epsAt q * max (0 : Rat) delta + let d : LogitDiffAtLoDebug seq := + { q := q + prev := prev + eps := epsAt q + valsPrevLo := valsPrevLo + lo := c.values.lo + loAt := loAtQ + valsPrevLoMinusLoAt := delta + gap := gap + lbAtQ := lb } + have h' : some lb = some d.lbAtQ := by + simp [d] + some ⟨d, h'⟩ + +/-- `logitDiffLowerBoundFromCache` matches the cached default computation. -/ +theorem logitDiffLowerBoundFromCache_eq (c : InductionHeadCert seq) : + logitDiffLowerBoundFromCache c (logitDiffCache c) = logitDiffLowerBoundFromCert c := by + classical + unfold logitDiffLowerBoundFromCache logitDiffLowerBoundFromCert logitDiffCache + have heps : Bounds.cacheBoundTask c.epsAt = c.epsAt := by + funext k + simp [Bounds.cacheBoundTask_apply] + have hvals : Bounds.cacheBoundTask c.values.valsLo = c.values.valsLo := by + funext k + simp [Bounds.cacheBoundTask_apply] + simp [heps, hvals, Bounds.cacheBoundTask_apply] + +/-- `logitDiffLowerBoundWeightedFromCache` matches the cached default computation. -/ +theorem logitDiffLowerBoundWeightedFromCache_eq (c : InductionHeadCert seq) : + logitDiffLowerBoundWeightedFromCache c (logitDiffCache c) = + logitDiffLowerBoundFromCertWeighted c := by + classical + unfold logitDiffLowerBoundWeightedFromCache logitDiffLowerBoundFromCertWeighted logitDiffCache + have hvals : Bounds.cacheBoundTask c.values.valsLo = c.values.valsLo := by + funext k + simp [Bounds.cacheBoundTask_apply] + simp [hvals, Bounds.cacheBoundTask_apply, logitDiffLowerBoundWeightedAt_def, + Linear.sumFin_eq_sum_univ] + +/-- Best available logit-diff lower bound from an induction certificate. -/ +def logitDiffLowerBoundFromCertBest (c : InductionHeadCert seq) : Option Rat := + match logitDiffLowerBoundFromCert c, logitDiffLowerBoundFromCertWeighted c with + | some lb0, some lb1 => some (max lb0 lb1) + | some lb0, none => some lb0 + | none, some lb1 => some lb1 + | none, none => none + +section WithNeZero + +variable [NeZero seq] + +theorem logitDiffLowerBoundFromCert_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) + {lb : Rat} (hbound : logitDiffLowerBoundFromCert c = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let epsAt := Bounds.cacheBoundTask c.epsAt + let valsLo := Bounds.cacheBoundTask c.values.valsLo + let loAt : Fin (Nat.succ n) → Rat := fun q => + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let sumOthers : Real := ∑ k ∈ others, weights q k + let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) + let loAtRat : Rat := loAt q + let loAtReal : Real := (loAtRat : Real) + have hboundRat : + lb ≤ valsLo (c.prev q) - + epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by + refine + Circuit.logitDiffLowerBoundAtLoAt_le + (active := c.active) + (prev := c.prev) + (epsAt := epsAt) + (loAt := loAt) + (valsLo := valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCert, loAt] using hbound + have hboundRat' : + lb ≤ c.values.valsLo (c.prev q) - + c.epsAt q * max (0 : Rat) (c.values.valsLo (c.prev q) - loAt q) := by + simpa [epsAt, valsLo, Bounds.cacheBoundTask_apply] using hboundRat + have hboundReal : + (lb : Real) ≤ + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + simpa [loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, + ratToReal_def] using + ratToReal_le_of_le hboundRat' + have hweights_nonneg : ∀ k, 0 ≤ weights q k := + hsound.softmax_bounds.nonneg q hq + have hweights := hsound.oneHot_bounds_at q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hweights.sum_one q rfl + have hsum_others_le : sumOthers ≤ (c.epsAt q : Real) := by + have hprev : 1 ≤ weights q (c.prev q) + (c.epsAt q : Real) := + hweights.prev_large q rfl + have hprev' : + weights q (c.prev q) + sumOthers ≤ + weights q (c.prev q) + (c.epsAt q : Real) := by + simpa [hsum, sumOthers] using hprev + exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' + have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ valsLo k := by + intro k hk + have hnonempty : others.Nonempty := ⟨k, hk⟩ + have hmin : others.inf' hnonempty valsLo ≤ valsLo k := + Finset.inf'_le (s := others) (f := valsLo) hk + have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by + simpa [others] using hnonempty + have hloAt : loAtRat = others.inf' hnonempty valsLo := by + dsimp [loAtRat, loAt] + simp [hnonempty', others] + calc + loAtRat = others.inf' hnonempty valsLo := hloAt + _ ≤ valsLo k := hmin + have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by + intro k hk + have hloRat := hloAt_le_valsLo k hk + have hloReal : loAtReal ≤ (valsLo k : Real) := by + simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) + have hvals : (valsLo k : Real) ≤ vals k := by + simpa [valsLo, Bounds.cacheBoundTask_apply] using + (hsound.value_bounds.vals_bounds k).1 + exact le_trans hloReal hvals + have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by + exact (hsound.value_bounds.vals_bounds (c.prev q)).1 + have hsum_vals_ge : + sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by + have hsum_lo : + sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by + have hsum_lo' : + (∑ k ∈ others, weights q k) * loAtReal = + ∑ k ∈ others, weights q k * loAtReal := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) + simpa [sumOthers] using hsum_lo' + have hle : + ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by + intro k _hk + have hval := hvals_lo k _hk + have hnonneg := hweights_nonneg k + exact mul_le_mul_of_nonneg_left hval hnonneg + have hsum' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + exact Finset.sum_le_sum hle + simpa [hsum_lo] using hsum' + have hsum_prod : + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hdot_ge : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_vals_ge (weights q (c.prev q) * vals (c.prev q))) + simpa [sumOthers, hout_eq, add_comm, add_left_comm, add_assoc] using hle + have hprev_lo : + weights q (c.prev q) * valsLoPrev ≤ + weights q (c.prev q) * vals (c.prev q) := by + exact mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) + have hdot_ge' : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_right hprev_lo (sumOthers * loAtReal)) + exact hle.trans hdot_ge + have hsplit : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + have hsplit' : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = + (weights q (c.prev q) + sumOthers) * valsLoPrev - + sumOthers * (valsLoPrev - loAtReal) := by + ring + calc + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal = + (weights q (c.prev q) + sumOthers) * valsLoPrev - + sumOthers * (valsLoPrev - loAtReal) := hsplit' + _ = valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + simp [hsum, sumOthers] + have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_right _ _ + have hsum_nonneg : 0 ≤ sumOthers := by + have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by + intro k _hk + exact hweights_nonneg k + have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by + exact Finset.sum_nonneg hnonneg + simpa [sumOthers] using hsum_nonneg' + have hsum_mul_le_left : + sumOthers * (valsLoPrev - loAtReal) ≤ + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg + have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_left _ _ + have hsum_mul_le : + sumOthers * (valsLoPrev - loAtReal) ≤ + (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + have hsum_mul_le_right : + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ + (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg + exact le_trans hsum_mul_le_left hsum_mul_le_right + have hsub_le : + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + exact sub_le_sub_left hsum_mul_le valsLoPrev + have hdot_lower : + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - (c.epsAt q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by + simp [hsplit] + _ ≤ dotProduct (weights q) vals := hdot_ge' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff, weights, vals] using hle + +/-- The weighted per-key logit-diff lower bound is sound on active queries. -/ +theorem logitDiffLowerBoundFromCertWeighted_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) + {lb : Rat} (hbound : logitDiffLowerBoundFromCertWeighted c = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let valsLoCached := Bounds.cacheBoundTask c.values.valsLo + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let valsLoPrevRat : Rat := valsLoCached (c.prev q) + let valsLoPrev : Real := (valsLoPrevRat : Real) + have hboundRat : + lb ≤ valsLoPrevRat - + (Finset.univ : Finset (Fin (Nat.succ n))).sum (fun k => + c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - valsLoCached k)) := by + refine + Circuit.logitDiffLowerBoundWeightedAt_le + (active := c.active) + (prev := c.prev) + (weightBoundAt := c.weightBoundAt) + (valsLo := valsLoCached) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCertWeighted] using hbound + have hboundRat' : + lb ≤ valsLoPrevRat - + (Finset.univ : Finset (Fin (Nat.succ n))).sum (fun k => + c.weightBoundAt q k * max (0 : Rat) (valsLoPrevRat - c.values.valsLo k)) := by + simpa [valsLoCached, valsLoPrevRat, Bounds.cacheBoundTask_apply] using hboundRat + have hboundReal : + (lb : Real) ≤ + valsLoPrev - + (Finset.univ : Finset (Fin (Nat.succ n))).sum (fun k => + (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by + simpa [valsLoPrevRat, valsLoPrev, ratToReal_sub, ratToReal_mul, ratToReal_max, + ratToReal_def, Rat.cast_sum] using ratToReal_le_of_le hboundRat' + let gapTerm : Fin (Nat.succ n) → Real := fun k => + (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) + have hgap_prev : gapTerm (c.prev q) = 0 := by + have hdiff : valsLoPrev - (c.values.valsLo (c.prev q) : Real) = 0 := by + simp [valsLoPrev, valsLoPrevRat, valsLoCached, Bounds.cacheBoundTask_apply] + simp [gapTerm, hdiff] + have hsum_gap : + (Finset.univ : Finset (Fin (Nat.succ n))).sum gapTerm = + ∑ k ∈ others, gapTerm k := by + classical + have hsum := + (Finset.sum_erase (s := (Finset.univ : Finset (Fin (Nat.succ n)))) + (f := gapTerm) (a := c.prev q) hgap_prev) + simpa [others] using hsum.symm + have hboundReal' : + (lb : Real) ≤ + valsLoPrev - ∑ k ∈ others, gapTerm k := by + simpa [gapTerm, hsum_gap] using hboundReal + have hweights_nonneg : ∀ k, 0 ≤ weights q k := + hsound.softmax_bounds.nonneg q hq + have hweights := hsound.oneHot_bounds_at q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hweights.sum_one q rfl + have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by + have hvals := (hsound.value_bounds.vals_bounds (c.prev q)).1 + simpa [valsLoPrev, valsLoPrevRat, valsLoCached, Bounds.cacheBoundTask_apply] using hvals + have hvals_lower : + ∀ k, + valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ vals k := by + intro k + by_cases hle : valsLoPrev ≤ (c.values.valsLo k : Real) + · have hdiff : valsLoPrev - (c.values.valsLo k : Real) ≤ 0 := sub_nonpos.mpr hle + have hmax : + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) = 0 := by + simp [hdiff] + have hvals_lo : (c.values.valsLo k : Real) ≤ vals k := + (hsound.value_bounds.vals_bounds k).1 + have hvals_prev : valsLoPrev ≤ vals k := le_trans hle hvals_lo + simpa [hmax] using hvals_prev + · have hlt : (c.values.valsLo k : Real) < valsLoPrev := lt_of_not_ge hle + have hdiff_nonneg : 0 ≤ valsLoPrev - (c.values.valsLo k : Real) := + le_of_lt (sub_pos.mpr hlt) + have hmax : + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) = + valsLoPrev - (c.values.valsLo k : Real) := by + simp [hdiff_nonneg] + have hvals_lo : (c.values.valsLo k : Real) ≤ vals k := + (hsound.value_bounds.vals_bounds k).1 + simpa [hmax] using hvals_lo + have hsum_vals_ge : + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) ≤ + ∑ k ∈ others, weights q k * vals k := by + refine Finset.sum_le_sum ?_ + intro k hk + exact + mul_le_mul_of_nonneg_left + (hvals_lower k) + (hweights_nonneg k) + have hsum_prod : + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k = + ∑ k, weights q k * vals k := by + simp [others] + have hout_eq : + dotProduct (weights q) vals = + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [dotProduct] using hsum_prod.symm + have hprev_lo : + weights q (c.prev q) * valsLoPrev ≤ + weights q (c.prev q) * vals (c.prev q) := by + exact mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) + have hdot_ge' : + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) ≤ + dotProduct (weights q) vals := by + have hle : + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) ≤ + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + exact add_le_add hprev_lo hsum_vals_ge + simpa [hout_eq, add_comm, add_left_comm, add_assoc] using hle + have hsum_weights : + (∑ k ∈ others, weights q k * valsLoPrev) = + (∑ k ∈ others, weights q k) * valsLoPrev := by + have hsum_mul : + (∑ k ∈ others, weights q k) * valsLoPrev = + ∑ k ∈ others, weights q k * valsLoPrev := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := valsLoPrev)) + exact hsum_mul.symm + have hsum_split : + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + (∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + calc + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + ∑ k ∈ others, + (weights q k * valsLoPrev - + weights q k * max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by + simp [mul_sub] + _ = + (∑ k ∈ others, weights q k * valsLoPrev) - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + simp [Finset.sum_sub_distrib] + _ = + (∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + simp [hsum_weights] + have hsplit : + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + calc + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) = + weights q (c.prev q) * valsLoPrev + + ((∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by + simp [hsum_split] + _ = + (weights q (c.prev q) + ∑ k ∈ others, weights q k) * valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + ring + _ = + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + simp [hsum] + have hweight_bound : + ∀ k ∈ others, weights q k ≤ (c.weightBoundAt q k : Real) := by + intro k hk + have hk' : k ≠ c.prev q := (Finset.mem_erase.mp hk).1 + exact hsound.weight_bounds_at q hq k hk' + have hsum_gap_le : + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + refine Finset.sum_le_sum ?_ + intro k hk + have hnonneg : + 0 ≤ max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := + le_max_left _ _ + exact mul_le_mul_of_nonneg_right (hweight_bound k hk) hnonneg + have hsub_le : + valsLoPrev - + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := by + exact sub_le_sub_left hsum_gap_le valsLoPrev + have hdot_lower : + valsLoPrev - + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - + ∑ k ∈ others, (c.weightBoundAt q k : Real) * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) ≤ + valsLoPrev - + ∑ k ∈ others, weights q k * + max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real)) := hsub_le + _ = + weights q (c.prev q) * valsLoPrev + + ∑ k ∈ others, weights q k * + (valsLoPrev - max (0 : Real) (valsLoPrev - (c.values.valsLo k : Real))) := by + simpa using hsplit.symm + _ ≤ dotProduct (weights q) vals := hdot_ge' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal' hdot_lower + simpa [headLogitDiff, weights, vals] using hle + +/-- The best available logit-diff lower bound is sound on active queries. -/ +theorem logitDiffLowerBoundFromCertBest_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (hsound : InductionHeadCertSound inputs c) + {lb : Rat} (hbound : logitDiffLowerBoundFromCertBest c = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases h0 : logitDiffLowerBoundFromCert c with + | none => + cases h1 : logitDiffLowerBoundFromCertWeighted c with + | none => + simp [logitDiffLowerBoundFromCertBest, h0, h1] at hbound + | some lb1 => + have hbound' : lb1 = lb := by + simpa [logitDiffLowerBoundFromCertBest, h0, h1] using hbound + cases hbound' + exact logitDiffLowerBoundFromCertWeighted_le inputs c hsound h1 hq + | some lb0 => + cases h1 : logitDiffLowerBoundFromCertWeighted c with + | none => + have hbound' : lb0 = lb := by + simpa [logitDiffLowerBoundFromCertBest, h0, h1] using hbound + cases hbound' + exact logitDiffLowerBoundFromCert_le inputs c hsound h0 hq + | some lb1 => + have hbound' : max lb0 lb1 = lb := by + simpa [logitDiffLowerBoundFromCertBest, h0, h1] using hbound + cases hbound' + have h0le : (lb0 : Real) ≤ headLogitDiff inputs q := + logitDiffLowerBoundFromCert_le inputs c hsound h0 hq + have h1le : (lb1 : Real) ≤ headLogitDiff inputs q := + logitDiffLowerBoundFromCertWeighted_le inputs c hsound h1 hq + have hmax : max (lb0 : Real) (lb1 : Real) ≤ headLogitDiff inputs q := + max_le_iff.mpr ⟨h0le, h1le⟩ + simpa [ratToReal_max, ratToReal_def] using hmax + +end WithNeZero + +/-! Head-output identities. -/ + +/-- The head logit-diff equals the direction dot product of the head output. -/ +theorem headLogitDiff_eq_direction_dot_headOutput + (inputs : Model.InductionHeadInputs seq dModel dHead) (q : Fin seq) : + headLogitDiff inputs q = + dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i) := by + classical + let dir : Fin dModel → Real := fun i => (inputs.direction i : Real) + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + have hswap : + dotProduct dir (fun i => headOutput inputs q i) = + ∑ k, weights q k * dotProduct dir (fun i => headValueRealOfInputs inputs k i) := by + calc + dotProduct dir (fun i => headOutput inputs q i) + = ∑ i, dir i * ∑ k, weights q k * headValueRealOfInputs inputs k i := by + simp [dir, headOutput_def, headOutputWithScores_def, weights, dotProduct] + _ = ∑ i, ∑ k, dir i * (weights q k * headValueRealOfInputs inputs k i) := by + simp [Finset.mul_sum] + _ = ∑ k, ∑ i, dir i * (weights q k * headValueRealOfInputs inputs k i) := by + simpa using + (Finset.sum_comm (s := (Finset.univ : Finset (Fin dModel))) + (t := (Finset.univ : Finset (Fin seq))) + (f := fun i k => dir i * (weights q k * headValueRealOfInputs inputs k i))) + _ = ∑ k, weights q k * ∑ i, dir i * headValueRealOfInputs inputs k i := by + refine Finset.sum_congr rfl ?_ + intro k _ + calc + ∑ i, dir i * (weights q k * headValueRealOfInputs inputs k i) = + ∑ i, weights q k * (dir i * headValueRealOfInputs inputs k i) := by + refine Finset.sum_congr rfl ?_ + intro i _ + simp [mul_assoc, mul_left_comm, mul_comm] + _ = weights q k * ∑ i, dir i * headValueRealOfInputs inputs k i := by + simp [Finset.mul_sum] + have hsum : + dotProduct dir (fun i => headOutput inputs q i) = + ∑ k, weights q k * valsRealOfInputs inputs k := by + calc + dotProduct dir (fun i => headOutput inputs q i) = + ∑ k, weights q k * dotProduct dir (fun i => headValueRealOfInputs inputs k i) := hswap + _ = ∑ k, weights q k * valsRealOfInputs inputs k := by + refine Finset.sum_congr rfl ?_ + intro k _ + have hdir := direction_dot_headValue_eq_valsReal (inputs := inputs) (k := k) + exact congrArg (fun x => weights q k * x) (by simpa [dir] using hdir) + calc + headLogitDiff inputs q = + dotProduct (weights q) (valsRealOfInputs inputs) := by + simp [headLogitDiff, weights] + _ = ∑ k, weights q k * valsRealOfInputs inputs k := by + simp [dotProduct] + _ = dotProduct (fun i => (inputs.direction i : Real)) + (fun i => headOutput inputs q i) := by + simpa [dir] using hsum.symm + +end LogitDiffLowerBound + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/LogitDiffSound.lean b/Nfp/Sound/Induction/LogitDiffSound.lean new file mode 100644 index 0000000..e4d329d --- /dev/null +++ b/Nfp/Sound/Induction/LogitDiffSound.lean @@ -0,0 +1,467 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.Sound.Induction.LogitDiff + +/-! +Soundness lemmas for logit-diff lower bounds with custom eps/values. +-/ + +public section + +namespace Nfp + +namespace Sound + +open Nfp.Circuit + +variable {seq dModel dHead : Nat} + +section WithNeZero + +variable [NeZero seq] + +/-- The unweighted logit-diff lower bound is sound for any valid per-query `epsAt`. -/ +theorem logitDiffLowerBoundFromCacheWithEps_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (epsAtCustom : Fin seq → Rat) + (hsound : InductionHeadCertSound inputs c) + (honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAtCustom q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k)) + {lb : Rat} + (hbound : + logitDiffLowerBoundFromCacheWithEps c (logitDiffCache c) epsAtCustom = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn (logitDiffCache c).valsLo + let epsAt : Fin (Nat.succ n) → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin (Nat.succ n) → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin (Nat.succ n) → Rat := fun q => + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let sumOthers : Real := ∑ k ∈ others, weights q k + let valsLoPrev : Real := (c.values.valsLo (c.prev q) : Real) + let loAtRat : Rat := loAt q + let loAtReal : Real := (loAtRat : Real) + have hboundRat : + lb ≤ valsLo (c.prev q) - + epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by + refine + Circuit.logitDiffLowerBoundAtLoAt_le + (active := c.active) + (prev := c.prev) + (epsAt := epsAt) + (loAt := loAt) + (valsLo := valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCacheWithEps_def, logitDiffCache_def, loAt, epsAt, valsLo, + valsLoArr, epsArr] using hbound + have hepsAt : epsAt q = epsAtCustom q := by + simp [epsAt, epsArr] + have hvalsLo : ∀ k, valsLo k = c.values.valsLo k := by + intro k + simp [valsLo, valsLoArr, logitDiffCache_def, Bounds.cacheBoundTask_apply] + have hboundRat' : + lb ≤ c.values.valsLo (c.prev q) - + epsAtCustom q * max (0 : Rat) (c.values.valsLo (c.prev q) - loAt q) := by + simpa [hepsAt, hvalsLo] using hboundRat + have hboundReal : + (lb : Real) ≤ + valsLoPrev - (epsAtCustom q : Real) * + max (0 : Real) (valsLoPrev - loAtReal) := by + simpa [loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, ratToReal_def] + using ratToReal_le_of_le hboundRat' + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + have hweights := honeHot q hq + simpa [weights] using hweights.nonneg q rfl + have hweights := honeHot q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using hweights.sum_one q rfl + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hsum_one + have hsum_others_le : sumOthers ≤ (epsAtCustom q : Real) := by + have hprev : 1 ≤ weights q (c.prev q) + (epsAtCustom q : Real) := + hweights.prev_large q rfl + have hprev' : + weights q (c.prev q) + sumOthers ≤ + weights q (c.prev q) + (epsAtCustom q : Real) := by + simpa [hsum, sumOthers] using hprev + exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' + have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ c.values.valsLo k := by + intro k hk + have hnonempty : others.Nonempty := ⟨k, hk⟩ + have hmin : others.inf' hnonempty valsLo ≤ valsLo k := + Finset.inf'_le (s := others) (f := valsLo) hk + have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by + simpa [others] using hnonempty + have hloAt : loAtRat = others.inf' hnonempty valsLo := by + dsimp [loAtRat, loAt] + simp [hnonempty', others] + have hvalsLo' : valsLo k = c.values.valsLo k := hvalsLo k + calc + loAtRat = others.inf' hnonempty valsLo := hloAt + _ ≤ valsLo k := hmin + _ = c.values.valsLo k := hvalsLo' + have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by + intro k hk + have hloRat := hloAt_le_valsLo k hk + have hloReal : loAtReal ≤ (c.values.valsLo k : Real) := by + simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) + have hvals : (c.values.valsLo k : Real) ≤ vals k := by + simpa using (hsound.value_bounds.vals_bounds k).1 + exact le_trans hloReal hvals + have hvalsLo_prev : valsLoPrev ≤ vals (c.prev q) := by + exact (hsound.value_bounds.vals_bounds (c.prev q)).1 + have hsum_lo : + sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by + have hsum_lo' : + (∑ k ∈ others, weights q k) * loAtReal = + ∑ k ∈ others, weights q k * loAtReal := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) + simpa [sumOthers] using hsum_lo' + have hsum_vals_ge : + sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by + have hle : + ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by + intro k _hk + have hval := hvals_lo k _hk + have hnonneg := hweights_nonneg k + exact mul_le_mul_of_nonneg_left hval hnonneg + have hsum' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + exact Finset.sum_le_sum hle + simpa [hsum_lo] using hsum' + have hsum_vals_ge' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + simpa [hsum_lo] using hsum_vals_ge + have hsum_nonneg : 0 ≤ sumOthers := by + have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by + intro k hk + exact hweights_nonneg k + have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by + exact Finset.sum_nonneg hnonneg + simpa [sumOthers] using hsum_nonneg' + have hsplit : + weights q (c.prev q) = 1 - sumOthers := by + have hsum' : weights q (c.prev q) + sumOthers = 1 := by + simpa [sumOthers] using hsum + exact (eq_sub_iff_add_eq).2 hsum' + have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_right _ _ + have hsum_mul_le_left : + sumOthers * (valsLoPrev - loAtReal) ≤ + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg + have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_left _ _ + have hsum_mul_le : + sumOthers * (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + have hsum_mul_le_right : + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg + exact le_trans hsum_mul_le_left hsum_mul_le_right + have hsub_le : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + exact sub_le_sub_left hsum_mul_le valsLoPrev + have hdot_lower : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by + have hsplit_calc : + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) = + (1 - sumOthers) * valsLoPrev + sumOthers * loAtReal := by + ring + simpa [hsplit] using hsplit_calc + _ ≤ dotProduct (weights q) vals := by + have hprev_le := mul_le_mul_of_nonneg_left hvalsLo_prev (hweights_nonneg (c.prev q)) + have hdot_ge : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_right hprev_le (sumOthers * loAtReal)) + have hdot_ge' : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + calc + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal + = weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * loAtReal := by + simp [hsum_lo] + _ ≤ weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_vals_ge' + (weights q (c.prev q) * vals (c.prev q))) + _ = dotProduct (weights q) vals := by + simp [dotProduct, others] + exact le_trans hdot_ge hdot_ge' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff_def, weights, vals] using hle + +/-- The unweighted logit-diff lower bound is sound for custom eps and values. -/ +theorem logitDiffLowerBoundFromCacheWithEpsVals_le + (inputs : Model.InductionHeadInputs seq dModel dHead) + (c : InductionHeadCert seq) (epsAtCustom valsLoCustom : Fin seq → Rat) + (hsound : InductionHeadCertSound inputs c) + (honeHot : + ∀ q, q ∈ c.active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAtCustom q : Real) + (fun q' => q' = q) c.prev + (fun q' k => Circuit.softmax (scoresRealOfInputs inputs q') k)) + (hvalsLo : + ∀ k, (valsLoCustom k : Real) ≤ valsRealOfInputs inputs k) + {lb : Rat} + (hbound : + logitDiffLowerBoundFromCacheWithEpsVals c epsAtCustom valsLoCustom = some lb) + {q : Fin seq} (hq : q ∈ c.active) : + (lb : Real) ≤ headLogitDiff inputs q := by + classical + cases seq with + | zero => + cases (NeZero.ne (n := (0 : Nat)) rfl) + | succ n => + let weights : Fin (Nat.succ n) → Fin (Nat.succ n) → Real := fun q k => + Circuit.softmax (scoresRealOfInputs inputs q) k + let vals : Fin (Nat.succ n) → Real := valsRealOfInputs inputs + let epsArr : Array Rat := Array.ofFn epsAtCustom + let valsLoArr : Array Rat := Array.ofFn valsLoCustom + let epsAt : Fin (Nat.succ n) → Rat := fun q => + epsArr[q.1]'(by + simp [epsArr, q.isLt]) + let valsLo : Fin (Nat.succ n) → Rat := fun q => + valsLoArr[q.1]'(by + simp [valsLoArr, q.isLt]) + let loAt : Fin (Nat.succ n) → Rat := fun q => + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + if h : others.Nonempty then + others.inf' h valsLo + else + c.values.lo + let others : Finset (Fin (Nat.succ n)) := + (Finset.univ : Finset (Fin (Nat.succ n))).erase (c.prev q) + let sumOthers : Real := ∑ k ∈ others, weights q k + let valsLoPrev : Real := (valsLo (c.prev q) : Real) + let loAtRat : Rat := loAt q + let loAtReal : Real := (loAtRat : Real) + have hvalsLo_eq : ∀ k, valsLo k = valsLoCustom k := by + intro k + simp [valsLo, valsLoArr] + have hboundRat : + lb ≤ valsLo (c.prev q) - + epsAt q * max (0 : Rat) (valsLo (c.prev q) - loAt q) := by + refine + Circuit.logitDiffLowerBoundAtLoAt_le + (active := c.active) + (prev := c.prev) + (epsAt := epsAt) + (loAt := loAt) + (valsLo := valsLo) + q hq lb ?_ + simpa [logitDiffLowerBoundFromCacheWithEpsVals_def, loAt, epsAt, valsLo, valsLoArr, epsArr] + using hbound + have hepsAt : epsAt q = epsAtCustom q := by + simp [epsAt, epsArr] + have hboundRat' : + lb ≤ valsLoCustom (c.prev q) - + epsAtCustom q * max (0 : Rat) (valsLoCustom (c.prev q) - loAt q) := by + simpa [hepsAt, hvalsLo_eq] using hboundRat + have hboundReal : + (lb : Real) ≤ + valsLoPrev - (epsAtCustom q : Real) * + max (0 : Real) (valsLoPrev - loAtReal) := by + have hvalsLoPrev_eq : valsLoPrev = (valsLoCustom (c.prev q) : Real) := by + simp [valsLoPrev, valsLo, valsLoArr] + simpa [hvalsLoPrev_eq, loAtRat, loAtReal, ratToReal_sub, ratToReal_mul, ratToReal_max, + ratToReal_def] + using ratToReal_le_of_le hboundRat' + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + have hweights := honeHot q hq + simpa [weights] using hweights.nonneg q rfl + have hweights := honeHot q hq + have hsum_decomp : + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := by + simp [others] + have hsum : + weights q (c.prev q) + ∑ k ∈ others, weights q k = 1 := by + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights] using hweights.sum_one q rfl + calc + weights q (c.prev q) + ∑ k ∈ others, weights q k = ∑ k, weights q k := hsum_decomp + _ = 1 := hsum_one + have hsum_others_le : sumOthers ≤ (epsAtCustom q : Real) := by + have hprev : 1 ≤ weights q (c.prev q) + (epsAtCustom q : Real) := + hweights.prev_large q rfl + have hprev' : + weights q (c.prev q) + sumOthers ≤ + weights q (c.prev q) + (epsAtCustom q : Real) := by + simpa [hsum, sumOthers] using hprev + exact (add_le_add_iff_left (weights q (c.prev q))).1 hprev' + have hvalsLo_real : ∀ k, (valsLo k : Real) ≤ vals k := by + intro k + have hvals := hvalsLo k + simpa [valsLo, valsLoArr, vals] using hvals + have hprev_lo : valsLoPrev ≤ vals (c.prev q) := by + simpa [valsLoPrev] using hvalsLo_real (c.prev q) + have hloAt_le_valsLo : ∀ k ∈ others, loAtRat ≤ valsLo k := by + intro k hk + have hnonempty : others.Nonempty := ⟨k, hk⟩ + have hmin : others.inf' hnonempty valsLo ≤ valsLo k := + Finset.inf'_le (s := others) (f := valsLo) hk + have hnonempty' : (Finset.univ.erase (c.prev q)).Nonempty := by + simpa [others] using hnonempty + have hloAt : loAtRat = others.inf' hnonempty valsLo := by + dsimp [loAtRat, loAt] + simp [hnonempty', others] + calc + loAtRat = others.inf' hnonempty valsLo := hloAt + _ ≤ valsLo k := hmin + have hvals_lo : ∀ k ∈ others, loAtReal ≤ vals k := by + intro k hk + have hloRat := hloAt_le_valsLo k hk + have hloReal : loAtReal ≤ (valsLo k : Real) := by + simpa [loAtReal, ratToReal_def] using (ratToReal_le_of_le hloRat) + have hvalsReal : (valsLo k : Real) ≤ vals k := hvalsLo_real k + exact le_trans hloReal hvalsReal + have hsum_nonneg : 0 ≤ sumOthers := by + have hnonneg : ∀ k ∈ others, 0 ≤ weights q k := by + intro k hk + exact hweights_nonneg k + have hsum_nonneg' : 0 ≤ ∑ k ∈ others, weights q k := by + exact Finset.sum_nonneg hnonneg + simpa [sumOthers] using hsum_nonneg' + have hsplit : + weights q (c.prev q) = 1 - sumOthers := by + have hsum' : weights q (c.prev q) + sumOthers = 1 := by + simpa [sumOthers] using hsum + exact (eq_sub_iff_add_eq).2 hsum' + have hdiff_le : valsLoPrev - loAtReal ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_right _ _ + have hsum_mul_le_left : + sumOthers * (valsLoPrev - loAtReal) ≤ + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_left hdiff_le hsum_nonneg + have hmax_nonneg : 0 ≤ max (0 : Real) (valsLoPrev - loAtReal) := by + exact le_max_left _ _ + have hsum_mul_le : + sumOthers * (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + have hsum_mul_le_right : + sumOthers * max (0 : Real) (valsLoPrev - loAtReal) ≤ + (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) := by + exact mul_le_mul_of_nonneg_right hsum_others_le hmax_nonneg + exact le_trans hsum_mul_le_left hsum_mul_le_right + have hsub_le : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := by + exact sub_le_sub_left hsum_mul_le valsLoPrev + have hsum_lo : + sumOthers * loAtReal = ∑ k ∈ others, weights q k * loAtReal := by + have hsum' : + (∑ k ∈ others, weights q k) * loAtReal = + ∑ k ∈ others, weights q k * loAtReal := by + simpa using + (Finset.sum_mul (s := others) (f := fun k => weights q k) (a := loAtReal)) + simpa [sumOthers] using hsum' + have hsum_vals_ge : + sumOthers * loAtReal ≤ ∑ k ∈ others, weights q k * vals k := by + have hle : ∀ k ∈ others, weights q k * loAtReal ≤ weights q k * vals k := by + intro k hk + have hlo := hvals_lo k hk + have hnonneg := hweights_nonneg k + exact mul_le_mul_of_nonneg_left hlo hnonneg + have hle' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + exact Finset.sum_le_sum hle + simpa [hsum_lo] using hle' + have hsum_vals_ge' : + ∑ k ∈ others, weights q k * loAtReal ≤ + ∑ k ∈ others, weights q k * vals k := by + simpa [hsum_lo] using hsum_vals_ge + have hdot_ge'' : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + have hdot_ge : + weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal ≤ + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal := by + have hprev_le := mul_le_mul_of_nonneg_left hprev_lo (hweights_nonneg (c.prev q)) + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_right hprev_le (sumOthers * loAtReal)) + have hdot_ge' : + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal ≤ + dotProduct (weights q) vals := by + calc + weights q (c.prev q) * vals (c.prev q) + sumOthers * loAtReal = + weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * loAtReal := by + simp [hsum_lo] + _ ≤ weights q (c.prev q) * vals (c.prev q) + + ∑ k ∈ others, weights q k * vals k := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_vals_ge' + (weights q (c.prev q) * vals (c.prev q))) + _ = dotProduct (weights q) vals := by + simp [dotProduct, others] + exact le_trans hdot_ge hdot_ge' + have hdot_lower : + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + dotProduct (weights q) vals := by + calc + valsLoPrev - (epsAtCustom q : Real) * max (0 : Real) (valsLoPrev - loAtReal) ≤ + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) := hsub_le + _ = weights q (c.prev q) * valsLoPrev + sumOthers * loAtReal := by + have hsplit_calc : + valsLoPrev - sumOthers * (valsLoPrev - loAtReal) = + (1 - sumOthers) * valsLoPrev + sumOthers * loAtReal := by + ring + simpa [hsplit] using hsplit_calc + _ ≤ dotProduct (weights q) vals := hdot_ge'' + have hle : (lb : Real) ≤ dotProduct (weights q) vals := + le_trans hboundReal hdot_lower + simpa [headLogitDiff_def, weights, vals] using hle + +end WithNeZero + +end Sound + +end Nfp diff --git a/Nfp/Sound/Induction/OneHot.lean b/Nfp/Sound/Induction/OneHot.lean new file mode 100644 index 0000000..71d9ca9 --- /dev/null +++ b/Nfp/Sound/Induction/OneHot.lean @@ -0,0 +1,485 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Algebra.BigOperators.Group.Finset.Basic +public import Mathlib.Algebra.Order.BigOperators.Group.Finset +public import Mathlib.Data.Rat.BigOperators +public import Nfp.Core.Basic +public import Nfp.Circuit.Layers.Induction +public import Nfp.Circuit.Layers.Softmax + +/-! +Per-query one-hot bounds derived from score margins. +-/ + +public section + +namespace Nfp + +namespace Sound + +open scoped BigOperators + +open Nfp.Circuit + +variable {seq : Nat} [NeZero seq] + +/-- One-hot bounds on a single active query, derived from a per-query margin. -/ +theorem oneHot_bounds_at_of_marginAt + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (marginAt : Fin seq → Rat) + (epsAt : Fin seq → Rat) + (hepsAt : + ∀ q, epsAt q = + if marginAt q < 0 then (1 : Rat) else + ratDivUp (seq - 1) (1 + marginAt q)) + (hseq : (1 : Nat) ≤ seq) + (hscore_margin_real_at : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + scoresReal q k + (marginAt q : Real) ≤ scoresReal q (prev q)) : + ∀ q, q ∈ active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) prev + (fun q k => Circuit.softmax (scoresReal q) k) := by + classical + intro q hq + let softmaxWeights := Circuit.softmaxWeights scoresReal + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + intro k + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by + by_cases hneg : marginAt q < 0 + · have heps : (epsAt q : Real) = 1 := by + simp [hepsAt, hneg] + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k _ _ + exact hweights_nonneg k) + simpa [heps, hsum_one] using hsum_le + · have hnonneg : 0 ≤ marginAt q := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (marginAt q : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hnonneg + have hbound : + ∀ k ∈ others q, + weights q k ≤ (1 + (marginAt q : Real))⁻¹ := by + intro k hk + have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 + have hscore := hscore_margin_real_at q hq k hkne + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := k) (m := (marginAt q : Real)) + hnonneg_real hscore) + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹ := + Finset.sum_le_sum hbound + have hsum_const : + (∑ k ∈ others q, (1 + (marginAt q : Real))⁻¹) = + (others q).card * (1 + (marginAt q : Real))⁻¹ := by + simp + have hcard : (others q).card = seq - 1 := by + simp [others, Finset.card_erase_of_mem] + have hsum_le' : + (∑ k ∈ others q, weights q k) ≤ + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ := by + have hsum_le'' := hsum_le.trans_eq hsum_const + have hsum_le''' := hsum_le'' + simp only [hcard, Nat.cast_sub hseq, Nat.cast_one] at hsum_le''' + exact hsum_le''' + have heps : + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ (epsAt q : Real) := by + have hden : (1 + marginAt q) ≠ 0 := by + intro hzero + have hrat : (1 : Rat) + marginAt q = 0 := by + simpa using hzero + have hnonneg_rat : (0 : Rat) ≤ marginAt q := hnonneg + linarith + have hrat : + (seq - 1 : Real) * (1 + (marginAt q : Real))⁻¹ ≤ + (ratDivUp (seq - 1) (1 + marginAt q) : Real) := by + have hrat' := ratDivUp_ge_real (seq - 1) (1 + marginAt q) hden + simpa [ratToReal_def, Rat.cast_div, Rat.cast_add, Rat.cast_natCast, + div_eq_mul_inv] using hrat' + simpa [hepsAt, hneg] using hrat + exact le_trans hsum_le' heps + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + subst q' + exact hweights_nonneg k + · intro q' hq' + subst q' + exact hsum_one + · intro q' hq' + subst q' + have hsum_eq : + weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + calc + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (prev q) + (epsAt q : Real) := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_others_le (weights q (prev q))) + have hprev : + 1 ≤ weights q (prev q) + (epsAt q : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q' hq' k hk + subst q' + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + exact hweights_nonneg j + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + +/-- One-hot bounds on a single active query, derived from per-key score gaps. -/ +theorem oneHot_bounds_at_of_scoreGapLo + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (scoreGapLo : Fin seq → Fin seq → Rat) + (epsAt : Fin seq → Rat) + (hepsAt : + ∀ q, epsAt q = + min (1 : Rat) + ((Finset.univ : Finset (Fin seq)).erase (prev q) |>.sum (fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)))) + (hscore_gap_real_at : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (prev q)) : + ∀ q, q ∈ active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) prev + (fun q k => Circuit.softmax (scoresReal q) k) := by + classical + intro q hq + let softmaxWeights := Circuit.softmaxWeights scoresReal + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + let bound : Fin seq → Rat := fun k => + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k) + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + intro k + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q + have hsum_others_le_one : (∑ k ∈ others q, weights q k) ≤ 1 := by + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k _ _ + exact hweights_nonneg k) + simpa [hsum_one] using hsum_le + have hbound : + ∀ k ∈ others q, weights q k ≤ (bound k : Real) := by + intro k hk + have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 + by_cases hneg : scoreGapLo q k < 0 + · have hle : weights q k ≤ 1 := by + simpa [weights] using + (Circuit.softmax_le_one (scores := scoresReal q) k) + simpa [bound, hneg] using hle + · have hnonneg : 0 ≤ scoreGapLo q k := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (scoreGapLo q k : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hnonneg + have hscore := hscore_gap_real_at q hq k hkne + have hsoft : + weights q k ≤ 1 / (1 + (scoreGapLo q k : Real)) := by + simpa [weights] using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := k) (m := (scoreGapLo q k : Real)) + hnonneg_real hscore) + have hpos : (0 : Rat) < 1 + scoreGapLo q k := by + have hle : (1 : Rat) ≤ 1 + scoreGapLo q k := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le zero_lt_one hle + have hden : (1 + scoreGapLo q k) ≠ 0 := by + exact ne_of_gt hpos + have hrat : + 1 / (1 + (scoreGapLo q k : Real)) ≤ + ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := by + simpa [ratToReal_def] using + (ratDivUp_ge_real 1 (1 + scoreGapLo q k) hden) + have hbound' : + weights q k ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := + hsoft.trans hrat + simpa [bound, hneg, ratToReal_def] using hbound' + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (bound k : Real) := + Finset.sum_le_sum hbound + have hsum_le_min : + (∑ k ∈ others q, weights q k) ≤ + min (1 : Real) (∑ k ∈ others q, (bound k : Real)) := by + exact le_min hsum_others_le_one hsum_le + have hepsAtReal : + (epsAt q : Real) = min (1 : Real) (∑ k ∈ others q, (bound k : Real)) := by + have h' : epsAt q = min 1 ((others q).sum bound) := by + simpa only [others, bound] using hepsAt q + have h'' : + ratToReal (epsAt q) = ratToReal (min 1 ((others q).sum bound)) := by + exact congrArg ratToReal h' + -- Avoid rewriting the erased-sum into a difference. + simpa [ratToReal_min, ratToReal_def, Rat.cast_sum] using h'' + simpa [hepsAtReal] using hsum_le_min + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + subst q' + exact hweights_nonneg k + · intro q' hq' + subst q' + exact hsum_one + · intro q' hq' + subst q' + have hsum_eq : + weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + calc + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (prev q) + (epsAt q : Real) := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_others_le (weights q (prev q))) + have hprev : + 1 ≤ weights q (prev q) + (epsAt q : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q' hq' k hk + subst q' + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + exact hweights_nonneg j + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + +/-- One-hot bounds on a single active query, derived from per-key weight bounds. -/ +theorem oneHot_bounds_at_of_weight_bounds + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (weightBoundAt : Fin seq → Fin seq → Rat) + (epsAt : Fin seq → Rat) + (hepsAt : + ∀ q, epsAt q = + min (1 : Rat) + ((Finset.univ : Finset (Fin seq)).erase (prev q) |>.sum (fun k => + weightBoundAt q k))) + (hweight_bounds : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + Circuit.softmax (scoresReal q) k ≤ (weightBoundAt q k : Real)) : + ∀ q, q ∈ active → + Layers.OneHotApproxBoundsOnActive (Val := Real) (epsAt q : Real) + (fun q' => q' = q) prev + (fun q k => Circuit.softmax (scoresReal q) k) := by + classical + intro q hq + let softmaxWeights := Circuit.softmaxWeights scoresReal + let weights : Fin seq → Fin seq → Real := fun q k => + Circuit.softmax (scoresReal q) k + let others : Fin seq → Finset (Fin seq) := fun q => + (Finset.univ : Finset (Fin seq)).erase (prev q) + have hweights_nonneg : ∀ k, 0 ≤ weights q k := by + intro k + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.nonneg q k + have hsum_one : (∑ k, weights q k) = 1 := by + simpa [weights, softmaxWeights, Circuit.softmaxWeights_weights] using + softmaxWeights.sum_one q + have hsum_others_le_one : (∑ k ∈ others q, weights q k) ≤ 1 := by + have hsubset : others q ⊆ (Finset.univ : Finset (Fin seq)) := by + intro k hk + simp + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ (Finset.univ : Finset (Fin seq)), weights q k := + Finset.sum_le_sum_of_subset_of_nonneg hsubset (by + intro k _ _ + exact hweights_nonneg k) + simpa [hsum_one] using hsum_le + have hbound : + ∀ k ∈ others q, weights q k ≤ (weightBoundAt q k : Real) := by + intro k hk + have hkne : k ≠ prev q := (Finset.mem_erase.mp hk).1 + exact hweight_bounds q hq k hkne + have hsum_others_le : (∑ k ∈ others q, weights q k) ≤ (epsAt q : Real) := by + have hsum_le : + (∑ k ∈ others q, weights q k) ≤ + ∑ k ∈ others q, (weightBoundAt q k : Real) := + Finset.sum_le_sum hbound + have hsum_le_min : + (∑ k ∈ others q, weights q k) ≤ + min (1 : Real) (∑ k ∈ others q, (weightBoundAt q k : Real)) := by + exact le_min hsum_others_le_one hsum_le + have hepsAtReal : + (epsAt q : Real) = min (1 : Real) (∑ k ∈ others q, (weightBoundAt q k : Real)) := by + have h' : epsAt q = min 1 ((others q).sum (fun k => weightBoundAt q k)) := by + simpa [others] using hepsAt q + have h'' : + ratToReal (epsAt q) = + ratToReal (min 1 ((others q).sum (fun k => weightBoundAt q k))) := by + exact congrArg ratToReal h' + simpa [ratToReal_min, ratToReal_def, Rat.cast_sum] using h'' + simpa [hepsAtReal] using hsum_le_min + refine + { nonneg := ?_ + sum_one := ?_ + prev_large := ?_ + other_le := ?_ } + · intro q' hq' k + subst q' + exact hweights_nonneg k + · intro q' hq' + subst q' + exact hsum_one + · intro q' hq' + subst q' + have hsum_eq : + weights q (prev q) + ∑ k ∈ others q, weights q k = 1 := by + have hsum' : + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := by + simp [others] + calc + weights q (prev q) + ∑ k ∈ others q, weights q k = + ∑ k, weights q k := hsum' + _ = 1 := hsum_one + have hsum_le' : + weights q (prev q) + ∑ k ∈ others q, weights q k ≤ + weights q (prev q) + (epsAt q : Real) := by + simpa [add_comm, add_left_comm, add_assoc] using + (add_le_add_left hsum_others_le (weights q (prev q))) + have hprev : + 1 ≤ weights q (prev q) + (epsAt q : Real) := by + simpa [hsum_eq] using hsum_le' + exact hprev + · intro q' hq' k hk + subst q' + have hk' : k ∈ others q := by + simp [others, hk] + have hnonneg : + ∀ j ∈ others q, 0 ≤ weights q j := by + intro j _ + exact hweights_nonneg j + have hle : + weights q k ≤ ∑ j ∈ others q, weights q j := by + have h := Finset.single_le_sum hnonneg hk' + simpa using h + exact hle.trans hsum_others_le + +/-- Per-key weight bounds on a single active query, derived from per-key score gaps. -/ +theorem weight_bound_at_of_scoreGapLo + (active : Finset (Fin seq)) + (prev : Fin seq → Fin seq) + (scoresReal : Fin seq → Fin seq → Real) + (scoreGapLo : Fin seq → Fin seq → Rat) + (weightBoundAt : Fin seq → Fin seq → Rat) + (hweightBoundAt : + ∀ q k, k ≠ prev q → + weightBoundAt q k = + if scoreGapLo q k < 0 then + (1 : Rat) + else + ratDivUp 1 (1 + scoreGapLo q k)) + (hscore_gap_real_at : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + scoresReal q k + (scoreGapLo q k : Real) ≤ scoresReal q (prev q)) : + ∀ q, q ∈ active → ∀ k, k ≠ prev q → + Circuit.softmax (scoresReal q) k ≤ (weightBoundAt q k : Real) := by + classical + intro q hq k hk + by_cases hneg : scoreGapLo q k < 0 + · have hle : Circuit.softmax (scoresReal q) k ≤ 1 := by + simpa using (Circuit.softmax_le_one (scores := scoresReal q) k) + simpa [hweightBoundAt q k hk, hneg] using hle + · have hnonneg : 0 ≤ scoreGapLo q k := le_of_not_gt hneg + have hnonneg_real : 0 ≤ (scoreGapLo q k : Real) := by + simpa [ratToReal_def] using ratToReal_nonneg_of_nonneg hnonneg + have hscore := hscore_gap_real_at q hq k hk + have hsoft : + Circuit.softmax (scoresReal q) k ≤ 1 / (1 + (scoreGapLo q k : Real)) := by + simpa using + (Circuit.softmax_other_le_inv_one_add (scores := scoresReal q) + (prev := prev q) (k := k) (m := (scoreGapLo q k : Real)) + hnonneg_real hscore) + have hpos : (0 : Rat) < 1 + scoreGapLo q k := by + have hle : (1 : Rat) ≤ 1 + scoreGapLo q k := by + exact le_add_of_nonneg_right hnonneg + exact lt_of_lt_of_le zero_lt_one hle + have hden : (1 + scoreGapLo q k) ≠ 0 := by + exact ne_of_gt hpos + have hrat : + 1 / (1 + (scoreGapLo q k : Real)) ≤ + ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := by + simpa [ratToReal_def] using + (ratDivUp_ge_real 1 (1 + scoreGapLo q k) hden) + have hbound' : + Circuit.softmax (scoresReal q) k ≤ ratToReal (ratDivUp 1 (1 + scoreGapLo q k)) := + hsoft.trans hrat + simpa [hweightBoundAt q k hk, hneg, ratToReal_def] using hbound' + +end Sound + +end Nfp diff --git a/Nfp/Sound/Interval.lean b/Nfp/Sound/Interval.lean deleted file mode 100644 index b4c76fa..0000000 --- a/Nfp/Sound/Interval.lean +++ /dev/null @@ -1,316 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Bounds - -namespace Nfp.Sound - -/-! -# Rational intervals for SOUND-mode local certification - -This file provides a minimal-but-complete `Rat` interval arithmetic library used by -streaming Interval Bound Propagation (IBP) in `Nfp.Untrusted.SoundCompute`, with trusted -verification wrappers in `Nfp.Sound.IO`. - -All operations are conservative (over-approximations). --/ - -/-- Closed rational interval `[lo, hi]`. -/ -structure RatInterval where - lo : Rat - hi : Rat - deriving Repr - -namespace RatInterval - -instance : Inhabited RatInterval := ⟨{ lo := 0, hi := 0 }⟩ - -/-- Singleton interval `[r, r]`. -/ -def const (r : Rat) : RatInterval := { lo := r, hi := r } - -/-- Interval addition: `[a.lo + b.lo, a.hi + b.hi]`. -/ -def add (a b : RatInterval) : RatInterval := - { lo := a.lo + b.lo, hi := a.hi + b.hi } - -/-- Interval subtraction: `[a.lo - b.hi, a.hi - b.lo]`. -/ -def sub (a b : RatInterval) : RatInterval := - { lo := a.lo - b.hi, hi := a.hi - b.lo } - -/-- Scale an interval by a rational `c`, handling sign. -/ -def scale (c : Rat) (a : RatInterval) : RatInterval := - if c ≥ 0 then - { lo := c * a.lo, hi := c * a.hi } - else - { lo := c * a.hi, hi := c * a.lo } - -/-- ReLU over-approximation: `[max(0, lo), max(0, hi)]`. -/ -def relu (a : RatInterval) : RatInterval := - { lo := max 0 a.lo, hi := max 0 a.hi } - -/-- Union (hull) of intervals: `[min(a.lo,b.lo), max(a.hi,b.hi)]`. -/ -def union (a b : RatInterval) : RatInterval := - { lo := min a.lo b.lo, hi := max a.hi b.hi } - -/-- Whether the interval contains 0. -/ -def containsZero (a : RatInterval) : Bool := - decide (a.lo ≤ 0 ∧ 0 ≤ a.hi) - -/-- Upper bound on `|x - μ|` for any `x, μ ∈ [lo, hi]`. -/ -def centeredAbsBound (a : RatInterval) : Rat := - ratAbs (a.hi - a.lo) - -private def ratSq (x : Rat) : Rat := x * x - -/-- Lower bound on `x^2` over an interval. - -If the interval contains 0, this is 0. Otherwise it is the squared distance to 0 of the -endpoint with smaller absolute value. --/ -def squareLowerBound (a : RatInterval) : Rat := - if containsZero a then - 0 - else - let alo := ratAbs a.lo - let ahi := ratAbs a.hi - let m := min alo ahi - ratSq m - -/-- Mean interval of a coordinate-wise box. - -For `xᵢ ∈ [loᵢ, hiᵢ]`, we have `μ = (1/n)∑ xᵢ ∈ [(1/n)∑ loᵢ, (1/n)∑ hiᵢ]`. --/ -def mean (xs : Array RatInterval) : RatInterval := - if xs.isEmpty then - const 0 - else - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - let (loSum, hiSum) := - xs.foldl (fun (acc : Rat × Rat) x => (acc.1 + x.lo, acc.2 + x.hi)) (0, 0) - { lo := loSum / nRat, hi := hiSum / nRat } - -/-- Sound lower bound on the variance of a coordinate-wise box. - -We return a conservative lower bound on `var(x)` for all `x` in the box. - -We compute the exact minimum via a 1D convex minimization: - -`var(x) = (1/n) ∑ (xᵢ - mean(x))^2 = min_c (1/n) ∑ (xᵢ - c)^2`. - -Therefore, - -`min_{x∈box} var(x) = min_c (1/n) ∑ min_{xᵢ∈[lᵢ,uᵢ]} (xᵢ - c)^2` - -and for fixed `c` each coordinate minimization is `dist([lᵢ,uᵢ], c)^2` where `dist` is 0 if -`c ∈ [lᵢ,uᵢ]` and otherwise the squared distance to the nearer endpoint. - -The resulting one-dimensional function of `c` is convex piecewise-quadratic, so we can find its -global minimum by scanning the sorted breakpoints `{lᵢ,uᵢ}` and checking the unique stationary point -in each region (plus the breakpoints themselves). --/ -def varianceLowerBound (xs : Array RatInterval) : Rat := - if xs.isEmpty then - 0 - else - Id.run do - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - -- Normalize endpoints defensively. - let normed : Array RatInterval := - xs.map (fun x => { lo := min x.lo x.hi, hi := max x.lo x.hi }) - if n < 2 then - return 0 - -- Build sorted breakpoint lists for `lo` and `hi` with squared endpoints for O(1) evaluation. - let mut enters : Array (Rat × Rat) := Array.mkEmpty n - let mut leaves : Array (Rat × Rat) := Array.mkEmpty n - let mut sumLeft : Rat := 0 - let mut sumLeftSq : Rat := 0 - for x in normed do - let lo := x.lo - let hi := x.hi - enters := enters.push (lo, ratSq lo) - leaves := leaves.push (hi, ratSq hi) - sumLeft := sumLeft + lo - sumLeftSq := sumLeftSq + ratSq lo - -- Exact minimization over the breakpoints (O(n log n)). - let entersSorted := enters.qsort (fun a b => a.1 ≤ b.1) - let leavesSorted := leaves.qsort (fun a b => a.1 ≤ b.1) - let breaksAll := - (entersSorted.map (fun p => p.1) ++ leavesSorted.map (fun p => p.1)).qsort (· ≤ ·) - -- Unique-ify breakpoints. - let mut breaks : Array Rat := Array.mkEmpty breaksAll.size - for b in breaksAll do - if breaks.isEmpty then - breaks := breaks.push b - else if breaks.back! = b then - pure () - else - breaks := breaks.push b - let evalG (c : Rat) (leftCount rightCount : Nat) - (sumLeft sumLeftSq sumRight sumRightSq : Rat) : Rat := - let cSq := ratSq c - let leftTerm := - sumLeftSq - (2 : Rat) * c * sumLeft + ((leftCount : Nat) : Rat) * cSq - let rightTerm := - ((rightCount : Nat) : Rat) * cSq - (2 : Rat) * c * sumRight + sumRightSq - leftTerm + rightTerm - -- State for scanning regions: - -- left set: intervals with `c < lo` - -- right set: intervals with `c > hi` - let mut leftCount : Nat := n - let mut rightCount : Nat := 0 - let mut sumRight : Rat := 0 - let mut sumRightSq : Rat := 0 - let mut iEnter : Nat := 0 - let mut iLeave : Nat := 0 - let mut bestG : Rat := - evalG (sumLeft / nRat) leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq - if !breaks.isEmpty then - bestG := min bestG - (evalG breaks[0]! leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - for bi in [:breaks.size] do - let b := breaks[bi]! - -- Process enters at `b` (at `c=b` these are already inside, so not in left). - while iEnter < entersSorted.size && entersSorted[iEnter]!.1 = b do - let (_, loSq) := entersSorted[iEnter]! - let lo := entersSorted[iEnter]!.1 - leftCount := leftCount - 1 - sumLeft := sumLeft - lo - sumLeftSq := sumLeftSq - loSq - iEnter := iEnter + 1 - -- Evaluate at the breakpoint `c=b` (intervals with `hi=b` are still inside at `c=b`). - bestG := min bestG (evalG b leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - -- After `b`, process leaves at `b` (those intervals become right for `c>b`). - while iLeave < leavesSorted.size && leavesSorted[iLeave]!.1 = b do - let (_, hiSq) := leavesSorted[iLeave]! - let hi := leavesSorted[iLeave]!.1 - rightCount := rightCount + 1 - sumRight := sumRight + hi - sumRightSq := sumRightSq + hiSq - iLeave := iLeave + 1 - -- Check stationary point in the region `(b, nextB)` if there is a next breakpoint. - let outsideCount : Nat := leftCount + rightCount - if outsideCount = 0 then - -- There exists `c` contained in every interval, so the box intersects the constant line. - -- The exact minimum is 0. - return 0 - let cStar : Rat := (sumLeft + sumRight) / ((outsideCount : Nat) : Rat) - if bi + 1 < breaks.size then - let bNext := breaks[bi + 1]! - if b < cStar ∧ cStar < bNext then - bestG := min bestG - (evalG cStar leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - else - -- Last region `(b, +∞)`. - if b ≤ cStar then - bestG := min bestG - (evalG cStar leftCount rightCount sumLeft sumLeftSq sumRight sumRightSq) - let exactLB := bestG / nRat - return exactLB - -/-- Over-approximate GeLU on an interval without any transcendental facts. - -For all real `x`, `GeLU(x) = x·Φ(x)` lies between `x` and `0`. -Therefore `GeLU([lo,hi]) ⊆ [min(lo,0), max(hi,0)]`. --/ -def geluOverapprox (a : RatInterval) : RatInterval := - { lo := min a.lo 0, hi := max a.hi 0 } - -/-- Upper bound on `max |gelu'(x)|` over a rational interval. -/ -def geluDerivBound (target : GeluDerivTarget) (a : RatInterval) : Rat := - let maxAbs := max (ratAbs a.lo) (ratAbs a.hi) - let maxAbsSq := maxAbs * maxAbs - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - -- Conservative: Φ(x) ≤ 1 and φ(x) ≤ 1/2, so gelu'(x) ≤ 1 + |x|/2. - min (1 + half * maxAbs) (geluDerivBoundGlobal target) - | .tanh => - -- Conservative: |tanh| ≤ 1, sech^2 ≤ 1, and sqrt(2/pi) ≤ 1. - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbs * slope - min localBound (geluDerivBoundGlobal target) - -/-- Upper bound on the row-sum softmax Jacobian norm for a probability interval. -/ -def softmaxJacobianNormInfBound (a : RatInterval) : Rat := - Nfp.Sound.softmaxJacobianNormInfBound a.lo a.hi - -/-! ### Specs -/ - -theorem const_def (r : Rat) : RatInterval.const r = { lo := r, hi := r } := rfl - -theorem centeredAbsBound_def (a : RatInterval) : - RatInterval.centeredAbsBound a = ratAbs (a.hi - a.lo) := rfl - -theorem add_def (a b : RatInterval) : - RatInterval.add a b = { lo := a.lo + b.lo, hi := a.hi + b.hi } := rfl - -theorem sub_def (a b : RatInterval) : - RatInterval.sub a b = { lo := a.lo - b.hi, hi := a.hi - b.lo } := rfl - -theorem scale_def (c : Rat) (a : RatInterval) : - RatInterval.scale c a = - if c ≥ 0 then - { lo := c * a.lo, hi := c * a.hi } - else - { lo := c * a.hi, hi := c * a.lo } := rfl - -theorem relu_def (a : RatInterval) : - RatInterval.relu a = { lo := max 0 a.lo, hi := max 0 a.hi } := rfl - -theorem union_def (a b : RatInterval) : - RatInterval.union a b = { lo := min a.lo b.lo, hi := max a.hi b.hi } := rfl - -theorem softmaxJacobianNormInfBound_def (a : RatInterval) : - RatInterval.softmaxJacobianNormInfBound a = - Nfp.Sound.softmaxJacobianNormInfBound a.lo a.hi := rfl - -theorem geluDerivBound_def (target : GeluDerivTarget) (a : RatInterval) : - RatInterval.geluDerivBound target a = - let maxAbs := max (ratAbs a.lo) (ratAbs a.hi) - let maxAbsSq := maxAbs * maxAbs - let half : Rat := (1 : Rat) / 2 - match target with - | .exact => - min (1 + half * maxAbs) (geluDerivBoundGlobal target) - | .tanh => - let c : Rat := (44715 : Rat) / 1000000 - let slope := 1 + (3 : Rat) * c * maxAbsSq - let localBound := 1 + half * maxAbs * slope - min localBound (geluDerivBoundGlobal target) := rfl - -theorem containsZero_iff (a : RatInterval) : - RatInterval.containsZero a = true ↔ a.lo ≤ 0 ∧ 0 ≤ a.hi := by - simp [containsZero] - -theorem squareLowerBound_def (a : RatInterval) : - RatInterval.squareLowerBound a = - if RatInterval.containsZero a then - 0 - else - let alo := ratAbs a.lo - let ahi := ratAbs a.hi - let m := min alo ahi - ratSq m := rfl - -theorem mean_def (xs : Array RatInterval) : - RatInterval.mean xs = - if xs.isEmpty then - RatInterval.const 0 - else - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - let (loSum, hiSum) := - xs.foldl (fun (acc : Rat × Rat) x => (acc.1 + x.lo, acc.2 + x.hi)) (0, 0) - { lo := loSum / nRat, hi := hiSum / nRat } := rfl - -theorem varianceLowerBound_spec (xs : Array RatInterval) : - RatInterval.varianceLowerBound xs = RatInterval.varianceLowerBound xs := rfl - -theorem geluOverapprox_def (a : RatInterval) : - RatInterval.geluOverapprox a = { lo := min a.lo 0, hi := max a.hi 0 } := rfl - -end RatInterval - -end Nfp.Sound diff --git a/Nfp/Sound/ModelHeader.lean b/Nfp/Sound/ModelHeader.lean deleted file mode 100644 index 6864691..0000000 --- a/Nfp/Sound/ModelHeader.lean +++ /dev/null @@ -1,82 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Activation -import Nfp.Sound.Decimal - -namespace Nfp.Sound - -/-! -# Model header helpers (SOUND) - -Pure parsing utilities for extracting trusted metadata from `NFP_TEXT` model headers. --/ - -/-- Parse `key=value` header lines. -/ -def parseHeaderLine (line : String) : Option (String × String) := - let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - -/-- Minimal parsed header data for sound certification. -/ -structure TextHeader where - eps : Rat - geluDerivTarget : GeluDerivTarget - deriving Repr - -private def parseGeluDerivTarget (v : String) : Except String GeluDerivTarget := - match geluDerivTargetOfString v with - | some t => .ok t - | none => .error s!"invalid gelu_kind '{v}' (expected tanh|exact)" - -/-- Parse required metadata from a `NFP_TEXT` header. -/ -def parseTextHeader (lines : Array String) : Except String TextHeader := - Id.run do - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut eps? : Option Rat := none - let mut gelu? : Option GeluDerivTarget := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := lines.size - else - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - if k = "layer_norm_eps" || k = "eps" then - match parseRat v with - | .error e => return .error s!"invalid layer_norm_eps '{v}': {e}" - | .ok r => eps? := some r - else if k = "gelu_kind" || k = "gelu_deriv" then - match parseGeluDerivTarget v with - | .error e => return .error e - | .ok t => gelu? := some t - i := i + 1 - let some eps := eps? | return .error "missing layer_norm_eps" - let some gelu := gelu? | return .error "missing gelu_kind" - return .ok { eps := eps, geluDerivTarget := gelu } - -/-- Parse `layer_norm_eps` (or `eps`) from a `NFP_TEXT` header. -/ -def parseTextHeaderEps (lines : Array String) : Except String Rat := do - let hdr ← parseTextHeader lines - return hdr.eps - -/-! ### Specs -/ - -theorem parseHeaderLine_spec : parseHeaderLine = parseHeaderLine := rfl -theorem parseTextHeader_spec : parseTextHeader = parseTextHeader := rfl -theorem parseTextHeaderEps_spec : parseTextHeaderEps = parseTextHeaderEps := rfl - -end Nfp.Sound diff --git a/Nfp/Sound/TextPure.lean b/Nfp/Sound/TextPure.lean deleted file mode 100644 index fe3895d..0000000 --- a/Nfp/Sound/TextPure.lean +++ /dev/null @@ -1,212 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Bounds -import Nfp.Sound.Decimal -import Nfp.Sound.ModelHeader - -namespace Nfp.Sound - -/-! -# Pure text helpers (`NFP_TEXT`) - -Pure parsing utilities for extracting exact `Rat` bounds from text model files. --/ - -structure TextModelDims where - numLayers : Nat - numHeads : Nat - modelDim : Nat - headDim : Nat - hiddenDim : Nat - seqLen : Nat - start : Nat - deriving Repr - -/-- Per-layer attention weight bounds extracted from a text model. -/ -structure AttnWeightBounds where - attnValueCoeff : Array Rat - wqOpBoundMax : Array Rat - wkOpBoundMax : Array Rat - deriving Repr - -def parseTextHeaderDims (lines : Array String) : Except String TextModelDims := - Id.run do - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut seqLen : Option Nat := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := v.toNat? - | "num_heads" => numHeads := v.toNat? - | "model_dim" => modelDim := v.toNat? - | "head_dim" => headDim := v.toNat? - | "hidden_dim" => hiddenDim := v.toNat? - | "seq_len" => seqLen := v.toNat? - | _ => pure () - i := i + 1 - let some L := numLayers | return .error "missing num_layers" - let some H := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let some n := seqLen | return .error "missing seq_len" - return .ok { - numLayers := L - numHeads := H - modelDim := d - headDim := dh - hiddenDim := dhid - seqLen := n - start := i - } - -/-- Fold `count` rationals from lines starting at `start`, returning the new state and index. -/ -def foldRatTokens {α : Type} - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Rat → α) : Except String (α × Nat) := - Id.run do - let mut i := start - let mut remaining := count - let mut st := state - while remaining > 0 do - if i < lines.size then - let line := lines[i]!.trim - i := i + 1 - if line.isEmpty then - pure () - else - let toks := line.splitOn " " |>.filter (· ≠ "") - for t in toks do - if remaining = 0 then - break - match parseRat t with - | .error e => return .error e - | .ok r => - st := step st r - remaining := remaining - 1 - else - return .error "unexpected end of file while reading numbers" - return .ok (st, i) - -/-- Consume a vector of length `n` and return its values. -/ -def consumeVector - (lines : Array String) - (start : Nat) - (n : Nat) : Except String (Array Rat × Nat) := - let step := fun (acc : Array Rat) (x : Rat) => acc.push x - foldRatTokens lines start n (Array.mkEmpty n) step - -/-- Consume a matrix and return its row-sum norm. -/ -def consumeMatrixNormInf - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String (Rat × Nat) := - let count := rows * cols - match consumeVector lines start count with - | .error e => .error e - | .ok (xs, next) => .ok (matrixNormInfOfRowMajor rows cols xs, next) - -/-- Compute per-layer attention value and `W_Q/W_K` bounds from text model lines. -/ -def attnWeightBoundsFromTextLines (lines : Array String) : Except String AttnWeightBounds := - Id.run do - let infoE := parseTextHeaderDims lines - let info ← - match infoE with - | .error e => return .error e - | .ok v => pure v - let mut i := info.start - let mut curLayer : Nat := 0 - let mut attnValueCoeff : Array Rat := Array.replicate info.numLayers 0 - let mut wqMax : Array Rat := Array.replicate info.numLayers 0 - let mut wkMax : Array Rat := Array.replicate info.numLayers 0 - while i < lines.size do - let line := lines[i]!.trim - if line.startsWith "LAYER" then - let parts := line.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD 0 - i := i + 1 - else if line = "W_Q" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with - | .error e => return .error e - | .ok (nq, next) => - if r < wqMax.size then - wqMax := wqMax.set! r (max wqMax[r]! nq) - i := next - else if line = "W_K" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with - | .error e => return .error e - | .ok (nk, next) => - if r < wkMax.size then - wkMax := wkMax.set! r (max wkMax[r]! nk) - i := next - else if line = "W_V" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) info.modelDim info.headDim with - | .error e => return .error e - | .ok (nv, next) => - i := next - while i < lines.size && lines[i]!.trim ≠ "W_O" do - i := i + 1 - if !(i < lines.size) then - return .error "expected W_O after W_V" - match consumeMatrixNormInf lines (i + 1) info.headDim info.modelDim with - | .error e => return .error e - | .ok (no, next2) => - if r < attnValueCoeff.size then - attnValueCoeff := - attnValueCoeff.set! r (attnValueCoeff[r]! + (nv * no)) - i := next2 - else - i := i + 1 - return .ok { - attnValueCoeff := attnValueCoeff - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - } - -/-- Compute per-layer `attnValueCoeff` from text model lines. -/ -def attnValueCoeffFromTextLines (lines : Array String) : Except String (Array Rat) := do - let bounds ← attnWeightBoundsFromTextLines lines - return bounds.attnValueCoeff - -/-! ### Specs -/ - -theorem parseTextHeaderDims_spec : parseTextHeaderDims = parseTextHeaderDims := rfl -theorem AttnWeightBounds_spec : AttnWeightBounds = AttnWeightBounds := rfl -theorem foldRatTokens_spec (α : Type) : - @foldRatTokens α = @foldRatTokens α := rfl -theorem consumeVector_spec : consumeVector = consumeVector := rfl -theorem consumeMatrixNormInf_spec : consumeMatrixNormInf = consumeMatrixNormInf := rfl -theorem attnWeightBoundsFromTextLines_spec : - attnWeightBoundsFromTextLines = attnWeightBoundsFromTextLines := rfl -theorem attnValueCoeffFromTextLines_spec : - attnValueCoeffFromTextLines = attnValueCoeffFromTextLines := rfl - -end Nfp.Sound diff --git a/Nfp/System.lean b/Nfp/System.lean new file mode 100644 index 0000000..9ee4bdb --- /dev/null +++ b/Nfp/System.lean @@ -0,0 +1,10 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Nfp.System.Dag +public import Nfp.System.LocalSystem + +/-! +DAG-based system foundations. +-/ diff --git a/Nfp/System/Dag.lean b/Nfp/System/Dag.lean new file mode 100644 index 0000000..7bf74f2 --- /dev/null +++ b/Nfp/System/Dag.lean @@ -0,0 +1,83 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public meta import Nfp.Tactic.Linter +public import Mathlib.Combinatorics.Digraph.Basic +public import Mathlib.Data.Fintype.Defs +public import Mathlib.Data.Finset.Basic + +/-! +Directed acyclic graph foundations. +-/ + +public section + +namespace Nfp + +universe u u' + +/-- A finite directed acyclic graph, built on top of `Digraph`. -/ +structure Dag (ι : Type u) [Fintype ι] where + /-- The underlying directed graph. -/ + graph : Digraph ι + /-- Decidable adjacency for `graph.Adj`. -/ + decAdj : DecidableRel graph.Adj + /-- The adjacency relation is well-founded. -/ + wf : WellFounded graph.Adj + +attribute [instance] Dag.decAdj + +namespace Dag + +variable {ι : Type u} [Fintype ι] + +/-- Coerce a DAG to its underlying digraph. -/ +instance : Coe (Dag ι) (Digraph ι) := ⟨Dag.graph⟩ + +/-- The edge relation of a DAG. -/ +abbrev rel (G : Dag ι) : ι → ι → Prop := G.graph.Adj + +/-- Parents (incoming neighbors) of a node. -/ +def parents (G : Dag ι) (i : ι) : Finset ι := by + let _ : DecidablePred (fun j => G.rel j i) := fun j => G.decAdj j i + exact Finset.filter (fun j => G.rel j i) Finset.univ + +/-- Children (outgoing neighbors) of a node. -/ +def children (G : Dag ι) (i : ι) : Finset ι := by + let _ : DecidablePred (fun j => G.rel i j) := fun j => G.decAdj i j + exact Finset.filter (fun j => G.rel i j) Finset.univ + +@[simp] theorem mem_parents {G : Dag ι} {i j : ι} : + j ∈ G.parents i ↔ G.rel j i := by + simp [Dag.parents] + +@[simp] theorem mem_children {G : Dag ι} {i j : ι} : + j ∈ G.children i ↔ G.rel i j := by + simp [Dag.children] + +section Relabel + +variable {ι' : Type u'} [Fintype ι'] + +/-- Relabel a DAG along an equivalence of vertex types. -/ +def relabel (G : Dag ι) (e : ι ≃ ι') : Dag ι' := + { graph := { Adj := fun a b => G.rel (e.symm a) (e.symm b) } + decAdj := by + intro a b + exact G.decAdj (e.symm a) (e.symm b) + wf := by + simpa using (InvImage.wf (f := e.symm) (h := G.wf)) } + +/-- Relabeling preserves adjacency via the equivalence. -/ +theorem relabel_rel_iff (G : Dag ι) (e : ι ≃ ι') (a b : ι') : + (G.relabel e).rel a b ↔ G.rel (e.symm a) (e.symm b) := by + rfl + +end Relabel + +end Dag + +end Nfp + +end diff --git a/Nfp/System/LocalSystem.lean b/Nfp/System/LocalSystem.lean new file mode 100644 index 0000000..60ea203 --- /dev/null +++ b/Nfp/System/LocalSystem.lean @@ -0,0 +1,108 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public import Mathlib.Data.Fintype.BigOperators +public import Nfp.Mixer.Basic +public import Nfp.System.Dag + +/-! +Local mixing systems on finite DAGs. +-/ + +public section + +open scoped BigOperators + +namespace Nfp + +universe u + +/-- A local mixing system on a DAG. +`weight i j` is the contribution from `j` into `i`. -/ +structure LocalSystem (ι : Type u) [Fintype ι] where + /-- The underlying DAG describing allowed dependencies. -/ + dag : Dag ι + /-- Mixing weights for each target/source pair. -/ + weight : ι → ι → Mass + /-- Weights vanish off the edge relation. -/ + support : ∀ i j, ¬ dag.rel j i → weight i j = 0 + +namespace LocalSystem + +variable {ι : Type u} [Fintype ι] + +/-- Row-stochasticity for a local system. -/ +def IsRowStochastic (L : LocalSystem ι) : Prop := + ∀ i, (∑ j, L.weight i j) = 1 + +/-- View a local system as a global mixer. -/ +def toMixer (L : LocalSystem ι) (h : IsRowStochastic L) : Mixer ι ι := + { weight := L.weight + row_sum := h } + +/-- Off-edge weights are zero. -/ +theorem weight_eq_zero_of_not_parent (L : LocalSystem ι) {i j : ι} (h : ¬ L.dag.rel j i) : + L.weight i j = 0 := + by + simpa using L.support i j h + +/-- Off-edge weights remain zero after coercion to a mixer. -/ +theorem toMixer_weight_eq_zero_of_not_parent (L : LocalSystem ι) (h : IsRowStochastic L) + {i j : ι} (hrel : ¬ L.dag.rel j i) : + (toMixer L h).weight i j = 0 := by + simpa [toMixer] using weight_eq_zero_of_not_parent (L := L) (i := i) (j := j) hrel + +/-- Row sums can be restricted to parents in a row-stochastic local system. -/ +theorem row_sum_parents (L : LocalSystem ι) (h : IsRowStochastic L) (i : ι) : + Finset.sum (L.dag.parents i) (fun j => L.weight i j) = 1 := by + classical + have hsubset : L.dag.parents i ⊆ (Finset.univ : Finset ι) := by + intro j _hj + exact Finset.mem_univ j + have hzero : + ∀ j ∈ (Finset.univ : Finset ι), j ∉ L.dag.parents i → L.weight i j = 0 := by + intro j _hj hj + have hrel : ¬ L.dag.rel j i := by + intro hrel + have hmem : j ∈ L.dag.parents i := by + simpa using hrel + exact hj hmem + exact weight_eq_zero_of_not_parent (L := L) (i := i) (j := j) hrel + have hsum : + Finset.sum (L.dag.parents i) (fun j => L.weight i j) = + Finset.sum (Finset.univ : Finset ι) (fun j => L.weight i j) := by + exact Finset.sum_subset hsubset hzero + calc + Finset.sum (L.dag.parents i) (fun j => L.weight i j) = + Finset.sum (Finset.univ : Finset ι) (fun j => L.weight i j) := hsum + _ = 1 := by + simpa using h i + +/-- One-step evaluation functional used by `eval`. -/ +def evalStep (L : LocalSystem ι) (input : ι → Mass) + (i : ι) (rec : ∀ j, L.dag.rel j i → Mass) : Mass := + input i + + ∑ j, (if h : L.dag.rel j i then L.weight i j * rec j h else 0) + +/-- Evaluate a local system with external input at each node. -/ +def eval (L : LocalSystem ι) (input : ι → Mass) : ι → Mass := + L.dag.wf.fix (fun i rec => evalStep L input i rec) + +/-- Unfolding equation for `eval`. -/ +theorem eval_eq (L : LocalSystem ι) (input : ι → Mass) (i : ι) : + eval L input i = + input i + + ∑ j, (if _ : L.dag.rel j i then L.weight i j * eval L input j else 0) := by + classical + set F : ∀ i, (∀ j, L.dag.rel j i → Mass) → Mass := fun i rec => evalStep L input i rec + change L.dag.wf.fix F i = + input i + ∑ j, (if _ : L.dag.rel j i then L.weight i j * L.dag.wf.fix F j else 0) + rw [WellFounded.fix_eq] + simp [F, evalStep] + +end LocalSystem + +end Nfp + +end diff --git a/Nfp/Tactic/Linter.lean b/Nfp/Tactic/Linter.lean new file mode 100644 index 0000000..610b3a3 --- /dev/null +++ b/Nfp/Tactic/Linter.lean @@ -0,0 +1,9 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public meta import Nfp.Tactic.Linter.NoHeartbeats + +/-! +Aggregator for NFP linters. +-/ diff --git a/Nfp/Tactic/Linter/NoHeartbeats.lean b/Nfp/Tactic/Linter/NoHeartbeats.lean new file mode 100644 index 0000000..cd86d33 --- /dev/null +++ b/Nfp/Tactic/Linter/NoHeartbeats.lean @@ -0,0 +1,56 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +public meta import Lean.Elab.Command +public meta import Lean.Linter.Basic + +/-! +Syntax linter forbidding heartbeat budget options. +-/ + +meta section + +namespace Nfp +namespace Linter + +open Lean Parser Elab Command Linter + +/-- Enable the no-heartbeats linter. -/ +public register_option linter.nfp.noHeartbeats : Bool := { + defValue := false + descr := "enable the noHeartbeats linter" +} + +namespace NoHeartbeats + +/-- Return the option name if syntax is a `set_option` command, term, or tactic. -/ +def parseSetOption : Syntax → Option Name + | `(command|set_option $name:ident $_val) => some name.getId + | `(set_option $name:ident $_val in $_x) => some name.getId + | `(tactic|set_option $name:ident $_val in $_x) => some name.getId + | _ => none + +/-- True if the option is a heartbeat budget. -/ +def isHeartbeatOption (name : Name) : Bool := + name == `maxHeartbeats || name == `synthInstance.maxHeartbeats + +/-- Linter that forbids heartbeat budget options in this repository. -/ +def noHeartbeatsLinter : Linter where + run stx := do + unless getLinterValue linter.nfp.noHeartbeats (← getLinterOptions) do + return + if (← MonadState.get).messages.hasErrors then + return + if let some head := stx.find? (fun stx => (parseSetOption stx).isSome) then + if let some name := parseSetOption head then + if isHeartbeatOption name then + logLint linter.nfp.noHeartbeats head + m!"Setting option '{name}' is forbidden; refactor the proof instead." + +initialize addLinter noHeartbeatsLinter + +end NoHeartbeats + +end Linter +end Nfp diff --git a/Nfp/Uniqueness.lean b/Nfp/Uniqueness.lean deleted file mode 100644 index ad4a8c7..0000000 --- a/Nfp/Uniqueness.lean +++ /dev/null @@ -1,98 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Mathlib.Data.NNReal.Basic -import Mathlib.Data.Fin.Basic -import Mathlib.Data.Finset.Basic -import Mathlib.Algebra.BigOperators.Group.Finset.Basic -import Aesop - -namespace Nfp - -open scoped BigOperators -open Finset - -variable {S : Type*} - -/-! -Auxiliary linear-uniqueness lemma used to support the Appendix A.3 uniqueness -story (fixed masks ⇒ linear system). It shows that for a finite DAG ordered by a -topological index, any two tracer families satisfying the same homogeneous linear -recurrence coincide. This is a helper fact; Appendix A.5 in the paper is a -counterexample narrative (not a formal lemma), so we avoid referring to A.5 here. -Small parent-index checks in the inductive step are discharged by `aesop`; the -overall proof structure (nested induction on the index bound) remains explicit. --/ - -/-- A local mixing system over `n` nodes, where each node `i` aggregates parents -with nonnegative coefficients, and every parent `u` of `i` has a strictly smaller index. -/ -structure LocalSystem (n : ℕ) where - Pa : Fin n → Finset (Fin n) - c : Fin n → Fin n → NNReal - topo : ∀ {i u}, u ∈ Pa i → (u.1 < i.1) - -namespace LocalSystem - -variable {n : ℕ} - -/-- A family of tracer masses at each node. It does not depend on `L`. -/ -abbrev TracerFamily (n : ℕ) := Fin n → (S → NNReal) - -/-- The recurrence equation for a tracer family against the local system. -/ -def Satisfies (L : LocalSystem n) (T : TracerFamily (S := S) n) : Prop := - ∀ i s, T i s = (L.Pa i).sum (fun u => L.c i u * T u s) - -/-- Homogeneous linear uniqueness on a topologically ordered finite DAG: if `T` and -`T'` both satisfy the same recurrence, then they are equal pointwise. -/ -theorem tracer_unique (L : LocalSystem n) {T T' : TracerFamily (S := S) n} - (hT : Satisfies (S := S) L T) (hT' : Satisfies (S := S) L T') : T = T' := by - classical - funext i s - -- Prove by induction on an upper bound `k` of the index. - have main : ∀ k, ∀ j : Fin n, j.1 ≤ k → T j s = T' j s := by - intro k - induction k with - | zero => - intro j hj - have hj0 : j.1 = 0 := Nat.le_zero.mp hj - -- No parents at index 0 - have hPaEmpty : L.Pa j = (∅ : Finset (Fin n)) := by - classical - apply Finset.eq_empty_iff_forall_notMem.mpr - intro u hu - have hlt : u.1 < j.1 := L.topo (i := j) (u := u) hu - have hlt0 : u.1 < 0 := by - have := hlt - rwa [hj0] at this - exact (Nat.not_lt_zero _ hlt0).elim - have hjT : T j s = (L.Pa j).sum (fun u => L.c j u * T u s) := (hT j s) - have hjT' : T' j s = (L.Pa j).sum (fun u => L.c j u * T' u s) := (hT' j s) - have hz1 : T j s = 0 := by simpa [hPaEmpty] using hjT - have hz2 : T' j s = 0 := by simpa [hPaEmpty] using hjT' - simpa [hz1] using hz2.symm - | succ k ih => - intro j hj - by_cases hle : j.1 ≤ k - · exact ih j hle - · have hklt : k < j.1 := Nat.not_le.mp hle - have hjeq : j.1 = k.succ := le_antisymm hj (Nat.succ_le_of_lt hklt) - have hjT : T j s = (L.Pa j).sum (fun u => L.c j u * T u s) := (hT j s) - have hjT' : T' j s = (L.Pa j).sum (fun u => L.c j u * T' u s) := (hT' j s) - have hparents : ∀ u ∈ L.Pa j, u.1 ≤ k := by - intro u hu - have hlt : u.1 < k.succ := by - simpa [hjeq] using L.topo (i := j) (u := u) hu - have hle : u.1 ≤ k := Nat.le_of_lt_succ hlt - aesop - have hsum : (L.Pa j).sum (fun u => L.c j u * T u s) = - (L.Pa j).sum (fun u => L.c j u * T' u s) := by - classical - apply Finset.sum_congr rfl - intro u hu - have := ih u (hparents u hu) - simp [this] - simp [hjT, hjT', hsum] - exact main i.1 i le_rfl - -end LocalSystem - -end Nfp diff --git a/Nfp/Untrusted/SoundBinary.lean b/Nfp/Untrusted/SoundBinary.lean deleted file mode 100644 index afb4dad..0000000 --- a/Nfp/Untrusted/SoundBinary.lean +++ /dev/null @@ -1,152 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.BinaryPure - -namespace Nfp.Untrusted.SoundBinary - -/-! -# Untrusted binary IO helpers (`NFP_BINARY_V1`) - -This module provides IO wrappers for the sound binary path. -Pure parsing/decoding lives in `Nfp.Sound.BinaryPure`. --/ - -private def readLine? (h : IO.FS.Handle) : IO (Option String) := do - let s ← h.getLine - if s.isEmpty then - return none - return some s - -def readBinaryHeader (h : IO.FS.Handle) : IO (Except String Nfp.Sound.BinaryHeader) := do - let some magicLine ← readLine? h - | return .error "empty file" - let mut lines : Array String := #[] - let mut line? ← readLine? h - while true do - match line? with - | none => return .error "unexpected EOF while reading header" - | some line => - let t := line.trim - if t = "BINARY_START" then - break - lines := lines.push line - line? ← readLine? h - return Nfp.Sound.parseBinaryHeaderLines magicLine lines - -private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - let mut out := ByteArray.empty - while out.size < n do - let chunk ← h.read (USize.ofNat (n - out.size)) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF") - out := out ++ chunk - return out - -def skipBytes (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := do - let mut remaining := n - while remaining > 0 do - let chunkSize := min remaining 65536 - let chunk ← h.read (USize.ofNat chunkSize) - if chunk.isEmpty then - return .error "unexpected EOF" - remaining := remaining - chunk.size - return .ok () - -def skipI32Array (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := - skipBytes h (n * 4) - -def skipF64Array (h : IO.FS.Handle) (n : Nat) : IO (Except String Unit) := - skipBytes h (n * 8) - -def readVectorMaxAbsScaled (h : IO.FS.Handle) (n scalePow10 : Nat) : - IO (Except String Int) := do - if n = 0 then - return .ok 0 - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (n * 8))) - catch - | _ => pure (Except.error "unexpected EOF") - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.vectorMaxAbsScaledFromBytes bytes n scalePow10 - -def readMatrixNormInfScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : - IO (Except String Int) := do - if rows = 0 || cols = 0 then - return .ok 0 - let count := rows * cols - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 8))) - catch - | _ => pure (Except.error "unexpected EOF") - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.matrixNormInfScaledFromBytes bytes rows cols scalePow10 - -def readScaledFloatArray (h : IO.FS.Handle) (count scalePow10 : Nat) : - IO (Except String (Array Int)) := do - if count = 0 then - return .ok #[] - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 8))) - catch - | _ => pure (Except.error "unexpected EOF") - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.scaledFloatArrayFromBytes bytes count scalePow10 - -def readScaledFloat (h : IO.FS.Handle) (scalePow10 : Nat) : IO (Except String Int) := do - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h 8)) - catch - | _ => pure (Except.error "unexpected EOF") - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.scaledFloatFromBytes bytes scalePow10 - -def readI32Array (h : IO.FS.Handle) (count : Nat) : - IO (Except String (Array Int)) := do - if count = 0 then - return .ok #[] - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 4))) - catch - | _ => pure (Except.error "unexpected EOF") - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.i32ArrayFromBytes bytes count - -def readMatrixNormOneInfScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : - IO (Except String (Nat × Nat)) := do - if rows = 0 || cols = 0 then - return .ok (0, 0) - let count := rows * cols - let bytesE : Except String ByteArray ← - try - pure (Except.ok (← readExactly h (count * 8))) - catch - | _ => pure (Except.error "unexpected EOF") - match bytesE with - | .error e => return .error e - | .ok bytes => - return Nfp.Sound.matrixNormOneInfScaledFromBytes bytes rows cols scalePow10 - -def readMatrixOpBoundScaled (h : IO.FS.Handle) (rows cols scalePow10 : Nat) : - IO (Except String Nat) := do - match ← readMatrixNormOneInfScaled h rows cols scalePow10 with - | .error e => return .error e - | .ok (rowSum, colSum) => - return .ok (Nfp.Sound.opBoundScaledFromOneInf rowSum colSum) - -end Nfp.Untrusted.SoundBinary diff --git a/Nfp/Untrusted/SoundCacheIO.lean b/Nfp/Untrusted/SoundCacheIO.lean deleted file mode 100644 index 5805ad6..0000000 --- a/Nfp/Untrusted/SoundCacheIO.lean +++ /dev/null @@ -1,116 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Init.System.IO -import Nfp.Sound.CachePure - -namespace Nfp.Untrusted.SoundCacheIO - -/-! -# Untrusted SOUND fixed-point cache - -IO wrappers for the SOUND cache format. Pure parsing/encoding lives in `Nfp.Sound.CachePure`. --/ -private def readExactly (h : IO.FS.Handle) (n : Nat) : IO ByteArray := do - let mut out := ByteArray.empty - while out.size < n do - let chunk ← h.read (USize.ofNat (n - out.size)) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF") - out := out ++ chunk - return out - -def writeHeader (h : IO.FS.Handle) (hdr : Nfp.Sound.SoundCache.Header) : IO Unit := do - h.write (Nfp.Sound.SoundCache.encodeHeader hdr) - -def readHeader (h : IO.FS.Handle) : IO Nfp.Sound.SoundCache.Header := do - let bytes ← readExactly h Nfp.Sound.SoundCache.headerBytes - match Nfp.Sound.SoundCache.decodeHeader bytes with - | .ok hdr => return hdr - | .error e => throw (IO.userError e) - -/-- FNV-1a 64-bit hash of a file's bytes (stable, deterministic). -/ -def fnv1a64File (path : System.FilePath) : IO UInt64 := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let mut hash : UInt64 := Nfp.Sound.SoundCache.fnv1a64Init - let mut done := false - while !done do - let chunk ← h.read (USize.ofNat 1048576) - if chunk.isEmpty then - done := true - else - hash := Nfp.Sound.SoundCache.fnv1a64Update hash chunk - return hash - -/-- Ensure the cache directory exists. -/ -def ensureCacheDir : IO Unit := do - IO.FS.createDirAll Nfp.Sound.SoundCache.cacheDir - -/-- Build (or overwrite) a SOUND fixed-point cache file. -/ -def buildCacheFile - (modelPath cachePath : System.FilePath) - (scalePow10 : Nat := 9) : IO (Except String Unit) := do - ensureCacheDir - let contents ← IO.FS.readFile modelPath - let lines : Array String := (contents.splitOn "\n").toArray - let modelHash ← fnv1a64File modelPath - let mdata ← modelPath.metadata - let modelSize : UInt64 := mdata.byteSize - match Nfp.Sound.SoundCache.buildCacheBytes lines scalePow10 modelHash modelSize with - | .error e => return .error e - | .ok bytes => - let tmpPath := cachePath.withExtension "tmp" - if (← tmpPath.pathExists) then - IO.FS.removeFile tmpPath - let out ← IO.FS.Handle.mk tmpPath IO.FS.Mode.write - out.write bytes - out.flush - if (← cachePath.pathExists) then - IO.FS.removeFile cachePath - IO.FS.rename tmpPath cachePath - return .ok () - -/-! ## Consistency checks (for CI and debugging) -/ - -/-- Check that for each numeric token in the text file, its exact `Rat` value lies in the -`±1`-ulp interval induced by `parseFixed10Rounded scalePow10`. -/ -def checkTextTokenEnvelope - (modelPath : System.FilePath) - (scalePow10 : Nat := 9) - (maxTokens : Nat := 0) : IO (Except String Unit) := do - let contents ← IO.FS.readFile modelPath - let lines : Array String := (contents.splitOn "\n").toArray - return Nfp.Sound.SoundCache.checkTextTokenEnvelopeLines lines scalePow10 maxTokens - -/-- Check that the cache file size matches the expected tensor stream length. -/ -def checkCacheFileSize (cachePath : System.FilePath) (hdr : Nfp.Sound.SoundCache.Header) : - IO (Except String Unit) := do - let mdata ← cachePath.metadata - let expectedBytes := Nfp.Sound.SoundCache.expectedCacheBytes hdr - if mdata.byteSize = expectedBytes then - return .ok () - else - return .error s!"cache size mismatch: expected {expectedBytes}, got {mdata.byteSize}" - -/-! ## Cache reader (buffered) -/ - -def I32Reader.init (h : IO.FS.Handle) : IO Nfp.Sound.SoundCache.I32Reader := - pure { h := h, buf := ByteArray.empty, pos := 0 } - -private def I32Reader.refill (r : Nfp.Sound.SoundCache.I32Reader) : - IO Nfp.Sound.SoundCache.I32Reader := do - let chunk ← r.h.read (USize.ofNat 1048576) - if chunk.isEmpty then - throw (IO.userError "unexpected EOF while reading cache") - return { r with buf := chunk, pos := 0 } - -def I32Reader.readI32 (r : Nfp.Sound.SoundCache.I32Reader) : - IO (Int × Nfp.Sound.SoundCache.I32Reader) := do - let r ← - if r.pos + 4 ≤ r.buf.size then - pure r - else - I32Reader.refill r - let x := Nfp.Sound.SoundCache.i32FromBuffer r.buf r.pos - return (x, { r with pos := r.pos + 4 }) -end Nfp.Untrusted.SoundCacheIO diff --git a/Nfp/Untrusted/SoundCompute.lean b/Nfp/Untrusted/SoundCompute.lean deleted file mode 100644 index 5562929..0000000 --- a/Nfp/Untrusted/SoundCompute.lean +++ /dev/null @@ -1,4665 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Std -import Nfp.Sound.Cert -import Nfp.Sound.HeadCert -import Nfp.Untrusted.SoundBinary -import Nfp.Sound.Interval -import Nfp.Untrusted.SoundCacheIO -import Nfp.Sound.Fixed - -namespace Nfp.Untrusted.SoundCompute - -open IO -open Nfp.Sound -open Nfp.Untrusted.SoundBinary - -/-! -# Untrusted SOUND `.nfpt` loader (exact Rat parsing, legacy text format) - -This is a minimal, *sound* loader intended for certification on the legacy text format. - -It does **not** construct the full `ConcreteModel` (Float-based). Instead it parses only the -weights needed for conservative residual amplification constants `Cᵢ` (bounds ‖layerJacobian - I‖), -using exact `Rat` arithmetic. - -It can optionally consume an input `.nfpt` file (for `EMBEDDINGS`) to enable **local** -LayerNorm certification on a bounded region around that input. - -Global certification supports `NFP_BINARY_V1` via a sound fixed-point rounding bridge. -Local (input-dependent) certification supports `NFP_BINARY_V1` using a union-box fixed-point path. - -Trusted base: -- Parsing from text to `Rat` via `Nfp.Sound.parseRat`. -- Exact accumulation of row-sum norms and max-abs values. - -No `Float` arithmetic is used as an input to certification. --/ - -/-- Parse `key=value` header lines. -/ -def parseHeaderLine (line : String) : Option (String × String) := - let line := line.trim - if line.isEmpty then none - else - match line.splitOn "=" with - | [k, v] => some (k.trim, v.trim) - | _ => none - -private def defaultBinaryScalePow10 : Nat := 9 - -/-- Read `count` rationals from lines starting at `start`, folding into `state`. - -Returns `(state, nextLineIndex)`. --/ -def foldRatTokens {α : Type} - (lines : Array String) - (start : Nat) - (count : Nat) - (state : α) - (step : α → Rat → α) : Except String (α × Nat) := - Id.run do - let mut i := start - let mut remaining := count - let mut st := state - while remaining > 0 do - if i < lines.size then - let line := lines[i]!.trim - i := i + 1 - if line.isEmpty then - pure () - else - let toks := line.splitOn " " |>.filter (· ≠ "") - for t in toks do - if remaining = 0 then - break - match parseRat t with - | .error e => return .error e - | .ok r => - st := step st r - remaining := remaining - 1 - else - return .error "unexpected end of file while reading numbers" - return .ok (st, i) - -/-- Consume a vector of length `n` into an array. Returns `(xs, nextLineIndex)`. -/ -def consumeVector - (lines : Array String) - (start : Nat) - (n : Nat) : Except String (Array Rat × Nat) := - let step := fun (acc : Array Rat) (x : Rat) => acc.push x - foldRatTokens lines start n (Array.mkEmpty n) step - -/-- Consume a matrix in row-major order and return its exact row-sum norm -(`‖·‖∞` in column-vector convention, `‖·‖₁` in row-vector convention). - -Returns `(normInf, nextLineIndex)`. --/ -def consumeMatrixNormInf - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String (Rat × Nat) := - let count := rows * cols - match consumeVector lines start count with - | .error e => .error e - | .ok (xs, next) => .ok (matrixNormInfOfRowMajor rows cols xs, next) - -/-- Parsed matrix norm is definitionally the spec-level bound on the parsed data. -/ -theorem consumeMatrixNormInf_spec - (lines : Array String) (start rows cols : Nat) (xs : Array Rat) (next : Nat) - (h : consumeVector lines start (rows * cols) = .ok (xs, next)) : - consumeMatrixNormInf lines start rows cols = - .ok (matrixNormInfOfRowMajor rows cols xs, next) := by - simp [consumeMatrixNormInf, h] - -/-- Consume a vector of length `n` and return `max |xᵢ|`. - -Returns `(maxAbs, nextLineIndex)`. --/ -def consumeMaxAbs - (lines : Array String) - (start : Nat) - (n : Nat) : Except String (Rat × Nat) := - let step := fun (m : Rat) (x : Rat) => max m (ratAbs x) - foldRatTokens lines start n 0 step - -private def maxAbsOfVector (xs : Array Rat) : Rat := - xs.foldl (fun acc x => max acc (ratAbs x)) 0 - -/-- Compute weight-only per-head contribution bounds from a binary `.nfpt`. -/ -def certifyHeadBoundsBinary - (path : System.FilePath) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadContributionCert)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut heads : Array HeadContributionCert := Array.mkEmpty (hdr.numLayers * hdr.numHeads) - for l in [:hdr.numLayers] do - for hIdx in [:hdr.numHeads] do - let wqScaledE ← readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wvScaledE ← readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10 - let wvScaled ← - match wvScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let woScaledE ← readMatrixOpBoundScaled h hdr.headDim hdr.modelDim scalePow10 - let woScaled ← - match woScaledE with - | .error e => return .error e - | .ok v => pure v - let wqOp := ratOfScaledNat scalePow10 wqScaled - let wkOp := ratOfScaledNat scalePow10 wkScaled - let wvOp := ratOfScaledNat scalePow10 wvScaled - let woOp := ratOfScaledNat scalePow10 woScaled - let cert : HeadContributionCert := { - layerIdx := l - headIdx := hIdx - wqOpBound := wqOp - wkOpBound := wkOp - wvOpBound := wvOp - woOpBound := woOp - qkFactorBound := wqOp * wkOp - voFactorBound := wvOp * woOp - } - if cert.check then - heads := heads.push cert - else - return .error "head contribution certificate failed internal checks" - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.hiddenDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.hiddenDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.vocabSize) with - | .error e => return .error e - | .ok _ => pure () - return .ok heads - -private def certifyModelFileGlobalBinary - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - let scalePow10 := defaultBinaryScalePow10 - let actDerivBound := geluDerivBoundGlobal geluDerivTarget - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut layers : Array LayerAmplificationCert := Array.mkEmpty hdr.numLayers - let mut totalAmp : Rat := 1 - for l in [:hdr.numLayers] do - let mut attnValueCoeff : Rat := 0 - let mut wqMax : Rat := 0 - let mut wkMax : Rat := 0 - for _h in [:hdr.numHeads] do - let wqScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let nvScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let nvScaled ← - match nvScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let noScaledE ← readMatrixNormInfScaled h hdr.headDim hdr.modelDim scalePow10 - let noScaled ← - match noScaledE with - | .error e => return .error e - | .ok v => pure v - let wq := ratOfScaledInt scalePow10 wqScaled - let wk := ratOfScaledInt scalePow10 wkScaled - let nv := ratOfScaledInt scalePow10 nvScaled - let no := ratOfScaledInt scalePow10 noScaled - wqMax := max wqMax wq - wkMax := max wkMax wk - attnValueCoeff := attnValueCoeff + nv * no - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let nWinScaledE ← - readMatrixNormInfScaled h hdr.modelDim hdr.hiddenDim scalePow10 - let nWinScaled ← - match nWinScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - let nWoutScaledE ← - readMatrixNormInfScaled h hdr.hiddenDim hdr.modelDim scalePow10 - let nWoutScaled ← - match nWoutScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let ln1GammaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1GammaScaled ← - match ln1GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln1BetaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln1BetaScaled ← - match ln1BetaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2GammaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - let ln2GammaScaled ← - match ln2GammaScaledE with - | .error e => return .error e - | .ok v => pure v - let ln2BetaScaledE ← readVectorMaxAbsScaled h hdr.modelDim scalePow10 - match ln2BetaScaledE with - | .error e => return .error e - | .ok _ => pure () - let ln1Max := ratOfScaledInt scalePow10 ln1GammaScaled - let ln1MaxAbsBeta := ratOfScaledInt scalePow10 ln1BetaScaled - let ln2Max := ratOfScaledInt scalePow10 ln2GammaScaled - let nWin := ratOfScaledInt scalePow10 nWinScaled - let nWout := ratOfScaledInt scalePow10 nWoutScaled - let ln1Bound := layerNormOpBoundConservative ln1Max eps soundnessBits - let ln2Bound := layerNormOpBoundConservative ln2Max eps soundnessBits - let ln1OutMaxAbsBound := layerNormOutputMaxAbsBound hdr.modelDim ln1Max ln1MaxAbsBeta - let attnPatternCoeff := - attnPatternCoeffBound hdr.seqLen hdr.modelDim hdr.headDim ln1OutMaxAbsBound - wqMax wkMax attnValueCoeff - let mlpCoeff := nWin * nWout - let mlpActDerivBound := actDerivBound - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen softmaxMarginLowerBound softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnW := - ln1Bound * - ((hdr.seqLen : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1Max - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2Max - ln1VarianceLowerBound? := none - ln2VarianceLowerBound? := none - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.vocabSize) with - | .error e => return .error e - | .ok _ => pure () - let cert : ModelCert := { - modelPath := path.toString - inputPath? := none - inputDelta := 0 - eps := eps - seqLen := hdr.seqLen - modelDim := hdr.modelDim - headDim := hdr.headDim - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBound - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -private def addVecIntervals (a b : Array RatInterval) : Array RatInterval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array RatInterval := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (RatInterval.add a[i]! b[i]!) - return out - -private def addConstVec (a : Array RatInterval) (b : Array Rat) : Array RatInterval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array RatInterval := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (RatInterval.add a[i]! (RatInterval.const b[i]!)) - return out - -private def unionVecIntervals (a b : Array RatInterval) : Array RatInterval := - Id.run do - if a.size ≠ b.size then - return a - let mut out : Array RatInterval := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (RatInterval.union a[i]! b[i]!) - return out - -private def zeroIntervals (n : Nat) : Array RatInterval := - Array.replicate n (RatInterval.const 0) - -/-- Max GeLU derivative bound across a vector of rational intervals. -/ -private def maxGeluDerivBound (target : GeluDerivTarget) (xs : Array RatInterval) : Rat := - xs.foldl (fun acc x => max acc (RatInterval.geluDerivBound target x)) 0 - -/-- Sum of per-coordinate centered absolute bounds (interval widths). -/ -private def centeredAbsSum (xs : Array RatInterval) : Rat := - xs.foldl (fun acc x => acc + RatInterval.centeredAbsBound x) 0 - -/-- Max GeLU derivative bound across fixed-point intervals (converted to `Rat`). -/ -private def maxGeluDerivBoundFixed (cfg : Fixed10Cfg) (target : GeluDerivTarget) - (xs : Array Fixed10Interval) : Rat := - xs.foldl (fun acc x => max acc (Fixed10Interval.geluDerivBound cfg target x)) 0 - -private def unionRows (rows : Array (Array RatInterval)) (dim : Nat) : Array RatInterval := - Id.run do - if rows.isEmpty then - return zeroIntervals dim - let mut out : Array RatInterval := zeroIntervals dim - let r0 := rows[0]! - if r0.size = dim then - out := r0 - for r in rows do - if r.size = dim then - out := unionVecIntervals out r - return out - -private def layerNormRowApprox (row : Array RatInterval) (gamma beta : Array Rat) (eps : Rat) - (soundnessBits : Nat) : (Array RatInterval × Rat) := - if row.size = 0 || gamma.size ≠ row.size || beta.size ≠ row.size then - (row, 0) - else - Id.run do - let μ := RatInterval.mean row - let varLB := RatInterval.varianceLowerBound row - let invσUpper : Rat := - if varLB ≤ 0 then - -- Sound fallback for IBP propagation: `1/σ ≤ 1/eps` (conservative, but rigorous). - layerNormOpBoundConservative 1 eps soundnessBits - else - layerNormOpBoundLocal 1 varLB eps soundnessBits - let mut out : Array RatInterval := Array.mkEmpty row.size - for i in [:row.size] do - let centered := RatInterval.sub row[i]! μ - let scaled := RatInterval.scale (gamma[i]! * invσUpper) centered - out := out.push (RatInterval.add scaled (RatInterval.const beta[i]!)) - return (out, varLB) - -private def minVarAcrossRows (rows : Array (Array RatInterval)) : Rat := - Id.run do - let mut best : Option Rat := none - for r in rows do - let v := RatInterval.varianceLowerBound r - best := some (match best with | none => v | some b => min b v) - best.getD 0 - -private def findLineIdxFrom - (lines : Array String) (start : Nat) (p : String → Bool) : Option Nat := - Id.run do - let mut i := start - while i < lines.size do - if p (lines[i]!.trim) then - return some i - i := i + 1 - return none - -private def skipUntil (lines : Array String) (start : Nat) (p : String → Bool) : Nat := - match findLineIdxFrom lines start p with - | some i => i - | none => lines.size - -private def skipBlankLines (lines : Array String) (start : Nat) : Nat := - Id.run do - let mut i := start - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - return i - -/-! -### Fast skipping without parsing - -For local SOUND certification we do not need `W_Q`, `W_K`, `b_Q`, or `b_K` numerically -(they don't affect the Jacobian bounds we certify in this streaming-only pass). - -Parsing decimals into `Rat` is expensive, so we skip these sections by **counting tokens** -instead of calling `parseRat`. --/ - -private def countWsTokens (s : String) : Nat := - Id.run do - let bytes := s.toUTF8 - let mut i : Nat := 0 - let mut inTok : Bool := false - let mut cnt : Nat := 0 - while i < bytes.size do - let b := bytes[i]! - let isWs : Bool := b = 32 || b = 9 -- ' ' or '\t' - if isWs then - inTok := false - else if !inTok then - inTok := true - cnt := cnt + 1 - i := i + 1 - return cnt - -private def consumeTokensSkipFast - (lines : Array String) (start : Nat) (numTokens : Nat) : Except String Nat := - Id.run do - let mut iLine := start - let mut remaining := numTokens - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while skipping tokens" - let line := lines[iLine]!.trim - iLine := iLine + 1 - if line.isEmpty then - pure () - else - let c := countWsTokens line - if c ≥ remaining then - remaining := 0 - else - remaining := remaining - c - return .ok iLine - -private def consumeMatrixSkip - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String Nat := - match foldRatTokens lines start (rows * cols) () (fun _ _ => ()) with - | .error e => .error e - | .ok (_, next) => .ok next - -private def consumeMatrixSkipFast - (lines : Array String) - (start : Nat) - (rows cols : Nat) : Except String Nat := - consumeTokensSkipFast lines start (rows * cols) - -private def consumeVectorSkipFast - (lines : Array String) - (start : Nat) - (n : Nat) : Except String Nat := - consumeTokensSkipFast lines start n - -/-! -Streaming multiplication for row-major stored matrices. - -The `.nfpt` format stores matrices row-major with `rows` = input dimension and `cols` = output -dimension in the repo's row-vector convention: `y = x · W` where `W : rows×cols`. - -We compute `y` in a single pass over weights by accumulating contributions row-by-row: -for each input index `i`, parse the `i`-th row `w_{i,*}` and add `w_{i,j} * x[i]` into `y[j]`. -This never stores the matrix. --/ -private def consumeMatrixMulAndNormInf - (lines : Array String) - (start : Nat) - (rows cols : Nat) - (input : Array RatInterval) : Except String (Array RatInterval × Rat × Nat) := - Id.run do - if input.size ≠ rows then - return .error "input interval dimension mismatch" - let mut out : Array RatInterval := zeroIntervals cols - let mut iLine := start - let mut remaining := rows * cols - let mut idx : Nat := 0 - let mut curRowAbs : Rat := 0 - let mut maxRowAbs : Rat := 0 - while remaining > 0 do - if iLine ≥ lines.size then - return .error "unexpected end of file while reading matrix" - let line := lines[iLine]!.trim - iLine := iLine + 1 - if line.isEmpty then - pure () - else - let toks := line.splitOn " " |>.filter (· ≠ "") - for t in toks do - if remaining = 0 then - break - match parseRat t with - | .error e => return .error e - | .ok w => - let r := idx / cols - let c := idx % cols - curRowAbs := curRowAbs + ratAbs w - -- out[c] += w * input[r] - let term := RatInterval.scale w (input[r]!) - out := out.set! c (RatInterval.add (out[c]!) term) - idx := idx + 1 - remaining := remaining - 1 - if c + 1 = cols then - maxRowAbs := max maxRowAbs curRowAbs - curRowAbs := 0 - -- Account for a partial last row (should not happen if rows*cols consumed). - maxRowAbs := max maxRowAbs curRowAbs - return .ok (out, maxRowAbs, iLine) - -/-- Soundly compute conservative per-layer residual amplification constants from a `.nfpt` file. -/ -def certifyModelFileGlobal - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let actDerivBound := geluDerivBoundGlobal geluDerivTarget - let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray - -- Header - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut seqLen : Option Nat := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => - i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := v.toNat? - | "num_heads" => numHeads := v.toNat? - | "model_dim" => modelDim := v.toNat? - | "head_dim" => headDim := v.toNat? - | "hidden_dim" => hiddenDim := v.toNat? - | "seq_len" => seqLen := v.toNat? - | _ => pure () - i := i + 1 - let some L := numLayers | return .error "missing num_layers" - let some _ := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let some n := seqLen | return .error "missing seq_len" - let inputVarLowerMin? : Option Rat := none - -- Accumulators - let mut ln1GammaMax : Array Rat := Array.replicate L 1 - let mut ln1BetaMax : Array Rat := Array.replicate L 0 - let mut ln2GammaMax : Array Rat := Array.replicate L 1 - let mut attnValueCoeff : Array Rat := Array.replicate L 0 - let mut wqMax : Array Rat := Array.replicate L 0 - let mut wkMax : Array Rat := Array.replicate L 0 - let mut mlpWin : Array Rat := Array.replicate L 0 - let mut mlpWout : Array Rat := Array.replicate L 0 - let mut curLayer : Nat := 0 - -- Scan remaining sections - while i < lines.size do - let line := lines[i]!.trim - if line.startsWith "LAYER" then - let parts := line.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD 0 - i := i + 1 - else if line = "W_Q" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) d dh with - | .error e => return .error e - | .ok (nq, next) => - if r < wqMax.size then - wqMax := wqMax.set! r (max wqMax[r]! nq) - i := next - else if line = "W_K" then - let r := curLayer - match consumeMatrixNormInf lines (i + 1) d dh with - | .error e => return .error e - | .ok (nk, next) => - if r < wkMax.size then - wkMax := wkMax.set! r (max wkMax[r]! nk) - i := next - else if line = "W_V" then - -- Expect: modelDim × headDim - let r := curLayer - match consumeMatrixNormInf lines (i + 1) d dh with - | .error e => return .error e - | .ok (nv, next) => - -- Find W_O next by scanning forward (format guarantee: W_O follows eventually) - i := next - while i < lines.size && lines[i]!.trim ≠ "W_O" do - i := i + 1 - if !(i < lines.size) then - return .error "expected W_O after W_V" - match consumeMatrixNormInf lines (i + 1) dh d with - | .error e => return .error e - | .ok (no, next2) => - if r < attnValueCoeff.size then - attnValueCoeff := attnValueCoeff.set! r (attnValueCoeff[r]! + (nv * no)) - i := next2 - else if line = "W_in" then - match consumeMatrixNormInf lines (i + 1) d dhid with - | .error e => return .error e - | .ok (n, next) => - if curLayer < mlpWin.size then - mlpWin := mlpWin.set! curLayer n - i := next - else if line = "W_out" then - match consumeMatrixNormInf lines (i + 1) dhid d with - | .error e => return .error e - | .ok (n, next) => - if curLayer < mlpWout.size then - mlpWout := mlpWout.set! curLayer n - i := next - else if line = "LN1_GAMMA" then - match consumeMaxAbs lines (i + 1) d with - | .error e => return .error e - | .ok (m, next) => - if curLayer < ln1GammaMax.size then - ln1GammaMax := ln1GammaMax.set! curLayer m - i := next - else if line = "LN1_BETA" then - match consumeMaxAbs lines (i + 1) d with - | .error e => return .error e - | .ok (m, next) => - if curLayer < ln1BetaMax.size then - ln1BetaMax := ln1BetaMax.set! curLayer m - i := next - else if line = "LN2_GAMMA" then - match consumeMaxAbs lines (i + 1) d with - | .error e => return .error e - | .ok (m, next) => - if curLayer < ln2GammaMax.size then - ln2GammaMax := ln2GammaMax.set! curLayer m - i := next - else - -- default: advance - i := i + 1 - -- Build layer reports - let mut layers : Array LayerAmplificationCert := Array.mkEmpty L - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:L] do - let ln1Max := ln1GammaMax[l]! - let ln1MaxAbsBeta := ln1BetaMax[l]! - let ln2Max := ln2GammaMax[l]! - let ln1Var? : Option Rat := if l = 0 then inputVarLowerMin? else none - let ln2Var? : Option Rat := none - let ln1Bound := - match ln1Var? with - | some v => layerNormOpBoundLocal ln1Max v eps soundnessBits - | none => layerNormOpBoundConservative ln1Max eps soundnessBits - let ln2Bound := - match ln2Var? with - | some v => layerNormOpBoundLocal ln2Max v eps soundnessBits - | none => layerNormOpBoundConservative ln2Max eps soundnessBits - let ln1OutMaxAbsBound := layerNormOutputMaxAbsBound d ln1Max ln1MaxAbsBeta - let attnValueCoeffLayer := attnValueCoeff[l]! - let attnPatternCoeff := - attnPatternCoeffBound n d dh ln1OutMaxAbsBound (wqMax[l]!) (wkMax[l]!) - attnValueCoeffLayer - let mlpCoeff := mlpWin[l]! * mlpWout[l]! - let mlpActDerivBound := actDerivBound - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin n softmaxMarginLowerBound softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnW := - ln1Bound * ((n : Rat) * attnValueCoeffLayer + softmaxBound * attnPatternCoeff) - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1Max - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2Max - ln1VarianceLowerBound? := ln1Var? - ln2VarianceLowerBound? := ln2Var? - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax[l]! - wkOpBoundMax := wkMax[l]! - attnValueCoeff := attnValueCoeffLayer - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := mlpWin[l]! - mlpWoutBound := mlpWout[l]! - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - let cert : ModelCert := { - modelPath := path.toString - inputPath? := inputPath?.map (·.toString) - inputDelta := inputDelta - eps := eps - seqLen := n - modelDim := d - headDim := dh - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -/-- Parse input `EMBEDDINGS` from an `.nfpt` file and return intervals `xᵢ ∈ [xᵢ-δ, xᵢ+δ]` -as an array of rows (`seqLen` rows, each of length `modelDim`). -/ -private def loadEmbeddingsIntervals - (path : System.FilePath) (seqLen modelDim : Nat) (delta : Rat) : - IO (Except String (Array (Array RatInterval))) := do - let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty input file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected input header '{headerTag}'" - i := i + 1 - -- Scan to EMBEDDINGS (optionally skipping TOKENS). - i := skipUntil lines i (fun s => s = "EMBEDDINGS") - if !(i < lines.size) then - return .error "Missing EMBEDDINGS section in input file" - i := i + 1 - let step := - fun (st : (Array (Array RatInterval) × Array RatInterval)) (x : Rat) => - let (rows, cur) := st - let cur := cur.push { lo := x - delta, hi := x + delta } - if cur.size = modelDim then - (rows.push cur, #[]) - else - (rows, cur) - match foldRatTokens lines i (seqLen * modelDim) (#[], #[]) step with - | .error e => return .error e - | .ok ((rows, cur), _) => - if cur.size ≠ 0 then - return .error "EMBEDDINGS parse ended mid-row" - if rows.size ≠ seqLen then - return .error s!"EMBEDDINGS length mismatch: expected {seqLen} rows, got {rows.size}" - return .ok rows - -private structure LayerNormParams where - gamma : Array Rat - beta : Array Rat - -private structure LayerNormParamsFixed where - gamma : Array Fixed10Interval - beta : Array Fixed10Interval - -private def intervalsFromScaled (xs : Array Int) (slack : Int) : Array Fixed10Interval := - xs.map (fun x => { lo := x - slack, hi := x + slack }) - -private def collectLayerNormParams - (lines : Array String) (L d : Nat) : - Except String (Array LayerNormParams × Array LayerNormParams) := - Id.run do - let defP : LayerNormParams := { gamma := Array.replicate d 1, beta := Array.replicate d 0 } - let mut ln1 : Array LayerNormParams := - Array.replicate L defP - let mut ln2 : Array LayerNormParams := - Array.replicate L defP - let mut i : Nat := 0 - let mut curLayer : Nat := 0 - while i < lines.size do - let s := lines[i]!.trim - if s.startsWith "LAYER" then - let parts := s.splitOn " " |>.filter (· ≠ "") - if parts.length >= 2 then - curLayer := (parts[1]!).toNat? |>.getD curLayer - i := i + 1 - else if s = "LN1_GAMMA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln1.getD curLayer defP - ln1 := ln1.set! curLayer { old with gamma := xs } - i := next - else if s = "LN1_BETA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln1.getD curLayer defP - ln1 := ln1.set! curLayer { old with beta := xs } - i := next - else if s = "LN2_GAMMA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln2.getD curLayer defP - ln2 := ln2.set! curLayer { old with gamma := xs } - i := next - else if s = "LN2_BETA" then - match consumeVector lines (i + 1) d with - | .error e => return .error e - | .ok (xs, next) => - if curLayer < L then - let old := ln2.getD curLayer defP - ln2 := ln2.set! curLayer { old with beta := xs } - i := next - else - i := i + 1 - return .ok (ln1, ln2) - -private def collectLayerNormParamsBinary - (path : System.FilePath) - (scalePow10 : Nat) - (slack : Int) : - IO - (Except String - (BinaryHeader × Array LayerNormParamsFixed × Array LayerNormParamsFixed)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut ln1 : Array LayerNormParamsFixed := Array.replicate hdr.numLayers defP - let mut ln2 : Array LayerNormParamsFixed := Array.replicate hdr.numLayers defP - for l in [:hdr.numLayers] do - for _h in [:hdr.numHeads] do - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.headDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.hiddenDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.hiddenDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - let ln1GammaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln1Gamma ← - match ln1GammaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - let ln1BetaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln1Beta ← - match ln1BetaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - let ln2GammaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln2Gamma ← - match ln2GammaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - let ln2BetaE ← readScaledFloatArray h hdr.modelDim scalePow10 - let ln2Beta ← - match ln2BetaE with - | .error e => return .error e - | .ok xs => pure (intervalsFromScaled xs slack) - ln1 := ln1.set! l { gamma := ln1Gamma, beta := ln1Beta } - ln2 := ln2.set! l { gamma := ln2Gamma, beta := ln2Beta } - return .ok (hdr, ln1, ln2) - -/-! -## Cached fixed-point local certification (fast path) - -The original local path (`RatInterval` + `parseRat`) is mathematically rigorous but too slow for -large models because it performs gcd-based normalization on the hot path. - -We therefore prefer a cached fixed-point representation (`sound_cache/*.nfpc`) and run local IBP -in scaled-`Int` arithmetic with conservative outward rounding. --/ - -private def defaultFixedScalePow10 : Nat := 9 -private def fixedUlpSlack : Int := 1 - -private def scaleCfgOfPow10 (p : Nat) : Fixed10Cfg := { scalePow10 := p } - -private def ratCeilMulNat (x : Rat) (k : Nat) : Int := - if x ≤ 0 then - 0 - else - let num : Int := x.num - let den : Nat := x.den - let numK : Int := num * (Int.ofNat k) - let q := numK.ediv (Int.ofNat den) - let r := numK.emod (Int.ofNat den) - if r = 0 then q else q + 1 - -private def fixedMeanInterval (xs : Array Fixed10Interval) : Fixed10Interval := - if xs.isEmpty then - { lo := 0, hi := 0 } - else - Id.run do - let n : Nat := xs.size - let mut loSum : Int := 0 - let mut hiSum : Int := 0 - for x in xs do - loSum := loSum + x.lo - hiSum := hiSum + x.hi - let loμ := loSum.ediv (Int.ofNat n) - let hiμ := - let q := hiSum.ediv (Int.ofNat n) - let r := hiSum.emod (Int.ofNat n) - if r = 0 then q else q + 1 - { lo := loμ, hi := hiμ } - -private def fixedVarianceLowerBoundRange (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := - if xs.size < 2 then - 0 - else - Id.run do - let n : Nat := xs.size - let nRat : Rat := (n : Nat) - let mut loMax : Int := xs[0]!.lo - let mut hiMin : Int := xs[0]!.hi - for x in xs do - loMax := max loMax x.lo - hiMin := min hiMin x.hi - let δInt : Int := max 0 (loMax - hiMin) - if δInt = 0 then - return 0 - let δRat : Rat := - Rat.normalize δInt cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let δSq : Rat := δRat * δRat - return δSq / ((2 : Rat) * nRat) - -private def fixedLayerNormRowApprox - (cfg : Fixed10Cfg) - (row : Array Fixed10Interval) - (gamma beta : Array Fixed10Interval) - (eps : Rat) - (soundnessBits : Nat) : - (Array Fixed10Interval × Rat) := - if row.size = 0 || gamma.size ≠ row.size || beta.size ≠ row.size then - (row, 0) - else - Id.run do - let μ := fixedMeanInterval row - let varLB := fixedVarianceLowerBoundRange cfg row - let invσUpper : Rat := - if varLB ≤ 0 then - layerNormOpBoundConservative 1 eps soundnessBits - else - layerNormOpBoundLocal 1 varLB eps soundnessBits - let invσUpperInt : Int := ratCeilMulNat invσUpper cfg.scaleNat - let invσFix : Fixed10Interval := { lo := invσUpperInt, hi := invσUpperInt } - let mut out : Array Fixed10Interval := Array.mkEmpty row.size - for i in [:row.size] do - let centered := Fixed10Interval.sub row[i]! μ - let coeff := Fixed10Interval.mul cfg gamma[i]! invσFix - let scaled := Fixed10Interval.mul cfg coeff centered - out := out.push (Fixed10Interval.add scaled beta[i]!) - return (out, varLB) - -private def readVecIntervals - (r : SoundCache.I32Reader) (n : Nat) (slack : Int) : - IO (Array Fixed10Interval × SoundCache.I32Reader) := do - let mut rr := r - let mut out : Array Fixed10Interval := Array.mkEmpty n - for _ in [:n] do - let (x, rr2) ← Nfp.Untrusted.SoundCacheIO.I32Reader.readI32 rr - rr := rr2 - out := out.push { lo := x - slack, hi := x + slack } - return (out, rr) - -private def readVecIntervalsBinary - (h : IO.FS.Handle) (n : Nat) (slack : Int) (scalePow10 : Nat) : - IO (Except String (Array Fixed10Interval)) := do - match ← readScaledFloatArray h n scalePow10 with - | .error e => return .error e - | .ok xs => return .ok (intervalsFromScaled xs slack) - -private def matMulIntervalsFromScaled - (cfg : Fixed10Cfg) - (slack : Int) - (rows cols : Nat) - (weights : Array Int) - (input : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if input.size ≠ rows || weights.size ≠ rows * cols then - return Array.replicate cols { lo := 0, hi := 0 } - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - for rowIdx in [:rows] do - let xi := input[rowIdx]! - for colIdx in [:cols] do - let idx := rowIdx * cols + colIdx - let w := weights[idx]! - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - return out - -private def fixedDotInterval - (cfg : Fixed10Cfg) - (a b : Array Fixed10Interval) : Fixed10Interval := - if a.size = 0 || a.size ≠ b.size then - { lo := 0, hi := 0 } - else - Id.run do - let mut acc : Fixed10Interval := { lo := 0, hi := 0 } - for i in [:a.size] do - let term := Fixed10Interval.mul cfg a[i]! b[i]! - acc := Fixed10Interval.add acc term - return acc - -private def maxAbsVecFixed (xs : Array Fixed10Interval) : Int := - xs.foldl (fun acc x => max acc (Fixed10Interval.absUpper x)) 0 - -/-- Sum of per-coordinate centered absolute bounds (interval widths), as a `Rat`. -/ -private def centeredAbsSumFixed (cfg : Fixed10Cfg) (xs : Array Fixed10Interval) : Rat := - let sumWidth : Int := xs.foldl (fun acc x => acc + Fixed10Interval.centeredAbsBound x) 0 - Rat.normalize sumWidth cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - -private def addVecFixed (a b : Array Fixed10Interval) : Array Fixed10Interval := - Id.run do - if a.size ≠ b.size then - return a - let mut out := Array.mkEmpty a.size - for i in [:a.size] do - out := out.push (Fixed10Interval.add a[i]! b[i]!) - return out - -private def addVecFixedRows - (rows : Array (Array Fixed10Interval)) - (v : Array Fixed10Interval) : Array (Array Fixed10Interval) := - rows.map (fun row => addVecFixed row v) - -private def addRowsFixed - (rows : Array (Array Fixed10Interval)) - (adds : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := - Id.run do - if rows.size ≠ adds.size then - return rows - let mut out : Array (Array Fixed10Interval) := Array.mkEmpty rows.size - for i in [:rows.size] do - out := out.push (addVecFixed rows[i]! adds[i]!) - return out - -private def mlpRowFromScaled - (cfg : Fixed10Cfg) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (row : Array Fixed10Interval) : Array Fixed10Interval := - let hidden0 := matMulIntervalsFromScaled cfg slack modelDim hiddenDim wIn row - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let mlpOut0 := matMulIntervalsFromScaled cfg slack hiddenDim modelDim wOut actHidden - addVecFixed mlpOut0 bOut - -private def mlpRowsFromScaled - (cfg : Fixed10Cfg) - (slack : Int) - (modelDim hiddenDim : Nat) - (wIn wOut : Array Int) - (bIn bOut : Array Fixed10Interval) - (rows : Array (Array Fixed10Interval)) : Array (Array Fixed10Interval) := - let useTasks := rows.size > 32 - if useTasks then - let tasks := rows.map (fun row => - Task.spawn (fun _ => mlpRowFromScaled cfg slack modelDim hiddenDim wIn wOut bIn bOut row)) - tasks.map (fun t => t.get) - else - rows.map (mlpRowFromScaled cfg slack modelDim hiddenDim wIn wOut bIn bOut) - -private def groupUnionRowsByToken - (rows : Array (Array Fixed10Interval)) - (tokens : Array Int) : Array (Array Fixed10Interval) := - Id.run do - if rows.size ≠ tokens.size then - return rows - let mut uniqTokens : Array Int := #[] - let mut uniqRows : Array (Array Fixed10Interval) := #[] - for i in [:rows.size] do - let tok := tokens[i]! - match uniqTokens.findIdx? (· == tok) with - | some idx => - let merged := Fixed10Interval.unionVec (uniqRows[idx]!) rows[i]! - uniqRows := uniqRows.set! idx merged - | none => - uniqTokens := uniqTokens.push tok - uniqRows := uniqRows.push rows[i]! - return uniqRows - -private def unionRowsFixed - (rows : Array (Array Fixed10Interval)) : Array Fixed10Interval := - if rows.isEmpty then - #[] - else - Id.run do - let mut out := rows[0]! - for i in [1:rows.size] do - let row := rows[i]! - if row.size = out.size then - for j in [:out.size] do - let cur := out[j]! - let r := row[j]! - out := out.set! j { lo := min cur.lo r.lo, hi := max cur.hi r.hi } - return out - -private def consumeMatrixMulAndNormInfFixed - (cfg : Fixed10Cfg) - (slack : Int) - (r : SoundCache.I32Reader) - (rows cols : Nat) - (input : Array Fixed10Interval) : - IO (Array Fixed10Interval × Rat × SoundCache.I32Reader) := do - if input.size ≠ rows then - return (Array.replicate cols { lo := 0, hi := 0 }, 0, r) - let mut rr := r - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - let mut curRowAbs : Int := 0 - let mut maxRowAbs : Int := 0 - for rowIdx in [:rows] do - let xi := input[rowIdx]! - for colIdx in [:cols] do - let (w, rr2) ← Nfp.Untrusted.SoundCacheIO.I32Reader.readI32 rr - rr := rr2 - let wAbsBound : Int := (if w < 0 then -w else w) + slack - curRowAbs := curRowAbs + wAbsBound - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - maxRowAbs := max maxRowAbs curRowAbs - curRowAbs := 0 - let normInf : Rat := - Rat.normalize maxRowAbs cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - return (out, normInf, rr) - -private def consumeMatrixMulAndNormInfFixedBinary - (cfg : Fixed10Cfg) - (slack : Int) - (h : IO.FS.Handle) - (rows cols : Nat) - (input : Array Fixed10Interval) - (scalePow10 : Nat) : - IO (Except String (Array Fixed10Interval × Rat)) := do - if input.size ≠ rows then - match ← skipF64Array h (rows * cols) with - | .error e => return .error e - | .ok _ => return .ok (Array.replicate cols { lo := 0, hi := 0 }, 0) - match ← readScaledFloatArray h (rows * cols) scalePow10 with - | .error e => return .error e - | .ok vals => - let mut out : Array Fixed10Interval := Array.replicate cols { lo := 0, hi := 0 } - let mut curRowAbs : Int := 0 - let mut maxRowAbs : Int := 0 - for rowIdx in [:rows] do - let xi := input[rowIdx]! - for colIdx in [:cols] do - let idx := rowIdx * cols + colIdx - let w := vals[idx]! - let wAbsBound : Int := (if w < 0 then -w else w) + slack - curRowAbs := curRowAbs + wAbsBound - let wI : Fixed10Interval := { lo := w - slack, hi := w + slack } - let term := Fixed10Interval.mul cfg wI xi - out := out.set! colIdx (Fixed10Interval.add (out[colIdx]!) term) - maxRowAbs := max maxRowAbs curRowAbs - curRowAbs := 0 - let normInf : Rat := - Rat.normalize maxRowAbs cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - return .ok (out, normInf) - -private def loadEmbeddingsUnionFixed - (cfg : Fixed10Cfg) - (path : System.FilePath) - (expectedModelDim : Nat) - (delta : Rat) : IO (Except String (Array Fixed10Interval × Nat)) := do - let deltaInt : Int := ratCeilMulNat delta cfg.scaleNat - let mut out : Array Fixed10Interval := Array.replicate expectedModelDim { lo := 0, hi := 0 } - let mut iCol : Nat := 0 - let mut remaining : Nat := 0 - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - -- Header: read until blank line. - let mut seqLen : Option Nat := none - let mut modelDim : Option Nat := none - let mut seenHeader : Bool := false - while true do - let line ← h.getLine - if line.isEmpty then - return .error "unexpected EOF while reading input header" - let s := line.trim - if !seenHeader then - if s.startsWith "NFP_TEXT" then - seenHeader := true - continue - if s.isEmpty then - break - match parseHeaderLine s with - | none => pure () - | some (k, v) => - if k = "seq_len" then - seqLen := v.toNat? - else if k = "model_dim" then - modelDim := v.toNat? - else - pure () - let some n := seqLen | return .error "missing seq_len in input file" - let some d := modelDim | return .error "missing model_dim in input file" - if d ≠ expectedModelDim then - return .error s!"input model_dim mismatch (expected {expectedModelDim}, got {d})" - remaining := n * d - -- Scan to EMBEDDINGS marker. - let mut found : Bool := false - while !found do - let line ← h.getLine - if line.isEmpty then - return .error "unexpected EOF while scanning for EMBEDDINGS" - if line.trim = "EMBEDDINGS" then - found := true - while remaining > 0 do - let line ← h.getLine - if line.isEmpty then - return .error "unexpected EOF while reading EMBEDDINGS" - let s := line.trim - if s.isEmpty then - continue - let bytes := s.toUTF8 - let mut j : Nat := 0 - while j < bytes.size && remaining > 0 do - while j < bytes.size && (bytes[j]! = 32 || bytes[j]! = 9) do - j := j + 1 - if j ≥ bytes.size then - break - let tokStart := j - while j < bytes.size && (bytes[j]! ≠ 32 && bytes[j]! ≠ 9) do - j := j + 1 - let tokStop := j - match parseFixed10Rounded cfg.scalePow10 bytes tokStart tokStop with - | .error e => return .error e - | .ok x => - let lo := x - fixedUlpSlack - deltaInt - let hi := x + fixedUlpSlack + deltaInt - let cur := out[iCol]! - out := out.set! iCol { lo := min cur.lo lo, hi := max cur.hi hi } - iCol := (iCol + 1) % expectedModelDim - remaining := remaining - 1 - return .ok (out, n) - -/-- Parse binary embeddings into a union-box of fixed-point intervals. -/ -private def loadEmbeddingsUnionFixedBinary - (path : System.FilePath) - (expectedModelDim : Nat) - (delta : Rat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array Fixed10Interval)) := do - if delta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - if hdr.modelDim ≠ expectedModelDim then - return .error - s!"input model_dim mismatch (expected {expectedModelDim}, got {hdr.modelDim})" - let total := hdr.seqLen * hdr.modelDim - let deltaScaled : Int := ratCeilMulNat delta (Nat.pow 10 scalePow10) - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← readScaledFloatArray h total scalePow10 with - | .error e => return .error e - | .ok scaled => - if total = 0 then - return .ok #[] - let mut out : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for col in [:hdr.modelDim] do - let v := scaled[col]! - out := out.push { lo := v - deltaScaled, hi := v + deltaScaled } - for i in [hdr.modelDim:total] do - let v := scaled[i]! - let col := i % hdr.modelDim - let lo := v - deltaScaled - let hi := v + deltaScaled - let cur := out[col]! - out := out.set! col { lo := min cur.lo lo, hi := max cur.hi hi } - return .ok out - -/-- Parse binary embeddings into per-position fixed-point intervals. -/ -private def loadEmbeddingsIntervalsBinary - (path : System.FilePath) - (expectedModelDim : Nat) - (delta : Rat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array (Array Fixed10Interval))) := do - if delta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - if hdr.modelDim ≠ expectedModelDim then - return .error - s!"input model_dim mismatch (expected {expectedModelDim}, got {hdr.modelDim})" - let total := hdr.seqLen * hdr.modelDim - let deltaScaled : Int := ratCeilMulNat delta (Nat.pow 10 scalePow10) - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← readScaledFloatArray h total scalePow10 with - | .error e => return .error e - | .ok scaled => - if total = 0 then - return .ok #[] - let mut rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for rowIdx in [:hdr.seqLen] do - let mut row : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for colIdx in [:hdr.modelDim] do - let idx := rowIdx * hdr.modelDim + colIdx - let v := scaled[idx]! - row := row.push { lo := v - deltaScaled, hi := v + deltaScaled } - rows := rows.push row - return .ok rows - -private def loadTokensBinary - (path : System.FilePath) : IO (Except String (BinaryHeader × Array Int)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← readI32Array h hdr.seqLen with - | .error e => return .error e - | .ok toks => return .ok (hdr, toks) - -private def skipToUnembeddingBinary - (h : IO.FS.Handle) (hdr : BinaryHeader) : IO (Except String Unit) := do - let action : ExceptT String IO Unit := do - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - for _l in [:hdr.numLayers] do - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.hiddenDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.hiddenDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.hiddenDim * hdr.modelDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - action.run - -/-- Compute local head output lower bounds at a specific query position (binary only). -/ -private def certifyHeadValueLowerBoundLocalBinaryAt - (path : System.FilePath) - (layerIdx headIdx queryPos coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (matchWeightLowerBound : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String HeadValueLowerBoundPosCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadValueLowerBoundPosCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if coord ≥ hdr.modelDim then - throw s!"coord index {coord} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) - let matchLo ← - match matchLo? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := matchWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadValueLowerBoundPosCert := { - layerIdx := layerIdx - headIdx := headIdx - queryPos := queryPos - coord := coord - matchWeightLowerBound := weightLB - matchCoordLowerBound := matchLoRat - nonmatchCoordLowerBound := nonmatchLoRat - outputCoordLowerBound := outputLB - } - if cert.check then - return cert - throw "head value lower bound (pos) failed internal consistency checks" - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - - -private def readUnembeddingColumnsBinary - (path : System.FilePath) - (tokenA tokenB : Nat) - (scalePow10 : Nat) : - IO (Except String (BinaryHeader × Array Int × Array Int)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let action : ExceptT String IO (BinaryHeader × Array Int × Array Int) := do - let hdr ← ExceptT.mk (readBinaryHeader h) - if tokenA ≥ hdr.vocabSize || tokenB ≥ hdr.vocabSize then - throw "token index out of range for unembedding" - if tokenA = tokenB then - throw "target and negative tokens must differ" - let _ ← ExceptT.mk (skipToUnembeddingBinary h hdr) - let loTok := min tokenA tokenB - let hiTok := max tokenA tokenB - let swapped : Bool := tokenA > tokenB - let mut colA : Array Int := Array.mkEmpty hdr.modelDim - let mut colB : Array Int := Array.mkEmpty hdr.modelDim - for _r in [:hdr.modelDim] do - let _ ← ExceptT.mk (skipF64Array h loTok) - let vLo ← ExceptT.mk (readScaledFloat h scalePow10) - let _ ← ExceptT.mk (skipF64Array h (hiTok - loTok - 1)) - let vHi ← ExceptT.mk (readScaledFloat h scalePow10) - let _ ← ExceptT.mk (skipF64Array h (hdr.vocabSize - hiTok - 1)) - if swapped then - colA := colA.push vHi - colB := colB.push vLo - else - colA := colA.push vLo - colB := colB.push vHi - return (hdr, colA, colB) - action.run - -private def readLogitDiffDirectionBinary - (path : System.FilePath) - (targetToken negativeToken : Nat) - (scalePow10 : Nat) - (slack : Int) : - IO (Except String (BinaryHeader × Array Fixed10Interval)) := do - let action : ExceptT String IO (BinaryHeader × Array Fixed10Interval) := do - let (hdr, colTarget, colNeg) ← - ExceptT.mk (readUnembeddingColumnsBinary path targetToken negativeToken scalePow10) - if colTarget.size ≠ hdr.modelDim || colNeg.size ≠ hdr.modelDim then - throw "unembedding column size mismatch" - let targetIntervals := intervalsFromScaled colTarget slack - let negIntervals := intervalsFromScaled colNeg slack - let mut dir : Array Fixed10Interval := Array.mkEmpty hdr.modelDim - for i in [:hdr.modelDim] do - dir := dir.push (Fixed10Interval.sub targetIntervals[i]! negIntervals[i]!) - return (hdr, dir) - action.run - -/-- Compute local head logit-difference lower bounds at a specific query position (binary only). -/ -private def certifyHeadLogitDiffLowerBoundLocalBinaryAt - (path : System.FilePath) - (layerIdx headIdx queryPos : Nat) - (targetToken negativeToken : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (matchWeightLowerBound : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String HeadLogitDiffLowerBoundPosCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadLogitDiffLowerBoundPosCert := do - let (hdrDir, direction) ← - ExceptT.mk (readLogitDiffDirectionBinary path targetToken negativeToken scalePow10 slack) - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if hdr.modelDim ≠ hdrDir.modelDim then - throw "unembedding model_dim mismatch" - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if direction.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let mut vDotRows : Array Fixed10Interval := Array.mkEmpty hdr.seqLen - for row in vOutRows do - vDotRows := vDotRows.push (fixedDotInterval cfg row direction) - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vLo - | some m => some (min m vLo) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vLo - | some m => some (min m vLo) - let matchLo ← - match matchLo? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := matchWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadLogitDiffLowerBoundPosCert := { - layerIdx := layerIdx - headIdx := headIdx - queryPos := queryPos - targetToken := targetToken - negativeToken := negativeToken - matchWeightLowerBound := weightLB - matchLogitLowerBound := matchLoRat - nonmatchLogitLowerBound := nonmatchLoRat - logitDiffLowerBound := outputLB - } - if cert.check then - return cert - throw "head logit lower bound (pos) failed internal consistency checks" - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -private def ensureSoundCache - (modelPath : System.FilePath) - (scalePow10 : Nat := defaultFixedScalePow10) : - IO (Except String (System.FilePath × SoundCache.Header)) := do - Nfp.Untrusted.SoundCacheIO.ensureCacheDir - let modelHash ← Nfp.Untrusted.SoundCacheIO.fnv1a64File modelPath - let mdata ← modelPath.metadata - let modelSize : UInt64 := mdata.byteSize - let cpath := SoundCache.cachePath modelPath modelHash scalePow10 - if !(← cpath.pathExists) then - match (← Nfp.Untrusted.SoundCacheIO.buildCacheFile modelPath cpath scalePow10) with - | .error e => return .error e - | .ok _ => pure () - let h ← IO.FS.Handle.mk cpath IO.FS.Mode.read - let hdr ← Nfp.Untrusted.SoundCacheIO.readHeader h - if hdr.modelHash ≠ modelHash then - return .error "sound cache hash mismatch" - if hdr.modelSize ≠ modelSize then - return .error "sound cache size mismatch" - return .ok (cpath, hdr) - -private def readWqWkMaxBinary - (path : System.FilePath) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array Rat × Array Rat)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - match ← readBinaryHeader h with - | .error e => return .error e - | .ok hdr => - match ← skipI32Array h hdr.seqLen with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.seqLen * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let mut wqMax : Array Rat := Array.replicate hdr.numLayers 0 - let mut wkMax : Array Rat := Array.replicate hdr.numLayers 0 - for l in [:hdr.numLayers] do - let mut wqLayer : Rat := 0 - let mut wkLayer : Rat := 0 - for _h in [:hdr.numHeads] do - let wqScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wqScaled ← - match wqScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - let wkScaledE ← readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10 - let wkScaled ← - match wkScaledE with - | .error e => return .error e - | .ok v => pure v - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.headDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.headDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.headDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - let wq := ratOfScaledInt scalePow10 wqScaled - let wk := ratOfScaledInt scalePow10 wkScaled - wqLayer := max wqLayer wq - wkLayer := max wkLayer wk - wqMax := wqMax.set! l wqLayer - wkMax := wkMax.set! l wkLayer - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.modelDim * hdr.hiddenDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.hiddenDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h (hdr.hiddenDim * hdr.modelDim) with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - match ← skipF64Array h hdr.modelDim with - | .error e => return .error e - | .ok _ => pure () - return .ok (wqMax, wkMax) - -/-- Local (input-dependent) certificate path using streaming interval propagation. - -This is conservative in two key ways to remain streaming/memory-safe: -- it uses a **union box** over tokens throughout (so we never hold `seqLen×modelDim` intervals), - which is sound (a superset) but can be looser than per-token tracking, -- it uses union boxes for attention/MLP linear maps to avoid `seqLen×hiddenDim` blowups. --/ -private def certifyModelFileLocalText - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let contents ← IO.FS.readFile path - let lines : Array String := (contents.splitOn "\n").toArray - -- Header - let mut i : Nat := 0 - while i < lines.size && lines[i]!.trim.isEmpty do - i := i + 1 - if !(i < lines.size) then - return .error "empty model file" - let headerTag := lines[i]!.trim - if !headerTag.startsWith "NFP_TEXT" then - return .error s!"unexpected header '{headerTag}'" - i := i + 1 - let mut numLayers : Option Nat := none - let mut numHeads : Option Nat := none - let mut modelDim : Option Nat := none - let mut headDim : Option Nat := none - let mut hiddenDim : Option Nat := none - let mut seqLen : Option Nat := none - while i < lines.size do - let line := lines[i]!.trim - if line.isEmpty then - i := i + 1 - break - match parseHeaderLine line with - | none => i := i + 1 - | some (k, v) => - match k with - | "num_layers" => numLayers := v.toNat? - | "num_heads" => numHeads := v.toNat? - | "model_dim" => modelDim := v.toNat? - | "head_dim" => headDim := v.toNat? - | "hidden_dim" => hiddenDim := v.toNat? - | "seq_len" => seqLen := v.toNat? - | _ => pure () - i := i + 1 - let some L := numLayers | return .error "missing num_layers" - let some H := numHeads | return .error "missing num_heads" - let some d := modelDim | return .error "missing model_dim" - let some dh := headDim | return .error "missing head_dim" - let some dhid := hiddenDim | return .error "missing hidden_dim" - let some n := seqLen | return .error "missing seq_len" - -- Prepass: collect LN parameters. - let (ln1Params, ln2Params) ← - match collectLayerNormParams lines L d with - | .error e => return .error e - | .ok x => pure x - let defLn : LayerNormParams := { gamma := Array.replicate d 1, beta := Array.replicate d 0 } - -- Input: per-token residual intervals. - let residual0 ← loadEmbeddingsIntervals inputPath n d inputDelta - match residual0 with - | .error e => return .error e - | .ok residualRows0 => - -- Use a single union box for all tokens (sound superset, much faster than - -- `seqLen×modelDim`). - let mut residualUnion := unionRows residualRows0 d - -- Start scanning at first layer marker. - let mut pos : Nat := skipUntil lines 0 (fun s => s.startsWith "LAYER") - let mut layers : Array LayerAmplificationCert := Array.mkEmpty L - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:L] do - -- Ensure we're at the next layer. - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - if pos ≥ lines.size then - return .error s!"unexpected end of file while scanning layer {l}" - pos := pos + 1 - -- LN1: compute per-row outputs (for union) and min variance LB (for Jacobian bound). - let p1 := ln1Params.getD l defLn - let (ln1Out, ln1VarLB) := - layerNormRowApprox residualUnion p1.gamma p1.beta eps soundnessBits - let ln1MaxAbsGamma := maxAbsOfVector p1.gamma - let ln1MaxAbsBeta := maxAbsOfVector p1.beta - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let ln1OutMaxAbsBound := layerNormOutputMaxAbsBound d ln1MaxAbsGamma ln1MaxAbsBeta - let ln1Union := ln1Out - -- Attention (streaming): use union input box. - let mut attnUnion : Array RatInterval := zeroIntervals d - let mut attnValueCoeff : Rat := 0 - let mut wqMax : Rat := 0 - let mut wkMax : Rat := 0 - for _h in [:H] do - pos := skipBlankLines lines pos - if !(pos < lines.size) then - return .error "unexpected end of file while scanning HEAD" - if !(lines[pos]!.trim.startsWith "HEAD") then - return .error "expected HEAD marker before per-head matrices" - pos := pos + 1 - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_Q") then - return .error "missing W_Q" - match consumeMatrixNormInf lines (pos + 1) d dh with - | .error e => return .error e - | .ok (nq, next) => - wqMax := max wqMax nq - pos := next - -- Optional per-head Q bias (does not affect Jacobian, - -- but must be parsed to stay in sync). - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_Q" then - match consumeVectorSkipFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_K") then - return .error "missing W_K" - match consumeMatrixNormInf lines (pos + 1) d dh with - | .error e => return .error e - | .ok (nk, next) => - wkMax := max wkMax nk - pos := next - -- Optional per-head K bias (does not affect Jacobian, - -- but must be parsed to stay in sync). - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "b_K" then - match consumeVectorSkipFast lines (pos + 1) dh with - | .error e => return .error e - | .ok next => pos := next - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_V") then - return .error "missing W_V" - match consumeMatrixMulAndNormInf lines (pos + 1) d dh ln1Union with - | .error e => return .error e - | .ok (vHidden, _nWv, nextV) => - pos := nextV - -- Optional per-head V bias (affects forward activations / variance, - -- so we include it). - pos := skipBlankLines lines pos - let mut vHidden := vHidden - if pos < lines.size && lines[pos]!.trim = "b_V" then - match consumeVector lines (pos + 1) dh with - | .error e => return .error e - | .ok (bv, nextBv) => - pos := nextBv - vHidden := addConstVec vHidden bv - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_O") then - return .error "missing W_O" - let vCenteredOpBound := centeredAbsSum vHidden - match consumeMatrixMulAndNormInf lines (pos + 1) dh d vHidden with - | .error e => return .error e - | .ok (vOut, no, nextO) => - pos := nextO - attnUnion := addVecIntervals attnUnion vOut - attnValueCoeff := attnValueCoeff + vCenteredOpBound * no - -- Shared attention projection bias (affects forward activations / variance, - -- so we include it). - pos := skipBlankLines lines pos - if pos < lines.size && lines[pos]!.trim = "ATTN_BIAS" then - match consumeVector lines (pos + 1) d with - | .error e => return .error e - | .ok (bAttn, nextB) => - pos := nextB - attnUnion := addConstVec attnUnion bAttn - residualUnion := addVecIntervals residualUnion attnUnion - -- LN2: compute per-row outputs and min variance LB. - let p2 := ln2Params.getD l defLn - let (ln2Out, ln2VarLB) := - layerNormRowApprox residualUnion p2.gamma p2.beta eps soundnessBits - let ln2MaxAbsGamma := maxAbsOfVector p2.gamma - let ln2Bound := - if ln2VarLB > 0 then - layerNormOpBoundLocal ln2MaxAbsGamma ln2VarLB eps soundnessBits - else - layerNormOpBoundConservative ln2MaxAbsGamma eps soundnessBits - let ln2Union := ln2Out - -- MLP (streaming): W_in, b_in, W_out, b_out. - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "MLP") then - return .error "missing MLP section" - pos := pos + 1 - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_in") then - return .error "missing W_in" - match consumeMatrixMulAndNormInf lines (pos + 1) d dhid ln2Union with - | .error e => return .error e - | .ok (hidden, nWin, nextWin) => - pos := nextWin - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_in") then - return .error "missing b_in" - match consumeVector lines (pos + 1) dhid with - | .error e => return .error e - | .ok (bin, nextBin) => - pos := nextBin - let hiddenB := addConstVec hidden bin - let mlpActDerivBound := maxGeluDerivBound geluDerivTarget hiddenB - let actHidden := hiddenB.map RatInterval.geluOverapprox - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "W_out") then - return .error "missing W_out" - match consumeMatrixMulAndNormInf lines (pos + 1) dhid d actHidden with - | .error e => return .error e - | .ok (mlpOut0, nWout, nextWout) => - pos := nextWout - pos := skipBlankLines lines pos - if !(pos < lines.size && lines[pos]!.trim = "b_out") then - return .error "missing b_out" - match consumeVector lines (pos + 1) d with - | .error e => return .error e - | .ok (bout, nextBout) => - pos := nextBout - let mlpOut := addConstVec mlpOut0 bout - residualUnion := addVecIntervals residualUnion mlpOut - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 - let softmaxIntervalBound := - softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin n softmaxMarginLowerBound - softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnPatternCoeff := - attnPatternCoeffBound n d dh ln1OutMaxAbsBound wqMax wkMax - attnValueCoeff - let attnW := - ln1Bound * - ((n : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpCoeff := nWin * nWout - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - ln1VarianceLowerBound? := some ln1VarLB - ln2VarianceLowerBound? := some ln2VarLB - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - pos := skipUntil lines pos (fun s => s.startsWith "LAYER") - let cert : ModelCert := { - modelPath := path.toString - inputPath? := some inputPath.toString - inputDelta := inputDelta - eps := eps - seqLen := n - modelDim := d - headDim := dh - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -private def certifyModelFileLocal - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - -- Prefer cached fixed-point path; fall back to the (slow) Rat-based path on any cache error. - match (← ensureSoundCache path) with - | .error _ => - certifyModelFileLocalText path eps geluDerivTarget soundnessBits partitionDepth - inputPath inputDelta softmaxMarginLowerBound softmaxExpEffort - | .ok (cpath, hdr) => - let cfg : Fixed10Cfg := scaleCfgOfPow10 hdr.scalePow10.toNat - let slack : Int := fixedUlpSlack - let modelDim := hdr.modelDim.toNat - let headDim := hdr.headDim.toNat - let hiddenDim := hdr.hiddenDim.toNat - let L := hdr.numLayers.toNat - let H := hdr.numHeads.toNat - let wqWkE ← readWqWkMaxBinary path (scalePow10 := hdr.scalePow10.toNat) - let (wqMaxArr, wkMaxArr) ← - match wqWkE with - | .error e => return .error e - | .ok v => pure v - -- For now we read embeddings from the input `.nfpt` file and use a union box. - let residualUnionE ← loadEmbeddingsUnionFixed cfg inputPath modelDim inputDelta - match residualUnionE with - | .error e => return .error e - | .ok (residualUnion0, inputSeqLen) => - let mut residualUnion := residualUnion0 - -- Open cache and position reader after header. - let ch ← IO.FS.Handle.mk cpath IO.FS.Mode.read - let _ ← Nfp.Untrusted.SoundCacheIO.readHeader ch - let mut rr ← Nfp.Untrusted.SoundCacheIO.I32Reader.init ch - let mut layers : Array LayerAmplificationCert := Array.mkEmpty L - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:L] do - -- LN params from cache - let (ln1Gamma, rr1) ← readVecIntervals rr modelDim slack - let (ln1Beta, rr2) ← readVecIntervals rr1 modelDim slack - let (ln2Gamma, rr3) ← readVecIntervals rr2 modelDim slack - let (ln2Beta, rr4) ← readVecIntervals rr3 modelDim slack - rr := rr4 - -- LN1 - let (ln1Out, ln1VarLB) := - fixedLayerNormRowApprox cfg residualUnion ln1Gamma ln1Beta eps soundnessBits - let ln1MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed ln1Gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1MaxAbsBeta : Rat := - Rat.normalize (maxAbsVecFixed ln1Beta) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let ln1OutMaxAbsBound := - layerNormOutputMaxAbsBound modelDim ln1MaxAbsGamma ln1MaxAbsBeta - -- Attention (streaming from cache) - let mut attnUnion : Array Fixed10Interval := - Array.replicate modelDim { lo := 0, hi := 0 } - let mut attnValueCoeff : Rat := 0 - for _h in [:H] do - let (vHidden0, _nWv, rrV) ← - consumeMatrixMulAndNormInfFixed cfg slack rr modelDim headDim ln1Out - rr := rrV - let (bV, rrBv) ← readVecIntervals rr headDim slack - rr := rrBv - let vHidden := addVecFixed vHidden0 bV - let vCenteredOpBound := centeredAbsSumFixed cfg vHidden - let (vOut, nWo, rrO) ← - consumeMatrixMulAndNormInfFixed cfg slack rr headDim modelDim vHidden - rr := rrO - attnUnion := addVecFixed attnUnion vOut - attnValueCoeff := attnValueCoeff + vCenteredOpBound * nWo - let (attnBias, rrB) ← readVecIntervals rr modelDim slack - rr := rrB - attnUnion := addVecFixed attnUnion attnBias - residualUnion := addVecFixed residualUnion attnUnion - -- LN2 - let (ln2Out, ln2VarLB) := - fixedLayerNormRowApprox cfg residualUnion ln2Gamma ln2Beta eps soundnessBits - let ln2MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed ln2Gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln2Bound := - if ln2VarLB > 0 then - layerNormOpBoundLocal ln2MaxAbsGamma ln2VarLB eps soundnessBits - else - layerNormOpBoundConservative ln2MaxAbsGamma eps soundnessBits - -- MLP - let (hidden0, nWin, rrWin) ← - consumeMatrixMulAndNormInfFixed cfg slack rr modelDim hiddenDim ln2Out - rr := rrWin - let (bIn, rrBin) ← readVecIntervals rr hiddenDim slack - rr := rrBin - let hiddenB := addVecFixed hidden0 bIn - let mlpActDerivBound := maxGeluDerivBoundFixed cfg geluDerivTarget hiddenB - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, nWout, rrWout) ← - consumeMatrixMulAndNormInfFixed cfg slack rr hiddenDim modelDim actHidden - rr := rrWout - let (bOut, rrBout) ← readVecIntervals rr modelDim slack - rr := rrBout - let mlpOut := addVecFixed mlpOut0 bOut - residualUnion := addVecFixed residualUnion mlpOut - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin inputSeqLen softmaxMarginLowerBound - softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnPatternCoeff := - attnPatternCoeffBound inputSeqLen modelDim headDim ln1OutMaxAbsBound - (wqMaxArr[l]!) (wkMaxArr[l]!) attnValueCoeff - let attnW := - ln1Bound * - ((inputSeqLen : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpCoeff := nWin * nWout - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - ln1VarianceLowerBound? := some ln1VarLB - ln2VarianceLowerBound? := some ln2VarLB - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMaxArr[l]! - wkOpBoundMax := wkMaxArr[l]! - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - let cert : ModelCert := { - modelPath := path.toString - inputPath? := some inputPath.toString - inputDelta := inputDelta - eps := eps - seqLen := inputSeqLen - modelDim := modelDim - headDim := headDim - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return .ok cert - return .error "sound certificate failed internal consistency checks" - -private def certifyModelFileLocalBinary - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (partitionDepth : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - if partitionDepth ≠ 0 then - return .error "partitionDepth > 0 not yet implemented" - let scalePow10 := defaultBinaryScalePow10 - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO ModelCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residualUnion0 ← - ExceptT.mk (loadEmbeddingsUnionFixedBinary inputPath hdr.modelDim inputDelta scalePow10) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residualUnion := residualUnion0 - let mut layers : Array LayerAmplificationCert := Array.mkEmpty hdr.numLayers - let mut totalAmp : Rat := 1 - let mut actDerivBoundMax : Rat := 0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let (ln1Out, ln1VarLB) := - fixedLayerNormRowApprox cfg residualUnion p1.gamma p1.beta eps soundnessBits - let ln1MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed p1.gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1MaxAbsBeta : Rat := - Rat.normalize (maxAbsVecFixed p1.beta) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let ln1OutMaxAbsBound := - layerNormOutputMaxAbsBound hdr.modelDim ln1MaxAbsGamma ln1MaxAbsBeta - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let mut attnValueCoeff : Rat := 0 - let mut wqMax : Rat := 0 - let mut wkMax : Rat := 0 - for _h in [:hdr.numHeads] do - let wqScaled ← - ExceptT.mk (readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10) - wqMax := max wqMax (ratOfScaledInt scalePow10 wqScaled) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wkScaled ← - ExceptT.mk (readMatrixNormInfScaled h hdr.modelDim hdr.headDim scalePow10) - wkMax := max wkMax (ratOfScaledInt scalePow10 wkScaled) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Out scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let vCenteredOpBound := centeredAbsSumFixed cfg vHidden - let (vOut, nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - attnValueCoeff := attnValueCoeff + vCenteredOpBound * nWo - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residualUnion := addVecFixed residualUnion attnUnion - let p2 := ln2Params.getD l defP - let (ln2Out, ln2VarLB) := - fixedLayerNormRowApprox cfg residualUnion p2.gamma p2.beta eps soundnessBits - let ln2MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed p2.gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln2Bound := - if ln2VarLB > 0 then - layerNormOpBoundLocal ln2MaxAbsGamma ln2VarLB eps soundnessBits - else - layerNormOpBoundConservative ln2MaxAbsGamma eps soundnessBits - let (hidden0, nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Out scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let mlpActDerivBound := maxGeluDerivBoundFixed cfg geluDerivTarget hiddenB - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residualUnion := addVecFixed residualUnion mlpOut - let softmaxProbLo : Rat := 0 - let softmaxProbHi : Rat := 1 - let softmaxIntervalBound := softmaxJacobianNormInfBound softmaxProbLo softmaxProbHi - let softmaxMarginBound := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen softmaxMarginLowerBound softmaxExpEffort - let softmaxBound := min softmaxIntervalBound softmaxMarginBound - let attnPatternCoeff := - attnPatternCoeffBound hdr.seqLen hdr.modelDim hdr.headDim ln1OutMaxAbsBound - wqMax wkMax attnValueCoeff - let attnW := - ln1Bound * - ((hdr.seqLen : Rat) * attnValueCoeff + softmaxBound * attnPatternCoeff) - let mlpCoeff := nWin * nWout - let mlpW := ln2Bound * (mlpCoeff * mlpActDerivBound) - let C := attnW + mlpW + attnW * mlpW - layers := layers.push { - layerIdx := l - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1MaxAbsBeta := ln1MaxAbsBeta - ln2MaxAbsGamma := ln2MaxAbsGamma - ln1VarianceLowerBound? := some ln1VarLB - ln2VarianceLowerBound? := some ln2VarLB - ln1Bound := ln1Bound - ln2Bound := ln2Bound - ln1OutMaxAbsBound := ln1OutMaxAbsBound - softmaxProbLo := softmaxProbLo - softmaxProbHi := softmaxProbHi - softmaxMarginLowerBound := softmaxMarginLowerBound - softmaxExpEffort := softmaxExpEffort - softmaxJacobianNormInfUpperBound := softmaxBound - wqOpBoundMax := wqMax - wkOpBoundMax := wkMax - attnValueCoeff := attnValueCoeff - attnPatternCoeff := attnPatternCoeff - mlpCoeff := mlpCoeff - mlpWinBound := nWin - mlpWoutBound := nWout - mlpActDerivBound := mlpActDerivBound - attnJacBound := attnW - mlpJacBound := mlpW - C := C - } - totalAmp := totalAmp * (1 + C) - actDerivBoundMax := max actDerivBoundMax mlpActDerivBound - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let cert : ModelCert := { - modelPath := path.toString - inputPath? := some inputPath.toString - inputDelta := inputDelta - eps := eps - seqLen := hdr.seqLen - modelDim := hdr.modelDim - headDim := hdr.headDim - soundnessBits := soundnessBits - geluDerivTarget := geluDerivTarget - actDerivBound := actDerivBoundMax - softmaxJacobianNormInfWorst := softmaxJacobianNormInfWorst - layers := layers - totalAmplificationFactor := totalAmp - } - if cert.check then - return cert - throw "sound certificate failed internal consistency checks" - action.run - -/-- Compute local per-head attention contribution bounds from a binary `.nfpt`. -/ -private def certifyHeadBoundsLocalBinary - (path : System.FilePath) - (eps : Rat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (soundnessBits : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadLocalContributionCert)) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO (Array HeadLocalContributionCert) := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - let residualUnion0 ← - ExceptT.mk (loadEmbeddingsUnionFixedBinary inputPath hdr.modelDim inputDelta scalePow10) - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residualUnion := residualUnion0 - let mut heads : Array HeadLocalContributionCert := - Array.mkEmpty (hdr.numLayers * hdr.numHeads) - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let (ln1Out, ln1VarLB) := - fixedLayerNormRowApprox cfg residualUnion p1.gamma p1.beta eps soundnessBits - let ln1MaxAbsGamma : Rat := - Rat.normalize (maxAbsVecFixed p1.gamma) cfg.scaleNat (den_nz := by - have h10pos : (0 : Nat) < 10 := by decide - exact Nat.ne_of_gt (Nat.pow_pos (n := cfg.scalePow10) h10pos)) - let ln1Bound := - if ln1VarLB > 0 then - layerNormOpBoundLocal ln1MaxAbsGamma ln1VarLB eps soundnessBits - else - layerNormOpBoundConservative ln1MaxAbsGamma eps soundnessBits - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for hIdx in [:hdr.numHeads] do - let wqScaledE ← - ExceptT.mk (readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10) - let wqOp := ratOfScaledNat scalePow10 wqScaledE - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wkScaledE ← - ExceptT.mk (readMatrixOpBoundScaled h hdr.modelDim hdr.headDim scalePow10) - let wkOp := ratOfScaledNat scalePow10 wkScaledE - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Out scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let vCenteredOpBound := centeredAbsSumFixed cfg vHidden - let (vOut, nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnW := ln1Bound * softmaxJacobianNormInfWorst * vCenteredOpBound * nWo - let cert : HeadLocalContributionCert := { - layerIdx := l - headIdx := hIdx - soundnessBits := soundnessBits - ln1MaxAbsGamma := ln1MaxAbsGamma - ln1VarianceLowerBound := ln1VarLB - ln1Bound := ln1Bound - wqOpBound := wqOp - wkOpBound := wkOp - wvOpBound := vCenteredOpBound - woOpBound := nWo - qkFactorBound := wqOp * wkOp - attnJacBound := attnW - } - if cert.check eps then - heads := heads.push cert - else - throw "local head contribution certificate failed internal checks" - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residualUnion := addVecFixed residualUnion attnUnion - let p2 := ln2Params.getD l defP - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg residualUnion p2.gamma p2.beta eps soundnessBits - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Out scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residualUnion := addVecFixed residualUnion mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - return heads - action.run - -/-- Compute local attention pattern bounds for a specific binary head. -/ -private def certifyHeadPatternLocalBinary - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (maxSeqLen : Nat) - (tightPattern : Bool) - (tightPatternLayers : Nat) - (perRowPatternLayers : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String HeadPatternCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadPatternCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let mut wq? : Option (Array Int) := none - let mut wk? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - if hIdx = headIdx then - let wq ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wq? := some wq - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - pure () - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = headIdx then - let wk ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wk? := some wk - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - pure () - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wq ← - match wq? with - | none => throw "missing W_Q for requested head" - | some xs => pure xs - let wk ← - match wk? with - | none => throw "missing W_K for requested head" - | some xs => pure xs - let mut qRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - qRows := qRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq row) - kRows := kRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row) - let mut minTargetLower? : Option Int := none - let mut maxOtherUpper? : Option Int := none - let mut minTargetCount? : Option Nat := none - for i in [:hdr.seqLen] do - let ti : Int := (Int.ofNat i) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let qRow := qRows[i]! - let mut targetLower? : Option Int := none - let mut targetMaxLower? : Option Int := none - let mut maxOtherUpperRow? : Option Int := none - let mut targetCount : Nat := 0 - for j in [:hdr.seqLen] do - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - targetCount := targetCount + 1 - targetLower? := - match targetLower? with - | none => some dot.lo - | some m => some (min m dot.lo) - targetMaxLower? := - match targetMaxLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - let cur := dot.hi - maxOtherUpperRow? := - match maxOtherUpperRow? with - | none => some cur - | some m => some (max m cur) - let targetLowerRow? := - if tightPattern then targetMaxLower? else targetLower? - match targetLowerRow? with - | none => pure () - | some targetLower => - let maxOtherUpperRow := - match maxOtherUpperRow? with - | none => targetLower - | some v => v - minTargetLower? := - match minTargetLower? with - | none => some targetLower - | some m => some (min m targetLower) - maxOtherUpper? := - match maxOtherUpper? with - | none => some maxOtherUpperRow - | some m => some (max m maxOtherUpperRow) - minTargetCount? := - match minTargetCount? with - | none => some targetCount - | some m => some (min m targetCount) - let minTargetLower ← - match minTargetLower? with - | none => throw "no valid target positions for the requested offset" - | some v => pure v - let minTargetCount : Nat := - match minTargetCount? with - | none => 0 - | some v => v - let targetCountLB : Nat := - if tightPattern then (if minTargetCount > 0 then 1 else 0) else minTargetCount - let maxOtherUpper := - match maxOtherUpper? with - | none => minTargetLower - | some v => v - let marginInt : Int := minTargetLower - maxOtherUpper - let targetLower := ratOfScaledInt scalePow10 minTargetLower - let otherUpper := ratOfScaledInt scalePow10 maxOtherUpper - let margin := ratOfScaledInt scalePow10 marginInt - let weightLB : Rat := - softmaxTargetWeightLowerBound hdr.seqLen targetCountLB margin softmaxExpEffort - let cert : HeadPatternCert := { - layerIdx := layerIdx - headIdx := headIdx - seqLen := hdr.seqLen - targetOffset := targetOffset - targetCountLowerBound := targetCountLB - targetLogitLowerBound := targetLower - otherLogitUpperBound := otherUpper - marginLowerBound := margin - softmaxExpEffort := softmaxExpEffort - targetWeightLowerBound := weightLB - } - if cert.check then - return cert - throw "head pattern certificate failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Compute local head best-match pattern bounds for a specific `.nfpt` head (binary only). -/ -private def certifyHeadPatternBestMatchLocalBinary - (path : System.FilePath) - (layerIdx headIdx : Nat) - (queryPos? : Option Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (maxSeqLen : Nat) - (tightPattern : Bool) - (tightPatternLayers : Nat) - (perRowPatternLayers : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String HeadBestMatchPatternCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadBestMatchPatternCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let queryPos : Nat := - match queryPos? with - | some q => q - | none => - if hdr.seqLen = 0 then 0 else hdr.seqLen - 1 - if queryPos ≥ hdr.seqLen then - throw s!"queryPos {queryPos} out of range" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let mut wq? : Option (Array Int) := none - let mut wk? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - if hIdx = headIdx then - let wq ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wq? := some wq - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wk ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wk? := some wk - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wq ← - match wq? with - | none => throw "missing W_Q for requested head" - | some xs => pure xs - let wk ← - match wk? with - | none => throw "missing W_K for requested head" - | some xs => pure xs - let qRow := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq (ln1Rows[queryPos]!) - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - kRows := kRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row) - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut bestMatchLower? : Option Int := none - let mut bestNonmatchUpper? : Option Int := none - for j in [:hdr.seqLen] do - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - let bestMatchLower ← - match bestMatchLower? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let bestNonmatchUpper := - match bestNonmatchUpper? with - | none => bestMatchLower - | some v => v - let marginInt : Int := bestMatchLower - bestNonmatchUpper - let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower - let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper - let margin := ratOfScaledInt scalePow10 marginInt - let weightLB : Rat := - softmaxMaxProbLowerBound hdr.seqLen margin softmaxExpEffort - let softmaxJacobianUB : Rat := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen margin softmaxExpEffort - let cert : HeadBestMatchPatternCert := { - layerIdx := layerIdx - headIdx := headIdx - seqLen := hdr.seqLen - queryPos := queryPos - targetOffset := targetOffset - targetToken := targetTok - bestMatchLogitLowerBound := bestMatchLowerRat - bestNonmatchLogitUpperBound := bestNonmatchUpperRat - marginLowerBound := margin - softmaxExpEffort := softmaxExpEffort - bestMatchWeightLowerBound := weightLB - softmaxJacobianNormInfUpperBound := softmaxJacobianUB - } - if cert.check then - return cert - throw "best-match head pattern certificate failed internal consistency checks" - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Compute local head best-match pattern bounds for all valid query positions (binary only). -/ -private def certifyHeadPatternBestMatchLocalBinarySweep - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (targetOffset : Int) - (maxSeqLen : Nat) - (tightPattern : Bool) - (tightPatternLayers : Nat) - (perRowPatternLayers : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String (Array HeadBestMatchPatternCert)) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO (Array HeadBestMatchPatternCert) := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if layerIdx ≥ hdr.numLayers then - throw s!"layer index {layerIdx} out of range" - if headIdx ≥ hdr.numHeads then - throw s!"head index {headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = layerIdx then - let mut wq? : Option (Array Int) := none - let mut wk? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - if hIdx = headIdx then - let wq ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wq? := some wq - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wk ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wk? := some wk - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wq ← - match wq? with - | none => throw "missing W_Q for requested head" - | some xs => pure xs - let wk ← - match wk? with - | none => throw "missing W_K for requested head" - | some xs => pure xs - let mut qRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - let mut kRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - qRows := qRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wq row) - kRows := kRows.push (matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wk row) - let validPositions : Array Nat := Id.run do - let mut out : Array Nat := Array.mkEmpty hdr.seqLen - for i in [:hdr.seqLen] do - let ti : Int := (Int.ofNat i) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - out := out.push i - out - if validPositions.isEmpty then - throw "no valid query positions for the requested offset" - let computeCert : Nat → Except String HeadBestMatchPatternCert := fun queryPos => do - let ti : Int := (Int.ofNat queryPos) + targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - throw "query position has no valid target offset" - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let qRow := qRows[queryPos]! - let mut bestMatchLower? : Option Int := none - let mut bestNonmatchUpper? : Option Int := none - for j in [:hdr.seqLen] do - let dot := fixedDotInterval cfg qRow (kRows[j]!) - if tokens[j]! = targetTok then - bestMatchLower? := - match bestMatchLower? with - | none => some dot.lo - | some m => some (max m dot.lo) - else - bestNonmatchUpper? := - match bestNonmatchUpper? with - | none => some dot.hi - | some m => some (max m dot.hi) - let bestMatchLower ← - match bestMatchLower? with - | none => throw "no matching keys for the requested offset" - | some v => pure v - let bestNonmatchUpper := - match bestNonmatchUpper? with - | none => bestMatchLower - | some v => v - let marginInt : Int := bestMatchLower - bestNonmatchUpper - let bestMatchLowerRat := ratOfScaledInt scalePow10 bestMatchLower - let bestNonmatchUpperRat := ratOfScaledInt scalePow10 bestNonmatchUpper - let margin := ratOfScaledInt scalePow10 marginInt - let weightLB : Rat := - softmaxMaxProbLowerBound hdr.seqLen margin softmaxExpEffort - let softmaxJacobianUB : Rat := - softmaxJacobianNormInfBoundFromMargin hdr.seqLen margin softmaxExpEffort - let cert : HeadBestMatchPatternCert := { - layerIdx := layerIdx - headIdx := headIdx - seqLen := hdr.seqLen - queryPos := queryPos - targetOffset := targetOffset - targetToken := targetTok - bestMatchLogitLowerBound := bestMatchLowerRat - bestNonmatchLogitUpperBound := bestNonmatchUpperRat - marginLowerBound := margin - softmaxExpEffort := softmaxExpEffort - bestMatchWeightLowerBound := weightLB - softmaxJacobianNormInfUpperBound := softmaxJacobianUB - } - if cert.check then - return cert - throw "best-match head pattern certificate failed internal consistency checks" - let useTasks := validPositions.size > 32 - let mut certs : Array HeadBestMatchPatternCert := Array.mkEmpty validPositions.size - if useTasks then - let tasks := validPositions.map (fun i => - Task.spawn (fun _ => computeCert i)) - for t in tasks do - match t.get with - | .ok cert => certs := certs.push cert - | .error e => throw e - else - for i in validPositions do - match computeCert i with - | .ok cert => certs := certs.push cert - | .error e => throw e - return certs - else - let tightLayers : Nat := - if tightPattern then Nat.max 1 tightPatternLayers else 0 - if tightLayers > 0 && layerIdx ≤ l + tightLayers then - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - let groupRows := groupUnionRowsByToken ln1Rows tokens - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let wv ← ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let wo ← ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty groupRows.size - for row in groupRows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bV - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let vUnion := unionRowsFixed vOutRows - attnUnion := addVecFixed attnUnion vUnion - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let perRowLayers : Nat := perRowPatternLayers - if perRowLayers > 0 && layerIdx ≤ l + perRowLayers then - let wIn ← - ExceptT.mk <| readScaledFloatArray h (hdr.modelDim * hdr.hiddenDim) scalePow10 - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let wOut ← - ExceptT.mk <| readScaledFloatArray h (hdr.hiddenDim * hdr.modelDim) scalePow10 - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpRows := - mlpRowsFromScaled cfg slack hdr.modelDim hdr.hiddenDim wIn wOut bIn bOut ln2Rows - residuals := addRowsFixed residuals mlpRows - else - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Compute local head output lower bounds for a single coordinate (binary only). -/ -private def certifyHeadValueLowerBoundLocalBinary - (path : System.FilePath) - (pattern : HeadPatternCert) - (coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String HeadValueLowerBoundCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadValueLowerBoundCert := do - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if pattern.layerIdx ≥ hdr.numLayers then - throw s!"layer index {pattern.layerIdx} out of range" - if pattern.headIdx ≥ hdr.numHeads then - throw s!"head index {pattern.headIdx} out of range" - if coord ≥ hdr.modelDim then - throw s!"coord index {coord} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - if pattern.seqLen ≠ hdr.seqLen then - throw "pattern seq_len mismatch" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = pattern.layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = pattern.headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let mut minMatchLo? : Option Int := none - let mut minNonmatchLo? : Option Int := none - for i in [:hdr.seqLen] do - let ti : Int := (Int.ofNat i) + pattern.targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - let row := vOutRows[j]! - let vCoord := row[coord]!.lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vCoord - | some m => some (min m vCoord) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vCoord - | some m => some (min m vCoord) - let matchLo := - match matchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - minMatchLo? := - match minMatchLo? with - | none => some matchLo - | some m => some (min m matchLo) - minNonmatchLo? := - match minNonmatchLo? with - | none => some nonmatchLo - | some m => some (min m nonmatchLo) - let matchLo := - match minMatchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match minNonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := pattern.targetWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadValueLowerBoundCert := { - layerIdx := pattern.layerIdx - headIdx := pattern.headIdx - coord := coord - matchWeightLowerBound := weightLB - matchCoordLowerBound := matchLoRat - nonmatchCoordLowerBound := nonmatchLoRat - outputCoordLowerBound := outputLB - } - if cert.check then - return cert - throw "head value lower bound failed internal consistency checks" - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Compute local head logit-difference lower bounds for a specific head (binary only). -/ -private def certifyHeadLogitDiffLowerBoundLocalBinary - (path : System.FilePath) - (pattern : HeadPatternCert) - (targetToken negativeToken : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath : System.FilePath) - (inputDelta : Rat) - (maxSeqLen : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String HeadLogitDiffLowerBoundCert) := do - let cfg : Fixed10Cfg := scaleCfgOfPow10 scalePow10 - let slack : Int := fixedUlpSlack - let action : ExceptT String IO HeadLogitDiffLowerBoundCert := do - let (hdrDir, direction) ← - ExceptT.mk (readLogitDiffDirectionBinary path targetToken negativeToken scalePow10 slack) - let (hdr, ln1Params, ln2Params) ← - ExceptT.mk (collectLayerNormParamsBinary path scalePow10 slack) - if hdr.modelDim ≠ hdrDir.modelDim then - throw "unembedding model_dim mismatch" - if pattern.layerIdx ≥ hdr.numLayers then - throw s!"layer index {pattern.layerIdx} out of range" - if pattern.headIdx ≥ hdr.numHeads then - throw s!"head index {pattern.headIdx} out of range" - if hdr.seqLen > maxSeqLen then - throw s!"seq_len {hdr.seqLen} exceeds maxSeqLen {maxSeqLen}" - if direction.size ≠ hdr.modelDim then - throw "logit direction size mismatch" - let residuals0 ← - ExceptT.mk - (loadEmbeddingsIntervalsBinary inputPath hdr.modelDim inputDelta scalePow10) - let (hdrTok, tokens) ← ExceptT.mk (loadTokensBinary inputPath) - if hdrTok.seqLen ≠ hdr.seqLen then - throw "token/embedding seq_len mismatch" - if pattern.seqLen ≠ hdr.seqLen then - throw "pattern seq_len mismatch" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let _ ← ExceptT.mk (readBinaryHeader h) - let _ ← ExceptT.mk (skipI32Array h hdr.seqLen) - let _ ← ExceptT.mk (skipF64Array h (hdr.seqLen * hdr.modelDim)) - let defP : LayerNormParamsFixed := { - gamma := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - beta := Array.replicate hdr.modelDim { lo := 0, hi := 0 } - } - let mut residuals := residuals0 - for l in [:hdr.numLayers] do - let p1 := ln1Params.getD l defP - let mut ln1Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln1Out, _ln1VarLB) := - fixedLayerNormRowApprox cfg row p1.gamma p1.beta eps soundnessBits - ln1Rows := ln1Rows.push ln1Out - if l = pattern.layerIdx then - let mut wv? : Option (Array Int) := none - let mut bv? : Option (Array Int) := none - let mut wo? : Option (Array Int) := none - for hIdx in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - if hIdx = pattern.headIdx then - let wv ← - ExceptT.mk <| - readScaledFloatArray h (hdr.modelDim * hdr.headDim) scalePow10 - wv? := some wv - let bV ← - ExceptT.mk <| - readScaledFloatArray h hdr.headDim scalePow10 - bv? := some bV - let wo ← - ExceptT.mk <| - readScaledFloatArray h (hdr.headDim * hdr.modelDim) scalePow10 - wo? := some wo - else - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.headDim * hdr.modelDim)) - let wv ← - match wv? with - | none => throw "missing W_V for requested head" - | some xs => pure xs - let bV ← - match bv? with - | none => throw "missing b_V for requested head" - | some xs => pure xs - let wo ← - match wo? with - | none => throw "missing W_O for requested head" - | some xs => pure xs - let bVIntervals := intervalsFromScaled bV slack - let mut vOutRows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in ln1Rows do - let vHidden0 := matMulIntervalsFromScaled cfg slack - hdr.modelDim hdr.headDim wv row - let vHidden := addVecFixed vHidden0 bVIntervals - let vOut := matMulIntervalsFromScaled cfg slack - hdr.headDim hdr.modelDim wo vHidden - vOutRows := vOutRows.push vOut - let mut vDotRows : Array Fixed10Interval := Array.mkEmpty hdr.seqLen - for row in vOutRows do - vDotRows := vDotRows.push (fixedDotInterval cfg row direction) - let mut minMatchLo? : Option Int := none - let mut minNonmatchLo? : Option Int := none - for i in [:hdr.seqLen] do - let ti : Int := (Int.ofNat i) + pattern.targetOffset - if ti < 0 || ti ≥ (Int.ofNat hdr.seqLen) then - pure () - else - let tIdx : Nat := Int.toNat ti - let targetTok := tokens[tIdx]! - let mut matchLo? : Option Int := none - let mut nonmatchLo? : Option Int := none - for j in [:hdr.seqLen] do - let vLo := (vDotRows[j]!).lo - if tokens[j]! = targetTok then - matchLo? := - match matchLo? with - | none => some vLo - | some m => some (min m vLo) - else - nonmatchLo? := - match nonmatchLo? with - | none => some vLo - | some m => some (min m vLo) - let matchLo := - match matchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match nonmatchLo? with - | none => matchLo - | some v => v - minMatchLo? := - match minMatchLo? with - | none => some matchLo - | some m => some (min m matchLo) - minNonmatchLo? := - match minNonmatchLo? with - | none => some nonmatchLo - | some m => some (min m nonmatchLo) - let matchLo := - match minMatchLo? with - | none => 0 - | some v => v - let nonmatchLo := - match minNonmatchLo? with - | none => matchLo - | some v => v - let matchLoRat := ratOfScaledInt scalePow10 matchLo - let nonmatchLoRat := ratOfScaledInt scalePow10 nonmatchLo - let weightLB := pattern.targetWeightLowerBound - let outputLB := mixLowerBound weightLB matchLoRat nonmatchLoRat - let cert : HeadLogitDiffLowerBoundCert := { - layerIdx := pattern.layerIdx - headIdx := pattern.headIdx - targetToken := targetToken - negativeToken := negativeToken - matchWeightLowerBound := weightLB - matchLogitLowerBound := matchLoRat - nonmatchLogitLowerBound := nonmatchLoRat - logitDiffLowerBound := outputLB - } - if cert.check then - return cert - throw "head logit lower bound failed internal consistency checks" - else - let ln1Union := unionRowsFixed ln1Rows - let mut attnUnion : Array Fixed10Interval := - Array.replicate hdr.modelDim { lo := 0, hi := 0 } - for _h in [:hdr.numHeads] do - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let _ ← ExceptT.mk (skipF64Array h (hdr.modelDim * hdr.headDim)) - let _ ← ExceptT.mk (skipF64Array h hdr.headDim) - let (vHidden0, _nWv) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.headDim ln1Union scalePow10) - let bV ← ExceptT.mk (readVecIntervalsBinary h hdr.headDim slack scalePow10) - let vHidden := addVecFixed vHidden0 bV - let (vOut, _nWo) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.headDim hdr.modelDim vHidden scalePow10) - attnUnion := addVecFixed attnUnion vOut - let attnBias ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - attnUnion := addVecFixed attnUnion attnBias - residuals := addVecFixedRows residuals attnUnion - let p2 := ln2Params.getD l defP - let mut ln2Rows : Array (Array Fixed10Interval) := Array.mkEmpty hdr.seqLen - for row in residuals do - let (ln2Out, _ln2VarLB) := - fixedLayerNormRowApprox cfg row p2.gamma p2.beta eps soundnessBits - ln2Rows := ln2Rows.push ln2Out - let ln2Union := unionRowsFixed ln2Rows - let (hidden0, _nWin) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.modelDim hdr.hiddenDim ln2Union scalePow10) - let bIn ← ExceptT.mk (readVecIntervalsBinary h hdr.hiddenDim slack scalePow10) - let hiddenB := addVecFixed hidden0 bIn - let actHidden := hiddenB.map Fixed10Interval.geluOverapprox - let (mlpOut0, _nWout) ← - ExceptT.mk (consumeMatrixMulAndNormInfFixedBinary cfg slack h - hdr.hiddenDim hdr.modelDim actHidden scalePow10) - let bOut ← ExceptT.mk (readVecIntervalsBinary h hdr.modelDim slack scalePow10) - let mlpOut := addVecFixed mlpOut0 bOut - residuals := addVecFixedRows residuals mlpOut - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - let _ ← ExceptT.mk (skipF64Array h hdr.modelDim) - throw "target layer not reached" - action.run - -/-- Soundly compute certification bounds from a `.nfpt` model file. - -If an input is provided via `inputPath?`, the certificate uses streaming rational IBP to obtain -local (input-dependent) LayerNorm variance lower bounds at every layer. -Otherwise it falls back to the weight-only global certificate. --/ -def certifyModelFile - (path : System.FilePath) - (eps : Rat) - (geluDerivTarget : GeluDerivTarget) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (partitionDepth : Nat := 0) - (softmaxMarginLowerBound : Rat := 0) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : IO (Except String ModelCert) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - if inputDelta < 0 then - return .error "delta must be nonnegative" - match inputPath? with - | none => - if inputDelta = 0 then - certifyModelFileGlobalBinary path eps geluDerivTarget soundnessBits partitionDepth - softmaxMarginLowerBound softmaxExpEffort - else - certifyModelFileLocalBinary path eps geluDerivTarget soundnessBits partitionDepth - path inputDelta softmaxMarginLowerBound softmaxExpEffort - | some ip => - certifyModelFileLocalBinary path eps geluDerivTarget soundnessBits partitionDepth - ip inputDelta softmaxMarginLowerBound softmaxExpEffort - else - match inputPath? with - | none => - certifyModelFileGlobal path eps geluDerivTarget soundnessBits - (inputPath? := none) (inputDelta := inputDelta) (partitionDepth := partitionDepth) - (softmaxMarginLowerBound := softmaxMarginLowerBound) - (softmaxExpEffort := softmaxExpEffort) - | some ip => - if inputDelta < 0 then - return .error "delta must be nonnegative" - certifyModelFileLocal path eps geluDerivTarget soundnessBits partitionDepth ip inputDelta - softmaxMarginLowerBound softmaxExpEffort - -/-- Compute weight-only per-head contribution bounds for a `.nfpt` model file. -/ -def certifyHeadBounds - (path : System.FilePath) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadContributionCert)) := do - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - certifyHeadBoundsBinary path scalePow10 - else - return .error "head contribution bounds require NFP_BINARY_V1" - -/-- Compute local per-head attention contribution bounds for a `.nfpt` model file. -/ -def certifyHeadBoundsLocal - (path : System.FilePath) - (eps : Rat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (soundnessBits : Nat) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String (Array HeadLocalContributionCert)) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadBoundsLocalBinary path eps inputPath inputDelta soundnessBits scalePow10 - else - return .error "local head contribution bounds require NFP_BINARY_V1" - -/-- Compute local attention pattern bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadPatternLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String HeadPatternCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort - else - return .error "head pattern bounds require NFP_BINARY_V1" - -/-- Compute local best-match pattern bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadPatternBestMatchLocal - (path : System.FilePath) - (layerIdx headIdx : Nat) - (queryPos? : Option Nat := none) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String HeadBestMatchPatternCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadPatternBestMatchLocalBinary path layerIdx headIdx queryPos? eps soundnessBits - inputPath - inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 softmaxExpEffort - else - return .error "head pattern bounds require NFP_BINARY_V1" - -/-- Compute local best-match pattern bounds for all valid query positions (binary only). -/ -def certifyHeadPatternBestMatchLocalSweep - (path : System.FilePath) - (layerIdx headIdx : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String (Array HeadBestMatchPatternCert)) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - certifyHeadPatternBestMatchLocalBinarySweep path layerIdx headIdx eps soundnessBits inputPath - inputDelta targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers - scalePow10 softmaxExpEffort - else - return .error "head pattern bounds require NFP_BINARY_V1" - -/-- Compute local head value lower bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadValueLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String HeadValueLowerBoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let patternE ← - certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - match patternE with - | .error e => return .error e - | .ok pattern => - certifyHeadValueLowerBoundLocalBinary path pattern coord eps soundnessBits inputPath - inputDelta maxSeqLen scalePow10 - else - return .error "head value bounds require NFP_BINARY_V1" - -/-- Compute local head logit-difference lower bounds for a specific `.nfpt` head (binary only). -/ -def certifyHeadLogitDiffLowerBoundLocal - (path : System.FilePath) - (layerIdx headIdx targetToken negativeToken : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (targetOffset : Int := -1) - (maxSeqLen : Nat := 256) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (scalePow10 : Nat := defaultBinaryScalePow10) : - IO (Except String HeadLogitDiffLowerBoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let patternE ← - certifyHeadPatternLocalBinary path layerIdx headIdx eps soundnessBits inputPath inputDelta - targetOffset maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - match patternE with - | .error e => return .error e - | .ok pattern => - certifyHeadLogitDiffLowerBoundLocalBinary path pattern targetToken negativeToken - eps soundnessBits inputPath inputDelta maxSeqLen scalePow10 - else - return .error "head logit bounds require NFP_BINARY_V1" - -/-- Compute a combined sound certificate for an induction-style head pair (binary only). -/ -def certifyInductionSound - (path : System.FilePath) - (layer1 head1 layer2 head2 coord : Nat) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (offset1 : Int := -1) - (offset2 : Int := -1) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String InductionHeadSoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let p1E ← - certifyHeadPatternLocalBinary path layer1 head1 eps soundnessBits inputPath inputDelta - offset1 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort - match p1E with - | .error e => return .error e - | .ok p1 => - let p2E ← - certifyHeadPatternLocalBinary path layer2 head2 eps soundnessBits inputPath inputDelta - offset2 maxSeqLen tightPattern tightPatternLayers perRowPatternLayers scalePow10 - softmaxExpEffort - match p2E with - | .error e => return .error e - | .ok p2 => - let vE ← - certifyHeadValueLowerBoundLocalBinary path p2 coord eps soundnessBits inputPath - inputDelta maxSeqLen scalePow10 - match vE with - | .error e => return .error e - | .ok v => - let logitE ← - match targetToken?, negativeToken? with - | none, none => pure (.ok none) - | some targetToken, some negativeToken => do - let logitE ← certifyHeadLogitDiffLowerBoundLocalBinary path p2 - targetToken negativeToken eps soundnessBits inputPath inputDelta - maxSeqLen scalePow10 - pure (logitE.map some) - | _, _ => - pure (.error "use both target and negative tokens (or neither)") - match logitE with - | .error e => return .error e - | .ok logit? => - let cert : InductionHeadSoundCert := { - layer1Pattern := p1 - layer2Pattern := p2 - layer2Value := v - layer2Logit? := logit? - deltaLowerBound := v.outputCoordLowerBound - } - if cert.check then - return .ok cert - return .error "induction head certificate failed internal consistency checks" - else - return .error "induction sound cert requires NFP_BINARY_V1" - -/-- Compute a combined sound certificate for an induction-style head pair (best-match, -binary only). -/ -def certifyInductionSoundBestMatch - (path : System.FilePath) - (layer1 head1 layer2 head2 coord : Nat) - (queryPos? : Option Nat := none) - (eps : Rat) - (soundnessBits : Nat) - (inputPath? : Option System.FilePath := none) - (inputDelta : Rat := 0) - (offset1 : Int := -1) - (offset2 : Int := -1) - (maxSeqLen : Nat := 256) - (scalePow10 : Nat := defaultBinaryScalePow10) - (tightPattern : Bool := false) - (tightPatternLayers : Nat := 1) - (perRowPatternLayers : Nat := 0) - (targetToken? : Option Nat := none) - (negativeToken? : Option Nat := none) - (softmaxExpEffort : Nat := defaultSoftmaxExpEffort) : - IO (Except String InductionHeadBestMatchSoundCert) := do - if inputDelta < 0 then - return .error "delta must be nonnegative" - let h ← IO.FS.Handle.mk path IO.FS.Mode.read - let firstLine := (← h.getLine).trim - if firstLine = "NFP_BINARY_V1" then - let inputPath := inputPath?.getD path - let p1E ← - certifyHeadPatternBestMatchLocalBinary path layer1 head1 queryPos? eps soundnessBits - inputPath inputDelta offset1 maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 softmaxExpEffort - match p1E with - | .error e => return .error e - | .ok p1 => - let p2E ← - certifyHeadPatternBestMatchLocalBinary path layer2 head2 queryPos? eps soundnessBits - inputPath inputDelta offset2 maxSeqLen tightPattern tightPatternLayers - perRowPatternLayers scalePow10 softmaxExpEffort - match p2E with - | .error e => return .error e - | .ok p2 => - let vE ← - certifyHeadValueLowerBoundLocalBinaryAt path layer2 head2 p2.queryPos coord - eps soundnessBits inputPath inputDelta offset2 p2.bestMatchWeightLowerBound - maxSeqLen scalePow10 - match vE with - | .error e => return .error e - | .ok v => - let logitE ← - match targetToken?, negativeToken? with - | none, none => pure (.ok none) - | some targetToken, some negativeToken => do - let logitE ← - certifyHeadLogitDiffLowerBoundLocalBinaryAt path layer2 head2 p2.queryPos - targetToken negativeToken eps soundnessBits inputPath inputDelta offset2 - p2.bestMatchWeightLowerBound maxSeqLen scalePow10 - pure (logitE.map some) - | _, _ => - pure (.error "use both target and negative tokens (or neither)") - match logitE with - | .error e => return .error e - | .ok logit? => - let cert : InductionHeadBestMatchSoundCert := { - layer1Pattern := p1 - layer2Pattern := p2 - layer2Value := v - layer2Logit? := logit? - deltaLowerBound := v.outputCoordLowerBound - } - if cert.check then - return .ok cert - return .error "induction head certificate failed internal consistency checks" - else - return .error "induction sound cert requires NFP_BINARY_V1" - -/-! ### Specs -/ - -theorem parseHeaderLine_spec_io : - parseHeaderLine = parseHeaderLine := rfl - -theorem defaultBinaryScalePow10_spec_io : - defaultBinaryScalePow10 = defaultBinaryScalePow10 := rfl - -theorem foldRatTokens_spec_io {α : Type} : - foldRatTokens (α := α) = foldRatTokens (α := α) := rfl - -theorem consumeVector_spec_io : - consumeVector = consumeVector := rfl - -theorem consumeMatrixNormInf_spec_io : - consumeMatrixNormInf = consumeMatrixNormInf := rfl - -theorem consumeMaxAbs_spec_io : - consumeMaxAbs = consumeMaxAbs := rfl - -theorem maxAbsOfVector_spec_io : - maxAbsOfVector = maxAbsOfVector := rfl - -theorem certifyHeadBoundsBinary_spec_io : - certifyHeadBoundsBinary = certifyHeadBoundsBinary := rfl - -theorem certifyModelFileGlobalBinary_spec_io : - certifyModelFileGlobalBinary = certifyModelFileGlobalBinary := rfl - -theorem addVecIntervals_spec_io : - addVecIntervals = addVecIntervals := rfl - -theorem addConstVec_spec_io : - addConstVec = addConstVec := rfl - -theorem unionVecIntervals_spec_io : - unionVecIntervals = unionVecIntervals := rfl - -theorem zeroIntervals_spec_io : - zeroIntervals = zeroIntervals := rfl - -theorem unionRows_spec_io : - unionRows = unionRows := rfl - -theorem layerNormRowApprox_spec_io : - layerNormRowApprox = layerNormRowApprox := rfl - -theorem minVarAcrossRows_spec_io : - minVarAcrossRows = minVarAcrossRows := rfl - -theorem findLineIdxFrom_spec_io : - findLineIdxFrom = findLineIdxFrom := rfl - -theorem skipUntil_spec_io : - skipUntil = skipUntil := rfl - -theorem skipBlankLines_spec_io : - skipBlankLines = skipBlankLines := rfl - -theorem countWsTokens_spec_io : - countWsTokens = countWsTokens := rfl - -theorem consumeTokensSkipFast_spec_io : - consumeTokensSkipFast = consumeTokensSkipFast := rfl - -theorem consumeMatrixSkip_spec_io : - consumeMatrixSkip = consumeMatrixSkip := rfl - -theorem consumeMatrixSkipFast_spec_io : - consumeMatrixSkipFast = consumeMatrixSkipFast := rfl - -theorem consumeVectorSkipFast_spec_io : - consumeVectorSkipFast = consumeVectorSkipFast := rfl - -theorem consumeMatrixMulAndNormInf_spec_io : - consumeMatrixMulAndNormInf = consumeMatrixMulAndNormInf := rfl - -theorem certifyModelFileGlobal_spec_io : - certifyModelFileGlobal = certifyModelFileGlobal := rfl - -theorem loadEmbeddingsIntervals_spec_io : - loadEmbeddingsIntervals = loadEmbeddingsIntervals := rfl - -theorem intervalsFromScaled_spec_io : - intervalsFromScaled = intervalsFromScaled := rfl - -theorem collectLayerNormParams_spec_io : - collectLayerNormParams = collectLayerNormParams := rfl - -theorem collectLayerNormParamsBinary_spec_io : - collectLayerNormParamsBinary = collectLayerNormParamsBinary := rfl - -theorem defaultFixedScalePow10_spec_io : - defaultFixedScalePow10 = defaultFixedScalePow10 := rfl - -theorem fixedUlpSlack_spec_io : - fixedUlpSlack = fixedUlpSlack := rfl - -theorem scaleCfgOfPow10_spec_io : - scaleCfgOfPow10 = scaleCfgOfPow10 := rfl - -theorem ratCeilMulNat_spec_io : - ratCeilMulNat = ratCeilMulNat := rfl - -theorem fixedMeanInterval_spec_io : - fixedMeanInterval = fixedMeanInterval := rfl - -theorem fixedVarianceLowerBoundRange_spec_io : - fixedVarianceLowerBoundRange = fixedVarianceLowerBoundRange := rfl - -theorem fixedLayerNormRowApprox_spec_io : - fixedLayerNormRowApprox = fixedLayerNormRowApprox := rfl - -theorem readVecIntervals_spec_io : - readVecIntervals = readVecIntervals := rfl - -theorem readVecIntervalsBinary_spec_io : - readVecIntervalsBinary = readVecIntervalsBinary := rfl - -theorem matMulIntervalsFromScaled_spec_io : - matMulIntervalsFromScaled = matMulIntervalsFromScaled := rfl - -theorem fixedDotInterval_spec_io : - fixedDotInterval = fixedDotInterval := rfl - -theorem maxAbsVecFixed_spec_io : - maxAbsVecFixed = maxAbsVecFixed := rfl - -theorem addVecFixed_spec_io : - addVecFixed = addVecFixed := rfl - -theorem addVecFixedRows_spec_io : - addVecFixedRows = addVecFixedRows := rfl - -theorem addRowsFixed_spec_io : - addRowsFixed = addRowsFixed := rfl - -theorem mlpRowFromScaled_spec_io : - mlpRowFromScaled = mlpRowFromScaled := rfl - -theorem mlpRowsFromScaled_spec_io : - mlpRowsFromScaled = mlpRowsFromScaled := rfl - -theorem groupUnionRowsByToken_spec_io : - groupUnionRowsByToken = groupUnionRowsByToken := rfl - -theorem unionRowsFixed_spec_io : - unionRowsFixed = unionRowsFixed := rfl - -theorem consumeMatrixMulAndNormInfFixed_spec_io : - consumeMatrixMulAndNormInfFixed = consumeMatrixMulAndNormInfFixed := rfl - -theorem consumeMatrixMulAndNormInfFixedBinary_spec_io : - consumeMatrixMulAndNormInfFixedBinary = consumeMatrixMulAndNormInfFixedBinary := rfl - -theorem loadEmbeddingsUnionFixed_spec_io : - loadEmbeddingsUnionFixed = loadEmbeddingsUnionFixed := rfl - -theorem loadEmbeddingsUnionFixedBinary_spec_io : - loadEmbeddingsUnionFixedBinary = loadEmbeddingsUnionFixedBinary := rfl - -theorem loadEmbeddingsIntervalsBinary_spec_io : - loadEmbeddingsIntervalsBinary = loadEmbeddingsIntervalsBinary := rfl - -theorem loadTokensBinary_spec_io : - loadTokensBinary = loadTokensBinary := rfl - -theorem skipToUnembeddingBinary_spec_io : - skipToUnembeddingBinary = skipToUnembeddingBinary := rfl - -theorem certifyHeadValueLowerBoundLocalBinaryAt_spec_io : - certifyHeadValueLowerBoundLocalBinaryAt = certifyHeadValueLowerBoundLocalBinaryAt := rfl - -theorem readUnembeddingColumnsBinary_spec_io : - readUnembeddingColumnsBinary = readUnembeddingColumnsBinary := rfl - -theorem readLogitDiffDirectionBinary_spec_io : - readLogitDiffDirectionBinary = readLogitDiffDirectionBinary := rfl - -theorem certifyHeadLogitDiffLowerBoundLocalBinaryAt_spec_io : - certifyHeadLogitDiffLowerBoundLocalBinaryAt = - certifyHeadLogitDiffLowerBoundLocalBinaryAt := rfl - -theorem ensureSoundCache_spec_io : - ensureSoundCache = ensureSoundCache := rfl - -theorem certifyModelFileLocalText_spec_io : - certifyModelFileLocalText = certifyModelFileLocalText := rfl - -theorem certifyModelFileLocal_spec_io : - certifyModelFileLocal = certifyModelFileLocal := rfl - -theorem certifyModelFileLocalBinary_spec_io : - certifyModelFileLocalBinary = certifyModelFileLocalBinary := rfl - -theorem certifyHeadBoundsLocalBinary_spec_io : - certifyHeadBoundsLocalBinary = certifyHeadBoundsLocalBinary := rfl - -theorem certifyHeadPatternLocalBinary_spec_io : - certifyHeadPatternLocalBinary = certifyHeadPatternLocalBinary := rfl - -theorem certifyHeadPatternBestMatchLocalBinary_spec_io : - certifyHeadPatternBestMatchLocalBinary = certifyHeadPatternBestMatchLocalBinary := rfl - -theorem certifyHeadPatternBestMatchLocalBinarySweep_spec_io : - certifyHeadPatternBestMatchLocalBinarySweep = - certifyHeadPatternBestMatchLocalBinarySweep := rfl - -theorem certifyHeadValueLowerBoundLocalBinary_spec_io : - certifyHeadValueLowerBoundLocalBinary = certifyHeadValueLowerBoundLocalBinary := rfl - -theorem certifyHeadLogitDiffLowerBoundLocalBinary_spec_io : - certifyHeadLogitDiffLowerBoundLocalBinary = certifyHeadLogitDiffLowerBoundLocalBinary := rfl - -theorem certifyModelFile_spec_io : - certifyModelFile = certifyModelFile := rfl - -theorem certifyHeadBounds_spec_io : - certifyHeadBounds = certifyHeadBounds := rfl - -theorem certifyHeadBoundsLocal_spec_io : - certifyHeadBoundsLocal = certifyHeadBoundsLocal := rfl - -theorem certifyHeadPatternLocal_spec_io : - certifyHeadPatternLocal = certifyHeadPatternLocal := rfl - -theorem certifyHeadPatternBestMatchLocal_spec_io : - certifyHeadPatternBestMatchLocal = certifyHeadPatternBestMatchLocal := rfl - -theorem certifyHeadPatternBestMatchLocalSweep_spec_io : - certifyHeadPatternBestMatchLocalSweep = certifyHeadPatternBestMatchLocalSweep := rfl - -theorem certifyHeadValueLowerBoundLocal_spec_io : - certifyHeadValueLowerBoundLocal = certifyHeadValueLowerBoundLocal := rfl - -theorem certifyHeadLogitDiffLowerBoundLocal_spec_io : - certifyHeadLogitDiffLowerBoundLocal = certifyHeadLogitDiffLowerBoundLocal := rfl - -theorem certifyInductionSound_spec_io : - certifyInductionSound = certifyInductionSound := rfl - -theorem certifyInductionSoundBestMatch_spec_io : - certifyInductionSoundBestMatch = certifyInductionSoundBestMatch := rfl - -end Nfp.Untrusted.SoundCompute diff --git a/Nfp/Verification.lean b/Nfp/Verification.lean deleted file mode 100644 index f2bbc26..0000000 --- a/Nfp/Verification.lean +++ /dev/null @@ -1,403 +0,0 @@ --- SPDX-License-Identifier: AGPL-3.0-or-later - -import Nfp.Discovery - -/-! -# Causal Verification (Head Ablation) - -This module implements **causal verification** of candidate circuits using an executable -forward pass with **head ablation** (a.k.a. zero-ablation). - -The core API is `verifyCircuit`, which: -- selects **energy-matched control heads** (same layer, closest output norm), -- checks runtime **axioms** (baseline competence, control independence, energy equivalence), -- and (only if axioms hold) measures **logit-difference impact** under ablation. - -This is intended to operationalize interchange-intervention style causality checks in a -fully executable setting. --/ - -namespace Nfp - -/-! ## Head references -/ - -/-- A reference to a specific attention head `(layer, head)`. -/ -structure HeadRef where - layerIdx : Nat - headIdx : Nat -deriving BEq, DecidableEq - -namespace HeadRef - -instance : Inhabited HeadRef := ⟨{ layerIdx := 0, headIdx := 0 }⟩ - -def toString (h : HeadRef) : String := - s!"L{h.layerIdx}H{h.headIdx}" - -instance : ToString HeadRef := ⟨toString⟩ - -def toComponentId (h : HeadRef) : ComponentId := - .head h.layerIdx h.headIdx - -end HeadRef - -/-! ## Induction target utilities -/ - -/-- Derive the **induction target token** from the model's `inputTokens`, if available. - -Let `T` be the token sequence and `t_curr = T[last]`. If `t_curr` appeared previously at index `k`, -the induction target is `t_tgt = T[k+1]` (the successor of the previous match). - -Returns `none` if there is no prior repetition of the last token or no token history. -/ -def inductionTargetTokenFromHistory (model : ConcreteModel) : Option Nat := do - let tokens ← model.inputTokens - if tokens.size = 0 then none else - let lastIdx := tokens.size - 1 - let tCurr := tokens[lastIdx]! - let mut foundIdx : Option Nat := none - for offset in [:lastIdx] do - if foundIdx.isNone then - let idx := lastIdx - 1 - offset - if tokens[idx]! = tCurr then - foundIdx := some idx - let k ← foundIdx - some (tokens[k + 1]!) - -/-! ## Verification configuration and results -/ - -/-- Runtime parameters for causal verification via head ablation. -/ -structure VerificationConfig where - /-- Competence threshold ε for baseline logit-difference Δ. -/ - competenceEpsilon : Float := 1e-3 - /-- Relative tolerance δ for energy matching: |‖out_cand‖ - ‖out_ctrl‖| < δ · ‖out_cand‖. -/ - energyRelTol : Float := 0.05 - /-- Absolute fallback tolerance for tiny-norm heads. -/ - energyAbsTol : Float := 1e-6 - /-- Precision (bits) for dyadic sqrt bounds in SOUND-mode checks. -/ - soundnessBits : Nat := 20 - /-- Whether to run attention causally (autoregressive). -/ - causal : Bool := true - -/-- The outcome of checking the runtime axioms required for causal interpretation. -/ -structure AxiomStatus where - baselineCompetence : Bool - controlIndependence : Bool - energyEquivalence : Bool - /-- Human-readable failures (empty iff all axioms hold). -/ - failures : Array String := #[] - -namespace AxiomStatus - -def verified (s : AxiomStatus) : Bool := - s.failures.isEmpty && s.baselineCompetence && s.controlIndependence && s.energyEquivalence - -end AxiomStatus - -/-- A single verification row for a candidate circuit. -/ -structure CircuitVerificationRow where - /-- Candidate circuit heads (ablated together). -/ - candidateHeads : Array HeadRef - /-- Selected control heads (one per candidate head, same layers). -/ - controlHeads : Option (Array HeadRef) - /-- Baseline logit-difference Δ = logit(t_tgt) - logit(t_neg). -/ - baseDelta : Float - /-- Ablated Δ for the candidate circuit, if axioms verified. -/ - ablatedDelta : Option Float - /-- Candidate causal impact = Δ_base - Δ_ablated, if axioms verified. -/ - impact : Option Float - /-- Relative score = impact / Δ_base (only defined if baseline competence holds). -/ - relScore : Option Float - /-- Control impact = Δ_base - Δ_ctrlAblated, if axioms verified. -/ - controlImpact : Option Float - /-- Axiom status (checked *before* computing impact scores). -/ - axioms : AxiomStatus - -namespace CircuitVerificationRow - -def candidateLabel (r : CircuitVerificationRow) : String := - if r.candidateHeads.size = 2 then - let h1 := r.candidateHeads[0]! - let h2 := r.candidateHeads[1]! - s!"{h1} -> {h2}" - else - String.intercalate "," (r.candidateHeads.toList.map (fun h => toString h)) - -end CircuitVerificationRow - -/-! ## Low-level helpers: logits, controls, and head ablation -/ - -private def logitAt (residual : ConcreteMatrix) (pos : Nat) - (W_U : ConcreteMatrix) (token : Nat) : Option Float := - if residual.numCols = W_U.numRows ∧ pos < residual.numRows ∧ token < W_U.numCols then - some <| Id.run do - let d := residual.numCols - let vocab := W_U.numCols - let rowBase := pos * d - let mut acc : Float := 0.0 - for k in [:d] do - -- SAFETY: `pos < residual.numRows` and `k < d = residual.numCols`. - let x := residual.data[rowBase + k]! - -- SAFETY: `k < W_U.numRows` and `token < vocab = W_U.numCols`. - let w := W_U.data[k * vocab + token]! - acc := acc + x * w - return acc - else none - -private def topNonTargetToken (residual : ConcreteMatrix) (pos : Nat) - (W_U : ConcreteMatrix) (targetToken : Nat) : Option (Nat × Float) := Id.run do - if residual.numCols = W_U.numRows ∧ pos < residual.numRows ∧ - targetToken < W_U.numCols ∧ W_U.numCols ≥ 2 then - let d := residual.numCols - let vocab := W_U.numCols - let rowBase := pos * d - let mut bestTok : Nat := 0 - let mut bestLogit : Float := (-Float.inf) - let mut found : Bool := false - for tok in [:vocab] do - if tok ≠ targetToken then - found := true - let mut acc : Float := 0.0 - for k in [:d] do - -- SAFETY: `pos < residual.numRows` and `k < d = residual.numCols`. - let x := residual.data[rowBase + k]! - -- SAFETY: `k < W_U.numRows` and `tok < vocab = W_U.numCols`. - let w := W_U.data[k * vocab + tok]! - acc := acc + x * w - if acc > bestLogit then - bestTok := tok - bestLogit := acc - if found then - return some (bestTok, bestLogit) - else return none - else - return none - -private def containsHead (hs : Array HeadRef) (h : HeadRef) : Bool := - hs.any (fun x => x == h) - -private def fullCircuit (model : ConcreteModel) : ConcreteCircuit := Id.run do - let mut headsPerLayer : Array Nat := Array.mkEmpty model.numLayers - let mut neuronsPerLayer : Array Nat := Array.mkEmpty model.numLayers - for l in [:model.numLayers] do - headsPerLayer := headsPerLayer.push (model.layers.getD l #[]).size - neuronsPerLayer := neuronsPerLayer.push (model.numNeuronsAtLayer l) - ConcreteCircuit.full model.numLayers headsPerLayer neuronsPerLayer - -private def runForwardAblatingHeads (model : ConcreteModel) (heads : Array HeadRef) - (causal : Bool := true) : ForwardPassResult := - let base := fullCircuit model - let circuit := heads.foldl (fun c h => c.removeComponent h.toComponentId) base - model.runAblatedForward circuit causal - -private def headOutputNorm? (fwd : ForwardPassResult) (h : HeadRef) : Option Float := - if hl : h.layerIdx < fwd.attnOutputs.size then - let layerOut := fwd.attnOutputs[h.layerIdx] - if hh : h.headIdx < layerOut.size then - some ((layerOut[h.headIdx]'hh).frobeniusNorm) - else none - else none - -private structure ControlSelection where - candidate : HeadRef - control : HeadRef - candNorm : Float - ctrlNorm : Float - absDiff : Float - -private def selectEnergyMatchedControl (fwd : ForwardPassResult) - (cand : HeadRef) (exclude : Array HeadRef) : Option ControlSelection := Id.run do - match headOutputNorm? fwd cand with - | none => none - | some candNorm => - if hl : cand.layerIdx < fwd.attnOutputs.size then - let layerOut := fwd.attnOutputs[cand.layerIdx] - let best : Option (HeadRef × Float × Float) := Id.run do - let mut best : Option (HeadRef × Float × Float) := none - for hIdx in [:layerOut.size] do - let h : HeadRef := { layerIdx := cand.layerIdx, headIdx := hIdx } - if !containsHead exclude h then - if hh : hIdx < layerOut.size then - let norm := (layerOut[hIdx]'hh).frobeniusNorm - let diff := Float.abs (candNorm - norm) - match best with - | none => - best := some (h, norm, diff) - | some (_, _, bestDiff) => - if diff < bestDiff then - best := some (h, norm, diff) - return best - match best with - | none => none - | some (ctrl, ctrlNorm, absDiff) => - some { candidate := cand, control := ctrl, candNorm, ctrlNorm, absDiff } - else - none - -private def energyEquivalent (cfg : VerificationConfig) (candNorm : Float) - (absDiff : Float) : Bool := - let thresh := max cfg.energyAbsTol (cfg.energyRelTol * candNorm) - absDiff < thresh - -/-! ## Verification context and core API -/ - -/-- Shared baseline information for circuit verification on a fixed prompt. -/ -structure VerificationContext where - model : ConcreteModel - W_U : ConcreteMatrix - /-- Logits are evaluated at this position (typically the last token). -/ - pos : Nat - targetToken : Nat - negativeToken : Nat - baseTargetLogit : Float - baseNegLogit : Float - baseDelta : Float - baselineForward : ForwardPassResult - cfg : VerificationConfig - -namespace VerificationContext - -/-- Build a verification context for a fixed target token. - -The negative token `t_neg` is chosen as the **top non-target** logit at the evaluation position. --/ -def build (model : ConcreteModel) (targetToken : Nat) - (cfg : VerificationConfig := {}) : Except String VerificationContext := - match model.unembedding with - | none => .error "Model is missing UNEMBEDDING; cannot compute logits." - | some W_U => - if model.seqLen = 0 then - .error "Model has seqLen = 0; cannot evaluate logits." - else - let pos := model.seqLen - 1 - let baselineForward := model.runForward cfg.causal - let residual := baselineForward.finalOutput - match logitAt residual pos W_U targetToken with - | none => - .error s!"Cannot compute logit(target={targetToken}) at pos={pos} \ - (dimension mismatch or token OOB)." - | some baseTargetLogit => - match topNonTargetToken residual pos W_U targetToken with - | none => - .error "Cannot select top non-target token (need vocab ≥ 2 and target in-bounds)." - | some (negativeToken, baseNegLogit) => - let baseDelta := baseTargetLogit - baseNegLogit - .ok { - model := model - W_U := W_U - pos := pos - targetToken := targetToken - negativeToken := negativeToken - baseTargetLogit := baseTargetLogit - baseNegLogit := baseNegLogit - baseDelta := baseDelta - baselineForward := baselineForward - cfg := cfg - } - -end VerificationContext - -/-- Verify a candidate circuit by **head ablation** with an energy-matched control. - -This function enforces the runtime axioms: -1. Baseline competence: `Δ_base > ε` -2. Control independence: control heads disjoint from candidate heads -3. Energy equivalence: per-head output norms match within tolerance - -Only if these axioms are verified do we run the ablations and compute impact scores. --/ -def verifyCircuit (ctx : VerificationContext) (candidateHeads : Array HeadRef) : - CircuitVerificationRow := Id.run do - let baseDelta := ctx.baseDelta - let competence := baseDelta > ctx.cfg.competenceEpsilon - let mut failures : Array String := #[] - if !competence then - failures := failures.push s!"Axiom1(baseline competence) failed: Δ_base={baseDelta} ≤ \ - ε={ctx.cfg.competenceEpsilon}" - - if !competence then - let axioms : AxiomStatus := { - baselineCompetence := false - controlIndependence := true - energyEquivalence := true - failures := failures - } - return { - candidateHeads := candidateHeads - controlHeads := none - baseDelta := baseDelta - ablatedDelta := none - impact := none - relScore := none - controlImpact := none - axioms := axioms - } - - -- Choose one control head per candidate head (same layer, closest output norm). - let mut selections : Array ControlSelection := #[] - for cand in candidateHeads do - match selectEnergyMatchedControl ctx.baselineForward cand candidateHeads with - | some sel => selections := selections.push sel - | none => - failures := failures.push s!"No control head found in layer {cand.layerIdx} for {cand}" - - let controlsComplete := selections.size = candidateHeads.size - let controlHeads : Array HeadRef := selections.map (·.control) - let independence := !(controlHeads.any (fun h => containsHead candidateHeads h)) - if !independence then - failures := failures.push "Axiom2(control independence) failed: control overlaps candidate." - - let energyOk := - controlsComplete && selections.all (fun s => energyEquivalent ctx.cfg s.candNorm s.absDiff) - if !energyOk then - failures := failures.push "Axiom3(energy equivalence) failed: no sufficiently matched control." - - let axioms : AxiomStatus := { - baselineCompetence := competence - controlIndependence := independence - energyEquivalence := energyOk - failures := failures - } - - if !axioms.verified then - return { - candidateHeads := candidateHeads - controlHeads := if controlsComplete then some controlHeads else none - baseDelta := baseDelta - ablatedDelta := none - impact := none - relScore := none - controlImpact := none - axioms := axioms - } - - -- Candidate ablation - let fwdAblated := runForwardAblatingHeads ctx.model candidateHeads ctx.cfg.causal - let residualAblated := fwdAblated.finalOutput - let ablatedTargetLogit := - (logitAt residualAblated ctx.pos ctx.W_U ctx.targetToken).getD 0.0 - let ablatedNegLogit := - (logitAt residualAblated ctx.pos ctx.W_U ctx.negativeToken).getD 0.0 - let ablatedDelta := ablatedTargetLogit - ablatedNegLogit - let impact := baseDelta - ablatedDelta - let relScore := if baseDelta > 0.0 then impact / baseDelta else 0.0 - - -- Control ablation (energy-matched, layer-matched) - let fwdCtrl := runForwardAblatingHeads ctx.model controlHeads ctx.cfg.causal - let residualCtrl := fwdCtrl.finalOutput - let ctrlTargetLogit := (logitAt residualCtrl ctx.pos ctx.W_U ctx.targetToken).getD 0.0 - let ctrlNegLogit := (logitAt residualCtrl ctx.pos ctx.W_U ctx.negativeToken).getD 0.0 - let ctrlDelta := ctrlTargetLogit - ctrlNegLogit - let controlImpact := baseDelta - ctrlDelta - - return { - candidateHeads := candidateHeads - controlHeads := some controlHeads - baseDelta := baseDelta - ablatedDelta := some ablatedDelta - impact := some impact - relScore := some relScore - controlImpact := some controlImpact - axioms := axioms - } - -end Nfp diff --git a/README.md b/README.md index 6dfcb60..8da1edd 100644 --- a/README.md +++ b/README.md @@ -1,394 +1,192 @@ # NFP -NFP is a Lean 4 project for **mathematically rigorous** reasoning about transformer-style computations, with a focus on mechanistic interpretability (e.g. induction heads) and provable norm/error bounds. +NFP is a Lean 4 project for **mathematically rigorous** reasoning about transformer-style +computations, with a focus on mechanistic interpretability (e.g. induction heads) and provable +norm/error bounds. NFP stands for **Neural Formal Pathways**. -This repo contains: - -- A **Lean library** (under `Nfp/`) for finite probability and a lightweight “transformer semantics” layer. -- A **CLI executable** (`lake exe nfp …`) that loads transformer weights stored in a compact binary format (`.nfpt`) and produces rigorous bounds and diagnostics. - -> Goal: *no “hand-wavy” numerics in the bound path.* Heuristic estimates (e.g. power iteration) may exist for diagnostics, but the bounds reported as “rigorous” are computed via conservative inequalities. - ## Status -This is research tooling. Interfaces may change; please treat results as experimental unless they are backed by a certificate/check you trust. - -## Soundness statement (what is proven vs checked) - -The Lean library defines the core math objects (finite probability, mixers, linearizations, and operator-norm-style bounds) and proves a number of lemmas about them. The CLI sound path produces certificates using exact `Rat` arithmetic and a trusted checker that verifies internal arithmetic relationships between certificate fields. - -At present, the checker does **not** include a bridge theorem that connects certificate validity to the Lean-defined Jacobian bounds (for example, a theorem of the form `||layerJacobian - I|| <= C`). Treat sound certificates as **internally consistent bound reports**, not as a fully formal end-to-end verification of transformer Jacobians. - -For known gaps and ongoing upgrades, see `SOUNDNESS_LIMITATIONS.md`. - -## North Star - -NFP’s long-term direction is **verified circuit discovery**: - -- Use fast, exploratory tooling to **propose** candidate circuits (e.g. induction-style head interactions), -- then produce **checkable evidence** (bounds / certificates) that a skeptical reader can re-run and validate. - -Concretely, the intended split is: - -- **Discovery / exploration (untrusted, fast):** - Heuristic search, ranking, and diagnostics are allowed here (and should be clearly labelled as such). - This includes things like candidate search (`induction`) and comparison estimates printed under diagnostics/verbose flags. +This repository is in a **tabula rasa rewrite**. The new core is intentionally minimal and the API +surface is still settling. Expect breaking changes. -- **Certification / checking (trusted, boring):** - Anything described as “rigorous” should be justified by conservative inequalities or by a certificate that a checker can validate. - The long-term aim is that Lean does as little “real inference” as possible: instead of running large forward passes, - it should mostly **check small, structured proof obligations** (e.g. inequality chains, norm bounds, interval/rational arithmetic). - -Current state: `certify` is already an example of this direction (sound-mode reporting using exact `Rat` arithmetic rather than trusted floats), -but the certificate story is still evolving and interfaces may change. - -Model trajectory: GPT-2 support is currently a proving ground for the end-to-end workflow (export → analyze/search → bound/certify). -The goal is to gradually cover more modern decoder blocks (e.g. RoPE-style position handling) while keeping the certification/checking layer lightweight. - -## Reproduce results - -Minimal local demo (no network needed): +## Build ```bash lake build -q --wfail lake build nfp -q --wfail -lake exe nfp certify tests/fixtures/tiny_sound_binary.nfpt \ - --output reports/tiny_sound_demo.txt -``` - -Expected artifacts: -- `reports/tiny_sound_demo.txt` - -Optional (rebuild the tiny binary from text fixtures and run a fixed induction cert): - -```bash -./scripts/demo_tiny_local_binary.sh -./scripts/demo_tiny_induction_cert.sh -``` - -Expected artifacts (optional path): -- `reports/tiny_sound_local_binary.txt` -- `reports/tiny_induction_cert.txt` - -End-to-end GPT-2 demo (requires network/model download): - -```bash -./scripts/demo_gpt2_sound.sh -./scripts/demo_gpt2_induction_sound.sh -``` - -Expected artifacts: -- `reports/gpt2_sound_demo.txt` -- `reports/gpt2_induction_sound_scan.txt` - -Notes: -- If a legacy `.nfpt` header is missing `gelu_kind`, `demo_gpt2_sound.sh` writes - `models/gpt2_with_gelu_kind.nfpt` and uses that for certification. -- `demo_gpt2_induction_sound.sh` can take a while on CPU; use `--top 1`, - `--fast`, or `--jobs 2` to shorten the scan or run it on a larger machine. -- You can also set `NFP_BIN=./.lake/build/bin/nfp` to avoid repeated `lake exe` - startup overhead. - - -## Requirements - -- **Lean 4** (pinned by `lean-toolchain`) and **Lake**. - - Easiest install: `elan` (Lean toolchain manager). -- A standard build toolchain for Lean (C/C++ compiler, `make`, etc.). -- (Optional) **Python** for the export scripts in `scripts/`. - -Lean version is pinned in `lean-toolchain` (currently `leanprover/lean4:v4.26`). - -## Getting started - -Clone and build: - -```bash -lake update -lake build ``` -Run the CLI (see subcommands below): +## CLI ```bash lake exe nfp --help +lake exe nfp induction --help ``` -## Models - -The CLI expects a model file in **`.nfpt`** format (NFP_BINARY_V1). - -- Create a local `models/` directory and place your `.nfpt` files there (the repo does not version model files; the author’s setup may have used local symlinks). -- You can export GPT-2 weights from Hugging Face using the scripts in `scripts/`. - -`.nfpt` files use a small text header followed by a binary payload: - -``` -NFP_BINARY_V1 -num_layers=... -num_heads=... -model_dim=... -head_dim=... -hidden_dim=... -vocab_size=... -seq_len=... -BINARY_START -``` +Current subcommands are limited to **induction certificate checking**. The CLI does **not** run a +full model forward pass; certificate generation is done by untrusted helper scripts (see below). -The payload is raw little-endian bytes in a fixed order (tokens, embeddings, then weights). +## Module map -Note: global sound certification supports `NFP_BINARY_V1`. Local sound certification -supports `NFP_BINARY_V1` (fixed-point union-box) and legacy `NFP_TEXT_V1/V2`. +The authoritative module map and invariants are tracked in `AGENTS.md`. -### Exporting GPT-2 to `.nfpt` +High-level layout: +- `Nfp/Core`, `Nfp/Prob`, `Nfp/Mixer`, `Nfp/System`: core math infrastructure. +- `Nfp/Circuit`: circuits, typed interfaces, and layer wiring (attention, induction). +- `Nfp/Sound`: soundness theorems and verified helpers. +- `Nfp/IO`, `Nfp/Cli`: parsing and CLI entrypoints. -The export scripts use `torch` + `transformers`. +## Induction Certification (prototype) -Example (write `models/gpt2_rigorous.nfpt`): +The current prototype checks **explicit induction-head certificates**. Certificates are produced +by **untrusted** Python scripts and verified by the Lean CLI; no model forward pass runs in Lean. +The input setup follows the standard literature diagnostic: repeated token patterns (pattern +repeated twice) and attention stripes that look back by one period. -```bash -python scripts/export_gpt2.py models/gpt2_rigorous.nfpt -``` +For a step-by-step walkthrough, see `docs/demo.md`. +For a careful statement of what certificates do and do not claim, see +`docs/cert_usefulness.md`. -If you prefer a locked Python environment, use `uv` or a venv and install dependencies from `pyproject.toml`: +### Build a head certificate (untrusted) ```bash -uv run python scripts/export_gpt2.py models/gpt2_rigorous.nfpt +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --active-eps-max 1/2 ``` -### GPT-2 sound demo (global) +Layer/head indices in the generator are 1-based to match the literature. -This demo downloads GPT-2 weights on demand, exports a binary `.nfpt`, and runs the -global sound certificate. +To certify a **non-vacuous** logit-diff lower bound, supply a direction: ```bash -./scripts/demo_gpt2_sound.sh +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --active-eps-max 1/2 \ + --direction-target 1268 --direction-negative 1796 ``` -Artifacts: -- `models/gpt2.nfpt` (binary export) -- `reports/gpt2_sound_demo.txt` (sound certificate report) - -### GPT-2 induction sound scan - -This demo builds the rigorous induction dataset (if needed), finds candidate -induction head pairs, and ranks them by sound logit-diff lower bounds. +Or let the untrusted script search for a direction in a vocab slice: ```bash -./scripts/demo_gpt2_induction_sound.sh +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --active-eps-max 1/2 \ + --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 \ + --direction-min-lb 1/10 \ + --direction-report-out reports/direction_report.txt --direction-topk 10 \ + --tokens-out reports/gpt2_induction.tokens ``` -Artifacts: -- `models/gpt2_rigorous.nfpt` (binary export) -- `reports/gpt2_induction_sound_scan.txt` (sound scan report) +Direction search is **untrusted witness generation**; the Lean CLI only verifies the resulting +explicit certificate. The direction report lists the top-ranked candidates by estimated lower +bound so you can pick a stable non-vacuous direction. -### Tiny local binary demo +Optional direction metadata: -This demo converts the tiny text fixtures into a binary `.nfpt` and runs a local -sound certificate (with `--delta`). - -```bash -./scripts/demo_tiny_local_binary.sh ``` - -Artifacts: -- `tests/fixtures/tiny_sound_binary.nfpt` (binary fixture) -- `reports/tiny_sound_local_binary.txt` (local sound certificate report) - -### Tiny induction cert demo - -This demo computes a minimal induction head certificate on the tiny fixture. - -```bash -./scripts/demo_tiny_induction_cert.sh +--direction-target --direction-negative ``` -Artifacts: -- `reports/tiny_induction_cert.txt` (induction cert report) - -## CLI overview - -The main entrypoint is: +### Verify a head certificate (trusted checker) ```bash -lake exe nfp [args] [flags] +lake exe nfp induction certify --cert reports/gpt2_induction.cert ``` -By default, `nfp` mirrors everything printed to stdout into `logs/` as a timestamped `.log` file. +Optional gates: -### `analyze` - -Runs the default end-to-end analysis for the supplied model and prints a human-readable report. - -```bash -lake exe nfp analyze models/gpt2_rigorous.nfpt \ - --threshold 0.1 --verify --verbose --output report.txt ``` - -- `--threshold` (`-t`) sets the minimum effect threshold used for verification (default: `0.1`). -- `--verify` optionally runs causal verification using model-provided inputs. -- `--verbose` prints model metadata and per-stage status messages. -- `--output` (`-o`) writes the report to a file instead of stdout. - -### `induction` - -Searches for **candidate induction circuits** and ranks head pairs by a mechanical score. - -```bash -lake exe nfp induction models/gpt2_rigorous.nfpt \ - --threshold 0.0 --diagnostics --diagTop 5 --adaptive --verbose +--min-active --min-margin --max-eps --min-logit-diff --tokens ``` -- `--threshold` (`-t`) sets the minimum normalized effect (default: `0.0`). -- `--correct` / `--incorrect` manually pick logit IDs for the induction target (otherwise the target is inferred from tokens). -- `--verify` runs causal verification via head ablation on the top-10 candidates. -- `--diagnostics` enables bound breakdowns; `--diagTop` controls how many candidates receive diagnostics (default: `5`). -- `--adaptive` turns on the adaptive bound scheduler. Tuning flags include `--targetSlack` (default: `8.0`), - `--maxUpgrades` (default: `120`), `--minRelImprove` (default: `0.01`), `--krylovSteps` (default: `2`), - and `--adaptiveScope` (`layernorm | all`, default: `layernorm`). -- `--verbose` prints detailed scoring metrics for each candidate. - -### `certify` - -Computes a conservative **certificate report** in sound mode using exact `Rat` arithmetic (no trusted floats). - -Note: global sound certification supports `NFP_BINARY_V1`. Local sound certification -supports `NFP_BINARY_V1` (fixed-point union-box) and legacy `NFP_TEXT_V1/V2`. - -`certify` supports both: -- **global certification** (weights only), and -- **local certification** (weights + a small input region around a concrete prompt/input). - -```bash -lake exe nfp certify models/gpt2_rigorous.nfpt \ - --output cert.txt -``` +If `--tokens` is provided, the CLI verifies that the certificate's `prev` and `active` +match the token-sequence semantics for repeated tokens (previous occurrence). -- For local (input-dependent) LayerNorm certification, pass an ℓ∞ radius `δ`: +Example non-vacuous check: ```bash -lake exe nfp certify models/gpt2_rigorous.nfpt \ - --delta 0.01 +lake exe nfp induction certify --cert reports/gpt2_induction.cert --min-logit-diff 1/10 ``` -If you want to override the embedded input, pass a separate input `.nfpt`: +## File formats -- LayerNorm ε is read from the model header (`layer_norm_eps`). -- `gelu_kind` in the model header selects the GeLU derivative target (`tanh` or `exact`). -- `--delta` sets the local ℓ∞ radius `δ` (default: `0`). Providing `--delta` enables local certification. -- `--partitionDepth` requests input partitioning depth (default: `0`; scaffold only, must remain `0` for now). -- `--input` optionally provides an input `.nfpt` file used for local certification. -- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). +### Induction-head certificate -### `head_bounds` - -Computes sound per-head contribution bounds (global weight-only, or local with `--delta`). - -```bash -lake exe nfp head_bounds models/gpt2_rigorous.nfpt ``` - -For local bounds (uses input embeddings in the model file when present): - -```bash -lake exe nfp head_bounds models/gpt2_rigorous.nfpt --delta 0.01 +seq +direction-target +direction-negative +eps +margin +active +prev +score +weight +eps-at +weight-bound +lo +hi +val +val-lo +val-hi ``` -- `--delta` enables local head bounds; `--input` can override the embedded input. -- LayerNorm ε is read from the model header (`layer_norm_eps`). -- `--scalePow10` controls fixed-point scaling for global bounds (default: `9`). -- `--output` (`-o`) writes the report to a file (otherwise it prints to stdout). +All sequence indices (`q`, `k`) are **1-based** (literature convention). Direction token IDs +(`direction-target`, `direction-negative`) are raw model IDs (tokenizer convention). +`direction-*` lines are optional metadata; if present, both must appear. If no `active` lines +appear, the checker defaults to all non-initial queries (indices 2.. in 1-based indexing). -### `head_pattern` +### Direction report (untrusted) -Computes a sound local attention pattern bound for a single head (binary only), -propagating per-position intervals up to the target layer (bounded by `maxSeqLen`). -The pattern compares logits for keys whose token matches the query’s offset token -(e.g., `--offset -1` matches the previous token). - -```bash -lake exe nfp head_pattern models/gpt2_rigorous.nfpt --layer 0 --head 0 --delta 0.01 --offset -1 ``` - -- `--offset` selects the target key position relative to the query (default: `-1` for previous token). -- `--maxSeqLen` caps the sequence length analyzed for pattern bounds (default: `256`). -- `--delta` sets the local input radius; LayerNorm ε is read from the model header (`layer_norm_eps`). -- `--tightPattern` enables a slower but tighter pattern bound near the target layer. -- `--tightPatternLayers` sets how many layers use tight bounds (default: `1`; implies `--tightPattern`). -- `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). -- `--bestMatch` switches to a single-query best-match bound (default query: last position). -- `--sweep` prints best-match bounds for all valid query positions (requires `--bestMatch`). -- `--queryPos` chooses the query position for best-match bounds (default: last position). - -### `induction_cert` - -Computes a minimal sound induction-head certificate by combining two pattern -certificates and a value-coordinate lower bound (binary only). - -```bash -lake exe nfp induction_cert models/gpt2_rigorous.nfpt \ - --layer1 0 --head1 0 --layer2 1 --head2 0 --coord 0 --delta 0.01 \ - --target 42 --negative 17 +direction_report +vocab_min= vocab_max= seed= +rank\tlb\ttarget\tnegative ``` -- `--layer1/--head1` selects the previous-token head; `--layer2/--head2` selects the - token-match head. -- `--coord` chooses the output coordinate used for the value lower bound. -- `--offset1/--offset2` adjust the token-match offsets (default: `-1`). -- `--target/--negative` optionally add a logit-diff lower bound using unembedding columns. -- `--tightPattern` enables a slower but tighter pattern bound near the target layer. -- `--tightPatternLayers` sets how many layers use tight bounds (default: `1`; implies `--tightPattern`). -- `--perRowPatternLayers` sets how many layers use per-row MLP propagation (default: `0`). -- `--bestMatch` switches to single-query best-match bounds (default query: last position). -- `--queryPos` chooses the query position for best-match bounds (default: last position). - -### `rope` +This file is an **untrusted helper artifact**; it only ranks candidate directions and does not +change what the Lean checker accepts. -Generates RoPE-related linearization bounds used by the certificate/checking pipeline. +### Token list (untrusted) -```bash -lake exe nfp rope --seqLen 4 --pairs 8 +``` +seq +token ``` -- `--seqLen` instantiates the bound at the given sequence length (default: `4`). -- `--pairs` sets the number of RoPE pairs; the dimension is `2 * pairs` (default: `8`). - -## What “rigorous” means here - -At a high level, the “rigorous” path avoids heuristic operator-norm estimation and instead uses **upper bounds** derived from standard inequalities (examples you may see in logs): - -- Frobenius-norm based bounds. -- Gram-matrix based bounds. -- Schur / Brauer-style eigenvalue bounds for symmetric matrices. -- Row-wise softmax operator bounds using quantities like `rowMaxP`, `rowTrace`, Gershgorin-style estimates, and a “moment” bound. - -The CLI may still compute **power-iteration estimates** for comparison, but those are explicitly labelled as diagnostics and are not used to produce the rigorous `ub=…` values. +This file is an **untrusted helper artifact** used to check that `prev` and `active` match the +token sequence (previous-occurrence semantics) when `--tokens` is supplied to the CLI. Indices +are 1-based. -## Reproducing the example command +## Soundness boundary -A typical workflow: +- Untrusted scripts may use floating-point numerics to generate candidate certificates. +- The CLI **only verifies** explicit certificates; it does not search for witnesses or run models. -```bash -# 1) Build -lake update -lake build +For known gaps, see `SOUNDNESS_LIMITATIONS.md`. -# 2) Export a model (optional) -python scripts/export_gpt2.py models/gpt2_rigorous.nfpt +## Requirements -# 3) Run induction search with diagnostics -lake exe nfp induction models/gpt2_rigorous.nfpt -v -d | sed -n '1,220p' -``` +- **Lean 4** (pinned in `lean-toolchain`) and **Lake**. +- Optional: **Python** for helper scripts (`scripts/`), plus `torch`, `transformers`, and `numpy`. -## Project layout +## References -- `Main.lean` — CLI wiring and command definitions. -- `Nfp/` — library code (probability, transformer semantics, soundness/cert machinery, discovery routines). -- `scripts/` — Python helpers to export models and generate induction datasets. -- `models/` — local model files (not versioned here if large; author’s setup may have used local symlinks). +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” + Link: `https://transformer-circuits.pub/2021/framework/index.html` +- Olsson et al., “In-context Learning and Induction Heads.” + Link: `https://arxiv.org/abs/2209.11895` -## License +## Contributing -This project is licensed under the GNU Affero General Public License v3.0 or later (AGPL-3.0-or-later). See the LICENSE file. +Please follow the project rules in `AGENTS.md` (no `sorry`, no linter disables, total soundness in +trusted namespaces). diff --git a/SOUNDNESS_LIMITATIONS.md b/SOUNDNESS_LIMITATIONS.md index 260ad98..04bb374 100644 --- a/SOUNDNESS_LIMITATIONS.md +++ b/SOUNDNESS_LIMITATIONS.md @@ -1,33 +1,26 @@ -## SOUNDNESS upgrade status +# SOUNDNESS_LIMITATIONS -This file tracks **current limitations** and **remaining work** for the rigorous -soundness upgrade. It is intentionally brief and human-readable. +This file tracks **current limitations** and **remaining work** for the tabula rasa rewrite. +It is intentionally brief and focused on the soundness boundary. -### Current limitations -- The bridge theorem in `Nfp/Sound/Bridge.lean` links `LayerAmplificationCert` bounds to - `DeepLinearization` residual Jacobians, but it requires external operator-norm assumptions - (LN Jacobians, attention full Jacobian, and MLP factors). The trusted checker does not yet - discharge those assumptions from model weights. -- `partitionDepth > 0` is rejected with an explicit error (no partitioning logic yet). -- Affine arithmetic is only a scaffold (`Nfp/Sound/Affine.lean`) and not wired into SOUND certification. -- Softmax Jacobian bounds typically use probability intervals defaulted to `[0,1]`, so they - reduce to the worst case. Margin-derived tightening is computed by the untrusted path, but - trusted IO currently **rejects nonzero** `softmaxMarginLowerBound` because margin evidence is - unverified. -- Best-match pattern certificates now use a margin-derived softmax Jacobian bound with an - effort-indexed `expLB` (scaled Taylor + squaring). The lower-bound correctness of `expLB` - is not yet formalized in Lean. -- GeLU derivative bounds are conservative envelopes; the exact interval supremum is not computed yet. -- Attention Jacobian bounds now include an explicit pattern-term coefficient using max `W_Q/W_K` - row-sum norms and a conservative LayerNorm output magnitude bound (`max|gamma|*sqrt(d)+max|beta|`), - but this is still very conservative and only connected to the Lean Jacobian theorems - under the external norm assumptions above. +## Current limitations -### Remaining work -- Implement input-space partitioning in the SOUND local path and plumb it through the certify pipeline. -- Replace or augment interval propagation with affine forms to preserve correlations. -- Add sound probability interval extraction for softmax (requires sound exp/log-sum-exp bounds). -- Verify or compute margin evidence in the trusted path so margin-derived softmax tightening can be enabled. -- Tighten GeLU derivative envelopes to the exact interval supremum if desired. -- Discharge the bridge theorem’s component-norm assumptions from certificates/model weights, and - connect the resulting statement to the `Linearization` Jacobian theorems. +- The trusted CLI only **checks explicit certificates**; it does not search for witnesses or + run model evaluation. +- Induction certificates are **head-level** (softmax-margin + value-interval + logit-diff lower + bound) and conditional on the supplied `prev`, `active`, and `direction` inputs. They do **not** + yet imply full model behavior. +- Direction metadata (`direction-target`, `direction-negative`) is untrusted and assumes that the + unembedding columns represent token logits. +- Any direction search performed by Python helpers is untrusted witness generation; only the + resulting explicit certificate is checked by the Lean CLI. +- The active set is user-supplied (or defaulted by the parser); bounds only hold for + `q ∈ active`. You can optionally verify `prev`/`active` against a token list via + `nfp induction certify --tokens ...`. +- Performance: checking large certificates can be expensive for long sequences. + +## Remaining work + +- Prove or verify that `prev`, `active`, and `direction` are derived from token-level semantics. +- Add a verified extraction pipeline from model weights to explicit certificates. +- Extend the bridge from head-level certificates to full circuit/model semantics. diff --git a/TheoremAxioms.lean b/TheoremAxioms.lean new file mode 100644 index 0000000..5477f81 --- /dev/null +++ b/TheoremAxioms.lean @@ -0,0 +1,34 @@ +-- SPDX-License-Identifier: AGPL-3.0-or-later + +module + +import Nfp + +/-! +Axioms used by key definitions/lemmas. +These `#print axioms` lines help ensure we only depend on a small set of axioms +(ideally a subset of: `propext`, `Classical.choice`, `Quot.sound`). +-/ + +public section + +#print axioms Nfp.ProbVec.sum_mass +#print axioms Nfp.ProbVec.pure +#print axioms Nfp.ProbVec.mix +#print axioms Nfp.Mixer.push +#print axioms Nfp.Mixer.comp +#print axioms Nfp.Mixer.id +#print axioms Nfp.Dag.parents +#print axioms Nfp.LocalSystem.toMixer +#print axioms Nfp.LocalSystem.eval +#print axioms Nfp.LocalSystem.eval_eq +#print axioms Nfp.Circuit.eval +#print axioms Nfp.Circuit.evalInput +#print axioms Nfp.Circuit.Interface.eval +#print axioms Nfp.Circuit.checkEquiv +#print axioms Nfp.Circuit.checkEquivOnInterface + +/-- Entrypoint for the axiom report build target. -/ +def main : IO Unit := pure () + +end diff --git a/docs/cert_usefulness.md b/docs/cert_usefulness.md new file mode 100644 index 0000000..700d63f --- /dev/null +++ b/docs/cert_usefulness.md @@ -0,0 +1,39 @@ +# What Induction-Head Certificates Do (and Do Not) Claim + +This document summarizes the **useful, limited guarantees** provided by an induction‑head +certificate, without overselling. + +## What the certificate **does** guarantee + +If the Lean checker accepts a certificate, then: +- **Softmax‑margin bounds** hold on the specified active queries (the `prev` score dominates other + keys by the declared margin). +- **One‑hot‑style weight bounds** hold on those queries (non‑`prev` weights are bounded by `ε`). +- **Value‑interval bounds** hold for the supplied value ranges. +- **Logit‑diff lower bound** holds for the supplied direction (if direction metadata is present + and the checker is run with `--min-logit-diff`). + +These are **formal, exact** statements about the explicit certificate data. + +## Why this is useful + +- **Quantitative guarantees:** the bounds are numeric and can be gated (e.g., require a strictly + positive logit‑diff lower bound). +- **Reproducibility:** certificates are explicit artifacts that can be re‑checked later. +- **Comparability:** bounds provide a principled way to compare heads or settings. +- **Soundness boundary clarity:** generation is untrusted, verification is trusted. + +## What the certificate **does not** guarantee + +- **No full‑model claim:** this is a head‑level certificate; it does not imply end‑to‑end model + behavior. +- **Input‑specific:** guarantees apply only to the specified inputs / token patterns. +- **Untrusted semantics:** unless you pass `--tokens`, the `prev` and `active` sets are not + verified against a token sequence. +- **Direction is untrusted:** `direction-target` / `direction-negative` are supplied metadata. + +## Optional token verification + +If you provide a token list to the CLI (`--tokens`), the checker verifies that `prev` and `active` +match **previous‑occurrence semantics** for that token sequence. This strengthens the link to the +induction‑head diagnostic while keeping the trusted checker lightweight. diff --git a/docs/demo.md b/docs/demo.md new file mode 100644 index 0000000..bc1400c --- /dev/null +++ b/docs/demo.md @@ -0,0 +1,61 @@ +# NFP User-Facing Demo (Induction Head Certification) + +This demo shows a full, reproducible path from **untrusted certificate generation** +to **trusted Lean verification** for an induction head. + +## 0. Prerequisites + +- Lean 4 / Lake (pinned in `lean-toolchain`) +- Python with: `numpy`, `torch`, `transformers` + +## 1. Build + +```bash +lake build -q --wfail +lake build nfp -q --wfail +``` + +## 2. Generate an explicit certificate (untrusted) + +This uses a repeated pattern (period repeated twice) and searches for a +non-vacuous logit-diff direction. + +```bash +python scripts/build_gpt2_induction_cert.py \ + --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --active-eps-max 1/2 \ + --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 \ + --direction-min-lb 1/10 \ + --direction-report-out reports/direction_report.txt --direction-topk 10 \ + --tokens-out reports/gpt2_induction.tokens +``` + +Expected output includes: +- a certificate (`reports/gpt2_induction.cert`) +- a ranked direction report (`reports/direction_report.txt`) +- a token list (`reports/gpt2_induction.tokens`) + +## 3. Verify with the Lean checker (trusted) + +This enforces a **non-vacuous** logit-diff lower bound and checks `prev/active` +against the token list. + +```bash +lake exe nfp induction certify \ + --cert reports/gpt2_induction.cert \ + --min-logit-diff 1/10 \ + --tokens reports/gpt2_induction.tokens +``` + +Expected output (example): + +``` +ok: induction head certificate checked (seq=32, active=15, margin=..., eps=..., logitDiffLB=...) +``` + +## Notes + +- Everything in `scripts/` is **untrusted witness generation**. +- The Lean CLI **only verifies** explicit certificates and token semantics. diff --git a/docs/induction_cert_audit.md b/docs/induction_cert_audit.md new file mode 100644 index 0000000..eafb351 --- /dev/null +++ b/docs/induction_cert_audit.md @@ -0,0 +1,78 @@ +# Induction Head Certification Audit + +Goal: assess whether the current Lean proofs justify the claim that we can certify +induction heads, and spell out the scope and limitations of that claim. + +## Formal proof chain (Lean) + +- Explicit induction-head certificates are parsed from text in + `Nfp/IO/InductionHead/Cert.lean` (sequence indices are 1-based in the payload). +- `checkInductionHeadCert` and `checkInductionHeadCert_sound` show that a + passing certificate satisfies `InductionHeadCertBounds` + (`Nfp/Circuit/Cert/InductionHead.lean`). +- `logitDiffLowerBoundAt` plus `logitDiffLowerBoundAt_le` give a certified lower + bound on the logit-diff contribution derived from the certificate’s values + (`Nfp/Circuit/Cert/LogitDiff.lean`). +- `headLogitDiff_eq_direction_dot_headOutput` connects the logit-diff definition + to head-output semantics (`Nfp/Sound/Induction/LogitDiff.lean`). + +## Mechanistic mapping (Transformer Circuits) + +The mechanistic induction-head story is a QK/OV decomposition: +- QK: identify a matching prior token (prefix-matching attention). +- OV: write the continuation token (or logit-diff direction) into the residual stream. + +The certificate aligns to that decomposition: +- The softmax-margin bounds constrain the QK pattern so that attention to the + chosen `prev` index dominates other keys (mechanistic “prefix match”). +- The value-interval bounds and logit-diff lower bound constrain the OV path in + the chosen direction, so the head’s contribution increases the target logit + relative to the negative logit. + +This is direct mechanistic evidence in the Transformer Circuits sense: it ties +parameters (Q/K/V/O + LayerNorm) to certified bounds on attention and value +contributions, but only for the specific inputs and direction supplied. + +## Literature alignment + +We follow the standard induction-head diagnostic setup from the literature: +repeated-token sequences (a pattern repeated twice) and attention stripes that +look back by one period (`q -> q - period`). The diagnostic script +`scripts/diagnose_induction_heads.py` mirrors this setup, and the certificate +generator uses repeated patterns for its inputs. + +## Preconditions and scope limits + +These proofs are sufficient for a **conditional** certification claim: +if the explicit certificate passes the checker, then the head-level bounds hold. +They are **not** sufficient for a global claim that a head “is an induction head” +without additional assumptions. + +Key assumptions and limitations: +- `prev`, `active`, and `direction` are user-supplied or produced by untrusted + scripts; Lean does not (yet) verify their derivation from token-level semantics. +- The active set can be strict; bounds only hold for `q ∈ active`, not all positions. +- The direction metadata assumes the unembedding columns encode the model’s logit map. + +Optional safeguard: +- If a token list is supplied to the CLI (`--tokens`), the checker verifies that `prev` + and `active` match the previous-occurrence semantics for that sequence. + +## Conclusion + +Yes—**within the formal scope** of the current definitions, the proofs are +enough to claim that we can certify induction-head behavior at the head level: +they certify attention to a specified `prev` index and a logit-diff lower bound +along a specified direction, conditional on an explicit certificate. + +## Next steps + +- Add a verified extraction pipeline from model weights to explicit certificates. +- Prove that `prev`, `active`, and `direction` correspond to token-level semantics. + +## References + +- Elhage et al., “A Mathematical Framework for Transformer Circuits.” + Link: `https://transformer-circuits.pub/2021/framework/index.html` +- Olsson et al., “In-context Learning and Induction Heads.” + Link: `https://arxiv.org/abs/2209.11895` diff --git a/lakefile.toml b/lakefile.toml index 9973d13..29ba710 100644 --- a/lakefile.toml +++ b/lakefile.toml @@ -7,10 +7,24 @@ defaultTargets = ["Nfp"] pp.unicode.fun = true # pretty-prints `fun a ↦ b` autoImplicit = false relaxedAutoImplicit = false +experimental.module = true weak.linter.mathlibStandardSet = true maxSynthPendingDepth = 3 linter.unusedVariables = true weak.linter.unreachableTactic = true +weak.linter.missingDocs = true +weak.linter.unusedTactic = true +weak.linter.omit = true +weak.linter.style.longFile = 1500 +weak.linter.style.setOption = true +weak.linter.style.missingEnd = true +weak.linter.style.openClassical = true +weak.linter.style.show = true +weak.linter.style.lambdaSyntax = true +weak.linter.style.dollarSyntax = true +weak.linter.style.cdot = true +weak.linter.style.longLine = true +weak.linter.nfp.noHeartbeats = true [[require]] name = "mathlib" @@ -19,7 +33,12 @@ rev = "stable" [[lean_lib]] name = "Nfp" +roots = ["Nfp"] [[lean_exe]] name = "nfp" root = "Main" + +[[lean_exe]] +name = "theorem-axioms" +root = "TheoremAxioms" diff --git a/scripts/build_gpt2_induction_cert.py b/scripts/build_gpt2_induction_cert.py new file mode 100644 index 0000000..d4e926d --- /dev/null +++ b/scripts/build_gpt2_induction_cert.py @@ -0,0 +1,530 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Build an induction-head certificate for a GPT-2-small induction head. + +This script is untrusted and uses floating-point arithmetic to produce a +rational induction-head certificate compatible with `nfp induction certify`. +Active induction positions are recorded as `active ` lines in the output. +All sequence indices in the certificate are 1-based (literature convention). + +The repeated-pattern inputs match the standard induction-head diagnostic setup +from the literature (pattern repeated twice). + +Usage: + python scripts/build_gpt2_induction_cert.py --output reports/gpt2_induction.cert \ + --layer 1 --head 6 --seq 32 --pattern-length 16 \ + --random-pattern --seed 0 \ + --values-out reports/gpt2_induction.values --value-dim 0 \ + --active-eps-max 0.2 --min-margin 0 + +Optionally, provide a logit-diff direction: + --direction-target --direction-negative + +Or ask the script to search for a direction (untrusted): + --search-direction --direction-vocab-min 1000 --direction-vocab-max 2000 + --direction-report-out reports/direction_report.txt --direction-topk 10 + +Optional token dump (for Lean-side prev/active verification): + --tokens-out reports/gpt2_induction.tokens + +Note: active positions are filtered by --active-eps-max and --min-margin. If +none qualify, the script exits with an error. +Layer/head indices are 1-based in the CLI and converted to 0-based internally. +Direction token IDs use the model's raw tokenizer indexing. +""" + +import argparse +import math +from fractions import Fraction +from pathlib import Path + +import numpy as np + +try: + import torch + from transformers import GPT2Model +except ImportError: + raise SystemExit( + "Missing dependencies. Install with: uv add transformers torch" + ) + + +def rat_from_float(x: float, decimals: int) -> Fraction: + scale = 10 ** decimals + return Fraction(int(round(x * scale)), scale) + + +def rat_to_str(q: Fraction) -> str: + if q.denominator == 1: + return str(q.numerator) + return f"{q.numerator}/{q.denominator}" + + +def build_tokens(seq: int, pattern_len: int, random_pattern: bool, seed: int) -> np.ndarray: + if random_pattern: + rng = np.random.default_rng(seed) + pattern = rng.integers(1000, 30000, size=pattern_len, endpoint=False) + else: + pattern = np.arange(pattern_len) + repeats = (seq // pattern_len) + 1 + return np.tile(pattern, repeats)[:seq] + + +def build_prev(tokens: np.ndarray) -> tuple[np.ndarray, np.ndarray]: + prev = np.zeros_like(tokens) + active = np.zeros_like(tokens, dtype=bool) + last_seen = {} + for idx, tok in enumerate(tokens): + if idx == 0: + prev[idx] = 0 + active[idx] = False + else: + if tok in last_seen: + prev[idx] = last_seen[tok] + active[idx] = True + else: + prev[idx] = 0 + active[idx] = False + last_seen[tok] = idx + return prev, active + + +def compute_scores_weights(model, input_ids, layer: int, head: int, device: str): + model.eval() + with torch.no_grad(): + outputs = model(input_ids, output_hidden_states=True) + hidden_states = outputs.hidden_states[layer] + block = model.h[layer] + x = block.ln_1(hidden_states) + qkv = block.attn.c_attn(x) + n_head = model.config.n_head + head_dim = qkv.shape[-1] // (3 * n_head) + q, k, _v = qkv.split(n_head * head_dim, dim=2) + q = q.view(1, -1, n_head, head_dim).transpose(1, 2) + k = k.view(1, -1, n_head, head_dim).transpose(1, 2) + v = _v.view(1, -1, n_head, head_dim).transpose(1, 2) + qh = q[:, head] + kh = k[:, head] + vh = v[:, head] + scores = torch.matmul(qh, kh.transpose(-2, -1)) / math.sqrt(head_dim) + seq = scores.shape[-1] + mask = torch.triu(torch.ones(seq, seq, device=device), diagonal=1).bool() + scores = scores.masked_fill(mask, -10000.0) + weights = torch.softmax(scores, dim=-1) + return (scores.squeeze(0).cpu().numpy(), + weights.squeeze(0).cpu().numpy(), + vh.squeeze(0).cpu().numpy()) + + +def write_scores(path: Path, seq: int, prev: np.ndarray, scores, weights, eps=None, margin=None, + active=None): + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + if eps is not None: + f.write(f"eps {rat_to_str(eps)}\n") + if margin is not None: + f.write(f"margin {rat_to_str(margin)}\n") + if active is not None: + for q in active: + f.write(f"active {q + 1}\n") + for q, k in enumerate(prev.tolist()): + f.write(f"prev {q + 1} {k + 1}\n") + for q in range(seq): + for k in range(seq): + f.write(f"score {q + 1} {k + 1} {rat_to_str(scores[q][k])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight {q + 1} {k + 1} {rat_to_str(weights[q][k])}\n") + +def write_induction_cert(path: Path, seq: int, prev: np.ndarray, scores, weights, + eps, margin, active, eps_at, weight_bound_at, + vals, direction_target=None, direction_negative=None) -> None: + lo = min(vals) + hi = max(vals) + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + if direction_target is not None and direction_negative is not None: + f.write(f"direction-target {direction_target}\n") + f.write(f"direction-negative {direction_negative}\n") + f.write(f"eps {rat_to_str(eps)}\n") + f.write(f"margin {rat_to_str(margin)}\n") + if active is not None: + for q in active: + f.write(f"active {q + 1}\n") + for q, k in enumerate(prev.tolist()): + f.write(f"prev {q + 1} {k + 1}\n") + for q in range(seq): + for k in range(seq): + f.write(f"score {q + 1} {k + 1} {rat_to_str(scores[q][k])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight {q + 1} {k + 1} {rat_to_str(weights[q][k])}\n") + for q in range(seq): + f.write(f"eps-at {q + 1} {rat_to_str(eps_at[q])}\n") + for q in range(seq): + for k in range(seq): + f.write(f"weight-bound {q + 1} {k + 1} {rat_to_str(weight_bound_at[q][k])}\n") + f.write(f"lo {rat_to_str(lo)}\n") + f.write(f"hi {rat_to_str(hi)}\n") + for k, val in enumerate(vals): + val_str = rat_to_str(val) + f.write(f"val {k + 1} {val_str}\n") + f.write(f"val-lo {k + 1} {val_str}\n") + f.write(f"val-hi {k + 1} {val_str}\n") + + +def write_value_range(path: Path, seq: int, values, decimals: int, + direction_target=None, direction_negative=None) -> None: + vals_rat = [rat_from_float(float(values[k]), decimals) for k in range(seq)] + lo = min(vals_rat) + hi = max(vals_rat) + with path.open("w", encoding="ascii") as f: + f.write(f"seq {seq}\n") + if direction_target is not None and direction_negative is not None: + f.write(f"direction-target {direction_target}\n") + f.write(f"direction-negative {direction_negative}\n") + f.write(f"lo {rat_to_str(lo)}\n") + f.write(f"hi {rat_to_str(hi)}\n") + for k, val in enumerate(vals_rat): + f.write(f"val {k + 1} {rat_to_str(val)}\n") + + +def write_tokens(path: Path, tokens: np.ndarray) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="ascii") as f: + f.write(f"seq {len(tokens)}\n") + for idx, tok in enumerate(tokens.tolist(), start=1): + f.write(f"token {idx} {tok}\n") + +def search_direction( + model, + values: np.ndarray, + layer: int, + head: int, + prev: np.ndarray, + active_positions: list[int], + eps_at: list[float], + vocab_min: int, + vocab_max: int, + max_candidates: int | None, + topk: int, + seed: int, +) -> tuple[int, int, float, list[tuple[float, int, int]]]: + import heapq + + if vocab_min < 0 or vocab_max <= vocab_min: + raise SystemExit("direction vocab range must satisfy 0 <= min < max") + cand_ids = list(range(vocab_min, vocab_max)) + if max_candidates is not None and max_candidates > 0 and len(cand_ids) > max_candidates: + rng = np.random.default_rng(seed) + cand_ids = rng.choice(cand_ids, size=max_candidates, replace=False).tolist() + + wte = model.wte.weight.detach().cpu().numpy() + if max(cand_ids) >= wte.shape[0]: + raise SystemExit("direction vocab range exceeds model vocab size") + head_dim = model.config.n_embd // model.config.n_head + start, end = head * head_dim, (head + 1) * head_dim + w_o = model.h[layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] + + proj = wte[cand_ids] @ w_o # (m, head_dim) + vals_mat = values @ proj.T # (seq, m) + active_prev = [(q, int(prev[q])) for q in active_positions] + best_lb = None + best_pair = None + topk_entries: list[tuple[float, int, int]] = [] + + for i in range(vals_mat.shape[1]): + vals_i = vals_mat[:, i] + for j in range(vals_mat.shape[1]): + if i == j: + continue + vals = vals_i - vals_mat[:, j] + lo = float(vals.min()) + hi = float(vals.max()) + gap = hi - lo + lb = None + for q, pq in active_prev: + cand = float(vals[pq]) - eps_at[q] * gap + if lb is None or cand < lb: + lb = cand + if lb is None: + continue + if best_lb is None or lb > best_lb: + best_lb = lb + best_pair = (cand_ids[i], cand_ids[j]) + if topk > 0: + # Preserve the best few candidates for reporting. + if len(topk_entries) < topk: + heapq.heappush(topk_entries, (lb, cand_ids[i], cand_ids[j])) + else: + if lb > topk_entries[0][0]: + heapq.heapreplace(topk_entries, (lb, cand_ids[i], cand_ids[j])) + + if best_pair is None or best_lb is None: + raise SystemExit("No direction candidates found") + topk_sorted = sorted(topk_entries, key=lambda x: x[0], reverse=True) + return best_pair[0], best_pair[1], best_lb, topk_sorted + + +def write_direction_report( + path: Path, + entries: list[tuple[float, int, int]], + vocab_min: int, + vocab_max: int, + seed: int, +) -> None: + path.parent.mkdir(parents=True, exist_ok=True) + with path.open("w", encoding="ascii") as f: + f.write("direction_report\n") + f.write(f"vocab_min={vocab_min} vocab_max={vocab_max} seed={seed}\n") + f.write("rank\tlb\ttarget\tnegative\n") + for rank, (lb, target, negative) in enumerate(entries, start=1): + f.write(f"{rank}\t{lb:.6f}\t{target}\t{negative}\n") + + +def main() -> None: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--output", required=True, help="Path to write certificate") + parser.add_argument("--scores-out", help="Optional path for raw scores/weights dump") + parser.add_argument("--layer", type=int, default=1, + help="Transformer layer index (1-based)") + parser.add_argument("--head", type=int, default=1, + help="Attention head index (1-based)") + parser.add_argument("--seq", type=int, default=32, help="Sequence length") + parser.add_argument("--pattern-length", type=int, default=16, help="Pattern length") + parser.add_argument("--random-pattern", action="store_true", help="Use random token pattern") + parser.add_argument("--seed", type=int, default=0, help="RNG seed for random pattern") + parser.add_argument("--decimals", type=int, default=6, help="Decimal rounding for rationals") + parser.add_argument("--model", default="gpt2", help="HuggingFace model name") + parser.add_argument("--device", default="cpu", help="Torch device") + parser.add_argument("--values-out", help="Optional path for a value-range certificate") + parser.add_argument("--value-dim", type=int, default=0, + help="Value dimension index for the value-range certificate") + parser.add_argument("--tokens-out", help="Optional path to write the token list") + parser.add_argument("--active-eps-max", default="1/2", + help="Maximum eps to include an active position (default: 1/2).") + parser.add_argument("--min-margin", default="0", + help="Minimum score gap required for an active position (default: 0).") + parser.add_argument("--direction-target", type=int, + help="Target token id for logit-diff direction (optional)") + parser.add_argument("--direction-negative", type=int, + help="Negative token id for logit-diff direction (optional)") + parser.add_argument("--search-direction", action="store_true", + help="Search for a direction within the vocab range (untrusted).") + parser.add_argument("--direction-vocab-min", type=int, default=1000, + help="Minimum vocab id for direction search (inclusive).") + parser.add_argument("--direction-vocab-max", type=int, default=2000, + help="Maximum vocab id for direction search (exclusive).") + parser.add_argument("--direction-max-candidates", type=int, default=0, + help="Limit number of direction candidates (0 = all in range).") + parser.add_argument("--direction-min-lb", default="0", + help="Minimum logit-diff lower bound to accept (default: 0).") + parser.add_argument("--direction-report-out", type=Path, + help="Optional path to write a ranked direction report.") + parser.add_argument("--direction-topk", type=int, default=10, + help="How many top directions to report (default: 10).") + args = parser.parse_args() + + if args.seq <= 0: + raise SystemExit("seq must be positive") + + tokens = build_tokens(args.seq, args.pattern_length, args.random_pattern, args.seed) + prev, active_mask = build_prev(tokens) + candidate_positions = [int(i) for i, flag in enumerate(active_mask) if flag] + + if args.layer <= 0: + raise SystemExit("layer must be >= 1") + if args.head <= 0: + raise SystemExit("head must be >= 1") + + model = GPT2Model.from_pretrained(args.model) + layer = args.layer - 1 + head = args.head - 1 + if layer >= model.config.n_layer: + raise SystemExit(f"layer must be in [1, {model.config.n_layer}]") + if head >= model.config.n_head: + raise SystemExit(f"head must be in [1, {model.config.n_head}]") + model.to(args.device) + input_ids = torch.tensor(tokens, dtype=torch.long, device=args.device).unsqueeze(0) + + scores, weights, values = compute_scores_weights(model, input_ids, layer, head, args.device) + + scores_rat = [[rat_from_float(float(scores[q, k]), args.decimals) for k in range(args.seq)] + for q in range(args.seq)] + weights_rat = [[rat_from_float(float(weights[q, k]), args.decimals) for k in range(args.seq)] + for q in range(args.seq)] + + for q in range(args.seq): + total = sum(weights_rat[q], Fraction(0)) + if total == 0: + raise SystemExit(f"zero weight sum at q={q}") + weights_rat[q] = [w / total for w in weights_rat[q]] + + eps = Fraction(0) + margin = None + eps_by_q: dict[int, Fraction] = {} + margin_by_q: dict[int, Fraction] = {} + for q in range(args.seq): + prev_q = prev[q] + prev_w = weights_rat[q][prev_q] + if args.seq == 1: + max_other = Fraction(0) + else: + max_other = max(weights_rat[q][k] for k in range(args.seq) if k != prev_q) + deficit = Fraction(1) - prev_w + eps_by_q[q] = max(max_other, deficit) + + diffs = [scores_rat[q][prev_q] - scores_rat[q][k] + for k in range(args.seq) if k != prev_q] + margin_by_q[q] = min(diffs) if diffs else Fraction(0) + + eps_threshold = Fraction(args.active_eps_max) + min_margin = Fraction(args.min_margin) + active_positions = [ + q for q in candidate_positions + if eps_by_q[q] <= eps_threshold and margin_by_q[q] >= min_margin + ] + if not active_positions: + raise SystemExit( + "No active positions satisfy active-eps-max/min-margin. " + "Try a different head/layer, random-pattern/seed, or relax the thresholds." + ) + if active_positions: + eps = max(eps_by_q[q] for q in active_positions) + margin = min((margin_by_q[q] for q in active_positions), default=Fraction(0)) + else: + margin = Fraction(0) + + if candidate_positions: + print(f"Active positions: {len(active_positions)}/{len(candidate_positions)}") + + eps_at = [] + for q in range(args.seq): + prev_q = prev[q] + if args.seq == 1: + max_other = Fraction(0) + else: + max_other = max(weights_rat[q][k] for k in range(args.seq) if k != prev_q) + deficit = Fraction(1) - weights_rat[q][prev_q] + eps_at.append(max(max_other, deficit)) + weight_bound_at = weights_rat + + if args.search_direction and (args.direction_target is not None or args.direction_negative is not None): + raise SystemExit("search-direction is mutually exclusive with explicit direction tokens") + if args.direction_report_out is not None and not args.search_direction: + raise SystemExit("direction-report-out requires --search-direction") + + direction_target = None + direction_negative = None + if args.search_direction: + eps_at_float = [float(eps_at[q]) for q in range(args.seq)] + max_candidates = args.direction_max_candidates + if max_candidates == 0: + max_candidates = None + direction_target, direction_negative, best_lb, topk_entries = search_direction( + model=model, + values=values, + layer=layer, + head=head, + prev=prev, + active_positions=active_positions, + eps_at=eps_at_float, + vocab_min=args.direction_vocab_min, + vocab_max=args.direction_vocab_max, + max_candidates=max_candidates, + topk=args.direction_topk, + seed=args.seed, + ) + if args.direction_report_out is not None: + write_direction_report( + args.direction_report_out, + topk_entries, + args.direction_vocab_min, + args.direction_vocab_max, + args.seed, + ) + try: + min_lb = float(Fraction(args.direction_min_lb)) + except (ValueError, ZeroDivisionError) as exc: + raise SystemExit("direction-min-lb must be a rational literal") from exc + if best_lb < min_lb: + raise SystemExit( + f"Best direction lower bound {best_lb:.6f} below minimum {min_lb:.6f}." + ) + print( + f"Selected direction: target={direction_target} negative={direction_negative} " + f"(estimated LB {best_lb:.6f})" + ) + + if (args.direction_target is None) != (args.direction_negative is None): + raise SystemExit("direction-target and direction-negative must be provided together") + if args.direction_target is not None: + direction_target = args.direction_target + direction_negative = args.direction_negative + + if direction_target is not None: + wte = model.wte.weight.detach().cpu().numpy() + if direction_target < 0 or direction_target >= wte.shape[0]: + raise SystemExit("direction-target out of vocab range") + if direction_negative < 0 or direction_negative >= wte.shape[0]: + raise SystemExit("direction-negative out of vocab range") + direction = wte[direction_target] - wte[direction_negative] + head_dim = model.config.n_embd // model.config.n_head + start, end = head * head_dim, (head + 1) * head_dim + w_o = model.h[layer].attn.c_proj.weight.detach().cpu().numpy()[:, start:end] + dir_head = w_o.T @ direction + vals = values @ dir_head + else: + if args.value_dim < 0 or args.value_dim >= values.shape[1]: + raise SystemExit(f"value-dim must be in [0, {values.shape[1] - 1}]") + vals = values[:, args.value_dim] + + vals_rat = [rat_from_float(float(vals[k]), args.decimals) for k in range(args.seq)] + + output_path = Path(args.output) + output_path.parent.mkdir(parents=True, exist_ok=True) + write_induction_cert( + output_path, + args.seq, + prev, + scores_rat, + weights_rat, + eps, + margin, + active_positions, + eps_at, + weight_bound_at, + vals_rat, + direction_target=direction_target, + direction_negative=direction_negative, + ) + + if args.scores_out: + scores_path = Path(args.scores_out) + scores_path.parent.mkdir(parents=True, exist_ok=True) + write_scores(scores_path, args.seq, prev, scores_rat, weights_rat, + active=active_positions) + + if args.values_out: + values_path = Path(args.values_out) + values_path.parent.mkdir(parents=True, exist_ok=True) + write_value_range(values_path, args.seq, vals, args.decimals, + direction_target=direction_target, + direction_negative=direction_negative) + if args.tokens_out: + tokens_path = Path(args.tokens_out) + write_tokens(tokens_path, tokens) + + print(f"Wrote certificate to {output_path}") + if args.scores_out: + print(f"Wrote scores dump to {scores_path}") + if args.values_out: + print(f"Wrote value-range certificate to {values_path}") + if args.tokens_out: + print(f"Wrote token list to {tokens_path}") + + +if __name__ == "__main__": + main() diff --git a/scripts/ci_sanity_forward_step.py b/scripts/ci_sanity_forward_step.py deleted file mode 100644 index a2668dd..0000000 --- a/scripts/ci_sanity_forward_step.py +++ /dev/null @@ -1,206 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -CI sanity check: export a tiny GPT-2-like model to `.nfpt` and compare one forward step -against PyTorch. - -This is meant to be cheap enough for CI while still catching semantic drift (e.g. missing -attention biases, wrong GeLU constant, etc.). - -Usage (CI): - python3 scripts/ci_sanity_forward_step.py --model sshleifer/tiny-gpt2 --seqLen 8 -""" - -from __future__ import annotations - -import argparse -import subprocess -import tempfile -from pathlib import Path - -import numpy as np - -try: - import torch - from transformers import GPT2Model -except ImportError as e: - raise SystemExit( - "Missing deps. Install with e.g.: pip install torch transformers\n" - f"ImportError: {e}" - ) - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().cpu().numpy() - - -def export_nfpt(model: GPT2Model, seq_len: int, out_path: Path) -> np.ndarray: - cfg = model.config - num_layers = int(cfg.n_layer) - num_heads = int(cfg.n_head) - model_dim = int(cfg.n_embd) - head_dim = model_dim // num_heads - hidden_dim = int(cfg.n_inner or 4 * model_dim) - vocab_size = int(cfg.vocab_size) - layer_norm_eps = float(cfg.layer_norm_epsilon) - - # Deterministic token sequence. - tokens = (np.arange(seq_len, dtype=np.int64) % min(32, vocab_size)).astype(np.int64) - - # GPT-2 embeddings: wte[token] + wpe[pos]. - wte = model.wte.weight.detach().cpu().numpy() - wpe = model.wpe.weight.detach().cpu().numpy() - embeddings = wte[tokens] + wpe[:seq_len] - - with out_path.open("wb") as f: - write_header( - f, - num_layers=num_layers, - num_heads=num_heads, - model_dim=model_dim, - head_dim=head_dim, - hidden_dim=hidden_dim, - vocab_size=vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - write_i32(f, tokens) - write_f64(f, embeddings) - - for layer_idx in range(num_layers): - block = model.h[layer_idx] - - c_attn_w = block.attn.c_attn.weight.detach().cpu().numpy() # (d, 3d) - c_attn_b = get_bias(block.attn.c_attn.bias, 3 * model_dim) # (3d,) - c_proj_w = block.attn.c_proj.weight.detach().cpu().numpy() # (d, d) - c_proj_b = get_bias(block.attn.c_proj.bias, model_dim) # (d,) - - W_Q_all = c_attn_w[:, 0:model_dim] - W_K_all = c_attn_w[:, model_dim : 2 * model_dim] - W_V_all = c_attn_w[:, 2 * model_dim : 3 * model_dim] - b_Q_all = c_attn_b[0:model_dim] - b_K_all = c_attn_b[model_dim : 2 * model_dim] - b_V_all = c_attn_b[2 * model_dim : 3 * model_dim] - - for h in range(num_heads): - start, end = h * head_dim, (h + 1) * head_dim - write_f64(f, W_Q_all[:, start:end]) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K_all[:, start:end]) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V_all[:, start:end]) - write_f64(f, b_V_all[start:end]) - write_f64(f, c_proj_w[start:end, :]) - - write_f64(f, c_proj_b) - - write_f64(f, block.mlp.c_fc.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.mlp.c_fc.bias, hidden_dim)) - write_f64(f, block.mlp.c_proj.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.mlp.c_proj.bias, model_dim)) - - write_f64(f, block.ln_1.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.ln_1.bias, model_dim)) - write_f64(f, block.ln_2.weight.detach().cpu().numpy()) - write_f64(f, get_bias(block.ln_2.bias, model_dim)) - - write_f64(f, model.ln_f.weight.detach().cpu().numpy()) - write_f64(f, get_bias(model.ln_f.bias, model_dim)) - write_f64(f, wte.T) - - return tokens - - -def run_lean_dump(model_path: Path, layer: int, pos: int, take: int) -> np.ndarray: - exe = Path(".lake/build/bin/nfp") - if not exe.exists(): - raise SystemExit("Missing `.lake/build/bin/nfp`. Run `lake build nfp` first.") - cmd = [ - str(exe), - "dump", - "--kind", - "afterLayer", - "--layer", - str(layer), - "--pos", - str(pos), - "--take", - str(take), - str(model_path), - ] - out = subprocess.check_output(cmd, text=True) - lines = [ln.strip() for ln in out.splitlines() if ln.strip()] - dump_i = None - for i, ln in enumerate(lines): - if ln.startswith("DUMP "): - dump_i = i - if dump_i is None or dump_i + 2 >= len(lines): - raise SystemExit(f"Unexpected Lean dump output:\n{out}") - vals = [float(x) for x in lines[dump_i + 2].split()] - return np.asarray(vals, dtype=np.float64) - - -def run_torch(model: GPT2Model, tokens: np.ndarray, layer: int, pos: int, take: int) -> np.ndarray: - input_ids = torch.tensor(tokens.reshape(1, -1), dtype=torch.long) - with torch.no_grad(): - out = model(input_ids=input_ids, output_hidden_states=True) - hs = out.hidden_states - if hs is None: - raise SystemExit("Expected output_hidden_states=True to return hidden_states.") - x = hs[layer + 1][0, pos, :take].detach().cpu().numpy() - return x.astype(np.float64) - - -def main() -> None: - ap = argparse.ArgumentParser() - ap.add_argument("--model", default="sshleifer/tiny-gpt2") - ap.add_argument("--seqLen", type=int, default=8) - ap.add_argument("--layer", type=int, default=0) - ap.add_argument("--pos", type=int, default=0) - ap.add_argument("--take", type=int, default=16) - ap.add_argument("--tol", type=float, default=1e-3) - args = ap.parse_args() - - m = GPT2Model.from_pretrained(args.model) - m = m.to(dtype=torch.float64) - m.eval() - - with tempfile.TemporaryDirectory() as td: - path = Path(td) / "ci_tiny.nfpt" - toks = export_nfpt(m, args.seqLen, path) - lean = run_lean_dump(path, args.layer, args.pos, args.take) - torch_x = run_torch(m, toks, args.layer, args.pos, args.take) - diff = float(np.max(np.abs(lean - torch_x))) - print(f"max_abs_diff={diff:.6g} tol={args.tol:.6g}") - if not np.isfinite(diff) or diff > args.tol: - print("FAIL") - print("lean :", lean.tolist()) - print("torch:", torch_x.tolist()) - raise SystemExit(1) - print("OK") - - -if __name__ == "__main__": - main() diff --git a/scripts/convert_text_fixture_to_binary.py b/scripts/convert_text_fixture_to_binary.py deleted file mode 100644 index 82d8751..0000000 --- a/scripts/convert_text_fixture_to_binary.py +++ /dev/null @@ -1,273 +0,0 @@ -#!/usr/bin/env python3 -from __future__ import annotations - -import argparse -import struct -from pathlib import Path -from typing import Callable, Dict, List, Tuple - - -def parse_header(lines: List[str], path: Path) -> Tuple[Dict[str, str], int]: - if not lines: - raise ValueError(f"{path}: empty file") - magic = lines[0].strip() - if not magic.startswith("NFP_TEXT"): - raise ValueError(f"{path}: unexpected header '{magic}'") - header: Dict[str, str] = {} - i = 1 - while i < len(lines): - line = lines[i].strip() - i += 1 - if line == "": - break - if "=" in line: - key, value = line.split("=", 1) - header[key.strip()] = value.strip() - return header, i - - -def next_nonempty(lines: List[str], i: int, path: Path) -> Tuple[int, str]: - while i < len(lines) and lines[i].strip() == "": - i += 1 - if i >= len(lines): - raise ValueError(f"{path}: unexpected EOF") - return i, lines[i].strip() - - -def expect_line(lines: List[str], i: int, expected: str, path: Path) -> int: - i, line = next_nonempty(lines, i, path) - if line != expected: - raise ValueError(f"{path}: expected '{expected}', got '{line}'") - return i + 1 - - -def expect_prefix(lines: List[str], i: int, prefix: str, path: Path) -> int: - i, line = next_nonempty(lines, i, path) - if not line.startswith(prefix): - raise ValueError(f"{path}: expected '{prefix} ...', got '{line}'") - return i + 1 - - -def read_numbers( - lines: List[str], - i: int, - count: int, - path: Path, - cast: Callable[[str], float], -) -> Tuple[List[float], int]: - out: List[float] = [] - while len(out) < count: - if i >= len(lines): - raise ValueError(f"{path}: unexpected EOF while reading numbers") - line = lines[i].strip() - i += 1 - if line == "": - continue - for tok in line.split(): - try: - out.append(cast(tok)) - except ValueError as exc: - raise ValueError(f"{path}: invalid number '{tok}'") from exc - if len(out) == count: - break - return out, i - - -def read_input_embeddings(path: Path) -> Tuple[Dict[str, str], List[int], List[float]]: - lines = path.read_text().splitlines() - header, i = parse_header(lines, path) - seq_len = int(header["seq_len"]) - model_dim = int(header["model_dim"]) - i = expect_line(lines, i, "TOKENS", path) - tokens_f, i = read_numbers(lines, i, seq_len, path, int) - tokens = [int(t) for t in tokens_f] - i = expect_line(lines, i, "EMBEDDINGS", path) - embeddings_f, _ = read_numbers(lines, i, seq_len * model_dim, path, float) - return header, tokens, embeddings_f - - -def read_model_weights(path: Path) -> Tuple[Dict[str, str], Dict[str, object]]: - lines = path.read_text().splitlines() - header, i = parse_header(lines, path) - num_layers = int(header["num_layers"]) - num_heads = int(header["num_heads"]) - model_dim = int(header["model_dim"]) - head_dim = int(header["head_dim"]) - hidden_dim = int(header["hidden_dim"]) - - layers: List[Dict[str, object]] = [] - for _ in range(num_layers): - i = expect_prefix(lines, i, "LAYER", path) - heads: List[Dict[str, List[float]]] = [] - for _ in range(num_heads): - i = expect_prefix(lines, i, "HEAD", path) - i = expect_line(lines, i, "W_Q", path) - w_q, i = read_numbers(lines, i, model_dim * head_dim, path, float) - i = expect_line(lines, i, "b_Q", path) - b_q, i = read_numbers(lines, i, head_dim, path, float) - i = expect_line(lines, i, "W_K", path) - w_k, i = read_numbers(lines, i, model_dim * head_dim, path, float) - i = expect_line(lines, i, "b_K", path) - b_k, i = read_numbers(lines, i, head_dim, path, float) - i = expect_line(lines, i, "W_V", path) - w_v, i = read_numbers(lines, i, model_dim * head_dim, path, float) - i = expect_line(lines, i, "b_V", path) - b_v, i = read_numbers(lines, i, head_dim, path, float) - i = expect_line(lines, i, "W_O", path) - w_o, i = read_numbers(lines, i, head_dim * model_dim, path, float) - heads.append( - { - "W_Q": w_q, - "b_Q": b_q, - "W_K": w_k, - "b_K": b_k, - "W_V": w_v, - "b_V": b_v, - "W_O": w_o, - } - ) - i = expect_line(lines, i, "ATTN_BIAS", path) - attn_bias, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "MLP", path) - i = expect_line(lines, i, "W_in", path) - w_in, i = read_numbers(lines, i, model_dim * hidden_dim, path, float) - i = expect_line(lines, i, "b_in", path) - b_in, i = read_numbers(lines, i, hidden_dim, path, float) - i = expect_line(lines, i, "W_out", path) - w_out, i = read_numbers(lines, i, hidden_dim * model_dim, path, float) - i = expect_line(lines, i, "b_out", path) - b_out, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN1_GAMMA", path) - ln1_gamma, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN1_BETA", path) - ln1_beta, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN2_GAMMA", path) - ln2_gamma, i = read_numbers(lines, i, model_dim, path, float) - i = expect_line(lines, i, "LN2_BETA", path) - ln2_beta, i = read_numbers(lines, i, model_dim, path, float) - layers.append( - { - "heads": heads, - "attn_bias": attn_bias, - "w_in": w_in, - "b_in": b_in, - "w_out": w_out, - "b_out": b_out, - "ln1_gamma": ln1_gamma, - "ln1_beta": ln1_beta, - "ln2_gamma": ln2_gamma, - "ln2_beta": ln2_beta, - } - ) - - return header, {"layers": layers} - - -def write_i32(f, data: List[int]) -> None: - if not data: - return - f.write(struct.pack("<" + "i" * len(data), *data)) - - -def write_f64(f, data: List[float]) -> None: - if not data: - return - f.write(struct.pack("<" + "d" * len(data), *data)) - - -def main() -> None: - parser = argparse.ArgumentParser( - description="Convert a tiny NFP_TEXT fixture into NFP_BINARY_V1." - ) - parser.add_argument( - "--model", - default="tests/fixtures/tiny_sound_model.nfpt", - help="Text model fixture path (NFP_TEXT_V2).", - ) - parser.add_argument( - "--input", - default="tests/fixtures/tiny_sound_input.nfpt", - help="Text input fixture path (NFP_TEXT_V2).", - ) - parser.add_argument( - "--output", - default="tests/fixtures/tiny_sound_binary.nfpt", - help="Output binary fixture path.", - ) - args = parser.parse_args() - - model_path = Path(args.model) - input_path = Path(args.input) - output_path = Path(args.output) - - input_header, tokens, embeddings = read_input_embeddings(input_path) - model_header, model_payload = read_model_weights(model_path) - - if int(input_header["model_dim"]) != int(model_header["model_dim"]): - raise ValueError("input/model model_dim mismatch") - if int(input_header["seq_len"]) != int(model_header["seq_len"]): - raise ValueError("input/model seq_len mismatch") - input_eps = input_header.get("layer_norm_eps") or input_header.get("eps") - model_eps = model_header.get("layer_norm_eps") or model_header.get("eps") - model_gelu = model_header.get("gelu_kind") or model_header.get("gelu_deriv") - input_gelu = input_header.get("gelu_kind") or input_header.get("gelu_deriv") - if model_eps is None: - raise ValueError("model header missing layer_norm_eps") - if input_eps is not None and input_eps != model_eps: - raise ValueError("input/model layer_norm_eps mismatch") - if model_gelu is None: - raise ValueError("model header missing gelu_kind") - if input_gelu is not None and input_gelu != model_gelu: - raise ValueError("input/model gelu_kind mismatch") - - output_path.parent.mkdir(parents=True, exist_ok=True) - with output_path.open("wb") as f: - f.write(b"NFP_BINARY_V1\n") - for key in [ - "num_layers", - "num_heads", - "model_dim", - "head_dim", - "hidden_dim", - "vocab_size", - "seq_len", - ]: - f.write(f"{key}={model_header[key]}\n".encode("ascii")) - f.write(f"layer_norm_eps={model_eps}\n".encode("ascii")) - f.write(f"gelu_kind={model_gelu}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - write_i32(f, tokens) - write_f64(f, embeddings) - - layers = model_payload["layers"] - for layer in layers: - for head in layer["heads"]: - write_f64(f, head["W_Q"]) - write_f64(f, head["b_Q"]) - write_f64(f, head["W_K"]) - write_f64(f, head["b_K"]) - write_f64(f, head["W_V"]) - write_f64(f, head["b_V"]) - write_f64(f, head["W_O"]) - write_f64(f, layer["attn_bias"]) - write_f64(f, layer["w_in"]) - write_f64(f, layer["b_in"]) - write_f64(f, layer["w_out"]) - write_f64(f, layer["b_out"]) - write_f64(f, layer["ln1_gamma"]) - write_f64(f, layer["ln1_beta"]) - write_f64(f, layer["ln2_gamma"]) - write_f64(f, layer["ln2_beta"]) - - model_dim = int(model_header["model_dim"]) - vocab_size = int(model_header["vocab_size"]) - write_f64(f, [1.0] * model_dim) - write_f64(f, [0.0] * model_dim) - write_f64(f, [0.0] * (model_dim * vocab_size)) - - print(f"Wrote {output_path}") - - -if __name__ == "__main__": - main() diff --git a/scripts/demo_gpt2_induction_sound.sh b/scripts/demo_gpt2_induction_sound.sh deleted file mode 100755 index cd3a8ae..0000000 --- a/scripts/demo_gpt2_induction_sound.sh +++ /dev/null @@ -1,35 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_PATH="${1:-models/gpt2_rigorous.nfpt}" -REPORT_PATH="${2:-reports/gpt2_induction_sound_scan.txt}" -EXTRA_ARGS=() -if [ "$#" -gt 2 ]; then - EXTRA_ARGS=("${@:3}") -fi - -PYTHON_BIN="python" -if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" -fi - -if ! "$PYTHON_BIN" scripts/ensure_gelu_kind.py --check "$MODEL_PATH"; then - PATCHED_PATH="${MODEL_PATH%.nfpt}_with_gelu_kind.nfpt" - "$PYTHON_BIN" scripts/ensure_gelu_kind.py "$MODEL_PATH" \ - --output "$PATCHED_PATH" \ - --default tanh - MODEL_PATH="$PATCHED_PATH" -fi - -if [ "${#EXTRA_ARGS[@]}" -gt 0 ]; then - "$PYTHON_BIN" scripts/scan_gpt2_induction_sound.py \ - --model "$MODEL_PATH" \ - --output "$REPORT_PATH" \ - "${EXTRA_ARGS[@]}" -else - "$PYTHON_BIN" scripts/scan_gpt2_induction_sound.py \ - --model "$MODEL_PATH" \ - --output "$REPORT_PATH" -fi - -echo "Report written to $REPORT_PATH" diff --git a/scripts/demo_gpt2_sound.sh b/scripts/demo_gpt2_sound.sh deleted file mode 100755 index 67660b9..0000000 --- a/scripts/demo_gpt2_sound.sh +++ /dev/null @@ -1,31 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_PATH="${1:-models/gpt2.nfpt}" -REPORT_PATH="${2:-reports/gpt2_sound_demo.txt}" - -mkdir -p "$(dirname "$MODEL_PATH")" "$(dirname "$REPORT_PATH")" - -PYTHON_BIN="python" -if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" -fi - -if [ ! -f "$MODEL_PATH" ]; then - if command -v uv >/dev/null 2>&1; then - uv run python scripts/export_gpt2.py "$MODEL_PATH" - else - "$PYTHON_BIN" scripts/export_gpt2.py "$MODEL_PATH" - fi -fi - -if ! "$PYTHON_BIN" scripts/ensure_gelu_kind.py --check "$MODEL_PATH"; then - PATCHED_PATH="${MODEL_PATH%.nfpt}_with_gelu_kind.nfpt" - "$PYTHON_BIN" scripts/ensure_gelu_kind.py "$MODEL_PATH" \ - --output "$PATCHED_PATH" \ - --default tanh - MODEL_PATH="$PATCHED_PATH" -fi - -lake exe nfp certify "$MODEL_PATH" --output "$REPORT_PATH" -echo "Report written to $REPORT_PATH" diff --git a/scripts/demo_tiny_induction_cert.sh b/scripts/demo_tiny_induction_cert.sh deleted file mode 100755 index f4e1ebc..0000000 --- a/scripts/demo_tiny_induction_cert.sh +++ /dev/null @@ -1,26 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_TEXT="${1:-tests/fixtures/tiny_sound_model.nfpt}" -INPUT_TEXT="${2:-tests/fixtures/tiny_sound_input.nfpt}" -BINARY_PATH="${3:-tests/fixtures/tiny_sound_binary.nfpt}" -REPORT_PATH="${4:-reports/tiny_induction_cert.txt}" - -mkdir -p "$(dirname "$BINARY_PATH")" "$(dirname "$REPORT_PATH")" - -PYTHON_BIN="python" -if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" -fi - -"$PYTHON_BIN" scripts/convert_text_fixture_to_binary.py \ - --model "$MODEL_TEXT" \ - --input "$INPUT_TEXT" \ - --output "$BINARY_PATH" - -lake exe nfp induction_cert "$BINARY_PATH" \ - --layer1 0 --head1 0 --layer2 0 --head2 0 --coord 0 \ - --offset1 -1 --offset2 -1 --target 2 --negative 1 \ - --delta 0.05 --output "$REPORT_PATH" - -echo "Report written to $REPORT_PATH" diff --git a/scripts/demo_tiny_local_binary.sh b/scripts/demo_tiny_local_binary.sh deleted file mode 100755 index 28a2018..0000000 --- a/scripts/demo_tiny_local_binary.sh +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env bash -set -euo pipefail - -MODEL_TEXT="${1:-tests/fixtures/tiny_sound_model.nfpt}" -INPUT_TEXT="${2:-tests/fixtures/tiny_sound_input.nfpt}" -BINARY_PATH="${3:-tests/fixtures/tiny_sound_binary.nfpt}" -REPORT_PATH="${4:-reports/tiny_sound_local_binary.txt}" - -mkdir -p "$(dirname "$BINARY_PATH")" "$(dirname "$REPORT_PATH")" - -if [ "${USE_UV:-0}" = "1" ] && command -v uv >/dev/null 2>&1; then - uv run python scripts/convert_text_fixture_to_binary.py \ - --model "$MODEL_TEXT" \ - --input "$INPUT_TEXT" \ - --output "$BINARY_PATH" -else - PYTHON_BIN="python" - if ! command -v "$PYTHON_BIN" >/dev/null 2>&1; then - PYTHON_BIN="python3" - fi - "$PYTHON_BIN" scripts/convert_text_fixture_to_binary.py \ - --model "$MODEL_TEXT" \ - --input "$INPUT_TEXT" \ - --output "$BINARY_PATH" -fi - -lake exe nfp certify "$BINARY_PATH" --delta 0.05 --output "$REPORT_PATH" -echo "Report written to $REPORT_PATH" diff --git a/scripts/diagnose_induction_heads.py b/scripts/diagnose_induction_heads.py new file mode 100644 index 0000000..7d17ebe --- /dev/null +++ b/scripts/diagnose_induction_heads.py @@ -0,0 +1,179 @@ +#!/usr/bin/env python3 +# SPDX-License-Identifier: AGPL-3.0-or-later + +""" +Diagnostic scan for induction heads on repeated random sequences. + +This mirrors the common literature setup: +- build a batch of repeated random token sequences (pattern_len repeated twice), +- run GPT-2 with output_attentions, +- rank heads by induction stripe attention (q -> q - period), +- rank heads by previous-token attention (q -> q - 1). + +Layer/head indices in the report are 1-based to match the literature. +""" + +from __future__ import annotations + +import argparse +import sys +from pathlib import Path + +import numpy as np + +try: + import torch + from transformers import GPT2Model, GPT2Tokenizer +except ImportError as exc: + raise SystemExit("Requires torch + transformers (use `uv run --with torch --with transformers`).") from exc + + +def select_vocab_candidates( + tokenizer, + vocab_min: int, + vocab_max: int, + min_word_length: int, + require_leading_space: bool, +) -> list[int]: + candidates = [] + for tid in range(vocab_min, vocab_max): + word = tokenizer.decode([tid]) + if len(word.strip()) <= min_word_length: + continue + if require_leading_space and not word.startswith(" "): + continue + candidates.append(tid) + return candidates + + +def build_batch( + rng: np.random.Generator, + candidates: list[int], + batch_size: int, + pattern_len: int, + seq_len: int, +) -> np.ndarray: + if seq_len != 2 * pattern_len: + raise ValueError("seq_len must equal 2 * pattern_len for repeated sequence diagnostic") + if len(candidates) < pattern_len: + raise ValueError("Not enough vocab candidates for requested pattern length") + batch = np.zeros((batch_size, seq_len), dtype=np.int64) + for idx in range(batch_size): + pattern = rng.choice(candidates, size=pattern_len, replace=False) + batch[idx] = np.tile(pattern, 2) + return batch + + +def compute_stripe_stats(attn: torch.Tensor, period: int) -> tuple[torch.Tensor, torch.Tensor]: + batch, heads, seq_len, _ = attn.shape + q = torch.arange(period, seq_len, device=attn.device) + k = q - period + block = attn[:, :, q, :] + k_index = k.view(1, 1, -1, 1).expand(batch, heads, -1, 1) + stripe_vals = block.gather(dim=-1, index=k_index).squeeze(-1) + stripe_mean = stripe_vals.mean(dim=(0, 2)) + max_vals = block.max(dim=-1).values + stripe_top1 = (stripe_vals >= max_vals - 1e-12).float().mean(dim=(0, 2)) + return stripe_mean, stripe_top1 + + +def compute_prev_stats(attn: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]: + batch, heads, seq_len, _ = attn.shape + q = torch.arange(1, seq_len, device=attn.device) + k = q - 1 + block = attn[:, :, q, :] + k_index = k.view(1, 1, -1, 1).expand(batch, heads, -1, 1) + prev_vals = block.gather(dim=-1, index=k_index).squeeze(-1) + prev_mean = prev_vals.mean(dim=(0, 2)) + max_vals = block.max(dim=-1).values + prev_top1 = (prev_vals >= max_vals - 1e-12).float().mean(dim=(0, 2)) + return prev_mean, prev_top1 + + +def main() -> int: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument("--model", default="gpt2", help="HuggingFace model name") + parser.add_argument("--seq-len", type=int, default=50) + parser.add_argument("--pattern-len", type=int, default=25) + parser.add_argument("--batch", type=int, default=30) + parser.add_argument("--seed", type=int, default=1337) + parser.add_argument("--vocab-min", type=int, default=1000) + parser.add_argument("--vocab-max", type=int, default=5000) + parser.add_argument("--min-word-length", type=int, default=4) + parser.add_argument("--require-leading-space", action="store_true", default=True) + parser.add_argument("--allow-no-leading-space", action="store_true") + parser.add_argument("--device", default="cpu") + parser.add_argument("--attn-implementation", default="eager") + parser.add_argument("--top", type=int, default=20) + parser.add_argument("--output", type=Path, default=Path("reports/induction_diagnostic.txt")) + args = parser.parse_args() + + require_leading_space = args.require_leading_space and not args.allow_no_leading_space + rng = np.random.default_rng(args.seed) + + tokenizer = GPT2Tokenizer.from_pretrained(args.model) + model = GPT2Model.from_pretrained(args.model, attn_implementation=args.attn_implementation) + model.eval() + model.to(args.device) + + candidates = select_vocab_candidates( + tokenizer, + vocab_min=args.vocab_min, + vocab_max=args.vocab_max, + min_word_length=args.min_word_length, + require_leading_space=require_leading_space, + ) + batch_tokens = build_batch( + rng, + candidates, + batch_size=args.batch, + pattern_len=args.pattern_len, + seq_len=args.seq_len, + ) + input_ids = torch.tensor(batch_tokens, dtype=torch.long, device=args.device) + with torch.no_grad(): + outputs = model(input_ids, output_attentions=True, use_cache=False) + if outputs.attentions is None: + raise SystemExit("Model did not return attention weights.") + + stripe_scores = [] + prev_scores = [] + for layer_idx, attn in enumerate(outputs.attentions): + stripe_mean, stripe_top1 = compute_stripe_stats(attn, args.pattern_len) + prev_mean, prev_top1 = compute_prev_stats(attn) + for head_idx in range(attn.shape[1]): + stripe_scores.append((float(stripe_mean[head_idx]), float(stripe_top1[head_idx]), layer_idx, head_idx)) + prev_scores.append((float(prev_mean[head_idx]), float(prev_top1[head_idx]), layer_idx, head_idx)) + + stripe_scores.sort(key=lambda x: x[0], reverse=True) + prev_scores.sort(key=lambda x: x[0], reverse=True) + + args.output.parent.mkdir(parents=True, exist_ok=True) + with args.output.open("w", encoding="ascii") as f: + f.write("Induction diagnostic (repeated random sequence)\n") + f.write(f"model={args.model}\n") + f.write(f"seq_len={args.seq_len} pattern_len={args.pattern_len} batch={args.batch}\n") + f.write( + f"vocab_range=[{args.vocab_min},{args.vocab_max}) min_word_length={args.min_word_length} " + f"leading_space={require_leading_space}\n" + ) + f.write("\nTop induction stripe heads:\n") + for rank, (mean, top1, layer, head) in enumerate(stripe_scores[: args.top], start=1): + f.write(f"{rank:02d} L{layer+1}H{head+1} stripeMean={mean:.6f} stripeTop1={top1:.3f}\n") + f.write("\nTop previous-token heads:\n") + for rank, (mean, top1, layer, head) in enumerate(prev_scores[: args.top], start=1): + f.write(f"{rank:02d} L{layer+1}H{head+1} prevMean={mean:.6f} prevTop1={top1:.3f}\n") + + print(f"Wrote report to {args.output}") + print("\nTop induction stripe heads:") + for rank, (mean, top1, layer, head) in enumerate(stripe_scores[: args.top], start=1): + print(f"{rank:02d} L{layer+1}H{head+1} stripeMean={mean:.6f} stripeTop1={top1:.3f}") + print("\nTop previous-token heads:") + for rank, (mean, top1, layer, head) in enumerate(prev_scores[: args.top], start=1): + print(f"{rank:02d} L{layer+1}H{head+1} prevMean={mean:.6f} prevTop1={top1:.3f}") + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/scripts/ensure_gelu_kind.py b/scripts/ensure_gelu_kind.py deleted file mode 100644 index 0a908a4..0000000 --- a/scripts/ensure_gelu_kind.py +++ /dev/null @@ -1,86 +0,0 @@ -#!/usr/bin/env python3 -"""Ensure a `.nfpt` header has `gelu_kind`, optionally writing a patched copy.""" - -from __future__ import annotations - -import argparse -import shutil -import sys -from pathlib import Path - - -def has_gelu_kind(path: Path) -> bool: - with path.open("rb") as f: - while True: - line = f.readline() - if not line: - raise ValueError("unexpected EOF while reading header") - stripped = line.strip() - if stripped.startswith(b"gelu_kind=") or stripped.startswith(b"gelu_deriv="): - return True - if stripped == b"BINARY_START": - return False - - -def patch_header(path: Path, output: Path, default_kind: str) -> None: - with path.open("rb") as f: - header_lines: list[bytes] = [] - gelu_present = False - while True: - line = f.readline() - if not line: - raise ValueError("unexpected EOF while reading header") - stripped = line.strip() - if stripped.startswith(b"gelu_kind=") or stripped.startswith(b"gelu_deriv="): - gelu_present = True - header_lines.append(line) - if stripped == b"BINARY_START": - break - payload = f.read() - - output.parent.mkdir(parents=True, exist_ok=True) - if gelu_present: - shutil.copyfile(path, output) - print(f"gelu_kind already present; copied to {output}") - return - - patched: list[bytes] = [] - inserted = False - for line in header_lines: - if line.strip() == b"BINARY_START" and not inserted: - patched.append(f"gelu_kind={default_kind}\n".encode("ascii")) - inserted = True - patched.append(line) - - if not inserted: - raise ValueError("missing BINARY_START while patching header") - - with output.open("wb") as f: - f.writelines(patched) - f.write(payload) - print(f"added gelu_kind={default_kind}; wrote {output}") - - -def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("input", type=Path) - parser.add_argument("--output", type=Path) - parser.add_argument("--default", default="tanh") - parser.add_argument("--check", action="store_true") - args = parser.parse_args() - - if args.check: - return 0 if has_gelu_kind(args.input) else 1 - - if args.output is None: - raise SystemExit("missing --output (use --check to test only)") - - if args.output.resolve() == args.input.resolve(): - raise SystemExit("refusing in-place patch; pass a distinct --output path") - - patch_header(args.input, args.output, args.default) - return 0 - - -if __name__ == "__main__": - sys.exit(main()) diff --git a/scripts/export_gpt2.py b/scripts/export_gpt2.py deleted file mode 100644 index 04e8f7c..0000000 --- a/scripts/export_gpt2.py +++ /dev/null @@ -1,239 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Export GPT-2 Small weights to NFP_BINARY_V1 format (.nfpt). - -This script loads GPT-2 Small from HuggingFace and exports all weights needed -for circuit analysis in the NFP library. - -Usage: - uv run scripts/export_gpt2.py [output_path] - -Default output: models/gpt2.nfpt -""" - -import sys -from pathlib import Path -import numpy as np - -try: - from transformers import GPT2Model -except ImportError: - print("Error: transformers library not installed.") - print("Install with: uv add transformers torch") - sys.exit(1) - - -def export_gpt2_weights(output_path: str = "models/gpt2.nfpt", seq_len: int = 256): - """ - Export GPT-2 Small weights to .nfpt format. - - GPT-2 Small architecture: - - 12 layers - - 12 attention heads per layer - - 768 model dimension - - 64 head dimension (768 / 12) - - 3072 MLP hidden dimension (4 * 768) - - 50257 vocabulary size - - Args: - output_path: Path to write the .nfpt file - seq_len: Sequence length for analysis (default 256, smaller than GPT-2's 1024 for speed) - """ - print("Loading GPT-2 Small from HuggingFace...") - model = GPT2Model.from_pretrained("gpt2") - config = model.config - - # Model dimensions - num_layers = config.n_layer # 12 - num_heads = config.n_head # 12 - model_dim = config.n_embd # 768 - head_dim = model_dim // num_heads # 64 - hidden_dim = config.n_inner or 4 * model_dim # 3072 - vocab_size = config.vocab_size # 50257 - layer_norm_eps = float(config.layer_norm_epsilon) - - print("Model configuration:") - print(f" Layers: {num_layers}") - print(f" Heads: {num_heads}") - print(f" Model dim: {model_dim}") - print(f" Head dim: {head_dim}") - print(f" Hidden dim: {hidden_dim}") - print(f" Vocab size: {vocab_size}") - print(f" Sequence length (for analysis): {seq_len}") - - # Create output directory - output_file = Path(output_path) - output_file.parent.mkdir(parents=True, exist_ok=True) - - print(f"\nExporting to {output_path}...") - - with open(output_path, "wb") as f: - # Header - write_header( - f, - num_layers=num_layers, - num_heads=num_heads, - model_dim=model_dim, - head_dim=head_dim, - hidden_dim=hidden_dim, - vocab_size=vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - # Sample input embeddings for analysis - # For circuit analysis, we need input embeddings for a specific sequence - # We'll create a dummy sequence of token indices and get their embeddings - # Shape: (seq_len, model_dim) - wte = model.wte.weight.detach().numpy() - wpe = model.wpe.weight.detach().numpy() # Position embeddings - - # Create sample token indices (repeating pattern to make induction detectable) - # Pattern: 0, 1, 2, ..., 15, 0, 1, 2, ..., 15, ... - sample_tokens = np.array([i % 16 for i in range(seq_len)]) - # Include both token embeddings and position embeddings - # This is what GPT-2 computes: x_0 = token_emb + pos_emb - sample_embeddings = wte[sample_tokens] + wpe[:seq_len] # (seq_len, model_dim) - - # Ground-truth token sequence for self-supervised induction targeting. - # This is used by the Lean pipeline to choose the correct induction target - # algorithmically from the sequence history. - write_i32(f, sample_tokens) - write_f64(f, sample_embeddings) - - # Process each layer - for layer_idx in range(num_layers): - print(f" Processing layer {layer_idx}...") - block = model.h[layer_idx] - - - # Attention weights - # GPT-2 stores Q, K, V concatenated in c_attn.weight - # Shape: (model_dim, 3 * model_dim) for weight - # We need to split into Q, K, V and then split each by heads - - c_attn_weight = block.attn.c_attn.weight.detach().numpy() # (768, 2304) - c_attn_bias = get_bias(block.attn.c_attn.bias, 3 * model_dim) # (2304,) - - # c_attn.weight layout: input_dim -> [Q_all_heads, K_all_heads, V_all_heads] - # Each section is (model_dim, model_dim) - W_Q_all = c_attn_weight[:, 0:model_dim] # (768, 768) - W_K_all = c_attn_weight[:, model_dim : 2 * model_dim] # (768, 768) - W_V_all = c_attn_weight[:, 2 * model_dim : 3 * model_dim] # (768, 768) - b_Q_all = c_attn_bias[0:model_dim] # (768,) - b_K_all = c_attn_bias[model_dim : 2 * model_dim] # (768,) - b_V_all = c_attn_bias[2 * model_dim : 3 * model_dim] # (768,) - - # Output projection c_proj - c_proj_weight = block.attn.c_proj.weight.detach().numpy() # (768, 768) - c_proj_bias = get_bias(block.attn.c_proj.bias, model_dim) # (768,) - - # Split into heads - # W_Q_all columns are organized as [head0, head1, ..., head11] - # Each head gets head_dim columns - for head_idx in range(num_heads): - start = head_idx * head_dim - end = (head_idx + 1) * head_dim - - # Extract per-head Q, K, V projections - # W_Q: (model_dim, head_dim) - projects input to queries for this head - W_Q = W_Q_all[:, start:end] # (768, 64) - W_K = W_K_all[:, start:end] # (768, 64) - W_V = W_V_all[:, start:end] # (768, 64) - - # Output projection for this head - # c_proj.weight: (768, 768), organized as [head0_rows, head1_rows, ...] - # W_O: (head_dim, model_dim) - projects head output back to model dim - W_O = c_proj_weight[start:end, :] # (64, 768) - - write_f64(f, W_Q) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V) - write_f64(f, b_V_all[start:end]) - write_f64(f, W_O) - - # Attention output bias (c_proj.bias), applied once after combining heads. - write_f64(f, c_proj_bias) - - # MLP weights - # GPT-2 MLP: x -> c_fc (expand) -> GELU -> c_proj (contract) -> out - # c_fc: (model_dim, hidden_dim) - # c_proj: (hidden_dim, model_dim) - - mlp_c_fc_weight = block.mlp.c_fc.weight.detach().numpy() # (768, 3072) - mlp_c_fc_bias = get_bias(block.mlp.c_fc.bias, hidden_dim) # (3072,) - mlp_c_proj_weight = block.mlp.c_proj.weight.detach().numpy() # (3072, 768) - mlp_c_proj_bias = get_bias(block.mlp.c_proj.bias, model_dim) # (768,) - - # Note: GPT-2 uses Conv1D which stores weights transposed compared to Linear - # Conv1D weight shape is (in_features, out_features) - # Our format expects W_in: (model_dim, hidden_dim) and W_out: (hidden_dim, model_dim) - - write_f64(f, mlp_c_fc_weight) # (768, 3072) - write_f64(f, mlp_c_fc_bias) # (3072,) - write_f64(f, mlp_c_proj_weight) # (3072, 768) - write_f64(f, mlp_c_proj_bias) # (768,) - - # LayerNorm parameters (Pre-LN) - # ln_1 is applied before attention; ln_2 is applied before MLP. - ln1_gamma = block.ln_1.weight.detach().numpy() # (model_dim,) - ln1_beta = get_bias(block.ln_1.bias, model_dim) # (model_dim,) - ln2_gamma = block.ln_2.weight.detach().numpy() # (model_dim,) - ln2_beta = get_bias(block.ln_2.bias, model_dim) # (model_dim,) - - write_f64(f, ln1_gamma) - write_f64(f, ln1_beta) - write_f64(f, ln2_gamma) - write_f64(f, ln2_beta) - - # Final LayerNorm (ln_f) before unembedding - ln_f_gamma = model.ln_f.weight.detach().numpy() # (model_dim,) - ln_f_beta = get_bias(model.ln_f.bias, model_dim) # (model_dim,) - write_f64(f, ln_f_gamma) - write_f64(f, ln_f_beta) - - # Unembedding matrix - # In GPT-2, the unembedding is tied to the embedding (same weights transposed) - # We want (model_dim, vocab_size) for projecting hidden states to logits - # wte is (vocab_size, model_dim), so we transpose it - unembed = wte.T # (768, 50257) - write_f64(f, unembed) - - # Report file size - file_size = output_file.stat().st_size - print("\nExport complete!") - print(f" File size: {file_size / 1024 / 1024:.1f} MB") - print(f" Output: {output_path}") - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().numpy() - - -if __name__ == "__main__": - output = sys.argv[1] if len(sys.argv) > 1 else "models/gpt2.nfpt" - export_gpt2_weights(output) diff --git a/scripts/generate_induction_data.py b/scripts/generate_induction_data.py deleted file mode 100644 index 2b6fe63..0000000 --- a/scripts/generate_induction_data.py +++ /dev/null @@ -1,148 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Generate a "Strong Induction" dataset for NFP verification. -Creates a sequence of random common words repeated multiple times. -""" - -import sys -from pathlib import Path -import numpy as np -import torch - -try: - from transformers import GPT2Model -except ImportError: - print("Error: transformers library not installed.") - print("Install with: uv add transformers torch") - sys.exit(1) - - -def export_induction_weights(output_path: str = "models/gpt2_induction.nfpt"): - print("Loading GPT-2 Small...") - model = GPT2Model.from_pretrained("gpt2") - config = model.config - layer_norm_eps = float(config.layer_norm_epsilon) - - # Fixed parameters for GPT-2 Small - seq_len = 256 - model_dim = 768 - num_heads = 12 - head_dim = 64 - - # --- CRITICAL CHANGE: Better Induction Data --- - # Instead of tokens 0-15, we use random "word" tokens (1000-30000). - # A pattern of length 30 repeated ~8 times creates strong induction pressure. - np.random.seed(42) - pattern_len = 30 - # Generate 30 random token IDs - pattern = np.random.randint(1000, 30000, size=pattern_len) - # Repeat the pattern to fill seq_len - repeats = (seq_len // pattern_len) + 1 - sample_tokens = np.tile(pattern, repeats)[:seq_len] - - print(f"Generated induction sequence (len={seq_len}):") - print(f"Pattern (first 10): {sample_tokens[:10]}") - # ----------------------------------------------- - - # Compute embeddings - wte = model.wte.weight.detach().numpy() - wpe = model.wpe.weight.detach().numpy() - sample_embeddings = wte[sample_tokens] + wpe[:seq_len] - - output_file = Path(output_path) - output_file.parent.mkdir(parents=True, exist_ok=True) - - print(f"Exporting to {output_path}...") - with open(output_path, "wb") as f: - write_header( - f, - num_layers=config.n_layer, - num_heads=config.n_head, - model_dim=model_dim, - head_dim=head_dim, - hidden_dim=config.n_inner or 4 * model_dim, - vocab_size=config.vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - write_i32(f, sample_tokens) - write_f64(f, sample_embeddings) - - # Export weights - for layer_idx in range(config.n_layer): - block = model.h[layer_idx] - - # Attention - c_attn = block.attn.c_attn.weight.detach().numpy() - c_attn_bias = get_bias(block.attn.c_attn.bias, 3 * model_dim) - c_proj = block.attn.c_proj.weight.detach().numpy() - c_proj_bias = get_bias(block.attn.c_proj.bias, model_dim) - - W_Q_all = c_attn[:, 0:model_dim] - W_K_all = c_attn[:, model_dim : 2 * model_dim] - W_V_all = c_attn[:, 2 * model_dim : 3 * model_dim] - b_Q_all = c_attn_bias[0:model_dim] - b_K_all = c_attn_bias[model_dim : 2 * model_dim] - b_V_all = c_attn_bias[2 * model_dim : 3 * model_dim] - - for h in range(num_heads): - start, end = h * head_dim, (h + 1) * head_dim - write_f64(f, W_Q_all[:, start:end]) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K_all[:, start:end]) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V_all[:, start:end]) - write_f64(f, b_V_all[start:end]) - write_f64(f, c_proj[start:end, :]) - - write_f64(f, c_proj_bias) - - # MLP - write_f64(f, block.mlp.c_fc.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_fc.bias, config.n_inner or 4 * model_dim)) - write_f64(f, block.mlp.c_proj.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_proj.bias, model_dim)) - - # LayerNorm - write_f64(f, block.ln_1.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_1.bias, model_dim)) - write_f64(f, block.ln_2.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_2.bias, model_dim)) - - # Unembedding - write_f64(f, model.ln_f.weight.detach().numpy()) - write_f64(f, get_bias(model.ln_f.bias, model_dim)) - write_f64(f, wte.T) - - print("Done.") - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().numpy() - - -if __name__ == "__main__": - export_induction_weights() diff --git a/scripts/generate_rigorous_induction.py b/scripts/generate_rigorous_induction.py deleted file mode 100644 index 35f5da4..0000000 --- a/scripts/generate_rigorous_induction.py +++ /dev/null @@ -1,170 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Generate a "Rigorous Induction" dataset for NFP verification. -Constructs a sequence of random single-token words repeated in a fixed pattern. - -Dataset intent (heuristic, model-agnostic): -- The next token is deterministically defined by a previous occurrence in the sequence. -- Randomized content reduces semantic cues but does not guarantee model probabilities. - -This aims to isolate induction-style copying from semantic completion. -""" - -import sys -from pathlib import Path -import numpy as np -import torch - -try: - from transformers import GPT2Model, GPT2Tokenizer -except ImportError: - print("Error: transformers library not installed.") - print("Install with: uv add transformers torch") - sys.exit(1) - - -def export_rigorous_induction(output_path: str = "models/gpt2_rigorous.nfpt"): - print("Loading GPT-2 Small...") - model = GPT2Model.from_pretrained("gpt2") - tokenizer = GPT2Tokenizer.from_pretrained("gpt2") - config = model.config - layer_norm_eps = float(config.layer_norm_epsilon) - - # 1. Define a vocabulary of common, distinct English nouns - # These token IDs are single-token words in GPT-2 (e.g., " apple", " logic") - # This prevents fragmentation issues. - # IDs: 1000-5000 range usually contains common words. - # We filter for length > 4 to ensure they are substantive. - vocab_candidates = [] - for tid in range(1000, 5000): - word = tokenizer.decode([tid]) - if len(word.strip()) > 4 and word.startswith(" "): - vocab_candidates.append(tid) - - # 2. Construct the Random Repeated Sequence - # Pattern length L=20, repeated to fill seq_len (256 tokens). - seq_len = 256 - pattern_len = 20 - - np.random.seed(1337) # Fixed seed for reproducibility - unique_pattern = np.random.choice(vocab_candidates, size=pattern_len, replace=False) - - # [A, B, C, ...] -> [A, B, C, ..., A, B, C, ...] - repeats = (seq_len // pattern_len) + 1 - full_sequence = np.tile(unique_pattern, repeats)[:seq_len] - - # 3. Verify the "Induction Target" property - # The last token T[N] must have appeared at T[N - pattern_len]. - # The target is T[N - pattern_len + 1]. - last_token = full_sequence[-1] - prev_idx = seq_len - 1 - pattern_len - target_token = full_sequence[prev_idx + 1] - - print(f"\nSequence Structure:") - print(f" Pattern Length: {pattern_len}") - print(f" Total Length: {seq_len}") - print(f" Last Token: '{tokenizer.decode([last_token])}' (ID: {last_token})") - print(f" Previous Occur: Index {prev_idx}") - print( - f" True Target: '{tokenizer.decode([target_token])}' (ID: {target_token})" - ) - - # 4. Compute Embeddings & Weights - wte = model.wte.weight.detach().numpy() - wpe = model.wpe.weight.detach().numpy() - sample_embeddings = wte[full_sequence] + wpe[:seq_len] - - # 5. Export - output_file = Path(output_path) - output_file.parent.mkdir(parents=True, exist_ok=True) - - print(f"\nExporting to {output_path}...") - with open(output_path, "wb") as f: - write_header( - f, - num_layers=config.n_layer, - num_heads=config.n_head, - model_dim=768, - head_dim=64, - hidden_dim=config.n_inner or 4 * 768, - vocab_size=config.vocab_size, - seq_len=seq_len, - layer_norm_eps=layer_norm_eps, - gelu_kind="tanh", - ) - - write_i32(f, full_sequence) - write_f64(f, sample_embeddings) - - # Export Layers (Standard Loop) - for layer_idx in range(config.n_layer): - block = model.h[layer_idx] - - c_attn = block.attn.c_attn.weight.detach().numpy() - c_attn_bias = get_bias(block.attn.c_attn.bias, 3 * 768) - c_proj = block.attn.c_proj.weight.detach().numpy() - c_proj_bias = get_bias(block.attn.c_proj.bias, 768) - - W_Q_all = c_attn[:, 0:768] - W_K_all = c_attn[:, 768 : 2 * 768] - W_V_all = c_attn[:, 2 * 768 : 3 * 768] - b_Q_all = c_attn_bias[0:768] - b_K_all = c_attn_bias[768 : 2 * 768] - b_V_all = c_attn_bias[2 * 768 : 3 * 768] - - for h in range(12): - start, end = h * 64, (h + 1) * 64 - write_f64(f, W_Q_all[:, start:end]) - write_f64(f, b_Q_all[start:end]) - write_f64(f, W_K_all[:, start:end]) - write_f64(f, b_K_all[start:end]) - write_f64(f, W_V_all[:, start:end]) - write_f64(f, b_V_all[start:end]) - write_f64(f, c_proj[start:end, :]) - - write_f64(f, c_proj_bias) - - write_f64(f, block.mlp.c_fc.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_fc.bias, config.n_inner or 4 * 768)) - write_f64(f, block.mlp.c_proj.weight.detach().numpy()) - write_f64(f, get_bias(block.mlp.c_proj.bias, 768)) - - write_f64(f, block.ln_1.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_1.bias, 768)) - write_f64(f, block.ln_2.weight.detach().numpy()) - write_f64(f, get_bias(block.ln_2.bias, 768)) - - write_f64(f, model.ln_f.weight.detach().numpy()) - write_f64(f, get_bias(model.ln_f.bias, 768)) - write_f64(f, wte.T) - - print("Done.") - - -def write_header(f, **fields: object) -> None: - f.write(b"NFP_BINARY_V1\n") - for key, value in fields.items(): - f.write(f"{key}={value}\n".encode("ascii")) - f.write(b"BINARY_START\n") - - -def write_i32(f, data: np.ndarray) -> None: - arr = np.ascontiguousarray(data, dtype=" None: - arr = np.ascontiguousarray(data, dtype=" np.ndarray: - if param is None: - return np.zeros(size, dtype=np.float64) - return param.detach().numpy() - - -if __name__ == "__main__": - export_rigorous_induction() diff --git a/scripts/sanity_forward_step.py b/scripts/sanity_forward_step.py deleted file mode 100644 index d700a9a..0000000 --- a/scripts/sanity_forward_step.py +++ /dev/null @@ -1,139 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Sanity check: compare one forward step against PyTorch / HuggingFace GPT-2. - -This is a correctness guardrail to ensure that the exported `.nfpt` semantics match what -the Lean-side executable actually analyzes (including attention biases). - -Usage: - uv run scripts/sanity_forward_step.py models/gpt2_rigorous.nfpt --layer 0 --pos 0 --take 16 -""" - -from __future__ import annotations - -import argparse -import subprocess -from pathlib import Path - -import numpy as np - -try: - import torch - from transformers import GPT2Model -except ImportError as e: - raise SystemExit( - "Missing deps. Install with e.g.: uv add torch transformers\n" - f"ImportError: {e}" - ) - - -def parse_tokens_from_nfpt(path: Path) -> list[int]: - with path.open("rb") as f: - magic = None - header: dict[str, str] = {} - while True: - line = f.readline() - if line == b"": - raise SystemExit("Unexpected EOF while reading header.") - text = line.decode("ascii").rstrip("\n") - if magic is None: - magic = text.strip() - if text.strip() == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - - if magic != "NFP_BINARY_V1": - raise SystemExit(f"Unsupported magic header: {magic}") - seq_len_raw = header.get("seq_len") - if seq_len_raw is None: - raise SystemExit("Missing seq_len in header.") - - seq_len = int(seq_len_raw) - raw = f.read(seq_len * 4) - if len(raw) != seq_len * 4: - raise SystemExit("Unexpected EOF while reading TOKENS section.") - toks = np.frombuffer(raw, dtype=" np.ndarray: - exe = Path(".lake/build/bin/nfp") - if not exe.exists(): - raise SystemExit("Missing `.lake/build/bin/nfp`. Run `lake build nfp` first.") - cmd = [ - str(exe), - "dump", - "--kind", - "afterLayer", - "--layer", - str(layer), - "--pos", - str(pos), - "--take", - str(take), - str(model_path), - ] - out = subprocess.check_output(cmd, text=True) - lines = [ln.strip() for ln in out.splitlines() if ln.strip()] - dump_i = None - for i, ln in enumerate(lines): - if ln.startswith("DUMP "): - dump_i = i - if dump_i is None or dump_i + 2 >= len(lines): - raise SystemExit(f"Unexpected Lean dump output:\n{out}") - vals = [float(x) for x in lines[dump_i + 2].split()] - return np.asarray(vals, dtype=np.float64) - - -def run_torch(model_name: str, tokens: list[int], layer: int, pos: int, take: int) -> np.ndarray: - m = GPT2Model.from_pretrained(model_name) - # Lean `Float` is IEEE-754 double; run PyTorch in float64 to avoid large float32-vs-float64 drift - # (especially through attention softmax on long sequences). - m = m.to(dtype=torch.float64) - m.eval() - input_ids = torch.tensor([tokens], dtype=torch.long) - with torch.no_grad(): - out = m(input_ids=input_ids, output_hidden_states=True) - hs = out.hidden_states - if hs is None: - raise SystemExit("Expected output_hidden_states=True to return hidden_states.") - # hs[0] = embeddings, hs[layer+1] = after block `layer`. - x = hs[layer + 1][0, pos, :take].detach().cpu().numpy() - return x.astype(np.float64) - - -def main() -> None: - ap = argparse.ArgumentParser() - ap.add_argument("model", type=Path) - ap.add_argument("--hf", default="gpt2", help="HuggingFace model name (default: gpt2)") - ap.add_argument("--layer", type=int, default=0) - ap.add_argument("--pos", type=int, default=0) - ap.add_argument("--take", type=int, default=16) - ap.add_argument("--tol", type=float, default=1e-3) - args = ap.parse_args() - - if not args.model.exists(): - raise SystemExit(f"Missing file: {args.model}") - tokens = parse_tokens_from_nfpt(args.model) - if len(tokens) == 0: - raise SystemExit("No tokens parsed from TOKENS section.") - - lean = run_lean_dump(args.model, args.layer, args.pos, args.take) - torch_x = run_torch(args.hf, tokens, args.layer, args.pos, args.take) - - diff = np.max(np.abs(lean - torch_x)) - print(f"max_abs_diff={diff:.6g} tol={args.tol:.6g}") - if not np.isfinite(diff) or diff > args.tol: - print("FAIL") - print("lean :", lean.tolist()) - print("torch:", torch_x.tolist()) - raise SystemExit(1) - print("OK") - - -if __name__ == "__main__": - main() diff --git a/scripts/scan_gpt2_induction_sound.py b/scripts/scan_gpt2_induction_sound.py deleted file mode 100755 index b767d59..0000000 --- a/scripts/scan_gpt2_induction_sound.py +++ /dev/null @@ -1,273 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -""" -Scan GPT-2 induction head candidates with SOUND logit-diff bounds. - -This script: -1) Ensures a GPT-2 "rigorous induction" binary model exists. -2) Uses the heuristic `induction` command to list candidate head pairs. -3) Runs `induction_cert` for each pair and ranks by logitDiffLB. -""" - -from __future__ import annotations - -import argparse -import os -import re -import shutil -import struct -import subprocess -import sys -from concurrent.futures import ThreadPoolExecutor, as_completed -from fractions import Fraction -from pathlib import Path - - -PAIR_RE = re.compile(r"L(\d+)H(\d+)\s+->\s+L(\d+)H(\d+)") - - -def run_cmd(cmd: list[str]) -> str: - proc = subprocess.run(cmd, check=True, capture_output=True, text=True) - return proc.stdout - - -def ensure_model(model_path: Path) -> None: - if model_path.exists(): - return - model_path.parent.mkdir(parents=True, exist_ok=True) - generator = [sys.executable, "scripts/generate_rigorous_induction.py", str(model_path)] - if shutil.which("uv"): - generator = ["uv", "run"] + generator - subprocess.run(generator, check=True) - - -def resolve_nfp_cmd(nfp_bin: str | None) -> list[str]: - if nfp_bin: - return [nfp_bin] - env_bin = os.environ.get("NFP_BIN") - if env_bin: - return [env_bin] - local_bin = Path(".lake/build/bin/nfp") - if local_bin.exists(): - return [str(local_bin)] - return ["lake", "exe", "nfp"] - - -def read_header_and_tokens(path: Path) -> tuple[dict[str, str], list[int]]: - header: dict[str, str] = {} - with path.open("rb") as f: - while True: - line = f.readline() - if not line: - raise ValueError("unexpected EOF while reading header") - text = line.decode("ascii").strip() - if text == "BINARY_START": - break - if "=" in text: - key, value = text.split("=", 1) - header[key.strip()] = value.strip() - seq_len_raw = header.get("seq_len") - if seq_len_raw is None: - raise ValueError("header missing seq_len") - seq_len = int(seq_len_raw) - token_bytes = f.read(seq_len * 4) - if len(token_bytes) != seq_len * 4: - raise ValueError("unexpected EOF while reading tokens") - tokens = list(struct.unpack("<" + "i" * seq_len, token_bytes)) - return header, tokens - - -def derive_target_negative(tokens: list[int]) -> tuple[int, int]: - if len(tokens) < 2: - raise ValueError("need at least 2 tokens to derive target/negative") - last = tokens[-1] - prev_idx = None - for i in range(len(tokens) - 2, -1, -1): - if tokens[i] == last: - prev_idx = i - break - if prev_idx is not None and prev_idx + 1 < len(tokens): - target = tokens[prev_idx + 1] - else: - target = last - negative = tokens[-2] - if negative == target: - negative = last if last != target else (0 if target != 0 else 1) - return target, negative - - -def parse_candidates(output: str, top: int) -> list[tuple[int, int, int, int]]: - pairs: list[tuple[int, int, int, int]] = [] - seen: set[tuple[int, int, int, int]] = set() - for line in output.splitlines(): - match = PAIR_RE.search(line) - if match is None: - continue - pair = tuple(int(x) for x in match.groups()) - if pair in seen: - continue - seen.add(pair) - pairs.append(pair) - if len(pairs) >= top: - break - return pairs - - -def parse_logit_lb(output: str) -> Fraction | None: - for line in output.splitlines(): - if line.startswith("logitDiffLB="): - token = line.split("=", 1)[1].split()[0] - try: - return Fraction(token) - except ValueError: - return None - return None - - -def main() -> int: - parser = argparse.ArgumentParser() - parser.add_argument("--model", default="models/gpt2_rigorous.nfpt") - parser.add_argument("--top", type=int, default=8) - parser.add_argument("--delta", default="0.01") - parser.add_argument("--coord", type=int, default=0) - parser.add_argument("--offset1", type=int, default=-1) - parser.add_argument("--offset2", type=int, default=-1) - parser.add_argument("--maxSeqLen", type=int, default=256) - parser.add_argument("--jobs", type=int, default=1) - parser.add_argument("--fast", action="store_true") - parser.add_argument("--nfp-bin", help="Path to nfp binary (defaults to .lake/build/bin/nfp)") - parser.add_argument("--tightPattern", action="store_true") - parser.add_argument("--tightPatternLayers", type=int) - parser.add_argument("--perRowPatternLayers", type=int) - parser.add_argument("--bestMatch", action="store_true") - parser.add_argument("--queryPos", type=int) - parser.add_argument("--output", default="reports/gpt2_induction_sound_scan.txt") - args = parser.parse_args() - args.jobs = max(1, args.jobs) - top_arg = any(a.startswith("--top") for a in sys.argv[1:]) - max_seq_len_arg = any(a.startswith("--maxSeqLen") for a in sys.argv[1:]) - if args.fast and not top_arg and args.top == parser.get_default("top"): - args.top = 4 - - model_path = Path(args.model) - ensure_model(model_path) - nfp_cmd = resolve_nfp_cmd(args.nfp_bin) - - header, tokens = read_header_and_tokens(model_path) - seq_len = int(header.get("seq_len", "0")) - if args.fast and not max_seq_len_arg and args.maxSeqLen == parser.get_default("maxSeqLen"): - if seq_len <= 128: - args.maxSeqLen = 128 - if seq_len > args.maxSeqLen: - print( - f"Error: seq_len {seq_len} exceeds maxSeqLen {args.maxSeqLen}", - file=sys.stderr, - ) - return 1 - target, negative = derive_target_negative(tokens) - - induction_out = run_cmd( - nfp_cmd - + [ - "induction", - str(model_path), - "--threshold", - "0.0", - ] - ) - pairs = parse_candidates(induction_out, args.top) - if not pairs: - print("No induction candidates found.", file=sys.stderr) - return 1 - - results: list[tuple[Fraction, tuple[int, int, int, int]]] = [] - - def run_cert(pair: tuple[int, int, int, int]) -> tuple[tuple[int, int, int, int], Fraction | None]: - l1, h1, l2, h2 = pair - cmd = nfp_cmd + [ - "induction_cert", - str(model_path), - "--layer1", - str(l1), - "--head1", - str(h1), - "--layer2", - str(l2), - "--head2", - str(h2), - "--coord", - str(args.coord), - "--offset1", - str(args.offset1), - "--offset2", - str(args.offset2), - "--delta", - args.delta, - "--maxSeqLen", - str(args.maxSeqLen), - "--target", - str(target), - "--negative", - str(negative), - ] - if args.tightPattern: - cmd.append("--tightPattern") - if args.tightPatternLayers is not None: - cmd.extend(["--tightPatternLayers", str(args.tightPatternLayers)]) - if args.perRowPatternLayers is not None: - cmd.extend(["--perRowPatternLayers", str(args.perRowPatternLayers)]) - if args.bestMatch: - cmd.append("--bestMatch") - if args.queryPos is not None: - cmd.extend(["--queryPos", str(args.queryPos)]) - try: - cert_out = run_cmd(cmd) - except subprocess.CalledProcessError: - return pair, None - return pair, parse_logit_lb(cert_out) - - if args.jobs == 1: - for pair in pairs: - pair_out, logit_lb = run_cert(pair) - if logit_lb is None: - continue - results.append((logit_lb, pair_out)) - else: - with ThreadPoolExecutor(max_workers=args.jobs) as executor: - futures = {executor.submit(run_cert, pair): pair for pair in pairs} - for future in as_completed(futures): - pair_out, logit_lb = future.result() - if logit_lb is None: - continue - results.append((logit_lb, pair_out)) - - if not results: - print("No sound logit bounds produced.", file=sys.stderr) - return 1 - - results.sort(key=lambda x: x[0], reverse=True) - out_path = Path(args.output) - out_path.parent.mkdir(parents=True, exist_ok=True) - with out_path.open("w", encoding="ascii") as f: - f.write("SOUND induction scan (logitDiffLB ranking)\n") - f.write(f"model={model_path}\n") - f.write(f"target={target} negative={negative}\n") - f.write( - f"bestMatch={args.bestMatch} queryPos={args.queryPos} " - f"tightPatternLayers={args.tightPatternLayers} " - f"perRowPatternLayers={args.perRowPatternLayers}\n" - ) - eps_header = header.get("layer_norm_eps") or header.get("eps") or "unknown" - f.write(f"top={args.top} delta={args.delta} eps={eps_header}\n") - for rank, (lb, (l1, h1, l2, h2)) in enumerate(results, start=1): - f.write( - f"{rank:02d} L{l1}H{h1} -> L{l2}H{h2} logitDiffLB={lb}\n" - ) - - print(f"Report written to {out_path}") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/scripts/variance_lower_bound_check.py b/scripts/variance_lower_bound_check.py deleted file mode 100644 index 5acd5c6..0000000 --- a/scripts/variance_lower_bound_check.py +++ /dev/null @@ -1,53 +0,0 @@ -#!/usr/bin/env python3 -# SPDX-License-Identifier: AGPL-3.0-or-later - -from fractions import Fraction - - -def population_variance(xs: list[Fraction]) -> Fraction: - n = len(xs) - if n == 0: - return Fraction(0) - mean = sum(xs, Fraction(0)) / n - return sum((x - mean) ** 2 for x in xs) / n - - -def variance_bound_range(delta: Fraction, n: int) -> Fraction: - if n <= 0 or delta <= 0: - return Fraction(0) - return (delta * delta) / (2 * n) - - -def main() -> None: - xs = [Fraction(0), Fraction(1, 2), Fraction(1)] - n = len(xs) - delta = max(xs) - min(xs) - var = population_variance(xs) - old_bound = Fraction(n - 1, n * n) * (delta**2) - new_bound = variance_bound_range(delta, n) - print(f"x = {xs}") - print(f"n = {n}, delta = {delta}") - print(f"variance = {var} = {float(var):.6f}") - print(f"old_bound = {old_bound} = {float(old_bound):.6f}") - print(f"new_bound = {new_bound} = {float(new_bound):.6f}") - print(f"old_bound <= variance? {old_bound <= var}") - print(f"new_bound <= variance? {new_bound <= var}") - - # Quick brute-force sanity check on a small grid. - grid = [Fraction(0), Fraction(1, 2), Fraction(1)] - for a in grid: - for b in grid: - for c in grid: - xs = [a, b, c] - delta = max(xs) - min(xs) - var = population_variance(xs) - bound = variance_bound_range(delta, len(xs)) - if bound > var: - raise SystemExit( - f"violation: xs={xs} delta={delta} var={var} bound={bound}" - ) - print("grid_check=OK") - - -if __name__ == "__main__": - main() diff --git a/tests/fixtures/tiny_sound_binary.nfpt b/tests/fixtures/tiny_sound_binary.nfpt deleted file mode 100644 index 30238ce..0000000 Binary files a/tests/fixtures/tiny_sound_binary.nfpt and /dev/null differ diff --git a/tests/fixtures/tiny_sound_input.nfpt b/tests/fixtures/tiny_sound_input.nfpt deleted file mode 100644 index 63db48c..0000000 --- a/tests/fixtures/tiny_sound_input.nfpt +++ /dev/null @@ -1,17 +0,0 @@ -NFP_TEXT_V2 -num_layers=1 -num_heads=1 -model_dim=4 -head_dim=2 -hidden_dim=4 -vocab_size=10 -seq_len=2 -layer_norm_eps=1e-5 -gelu_kind=tanh - -TOKENS -1 2 - -EMBEDDINGS -0.5 -0.25 0.125 -0.0625 -0.75 -0.50 0.25 -0.125 diff --git a/tests/fixtures/tiny_sound_model.nfpt b/tests/fixtures/tiny_sound_model.nfpt deleted file mode 100644 index a741285..0000000 --- a/tests/fixtures/tiny_sound_model.nfpt +++ /dev/null @@ -1,62 +0,0 @@ -NFP_TEXT_V2 -num_layers=1 -num_heads=1 -model_dim=4 -head_dim=2 -hidden_dim=4 -vocab_size=10 -seq_len=2 -layer_norm_eps=1e-5 -gelu_kind=tanh - -LAYER 0 -HEAD 0 -W_Q -0.10 -0.20 -0.30 0.40 --0.50 0.60 -0.70 -0.80 -b_Q -0.01 -0.02 -W_K --0.11 0.21 -0.31 -0.41 -0.51 0.61 --0.71 0.81 -b_K -0.03 -0.04 -W_V -0.05 0.06 --0.07 0.08 -0.09 -0.10 -0.11 0.12 -b_V -0.001 -0.002 -W_O -0.13 -0.14 0.15 -0.16 -0.17 0.18 -0.19 0.20 -ATTN_BIAS -0.001 0.002 0.003 0.004 -MLP -W_in -0.10 0.00 -0.10 0.20 -0.05 -0.05 0.10 0.00 --0.20 0.30 0.00 -0.10 -0.01 0.02 0.03 0.04 -b_in -0.001 -0.002 0.003 -0.004 -W_out -0.02 0.01 0.00 -0.01 --0.03 0.04 0.05 0.06 -0.07 -0.08 0.09 -0.10 -0.11 0.12 -0.13 0.14 -b_out -0.0001 -0.0002 0.0003 -0.0004 -LN1_GAMMA -1.0 1.0 1.0 1.0 -LN1_BETA -0.0 0.0 0.0 0.0 -LN2_GAMMA -1.0 1.0 1.0 1.0 -LN2_BETA -0.0 0.0 0.0 0.0