diff --git a/configs/example/kmhv3.py b/configs/example/kmhv3.py index 71844d9478..7d99a6b5c4 100644 --- a/configs/example/kmhv3.py +++ b/configs/example/kmhv3.py @@ -99,6 +99,8 @@ def setKmhV3Params(args, system): cpu.branchPred.tage.resolvedUpdate = True cpu.branchPred.ittage.resolvedUpdate = True + cpu.branchPred.tage.enableBankConflict = False + cpu.branchPred.ubtb.enabled = True cpu.branchPred.abtb.enabled = True cpu.branchPred.microtage.enabled = False diff --git a/docs/Gem5_Docs/frontend/block-tage.md b/docs/Gem5_Docs/frontend/block-tage.md new file mode 100644 index 0000000000..3a00c85633 --- /dev/null +++ b/docs/Gem5_Docs/frontend/block-tage.md @@ -0,0 +1,383 @@ +本文档是一份偏 PRD/架构说明的设计稿,目标读者包括: + +- 架构/RTL 同学:希望快速理解为什么要做、要做成什么语义、关键状态机如何跑。 +- Gem5/模型同学:希望能直接映射到现有 DecoupledBTB/BTBTAGE 的接口与更新时序。 + +本文聚焦第一阶段:**Block-Based Exit-Slot TAGE(Cond Exit)**。Two-Taken 仅保留为后续扩展方向。 + +--- + +# 架构演进提案:基于 Block 粒度的 Exit-Slot TAGE(Cond Exit)与 Two-Taken 预测机制 + +## 1. 背景与动机 (Motivation) + +在 Gem5/香山高性能核的 SPEC06 性能分析中,我们发现前端带宽(Instruction Delivery)在 8-wide 架构下存在显著瓶颈。现有的 BTB-TAGE 组合方式存在以下痛点: + +1. **信息密度低 / 资源浪费**:当前做法是“Block 内每条 Cond 分支都单独预测方向”,但很多 Block 的真实行为往往是“最多只有一条 Cond Taken”。这会造成训练样本被稀释、表项被无效占用。 +2. **同 Block 多分支的 Set 压力与互相污染**:当前索引主要由 `StartPC + PHR` 决定,同一个 Fetch Block 内多条分支会落在同一个 set(靠 tag 中 XOR 的 position 来区分)。在 way 数较小(如 2-way)时,多分支会互相挤占/替换,等价于你文档里想表达的“aliasing/冲突”问题(这不是传统意义的 bank conflict,而是 set-assoc 压力)。 +3. **Two-Taken 缺失**:由于 BPU 内部 override 机制导致的流水线气泡,无法满足后端 8 发射的饥渴需求。(Two-Taken 本文先不展开实现细节) + +**本提案旨在通过实现 "Block-Based Exit-Slot TAGE(Cond Exit)" 和 "Speculative Two-Taken" 机制,将 BPU 的有效吞吐提升至 >1 Block/Cycle。** + +--- + +## 2. 核心架构设计 (Architecture Overview) + +我们将 TAGE 从 **"Per-Branch Direction Predictor"** 重构为 **"Block-Based Cond-Exit Predictor"**: + +- TAGE 只负责 **Cond 分支的“退出点”选择**(即:Block 内哪一个 Cond 分支会是第一条 Taken)。 +- Uncond/Indirect/Return 的处理保持现有 BTB 流水线逻辑,不在本次改动范围内。 + +“Block-Based(Exit-Slot)” 相比“Per-Branch”的真正优势,不在于“多Pattern”时的容量,而在于“单Pattern”时的效率、抗干扰能力以及带宽匹配度。 +目的是提升信息密度:你的方法输出的是一个向量 (Vector) [T/NT, T/NT, T/NT];Exit-Slot 输出的是一个标量 (Scalar) slot 编码。对于单目标跳转体系,标量比向量更抗噪。 + +### 2.0 设计目标 / 非目标(PRD) + +**目标(Goals)** + +1. 将 Cond 分支方向预测从 “每条分支一个表项/一个预测” 转为 “每个 Fetch Block 一个 payload(ExitSlotEnc)”。 +2. 保持与现有 GEM5 DecoupledBTB 的接口兼容:仍输出 `condTakens`,上层仍通过 “按 PC 顺序选择第一条 taken” 得到最终控制流出口。 +3. 更新与分配以 “每个 Fetch Block 一次训练样本” 为粒度,避免对 exit 之后不可达 cond 分支做 NT 训练(减少系统性噪声)。 + +**非目标(Non-Goals)** + +1. 不改变 Uncond/Indirect/Return 的预测与选择规则。 +2. 第一阶段不引入 Two-Taken 的细节实现(但文档保留扩展点)。 +3. 第一阶段不引入复杂的多 payload/向量输出(例如同时预测多个 cond 的 T/NT)。 + +### 2.1 概念定义(与现有 GEM5 BTB 模型对齐) + +* **Fetch Block**: 取指块。当前 DecoupledBTB 模型里 `predictWidth = 64B`,并按 PC 顺序返回该范围内的 BTB entries。 +* **Slot(指令位置槽)**:以 **2B 粒度**划分 64B block,共 `32` 个 slot,范围 `0..31`。slot 计算方式与当前实现一致:`slot = (branchPC - alignedStartPC) >> instShiftAmt`,其中 `instShiftAmt=1`。其中 `alignedStartPC` 取 fetch 起始地址按 32B 对齐(MBTB half-aligned),因此 slot 覆盖的地址范围是 `[alignedStartPC, alignedStartPC+64B)`。 +* **Cond Exit Slot**:指示该 Fetch Block 内 **第一条 Taken 的 Cond 分支**位于哪个 slot。 +* **No-Cond-Exit(本文仍沿用“fallthrough”术语)**:表示该 Fetch Block 内 **没有 Cond Taken**(注意:这不排除 block 内存在 Uncond/Indirect/Return 导致的控制流退出;本提案的 TAGE 只负责 Cond Exit)。 + +### 2.1.1 关键语义澄清(给 RTL/模型同学) + +- “fallthrough / No-Cond-Exit” 在本文中仅表示 **cond 维度的 fallthrough**:即该 block 内没有 cond taken。 +- 若 block 内存在 uncond/indirect/return,它们依然可能成为最终控制流出口;这不由 Exit-Slot TAGE 决定。 + +### 2.1.2 兼容现有接口的落地方式 + +现有框架最终通过扫描 `btbEntries` 并结合 `condTakens` 选出第一条 taken entry。为了最小改动: + +- Exit-Slot TAGE 仍然生成 `condTakens`; +- 但不再为每条 cond 输出方向,而是 **最多只标记 1 条 cond taken**(对应预测的 exit slot);其余 cond 默认不在 `condTakens` 中出现,等价 NT。 + + + +### 2.2 组件交互图(概念) + +```text +[PC] ^ [PHR] + | + v ++------------------------+ +------------------------+ +| Block-Based TAGE | | Auxiliary GShare | +| (Main Predictor) | | (For 2nd Taken) | +| Output: ExitSlot_1 | | Output: Is_Taken_2? | ++------------------------+ +------------------------+ + | | + +---------------+ +---------------+ + | | + v v ++------------------------------------------------+ +| MBTB (Multi-Target BTB) | +| Lookup(PC1) -> { Branch_1..N_Info, Targets } | ++------------------------------------------------+ + | + v + Final Decision Logic + 1. Taken 1: TAGE predicts ExitSlot_1 (Cond exit slot). + Then select the corresponding cond branch entry from MBTB entries (by slot) and mark it taken. + 2. Taken 2: If Taken 1 is Taken AND GShare says Taken: + Get First_Branch Target from MBTB (Next Line Logic). + +``` +但在 GEM5 当前模型架构中,是先查 Main BTB 结构得到一个 block 内命中的 BTB entries(按 PC 顺序),再交给方向预测器填充 `condTakens`。本提案的第一阶段会保持接口兼容:仍然输出 `condTakens`,只是由 “per-branch” 改为 “per-block 选中一个 cond exit”。 + +### 2.3 设计约束与假设(Implementation Constraints) + +- **slot 编码选择**:使用指令位置 slot(0..31),而非“第 N 条分支”。原因:slot 语义稳定,不随 MBTB 命中条目集合变化而漂移。 +- **payload 编码**:由于 32 个 slot + 1 个 No-Cond-Exit,推荐使用 **6 bits 的 `ExitSlotEnc`**: + - `ExitSlotEnc==0`:No-Cond-Exit + - `ExitSlotEnc in [1..32]`:slot = ExitSlotEnc - 1 +- **训练粒度**:每个 fetch block 只训练一次(围绕真实的 cond exit),不训练 exit 之后不可达 cond。 +- **回退策略**:payload 不可映射(找不到该 slot 的 cond entry)时,优先回退到 base(MBTB entry 的 `ctr`)。 +- **保持经验法则**:保留 useAltOnNa “provider 弱态时是否用 alt/base” 的机制;但索引从 branchPC 改为 startPC(block 粒度)。 + +--- + +## 3. 详细设计:Block-Based TAGE(Cond Exit / Taken 1) + +### 3.1 表项结构 (Entry Structure) + +不再存储 1-bit Direction,而是存储 “Cond Exit Slot(或 No-Cond-Exit)”。 + +> 关键点:64B block 有 32 个 slot(0..31),“No-Cond-Exit” 是额外的一个状态,因此**单独用 5 bits 无法同时表示全部 slot + fallthrough**。 +> +> 推荐采用 **6 bits 编码**(或等价的 `5bits slot + 1bit is_fallthrough`)。 + +| Field | Bits | Description | +| --- | --- | --- | +| **Tag** | 8-16 | `Hash(StartPC, PHR)`,用于匹配 Block。 | +| **Conf** | 2-3 | **置信计数器(建议 3 bits)**:表示该 payload 在该相关历史下是否稳定可靠。
弱态阈值建议沿用现有经验:`Conf in {0, -1}` 视为 weak。
更新规则与 per-branch 的 taken/nt 不同:**用 “是否预测正确” 来更新 Conf**(见 3.3)。 | +| **ExitSlotEnc** | 6 | **Payload**(推荐编码):
`0`: No-Cond-Exit(本文仍称 fallthrough)
`1..32`: 表示 `slot = ExitSlotEnc - 1`,范围 `0..31` | +| **U** | 1 | Useful bit,用于替换策略 (Clock/Ageing)。 | + +**Conf 与 U 的分工(必须写清楚)** + +- `Conf`:回答“这个 payload 在这个相关历史下是否稳定可靠”,主要用于 **useAlt 门控**、**防抖(是否允许 rewrite)**、以及 **是否值得 allocate 长历史**。 +- `U`:回答“这条表项是否相对 alt/base 提供了增益”,主要用于 **替换/分配候选选择**(例如优先替换 `U==0` 的 entry)。 + +### 3.2 预测逻辑 (Prediction Stage) + +本节描述 **预测阶段**在一个 fetch block 上的完整行为:如何从 TAGE 表项得到 `ExitSlotEnc`,以及如何将其落地到 `condTakens`。 + +#### 3.2.1 Index/Tag(与现有实现对齐) + +1. **Index**:仅使用 `StartPC + FoldedPHR`(不加入 branch offset)。 +2. **Tag**:使用 `StartPC + FoldedPHR`;无需再 XOR position(因为一个 block 只对应一个 payload)。 + +> 说明:现有 per-branch TAGE 的 tag 会 XOR position 来区分同一 block 内的多条分支;Exit-Slot TAGE 的目的正是把这些分支“压缩”为一个 block-level payload,因此不再需要 position 进入 tag。 + +#### 3.2.2 Provider/Alt 选择(最长历史优先) + +- 从最长历史表向短历史表扫描命中: + - 第一命中为 Provider + - 第二命中为 Alt Provider(用于弱态/冲突时回退) + +#### 3.2.3 useAltOnNa 门控(沿用经验,但索引换成 startPC) + +- Provider miss:直接回退 Base。 +- Provider hit 且 `Conf` 为 weak(建议 `Conf in {0,-1}`): + - 查询 `useAltOnNa[startPC]` 决定使用 Alt(若存在)或 Base; +- Provider hit 且 `Conf` 非 weak:使用 Provider payload。 + +#### 3.2.4 将 payload 落地为 `condTakens`(接口兼容的关键) + +解码得到 `(is_no_cond_exit, pred_slot)`: + +- 若 `ExitSlotEnc==0`: + - 不写入任何 cond 的 taken(等价所有 cond NT) +- 若 `ExitSlotEnc in [1..32]`: + 1. 在 MBTB 返回的 `btbEntries` 中寻找 `isCond==true` 且 `slot(entry.pc)==pred_slot` 的 entry; + 2. 找到则仅写入这一条 `condTakens[entry.pc]=true`; + 3. 其余 cond entry 不写入 `condTakens`(等价 NT)。 + +**Fallback(payload 不可映射)**: + +- 若找不到 `pred_slot` 对应的 cond entry(MBTB miss/过滤/未学到等): + - 回退 Base:对每条 cond entry 使用 MBTB 的 `ctr>=0` 作为方向预测,生成 `condTakens`; + - 这是为了避免 “payload 不可映射 ⇒ 强制 No-Cond-Exit” 带来的不必要性能退化。 + +**Base 的精确定义(便于 RTL/模型一致)** + +- 对每条 `btbEntries` 中的 cond entry: + - `pred_taken = entry.alwaysTaken || (entry.ctr >= 0)` + - 写入 `condTakens[entry.pc] = pred_taken` +- 若某条 cond entry 没写入 `condTakens`,上层会按 “未找到即视为 NT” 处理。 + +#### 3.2.5 伪代码(预测阶段) + +```text +predict_block(startPC, btbEntries, PHR): + provider, alt = tage_lookup(startPC, PHR) + if provider.miss: + return base_condTakens(btbEntries) + + if is_weak(provider.Conf) and useAltOnNa[startPC] says "use alt": + if alt.hit: + enc = alt.ExitSlotEnc + else: + return base_condTakens(btbEntries) + else: + enc = provider.ExitSlotEnc + + if enc == 0: + return {} // all cond NT + + pred_slot = enc - 1 + e = find_cond_entry_by_slot(btbEntries, pred_slot) + if e.exists: + return { e.pc : true } // only one taken + else: + return base_condTakens(btbEntries) +``` + + + +### 3.3 更新逻辑 (Update Stage) + +每个 Fetch Block **只更新/分配一次**,并且**不对 exit 之后的 cond 分支进行“NT 训练”**(它们在该动态 instance 中不可达)。 + +本节给出 **可直接给 RTL 同学实现** 的更新/分配状态机:什么时候只训练 Conf,什么时候 rewrite payload,什么时候 allocate 长历史表项。 + +#### 3.3.1 真实标签 `RealEnc` 的定义(Cond 维度) + +- 若 `stream.exeTaken==true` 且 `stream.exeBranchInfo.isCond==true`: + - `real_slot = slot(stream.exeBranchInfo.pc)` + - `RealEnc = real_slot + 1` +- 否则: + - `RealEnc = 0`(No-Cond-Exit) + +> 说明:若最终出口是 uncond/indirect/return,本提案把 `RealEnc` 视为 0,因为 TAGE 只负责 cond exit。 + +#### 3.3.2 预测标签 `PredEnc` 的定义(与预测阶段保持一致) + +更新时应使用“预测阶段最终生效的决策”来计算 `PredEnc`: + +- 若最终使用了某个 TAGE provider/alt 的 payload:`PredEnc = ExitSlotEnc` +- 若走了 Base 回退: + - 令 `PredEnc = base_exit_slot_enc(btbEntries)`: + - 若 base 在该 block 内预测到某条 cond taken:`PredEnc = slot(pc_first_taken_cond)+1` + - 否则:`PredEnc = 0` + +其中 `base_exit_slot_enc` 的计算方式为: + +1. 按 `btbEntries` 的 PC 顺序扫描 cond entry; +2. 对每条 cond 计算 `pred_taken = entry.alwaysTaken || (entry.ctr >= 0)`; +3. 返回第一条 `pred_taken==true` 的 cond 的 `slot(pc)+1`;若不存在则返回 0。 + +#### 3.3.3 Conf/U 的更新(正确性驱动,而非 taken/nt 驱动) + +令 `correct = (PredEnc == RealEnc)`。 + +- 若 `correct`: + - `Conf = sat_inc(Conf)` + - `U`:当 **provider 被选用** 且 provider 正确,并且 alt/base 的结果会不同/更差时置 1(表示这条表项“提供了增益”)。一种可执行的定义是:\ + `provider_used && correct && (AltOrBasePredEnc != RealEnc) => U=1`。 +- 若 `!correct`: + - `Conf = sat_dec(Conf)` + - `U`:可在进入弱态时清 0(更保守),或直接清 0(更激进,利于替换)。 + +> 关键点:Conf 的更新以 “payload 是否正确” 为准;这与 per-branch TAGE 里 “按 taken/nt 更新 counter” 不同,是本 PRD 的核心变化之一。 + +#### 3.3.4 分配/重写策略(建议的三条硬规则) + +为兼顾收敛速度与稳定性,推荐采用下述三条规则: + +1. **weak-but-correct:不分配** + - 若 provider hit,且 `is_weak(Conf)`,但 `correct==true`: + - 只训练 `Conf++`(“还不够自信,继续训练”),不 allocate 长历史表,避免浪费与 ping-pong。 + +2. **strong-but-wrong:倾向分配长历史表项** + - 若 provider hit,且错误发生前 `Conf` 为 strong(非 weak 且接近饱和),但 `correct==false`: + - 解释:此时错往往代表 “短历史不足以区分多模式/aliasing”,allocate 长历史更可能解决。 + - 行为:在更长历史表中尝试 allocate 写入 `RealEnc`,原 entry payload 不立刻改(防抖)。 + +3. **weak-and-wrong:倾向原地重写 payload** + - 若 provider hit 且 `correct==false`,并且 `Conf` 已经掉到 weak(进入/处于 weak 区间): + - 解释:该 entry 现阶段不可信,继续“死守旧 payload”只会制造持续噪声; + - 行为:允许 **原地 rewrite payload = RealEnc**,并将 `Conf` 重新初始化到 weak(例如 0 或 -1),`U=0`。 + +#### 3.3.5 Provider miss 时的分配策略 + +- 若 provider miss: + - 直接在若干个更长历史表(或从最短表起)尝试 allocate 新 entry,payload 写入 `RealEnc`; + - `Conf` 初始化为 weak;`U=0`。 + +#### 3.3.6 useAltOnNa 的更新(沿用经验,但以 block 粒度) + +- 仅当 provider hit 且 provider 在预测时处于 weak,才更新 `useAltOnNa[startPC]`: + - 若 alt/base 的决策更接近真实 `RealEnc`,则向 “use alt/base” 方向更新; + - 否则向相反方向更新。 + +#### 3.3.7 伪代码(更新/分配) + +```text +update_block(startPC, btbEntries, RealEnc, provider, alt, PredEnc): + correct = (PredEnc == RealEnc) + + if provider.hit: + if correct: + provider.Conf++ + if provider_decision_differs_from_alt_or_base: + provider.U = 1 + if is_weak(provider.Conf): // weak-but-correct + return // no allocation + else: + provider.Conf-- + if becomes_or_is_weak(provider.Conf): // weak-and-wrong + provider.ExitSlotEnc = RealEnc + provider.Conf = WEAK_INIT + provider.U = 0 + return + else: // strong-but-wrong + try_allocate_longer_tables(startPC, RealEnc) + return + else: + try_allocate_tables(startPC, RealEnc) // miss allocation +``` + +#### 3.3.8 参数建议(给 RTL 一个可落地的默认配置) + +- **Conf 位宽**:建议先沿用现有 3-bit 饱和计数器(实现成本低,便于快速原型),并将更新从 “taken/nt 驱动” 改为 “correct/incorrect 驱动”: + - `sat_inc`:上饱和到 `CONF_MAX` + - `sat_dec`:下饱和到 `CONF_MIN` +- **weak 判定**:默认 `Conf in {0, -1}` 为 weak(与现有经验一致)。 +- **WEAK_INIT**:allocate/rewrite 时可统一初始化为 `0`(weak),并将 `U=0`。 +- **strong-but-wrong 判定**:默认可用 `Conf` 接近饱和作为 strong(例如 `Conf >= CONF_MAX-1`)。 + + + + + + + +--- + +下面的 Two-Taken 细节先不考虑实现,但文档保留作为后续扩展方向。 + +## 4. 详细设计:Two-Taken 机制 (Taken 2) + +为了解决 BPU 带宽不足,我们引入轻量级 GShare 预测紧随其后的第二个 Block。 + +### 4.1 索引策略 (Speculative Indexing) + +为了避免时序依赖,**不使用 Block 2 的 PC,而是使用 Block 1 的 PC**。 + +* **Index**: `Hash(PC_Block1, PHR)` +* 注意:这里假设 Block 1 Taken 后的 PHR 更新模式是固定的(或者忽略 Block 1 的 PHR 更新影响,直接用当前 PHR)。 + + +* **Rationale**: 我们在预测 Block 1 时,顺便问一句:“在这种历史路径下,Block 1 跳完后的下一个块,大概率会跳吗?” + +### 4.2 GShare 结构 + +* **Table Size**: 4K - 8K Entries (小容量,单读写口)。 +* **Entry**: 2-bit Sat Counter (Taken / Not Taken)。 +* **Output**: 仅指示 Block 2 **是否发生跳转**。 + +### 4.3 生成逻辑 + +1. **Condition**: 仅当 TAGE 预测 Block 1 为 **Taken** 时,启用 Two-Taken 逻辑。 +2. **Check**: 读取辅助 GShare。 +* 如果 GShare = **Not Taken**: 只发 Taken 1。 +* If GShare = **Taken**: 尝试发 Taken 2。 + + +3. **Taken 2 Target**: +* 利用 MBTB 的 **Next-Line** 能力或者 **Way 0 (First Branch)** 的信息。 +* 假设 Block 2 中最早遇到的那个分支是跳转点(这是统计学上的大概率事件)。 +* *注:如果 BTB 无法提供 Block 2 的 Target,则放弃 Taken 2。* + + + + +--- + +## 5. 讨论点 (For Discussion) + +1. **MBTB Miss / 不可映射 payload 的处理**: 若 TAGE 预测的 `ExitSlotEnc` 在当前 `btbEntries` 中找不到对应的 cond entry(MBTB miss/未学到/过滤导致),推荐回退到 Base(按 MBTB 内 cond 的 `ctr` 方向预测),而不是强制 fallthrough;否则可能出现不必要的性能回退。 +2. **Two-Taken 的 Target 精度**: 对于 Taken 2,我们只预测了“跳”,但默认它从第一条分支跳。对于复杂控制流(如 Taken 2 是一个 `if-else` 块),这可能不准。是否值得为 Taken 2 引入更复杂的逻辑? +3. **Loop Handling**: 这种 Exit-Slot 结构天然支持 Loop(ExitSlot 往往会稳定落在 Loop Back 的那条 cond 分支位置)。是否还需要单独的 Loop Predictor? + +--- + +### 下一步行动计划 + +1. **Pattern 分析 (QEMU)**: 运行脚本,确认 SPEC06 中 `Cond -> Cond` 的比例以及 Block 2 的默认跳转倾向。 +2. **原型开发**: +* 第一阶段:将 TAGE 改为 Block-Based(Exit-Slot / Cond Exit)模式,验证单 Taken 性能与资源节省情况。 +* 第二阶段:加入 GShare 辅助预测器,开启 Two-Taken 发射。 diff --git a/docs/Gem5_Docs/frontend/upperbound_report2.md b/docs/Gem5_Docs/frontend/upperbound_report2.md new file mode 100644 index 0000000000..49cc184ef2 --- /dev/null +++ b/docs/Gem5_Docs/frontend/upperbound_report2.md @@ -0,0 +1,107 @@ +# Upperbound Report: /tmp/debug/tage-new6 + +## What This Report Measures + +- This is an *offline separability upper bound* computed from `bp.db`. +- For each chosen feature key (e.g., `(startPC, history)`), we compute the best possible + accuracy under 0/1 loss by always predicting the *most frequent label* for that key + (majority vote). This is Bayes-optimal given only that key. +- It is **NOT** an oracle that peeks at the future; it quantifies whether the available + features contain enough information to separate patterns. + +### Exit-slot (per-block) label + +- Uses `TAGEMISSTRACE.realEnc` (0..32) as the true label for Exit-Slot multi-class classification. +- `UB_exit(startPC,hist)`: key is `(startPC, indexFoldedHist)`. +- `UB_exit(startPC,H)`: key is `(startPC, history_string)` (low 50 bits in current logging). + +### Direction (per-branch) label + +- Uses `TAGEMISSTRACE.actualTaken` (0/1) as the true label for direction prediction. +- `acc_dir(ref)`: measured accuracy `predTaken==actualTaken` in ref trace (if `predTaken` exists). +- `UB_dir(ref startPC,slot,hist)`: key is `(startPC, slot, indexFoldedHist)`, where + `slot = ((branchPC - startPC) >> 1) & 31` approximates in-block position identity. +- `UB_dir(ref startPC,slot,H)`: key is `(startPC, slot, history_string)`. + +### About `n/a` + +- `n/a` means the db does not have usable samples for that metric (missing table/columns, + or `TAGEMISSTRACE` exists but has 0 rows for that run). + +| bench | BP mispred opt | BP mispred ref | delta | n_exit(opt) | acc_exit(opt) | UB_exit(startPC,hist) | UB_exit(startPC,H) | n_dir(ref) | acc_dir(ref) | UB_dir(ref startPC,slot,hist) | UB_dir(ref startPC,slot,H) | +|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:| +| 2fetch | 0.01% | 0.01% | +0.00% | 20.0k | 99.9% | 100.0% | 100.0% | n/a | n/a | n/a | n/a | +| 2fetch_self | 0.02% | 0.02% | +0.00% | 10.0k | 100.0% | 100.0% | 100.0% | n/a | n/a | n/a | n/a | +| alias_branches | 0.37% | 0.31% | +0.06% | 135.7k | 99.6% | 98.7% | 99.7% | 90.9k | 99.6% | 99.4% | 100.0% | +| aliasing_pattern_test | 3.71% | 0.76% | +2.95% | 3.1k | 96.7% | 97.1% | 97.4% | 983 | 98.5% | 98.4% | 100.0% | +| all_patterns_test | 3.27% | 0.75% | +2.52% | 38.3k | 97.3% | 96.2% | 96.5% | 8.8k | 98.0% | 98.7% | 99.9% | +| alternating_test | 0.36% | 0.28% | +0.08% | 2.5k | 99.6% | 99.8% | 99.8% | 997 | 99.7% | 100.0% | 100.0% | +| aluwidth | 0.96% | 0.96% | +0.00% | 209 | 99.0% | 99.0% | 99.0% | n/a | n/a | n/a | n/a | +| always_taken_test | 0.07% | 0.07% | +0.00% | 3.0k | 99.9% | 99.9% | 99.9% | n/a | n/a | n/a | n/a | +| bias_critical | 2.21% | 0.79% | +1.43% | 57.8k | 97.0% | 97.4% | 97.6% | 59.7k | 99.1% | 98.4% | 99.1% | +| brnum | 1.71% | 1.67% | +0.03% | 1.8k | 94.7% | 98.6% | 99.2% | 1.3k | 100.0% | 100.0% | 100.0% | +| brnum2 | 0.71% | 0.71% | +0.00% | 1.4k | 96.2% | 98.8% | 99.9% | 959 | 100.0% | 100.0% | 100.0% | +| brnum2_uftb | 0.25% | 0.23% | +0.02% | 14.1k | 99.2% | 99.8% | 100.0% | 9.4k | 100.0% | 100.0% | 100.0% | +| brnum3 | 0.43% | 0.43% | +0.00% | 3.2k | 98.0% | 99.5% | 99.9% | 960 | 100.0% | 100.0% | 100.0% | +| brsimple | 1.85% | 1.85% | +0.00% | 109 | 98.2% | 98.2% | 98.2% | n/a | n/a | n/a | n/a | +| brwidth | 0.02% | 0.02% | +0.00% | 217 | 99.1% | 99.1% | 99.1% | n/a | n/a | n/a | n/a | +| call_branch | 0.92% | 0.66% | +0.26% | 5.0k | 98.6% | 99.4% | 99.7% | 2.2k | 98.5% | 99.2% | 100.0% | +| confidence_trap | 4.48% | 2.51% | +1.97% | 4.6k | 94.3% | 96.2% | 97.7% | 3.8k | 97.0% | 92.2% | 99.4% | +| coremark10 | 5.54% | 3.62% | +1.92% | 599.2k | 92.7% | 93.9% | 97.1% | 551.7k | 95.3% | 96.1% | 99.1% | +| early_exits_test | 0.38% | 0.38% | +0.00% | 1.0k | 99.6% | 99.7% | 99.7% | 11 | 100.0% | 100.0% | 100.0% | +| fetchfrag | 0.75% | 0.75% | +0.00% | 30.2k | 99.0% | 100.0% | 100.0% | n/a | n/a | n/a | n/a | +| forloop | 0.72% | 0.33% | +0.39% | 10.4k | 99.1% | 99.6% | 98.2% | 11.1k | 99.7% | 99.8% | 99.0% | +| fpuwidth | 1.85% | 1.85% | +0.00% | 109 | 98.2% | 98.2% | 98.2% | n/a | n/a | n/a | n/a | +| gradual_transition_test | 0.12% | 0.12% | +0.00% | 2.5k | 99.9% | 99.9% | 80.1% | n/a | n/a | n/a | n/a | +| ifuwidth | 1.85% | 1.85% | +0.00% | 109 | 98.2% | 98.2% | 98.2% | n/a | n/a | n/a | n/a | +| imli_fixed_pos | 1.57% | 0.01% | +1.56% | 243.9k | 98.4% | 98.4% | 98.4% | 247.9k | 100.0% | 100.0% | 100.0% | +| imli_iter | 5.64% | 3.27% | +2.37% | 12.3k | 91.3% | 94.2% | 95.0% | 15.6k | 97.1% | 97.8% | 99.8% | +| imli_phase_shift | 1.51% | 0.01% | +1.50% | 517.9k | 98.5% | 98.5% | 98.5% | 511.9k | 100.0% | 100.0% | 99.2% | +| imli_threshold | 3.04% | 1.54% | +1.50% | 164.0k | 95.1% | 100.0% | 95.1% | 230.6k | 100.0% | 100.0% | 98.4% | +| indirect_branch | 0.07% | 0.07% | +0.00% | 3.0k | 99.5% | 100.0% | 100.0% | n/a | n/a | n/a | n/a | +| indirect_branch_alternating | 0.66% | 0.73% | -0.07% | 3.0k | 99.8% | 100.0% | 100.0% | n/a | n/a | n/a | n/a | +| indirect_branch_drift | 0.15% | 0.15% | +0.00% | 3.6k | 99.5% | 86.3% | 99.9% | 499 | 100.0% | 100.0% | 100.0% | +| indirect_branch_multi | 5.35% | 6.25% | -0.90% | 3.4k | 99.8% | 100.0% | 100.0% | n/a | n/a | n/a | n/a | +| jump_branch | 0.50% | 25.05% | -24.55% | 2.3k | 99.1% | 99.6% | 99.8% | 2.2k | 55.6% | 77.8% | 100.0% | +| local_mix | 17.63% | 4.64% | +12.99% | 80.7k | 83.2% | 86.0% | 89.2% | 58.0k | 95.1% | 96.1% | 97.6% | +| local_periodic | 0.61% | 0.21% | +0.40% | 18.6k | 99.2% | 99.5% | 99.4% | 13.8k | 99.7% | 99.8% | 100.0% | +| long_period_flip | 6.37% | 3.08% | +3.29% | 61.7k | 92.3% | 93.1% | 97.1% | 36.5k | 95.7% | 94.1% | 99.1% | +| majority_vote | 8.48% | 4.04% | +4.45% | 127.8k | 89.4% | 90.9% | 91.0% | 110.9k | 95.7% | 96.5% | 96.1% | +| multi_dim_pattern | 2.22% | 0.32% | +1.90% | 37.9k | 97.6% | 97.7% | 97.8% | 30.0k | 99.6% | 99.7% | 100.0% | +| nested_branches_test | 6.25% | 2.68% | +3.58% | 5.5k | 95.9% | 94.8% | 96.0% | 2.0k | 95.7% | 97.1% | 100.0% | +| never_taken_test | 0.15% | 0.15% | +0.00% | 2.0k | 99.9% | 99.9% | 99.9% | n/a | n/a | n/a | n/a | +| path_history | 1.54% | 0.09% | +1.45% | 7.8k | 98.4% | 98.4% | 97.0% | 4.0k | 99.9% | 100.0% | 100.0% | +| path_signature | 7.12% | 7.18% | -0.06% | 40.5k | 99.8% | 100.0% | 100.0% | 6.0k | 99.6% | 99.7% | 100.0% | +| prime_based_pattern_test | 6.43% | 0.83% | +5.60% | 4.3k | 96.2% | 96.8% | 96.4% | 1.0k | 98.3% | 94.1% | 100.0% | +| rare_branches_test | 0.99% | 0.64% | +0.35% | 2.1k | 99.0% | 99.0% | 98.8% | 902 | 99.0% | 99.2% | 99.0% | +| ras_recursive | 2.22% | 2.22% | +0.00% | 43 | 97.7% | 97.7% | 97.7% | n/a | n/a | n/a | n/a | +| rastest | 0.65% | 0.65% | +0.00% | 309 | 97.1% | 99.7% | 99.7% | n/a | n/a | n/a | n/a | +| renamewidth | 0.39% | 0.39% | +0.00% | 509 | 99.6% | 99.6% | 99.6% | n/a | n/a | n/a | n/a | +| resolve | 2.58% | 2.26% | +0.32% | 316 | 97.5% | 98.4% | 98.4% | 100 | 97.0% | 100.0% | 100.0% | +| return_branch | 0.45% | 0.38% | +0.07% | 4.8k | 99.1% | 99.7% | 99.7% | 2.2k | 99.4% | 99.9% | 100.0% | +| switching_pattern_test | 4.90% | 0.80% | +4.10% | 5.4k | 96.7% | 96.8% | 96.8% | 963 | 97.6% | 98.5% | 100.0% | +| tage1 | 0.49% | 9.67% | -9.18% | 3.6k | 98.5% | 99.9% | 99.9% | 10.0k | 89.5% | 100.0% | 100.0% | +| tage2 | 0.64% | 0.87% | -0.23% | 1.8k | 96.1% | 98.5% | 97.0% | 2.7k | 97.2% | 99.7% | 98.8% | +| tage3 | 0.40% | 0.35% | +0.05% | 1.5k | 99.5% | 99.6% | 99.7% | 998 | 99.7% | 100.0% | 100.0% | +| tage4 | 0.45% | 0.35% | +0.10% | 1.5k | 99.4% | 99.8% | 99.8% | 997 | 99.7% | 100.0% | 100.0% | +| tage5 | 14.29% | 14.29% | +0.00% | 52 | 82.7% | 80.8% | 92.3% | 3 | 100.0% | 100.0% | 100.0% | +| tage_aliasing | 1.46% | 0.42% | +1.04% | 33.6k | 98.2% | 98.7% | 98.6% | 39.9k | 99.6% | 99.6% | 99.5% | +| test_stringlen_v1 | 0.98% | 0.62% | +0.36% | 18.3k | 99.4% | 99.6% | 99.9% | 12.1k | 99.3% | 97.9% | 99.9% | +| test_stringlen_v2 | 2.17% | 1.30% | +0.87% | 37.9k | 97.8% | 98.3% | 99.8% | 25.1k | 98.3% | 98.3% | 99.9% | +| test_stringlen_v3 | 4.61% | 1.73% | +2.88% | 26.4k | 95.6% | 95.9% | 96.5% | 14.4k | 97.7% | 98.0% | 99.1% | +| three_bit_pattern_test | 6.93% | 0.52% | +6.41% | 4.4k | 96.0% | 96.2% | 96.1% | 993 | 99.0% | 99.5% | 100.0% | +| two_bit_pattern_test | 10.04% | 0.36% | +9.68% | 4.0k | 93.7% | 94.0% | 93.7% | 995 | 99.5% | 99.7% | 100.0% | +| weak_correlation | 17.55% | 12.96% | +4.59% | 72.3k | 88.6% | 86.3% | 94.2% | 48.1k | 88.0% | 89.7% | 99.8% | +| xor_dependency | 15.31% | 0.22% | +15.09% | 41.2k | 90.0% | 89.3% | 90.1% | 18.0k | 99.7% | 99.8% | 100.0% | + +## Biggest BP mispred regressions (opt - ref) +- xor_dependency: +15.09% +- local_mix: +12.99% +- two_bit_pattern_test: +9.68% +- three_bit_pattern_test: +6.41% +- prime_based_pattern_test: +5.60% +- weak_correlation: +4.59% +- majority_vote: +4.45% +- switching_pattern_test: +4.10% +- nested_branches_test: +3.58% +- long_period_flip: +3.29% diff --git a/src/cpu/pred/btb/btb_ittage.cc b/src/cpu/pred/btb/btb_ittage.cc index 58828467cd..66281d09c5 100644 --- a/src/cpu/pred/btb/btb_ittage.cc +++ b/src/cpu/pred/btb/btb_ittage.cc @@ -485,10 +485,8 @@ BTBITTAGE::doUpdateHist(const boost::dynamic_bitset<> &history, bool taken, Addr boost::to_string(history, buf); DPRINTF(ITTAGEHistory, "in doUpdateHist, taken %d, pc %#lx, history %s\n", taken, pc, buf.c_str()); } - if (!taken) { - DPRINTF(ITTAGEHistory, "not updating folded history, since FB not taken\n"); - return; - } + // Strategy B: keep folded path history evolving even on fall-through by using a pseudo edge. + // (Callers are expected to pass a meaningful (pc,target) when taken==false.) for (int t = 0; t < numPredictors; t++) { for (int type = 0; type < 3; type++) { @@ -531,6 +529,10 @@ void BTBITTAGE::specUpdatePHist(const boost::dynamic_bitset<> &history, FullBTBPrediction &pred) { auto [pc, target, taken] = pred.getPHistInfo(); + if (!taken) { + pc = pred.bbStart; + target = pred.bbStart + blockSize; + } doUpdateHist(history, taken, pc, target); } @@ -556,7 +558,13 @@ BTBITTAGE::recoverPHist(const boost::dynamic_bitset<> &history, const FetchTarge altTagFoldedHist[i].recover(predMeta->altTagFoldedHist[i]); indexFoldedHist[i].recover(predMeta->indexFoldedHist[i]); } - doUpdateHist(history, cond_taken, entry.getControlPC(), entry.getTakenTarget()); + Addr pc = entry.getControlPC(); + Addr target = entry.getTakenTarget(); + if (!cond_taken) { + pc = entry.startPC; + target = entry.startPC + blockSize; + } + doUpdateHist(history, cond_taken, pc, target); } void diff --git a/src/cpu/pred/btb/btb_mgsc.cc b/src/cpu/pred/btb/btb_mgsc.cc index 0211ad7eaa..0da22037a3 100755 --- a/src/cpu/pred/btb/btb_mgsc.cc +++ b/src/cpu/pred/btb/btb_mgsc.cc @@ -1080,6 +1080,11 @@ void BTBMGSC::specUpdatePHist(const boost::dynamic_bitset<> &history, FullBTBPrediction &pred) { auto [pc, target, taken] = pred.getPHistInfo(); + if (!taken) { + // Strategy B: pseudo edge for fall-through to keep PHR/folded PHR evolving. + pc = pred.bbStart; + target = pred.bbStart + blockSize; + } doUpdateHist(history, 2, taken, indexPFoldedHist, pc, target); // only path history needs pc! } @@ -1199,7 +1204,13 @@ BTBMGSC::recoverPHist(const boost::dynamic_bitset<> &history, const FetchTarget for (int i = 0; i < pTableNum; i++) { indexPFoldedHist[i].recover(predMeta->indexPFoldedHist[i]); } - doUpdateHist(history, 2, cond_taken, indexPFoldedHist, entry.getControlPC(), entry.getTakenTarget()); + Addr pc = entry.getControlPC(); + Addr target = entry.getTakenTarget(); + if (!cond_taken) { + pc = entry.startPC; + target = entry.startPC + blockSize; + } + doUpdateHist(history, 2, cond_taken, indexPFoldedHist, pc, target); } /** diff --git a/src/cpu/pred/btb/btb_tage.cc b/src/cpu/pred/btb/btb_tage.cc index be55eab771..7282b32479 100644 --- a/src/cpu/pred/btb/btb_tage.cc +++ b/src/cpu/pred/btb/btb_tage.cc @@ -49,7 +49,8 @@ BTBTAGE::BTBTAGE(unsigned numPredictors, unsigned numWays, unsigned tableSize, u indexShift(bankBaseShift + ceilLog2(numBanks)), enableBankConflict(false), lastPredBankId(0), - predBankValid(false) + predBankValid(false), + tageStats() { setNumDelay(1); @@ -148,34 +149,47 @@ void BTBTAGE::setTrace() { #ifndef UNIT_TEST - if (enableDB) { - std::vector> fields_vec = { - std::make_pair("startPC", UINT64), - std::make_pair("branchPC", UINT64), - std::make_pair("wayIdx", UINT64), - std::make_pair("mainFound", UINT64), - std::make_pair("mainCounter", UINT64), - std::make_pair("mainUseful", UINT64), - std::make_pair("mainTable", UINT64), - std::make_pair("mainIndex", UINT64), - std::make_pair("altFound", UINT64), - std::make_pair("altCounter", UINT64), - std::make_pair("altUseful", UINT64), - std::make_pair("altTable", UINT64), - std::make_pair("altIndex", UINT64), - std::make_pair("useAlt", UINT64), - std::make_pair("predTaken", UINT64), - std::make_pair("actualTaken", UINT64), - std::make_pair("allocSuccess", UINT64), - std::make_pair("allocTable", UINT64), - std::make_pair("allocIndex", UINT64), - std::make_pair("allocWay", UINT64), - std::make_pair("history", TEXT), - std::make_pair("indexFoldedHist", UINT64), - }; - tageMissTrace = _db->addAndGetTrace("TAGEMISSTRACE", fields_vec); - tageMissTrace->init_table(); - } + if (enableDB) { + std::vector> fields_vec = { + std::make_pair("startPC", UINT64), + std::make_pair("branchPC", UINT64), + std::make_pair("wayIdx", UINT64), + std::make_pair("mainFound", UINT64), + std::make_pair("mainCounter", UINT64), + std::make_pair("mainUseful", UINT64), + std::make_pair("mainTable", UINT64), + std::make_pair("mainIndex", UINT64), + std::make_pair("altFound", UINT64), + std::make_pair("altCounter", UINT64), + std::make_pair("altUseful", UINT64), + std::make_pair("altTable", UINT64), + std::make_pair("altIndex", UINT64), + std::make_pair("useAlt", UINT64), + std::make_pair("predTaken", UINT64), + std::make_pair("actualTaken", UINT64), + std::make_pair("allocSuccess", UINT64), + std::make_pair("allocTable", UINT64), + std::make_pair("allocIndex", UINT64), + std::make_pair("allocWay", UINT64), + std::make_pair("history", TEXT), + std::make_pair("indexFoldedHist", UINT64), + // Exit-slot debug fields (block-level) + std::make_pair("mainTag", UINT64), + std::make_pair("altTag", UINT64), + std::make_pair("mainPayload", UINT64), + std::make_pair("altPayload", UINT64), + std::make_pair("mainPayload1", UINT64), + std::make_pair("altPayload1", UINT64), + std::make_pair("mainSel", UINT64), + std::make_pair("altSel", UINT64), + std::make_pair("baseEnc", UINT64), + std::make_pair("predEnc", UINT64), + std::make_pair("realEnc", UINT64), + std::make_pair("predSource", UINT64), + }; + tageMissTrace = _db->addAndGetTrace("TAGEMISSTRACE", fields_vec); + tageMissTrace->init_table(); + } #endif } @@ -185,140 +199,269 @@ BTBTAGE::tick() {} void BTBTAGE::tickStart() {} +namespace +{ +inline bool +isWeakConf(uint8_t conf) +{ + // 3-bit saturating confidence counter (0..7). + // Weak = 0/1, strong = 6/7. + return conf <= 1; +} + +inline bool +isStrongConf(uint8_t conf) +{ + return conf >= 6; +} + +inline void +satIncConf(uint8_t &conf) +{ + if (conf < 7) { + conf++; + } +} + +inline void +satDecConf(uint8_t &conf) +{ + if (conf > 0) { + conf--; + } +} + +inline void +updateConf(bool correct, uint8_t &conf) +{ + if (correct) { + satIncConf(conf); + } else { + satDecConf(conf); + } +} + +inline void +satIncSel(uint8_t &sel) +{ + if (sel < 3) { + sel++; + } +} + +inline void +satDecSel(uint8_t &sel) +{ + if (sel > 0) { + sel--; + } +} +} // namespace + /** - * @brief Generate prediction for a single BTB entry by searching TAGE tables - * - * @param btb_entry The BTB entry to generate prediction for - * @param startPC The starting PC address for calculating indices and tags - * @param predMeta Optional prediction metadata; if provided, use snapshot for index/tag - * calculation (update path); if nullptr, use current folded history (prediction path) - * @return TagePrediction containing main and alternative predictions + * @brief Lookup provider/alt entries for this fetch block. */ -BTBTAGE::TagePrediction -BTBTAGE::generateSinglePrediction(const BTBEntry &btb_entry, - const Addr &startPC, - std::shared_ptr predMeta) { - DPRINTF(TAGE, "generateSinglePrediction for btbEntry: %#lx\n", btb_entry.pc); - - // Find main and alternative predictions +std::pair +BTBTAGE::lookupProviders(const Addr &startPC, std::shared_ptr predMeta) +{ bool provided = false; bool alt_provided = false; TageTableInfo main_info, alt_info; - // Search from highest to lowest table for matches - // Calculate branch position within the block (like RTL's cfiPosition) - unsigned position = getBranchIndexInBlock(btb_entry.pc, startPC); - for (int i = numPredictors - 1; i >= 0; --i) { - // Calculate index and tag: use snapshot if provided, otherwise use current folded history - // Tag includes position XOR (like RTL: tag = tempTag ^ cfiPosition) Addr index = predMeta ? getTageIndex(startPC, i, predMeta->indexFoldedHist[i].get()) - : getTageIndex(startPC, i); + : getTageIndex(startPC, i); Addr tag = predMeta ? getTageTag(startPC, i, - predMeta->tagFoldedHist[i].get(), predMeta->altTagFoldedHist[i].get(), position) - : getTageTag(startPC, i, position); + predMeta->tagFoldedHist[i].get(), + predMeta->altTagFoldedHist[i].get()) + : getTageTag(startPC, i); - bool match = false; // for each table, only one way can be matched + bool match = false; TageEntry matching_entry; unsigned matching_way = 0; - // Search all ways for a matching entry for (unsigned way = 0; way < numWays; way++) { auto &entry = tageTable[i][index][way]; - // entry valid, tag match (position already encoded in tag, no need to check pc) if (entry.valid && tag == entry.tag) { matching_entry = entry; matching_way = way; match = true; - - // Do not use LRU; keep logic simple and align with CBP-style replacement - - DPRINTF(TAGE, "hit table %d[%lu][%u]: valid %d, tag %lu, ctr %d, useful %d, btb_pc %#lx, pos %u\n", - i, index, way, entry.valid, entry.tag, entry.counter, entry.useful, btb_entry.pc, position); - break; // only one way can be matched, aviod multi hit, TODO: RTL how to do this? + DPRINTF(TAGE, + "hit table %d[%lu][%u]: tag %lu, conf %d, u %d, enc0 %u, enc1 %u, sel %u\n", + i, index, way, entry.tag, entry.conf, entry.useful, + entry.exitSlotEnc0, entry.exitSlotEnc1, entry.selCtr); + break; } } if (match) { if (!provided) { - // First match becomes main prediction main_info = TageTableInfo(true, matching_entry, i, index, tag, matching_way); provided = true; } else if (!alt_provided) { - // Second match becomes alternative prediction alt_info = TageTableInfo(true, matching_entry, i, index, tag, matching_way); alt_provided = true; break; } - } else { - DPRINTF(TAGE, "miss table %d[%lu] for tag %lu (with pos %u), btb_pc %#lx\n", - i, index, tag, position, btb_entry.pc); } } - // Generate final prediction - bool main_taken = main_info.taken(); - bool alt_taken = alt_info.taken(); - // Use base table instead of btb_entry.ctr - bool base_taken = btb_entry.ctr >= 0; - //bool base_taken = btb_entry.ctr >= 0; - bool alt_pred = alt_provided ? alt_taken : base_taken; // if alt provided, use alt prediction, otherwise use base + return {main_info, alt_info}; +} - // use_alt_on_na gating: when provider weak, consult per-PC counter - bool use_alt = false; - if (!provided) { - use_alt = true; - } else { - bool main_weak = (main_info.entry.counter == 0 || main_info.entry.counter == -1); - if (main_weak) { - Addr uidx = getUseAltIdx(btb_entry.pc); - use_alt = (useAlt[uidx] >= 0); - } else { - use_alt = false; +uint8_t +BTBTAGE::getBaseExitSlotEnc(const Addr &startPC, + const std::vector &btbEntries) const +{ + // Base: scan cond branches in PC order; choose the first predicted-taken cond. + for (auto &e : btbEntries) { + if (!(e.valid && e.isCond)) { + continue; + } + const bool pred_taken = e.alwaysTaken || (e.ctr >= 0); + if (pred_taken) { + unsigned slot = getBranchIndexInBlock(e.pc, startPC); + return static_cast(slot + 1); } } - bool taken = use_alt ? alt_pred : main_taken; - - DPRINTF(TAGE, "tage predict %#lx taken %d\n", btb_entry.pc, taken); - DPRINTF(TAGE, "tage use_alt %d ? (alt_provided %d ? alt_taken %d : base_taken %d) : main_taken %d\n", - use_alt, alt_provided, alt_taken, base_taken, main_taken); + return 0; +} - return TagePrediction(btb_entry.pc, main_info, alt_info, use_alt, taken, alt_pred); +Addr +BTBTAGE::mapExitSlotToCondPC(const Addr &startPC, + const std::vector &btbEntries, + uint8_t predEnc) const +{ + if (predEnc == 0 || predEnc > 32) { + return 0; + } + const unsigned pred_slot = predEnc - 1; + for (auto &e : btbEntries) { + if (!(e.valid && e.isCond)) { + continue; + } + if (getBranchIndexInBlock(e.pc, startPC) == pred_slot) { + return e.pc; + } + } + return 0; } -/** - * @brief Look up predictions in TAGE tables for a stream of instructions - * - * @param startPC The starting PC address for the instruction stream - * @param btbEntries Vector of BTB entries to make predictions for - * @return Map of branch PC addresses to their predicted outcomes - */ void BTBTAGE::lookupHelper(const Addr &startPC, const std::vector &btbEntries, - std::unordered_map &tageInfoForMgscs, CondTakens& results) + std::unordered_map &tageInfoForMgscs, + CondTakens &results) { - DPRINTF(TAGE, "lookupHelper startAddr: %#lx\n", startPC); - - // Process each BTB entry to make predictions - for (auto &btb_entry : btbEntries) { - // Only predict for valid conditional branches - if (btb_entry.isCond && btb_entry.valid) { - auto pred = generateSinglePrediction(btb_entry, startPC); - meta->preds[btb_entry.pc] = pred; - tageStats.updateStatsWithTagePrediction(pred, true); - results.push_back({btb_entry.pc, pred.taken || btb_entry.alwaysTaken}); - tageInfoForMgscs[btb_entry.pc].tage_pred_taken = pred.taken; - tageInfoForMgscs[btb_entry.pc].tage_main_taken = pred.mainInfo.found ? pred.mainInfo.taken() : false; - tageInfoForMgscs[btb_entry.pc].tage_pred_conf_high = pred.mainInfo.found && - abs(pred.mainInfo.entry.counter*2 + 1) == 7; // counter saturated, -4 or 3 - tageInfoForMgscs[btb_entry.pc].tage_pred_conf_mid = pred.mainInfo.found && - (abs(pred.mainInfo.entry.counter*2 + 1) < 7 && - abs(pred.mainInfo.entry.counter*2 + 1) > 1); // counter not saturated, -3, -2, 1, 2 - tageInfoForMgscs[btb_entry.pc].tage_pred_conf_low = !pred.mainInfo.found || - (abs(pred.mainInfo.entry.counter*2 + 1) <= 1); // counter initialized, -1 or 0 - // main predict is different from alt predict/base predict - tageInfoForMgscs[btb_entry.pc].tage_pred_alt_diff = pred.mainInfo.found && pred.mainInfo.taken() != pred.altPred; + DPRINTF(TAGE, "lookupHelper(startPC=%#lx)\n", startPC); + + tageInfoForMgscs.clear(); + + const uint8_t baseEnc = getBaseExitSlotEnc(startPC, btbEntries); + auto [main_info, alt_info] = lookupProviders(startPC); + + bool use_alt = false; + PredSource source = PredSource::Base; + uint8_t predEnc = baseEnc; + + if (main_info.found) { + const bool weak = isWeakConf(main_info.entry.conf); + if (weak) { + Addr uidx = getUseAltIdx(startPC); + // Exit-Slot v2: useAltOnNa acts as a conservative gate to fall back to Base + // when Provider is weak (instead of using Alt). + use_alt = (useAlt[uidx] >= 0); // true => use Base, false => use Provider even if weak + } + + if (!weak) { + source = PredSource::Provider; + predEnc = main_info.entry.selectedEnc(); + } else if (use_alt) { + source = PredSource::Base; + predEnc = baseEnc; + } else { + source = PredSource::Provider; + predEnc = main_info.entry.selectedEnc(); } + } else { + use_alt = true; // consistent with old "no provider => consult base" + source = PredSource::Base; + predEnc = baseEnc; + } + + Addr predCondPC = mapExitSlotToCondPC(startPC, btbEntries, predEnc); + bool payloadMapped = (predEnc != 0) && (predCondPC != 0); + + // If payload cannot be mapped to current MBTB entries, fall back to base as PRD suggests. + if (source != PredSource::Base && predEnc != 0 && !payloadMapped) { + tageStats.predPayloadMapFail++; + source = PredSource::Base; + predEnc = baseEnc; + predCondPC = mapExitSlotToCondPC(startPC, btbEntries, predEnc); + payloadMapped = (predEnc != 0) && (predCondPC != 0); + } + + if (source == PredSource::Base) { + tageStats.predBaseFallback++; + } + if (predEnc == 0) { + tageStats.predNoCondExit++; + } + + TagePrediction pred(startPC, main_info, alt_info, + use_alt, source, predEnc, baseEnc, + payloadMapped, predCondPC); + meta->pred = pred; + meta->hasPred = true; + + tageStats.updateStatsWithTagePrediction(pred, true); + + // Fill per-branch TAGE info for MGSC, and condTakens for control-flow selection. + // - If source==Base: provide a direction prediction for each cond branch (like old behavior). + // - Else: only mark the predicted exit cond as taken; others are implicitly NT. + if (source == PredSource::Base) { + for (auto &e : btbEntries) { + if (!(e.valid && e.isCond)) { + continue; + } + const bool base_taken = (e.ctr >= 0); + results.push_back({e.pc, e.alwaysTaken || base_taken}); + } + } else if (predCondPC != 0) { + results.push_back({predCondPC, true}); + } + + // MGSC expects an entry for every cond BTB entry. + const uint8_t altOrBaseEnc = baseEnc; // Alt unused in Exit-Slot v2 + const bool provider_alt_diff = main_info.found && (main_info.entry.selectedEnc() != altOrBaseEnc); + const int provider_conf_metric = main_info.found ? main_info.entry.conf : 0; + + for (auto &e : btbEntries) { + if (!(e.valid && e.isCond)) { + continue; + } + auto &info = tageInfoForMgscs[e.pc]; + + bool pred_taken_no_always = false; + if (source == PredSource::Base) { + pred_taken_no_always = (e.ctr >= 0); + } else { + pred_taken_no_always = (predCondPC != 0) && (e.pc == predCondPC); + } + + info.tage_pred_taken = pred_taken_no_always; + info.tage_main_taken = (source == PredSource::Provider) && pred_taken_no_always; + + if ((source == PredSource::Provider) && pred_taken_no_always && main_info.found) { + info.tage_pred_conf_high = provider_conf_metric >= 6; + info.tage_pred_conf_mid = (provider_conf_metric < 6) && (provider_conf_metric > 1); + info.tage_pred_conf_low = provider_conf_metric <= 1; + } else { + info.tage_pred_conf_high = false; + info.tage_pred_conf_mid = false; + info.tage_pred_conf_low = true; + } + + info.tage_pred_alt_diff = provider_alt_diff; } } @@ -384,181 +527,17 @@ BTBTAGE::getPredictionMeta() { } /** - * @brief Prepare BTB entries for update by filtering and processing - * - * @param stream The fetch stream containing update information - * @return Vector of BTB entries that need to be updated - */ -std::vector -BTBTAGE::prepareUpdateEntries(const FetchTarget &stream) { - auto all_entries = stream.updateBTBEntries; - - // Add potential new BTB entry if it's a btb miss during prediction - if (!stream.updateIsOldEntry) { - BTBEntry potential_new_entry = stream.updateNewBTBEntry; - bool new_entry_taken = stream.exeTaken && stream.getControlPC() == potential_new_entry.pc; - if (!new_entry_taken) { - potential_new_entry.alwaysTaken = false; - } - all_entries.push_back(potential_new_entry); - } - - // Filter: only keep conditional branches that are not always taken - if (getResolvedUpdate()) { - auto remove_it = std::remove_if(all_entries.begin(), all_entries.end(), - [](const BTBEntry &e) { return !(e.isCond && !e.alwaysTaken && e.resolved); }); - all_entries.erase(remove_it, all_entries.end()); - } else { - auto remove_it = std::remove_if(all_entries.begin(), all_entries.end(), - [](const BTBEntry &e) { return !(e.isCond && !e.alwaysTaken); }); - all_entries.erase(remove_it, all_entries.end()); - } - - return all_entries; -} - -/** - * @brief Update predictor state for a single entry - * - * @param entry The BTB entry being updated - * @param actual_taken The actual outcome of the branch - * @param pred The prediction made for this entry - * @param stream The fetch stream containing update information - * @return true if need to allocate new entry - */ -bool -BTBTAGE::updatePredictorStateAndCheckAllocation(const BTBEntry &entry, - bool actual_taken, - const TagePrediction &pred, - const FetchTarget &stream) { - tageStats.updateStatsWithTagePrediction(pred, false); - - auto &main_info = pred.mainInfo; - auto &alt_info = pred.altInfo; - bool used_alt = pred.useAlt; - // Use base table instead of entry.ctr for fallback prediction - Addr startPC = stream.getRealStartPC(); - bool base_taken = entry.ctr >= 0; - bool alt_taken = alt_info.found ? alt_info.taken() : base_taken; - - // Update use_alt_on_na when provider is weak (0 or -1) - if (main_info.found) { - bool main_weak = (main_info.entry.counter == 0 || main_info.entry.counter == -1); - if (main_weak) { - tageStats.updateProviderNa++; - Addr uidx = getUseAltIdx(entry.pc); - bool alt_correct = (alt_taken == actual_taken); - updateCounter(alt_correct, useAltOnNaWidth, useAlt[uidx]); - tageStats.updateUseAltOnNaUpdated++; - if (alt_correct) { - tageStats.updateUseAltOnNaCorrect++; - } else { - tageStats.updateUseAltOnNaWrong++; - } - } - } - - // Update main prediction provider - if (main_info.found) { - DPRINTF(TAGE, "prediction provided by table %d, idx %lu, way %u, updating corresponding entry\n", - main_info.table, main_info.index, main_info.way); - - auto &way = tageTable[main_info.table][main_info.index][main_info.way]; - - // Update prediction counter - updateCounter(actual_taken, 3, way.counter); - - // Update useful bit based on several conditions - bool main_is_correct = main_info.taken() == actual_taken; - bool alt_is_correct_and_strong = alt_info.found && - (alt_info.taken() == actual_taken) && - (abs(2 * alt_info.entry.counter + 1) == 7); - - // a. Special reset (humility mechanism) - if (alt_is_correct_and_strong && main_is_correct) { - way.useful = 0; - DPRINTF(TAGEUseful, "useful bit reset to 0 due to humility rule\n"); - } else if (main_info.taken() != alt_taken) { - // b. Original logic to set useful bit high - if (main_is_correct) { - way.useful = 1; - } - } - - // c. Reset u on counter sign flip (becomes weak) - if (way.counter == 0 || way.counter == -1) { - way.useful = 0; - DPRINTF(TAGEUseful, "useful bit reset to 0 due to weak counter\n"); - } - DPRINTF(TAGE, "useful bit is now %d\n", way.useful); - - // No LRU maintenance - } - - // Update alternative prediction provider - if (used_alt && alt_info.found) { - auto &way = tageTable[alt_info.table][alt_info.index][alt_info.way]; - updateCounter(actual_taken, 3, way.counter); - // No LRU maintenance - } - - // Update statistics - if (used_alt) { - bool alt_correct = alt_taken == actual_taken; - if (alt_correct) { - tageStats.updateUseAltCorrect++; - } else { - tageStats.updateUseAltWrong++; - } - if (main_info.found && main_info.taken() != alt_taken) { - tageStats.updateAltDiffers++; - } - } - - // Check if misprediction occurred - bool this_fb_mispred = stream.squashType == SquashType::SQUASH_CTRL && - stream.squashPC == entry.pc; - if (getDelay() == 2){ - if (this_fb_mispred) { - tageStats.updateMispred++; - if (!used_alt && main_info.found) { -#ifndef UNIT_TEST - tageStats.updateTableMispreds[main_info.table]++; -#endif - } - } - } - - // No allocation if no misprediction - if (!this_fb_mispred) { - return false; - } - - // Special case: provider is weak but direction is correct - // In this case, provider just needs more training, not a longer history table - // This avoids wasteful allocation and prevents ping-pong effects - if (used_alt && main_info.found && main_info.taken() == actual_taken) { - return false; - } - - // All other cases: allocate longer history table - return true; -} - -/** - * @brief Handle allocation of new entries - * + * @brief Handle allocation of new entries (block-level). + * * @param startPC The starting PC address - * @param entry The BTB entry being updated - * @param actual_taken The actual outcome of the branch + * @param realEnc The actual ExitSlotEnc (0..32) * @param start_table The starting table for allocation * @param meta The metadata of the predictor * @return true if allocation is successful */ bool BTBTAGE::handleNewEntryAllocation(const Addr &startPC, - const BTBEntry &entry, - bool actual_taken, + uint8_t realEnc, unsigned start_table, std::shared_ptr meta, uint64_t &allocated_table, @@ -569,25 +548,24 @@ BTBTAGE::handleNewEntryAllocation(const Addr &startPC, // - Prefer invalid ways; else choose any way with useful==0 and weak counter. // - If none, apply a one-step age penalty to a strong, not-useful way (no allocation). - // Calculate branch position within the block (like RTL's cfiPosition) - unsigned position = getBranchIndexInBlock(entry.pc, startPC); - for (unsigned ti = start_table; ti < numPredictors; ++ti) { Addr newIndex = getTageIndex(startPC, ti, meta->indexFoldedHist[ti].get()); Addr newTag = getTageTag(startPC, ti, - meta->tagFoldedHist[ti].get(), meta->altTagFoldedHist[ti].get(), position); + meta->tagFoldedHist[ti].get(), meta->altTagFoldedHist[ti].get()); auto &set = tageTable[ti][newIndex]; // Allocate into invalid way or not-useful and weak way for (unsigned way = 0; way < numWays; ++way) { auto &cand = set[way]; - const bool weakish = std::abs(cand.counter * 2 + 1) <= 3; // -3,-2,-1,0,1,2 + const bool weakish = isWeakConf(cand.conf); if (!cand.valid || (!cand.useful && weakish)) { - short newCounter = actual_taken ? 0 : -1; - DPRINTF(TAGE, "allocating entry in table %d[%lu][%u], tag %lu (with pos %u), counter %d, pc %#lx\n", - ti, newIndex, way, newTag, position, newCounter, entry.pc); - cand = TageEntry(newTag, newCounter, entry.pc); // u = 0 default + uint8_t newConf = 0; // weak init + DPRINTF(TAGE, + "allocating entry in table %d[%lu][%u], tag %lu, conf %d, exitEnc %u\n", + ti, newIndex, way, newTag, newConf, realEnc); + // Allocate with a single known candidate; the second candidate is empty (0). + cand = TageEntry(newTag, newConf, realEnc, 0, 0); // u = 0 default tageStats.updateAllocSuccess++; allocated_table = ti; allocated_index = newIndex; @@ -600,11 +578,11 @@ BTBTAGE::handleNewEntryAllocation(const Addr &startPC, // 3) Apply age penalty to one strong, not-useful way to make it replacable later for (unsigned way = 0; way < numWays; ++way) { auto &cand = set[way]; - const bool weakish = std::abs(cand.counter * 2 + 1) <= 3; + const bool weakish = isWeakConf(cand.conf); if (!cand.useful && !weakish) { - if (cand.counter > 0) cand.counter--; else cand.counter++; - DPRINTF(TAGE, "age penalty applied on table %d[%lu][%u], new ctr %d\n", - ti, newIndex, way, cand.counter); + satDecConf(cand.conf); + DPRINTF(TAGE, "age penalty applied on table %d[%lu][%u], new conf %u\n", + ti, newIndex, way, cand.conf); break; // one penalty per table per update } } @@ -685,119 +663,280 @@ BTBTAGE::update(const FetchTarget &stream) { DPRINTF(TAGE, "update startAddr: %#lx, bank: %u\n", startAddr, updateBank); - // ========== Normal Update Logic ========== - // Prepare BTB entries to update - auto entries_to_update = prepareUpdateEntries(stream); - - // Get prediction metadata snapshot and bind to member for helpers auto predMeta = std::static_pointer_cast(stream.predMetas[getComponentIdx()]); - if (!predMeta) { + if (!predMeta || !predMeta->hasPred) { DPRINTF(TAGE, "update: no prediction meta, skip\n"); return; } - // Process each BTB entry + const TagePrediction &pred_at_pred = predMeta->pred; + + // RealEnc is defined on cond dimension only. + uint8_t realEnc = 0; + if (stream.exeTaken && stream.exeBranchInfo.isCond) { + unsigned real_slot = getBranchIndexInBlock(stream.exeBranchInfo.pc, startAddr); + realEnc = static_cast(real_slot + 1); + } + + const bool correct = (pred_at_pred.predEnc == realEnc); + + // Recompute provider/alt for update-on-read, or use stored info. + TageTableInfo main_info, alt_info; + if (updateOnRead) { + std::tie(main_info, alt_info) = lookupProviders(startAddr, predMeta); + } else { + main_info = pred_at_pred.mainInfo; + alt_info = pred_at_pred.altInfo; + } + + // Track recomputed-vs-original differences (block-level). bool hasRecomputedVsActualDiff = false; bool hasRecomputedVsOriginalDiff = false; - for (auto &btb_entry : entries_to_update) { - bool actual_taken = stream.exeTaken && stream.exeBranchInfo == btb_entry; - TagePrediction recomputed; - if (updateOnRead) { // if update on read is enabled, re-read providers using snapshot - // Re-read providers using snapshot (do not rely on prediction-time main/alt) - recomputed = generateSinglePrediction(btb_entry, startAddr, predMeta); - // Track differences for statistics - auto it = predMeta->preds.find(btb_entry.pc); - if (it != predMeta->preds.end() && recomputed.taken != it->second.taken) { - hasRecomputedVsOriginalDiff = true; + if (updateOnRead) { + const uint8_t baseEnc = pred_at_pred.baseEnc; + bool use_alt = false; + PredSource src = PredSource::Base; + uint8_t recEnc = baseEnc; + if (main_info.found) { + const bool weak = isWeakConf(main_info.entry.conf); + if (weak) { + Addr uidx = getUseAltIdx(startAddr); + use_alt = (useAlt[uidx] >= 0); // true => use Base (conservative) } - } else { // otherwise, use the prediction from the prediction-time main/alt - recomputed = predMeta->preds[btb_entry.pc]; - } - if (recomputed.taken != actual_taken) { - hasRecomputedVsActualDiff = true; - } - - // Update predictor state and check if need to allocate new entry - bool need_allocate = updatePredictorStateAndCheckAllocation(btb_entry, actual_taken, recomputed, stream); - - // Handle new entry allocation if needed - bool alloc_success = false; - uint64_t allocated_table = 0; - uint64_t allocated_index = 0; - uint64_t allocated_way = 0; - if (need_allocate) { - - // Handle allocation of new entries - uint start_table = 0; - auto &main_info = recomputed.mainInfo; - if (main_info.found) { - start_table = main_info.table + 1; // start from the table after the main prediction table + if (!weak) { + src = PredSource::Provider; + recEnc = main_info.entry.selectedEnc(); + } else if (use_alt) { + src = PredSource::Base; + recEnc = baseEnc; + } else { + src = PredSource::Provider; + recEnc = main_info.entry.selectedEnc(); } - alloc_success = handleNewEntryAllocation(startAddr, btb_entry, actual_taken, - start_table, predMeta, allocated_table, allocated_index, allocated_way); + } else { + src = PredSource::Base; + recEnc = baseEnc; } - -#ifndef UNIT_TEST - if (enableDB) { - TageMissTrace t; - std::string history_str; - boost::dynamic_bitset<> history_low50 = predMeta->history; - if (history_low50.size() > 50) { - history_low50.resize(50); // get the lower 50 bits of history - } - boost::to_string(history_low50, history_str); - auto main_info = recomputed.mainInfo; - auto alt_info = recomputed.altInfo; - t.set(startAddr, btb_entry.pc, main_info.way, - main_info.found, main_info.entry.counter, main_info.entry.useful, - main_info.table, main_info.index, - alt_info.found, alt_info.entry.counter, alt_info.entry.useful, - alt_info.table, alt_info.index, - recomputed.useAlt, recomputed.taken, actual_taken, alloc_success, - allocated_table, allocated_index, allocated_way, - history_str, predMeta->indexFoldedHist[main_info.table].get()); - tageMissTrace->write_record(t); + // Use prediction-time BTB entries for payload mapping check. + if (src != PredSource::Base && recEnc != 0 && + mapExitSlotToCondPC(startAddr, stream.predBTBEntries, recEnc) == 0) { + src = PredSource::Base; + recEnc = baseEnc; } -#endif + hasRecomputedVsOriginalDiff = (recEnc != pred_at_pred.predEnc); + hasRecomputedVsActualDiff = (recEnc != realEnc); + } else { + hasRecomputedVsActualDiff = (pred_at_pred.predEnc != realEnc); } - // Update recomputed difference statistics (per fetchBlock) + if (hasRecomputedVsActualDiff) { tageStats.recomputedVsActualDiff++; } if (hasRecomputedVsOriginalDiff) { tageStats.recomputedVsOriginalDiff++; } - if (getDelay() <2){ + + // Update basic hit/useAlt statistics on update. + { + TagePrediction updPred(startAddr, main_info, alt_info, + pred_at_pred.useAlt, pred_at_pred.source, + pred_at_pred.predEnc, pred_at_pred.baseEnc, + pred_at_pred.payloadMapped, pred_at_pred.predCondPC); + tageStats.updateStatsWithTagePrediction(updPred, false); + } + + // Update useAltOnNa (block-level): only when provider was weak at prediction time. + if (pred_at_pred.mainInfo.found && isWeakConf(pred_at_pred.mainInfo.entry.conf)) { + tageStats.updateProviderNa++; + const uint8_t providerEnc = pred_at_pred.mainInfo.entry.selectedEnc(); + const bool base_correct = (pred_at_pred.baseEnc == realEnc); + const bool provider_correct = (providerEnc == realEnc); + // Gate meaning in Exit-Slot v2: + // useAltOnNa[startPC] >= 0 => choose Base when Provider is weak. + // So we train it toward Base when Base is correct, otherwise toward Provider. + if (base_correct != provider_correct) { + const bool prefer_base = base_correct && !provider_correct; + Addr uidx = getUseAltIdx(startAddr); + updateCounter(prefer_base, useAltOnNaWidth, useAlt[uidx]); + tageStats.updateUseAltOnNaUpdated++; + if (prefer_base) { + tageStats.updateUseAltOnNaCorrect++; + } else { + tageStats.updateUseAltOnNaWrong++; + } + } + } + + bool alloc_success = false; + uint64_t allocated_table = 0; + uint64_t allocated_index = 0; + uint64_t allocated_way = 0; + + // Provider update (always update provider entry when found, like old behavior). + if (main_info.found) { + auto &way = tageTable[main_info.table][main_info.index][main_info.way]; + const uint8_t old_conf = way.conf; + const uint8_t providerPredEnc = way.selectedEnc(); + const uint8_t providerOtherEnc = way.otherEnc(); + const bool providerSelCorrect = (providerPredEnc == realEnc); + const bool providerOtherHit = (providerOtherEnc == realEnc); + const bool providerAnyHit = providerSelCorrect || providerOtherHit; + + // Conf reflects *predictive* reliability under this history. + // + // For dual-candidate Exit-Slot entries, "otherHit" only means the correct label + // is present, but selector still failed. Treat selector-miss as incorrect to: + // - avoid conf sticking to strong and suppressing longer-history allocation + // - quickly expose cases where short history cannot separate patterns (e.g. 0/7 alternation) + updateConf(providerSelCorrect, way.conf); + + const uint8_t altOrBaseEnc = pred_at_pred.baseEnc; // Alt unused in Exit-Slot v2 + const bool provider_used = (pred_at_pred.source == PredSource::Provider); + + // Useful: provider provides gain only when provider is used and correct, and alt/base is wrong. + if (provider_used && correct && (altOrBaseEnc != realEnc)) { + way.useful = 1; + } + if (!providerAnyHit && isWeakConf(way.conf)) { + way.useful = 0; + } + + if (providerSelCorrect) { + if (isWeakConf(way.conf)) { + tageStats.updateNoAllocWeakCorrect++; + } + } else if (providerOtherHit) { + // Selector miss: the other candidate is correct, so train selector toward it. + if (way.selCtr >= 2) { + // selected enc1 but real matches enc0 + satDecSel(way.selCtr); + } else { + // selected enc0 but real matches enc1 + satIncSel(way.selCtr); + } + + // If selector keeps missing under the same (short) history, it likely needs longer + // history separation rather than more selector training (classic "1-step lag" on + // alternating labels). Try allocating to longer tables when either: + // - the entry was already strong (we were confident but still wrong), or + // - conf has been trained down to weak (repeated selector misses). + if (isStrongConf(old_conf) || isWeakConf(way.conf)) { + unsigned start_table = main_info.table + 1; + alloc_success = handleNewEntryAllocation(startAddr, realEnc, start_table, + predMeta, allocated_table, + allocated_index, allocated_way); + } + } else { + // Weak-and-wrong is the typical ping-pong trigger in Exit-Slot mode: + // multiple exit patterns of the same startPC keep rewriting the same entry. + // Prefer allocating into longer history tables to separate patterns; fall back + // to rewrite only when allocation fails. + const bool provider_was_weak = isWeakConf(old_conf); + if (provider_was_weak) { + unsigned start_table = main_info.table + 1; + alloc_success = handleNewEntryAllocation(startAddr, realEnc, start_table, + predMeta, allocated_table, + allocated_index, allocated_way); + if (!alloc_success) { + // Replace the non-selected candidate with the new label, and steer selector to it. + if (way.selCtr >= 2) { + // currently selects enc1 => replace enc0, then select enc0 strongly + way.exitSlotEnc0 = realEnc; + way.selCtr = 0; + } else { + // currently selects enc0 => replace enc1, then select enc1 strongly + way.exitSlotEnc1 = realEnc; + way.selCtr = 3; + } + way.conf = 0; // weak init + way.useful = 0; + tageStats.updateRewriteWeakWrong++; + } + } else if (isStrongConf(old_conf)) { + // strong-but-wrong => allocate longer history. + tageStats.updateAllocStrongWrong++; + unsigned start_table = main_info.table + 1; + alloc_success = handleNewEntryAllocation(startAddr, realEnc, start_table, + predMeta, allocated_table, + allocated_index, allocated_way); + } + } + } else { + // Provider miss: allocate only when incorrect (i.e., base can't cover this pattern). + if (!correct) { + tageStats.updateAllocOnMiss++; + alloc_success = handleNewEntryAllocation(startAddr, realEnc, 0, + predMeta, allocated_table, + allocated_index, allocated_way); + } + } + + // If alt was actually used, train alt entry as well. + if (pred_at_pred.source == PredSource::Alt && alt_info.found) { + auto &way = tageTable[alt_info.table][alt_info.index][alt_info.way]; + updateConf(correct, way.conf); + } + +#ifndef UNIT_TEST + if (enableDB) { + TageMissTrace t; + std::string history_str; + boost::dynamic_bitset<> history_low50 = predMeta->history; + if (history_low50.size() > 50) { + history_low50.resize(50); + } + boost::to_string(history_low50, history_str); + + const uint64_t branchPC = stream.exeBranchInfo.isCond ? stream.exeBranchInfo.pc : 0; + const uint64_t main_tag = main_info.found ? main_info.tag : 0; + const uint64_t alt_tag = alt_info.found ? alt_info.tag : 0; + const uint64_t main_payload = main_info.found ? main_info.entry.exitSlotEnc0 : 0; + const uint64_t alt_payload = alt_info.found ? alt_info.entry.exitSlotEnc0 : 0; + const uint64_t main_payload1 = main_info.found ? main_info.entry.exitSlotEnc1 : 0; + const uint64_t alt_payload1 = alt_info.found ? alt_info.entry.exitSlotEnc1 : 0; + const uint64_t main_sel = main_info.found ? main_info.entry.selCtr : 0; + const uint64_t alt_sel = alt_info.found ? alt_info.entry.selCtr : 0; + const uint64_t pred_source = static_cast(pred_at_pred.source); + t.set(startAddr, branchPC, main_info.way, + main_info.found, main_info.entry.conf, main_info.entry.useful, + main_info.table, main_info.index, + alt_info.found, alt_info.entry.conf, alt_info.entry.useful, + alt_info.table, alt_info.index, + pred_at_pred.useAlt, pred_at_pred.predEnc != 0, stream.exeTaken, alloc_success, + allocated_table, allocated_index, allocated_way, + history_str, + main_info.found ? predMeta->indexFoldedHist[main_info.table].get() : 0, + main_tag, alt_tag, + main_payload, alt_payload, + main_payload1, alt_payload1, + main_sel, alt_sel, + pred_at_pred.baseEnc, pred_at_pred.predEnc, realEnc, + pred_source); + tageMissTrace->write_record(t); + } +#endif + + if (getDelay() < 2) { checkUtageUpdateMisspred(stream); } - DPRINTF(TAGE, "end update\n"); + + DPRINTF(TAGE, "end update (PredEnc %u, RealEnc %u, correct %d)\n", + pred_at_pred.predEnc, realEnc, correct); } void BTBTAGE::checkUtageUpdateMisspred(const FetchTarget &stream) { auto predMeta = std::static_pointer_cast(stream.predMetas[getComponentIdx()]); - // use for microtage updatemispred counting - // sort microtage predictions by pc to find the first taken branch - std::vector> lastPreds; - lastPreds.reserve(predMeta->preds.size()); - for (auto &kv : predMeta->preds) { - lastPreds.emplace_back(kv.first, kv.second); - } - std::sort(lastPreds.begin(), lastPreds.end(), - [](const std::pair &a, - const std::pair &b) { - return a.first < b.first; - }); - Addr first_taken_pc = 0; - for (auto &entry_info : lastPreds) { - if (entry_info.second.taken) { - first_taken_pc = entry_info.first; - break; - } + if (!predMeta || !predMeta->hasPred) { + return; } - bool fallthrough_mispred = (first_taken_pc == 0 && stream.exeTaken) || - (first_taken_pc != 0 && !stream.exeTaken); - bool branch_mispred = stream.exeTaken && first_taken_pc != stream.exeBranchInfo.pc; + // MicroTAGE mispred counting: focus on cond-exit only. + const Addr first_taken_pc = predMeta->pred.predCondPC; + const bool actual_cond_taken = stream.exeTaken && stream.exeBranchInfo.isCond; + + bool fallthrough_mispred = (first_taken_pc == 0 && actual_cond_taken) || + (first_taken_pc != 0 && !actual_cond_taken); + bool branch_mispred = actual_cond_taken && first_taken_pc != stream.exeBranchInfo.pc; if (fallthrough_mispred || branch_mispred) { tageStats.updateMispred++; } @@ -817,7 +956,7 @@ BTBTAGE::updateCounter(bool taken, unsigned width, short &counter) { // Calculate TAGE tag with folded history - optimized version using bitwise operations Addr -BTBTAGE::getTageTag(Addr pc, int t, uint64_t foldedHist, uint64_t altFoldedHist, Addr position) +BTBTAGE::getTageTag(Addr pc, int t, uint64_t foldedHist, uint64_t altFoldedHist) { // Create mask for tableTagBits[t] to limit result size Addr mask = (1ULL << tableTagBits[t]) - 1; @@ -832,14 +971,14 @@ BTBTAGE::getTageTag(Addr pc, int t, uint64_t foldedHist, uint64_t altFoldedHist, // Extract alt tag bits and shift left by 1 Addr altTagBits = (altFoldedHist << 1) & mask; - // XOR all components together, including position (like RTL) - return pcBits ^ foldedBits ^ altTagBits ^ position; + // XOR all components together (Exit-Slot mode does not include position). + return pcBits ^ foldedBits ^ altTagBits; } Addr -BTBTAGE::getTageTag(Addr pc, int t, Addr position) +BTBTAGE::getTageTag(Addr pc, int t) { - return getTageTag(pc, t, tagFoldedHist[t].get(), altTagFoldedHist[t].get(), position); + return getTageTag(pc, t, tagFoldedHist[t].get(), altTagFoldedHist[t].get()); } Addr @@ -892,7 +1031,7 @@ BTBTAGE::getUseAltIdx(Addr pc) { } unsigned -BTBTAGE::getBranchIndexInBlock(Addr branchPC, Addr startPC) { +BTBTAGE::getBranchIndexInBlock(Addr branchPC, Addr startPC) const { // Calculate branch position within the fetch block (0 .. maxBranchPositions-1) Addr alignedPC = startPC & ~(blockSize - 1); Addr offset = (branchPC - alignedPC) >> instShiftAmt; @@ -927,10 +1066,8 @@ BTBTAGE::doUpdateHist(const boost::dynamic_bitset<> &history, bool taken, Addr p boost::to_string(history, buf); DPRINTF(TAGEHistory, "in doUpdateHist, taken %d, pc %#lx, history %s\n", taken, pc, buf.c_str()); } - if (!taken) { - DPRINTF(TAGEHistory, "not updating folded history, since FB not taken\n"); - return; - } + // Strategy B: keep folded path history evolving even on fall-through by using a pseudo edge. + // (Callers are expected to pass a meaningful (pc,target) when taken==false.) for (int t = 0; t < numPredictors; t++) { for (int type = 0; type < 3; type++) { @@ -958,6 +1095,11 @@ void BTBTAGE::specUpdatePHist(const boost::dynamic_bitset<> &history, FullBTBPrediction &pred) { auto [pc, target, taken] = pred.getPHistInfo(); + if (!taken) { + // Pseudo edge for fall-through: startPC -> startPC + blockSize. + pc = pred.bbStart; + target = pred.bbStart + blockSize; + } doUpdateHist(history, taken, pc, target); } @@ -984,7 +1126,13 @@ BTBTAGE::recoverPHist(const boost::dynamic_bitset<> &history, altTagFoldedHist[i].recover(predMeta->altTagFoldedHist[i]); indexFoldedHist[i].recover(predMeta->indexFoldedHist[i]); } - doUpdateHist(history, cond_taken, entry.getControlPC(), entry.getTakenTarget()); + Addr pc = entry.getControlPC(); + Addr target = entry.getTakenTarget(); + if (!cond_taken) { + pc = entry.startPC; + target = entry.startPC + blockSize; + } + doUpdateHist(history, cond_taken, pc, target); } // Check folded history after speculative update and recovery @@ -1028,6 +1176,13 @@ BTBTAGE::TageStats::TageStats(statistics::Group* parent, int numPredictors, int ADD_STAT(updateAllocSuccess, statistics::units::Count::get(), "alloc success when update"), ADD_STAT(updateMispred, statistics::units::Count::get(), "mispred when update"), ADD_STAT(updateResetU, statistics::units::Count::get(), "reset u when update"), + ADD_STAT(predNoCondExit, statistics::units::Count::get(), "predicted No-Cond-Exit (ExitSlotEnc==0) blocks"), + ADD_STAT(predBaseFallback, statistics::units::Count::get(), "blocks that fall back to base (provider miss/weak/ mapfail)"), + ADD_STAT(predPayloadMapFail, statistics::units::Count::get(), "non-base payload that cannot be mapped to a cond entry in btbEntries"), + ADD_STAT(updateAllocOnMiss, statistics::units::Count::get(), "allocate on provider miss when base is wrong"), + ADD_STAT(updateAllocStrongWrong, statistics::units::Count::get(), "allocate on strong-but-wrong provider"), + ADD_STAT(updateRewriteWeakWrong, statistics::units::Count::get(), "rewrite payload on weak-and-wrong provider"), + ADD_STAT(updateNoAllocWeakCorrect, statistics::units::Count::get(), "no-alloc on weak-but-correct provider"), ADD_STAT(recomputedVsActualDiff, statistics::units::Count::get(), "fetchBlocks where recomputed.taken != actual_taken"), ADD_STAT(recomputedVsOriginalDiff, statistics::units::Count::get(), "fetchBlocks where recomputed.taken != original pred.taken"), ADD_STAT(updateBankConflict, statistics::units::Count::get(), "number of bank conflicts detected"), @@ -1133,15 +1288,29 @@ BTBTAGE::commitBranch(const FetchTarget &stream, const DynInstPtr &inst) return; } auto meta = std::static_pointer_cast(stream.predMetas[getComponentIdx()]); - auto pc = inst->pcState().instAddr(); - auto it = meta->preds.find(pc); + const Addr pc = inst->pcState().instAddr(); + + // pred_hit: the branch must be present in the BTB entries of this stream. + const BTBEntry *btb_entry = nullptr; + for (auto &e : stream.predBTBEntries) { + if (e.valid && e.isCond && e.pc == pc) { + btb_entry = &e; + break; + } + } + const bool pred_hit = (btb_entry != nullptr) && meta && meta->hasPred; + bool pred_taken = false; - bool pred_hit = false; - if (it != meta->preds.end()) { - pred_taken = it->second.taken; - pred_hit = true; + if (pred_hit) { + if (meta->pred.source == PredSource::Base) { + pred_taken = (btb_entry->ctr >= 0); + } else { + pred_taken = (meta->pred.predCondPC == pc); + } } - bool this_cond_taken = stream.exeTaken && stream.exeBranchInfo.pc == pc; + + const bool this_cond_taken = stream.exeTaken && stream.exeBranchInfo.isCond && + stream.exeBranchInfo.pc == pc; bool predcorrect = (pred_taken == this_cond_taken); if (!predcorrect) { tageStats.condPredwrong++; diff --git a/src/cpu/pred/btb/btb_tage.hh b/src/cpu/pred/btb/btb_tage.hh index 5d104b856d..e3d2aa1f61 100644 --- a/src/cpu/pred/btb/btb_tage.hh +++ b/src/cpu/pred/btb/btb_tage.hh @@ -58,18 +58,34 @@ class BTBTAGE : public TimedBaseBTBPredictor public: bool valid; // Whether this entry is valid Addr tag; // Tag for matching - short counter; // Prediction counter (-4 to 3), 3bits, 0 and -1 are weak + // Exit-Slot v2: confidence is independent of label (multi-class). + // Use an unsigned saturating counter: 0..7 (weak..strong). + uint8_t conf; bool useful; // 1-bit usefulness counter; true means useful - Addr pc; // branch pc, like branch position, for btb entry pc check + // Dual-candidate payloads to reduce multi-pattern ping-pong in Exit-Slot mode. + // 0=No-Cond-Exit, 1..32 => slot=enc-1. + uint8_t exitSlotEnc0; + uint8_t exitSlotEnc1; + // 2-bit selector counter: + // - value < 2 selects enc0 + // - value >= 2 selects enc1 + uint8_t selCtr; unsigned lruCounter; // Counter for LRU replacement policy - TageEntry() : valid(false), tag(0), counter(0), useful(false), pc(0), lruCounter(0) {} + TageEntry() + : valid(false), tag(0), conf(0), useful(false), + exitSlotEnc0(0), exitSlotEnc1(0), selCtr(0), + lruCounter(0) + {} - TageEntry(Addr tag, short counter, Addr pc) : - valid(true), tag(tag), counter(counter), useful(false), pc(pc), lruCounter(0) {} - bool taken() const { - return counter >= 0; - } + TageEntry(Addr tag, uint8_t conf, uint8_t exit0, uint8_t exit1, uint8_t selCtr) : + valid(true), tag(tag), conf(conf), useful(false), + exitSlotEnc0(exit0), exitSlotEnc1(exit1), selCtr(selCtr), + lruCounter(0) + {} + + uint8_t selectedEnc() const { return (selCtr >= 2) ? exitSlotEnc1 : exitSlotEnc0; } + uint8_t otherEnc() const { return (selCtr >= 2) ? exitSlotEnc0 : exitSlotEnc1; } }; // Contains information about a TAGE table lookup @@ -85,29 +101,41 @@ class BTBTAGE : public TimedBaseBTBPredictor TageTableInfo() : found(false), table(0), index(0), tag(0), way(0) {} TageTableInfo(bool found, TageEntry entry, unsigned table, Addr index, Addr tag, unsigned way) : found(found), entry(entry), table(table), index(index), tag(tag), way(way) {} - bool taken() const { - return entry.taken(); - } + }; + + enum class PredSource : uint8_t + { + Provider = 0, + Alt = 1, + Base = 2, }; // Contains the complete prediction result struct TagePrediction { public: - Addr btb_pc; // btb entry pc, same as tage entry pc - TageTableInfo mainInfo; // Main prediction info - TageTableInfo altInfo; // Alternative prediction info - bool useAlt; // Whether to use alternative prediction, true if main is weak or no main prediction - bool taken; // Final prediction (taken/not taken) = use_alt ? alt_provided ? alt_taken : base_taken : main_taken - bool altPred; // Alternative prediction = alt_provided ? alt_taken : base_taken; - - - TagePrediction() : btb_pc(0), useAlt(false), taken(false), altPred(false) {} - - TagePrediction(Addr btb_pc, TageTableInfo mainInfo, TageTableInfo altInfo, - bool useAlt, bool taken, bool altPred) : - btb_pc(btb_pc), mainInfo(mainInfo), altInfo(altInfo), - useAlt(useAlt), taken(taken), altPred(altPred){} + Addr startPC; // Fetch block start PC (aligned as used by MBTB/TAGE) + TageTableInfo mainInfo; // Provider info + TageTableInfo altInfo; // Alternative provider info + bool useAlt; // Whether weak-provider useAltOnNa gate selects base (conservative) + PredSource source; // Final decision source (Provider/Base; Alt is unused in Exit-Slot v2) + uint8_t predEnc; // Final ExitSlotEnc used by this component (0..32) + uint8_t baseEnc; // Base ExitSlotEnc (computed from MBTB ctr, 0..32) + bool payloadMapped; // predEnc!=0 and found matching cond entry in btbEntries + Addr predCondPC; // PC of predicted cond exit (0 if No-Cond-Exit or map fail) + + TagePrediction() + : startPC(0), useAlt(false), source(PredSource::Base), + predEnc(0), baseEnc(0), payloadMapped(false), predCondPC(0) {} + + TagePrediction(Addr startPC, TageTableInfo mainInfo, TageTableInfo altInfo, + bool useAlt, PredSource source, + uint8_t predEnc, uint8_t baseEnc, + bool payloadMapped, Addr predCondPC) + : startPC(startPC), mainInfo(mainInfo), altInfo(altInfo), + useAlt(useAlt), source(source), + predEnc(predEnc), baseEnc(baseEnc), + payloadMapped(payloadMapped), predCondPC(predCondPC) {} }; @@ -179,12 +207,10 @@ class BTBTAGE : public TimedBaseBTBPredictor Addr getTageIndex(Addr pc, int table, uint64_t foldedHist); // Calculate TAGE tag for a given PC and table - // position: branch position within the block (xored into tag like RTL) - Addr getTageTag(Addr pc, int table, Addr position = 0); + Addr getTageTag(Addr pc, int table); // Calculate TAGE tag with folded history (uint64_t version for performance) - // position: branch position within the block (xored into tag like RTL) - Addr getTageTag(Addr pc, int table, uint64_t foldedHist, uint64_t altFoldedHist, Addr position = 0); + Addr getTageTag(Addr pc, int table, uint64_t foldedHist, uint64_t altFoldedHist); // Get offset within a block for a given PC Addr getOffset(Addr pc) { @@ -192,7 +218,7 @@ class BTBTAGE : public TimedBaseBTBPredictor } // Get branch index within a prediction block - unsigned getBranchIndexInBlock(Addr branchPC, Addr startPC); + unsigned getBranchIndexInBlock(Addr branchPC, Addr startPC) const; // Get bank ID from PC (after removing instruction alignment bits) // Extract bits [bankBaseShift + bankIdWidth - 1 : bankBaseShift] @@ -343,6 +369,16 @@ class BTBTAGE : public TimedBaseBTBPredictor Scalar updateMispred; Scalar updateResetU; + // ===== Exit-Slot specific counters (block-level) ===== + Scalar predNoCondExit; + Scalar predBaseFallback; + Scalar predPayloadMapFail; + + Scalar updateAllocOnMiss; + Scalar updateAllocStrongWrong; + Scalar updateRewriteWeakWrong; + Scalar updateNoAllocWeakCorrect; + // Recomputed prediction difference statistics (per fetchBlock) Scalar recomputedVsActualDiff; // recomputed.taken != actual_taken Scalar recomputedVsOriginalDiff; // recomputed.taken != original pred.taken @@ -399,7 +435,8 @@ public: // Metadata for TAGE prediction typedef struct TageMeta { - std::unordered_map preds; + TagePrediction pred; + bool hasPred{false}; std::vector tagFoldedHist; std::vector altTagFoldedHist; std::vector indexFoldedHist; @@ -409,27 +446,25 @@ public: private: - // Helper method to generate prediction for a single BTB entry - // If predMeta is provided, use snapshot folded history for index/tag calculation (update path) - // If predMeta is nullptr, use current folded history (prediction path) - TagePrediction generateSinglePrediction(const BTBEntry &btb_entry, - const Addr &startPC, - const std::shared_ptr predMeta = nullptr); + // Lookup provider/alt in TAGE tables for this fetch block (startPC + PHR snapshot). + // If predMeta is provided, use snapshot folded history for index/tag calculation (update path). + std::pair + lookupProviders(const Addr &startPC, + const std::shared_ptr predMeta = nullptr); - // Helper method to prepare BTB entries for update - std::vector prepareUpdateEntries(const FetchTarget &stream); + // Compute Base exit-slot encoding from MBTB entries (ctr/alwaysTaken), 0..32. + uint8_t getBaseExitSlotEnc(const Addr &startPC, + const std::vector &btbEntries) const; - // Helper method to update predictor state for a single entry - bool updatePredictorStateAndCheckAllocation(const BTBEntry &entry, - bool actual_taken, - const TagePrediction &pred, - const FetchTarget &stream); + // Map predicted exit slot to a cond BTB entry in this block. Returns 0 on failure. + Addr mapExitSlotToCondPC(const Addr &startPC, + const std::vector &btbEntries, + uint8_t predEnc) const; - // Helper method to handle new entry allocation + // Allocation helper for block-level entry (payload = RealEnc). bool handleNewEntryAllocation(const Addr &startPC, - const BTBEntry &entry, - bool actual_taken, - unsigned main_table, + uint8_t realEnc, + unsigned start_table, std::shared_ptr meta, uint64_t &allocated_table, uint64_t &allocated_index, diff --git a/src/cpu/pred/btb/common.hh b/src/cpu/pred/btb/common.hh index 5d54902069..c18d2770e6 100644 --- a/src/cpu/pred/btb/common.hh +++ b/src/cpu/pred/btb/common.hh @@ -667,7 +667,14 @@ struct TageMissTrace : public Record uint64_t altFound, uint64_t altCounter, uint64_t altUseful, uint64_t altTable, uint64_t altIndex, uint64_t useAlt, uint64_t predTaken, uint64_t actualTaken, uint64_t allocSuccess, uint64_t allocTable, uint64_t allocIndex, uint64_t allocWay, - std::string history, uint64_t indexFoldedHist) + std::string history, uint64_t indexFoldedHist, + // Exit-slot specific debug fields (block-level) + uint64_t mainTag, uint64_t altTag, + uint64_t mainPayload, uint64_t altPayload, + uint64_t mainPayload1, uint64_t altPayload1, + uint64_t mainSel, uint64_t altSel, + uint64_t baseEnc, uint64_t predEnc, uint64_t realEnc, + uint64_t predSource) { _tick = curTick(); _uint64_data["startPC"] = startPC; @@ -692,6 +699,19 @@ struct TageMissTrace : public Record _uint64_data["allocWay"] = allocWay; _text_data["history"] = history; _uint64_data["indexFoldedHist"] = indexFoldedHist; + + _uint64_data["mainTag"] = mainTag; + _uint64_data["altTag"] = altTag; + _uint64_data["mainPayload"] = mainPayload; + _uint64_data["altPayload"] = altPayload; + _uint64_data["mainPayload1"] = mainPayload1; + _uint64_data["altPayload1"] = altPayload1; + _uint64_data["mainSel"] = mainSel; + _uint64_data["altSel"] = altSel; + _uint64_data["baseEnc"] = baseEnc; + _uint64_data["predEnc"] = predEnc; + _uint64_data["realEnc"] = realEnc; + _uint64_data["predSource"] = predSource; } }; diff --git a/src/cpu/pred/btb/decoupled_bpred.cc b/src/cpu/pred/btb/decoupled_bpred.cc index 8869dec427..12b6eb2318 100644 --- a/src/cpu/pred/btb/decoupled_bpred.cc +++ b/src/cpu/pred/btb/decoupled_bpred.cc @@ -703,15 +703,16 @@ DecoupledBPUWithBTB::pHistShiftIn(int shamt, bool taken, boost::dynamic_bitset<> if (shamt == 0) { return; } - if(taken){ - // Calculate path hash - uint64_t hash = pathHash(pc, target); - - history <<= shamt; - for (auto i = 0; i < pathHashLength && i < history.size(); i++) { - history[i] = (hash & 1) ^ history[i]; - hash >>= 1; - } + // Exit-Slot predictors benefit from path history evolving even when the block falls through: + // - If PHR stops updating on predicted fall-through, patterns that differ mainly by "no-exit" + // become hard to separate (self-bootstrapping issue). + // Strategy B: always shift, and always inject a hashed (pc,target) event. + // The caller should pass a pseudo edge for fall-through (e.g., startPC -> startPC+blockSize). + uint64_t hash = pathHash(pc, target); + history <<= shamt; + for (auto i = 0; i < pathHashLength && i < history.size(); i++) { + history[i] = (hash & 1) ^ history[i]; + hash >>= 1; } } @@ -941,7 +942,11 @@ DecoupledBPUWithBTB::updateHistoryForPrediction(FetchTarget &entry) histShiftIn(bw_shamt, bw_taken, s0BwHistory); // Update path history - pHistShiftIn(2, p_taken, s0PHistory, p_pc, p_target); + // For fall-through, use a pseudo edge to keep PHR moving (Strategy B). + const Addr phrStride = tage ? tage->blockSize : 32; + const Addr phr_pc = p_taken ? p_pc : entry.startPC; + const Addr phr_target = p_taken ? p_target : (entry.startPC + phrStride); + pHistShiftIn(2, p_taken, s0PHistory, phr_pc, phr_target); // Update local history histShiftIn(shamt, taken, @@ -1017,7 +1022,12 @@ DecoupledBPUWithBTB::recoverHistoryForSquash( histShiftIn(real_shamt, real_taken, s0History); // Update path history with actual outcome - pHistShiftIn(2, real_taken, s0PHistory, squash_pc.instAddr(), redirect_pc); + // Strategy B: when the resolved outcome is fall-through, keep PHR consistent with + // predictors' folded PHR update by using the same pseudo edge (startPC -> startPC+blockSize). + const Addr phrStride = tage ? tage->blockSize : 32; + const Addr phr_pc = real_taken ? squash_pc.instAddr() : target.startPC; + const Addr phr_target = real_taken ? redirect_pc : (target.startPC + phrStride); + pHistShiftIn(2, real_taken, s0PHistory, phr_pc, phr_target); // Update global backward history with actual outcome histShiftIn(real_bw_shamt, real_bw_taken, s0BwHistory); diff --git a/src/cpu/pred/btb/folded_hist.cc b/src/cpu/pred/btb/folded_hist.cc index 7129945856..d1f18e4359 100644 --- a/src/cpu/pred/btb/folded_hist.cc +++ b/src/cpu/pred/btb/folded_hist.cc @@ -177,55 +177,57 @@ ImliFoldedHist::update(const boost::dynamic_bitset<> &ghr, int shamt, bool taken void PathFoldedHist::update(const boost::dynamic_bitset<> &ghr, int shamt, bool taken, Addr pc, Addr target) { - if (taken) { - // Calculate path hash - uint64_t hash = pathHash(pc, target); - - const uint64_t foldedMask = ((1ULL << foldedLen) - 1); - uint64_t temp = _folded; - - assert(shamt <= foldedLen); - assert(shamt <= histLen); - - // Case 1: When folded length >= history length - if (foldedLen >= histLen) { - // Simple shift and set case - temp <<= shamt; - temp ^= hash; - // Clear any bits beyond histLen - temp &= ((1ULL << histLen) - 1); + // Strategy B: also evolve path folded history on fall-through by injecting a pseudo edge. + // The caller is expected to provide a meaningful (pc,target) even when taken==false. + // (If pc/target are 0, the update degenerates to a pure shift.) + // + // Calculate path hash + uint64_t hash = pathHash(pc, target); + + const uint64_t foldedMask = ((1ULL << foldedLen) - 1); + uint64_t temp = _folded; + + assert(shamt <= foldedLen); + assert(shamt <= histLen); + + // Case 1: When folded length >= history length + if (foldedLen >= histLen) { + // Simple shift and set case + temp <<= shamt; + temp ^= hash; + // Clear any bits beyond histLen + temp &= ((1ULL << histLen) - 1); + } + // Case 2: When folded length < history length + else { + assert(shamt <= maxShamt); + // Step 1: Handle the bits that would be lost in shift + for (int i = 0; i < shamt; i++) { + // XOR the highest bits from GHR with corresponding positions in folded history + temp ^= (ghr[posHighestBitsInGhr[i]] << posHighestBitsInOldFoldedHist[i]); } - // Case 2: When folded length < history length - else { - assert(shamt <= maxShamt); - // Step 1: Handle the bits that would be lost in shift - for (int i = 0; i < shamt; i++) { - // XOR the highest bits from GHR with corresponding positions in folded history - temp ^= (ghr[posHighestBitsInGhr[i]] << posHighestBitsInOldFoldedHist[i]); - } - - // Step 2: Perform the shift - temp <<= shamt; - - // Step 3: Copy the XORed bits back to lower positions - for (int i = 0; i < shamt; i++) { - uint64_t highBit = (temp >> (foldedLen + i)) & 1; - temp |= (highBit << i); - } - - // Step 4: Add new branch outcome - uint64_t effectiveHash = hash; - if (histLen < pathHashLength) { - const uint64_t mask = (1ULL << histLen) - 1; - effectiveHash &= mask; - } - temp ^= foldHash(effectiveHash, foldedLen); - - // Mask to folded length - temp &= foldedMask; + + // Step 2: Perform the shift + temp <<= shamt; + + // Step 3: Copy the XORed bits back to lower positions + for (int i = 0; i < shamt; i++) { + uint64_t highBit = (temp >> (foldedLen + i)) & 1; + temp |= (highBit << i); + } + + // Step 4: Add new branch outcome + uint64_t effectiveHash = hash; + if (histLen < pathHashLength) { + const uint64_t mask = (1ULL << histLen) - 1; + effectiveHash &= mask; } - _folded = temp; + temp ^= foldHash(effectiveHash, foldedLen); + + // Mask to folded length + temp &= foldedMask; } + _folded = temp; } } // namespace btb_pred diff --git a/src/cpu/pred/btb/test/btb_mgsc.test.cc b/src/cpu/pred/btb/test/btb_mgsc.test.cc index 9cb241912b..b656425642 100644 --- a/src/cpu/pred/btb/test/btb_mgsc.test.cc +++ b/src/cpu/pred/btb/test/btb_mgsc.test.cc @@ -112,13 +112,13 @@ pHistShiftIn(int shamt, bool taken, boost::dynamic_bitset<> &history, Addr pc, A if (shamt == 0) { return; } - if (taken) { - uint64_t hash = pathHash(pc, target); - history <<= shamt; - for (std::size_t i = 0; i < pathHashLength && i < history.size(); i++) { - history[i] = (hash & 1) ^ history[i]; - hash >>= 1; - } + // Keep path history evolving even on fall-through (Strategy B). + // The caller should provide a pseudo edge for fall-through (e.g., startPC -> startPC+blockSize). + uint64_t hash = pathHash(pc, target); + history <<= shamt; + for (std::size_t i = 0; i < pathHashLength && i < history.size(); i++) { + history[i] = (hash & 1) ^ history[i]; + hash >>= 1; } } @@ -277,6 +277,11 @@ struct MgscHarness histShiftIn(bw_shamt, bw_taken, bwhr); auto [p_pc, p_target, p_taken] = stage_preds[1].getPHistInfo(); + if (!p_taken) { + // Match DecoupledBPUWithBTB Strategy B pseudo edge. + p_pc = start_pc; + p_target = start_pc + 32; + } pHistShiftIn(2, p_taken, phr, p_pc, p_target); unsigned lhr_idx = @@ -313,7 +318,14 @@ struct MgscHarness // Apply correct external history update. histShiftIn(shamt, actual_taken, ghr); histShiftIn(bw_shamt, actual_bw_taken, bwhr); - pHistShiftIn(2, actual_taken, phr, entry.pc, entry.target); + Addr phr_pc = entry.pc; + Addr phr_target = entry.target; + if (!actual_taken) { + // Match DecoupledBPUWithBTB Strategy B pseudo edge. + phr_pc = start_pc; + phr_target = start_pc + 32; + } + pHistShiftIn(2, actual_taken, phr, phr_pc, phr_target); histShiftIn(shamt, actual_taken, lhr[lhr_idx]); } diff --git a/src/cpu/pred/btb/test/btb_tage.test.cc b/src/cpu/pred/btb/test/btb_tage.test.cc index a6289ec4e9..c0be88cb3a 100644 --- a/src/cpu/pred/btb/test/btb_tage.test.cc +++ b/src/cpu/pred/btb/test/btb_tage.test.cc @@ -1,7 +1,9 @@ #include #include -#include +#include +#include +#include #include "base/types.hh" #include "cpu/pred/btb/btb_tage.hh" @@ -20,21 +22,10 @@ namespace btb_pred namespace test { -// Helper functions for TAGE testing - -/** - * @brief Create a BTB entry with specified parameters - * - * @param pc Branch instruction address - * @param isCond Whether the branch is conditional - * @param valid Whether the entry is valid - * @param alwaysTaken Whether the branch is always taken - * @param ctr Prediction counter value - * @param target Branch target address (defaults to sequential PC) - * @return BTBEntry Initialized branch entry - */ -BTBEntry createBTBEntry(Addr pc, bool isCond = true, bool valid = true, - bool alwaysTaken = false, int ctr = 0, Addr target = 0) { +static BTBEntry +createBTBEntry(Addr pc, bool isCond = true, bool valid = true, + bool alwaysTaken = false, int ctr = 0, Addr target = 0) +{ BTBEntry entry; entry.pc = pc; entry.target = target ? target : (pc + 4); @@ -42,44 +33,41 @@ BTBEntry createBTBEntry(Addr pc, bool isCond = true, bool valid = true, entry.valid = valid; entry.alwaysTaken = alwaysTaken; entry.ctr = ctr; - // Other fields are set to default return entry; } -/** - * @brief Create a stream for update or recovery - * - * @param startPC Starting PC for the stream - * @param entry Branch entry information - * @param taken Actual outcome (taken/not taken) - * @param meta Prediction metadata from prediction phase - * @param squashType Type of squash (control or non-control) - * @return FetchTarget Initialized stream for update or recovery - */ -FetchTarget createStream(Addr startPC, const BTBEntry& entry, bool taken, - std::shared_ptr meta) { +static FetchTarget +createStream(Addr startPC, + const std::vector &predEntries, + const BTBEntry *actual_taken_entry, + std::shared_ptr meta) +{ FetchTarget stream; stream.startPC = startPC; - stream.exeBranchInfo = entry; - stream.exeTaken = taken; - // Mark as resolved so recover paths use exe* info + stream.predBTBEntries = predEntries; + stream.updateBTBEntries = predEntries; stream.resolved = true; - stream.predBranchInfo = entry; // keep fields consistent - stream.updateBTBEntries = {entry}; stream.updateIsOldEntry = true; stream.predMetas[0] = meta; - return stream; -} -FetchTarget setMispredStream(FetchTarget stream) { - stream.squashType = SquashType::SQUASH_CTRL; - stream.squashPC = stream.exeBranchInfo.pc; + if (actual_taken_entry) { + stream.exeBranchInfo = *actual_taken_entry; + stream.exeTaken = true; + stream.squashType = SquashType::SQUASH_CTRL; + stream.squashPC = actual_taken_entry->pc; + } else { + stream.exeTaken = false; + stream.exeBranchInfo = BranchInfo(); + stream.squashType = SquashType::SQUASH_NONE; + stream.squashPC = 0; + } return stream; } -void applyPathHistoryTaken(boost::dynamic_bitset<>& history, Addr pc, Addr target, - int shamt = 2) { - boost::dynamic_bitset<> before = history; +static void +applyPathHistoryTaken(boost::dynamic_bitset<> &history, Addr pc, Addr target, + int shamt = 2) +{ history <<= shamt; uint64_t hash = pathHash(pc, target); for (std::size_t i = 0; i < pathHashLength && i < history.size(); ++i) { @@ -89,180 +77,59 @@ void applyPathHistoryTaken(boost::dynamic_bitset<>& history, Addr pc, Addr targe } } -/** - * @brief Helper function to find conditional taken prediction for a given PC - * - * @param condTakens Vector of conditional predictions - * @param pc Branch PC to search for - * @return Pair of (found, prediction) where found indicates if PC was found - */ -std::pair findCondTaken(const gem5::branch_prediction::btb_pred::CondTakens& condTakens, Addr pc) { - auto it = CondTakens_find(condTakens, pc); - if (it != condTakens.end()) { - return {true, it->second}; - } - return {false, false}; -} - -/** - * @brief Execute a complete TAGE prediction cycle - * - * @param tage The TAGE predictor - * @param startPC Starting PC for prediction - * @param entries Vector of BTB entries - * @param history Branch history register - * @param stagePreds Prediction results container - * @return bool Prediction result (taken/not taken) for the first entry - */ -bool predictTAGE(BTBTAGE* tage, Addr startPC, - const std::vector& entries, - boost::dynamic_bitset<>& history, - std::vector& stagePreds) { - // Setup stage predictions with BTB entries +static Addr +predictExitPC(BTBTAGE *tage, Addr startPC, + const std::vector &entries, + const boost::dynamic_bitset<> &history, + std::vector &stagePreds) +{ stagePreds[1].btbEntries = entries; - - // Make prediction - tage->putPCHistory(startPC, history, stagePreds); - - // Return prediction for first entry if exists - if (!entries.empty()) { - auto result = findCondTaken(stagePreds[1].condTakens, entries[0].pc); - bool found = result.first; - bool taken = result.second; - if (found) { - return taken; - } - } - return false; -} - -/** - * @brief Execute a complete prediction-update cycle - * - * @param tage The TAGE predictor - * @param startPC Starting PC for prediction - * @param entry BTB entry to predict - * @param actual_taken Actual outcome (taken/not taken) - * @param history Branch history register - * @param stagePreds Prediction results container - */ -bool predictUpdateCycle(BTBTAGE* tage, Addr startPC, - const BTBEntry& entry, - bool actual_taken, - boost::dynamic_bitset<>& history, - std::vector& stagePreds) { - // 1. Make prediction - stagePreds[1].btbEntries = {entry}; tage->putPCHistory(startPC, history, stagePreds); - // 2. Get predicted result - Addr branch_pc = entry.pc; - auto it = CondTakens_find(stagePreds[1].condTakens, branch_pc); - // ASSERT_TRUE(it != stagePreds[1].condTakens.end()) << "Prediction not found for PC " << std::hex << entry.pc; - bool predicted_taken = it->second; - - // 3. Speculatively update folded history - tage->specUpdateHist(history, stagePreds[1]); - auto meta = tage->getPredictionMeta(); - - // 4. Update path history register, see pHistShiftIn - bool history_updated = false; - auto [pred_pc, pred_target, pred_taken] = stagePreds[1].getPHistInfo(); - boost::dynamic_bitset<> pre_spec_history = history; - if (pred_taken) { - history_updated = true; - applyPathHistoryTaken(history, pred_pc, pred_target); - } - tage->checkFoldedHist(history, "speculative update"); - - // 5. Create update stream - FetchTarget stream = createStream(startPC, entry, actual_taken, meta); - - // 6. Handle possible misprediction - if (predicted_taken != actual_taken) { - stream = setMispredStream(stream); - // Update history with correct outcome - if (history_updated) { - history = pre_spec_history; + Addr pred_pc = 0; + for (auto &e : entries) { + if (!(e.valid && e.isCond)) { + continue; } - // Recover from misprediction - tage->recoverHist(history, stream, 1, actual_taken); - - if (actual_taken) { - applyPathHistoryTaken(history, stream.exeBranchInfo.pc, - stream.exeBranchInfo.target); + Addr branch_pc = e.pc; + auto it = CondTakens_find(stagePreds[1].condTakens, branch_pc); + if (it != stagePreds[1].condTakens.end() && it->second) { + pred_pc = e.pc; + break; } - tage->checkFoldedHist(history, "recover"); } - - // 7. Update predictor - tage->update(stream); - return predicted_taken; + return pred_pc; } -/** - * @brief Directly setup TAGE table entries for testing - * - * @param tage The TAGE predictor - * @param pc Branch PC - * @param table_idx Index of the table to set - * @param counter Counter value - * @param useful Useful bit value - */ -void setupTageEntry(BTBTAGE* tage, Addr pc, int table_idx, - short counter, bool useful = false, int way = 0) { - Addr index = tage->getTageIndex(pc, table_idx); - Addr tag = tage->getTageTag(pc, table_idx); - - auto& entry = tage->tageTable[table_idx][index][way]; +static void +setupTageEntry(BTBTAGE *tage, Addr startPC, int table_idx, + uint8_t conf, uint8_t exit0, uint8_t exit1 = 0, uint8_t sel = 0, + bool useful = false, int way = 0) +{ + Addr index = tage->getTageIndex(startPC, table_idx); + Addr tag = tage->getTageTag(startPC, table_idx); + auto &entry = tage->tageTable[table_idx][index][way]; entry.valid = true; entry.tag = tag; - entry.counter = counter; + entry.conf = conf; entry.useful = useful; - entry.pc = pc; + entry.exitSlotEnc0 = exit0; + entry.exitSlotEnc1 = exit1; + entry.selCtr = sel; } -/** - * @brief Verify TAGE table entries - * - * @param tage The TAGE predictor - * @param pc Branch instruction address to check - * @param expected_tables Vector of expected table indices to have valid entries - */ -void verifyTageEntries(BTBTAGE* tage, Addr pc, const std::vector& expected_tables) { - for (int t = 0; t < tage->numPredictors; t++) { - for (int way = 0; way < tage->numWays; way++) { - Addr index = tage->getTageIndex(pc, t); - auto &entry = tage->tageTable[t][index][way]; - - // Check if this table should have a valid entry - bool should_be_valid = std::find(expected_tables.begin(), - expected_tables.end(), t) != expected_tables.end(); - - if (should_be_valid) { - EXPECT_TRUE(entry.valid && entry.pc == pc) - << "Table " << t << " should have valid entry for PC " << std::hex << pc; - } - } - } -} - -/** - * @brief Find the table with a valid entry for a given fetch block and branch - * - * @param tage The TAGE predictor - * @param startPC Fetch-block start address used during prediction - * @param branchPC Branch instruction address being searched - * @return int Index of the table with valid entry (-1 if not found) - */ -int findTableWithEntry(BTBTAGE* tage, Addr startPC, Addr branchPC) { - auto meta = std::static_pointer_cast(tage->getPredictionMeta()); - // use meta to find the table, predicted info - for (int t = 0; t < tage->numPredictors; t++) { +static int +findTableWithEntryWithMeta(BTBTAGE *tage, Addr startPC, + const std::shared_ptr &meta) +{ + for (int t = 0; t < (int)tage->numPredictors; ++t) { Addr index = tage->getTageIndex(startPC, t, meta->indexFoldedHist[t].get()); - for (int way = 0; way < tage->numWays; way++) { + Addr tag = tage->getTageTag(startPC, t, + meta->tagFoldedHist[t].get(), + meta->altTagFoldedHist[t].get()); + for (int way = 0; way < (int)tage->numWays; ++way) { auto &entry = tage->tageTable[t][index][way]; - if (entry.valid && entry.pc == branchPC) { + if (entry.valid && entry.tag == tag) { return t; } } @@ -270,640 +137,207 @@ int findTableWithEntry(BTBTAGE* tage, Addr startPC, Addr branchPC) { return -1; } +static std::shared_ptr +predictUpdateCycleBlock(BTBTAGE *tage, Addr startPC, + const std::vector &entries, + const BTBEntry *actual_taken_entry, + boost::dynamic_bitset<> &history, + std::vector &stagePreds) +{ + stagePreds[1].btbEntries = entries; + tage->putPCHistory(startPC, history, stagePreds); + tage->specUpdateHist(history, stagePreds[1]); + + auto meta = std::static_pointer_cast(tage->getPredictionMeta()); + + // Mirror pHistShiftIn behavior to keep history consistent in the test. + auto [pred_pc, pred_target, pred_taken] = stagePreds[1].getPHistInfo(); + Addr phr_pc = pred_taken ? pred_pc : startPC; + Addr phr_target = pred_taken ? pred_target : (startPC + tage->blockSize); + applyPathHistoryTaken(history, phr_pc, phr_target); + + FetchTarget stream = createStream(startPC, entries, actual_taken_entry, + std::static_pointer_cast(meta)); + tage->update(stream); + return meta; +} + class BTBTAGETest : public ::testing::Test { -protected: - void SetUp() override { + protected: + void SetUp() override + { tage = new BTBTAGE(); - // memset tageStats to 0 - memset(&tage->tageStats, 0, sizeof(BTBTAGE::TageStats)); - history.resize(64, false); // 64-bit history initialized to 0 - stagePreds.resize(2); // 2 stages + std::memset(&tage->tageStats, 0, sizeof(BTBTAGE::TageStats)); + history.resize(64, false); + stagePreds.resize(2); } - BTBTAGE* tage; + BTBTAGE *tage; boost::dynamic_bitset<> history; std::vector stagePreds; }; -// Test basic prediction functionality -TEST_F(BTBTAGETest, BasicPrediction) { - // Create a conditional branch entry biased towards taken - BTBEntry entry = createBTBEntry(0x1000, true, true, false, 1); - - // Predict and verify - bool taken = predictTAGE(tage, 0x1000, {entry}, history, stagePreds); +TEST_F(BTBTAGETest, BasicPrediction) +{ + Addr startPC = 0x1000; + BTBEntry b0 = createBTBEntry(0x1000, true, true, false, -1); + BTBEntry b1 = createBTBEntry(0x1002, true, true, false, -1); + std::vector entries = {b0, b1}; - // Should predict taken due to initial counter bias - EXPECT_TRUE(taken) << "Initial prediction should be taken"; + setupTageEntry(tage, startPC, /*table*/ 3, /*conf*/ 2, /*exit0*/ 2); - // Update predictor with actual outcome Not taken - predictUpdateCycle(tage, 0x1000, entry, false, history, stagePreds); + Addr pred_pc = predictExitPC(tage, startPC, entries, history, stagePreds); + EXPECT_EQ(pred_pc, 0x1002); - // Verify at least one table has an entry allocated - int table = findTableWithEntry(tage, 0x1000, 0x1000); - EXPECT_GE(table, 0) << "No TAGE table entry was allocated"; + auto meta = std::static_pointer_cast(tage->getPredictionMeta()); + EXPECT_TRUE(meta->hasPred); + EXPECT_EQ(meta->pred.predEnc, 2); + EXPECT_EQ(meta->pred.predCondPC, 0x1002); + EXPECT_EQ(meta->pred.source, BTBTAGE::PredSource::Provider); } -// Test basic history update functionality (PHR semantics) -TEST_F(BTBTAGETest, HistoryUpdate) { - // Use a fixed control PC to derive PHR bits +TEST_F(BTBTAGETest, HistoryUpdate) +{ Addr pc = 0x1000; Addr target = pc + 0x40; - // Test case 1: Update with taken branch (PHR shifts in 2 bits from PC hash) - // Correct order: first update folded histories with pre-update PHR, then mutate PHR tage->doUpdateHist(history, true, pc, target); applyPathHistoryTaken(history, pc, target); - - // Verify folded history matches the ideal fold of the updated PHR tage->checkFoldedHist(history, "taken update"); - // Test case 2: Update with not-taken branch (PHR unchanged, folded update is no-op) - boost::dynamic_bitset<> before_not_taken = history; tage->doUpdateHist(history, false, pc, target); - - // Verify folded history remains consistent + applyPathHistoryTaken(history, pc, target); tage->checkFoldedHist(history, "not-taken update"); - EXPECT_EQ(history, before_not_taken); } -// Test main and alternative prediction mechanism by direct setup -TEST_F(BTBTAGETest, MainAltPredictionBehavior) { - // Create a branch entry for testing - BTBEntry entry = createBTBEntry(0x1000); - - // Setup a strong main prediction (taken) in table 3 - setupTageEntry(tage, 0x1000, 3, 2); // Strong taken +TEST_F(BTBTAGETest, MainAltPredictionBehavior) +{ + Addr startPC = 0x1000; + // Make base prefer slot0. + BTBEntry b0 = createBTBEntry(0x1000, true, true, false, /*ctr*/ 1); + BTBEntry b1 = createBTBEntry(0x1002, true, true, false, /*ctr*/ -1); + std::vector entries = {b0, b1}; - // Setup a weak alternative prediction (not taken) in table 1 - setupTageEntry(tage, 0x1000, 1, -1); // Weak not taken + // Provider predicts slot1. + setupTageEntry(tage, startPC, 3, /*conf*/ 2, /*exit0*/ 2); - // Predict with these entries - predictTAGE(tage, 0x1000, {entry}, history, stagePreds); + Addr pred_pc = predictExitPC(tage, startPC, entries, history, stagePreds); + EXPECT_EQ(pred_pc, 0x1002); - // Check prediction metadata auto meta = std::static_pointer_cast(tage->getPredictionMeta()); - auto pred = meta->preds[0x1000]; + EXPECT_EQ(meta->pred.source, BTBTAGE::PredSource::Provider); + EXPECT_FALSE(meta->pred.useAlt); - // Should use main prediction (strong counter) - EXPECT_FALSE(pred.useAlt) << "Should use main prediction with strong counter"; - EXPECT_TRUE(pred.taken) << "Main prediction should be taken"; - EXPECT_EQ(pred.mainInfo.table, 3) << "Main prediction should come from table 3"; - EXPECT_EQ(pred.altInfo.table, 1) << "Alt prediction should come from table 1"; - - // Now set main prediction to weak - setupTageEntry(tage, 0x1000, 3, 0); // Weak taken - - // Predict again - predictTAGE(tage, 0x1000, {entry}, history, stagePreds); - - // Check prediction metadata again + // Make provider weak => default useAltOnNa is >= 0, so choose Base (conservative). + setupTageEntry(tage, startPC, 3, /*conf*/ 0, /*exit0*/ 2); + pred_pc = predictExitPC(tage, startPC, entries, history, stagePreds); + EXPECT_EQ(pred_pc, 0x1000); meta = std::static_pointer_cast(tage->getPredictionMeta()); - pred = meta->preds[0x1000]; - - // Should use alt prediction (main is weak) - EXPECT_TRUE(pred.useAlt) << "Should use alt prediction with weak main counter"; - EXPECT_FALSE(pred.taken) << "Alt prediction should be not taken"; -} - -// Test useful bit update mechanism -TEST_F(BTBTAGETest, UsefulBitMechanism) { - // Setup a test branch - BTBEntry entry = createBTBEntry(0x1000); - - // Setup entries in main and alternative tables - setupTageEntry(tage, 0x1000, 3, 2, false); // Main: strong taken, useful=false - setupTageEntry(tage, 0x1000, 1, -2, false); // Alt: strong not taken, useful=false - - // Verify initial useful bit state - Addr mainIndex = tage->getTageIndex(0x1000, 3); - EXPECT_FALSE(tage->tageTable[3][mainIndex][0].useful) << "Useful bit should start as false"; - - // Predict - predictTAGE(tage, 0x1000, {entry}, history, stagePreds); - auto meta = tage->getPredictionMeta(); - - // Update with actual outcome matching main prediction (taken) - FetchTarget stream = createStream(0x1000, entry, true, meta); - tage->update(stream); - - // Verify useful bit is set (main prediction was correct and differed from alt) - EXPECT_TRUE(tage->tageTable[3][mainIndex][0].useful) - << "Useful bit should be set when main predicts correctly and differs from alt"; - - // Predict again - predictTAGE(tage, 0x1000, {entry}, history, stagePreds); - meta = tage->getPredictionMeta(); - - // Update with actual outcome opposite to main prediction (not taken) - stream = createStream(0x1000, entry, false, meta); - tage->update(stream); - - // Verify useful bit is NOT cleared (policy is ++ only, no --) - EXPECT_TRUE(tage->tageTable[3][mainIndex][0].useful) - << "Useful bit should remain set when main predicts incorrectly (no decrement)"; -} - -// Test entry allocation mechanism -TEST_F(BTBTAGETest, EntryAllocationAndReplacement) { - // Instead of creating two different PCs, we'll create two entries with the same PC - // This ensures they map to the same indices in the tables - BTBEntry entry1 = createBTBEntry(0x1000); - BTBEntry entry2 = createBTBEntry(0x1000); // Same PC to ensure same indices - - // Set all tables to have entries with useful=true - for (int t = 0; t < tage->numPredictors; t++) { - setupTageEntry(tage, 0x1000, t, 0, true); // Counter=0, useful=true - } - - // Force a misprediction to trigger allocation attempt - // First, make a prediction - predictTAGE(tage, 0x1000, {entry1}, history, stagePreds); - auto meta = tage->getPredictionMeta(); - bool predicted = false; - auto result_pred = findCondTaken(stagePreds[1].condTakens, 0x1000); - bool found_pred = result_pred.first; - bool pred_result = result_pred.second; - if (found_pred) { - predicted = pred_result; - } - - // Create a stream for entry2 with opposite outcome to force allocation - // Although it has the same PC, we'll treat it as a different branch context - // by setting a specific tag that doesn't match existing entries - FetchTarget stream = createStream(0x1000, entry2, !predicted, meta); - stream.squashType = SquashType::SQUASH_CTRL; // Mark as control misprediction - stream.squashPC = 0x1000; - - // Update the predictor (this should try to allocate but fail) - tage->update(stream); - - int alloc_failed_no_valid = tage->tageStats.updateAllocFailureNoValidTable; - EXPECT_GE(alloc_failed_no_valid, 1) << "Allocate failed due to no valid table to allocate (all useful)"; - -} - -// Test history recovery mechanism -TEST_F(BTBTAGETest, HistoryRecoveryCorrectness) { - BTBEntry entry = createBTBEntry(0x1000); - - // Record initial history state - boost::dynamic_bitset<> originalHistory = history; - - // Store original folded history state - std::vector originalTagFoldedHist; - std::vector originalAltTagFoldedHist; - std::vector originalIndexFoldedHist; - - for (int i = 0; i < tage->numPredictors; i++) { - originalTagFoldedHist.push_back(tage->tagFoldedHist[i]); - originalAltTagFoldedHist.push_back(tage->altTagFoldedHist[i]); - originalIndexFoldedHist.push_back(tage->indexFoldedHist[i]); - } - - // Make a prediction - bool predicted_taken = predictTAGE(tage, 0x1000, {entry}, history, stagePreds); - - // Speculatively update history - tage->specUpdateHist(history, stagePreds[1]); - auto meta = tage->getPredictionMeta(); - - // Update PHR register (speculative) to mirror pHistShiftIn - if (predicted_taken) { - applyPathHistoryTaken(history, entry.pc, entry.target); - } - - // Create a recovery stream with opposite outcome - FetchTarget stream = createStream(0x1000, entry, !predicted_taken, meta); - stream = setMispredStream(stream); - - // Recover to pre-speculative state and update with correct outcome - boost::dynamic_bitset<> recoveryHistory = originalHistory; - tage->recoverHist(recoveryHistory, stream, 1, !predicted_taken); - - // Expected history should be original updated with PHR if actually taken - boost::dynamic_bitset<> expectedHistory = originalHistory; - if (!predicted_taken) { // actual_taken - applyPathHistoryTaken(expectedHistory, entry.pc, entry.target); - } - - // Verify recovery produced the expected history - for (int i = 0; i < tage->numPredictors; i++) { - tage->tagFoldedHist[i].check(expectedHistory); - tage->altTagFoldedHist[i].check(expectedHistory); - tage->indexFoldedHist[i].check(expectedHistory); - } -} - -// Simplified test for multiple branch sequence -TEST_F(BTBTAGETest, MultipleBranchSequence) { - // Create two branches - std::vector btbEntries = { - createBTBEntry(0x1000), - createBTBEntry(0x1004) - }; - - // Predict for both branches - predictTAGE(tage, 0x1000, btbEntries, history, stagePreds); - auto meta = tage->getPredictionMeta(); - - // Get predictions for both branches - bool first_pred = false, second_pred = false; - auto result1 = findCondTaken(stagePreds[1].condTakens, 0x1000); - if (result1.first) { - first_pred = result1.second; - } - auto result2 = findCondTaken(stagePreds[1].condTakens, 0x1004); - if (result2.first) { - second_pred = result2.second; - } - - // Update first branch (correct prediction), no allocation - FetchTarget stream1 = createStream(0x1000, btbEntries[0], first_pred, meta); - tage->update(stream1); - - // Update second branch (incorrect prediction), allocate 1 entry - FetchTarget stream2 = createStream(0x1000, btbEntries[1], !second_pred, meta); - stream2.squashType = SquashType::SQUASH_CTRL; - stream2.squashPC = 0x1004; - tage->update(stream2); - - // Verify both branches have entries allocated - EXPECT_EQ(findTableWithEntry(tage, 0x1000, 0x1000), -1) << "First branch should not have an entry"; - EXPECT_GE(findTableWithEntry(tage, 0x1000, 0x1004), 0) << "Second branch should have an entry"; + EXPECT_TRUE(meta->pred.useAlt); + EXPECT_EQ(meta->pred.source, BTBTAGE::PredSource::Base); + + // Disable useAltOnNa => weak provider should be used. + Addr uidx = tage->getUseAltIdx(startPC); + tage->useAlt[uidx] = -1; + pred_pc = predictExitPC(tage, startPC, entries, history, stagePreds); + EXPECT_EQ(pred_pc, 0x1002); + meta = std::static_pointer_cast(tage->getPredictionMeta()); + EXPECT_EQ(meta->pred.source, BTBTAGE::PredSource::Provider); } -// Test counter update mechanism -TEST_F(BTBTAGETest, CounterUpdateMechanism) { - BTBEntry entry = createBTBEntry(0x1000); - - // Setup a TAGE entry with a neutral counter - int testTable = 3; - setupTageEntry(tage, 0x1000, testTable, 0); - - // Verify initial counter value - Addr index = tage->getTageIndex(0x1000, testTable); - EXPECT_EQ(tage->tageTable[testTable][index][0].counter, 0) << "Initial counter should be 0"; - - // Train with taken outcomes multiple times - for (int i = 0; i < 3; i++) { - predictTAGE(tage, 0x1000, {entry}, history, stagePreds); - auto meta = tage->getPredictionMeta(); - - FetchTarget stream = createStream(0x1000, entry, true, meta); - tage->update(stream); - } - - // Verify counter saturates at maximum - EXPECT_EQ(tage->tageTable[testTable][index][0].counter, 3) - << "Counter should saturate at maximum value"; +TEST_F(BTBTAGETest, UsefulBitMechanism) +{ + Addr startPC = 0x1000; + // Base prefers slot0, but actual is slot1. + BTBEntry b0 = createBTBEntry(0x1000, true, true, false, /*ctr*/ 1); + BTBEntry b1 = createBTBEntry(0x1002, true, true, false, /*ctr*/ -1); + std::vector entries = {b0, b1}; - // Train with not-taken outcomes multiple times - for (int i = 0; i < 7; i++) { - predictTAGE(tage, 0x1000, {entry}, history, stagePreds); - auto meta = tage->getPredictionMeta(); + setupTageEntry(tage, startPC, 3, /*conf*/ 2, /*exit0*/ 2, /*exit1*/ 0, /*sel*/ 0, /*useful*/ false); - FetchTarget stream = createStream(0x1000, entry, false, meta); - tage->update(stream); - } + Addr mainIndex = tage->getTageIndex(startPC, 3); + EXPECT_FALSE(tage->tageTable[3][mainIndex][0].useful); - // Verify counter saturates at minimum - EXPECT_EQ(tage->tageTable[testTable][index][0].counter, -4) - << "Counter should saturate at minimum value"; + predictUpdateCycleBlock(tage, startPC, entries, &b1, history, stagePreds); + EXPECT_TRUE(tage->tageTable[3][mainIndex][0].useful); } -/** - * @brief Test predictor consistency after multiple predictions - * - * This test verifies that: - * 1. The predictor learns a repeating pattern - * 2. The prediction accuracy improves over time - * 3. Predictor state is consistent after multiple predictions - */ -TEST_F(BTBTAGETest, UpdateConsistencyAfterMultiplePredictions) { - // Create a branch entry - BTBEntry entry = createBTBEntry(0x1000); - // outer loop always taken - BTBEntry entry2 = createBTBEntry(0x1010); // always taken - - // Step 1: Train predictor on a fixed pattern (alternating T/N) - const int TOTAL_ITERATIONS = 100; - const int WARMUP_ITERATIONS = 80; - - int correctly_predicted = 0; - - for (int i = 0; i < TOTAL_ITERATIONS; i++) { - bool actual_taken = (i % 2 == 0); // T,N,T,N pattern - bool predicted_taken = predictUpdateCycle(tage, 0x1000, entry, actual_taken, history, stagePreds); - predictUpdateCycle(tage, 0x1010, entry2, true, history, stagePreds); - - // Count correct predictions after warmup - if (i >= WARMUP_ITERATIONS) { - correctly_predicted += (predicted_taken == actual_taken) ? 1 : 0; - } - } - - // Calculate accuracy in final phase - double accuracy = static_cast(correctly_predicted) / - (TOTAL_ITERATIONS - WARMUP_ITERATIONS); - - // Verify predictor has learned the pattern with high accuracy - EXPECT_GT(accuracy, 0.9) - << "Predictor should learn alternating pattern with >90% accuracy"; - // print updateMispred: mispredictions times - std::cout << "updateMispred: " << tage->tageStats.updateMispred << std::endl; -} - -/** - * @brief Test combined prediction accuracy across different tables - * - * This test evaluates how different tables in the TAGE predictor - * contribute to prediction accuracy for various branch patterns. - */ -TEST_F(BTBTAGETest, CombinedPredictionAccuracyTesting) { - // Setup branch entry - BTBEntry entry = createBTBEntry(0x1000); - // outer loop always taken - BTBEntry entry2 = createBTBEntry(0x1010); // always taken - - // Define different branch patterns - struct PatternTest - { - std::string name; - std::function pattern; - }; - - std::vector patterns = { - {"Alternating", [](int i) { return i % 2 == 0; }}, // T,N,T,N... - {"ThreeCycle", [](int i) { return i % 3 == 0; }}, // T,N,N,T,N,N... - {"LongCycle", [](int i) { return (i / 10) % 2 == 0; }}, // 10 Ts, 10 Ns... - {"BiasedRandom", [](int i) { - // Use deterministic but complex pattern that appears somewhat random - return ((i * 7 + 3) % 11) > 5; - }} - }; - - const int TRAIN_ITERATIONS = 200; // it need more iterations to train! - const int WARMUP_ITERATIONS = 180; - - - // Test each pattern - for (const auto& pattern_test : patterns) { - // Reset predictor and history - tage = new BTBTAGE(); - // clear history - history.reset(); - stagePreds.resize(2); - - int correctly_predicted = 0; - // Training phase - for (int i = 0; i < TRAIN_ITERATIONS; i++) { - bool actual_taken = pattern_test.pattern(i); - bool predicted_taken = predictUpdateCycle(tage, 0x1000, entry, actual_taken, history, stagePreds); - predictUpdateCycle(tage, 0x1010, entry2, true, history, stagePreds); - - // Count correct predictions after warmup - if (i >= WARMUP_ITERATIONS) { - correctly_predicted += (predicted_taken == actual_taken) ? 1 : 0; - } - } - - // Calculate accuracy in final phase - double accuracy = static_cast(correctly_predicted) / - (TRAIN_ITERATIONS - WARMUP_ITERATIONS); - - - // Verify predictor has learned the pattern with high accuracy - EXPECT_GE(accuracy, 0.8) - << "Predictor should learn alternating pattern with >80% accuracy"; +TEST_F(BTBTAGETest, EntryAllocationOnMissWhenBaseWrong) +{ + Addr startPC = 0x1000; + // Base predicts slot0 taken, but actual is slot1 => miss/wrong should allocate. + BTBEntry b0 = createBTBEntry(0x1000, true, true, false, /*ctr*/ 1); + BTBEntry b1 = createBTBEntry(0x1002, true, true, false, /*ctr*/ -1); + std::vector entries = {b0, b1}; - // print updateMispred: mispredictions times - std::cout << "updateMispred: " << tage->tageStats.updateMispred << std::endl; - } -} + auto meta = predictUpdateCycleBlock(tage, startPC, entries, &b1, history, stagePreds); -/** - * @brief Create a TAGE table entry manually with specific properties - * - * This is particularly useful for set-associative testing when we need - * to control exact placement of entries - */ -void createManualTageEntry(BTBTAGE* tage, int table, Addr index, int way, - Addr tag, short counter, bool useful, Addr pc, - unsigned lruCounter = 0) { - auto &entry = tage->tageTable[table][index][way]; - entry.valid = true; - entry.tag = tag; - entry.counter = counter; - entry.useful = useful; - entry.pc = pc; - entry.lruCounter = lruCounter; + int table = findTableWithEntryWithMeta(tage, startPC, meta); + EXPECT_GE(table, 0); + EXPECT_EQ(tage->tageStats.updateAllocOnMiss, 1); + EXPECT_EQ(tage->tageStats.updateAllocSuccess, 1); } - -/** - * @brief Test set-associative conflict handling - * - * This test verifies that: - * 1. Multiple branches mapping to the same index can be predicted correctly - * 2. The LRU counters are updated properly when entries are accessed - */ -TEST_F(BTBTAGETest, SetAssociativeConflictHandling) { - // Create two branch entries with different PCs +TEST_F(BTBTAGETest, SelectorTrainingOnOtherCandidateHit) +{ Addr startPC = 0x1000; - BTBEntry entry1 = createBTBEntry(startPC); - BTBEntry entry2 = createBTBEntry(startPC + 4); - - // Use a specific table and index for testing - int testTable = 1; - Addr testIndex = tage->getTageIndex(startPC, testTable); - - // Calculate correct tags for each entry (tag includes position XOR) - // entry1: PC=0x1000, position=0 - Addr testTag1 = tage->getTageTag(startPC, testTable, 0); - // entry2: PC=0x1004, position=2 (calculated as (0x1004-0x1000)>>1) - Addr testTag2 = tage->getTageTag(startPC, testTable, 2); - - // Manually create entries with the same index but different tags (due to position) - createManualTageEntry(tage, testTable, testIndex, 0, testTag1, 2, false, 0x1000, 0); // Way 0: Strong taken - createManualTageEntry(tage, testTable, testIndex, 1, testTag2, -2, false, 0x1004, 1); // Way 1: Strong not taken - - // Make predictions and verify directly - // For entry1 (should predict taken) - stagePreds.clear(); - stagePreds.resize(2); - stagePreds[1].btbEntries = {entry1}; - tage->putPCHistory(startPC, history, stagePreds); + BTBEntry b0 = createBTBEntry(0x1000, true, true, false, /*ctr*/ -1); + BTBEntry b1 = createBTBEntry(0x1002, true, true, false, /*ctr*/ -1); + std::vector entries = {b0, b1}; - // Get prediction for entry1 - bool pred1 = false; - auto result_entry1 = findCondTaken(stagePreds[1].condTakens, entry1.pc); - if (result_entry1.first) { - pred1 = result_entry1.second; - } - EXPECT_TRUE(pred1) << "Entry1 should predict taken"; + Addr uidx = tage->getUseAltIdx(startPC); + tage->useAlt[uidx] = -1; - // Check LRU counters after first access - EXPECT_EQ(tage->tageTable[testTable][testIndex][0].lruCounter, 0) - << "LRU counter for way 0 should be reset after access"; + // Dual-candidate entry: enc0 predicts slot0, enc1 predicts slot1, selector initially picks enc0. + setupTageEntry(tage, startPC, /*table*/ 3, /*conf*/ 0, /*exit0*/ 1, /*exit1*/ 2, /*sel*/ 0, /*useful*/ true); + Addr mainIndex = tage->getTageIndex(startPC, 3); - // For entry2 (should predict not taken) - stagePreds.clear(); - stagePreds.resize(2); - stagePreds[1].btbEntries = {entry2}; - tage->putPCHistory(startPC, history, stagePreds); + predictUpdateCycleBlock(tage, startPC, entries, &b1, history, stagePreds); - // Get prediction for entry2 - bool pred2 = false; - auto result_entry2 = findCondTaken(stagePreds[1].condTakens, entry2.pc); - if (result_entry2.first) { - pred2 = result_entry2.second; - } - EXPECT_FALSE(pred2) << "Entry2 should predict not taken"; + // Should not rewrite payload; should only steer selector toward the correct candidate. + EXPECT_EQ(tage->tageTable[3][mainIndex][0].exitSlotEnc0, 1); + EXPECT_EQ(tage->tageTable[3][mainIndex][0].exitSlotEnc1, 2); + EXPECT_EQ(tage->tageTable[3][mainIndex][0].selCtr, 1); } -/** - * @brief Test allocation behavior with multiple ways (new policy) - * - * New allocation policy highlights: - * - Allocation consults the selected way's usefulMask for each table. - * - Only invalid entries, or (useful==0 and weak counter) can be allocated. - * - No LRU-based replacement is performed when all considered entries are useful. - * - * This test verifies: - * 1. First mispredict allocates into an invalid way. - * 2. Subsequent allocations fail when the selected way's usefulMask marks the table useful. - * 3. No replacement occurs even after additional allocation attempts. - */ -TEST_F(BTBTAGETest, AllocationBehaviorWithMultipleWays) { - // Start with a fresh predictor - tage = new BTBTAGE(1, 2, 10); // only 1 predictor table, 2 ways - memset(&tage->tageStats, 0, sizeof(BTBTAGE::TageStats)); - history.resize(64, false); - stagePreds.resize(2); - - // Create a branch entry, base ctr=0, base taken - BTBEntry entry = createBTBEntry(0x1000); - - // Set up a test table and index - int testTable = 0; - Addr testIndex = tage->getTageIndex(0x1000, testTable); - - // Step 1: Verify allocation in an invalid way first - // Make first prediction, mispredict, allocate a new entry - bool predicted1 = predictUpdateCycle(tage, 0x1000, entry, false, history, stagePreds); - - // Check if allocation happened - int allocatedWay = -1; - for (unsigned way = 0; way < tage->numWays; way++) { - if (tage->tageTable[testTable][testIndex][way].valid && - tage->tageTable[testTable][testIndex][way].pc == 0x1000) { - allocatedWay = way; - break; - } - } - - EXPECT_GE(allocatedWay, 0) << "Entry should be allocated in one of the ways"; - - // Strengthen the first allocated entry to prevent it from being replaced - // This simulates that the first branch has been trained and should be protected - tage->tageTable[testTable][testIndex][allocatedWay].useful = true; - tage->tageTable[testTable][testIndex][allocatedWay].counter = 2; // Make it strong - - // Step 2: Attempt to fill remaining ways with different branches - for (unsigned way = 0; way < tage->numWays; way++) { - if (way == allocatedWay) continue; - - // Create a branch with different PC - BTBEntry newEntry = createBTBEntry(0x1004); - - // Make prediction and force allocation - bool predicted = predictUpdateCycle(tage, 0x1000, newEntry, false, history, stagePreds); - } - - // Verify now both ways can be filled under miss policy (consider any way's useful=0) - int filledWays = 0; - for (unsigned way = 0; way < tage->numWays; way++) { - if (tage->tageTable[testTable][testIndex][way].valid) { - filledWays++; - } - } - - EXPECT_EQ(filledWays, tage->numWays) << "All ways should be filled after multiple allocations under miss policy"; +TEST_F(BTBTAGETest, PayloadMapFailFallbackToBase) +{ + Addr startPC = 0x1000; + // Only two conds in this block => slot0(0x1000), slot1(0x1002). + BTBEntry b0 = createBTBEntry(0x1000, true, true, false, /*ctr*/ -1); + BTBEntry b1 = createBTBEntry(0x1002, true, true, false, /*ctr*/ 1); + std::vector entries = {b0, b1}; - // Strengthen all allocated entries to prevent replacement in Step 3 - for (unsigned way = 0; way < tage->numWays; way++) { - if (tage->tageTable[testTable][testIndex][way].valid) { - tage->tageTable[testTable][testIndex][way].useful = true; - tage->tageTable[testTable][testIndex][way].counter = 2; // Make it strong - } - } + // Provider predicts slot2 (enc=3) which cannot map => should fallback to base (slot1). + setupTageEntry(tage, startPC, /*table*/ 3, /*conf*/ 2, /*exit0*/ 3); - // Stats: first allocation succeeded, subsequent attempts failed - int alloc_success_after_step2 = tage->tageStats.updateAllocSuccess; - int alloc_failure_after_step2 = tage->tageStats.updateAllocFailure; - EXPECT_EQ(alloc_success_after_step2, 2) << "Two allocations should have succeeded (one per way)"; - EXPECT_GE(alloc_failure_after_step2, 0) << "Allocation failures may occur depending on mask selection"; - - // Step 3: One more allocation should still not replace existing entry (no LRU replacement) - BTBEntry newEntry = createBTBEntry(0x1008); - bool predicted = predictUpdateCycle(tage, 0x1000, newEntry, false, history, stagePreds); - - // Check if the new entry was allocated - bool found = false; - unsigned foundWay = 0; - for (unsigned way = 0; way < tage->numWays; way++) { - if (tage->tageTable[testTable][testIndex][way].valid && - tage->tageTable[testTable][testIndex][way].pc == 0x1008) { - found = true; - foundWay = way; - break; - } - } + Addr pred_pc = predictExitPC(tage, startPC, entries, history, stagePreds); + EXPECT_EQ(pred_pc, 0x1002); - EXPECT_FALSE(found) << "New entry should not be allocated (no replacement without eligible slot)"; - - // Stats: failure count should increase further after another attempt - int alloc_failure_after_step3 = tage->tageStats.updateAllocFailure; - EXPECT_GE(alloc_failure_after_step3, alloc_failure_after_step2 + 1) - << "Allocation failures should increase after additional failed attempt"; + auto meta = std::static_pointer_cast(tage->getPredictionMeta()); + EXPECT_TRUE(meta->hasPred); + EXPECT_EQ(meta->pred.source, BTBTAGE::PredSource::Base); + EXPECT_EQ(meta->pred.baseEnc, 2); + EXPECT_EQ(tage->tageStats.predPayloadMapFail, 1); + EXPECT_EQ(tage->tageStats.predBaseFallback, 1); } -/** - * @brief Test bank conflict detection - * - * Verifies: - * 1. Same bank access causes conflict and drops update (when enabled) - * 2. Different bank access has no conflict - * 3. Disabled flag prevents conflict detection - */ -TEST_F(BTBTAGETest, BankConflict) { - // Create TAGE with 4 banks +TEST_F(BTBTAGETest, BankConflict) +{ BTBTAGE *bankTage = new BTBTAGE(4, 2, 1024, 4); - boost::dynamic_bitset<> testHistory(128); - std::vector testStagePreds(5); - - // Bank ID derives from bits [2:1] (pc >> 1) & 0x3 when instShiftAmt == 1. - // Bank 0: ..., 0x100, 0x108 ... Bank 1: ..., 0x102, 0x10A ... - // Bank 2: ..., 0x104, 0x10C ... Bank 3: ..., 0x106, 0x10E ... // Test 1: Same bank conflict (enabled) bankTage->enableBankConflict = true; { - // Predict on bank 1 (0x20), then update on bank 1 (0xa0) - testStagePreds[1].btbEntries = {createBTBEntry(0x20)}; - bankTage->putPCHistory(0x20, testHistory, testStagePreds); - EXPECT_TRUE(bankTage->predBankValid); + bankTage->lastPredBankId = bankTage->getBankId(0x20); + bankTage->predBankValid = true; - auto meta = bankTage->getPredictionMeta(); - FetchTarget stream = createStream(0xa0, createBTBEntry(0xa0), true, meta); - setupTageEntry(bankTage, 0xa0, 0, 1, false); + BTBEntry u = createBTBEntry(0xa0); + FetchTarget stream = createStream(0xa0, {u}, &u, nullptr); uint64_t conflicts_before = bankTage->tageStats.updateBankConflict; bool can_update = bankTage->canResolveUpdate(stream); - - // Should detect conflict and defer update EXPECT_EQ(bankTage->tageStats.updateBankConflict, conflicts_before + 1); EXPECT_FALSE(can_update); EXPECT_FALSE(bankTage->predBankValid); @@ -911,48 +345,40 @@ TEST_F(BTBTAGETest, BankConflict) { // Test 2: Different bank, no conflict { - // Predict on bank 0 (0x100), update on bank 2 (0x104) - testStagePreds[1].btbEntries = {createBTBEntry(0x100)}; - bankTage->putPCHistory(0x100, testHistory, testStagePreds); + bankTage->lastPredBankId = bankTage->getBankId(0x100); + bankTage->predBankValid = true; - auto meta = bankTage->getPredictionMeta(); - FetchTarget stream = createStream(0x104, createBTBEntry(0x104), true, meta); + BTBEntry u = createBTBEntry(0x104); + FetchTarget stream = createStream(0x104, {u}, &u, nullptr); uint64_t conflicts_before = bankTage->tageStats.updateBankConflict; bool can_update = bankTage->canResolveUpdate(stream); - ASSERT_TRUE(can_update); - bankTage->doResolveUpdate(stream); - - // Should not detect conflict + EXPECT_TRUE(can_update); EXPECT_EQ(bankTage->tageStats.updateBankConflict, conflicts_before); + EXPECT_TRUE(bankTage->predBankValid); } // Test 3: Disabled flag prevents conflict bankTage->enableBankConflict = false; { - // Same bank (0x20 and 0xa0), but conflict disabled - testStagePreds[1].btbEntries = {createBTBEntry(0x20)}; - bankTage->putPCHistory(0x20, testHistory, testStagePreds); + bankTage->lastPredBankId = bankTage->getBankId(0x20); + bankTage->predBankValid = true; - auto meta = bankTage->getPredictionMeta(); - FetchTarget stream = createStream(0xa0, createBTBEntry(0xa0), true, meta); - setupTageEntry(bankTage, 0xa0, 0, 1, false); + BTBEntry u = createBTBEntry(0xa0); + FetchTarget stream = createStream(0xa0, {u}, &u, nullptr); uint64_t conflicts_before = bankTage->tageStats.updateBankConflict; bool can_update = bankTage->canResolveUpdate(stream); - ASSERT_TRUE(can_update); - bankTage->doResolveUpdate(stream); - - // No conflict even with same bank + EXPECT_TRUE(can_update); EXPECT_EQ(bankTage->tageStats.updateBankConflict, conflicts_before); + EXPECT_TRUE(bankTage->predBankValid); } } +} // namespace test -} // namespace test - -} // namespace btb_pred +} // namespace btb_pred -} // namespace branch_prediction +} // namespace branch_prediction -} // namespace gem5 +} // namespace gem5 diff --git a/util/xs_scripts/bp_db_tage_pingpong.py b/util/xs_scripts/bp_db_tage_pingpong.py new file mode 100644 index 0000000000..ce28613dab --- /dev/null +++ b/util/xs_scripts/bp_db_tage_pingpong.py @@ -0,0 +1,270 @@ +#!/usr/bin/env python3 +""" +Analyze TAGEMISSTRACE in bp.db for Exit-Slot TAGE ping-pong / multi-pattern blocks. + +Typical usage: + python3 util/xs_scripts/bp_db_tage_pingpong.py --db /tmp/debug/.../bp.db --top 20 + python3 util/xs_scripts/bp_db_tage_pingpong.py --db .../bp.db --startpc 0x80000160 --top 50 +""" + +from __future__ import annotations + +import argparse +import collections +import sqlite3 +import sys +from dataclasses import dataclass +from typing import Dict, Iterable, List, Optional, Set, Tuple + + +def parse_u64(x: str) -> int: + x = x.strip().lower() + if x.startswith("0x"): + return int(x, 16) + return int(x, 10) + + +def hex0(x: int) -> str: + return "0x%x" % x + + +def get_cols(con: sqlite3.Connection, table: str) -> Set[str]: + cur = con.cursor() + cur.execute(f"pragma table_info({table});") + return {r[1] for r in cur.fetchall()} + + +def require_table(con: sqlite3.Connection, table: str) -> None: + cur = con.cursor() + cur.execute( + "select name from sqlite_master where type='table' and name=?;", + (table,), + ) + if cur.fetchone() is None: + raise SystemExit(f"ERROR: table {table} not found in db") + + +@dataclass(frozen=True) +class EntryKey: + main_table: int + main_index: int + way: int + main_tag: int # 0 if not present + + +@dataclass +class EntryAgg: + n: int = 0 + real_encs: Set[int] = None # type: ignore[assignment] + payload_pairs: Set[Tuple[int, int]] = None # type: ignore[assignment] + pred_encs: Set[int] = None # type: ignore[assignment] + startpcs: Set[int] = None # type: ignore[assignment] + correct: int = 0 + sels: Set[int] = None # type: ignore[assignment] + + def __post_init__(self) -> None: + if self.real_encs is None: + self.real_encs = set() + if self.payload_pairs is None: + self.payload_pairs = set() + if self.pred_encs is None: + self.pred_encs = set() + if self.startpcs is None: + self.startpcs = set() + if self.sels is None: + self.sels = set() + + +def iter_rows( + con: sqlite3.Connection, + cols: Set[str], + startpc: Optional[int], + limit: Optional[int], +) -> Iterable[sqlite3.Row]: + con.row_factory = sqlite3.Row + cur = con.cursor() + + want = [ + "TICK", + "startPC", + "branchPC", + "actualTaken", + "mainFound", + "mainTable", + "mainIndex", + "wayIdx", + # Optional new fields + "mainTag", + "mainPayload", + "mainPayload1", + "mainSel", + "predEnc", + "realEnc", + ] + + select = [c for c in want if c in cols] + if "TICK" not in select: + # Old schema: no explicit tick column in the trace table, but Record adds it. + # If missing, still proceed. + pass + + q = "select %s from TAGEMISSTRACE" % (", ".join(select) if select else "*") + args: List[object] = [] + if startpc is not None and "startPC" in cols: + q += " where startPC = ?" + args.append(startpc) + if "TICK" in cols: + q += " order by TICK asc" + if limit is not None: + q += " limit ?" + args.append(limit) + + cur.execute(q, args) + for row in cur.fetchall(): + yield row + + +def main() -> int: + ap = argparse.ArgumentParser() + ap.add_argument("--db", required=True, help="path to bp.db") + ap.add_argument("--startpc", default=None, help="filter by startPC (hex or dec)") + ap.add_argument("--top", type=int, default=30, help="top N entries by entropy") + ap.add_argument("--min-samples", type=int, default=50, help="min samples per entry") + ap.add_argument("--limit", type=int, default=None, help="limit number of rows scanned") + args = ap.parse_args() + + startpc = parse_u64(args.startpc) if args.startpc is not None else None + + con = sqlite3.connect(args.db) + # In some sandboxed environments TMPDIR may point to a non-writable path. + # ORDER BY on a large-ish trace can force SQLite to spill to a temp file and fail with + # "unable to open database file". Keep temp in memory to make the script robust. + try: + con.execute("pragma temp_store=memory;") + except sqlite3.Error: + pass + require_table(con, "TAGEMISSTRACE") + cols = get_cols(con, "TAGEMISSTRACE") + + # mainPayload1/mainSel are optional (Exit-Slot v2 dual-candidate debug fields). + missing = [c for c in ("mainPayload", "realEnc", "mainTag", "predEnc") if c not in cols] + if missing: + print( + "WARNING: TAGEMISSTRACE missing columns %s. " + "This db cannot fully prove ping-pong at entry level. " + "Re-run with updated gem5 to log payload/tag/realEnc." + % (missing,), + file=sys.stderr, + ) + + aggs: Dict[EntryKey, EntryAgg] = {} + realenc_missing = "realEnc" not in cols + predenc_missing = "predEnc" not in cols + + for row in iter_rows(con, cols, startpc, args.limit): + if "mainFound" in cols and int(row["mainFound"]) == 0: + continue + if "mainTable" not in row.keys() or "mainIndex" not in row.keys() or "wayIdx" not in row.keys(): + continue + k = EntryKey( + main_table=int(row["mainTable"]), + main_index=int(row["mainIndex"]), + way=int(row["wayIdx"]), + main_tag=int(row["mainTag"]) if "mainTag" in row.keys() else 0, + ) + a = aggs.get(k) + if a is None: + a = EntryAgg() + aggs[k] = a + a.n += 1 + if "startPC" in row.keys(): + a.startpcs.add(int(row["startPC"])) + if "mainPayload" in row.keys(): + p0 = int(row["mainPayload"]) + p1 = int(row["mainPayload1"]) if "mainPayload1" in row.keys() else -1 + a.payload_pairs.add((p0, p1)) + if "mainSel" in row.keys(): + a.sels.add(int(row["mainSel"])) + if not realenc_missing and "realEnc" in row.keys(): + real = int(row["realEnc"]) + a.real_encs.add(real) + if not predenc_missing and "predEnc" in row.keys(): + pred = int(row["predEnc"]) + a.pred_encs.add(pred) + if pred == real: + a.correct += 1 + elif not predenc_missing and "predEnc" in row.keys(): + a.pred_encs.add(int(row["predEnc"])) + + # Histogram by distinct realEnc count (a proxy of multi-pattern pressure on one entry). + hist = collections.Counter() + for a in aggs.values(): + if a.n < args.min_samples: + continue + hist[len(a.real_encs)] += 1 + + print("# TAGEMISSTRACE Entry Entropy (min_samples=%d)" % args.min_samples) + if startpc is not None: + print("- startPC filter: %s" % hex0(startpc)) + print("- total provider-hit records scanned: %d" % sum(a.n for a in aggs.values())) + print("- unique entry keys: %d" % len(aggs)) + if "realEnc" in cols: + print("\n## Distinct realEnc per (table,index,way,tag) histogram") + for k in sorted(hist.keys()): + print("- %d distinct realEnc: %d entries" % (k, hist[k])) + else: + print("\n## NOTE") + print("- realEnc not available in this db; histogram is skipped.") + + # Top entries by entropy + items = [] + if "realEnc" in cols: + for k, a in aggs.items(): + if a.n < args.min_samples: + continue + items.append((len(a.real_encs), a.n, k, a)) + # EntryKey is not orderable; provide an explicit key for deterministic sorting. + items.sort( + key=lambda x: ( + x[0], # distinct realEnc + x[1], # samples + x[2].main_table, + x[2].main_index, + x[2].way, + x[2].main_tag, + ), + reverse=True, + ) + + print("\n## Top %d entries by distinct realEnc" % args.top) + if not items: + print( + "WARNING: TAGEMISSTRACE missing required columns (need at least realEnc/predEnc/mainTag/mainPayload). " + "This db cannot prove ping-pong at entry level; please re-run with an instrumented gem5.opt." + ) + return 0 + for ent_cnt, n, k, a in items[: args.top]: + acc = (a.correct / a.n) if ("realEnc" in cols and "predEnc" in cols and a.n) else None + print( + "- ent=%d n=%d table=%d index=%d way=%d tag=%s startPCs=%d acc=%s realEnc=%s predEnc=%s payloadPairs=%s sel=%s" + % ( + ent_cnt, + n, + k.main_table, + k.main_index, + k.way, + hex0(k.main_tag) if k.main_tag else "0", + len(a.startpcs), + ("%.3f" % acc) if acc is not None else "NA", + sorted(a.real_encs), + sorted(a.pred_encs), + sorted(a.payload_pairs), + sorted(a.sels), + ) + ) + + return 0 + + +if __name__ == "__main__": + raise SystemExit(main()) diff --git a/util/xs_scripts/bp_db_upperbound.py b/util/xs_scripts/bp_db_upperbound.py new file mode 100644 index 0000000000..2a90d2b6f9 --- /dev/null +++ b/util/xs_scripts/bp_db_upperbound.py @@ -0,0 +1,572 @@ +#!/usr/bin/env python3 +""" +Compute simple *offline* upper bounds for Exit-Slot (block-based) TAGE using bp.db. + +Why this exists: + - We want a quick way to answer: "Is per-block exit-slot fundamentally limited, or is our + current implementation/training leaving accuracy on the table?" + - We estimate an upper bound under a *fixed feature set* by doing majority-vote per key. + +Upper bounds reported (all computed from TAGEMISSTRACE rows): + UB(startPC): + For each startPC, always predict the most frequent realEnc under that startPC. + UB(startPC, indexFoldedHist): + For each (startPC, indexFoldedHist), always predict the most frequent realEnc. + +Interpretation: + - If UB(startPC, hist) is high but actual acc is low -> implementation/training/aliasing issues. + - If UB(startPC, hist) itself is low -> the current history signature cannot separate modes; + need better features (history type/length/folding) or accept a lower ceiling. + +Typical usage: + python3 util/xs_scripts/bp_db_upperbound.py --root /tmp/debug/tage-new6 + python3 util/xs_scripts/bp_db_upperbound.py --db /tmp/debug/tage-new6/xor_dependency_opt/bp.db +""" + +from __future__ import annotations + +import argparse +import os +import sqlite3 +from dataclasses import dataclass +from typing import Dict, List, Optional, Tuple + + +def _connect(path: str) -> sqlite3.Connection: + con = sqlite3.connect(path) + # ORDER BY / GROUP BY can spill to temp; keep it in memory to avoid TMPDIR quirks. + try: + con.execute("pragma temp_store=memory;") + except sqlite3.Error: + pass + return con + + +def _has_table(con: sqlite3.Connection, table: str) -> bool: + cur = con.cursor() + cur.execute( + "select 1 from sqlite_master where type='table' and name=?;", + (table,), + ) + return cur.fetchone() is not None + + +def _cols(con: sqlite3.Connection, table: str) -> List[str]: + return [r[1] for r in con.execute(f"pragma table_info({table});")] + + +def _mispred_rate(con: sqlite3.Connection) -> Optional[Tuple[int, int, float]]: + if not _has_table(con, "BPTRACE"): + return None + cur = con.cursor() + n = cur.execute("select count(*) from BPTRACE;").fetchone()[0] + m = cur.execute("select sum(mispred) from BPTRACE;").fetchone()[0] + m = int(m or 0) + return int(n), m, (m / n if n else 0.0) + + +@dataclass +class UBRes: + n: int + actual_acc: Optional[float] + provider_acc: Optional[float] + base_acc: Optional[float] + ub_startpc: Optional[float] + ub_startpc_hist: Optional[float] + ub_startpc_fullhist: Optional[float] + + +def _tage_upperbounds(con: sqlite3.Connection) -> Optional[UBRes]: + if not _has_table(con, "TAGEMISSTRACE"): + return None + + cols = set(_cols(con, "TAGEMISSTRACE")) + if "realEnc" not in cols: + # Old per-branch schema doesn't carry block label; cannot compute UB. + n = con.execute("select count(*) from TAGEMISSTRACE;").fetchone()[0] + return UBRes( + n=int(n), + actual_acc=None, + provider_acc=None, + base_acc=None, + ub_startpc=None, + ub_startpc_hist=None, + ub_startpc_fullhist=None, + ) + + cur = con.cursor() + n = int(cur.execute("select count(*) from TAGEMISSTRACE;").fetchone()[0]) + + actual_acc = None + provider_acc = None + base_acc = None + if "predEnc" in cols: + actual_acc = float( + cur.execute( + "select 1.0*sum(case when predEnc=realEnc then 1 else 0 end)/count(*) " + "from TAGEMISSTRACE;" + ).fetchone()[0] + ) + if "predSource" in cols: + v = cur.execute( + "select case when count(*)=0 then null else " + "1.0*sum(case when predEnc=realEnc then 1 else 0 end)/count(*) end " + "from TAGEMISSTRACE where predSource=0;" + ).fetchone()[0] + provider_acc = (None if v is None else float(v)) + v = cur.execute( + "select case when count(*)=0 then null else " + "1.0*sum(case when predEnc=realEnc then 1 else 0 end)/count(*) end " + "from TAGEMISSTRACE where predSource=2;" + ).fetchone()[0] + base_acc = (None if v is None else float(v)) + + # UB(startPC) + ub_startpc = float( + cur.execute( + """ + with per_label as ( + select startPC, realEnc, count(*) as c + from TAGEMISSTRACE + group by startPC, realEnc + ), + per_startpc as ( + select startPC, max(c) as mx + from per_label + group by startPC + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_startpc; + """ + ).fetchone()[0] + ) + + ub_startpc_hist = None + if "indexFoldedHist" in cols: + ub_startpc_hist = float( + cur.execute( + """ + with per_label as ( + select startPC, indexFoldedHist, realEnc, count(*) as c + from TAGEMISSTRACE + group by startPC, indexFoldedHist, realEnc + ), + per_key as ( + select startPC, indexFoldedHist, max(c) as mx + from per_label + group by startPC, indexFoldedHist + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + + # UB(startPC, full history bitstring) if available. + ub_startpc_fullhist = None + if "history" in cols: + ub_startpc_fullhist = float( + cur.execute( + """ + with per_label as ( + select startPC, history, realEnc, count(*) as c + from TAGEMISSTRACE + group by startPC, history, realEnc + ), + per_key as ( + select startPC, history, max(c) as mx + from per_label + group by startPC, history + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + + return UBRes( + n=n, + actual_acc=actual_acc, + provider_acc=provider_acc, + base_acc=base_acc, + ub_startpc=ub_startpc, + ub_startpc_hist=ub_startpc_hist, + ub_startpc_fullhist=ub_startpc_fullhist, + ) + + +@dataclass +class DirUBRes: + """Offline separability upper bounds for per-branch direction prediction.""" + + n: int + taken_rate: Optional[float] + actual_acc: Optional[float] # predTaken vs actualTaken, if predTaken exists + # Majority-vote UB under different identity/features. + ub_branchpc: Optional[float] + ub_branchpc_hist: Optional[float] + ub_branchpc_fullhist: Optional[float] + ub_startpc_slot: Optional[float] + ub_startpc_slot_hist: Optional[float] + ub_startpc_slot_fullhist: Optional[float] + + +def _dir_upperbounds(con: sqlite3.Connection) -> Optional[DirUBRes]: + if not _has_table(con, "TAGEMISSTRACE"): + return None + cols = set(_cols(con, "TAGEMISSTRACE")) + if "actualTaken" not in cols or "branchPC" not in cols: + return None + + cur = con.cursor() + n = int(cur.execute("select count(*) from TAGEMISSTRACE;").fetchone()[0]) + if n == 0: + return DirUBRes( + n=0, + taken_rate=None, + actual_acc=None, + ub_branchpc=None, + ub_branchpc_hist=None, + ub_branchpc_fullhist=None, + ub_startpc_slot=None, + ub_startpc_slot_hist=None, + ub_startpc_slot_fullhist=None, + ) + + taken_rate = float(cur.execute("select 1.0*sum(actualTaken)/count(*) from TAGEMISSTRACE;").fetchone()[0]) + + actual_acc = None + if "predTaken" in cols: + v = cur.execute( + "select 1.0*sum(case when predTaken=actualTaken then 1 else 0 end)/count(*) from TAGEMISSTRACE;" + ).fetchone()[0] + actual_acc = (None if v is None else float(v)) + + # UB(branchPC) + ub_branchpc = float( + cur.execute( + """ + with per_label as ( + select branchPC, actualTaken, count(*) as c + from TAGEMISSTRACE + group by branchPC, actualTaken + ), + per_key as ( + select branchPC, max(c) as mx + from per_label + group by branchPC + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + + ub_branchpc_hist = None + if "indexFoldedHist" in cols: + ub_branchpc_hist = float( + cur.execute( + """ + with per_label as ( + select branchPC, indexFoldedHist, actualTaken, count(*) as c + from TAGEMISSTRACE + group by branchPC, indexFoldedHist, actualTaken + ), + per_key as ( + select branchPC, indexFoldedHist, max(c) as mx + from per_label + group by branchPC, indexFoldedHist + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + + ub_branchpc_fullhist = None + if "history" in cols: + ub_branchpc_fullhist = float( + cur.execute( + """ + with per_label as ( + select branchPC, history, actualTaken, count(*) as c + from TAGEMISSTRACE + group by branchPC, history, actualTaken + ), + per_key as ( + select branchPC, history, max(c) as mx + from per_label + group by branchPC, history + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + + # UB(startPC, slot): approximate the benefit of injecting "position" identity. + # Slot is computed at 2B granularity and masked to 5 bits (0..31) to match the typical + # in-block slot encoding. + ub_startpc_slot = None + ub_startpc_slot_hist = None + ub_startpc_slot_fullhist = None + if "startPC" in cols: + ub_startpc_slot = float( + cur.execute( + """ + with per_label as ( + select startPC, ((branchPC - startPC) >> 1) & 31 as slot, actualTaken, count(*) as c + from TAGEMISSTRACE + group by startPC, slot, actualTaken + ), + per_key as ( + select startPC, slot, max(c) as mx + from per_label + group by startPC, slot + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + if "indexFoldedHist" in cols: + ub_startpc_slot_hist = float( + cur.execute( + """ + with per_label as ( + select startPC, ((branchPC - startPC) >> 1) & 31 as slot, + indexFoldedHist, actualTaken, count(*) as c + from TAGEMISSTRACE + group by startPC, slot, indexFoldedHist, actualTaken + ), + per_key as ( + select startPC, slot, indexFoldedHist, max(c) as mx + from per_label + group by startPC, slot, indexFoldedHist + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + if "history" in cols: + ub_startpc_slot_fullhist = float( + cur.execute( + """ + with per_label as ( + select startPC, ((branchPC - startPC) >> 1) & 31 as slot, + history, actualTaken, count(*) as c + from TAGEMISSTRACE + group by startPC, slot, history, actualTaken + ), + per_key as ( + select startPC, slot, history, max(c) as mx + from per_label + group by startPC, slot, history + ) + select 1.0*sum(mx)/(select count(*) from TAGEMISSTRACE) + from per_key; + """ + ).fetchone()[0] + ) + + return DirUBRes( + n=n, + taken_rate=taken_rate, + actual_acc=actual_acc, + ub_branchpc=ub_branchpc, + ub_branchpc_hist=ub_branchpc_hist, + ub_branchpc_fullhist=ub_branchpc_fullhist, + ub_startpc_slot=ub_startpc_slot, + ub_startpc_slot_hist=ub_startpc_slot_hist, + ub_startpc_slot_fullhist=ub_startpc_slot_fullhist, + ) + + +def _fmt_pct(x: Optional[float]) -> str: + if x is None: + return "n/a" + return f"{x*100:5.1f}%" + + +def _fmt_n(x: Optional[int]) -> str: + if x is None: + return "n/a" + # Compact human-readable counts. + if x >= 1_000_000_000: + return f"{x/1_000_000_000:.1f}G" + if x >= 1_000_000: + return f"{x/1_000_000:.1f}M" + if x >= 1_000: + return f"{x/1_000:.1f}k" + return str(x) + + +def _analyze_one(db: str) -> Dict[str, object]: + con = _connect(db) + ub = _tage_upperbounds(con) + dir_ub = _dir_upperbounds(con) + bp = _mispred_rate(con) + con.close() + return {"db": db, "ub": ub, "dir_ub": dir_ub, "bp": bp} + + +def main() -> int: + ap = argparse.ArgumentParser() + g = ap.add_mutually_exclusive_group(required=True) + g.add_argument("--db", help="analyze one bp.db") + g.add_argument("--root", help="scan a /tmp/debug/tage-newX directory that contains */bp.db") + args = ap.parse_args() + + if args.db: + r = _analyze_one(args.db) + ub: Optional[UBRes] = r["ub"] # type: ignore[assignment] + dub: Optional[DirUBRes] = r["dir_ub"] # type: ignore[assignment] + bp = r["bp"] + print(f"# {args.db}") + if bp is not None: + n, m, rate = bp + print(f"- BPTRACE mispred: {rate*100:.2f}% ({m}/{n})") + if ub is not None and ub.ub_startpc is not None: + print(f"- TAGEMISSTRACE samples: {ub.n}") + print(f"- actual acc: {_fmt_pct(ub.actual_acc)}") + print(f"- provider acc: {_fmt_pct(ub.provider_acc)}") + print(f"- base acc: {_fmt_pct(ub.base_acc)}") + print(f"- UB_exit(startPC): {_fmt_pct(ub.ub_startpc)}") + print(f"- UB_exit(startPC,hist): {_fmt_pct(ub.ub_startpc_hist)}") + print(f"- UB_exit(startPC,H): {_fmt_pct(ub.ub_startpc_fullhist)}") + if ub.actual_acc is not None and ub.ub_startpc_hist is not None: + print(f"- headroom (UB2-acc): {_fmt_pct(ub.ub_startpc_hist - ub.actual_acc)}") + if dub is not None: + print(f"- DIR samples: {dub.n}") + print(f"- DIR taken rate: {_fmt_pct(dub.taken_rate)}") + print(f"- DIR actual acc: {_fmt_pct(dub.actual_acc)}") + print(f"- UB_dir(branchPC): {_fmt_pct(dub.ub_branchpc)}") + print(f"- UB_dir(branchPC,hist): {_fmt_pct(dub.ub_branchpc_hist)}") + print(f"- UB_dir(branchPC,H): {_fmt_pct(dub.ub_branchpc_fullhist)}") + print(f"- UB_dir(startPC,slot): {_fmt_pct(dub.ub_startpc_slot)}") + print(f"- UB_dir(startPC,slot,hist): {_fmt_pct(dub.ub_startpc_slot_hist)}") + print(f"- UB_dir(startPC,slot,H): {_fmt_pct(dub.ub_startpc_slot_fullhist)}") + return 0 + + root: str = args.root + # Pair *_opt with *_ref. + benches: Dict[str, Dict[str, str]] = {} + for d in os.listdir(root): + if not d.endswith(("_opt", "_ref")): + continue + kind = "opt" if d.endswith("_opt") else "ref" + base = d[: -len("_opt")] if kind == "opt" else d[: -len("_ref")] + db = os.path.join(root, d, "bp.db") + if os.path.exists(db): + benches.setdefault(base, {})[kind] = db + + rows = [] + for base, mp in sorted(benches.items()): + opt = _analyze_one(mp["opt"]) if "opt" in mp else None + ref = _analyze_one(mp["ref"]) if "ref" in mp else None + rows.append((base, opt, ref)) + + # Print a compact table for quick comparison. + print(f"# Upperbound Report: {root}") + print("") + print("## What This Report Measures") + print("") + print("- This is an *offline separability upper bound* computed from `bp.db`.") + print("- For each chosen feature key (e.g., `(startPC, history)`), we compute the best possible") + print(" accuracy under 0/1 loss by always predicting the *most frequent label* for that key") + print(" (majority vote). This is Bayes-optimal given only that key.") + print("- It is **NOT** an oracle that peeks at the future; it quantifies whether the available") + print(" features contain enough information to separate patterns.") + print("") + print("### Exit-slot (per-block) label") + print("") + print("- Uses `TAGEMISSTRACE.realEnc` (0..32) as the true label for Exit-Slot multi-class classification.") + print("- `UB_exit(startPC,hist)`: key is `(startPC, indexFoldedHist)`.") + print("- `UB_exit(startPC,H)`: key is `(startPC, history_string)` (low 50 bits in current logging).") + print("") + print("### Direction (per-branch) label") + print("") + print("- Uses `TAGEMISSTRACE.actualTaken` (0/1) as the true label for direction prediction.") + print("- `acc_dir(ref)`: measured accuracy `predTaken==actualTaken` in ref trace (if `predTaken` exists).") + print("- `UB_dir(ref startPC,slot,hist)`: key is `(startPC, slot, indexFoldedHist)`, where") + print(" `slot = ((branchPC - startPC) >> 1) & 31` approximates in-block position identity.") + print("- `UB_dir(ref startPC,slot,H)`: key is `(startPC, slot, history_string)`.") + print("") + print("### About `n/a`") + print("") + print("- `n/a` means the db does not have usable samples for that metric (missing table/columns,") + print(" or `TAGEMISSTRACE` exists but has 0 rows for that run).") + print("") + header = ( + "| bench | BP mispred opt | BP mispred ref | delta | " + "n_exit(opt) | acc_exit(opt) | UB_exit(startPC,hist) | UB_exit(startPC,H) | " + "n_dir(ref) | acc_dir(ref) | UB_dir(ref startPC,slot,hist) | UB_dir(ref startPC,slot,H) |" + ) + sep = "|---|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|---:|" + print(header) + print(sep) + + reg_items: List[Tuple[float, str]] = [] + + for base, opt, ref in rows: + opt_bp = opt["bp"] if opt else None # type: ignore[index] + ref_bp = ref["bp"] if ref else None # type: ignore[index] + + def _bp_fmt(x: Optional[Tuple[int, int, float]]) -> str: + if x is None: + return "n/a" + return f"{x[2]*100:5.2f}%" + + opt_rate = opt_bp[2] if opt_bp else None + ref_rate = ref_bp[2] if ref_bp else None + delta = (opt_rate - ref_rate) if (opt_rate is not None and ref_rate is not None) else None + + opt_ub: Optional[UBRes] = opt["ub"] if opt else None # type: ignore[index] + ref_dir: Optional[DirUBRes] = (ref["dir_ub"] if ref else None) # type: ignore[index] + + n_exit = opt_ub.n if (opt_ub and opt_ub.ub_startpc is not None) else None + acc_exit = opt_ub.actual_acc if (opt_ub and opt_ub.actual_acc is not None) else None + ub_exit2 = opt_ub.ub_startpc_hist if opt_ub else None + ub_exit3 = opt_ub.ub_startpc_fullhist if opt_ub else None + + n_dir = ref_dir.n if (ref_dir and ref_dir.n) else None + acc_dir = ref_dir.actual_acc if ref_dir else None + ub_dir2 = ref_dir.ub_startpc_slot_hist if ref_dir else None + ub_dir3 = ref_dir.ub_startpc_slot_fullhist if ref_dir else None + if delta is not None: + reg_items.append((delta, base)) + + def _pct(x: Optional[float]) -> str: + if x is None: + return "n/a" + return f"{x*100:5.1f}%" + + print( + "| %s | %s | %s | %s | %s | %s | %s | %s | %s | %s | %s | %s |" + % ( + base, + _bp_fmt(opt_bp), + _bp_fmt(ref_bp), + ("n/a" if delta is None else f"{delta*100:+.2f}%"), + _fmt_n(n_exit), + _pct(acc_exit), + _pct(ub_exit2), + _pct(ub_exit3), + _fmt_n(n_dir), + _pct(acc_dir), + _pct(ub_dir2), + _pct(ub_dir3), + ) + ) + + reg_items.sort(reverse=True) + print("") + print("## Biggest BP mispred regressions (opt - ref)") + for d, b in reg_items[:10]: + print(f"- {b}: {d*100:+.2f}%") + return 0 + + +if __name__ == "__main__": + raise SystemExit(main())