From cf7ddca35122833ad862ffe9ee5da377b9c438d2 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Thu, 28 Aug 2025 19:23:43 +0000 Subject: [PATCH 01/25] Surprisal intervention and config --- delphi/__main__.py | 59 ++- delphi/config.py | 7 +- delphi/latents/latents.py | 7 + delphi/scorers/__init__.py | 7 + delphi/scorers/intervention/__init__.py | 0 .../output_based_intervention_scorer.py | 141 ++++++ .../surprisal_intervention_scorer.py | 449 ++++++++++++++++++ 7 files changed, 663 insertions(+), 7 deletions(-) create mode 100644 delphi/scorers/intervention/__init__.py create mode 100644 delphi/scorers/intervention/output_based_intervention_scorer.py create mode 100644 delphi/scorers/intervention/surprisal_intervention_scorer.py diff --git a/delphi/__main__.py b/delphi/__main__.py index 16f0c557..46111bc1 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -5,12 +5,15 @@ from pathlib import Path from typing import Callable +from dataclasses import asdict + import orjson import torch from simple_parsing import ArgumentParser from torch import Tensor from transformers import ( AutoModel, + AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, PreTrainedModel, @@ -27,7 +30,7 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator +from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, InterventionScorer, LogProbInterventionScorer, SurprisalInterventionScorer from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -40,7 +43,7 @@ def load_artifacts(run_cfg: RunConfig): else: dtype = "auto" - model = AutoModel.from_pretrained( + model = AutoModelForCausalLM.from_pretrained( run_cfg.model, device_map={"": "cuda"}, quantization_config=( @@ -118,6 +121,8 @@ async def process_cache( hookpoints: list[str], tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, latent_range: Tensor | None, + model, + hookpoint_to_sparse_encode ): """ Converts SAE latent activations in on-disk cache in the `latents_path` directory @@ -218,6 +223,12 @@ def none_postprocessor(result): postprocess=none_postprocessor, ) ) + + def custom_serializer(obj): + """A custom serializer for orjson to handle specific types.""" + if isinstance(obj, Tensor): + return obj.tolist() + raise TypeError # Builds the record from result returned by the pipeline def scorer_preprocess(result): @@ -230,12 +241,22 @@ def scorer_preprocess(result): return record # Saves the score to a file - def scorer_postprocess(result, score_dir): + # In your __main__.py file + + def scorer_postprocess(result, score_dir, scorer_name=None): + if isinstance(result, list): + if not result: + return + result = result[0] + safe_latent_name = str(result.record.latent).replace("/", "--") with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: - f.write(orjson.dumps(result.score)) + # This line now works universally. For other scorers, it saves their simple + # score. For surprisal_intervention, it saves the rich 'final_payload'. + f.write(orjson.dumps(result.score, default=custom_serializer)) + scorers = [] for scorer_name in run_cfg.scorers: scorer_path = scores_path / scorer_name @@ -257,6 +278,29 @@ def scorer_postprocess(result, score_dir): verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, ) + elif scorer_name == "intervention": + scorer = InterventionScorer( + llm_client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) + elif scorer_name == "logprob_intervention": + scorer = LogProbInterventionScorer( + llm_client, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) + elif scorer_name == "surprisal_intervention": + scorer = SurprisalInterventionScorer( + model, + hookpoint_to_sparse_encode, + hookpoints = run_cfg.hookpoints, + n_examples_shown=run_cfg.num_examples_per_scorer_prompt, + verbose=run_cfg.verbose, + log_prob=run_cfg.log_probs, + ) else: raise ValueError(f"Scorer {scorer_name} not supported") @@ -396,6 +440,8 @@ async def run( hookpoints, hookpoint_to_sparse_encode, model, transcode = load_artifacts(run_cfg) tokenizer = AutoTokenizer.from_pretrained(run_cfg.model, token=run_cfg.hf_token) + model.tokenizer = tokenizer + nrh = assert_type( dict, non_redundant_hookpoints( @@ -412,7 +458,6 @@ async def run( transcode, ) - del model, hookpoint_to_sparse_encode if run_cfg.constructor_cfg.non_activating_source == "neighbours": nrh = assert_type( list, @@ -445,8 +490,12 @@ async def run( nrh, tokenizer, latent_range, + model, + hookpoint_to_sparse_encode ) + del model, hookpoint_to_sparse_encode + if run_cfg.verbose: log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers) diff --git a/delphi/config.py b/delphi/config.py index 6e49b09d..0d2193c5 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -152,14 +152,17 @@ class RunConfig(Serializable): "fuzz", "detection", "simulation", + "intervention", + "logprob_intervention", + "surprisal_intervention" ], default=[ "fuzz", "detection", ], ) - """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', and - 'simulation'.""" + """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', + 'simulation' and 'intervention'.""" name: str = "" """The name of the run. Results are saved in a directory with this name.""" diff --git a/delphi/latents/latents.py b/delphi/latents/latents.py index 0f4ff94d..ca08ffaa 100644 --- a/delphi/latents/latents.py +++ b/delphi/latents/latents.py @@ -157,6 +157,13 @@ class LatentRecord: """Frequency of the latent. Number of activations in a context per total number of contexts.""" + @property + def feature_id(self) -> int: + """ + Returns the unique feature index for this latent. + """ + return self.latent.latent_index + @property def max_activation(self) -> float: """ diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 747db837..ad84c15f 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -6,6 +6,9 @@ from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer +from .intervention.intervention_scorer import InterventionScorer +from .intervention.logprob_intervention_scorer import LogProbInterventionScorer +from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer __all__ = [ "FuzzingScorer", @@ -16,4 +19,8 @@ "EmbeddingScorer", "IntruderScorer", "ExampleEmbeddingScorer", + "SurprisalInterventionScorer", + "InterventionScorer", + "LogProbInterventionScorer", + ] diff --git a/delphi/scorers/intervention/__init__.py b/delphi/scorers/intervention/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/delphi/scorers/intervention/output_based_intervention_scorer.py b/delphi/scorers/intervention/output_based_intervention_scorer.py new file mode 100644 index 00000000..9c706962 --- /dev/null +++ b/delphi/scorers/intervention/output_based_intervention_scorer.py @@ -0,0 +1,141 @@ +# Output-based intervention scorer (Gur-Arieh et al. 2025) +from dataclasses import dataclass +import torch +import torch.nn.functional as F +import random +from ...scorer import Scorer, ScorerResult +from ...latents import LatentRecord, ActivatingExample +from transformers import PreTrainedModel + +@dataclass +class OutputInterventionResult: + """Result of output-based intervention evaluation.""" + score: int # +1 if target set chosen, -1 otherwise + explanation: str + example_text: str + +class OutputInterventionScorer(Scorer): + """ + Output-based evaluation by steering (clamping) the feature and using a judge LLM + to pick which outputs best match the description:contentReference[oaicite:5]{index=5}. + We generate texts for the target feature and for a few random features, + then ask the judge to choose the matching set. + """ + name = "output_intervention" + + def __init__(self, subject_model: PreTrainedModel, explainer_model, **kwargs): + self.subject_model = subject_model + self.explainer_model = explainer_model + self.steering_strength = kwargs.get("strength", 5.0) + self.num_prompts = kwargs.get("num_prompts", 3) + self.num_random = kwargs.get("num_random_features", 2) + self.hookpoint = kwargs.get("hookpoint", "transformer.h.6.mlp") + self.tokenizer = getattr(subject_model, "tokenizer", None) + + async def __call__(self, record: LatentRecord) -> ScorerResult: + # Prepare activating prompts + examples = [ex for ex in record.test if isinstance(ex, ActivatingExample)] + random.shuffle(examples) + prompts = ["".join(str(t) for t in ex.str_tokens) for ex in examples[:self.num_prompts]] + + # Generate text for the target feature + target_texts = [] + for p in prompts: + text, _ = await self._generate(p, record.feature_id, self.steering_strength) + target_texts.append(text) + + # Pick a few random feature IDs (avoid the target) + random_ids = [] + while len(random_ids) < self.num_random: + rid = random.randint(0, 999) + if rid != record.feature_id: + random_ids.append(rid) + + # Generate texts for random features + random_sets = [] + for fid in random_ids: + rand_texts = [] + for p in prompts: + text, _ = await self._generate(p, fid, self.steering_strength) + rand_texts.append(text) + random_sets.append(rand_texts) + + # Create prompt for judge LLM + judge_prompt = self._format_judge_prompt(record.explanation, target_texts, random_sets) + judge_response = await self._ask_judge(judge_prompt) + + # Parse judge response: check if target set was chosen + resp_lower = judge_response.lower() + if "target" in resp_lower or "set 1" in resp_lower: + score = 1 + elif "set 2" in resp_lower or "set 3" in resp_lower or "random" in resp_lower: + score = -1 + else: + score = 0 + + example_text = prompts[0] if prompts else "" + detailed = OutputInterventionResult( + score=score, + explanation=record.explanation, + example_text=example_text + ) + return ScorerResult(record=record, score=detailed) + + async def _generate(self, prompt: str, feature_id: int, strength: float): + """ + Generates text with the feature clamped (added to hidden state). + Returns the (partial) generated text and logits. + """ + tokenizer = self.tokenizer or __import__("transformers").AutoTokenizer.from_pretrained("gpt2") + input_ids = tokenizer(prompt, return_tensors="pt").input_ids + + # Forward hook to clamp feature activation + direction = self.explainer_model.get_feature_vector(feature_id) + def hook_fn(module, inp, out): + out[:, -1, :] = out[:, -1, :] + strength * direction.to(out.device) + return out + layer = self._find_layer(self.subject_model, self.hookpoint) + handle = layer.register_forward_hook(hook_fn) + + with torch.no_grad(): + outputs = self.subject_model(input_ids) + logits = outputs.logits[0, -1, :] + log_probs = F.log_softmax(logits, dim=-1) + handle.remove() + + text = tokenizer.decode(input_ids[0]) + return text, log_probs + + def _format_judge_prompt(self, explanation: str, target_texts: list, other_sets: list): + """ + Constructs a prompt for the judge LLM listing each set of texts + under the target feature and random features. + """ + prompt = f"Feature description: \"{explanation}\"\n" + prompt += "Which of the following sets of generated texts best matches this description?\n\n" + prompt += "Set 1 (target feature):\n" + for txt in target_texts: + prompt += f"- {txt}\n" + for i, rand_set in enumerate(other_sets, start=2): + prompt += f"\nSet {i} (random feature):\n" + for txt in rand_set: + prompt += f"- {txt}\n" + prompt += "\nAnswer (mention the set number or 'target'/'random'): " + return prompt + + async def _ask_judge(self, prompt: str) -> str: + """ + Queries a judge LLM (e.g., GPT-4) with the prompt. Stubbed here. + """ + # TODO: Implement actual LLM call to get response + return "" + + def _find_layer(self, model, name: str): + """Locate a module by its dotted name.""" + current = model + for attr in name.split('.'): + if attr.isdigit(): + current = current[int(attr)] + else: + current = getattr(current, attr) + return current diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py new file mode 100644 index 00000000..f3678c9d --- /dev/null +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -0,0 +1,449 @@ +# surprisal_intervention_scorer.py +import functools +import random +import copy +from dataclasses import dataclass +from typing import Any, List, Dict, Tuple + +import torch +import torch.nn.functional as F +from transformers import AutoTokenizer + +# Assuming 'delphi' is your project structure. +# If not, you may need to adjust these relative imports. +from ..scorer import Scorer, ScorerResult +from ...latents import LatentRecord, ActivatingExample + +@dataclass +class SurprisalInterventionResult: + """ + Detailed results from the SurprisalInterventionScorer. + + Attributes: + score: The final computed score. + avg_kl: The average KL divergence between the clean and intervened next-token distributions. + explanation: The explanation string that was scored. + """ + score: float + avg_kl: float + explanation: str + + +class SurprisalInterventionScorer(Scorer): + """ + Implements the Surprisal / Log-Probability Intervention Scorer. + + This scorer evaluates an explanation for a model's latent feature by measuring + how much an intervention in the feature's direction increases the model's belief + (log-probability) in the explanation. The change in log-probability is normalized + by the intervention's strength, measured by the KL divergence between the clean + and intervened next-token distributions. + + Reference: Paulo et al., "Automatically Interpreting Millions of Features in Large Language Models" + (https://arxiv.org/pdf/2410.13928), Section 3.3.5[cite: 206, 207]. + + Pipeline: + 1. For a small set of activating prompts: + a. Generate a continuation and get the next-token distribution ("clean"). + b. Add a directional vector for the feature to the activations and repeat ("intervened"). + 2. Compute the log-probability of the explanation conditioned on both the clean + and intervened generated texts: log P(explanation | text)[cite: 209]. + 3. Compute the KL divergence between the clean and intervened next-token distributions[cite: 216]. + 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: + score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε)[cite: 209]. + """ + name = "surprisal_intervention" + + def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): + """ + Args: + subject_model: The language model to generate from and score with. + explainer_model: An optional model (e.g., an SAE) used to get feature directions. + **kwargs: Configuration options. + strength (float): The magnitude of the intervention. Default: 5.0. + num_prompts (int): Number of activating examples to test. Default: 3. + max_new_tokens (int): Max tokens to generate for continuations. Default: 20. + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') for the intervention. + """ + self.subject_model = subject_model + self.explainer_model = explainer_model + self.strength = float(kwargs.get("strength", 5.0)) + self.num_prompts = int(kwargs.get("num_prompts", 3)) + self.max_new_tokens = int(kwargs.get("max_new_tokens", 20)) + self.hookpoints = kwargs.get("hookpoints") + + if len(self.hookpoints): + self.hookpoint_str = self.hookpoints[0] + + # Ensure tokenizer is available + if hasattr(subject_model, "tokenizer"): + self.tokenizer = subject_model.tokenizer + else: + # Fallback to a standard tokenizer if not attached to the model + self.tokenizer = AutoTokenizer.from_pretrained("gpt2") + + if self.tokenizer.pad_token is None: + self.tokenizer.pad_token = self.tokenizer.eos_token + self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id + + def _get_device(self) -> torch.device: + """Safely gets the device of the subject model.""" + try: + return next(self.subject_model.parameters()).device + except StopIteration: + return torch.device("cuda" if torch.cuda.is_available() else "cpu") + + def _find_layer(self, model: Any, name: str) -> torch.nn.Module: + """Resolves a module by its dotted path name.""" + if name is None: + raise ValueError("Hookpoint name is not configured.") + current = model + for part in name.split("."): + if part.isdigit(): + current = current[int(part)] + else: + current = getattr(current, part) + return current + + def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: + """ + Dynamically finds the correct model prefix and resolves the full hookpoint path. + + This makes the scorer agnostic to different transformer architectures. + """ + parts = hookpoint_str.split('.') + + # 1. Validate the string format. + is_valid_format = ( + len(parts) == 3 and + parts[0] in ['layers', 'h'] and + parts[1].isdigit() and + parts[2] in ['mlp', 'attention', 'attn'] + ) + + if not is_valid_format: + # Fallback for simple block types at the top level, e.g. 'embed_in' + if len(parts) == 1 and hasattr(model, hookpoint_str): + return getattr(model, hookpoint_str) + raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") + # --- End of changes --- + + # 2. Heuristically find the model prefix. + prefix = None + for p in ["gpt_neox", "transformer", "model"]: + if hasattr(model, p): + candidate_body = getattr(model, p) + # Use parts[0] to get the layer block name ('layers' or 'h') + if hasattr(candidate_body, parts[0]): + prefix = p + break + + full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str + + # 3. Use the simple path finder to get the module. + try: + return self._find_layer(model, full_path) + except AttributeError as e: + raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") + + + + + # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + # """Ensures examples are in a consistent format: a list of dictionaries with 'str_tokens'.""" + # sanitized = [] + # for ex in examples: + # if isinstance(ex, dict) and "str_tokens" in ex: + # sanitized.append(ex) + # elif hasattr(ex, "str_tokens"): + # sanitized.append({"str_tokens": [str(t) for t in ex.str_tokens]}) + # elif isinstance(ex, str): + # sanitized.append({"str_tokens": [ex]}) + # elif isinstance(ex, (list, tuple)): + # sanitized.append({"str_tokens": [str(t) for t in ex]}) + # else: + # sanitized.append({"str_tokens": [str(ex)]}) + # return sanitized + + + def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + sanitized = [] + for ex in examples: + # --- NEW, MORE ROBUST LOGIC --- + # 1. Prioritize handling objects that have the data we need (like ActivatingExample) + if hasattr(ex, 'str_tokens') and ex.str_tokens is not None: + # This correctly handles ActivatingExample objects and similar structures. + # It extracts the string tokens instead of converting the whole object to a string. + sanitized.append({'str_tokens': ex.str_tokens}) + + # 2. Handle cases where the item is already a correct dictionary + elif isinstance(ex, dict) and "str_tokens" in ex: + sanitized.append(ex) + + # 3. Handle plain strings + elif isinstance(ex, str): + sanitized.append({"str_tokens": [ex]}) + + # 4. Handle lists/tuples of strings as a fallback + elif isinstance(ex, (list, tuple)): + sanitized.append({"str_tokens": [str(t) for t in ex]}) + + # 5. Handle any other unexpected type as a last resort + else: + sanitized.append({"str_tokens": [str(ex)]}) + + return sanitized + + + # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + + # sanitized = [] + # for i, ex in enumerate(examples): + + + # if isinstance(ex, dict) and "str_tokens" in ex: + # sanitized.append(ex) + + + # elif isinstance(ex, str): + # # This is the key conversion + # converted_ex = {"str_tokens": [ex]} + # sanitized.append(converted_ex) + + + # elif isinstance(ex, (list, tuple)): + # converted_ex = {"str_tokens": [str(t) for t in ex]} + # sanitized.append(converted_ex) + + # else: + # converted_ex = {"str_tokens": [str(ex)]} + # sanitized.append(converted_ex) + + # print("fin this") + # return sanitized + + async def __call__(self, record: LatentRecord) -> ScorerResult: + # --- MODIFICATION START --- + # 1. Create a deep copy to work on, ensuring we don't interfere + # with other parts of the pipeline that might use the original record. + record_copy = copy.deepcopy(record) + + # 2. Read the raw examples from our copy. + raw_examples = getattr(record_copy, "test", []) or [] + + if not raw_examples: + result = SurprisalInterventionResult(score=0.0, avg_kl=0.0, explanation=record_copy.explanation) + # Return the result with the original record since no changes were made. + return ScorerResult(record=record, score=result) + + # 3. Sanitize the examples. + examples = self._sanitize_examples(raw_examples) + + # 4. Overwrite the attributes on the copy with the clean data. + record_copy.test = examples + record_copy.examples = examples + record_copy.train = examples + + # Now, use the sanitized 'examples' and the 'record_copy' for all subsequent operations. + prompts = ["".join(ex["str_tokens"]) for ex in examples[:self.num_prompts]] + + total_diff = 0.0 + total_kl = 0.0 + n = 0 + + for prompt in prompts: + # Pass the clean record_copy to the generation methods. + clean_text, clean_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=False) + int_text, int_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=True) + + logp_clean = await self._score_explanation(clean_text, record_copy.explanation) + logp_int = await self._score_explanation(int_text, record_copy.explanation) + + p_clean = torch.exp(clean_logp_dist) + kl_div = F.kl_div(int_logp_dist, p_clean, reduction='sum', log_target=False).item() + + total_diff += logp_int - logp_clean + total_kl += kl_div + n += 1 + + avg_diff = total_diff / n if n > 0 else 0.0 + avg_kl = total_kl / n if n > 0 else 0.0 + final_score = avg_diff / (avg_kl + 1e-9) if n > 0 else 0.0 + + final_output_list = [] + for ex in examples[:self.num_prompts]: + final_output_list.append({ + "str_tokens": ex["str_tokens"], + # Add the final scores. These will be duplicated for each example. + "final_score": final_score, + "avg_kl_divergence": avg_kl, + # Add placeholder keys that the parser expects, with default values. + "distance": None, + "activating": None, + "prediction": None, + "correct": None, + "probability": None, + "activations": None, + }) + return ScorerResult(record=record_copy, score=final_output_list) + + async def _generate_with_and_without_intervention( + self, prompt: str, record: LatentRecord, intervene: bool + ) -> Tuple[str, torch.Tensor]: + """ + Generates a text continuation and returns the next-token log-probabilities. + + If `intervene` is True, it adds a feature direction to the activations at the + specified hookpoint before generation. + + Returns: + A tuple containing: + - The generated text (string). + - The log-probability distribution for the token immediately following the prompt (Tensor). + """ + device = self._get_device() + enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) + input_ids = enc["input_ids"].to(device) + + hooks = [] + if intervene: + + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + if hookpoint_str is None: + raise ValueError("No hookpoint string specified for intervention.") + + # Resolve the string into the actual layer module. + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + + direction = self._get_intervention_direction(record).to(device) + direction = direction.unsqueeze(0).unsqueeze(0) # Shape for broadcasting: [1, 1, D] + + def hook_fn(module, inp, out): + # Gracefully handle both tuple and tensor outputs + hidden_states = out[0] if isinstance(out, tuple) else out + + # Apply intervention to the last token's hidden state + hidden_states[:, -1:, :] += self.strength * direction + + # Return the modified activations in their original format + if isinstance(out, tuple): + return (hidden_states,) + out[1:] + return hidden_states + + hooks.append(layer_to_hook.register_forward_hook(hook_fn)) + + try: + with torch.no_grad(): + # 1. Get next-token logits for KL divergence calculation + outputs = self.subject_model(input_ids) + next_token_logits = outputs.logits[0, -1, :] + log_probs_next_token = F.log_softmax(next_token_logits, dim=-1) + + # 2. Generate the full text continuation + gen_ids = self.subject_model.generate( + input_ids, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id, + ) + generated_text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) + finally: + for h in hooks: + h.remove() + + return generated_text, log_probs_next_token.cpu() + + async def _score_explanation(self, generated_text: str, explanation: str) -> float: + """Computes log P(explanation | generated_text) under the subject model.""" + device = self._get_device() + + # Create the full input sequence: context + explanation + context_enc = self.tokenizer(generated_text, return_tensors="pt") + explanation_enc = self.tokenizer(explanation, return_tensors="pt") + + full_input_ids = torch.cat([context_enc.input_ids, explanation_enc.input_ids], dim=1).to(device) + + with torch.no_grad(): + outputs = self.subject_model(full_input_ids) + logits = outputs.logits + + # We only need to score the explanation part + context_len = context_enc.input_ids.shape[1] + # Get logits for positions that predict the explanation tokens + explanation_logits = logits[:, context_len - 1:-1, :] + + # Get the target token IDs for the explanation + target_ids = explanation_enc.input_ids.to(device) + + log_probs = F.log_softmax(explanation_logits, dim=-1) + + # Gather the log-probabilities of the actual explanation tokens + token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) + + return token_log_probs.sum().item() + + def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: + """ + Gets the feature direction vector, preferring an SAE if available, + otherwise falling back to estimating it from activations. + """ + # --- Fast Path: Try to get vector from an SAE-like explainer model --- + if self.explainer_model: + sae = None + candidate = self.explainer_model + if isinstance(self.explainer_model, dict): + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + candidate = self.explainer_model.get(hookpoint_str) + + if hasattr(candidate, 'get_feature_vector'): + sae = candidate + elif hasattr(candidate, 'sae') and hasattr(candidate.sae, 'get_feature_vector'): + sae = candidate.sae + + if sae: + direction = sae.get_feature_vector(record.feature_id) + if not isinstance(direction, torch.Tensor): + direction = torch.tensor(direction, dtype=torch.float32) + direction = direction.squeeze() + return F.normalize(direction, p=2, dim=0) + + # --- Fallback: Estimate direction from activating examples --- + return self._estimate_direction_from_examples(record) + + def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: + """Estimates an intervention direction by averaging activations.""" + device = self._get_device() + + examples = self._sanitize_examples(getattr(record, "test", []) or []) + if not examples: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + captured_activations = [] + def capture_hook(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + + # Now, hidden_states is guaranteed to be the 3D activation tensor + captured_activations.append(hidden_states[:, -1, :].detach().cpu()) + + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + handle = layer_to_hook.register_forward_hook(capture_hook) + + try: + for ex in examples[:min(8, self.num_prompts)]: + prompt = "".join(ex["str_tokens"]) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) + with torch.no_grad(): + self.subject_model(input_ids) + finally: + handle.remove() + + if not captured_activations: + hidden_dim = self.subject_model.config.hidden_size + return torch.zeros(hidden_dim, device=device) + + activations = torch.cat(captured_activations, dim=0).to(device) + direction = activations.mean(dim=0) + + return F.normalize(direction, p=2, dim=0) \ No newline at end of file From 0ad4424ef88ec5b17c21bb424f0c6937c02d3556 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Thu, 28 Aug 2025 20:44:05 +0000 Subject: [PATCH 02/25] Add metrics for surprisal_intervention --- delphi/log/result_analysis.py | 326 ++++++++++++---------------------- 1 file changed, 116 insertions(+), 210 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 9937bd96..4af7030a 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -14,7 +14,17 @@ def plot_firing_vs_f1( ) -> None: out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): + + if 'firing_count' not in module_df.columns: + print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") + continue + module_df = module_df.copy() + # Filter out rows where f1_score is NaN to avoid errors in plotting + module_df = module_df[module_df['f1_score'].notna()] + if module_df.empty: + continue + module_df["firing_rate"] = module_df["firing_count"] / num_tokens fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) fig.update_layout( @@ -26,30 +36,32 @@ def plot_firing_vs_f1( def import_plotly(): """Import plotly with mitigiation for MathJax bug.""" try: - import plotly.express as px # type: ignore - import plotly.io as pio # type: ignore + import plotly.express as px + import plotly.io as pio except ImportError: raise ImportError( "Plotly is not installed.\n" "Please install it using `pip install plotly`, " "or install the `[visualize]` extra." ) - pio.kaleido.scope.mathjax = None # https://github.com/plotly/plotly.py/issues/3469 + pio.kaleido.scope.mathjax = None return px def compute_auc(df: pd.DataFrame) -> float | None: - if not df.probability.nunique(): - return None - + # Filter for rows where probability is not None and there's more than one unique value valid_df = df[df.probability.notna()] - - return roc_auc_score(valid_df.activating, valid_df.probability) # type: ignore + if valid_df.probability.nunique() <= 1: + return None + return roc_auc_score(valid_df.activating, valid_df.probability) def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): + # Filter out surprisal_intervention as 'accuracy' is not relevant for it + if label == 'surprisal_intervention': + continue fig = px.histogram( df[df["score_type"] == label], x="accuracy", @@ -60,11 +72,10 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - if not df.probability.nunique(): - return - - # filter out NANs + # Filter for rows where probability is not None and there's more than one unique value valid_df = df[df.probability.notna()] + if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: + return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) auc = roc_auc_score(valid_df.activating, valid_df.probability) @@ -85,67 +96,41 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path): def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: df_valid = df[df["prediction"].notna()] - act = df_valid["activating"].astype(bool) + if df_valid.empty: + return dict(true_positives=0, true_negatives=0, false_positives=0, false_negatives=0, + total_examples=0, total_positives=0, total_negatives=0, failed_count=len(df)) + act = df_valid["activating"].astype(bool) total = len(df_valid) pos = act.sum() neg = total - pos - tp = ((df_valid.prediction >= threshold) & act).sum() tn = ((df_valid.prediction < threshold) & ~act).sum() fp = ((df_valid.prediction >= threshold) & ~act).sum() fn = ((df_valid.prediction < threshold) & act).sum() - assert fp <= neg and tn <= neg and tp <= pos and fn <= pos - return dict( - true_positives=tp, - true_negatives=tn, - false_positives=fp, - false_negatives=fn, - total_examples=total, - total_positives=pos, - total_negatives=neg, - failed_count=len(df_valid) - total, + true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, + total_examples=total, total_positives=pos, total_negatives=neg, + failed_count=len(df) - len(df_valid), ) def compute_classification_metrics(conf: dict) -> dict: - tp = conf["true_positives"] - tn = conf["true_negatives"] - fp = conf["false_positives"] - fn = conf["false_negatives"] - total = conf["total_examples"] - pos = conf["total_positives"] - neg = conf["total_negatives"] - - assert pos + neg == total, "pos + neg must equal total" - - # accuracy = (tp + tn) / total if total > 0 else 0 - balanced_accuracy = ( - (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) - ) / 2 - + tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] + pos, neg = conf["total_positives"], conf["total_negatives"] + + balanced_accuracy = ((tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0)) / 2 precision = tp / (tp + fp) if tp + fp > 0 else 0 recall = tp / pos if pos > 0 else 0 - f1 = ( - 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 - ) + f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 return dict( - precision=precision, - recall=recall, - f1_score=f1, - accuracy=balanced_accuracy, + precision=precision, recall=recall, f1_score=f1, accuracy=balanced_accuracy, true_positive_rate=tp / pos if pos > 0 else 0, true_negative_rate=tn / neg if neg > 0 else 0, false_positive_rate=fp / neg if neg > 0 else 0, false_negative_rate=fn / pos if pos > 0 else 0, - total_examples=total, - total_positives=pos, - total_negatives=neg, - positive_class_ratio=pos / total if total > 0 else 0, - negative_class_ratio=neg / total if total > 0 else 0, ) @@ -153,27 +138,32 @@ def load_data(scores_path: Path, modules: list[str]): """Load all on-disk data into a single DataFrame.""" def parse_score_file(path: Path) -> pd.DataFrame: - """ - Load a score file and return a raw DataFrame - """ try: data = orjson.loads(path.read_bytes()) except orjson.JSONDecodeError: print(f"Error decoding JSON from {path}. Skipping file.") return pd.DataFrame() + + if not isinstance(data, list): + print(f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.") + return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) + # --- MODIFICATION 1: PARSE THE NEW METRICS --- + # Updated to extract all possible keys safely using .get() return pd.DataFrame( [ { - "text": "".join(ex["str_tokens"]), - "distance": ex["distance"], - "activating": ex["activating"], - "prediction": ex["prediction"], - "probability": ex["probability"], - "correct": ex["correct"], - "activations": ex["activations"], + "text": "".join(ex.get("str_tokens", [])), + "distance": ex.get("distance"), + "activating": ex.get("activating"), + "prediction": ex.get("prediction"), + "probability": ex.get("probability"), + "correct": ex.get("correct"), + "activations": ex.get("activations"), + "final_score": ex.get("final_score"), + "avg_kl_divergence": ex.get("avg_kl_divergence"), "latent_idx": latent_idx, } for ex in data @@ -187,197 +177,113 @@ def parse_score_file(path: Path) -> pd.DataFrame: print(f"Missing modules: {[m for m in modules if m not in counts]}") counts = None - # Collect per-latent data latent_dfs = [] for score_type_dir in scores_path.iterdir(): if not score_type_dir.is_dir(): continue for module in modules: for file in score_type_dir.glob(f"*{module}*"): - latent_idx = int(file.stem.split("latent")[-1]) - latent_df = parse_score_file(file) + if latent_df.empty: + continue latent_df["score_type"] = score_type_dir.name latent_df["module"] = module - latent_df["latent_idx"] = latent_idx if counts: + latent_idx = latent_df["latent_idx"].iloc[0] latent_df["firing_count"] = ( counts[module][latent_idx].item() - if latent_idx in counts[module] + if module in counts and latent_idx in counts[module] else None ) - latent_dfs.append(latent_df) + if not latent_dfs: + return pd.DataFrame(), counts + return pd.concat(latent_dfs, ignore_index=True), counts -def frequency_weighted_f1( - df: pd.DataFrame, counts: dict[str, torch.Tensor] -) -> float | None: - rows = [] - for (module, latent_idx), grp in df.groupby(["module", "latent_idx"]): - f1 = compute_classification_metrics(compute_confusion(grp))["f1_score"] - fire = counts[module][latent_idx].item() - rows.append( - { - "module": module, - "latent_idx": latent_idx, - "f1_score": f1, - "firing_count": fire, - } - ) - - latent_df = pd.DataFrame(rows) - - per_module_f1 = [] - for module in latent_df["module"].unique(): - module_df = latent_df[latent_df["module"] == module] - - firing_weights = counts[module][module_df["latent_idx"]].float() - total_weight = firing_weights.sum() - if total_weight == 0: - continue - - f1_tensor = torch.as_tensor(module_df["f1_score"].values, dtype=torch.float32) - module_f1 = (f1_tensor * firing_weights).sum() / firing_weights.sum() - per_module_f1.append(module_f1) - - overall_frequency_weighted_f1 = torch.stack(per_module_f1).mean() - return ( - overall_frequency_weighted_f1.item() - if not overall_frequency_weighted_f1.isnan() - else None - ) - - def get_agg_metrics( latent_df: pd.DataFrame, counts: Optional[dict[str, torch.Tensor]] ) -> pd.DataFrame: processed_rows = [] for score_type, group_df in latent_df.groupby("score_type"): + # For surprisal_intervention, we don't compute classification metrics + if score_type == 'surprisal_intervention': + continue + conf = compute_confusion(group_df) class_m = compute_classification_metrics(conf) auc = compute_auc(group_df) f1_w = frequency_weighted_f1(group_df, counts) if counts else None - + row = { "score_type": score_type, - **conf, - **class_m, - "auc": auc, - "weighted_f1": f1_w, + **conf, **class_m, "auc": auc, "weighted_f1": f1_w } processed_rows.append(row) return pd.DataFrame(processed_rows) -def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: - f1s = ( - latent_df.groupby(["module", "latent_idx"]) - .apply( - lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] - ) - .reset_index(name="f1_score") # <- naive (un-weighted) F1 - ) - return latent_df.merge(f1s, on=["module", "latent_idx"]) - - def log_results( scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str] ): import_plotly() latent_df, counts = load_data(scores_path, modules) - latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - latent_df = add_latent_f1(latent_df) - - plot_firing_vs_f1( - latent_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name - ) - if latent_df.empty: - print("No data found") + print("No data to analyze.") return - - dead = sum((counts[m] == 0).sum().item() for m in modules) - print(f"Number of dead features: {dead}") - print(f"Number of interpreted live features: {len(latent_df)}") - - # Load constructor config for run - with open(scores_path.parent / "run_config.json", "r") as f: - run_cfg = orjson.loads(f.read()) - constructor_cfg = run_cfg.get("constructor_cfg", {}) - min_examples = constructor_cfg.get("min_examples", None) - print("min examples", min_examples) - - if min_examples is not None: - uninterpretable_features = sum( - [(counts[m] < min_examples).sum() for m in modules] - ) - print( - f"Number of features below the interpretation firing" - f" count threshold: {uninterpretable_features}" - ) - - plot_roc_curve(latent_df, viz_path) - - processed_df = get_agg_metrics(latent_df, counts) - - plot_accuracy_hist(processed_df, viz_path) - - for score_type in processed_df.score_type.unique(): - score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] - print(f"\n--- {score_type.title()} Metrics ---") - print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") - print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") - print( - "Note: the frequency-weighted F1 score is computed over each" - " hookpoint and averaged" - ) - print(f"Precision: {score_type_summary['precision']:.3f}") - print(f"Recall: {score_type_summary['recall']:.3f}") - # Only print AUC if unbalanced AUC is not -1. - if score_type_summary["auc"] is not None: - print(f"AUC: {score_type_summary['auc']:.3f}") + + latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] + + # Separate the dataframes for different processing + classification_df = latent_df[latent_df['score_type'] != 'surprisal_intervention'] + surprisal_df = latent_df[latent_df['score_type'] == 'surprisal_intervention'] + + if not classification_df.empty: + classification_df = add_latent_f1(classification_df) + if counts: + plot_firing_vs_f1(classification_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name) + plot_roc_curve(classification_df, viz_path) + processed_df = get_agg_metrics(classification_df, counts) + plot_accuracy_hist(processed_df, viz_path) + + if counts: + dead = sum((counts[m] == 0).sum().item() for m in modules) + print(f"Number of dead features: {dead}") + + # --- MODIFICATION 2: ADD CONDITIONAL REPORTING --- + # Loop through all scorer types found in the data + for score_type in latent_df["score_type"].unique(): + + # Handle the new scorer with its specific metrics + if score_type == 'surprisal_intervention': + # Drop duplicates since score is per-latent, not per-example + unique_latents = surprisal_df.drop_duplicates(subset=['module', 'latent_idx']) + avg_score = unique_latents['final_score'].mean() + avg_kl = unique_latents['avg_kl_divergence'].mean() + + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Average Normalized Score: {avg_score:.3f}") + print(f"Average KL Divergence: {avg_kl:.3f}") + + # Handle all other scorers with the original classification metrics else: - print("Logits not available.") - - fractions_failed = [ - score_type_summary["failed_count"] - / ( - ( - score_type_summary["total_examples"] - + score_type_summary["failed_count"] - ) - ) - ] - print( - f"""Average fraction of failed examples: \ -{sum(fractions_failed) / len(fractions_failed)}""" - ) - - print("\nConfusion Matrix:") - print( - f"True Positive Rate: {score_type_summary['true_positive_rate']:.3f} " - f"({score_type_summary['true_positives'].sum()})" - ) - print( - f"True Negative Rate: {score_type_summary['true_negative_rate']:.3f} " - f"({score_type_summary['true_negatives'].sum()})" - ) - print( - f"False Positive Rate: {score_type_summary['false_positive_rate']:.3f} " - f"({score_type_summary['false_positives'].sum()})" - ) - print( - f"False Negative Rate: {score_type_summary['false_negative_rate']:.3f} " - f"({score_type_summary['false_negatives'].sum()})" - ) - - print("\nClass Distribution:") - print(f"""Positives: {score_type_summary['total_positives'].sum():.0f}""") - print(f"""Negatives: {score_type_summary['total_negatives'].sum():.0f}""") - print(f"Total: {score_type_summary['total_examples'].sum():.0f}") + if not classification_df.empty: + score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] + print(f"\n--- {score_type.title()} Metrics ---") + print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") + print(f"F1 Score: {score_type_summary['f1_score']:.3f}") + + if counts and score_type_summary['weighted_f1'] is not None: + print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") + + print(f"Precision: {score_type_summary['precision']:.3f}") + print(f"Recall: {score_type_summary['recall']:.3f}") + + if score_type_summary["auc"] is not None: + print(f"AUC: {score_type_summary['auc']:.3f}") + else: + print("AUC not available.") \ No newline at end of file From aa12cf23c3e09e7c9534ee4475c75b008fb15a1d Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 14:15:20 +0000 Subject: [PATCH 03/25] Code cleaning --- delphi/__main__.py | 19 +---- delphi/config.py | 4 +- delphi/log/result_analysis.py | 6 +- .../surprisal_intervention_scorer.py | 82 ++----------------- 4 files changed, 9 insertions(+), 102 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 46111bc1..7a8fd399 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -30,7 +30,7 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, InterventionScorer, LogProbInterventionScorer, SurprisalInterventionScorer +from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, SurprisalInterventionScorer from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -252,8 +252,6 @@ def scorer_postprocess(result, score_dir, scorer_name=None): safe_latent_name = str(result.record.latent).replace("/", "--") with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: - # This line now works universally. For other scorers, it saves their simple - # score. For surprisal_intervention, it saves the rich 'final_payload'. f.write(orjson.dumps(result.score, default=custom_serializer)) @@ -278,20 +276,7 @@ def scorer_postprocess(result, score_dir, scorer_name=None): verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, ) - elif scorer_name == "intervention": - scorer = InterventionScorer( - llm_client, - n_examples_shown=run_cfg.num_examples_per_scorer_prompt, - verbose=run_cfg.verbose, - log_prob=run_cfg.log_probs, - ) - elif scorer_name == "logprob_intervention": - scorer = LogProbInterventionScorer( - llm_client, - n_examples_shown=run_cfg.num_examples_per_scorer_prompt, - verbose=run_cfg.verbose, - log_prob=run_cfg.log_probs, - ) + elif scorer_name == "surprisal_intervention": scorer = SurprisalInterventionScorer( model, diff --git a/delphi/config.py b/delphi/config.py index 0d2193c5..9d54c26e 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -152,8 +152,6 @@ class RunConfig(Serializable): "fuzz", "detection", "simulation", - "intervention", - "logprob_intervention", "surprisal_intervention" ], default=[ @@ -162,7 +160,7 @@ class RunConfig(Serializable): ], ) """Scorer methods to score latent explanations. Options are 'fuzz', 'detection', - 'simulation' and 'intervention'.""" + 'simulation' and 'surprisal_intervention'.""" name: str = "" """The name of the run. Results are saved in a directory with this name.""" diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 4af7030a..99666acb 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -150,7 +150,6 @@ def parse_score_file(path: Path) -> pd.DataFrame: latent_idx = int(path.stem.split("latent")[-1]) - # --- MODIFICATION 1: PARSE THE NEW METRICS --- # Updated to extract all possible keys safely using .get() return pd.DataFrame( [ @@ -254,11 +253,9 @@ def log_results( dead = sum((counts[m] == 0).sum().item() for m in modules) print(f"Number of dead features: {dead}") - # --- MODIFICATION 2: ADD CONDITIONAL REPORTING --- - # Loop through all scorer types found in the data + for score_type in latent_df["score_type"].unique(): - # Handle the new scorer with its specific metrics if score_type == 'surprisal_intervention': # Drop duplicates since score is per-latent, not per-example unique_latents = surprisal_df.drop_duplicates(subset=['module', 'latent_idx']) @@ -269,7 +266,6 @@ def log_results( print(f"Average Normalized Score: {avg_score:.3f}") print(f"Average KL Divergence: {avg_kl:.3f}") - # Handle all other scorers with the original classification metrics else: if not classification_df.empty: score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index f3678c9d..c683cca9 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -1,4 +1,3 @@ -# surprisal_intervention_scorer.py import functools import random import copy @@ -9,8 +8,6 @@ import torch.nn.functional as F from transformers import AutoTokenizer -# Assuming 'delphi' is your project structure. -# If not, you may need to adjust these relative imports. from ..scorer import Scorer, ScorerResult from ...latents import LatentRecord, ActivatingExample @@ -75,11 +72,9 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): if len(self.hookpoints): self.hookpoint_str = self.hookpoints[0] - # Ensure tokenizer is available if hasattr(subject_model, "tokenizer"): self.tokenizer = subject_model.tokenizer else: - # Fallback to a standard tokenizer if not attached to the model self.tokenizer = AutoTokenizer.from_pretrained("gpt2") if self.tokenizer.pad_token is None: @@ -113,7 +108,6 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: """ parts = hookpoint_str.split('.') - # 1. Validate the string format. is_valid_format = ( len(parts) == 3 and parts[0] in ['layers', 'h'] and @@ -122,129 +116,68 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: ) if not is_valid_format: - # Fallback for simple block types at the top level, e.g. 'embed_in' if len(parts) == 1 and hasattr(model, hookpoint_str): return getattr(model, hookpoint_str) raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") - # --- End of changes --- - # 2. Heuristically find the model prefix. + #Heuristically find the model prefix. prefix = None for p in ["gpt_neox", "transformer", "model"]: if hasattr(model, p): candidate_body = getattr(model, p) - # Use parts[0] to get the layer block name ('layers' or 'h') if hasattr(candidate_body, parts[0]): prefix = p break full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str - # 3. Use the simple path finder to get the module. try: return self._find_layer(model, full_path) except AttributeError as e: raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") - - - - # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: - # """Ensures examples are in a consistent format: a list of dictionaries with 'str_tokens'.""" - # sanitized = [] - # for ex in examples: - # if isinstance(ex, dict) and "str_tokens" in ex: - # sanitized.append(ex) - # elif hasattr(ex, "str_tokens"): - # sanitized.append({"str_tokens": [str(t) for t in ex.str_tokens]}) - # elif isinstance(ex, str): - # sanitized.append({"str_tokens": [ex]}) - # elif isinstance(ex, (list, tuple)): - # sanitized.append({"str_tokens": [str(t) for t in ex]}) - # else: - # sanitized.append({"str_tokens": [str(ex)]}) - # return sanitized - def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: + """ + Function used for formatting results to run smoothly in the delphi pipeline + """ sanitized = [] for ex in examples: - # --- NEW, MORE ROBUST LOGIC --- - # 1. Prioritize handling objects that have the data we need (like ActivatingExample) if hasattr(ex, 'str_tokens') and ex.str_tokens is not None: - # This correctly handles ActivatingExample objects and similar structures. - # It extracts the string tokens instead of converting the whole object to a string. sanitized.append({'str_tokens': ex.str_tokens}) - # 2. Handle cases where the item is already a correct dictionary elif isinstance(ex, dict) and "str_tokens" in ex: sanitized.append(ex) - # 3. Handle plain strings elif isinstance(ex, str): sanitized.append({"str_tokens": [ex]}) - # 4. Handle lists/tuples of strings as a fallback elif isinstance(ex, (list, tuple)): sanitized.append({"str_tokens": [str(t) for t in ex]}) - # 5. Handle any other unexpected type as a last resort else: sanitized.append({"str_tokens": [str(ex)]}) return sanitized - # def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: - - # sanitized = [] - # for i, ex in enumerate(examples): - - - # if isinstance(ex, dict) and "str_tokens" in ex: - # sanitized.append(ex) - - - # elif isinstance(ex, str): - # # This is the key conversion - # converted_ex = {"str_tokens": [ex]} - # sanitized.append(converted_ex) - - - # elif isinstance(ex, (list, tuple)): - # converted_ex = {"str_tokens": [str(t) for t in ex]} - # sanitized.append(converted_ex) - - # else: - # converted_ex = {"str_tokens": [str(ex)]} - # sanitized.append(converted_ex) - - # print("fin this") - # return sanitized async def __call__(self, record: LatentRecord) -> ScorerResult: - # --- MODIFICATION START --- - # 1. Create a deep copy to work on, ensuring we don't interfere - # with other parts of the pipeline that might use the original record. + record_copy = copy.deepcopy(record) - # 2. Read the raw examples from our copy. raw_examples = getattr(record_copy, "test", []) or [] if not raw_examples: result = SurprisalInterventionResult(score=0.0, avg_kl=0.0, explanation=record_copy.explanation) - # Return the result with the original record since no changes were made. return ScorerResult(record=record, score=result) - # 3. Sanitize the examples. examples = self._sanitize_examples(raw_examples) - # 4. Overwrite the attributes on the copy with the clean data. record_copy.test = examples record_copy.examples = examples record_copy.train = examples - # Now, use the sanitized 'examples' and the 'record_copy' for all subsequent operations. prompts = ["".join(ex["str_tokens"]) for ex in examples[:self.num_prompts]] total_diff = 0.0 @@ -252,7 +185,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: n = 0 for prompt in prompts: - # Pass the clean record_copy to the generation methods. clean_text, clean_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=False) int_text, int_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=True) @@ -274,7 +206,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: for ex in examples[:self.num_prompts]: final_output_list.append({ "str_tokens": ex["str_tokens"], - # Add the final scores. These will be duplicated for each example. "final_score": final_score, "avg_kl_divergence": avg_kl, # Add placeholder keys that the parser expects, with default values. @@ -312,14 +243,12 @@ async def _generate_with_and_without_intervention( if hookpoint_str is None: raise ValueError("No hookpoint string specified for intervention.") - # Resolve the string into the actual layer module. layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) direction = self._get_intervention_direction(record).to(device) direction = direction.unsqueeze(0).unsqueeze(0) # Shape for broadcasting: [1, 1, D] def hook_fn(module, inp, out): - # Gracefully handle both tuple and tensor outputs hidden_states = out[0] if isinstance(out, tuple) else out # Apply intervention to the last token's hidden state @@ -423,7 +352,6 @@ def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tenso def capture_hook(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out - # Now, hidden_states is guaranteed to be the 3D activation tensor captured_activations.append(hidden_states[:, -1, :].detach().cpu()) hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) From bbf915a005d5367a7063b96ff05b0e452e60d2d9 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:01:03 +0000 Subject: [PATCH 04/25] Fix results --- delphi/log/result_analysis.py | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 99666acb..bffa7f6b 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -225,6 +225,17 @@ def get_agg_metrics( return pd.DataFrame(processed_rows) +def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: + f1s = ( + latent_df.groupby(["module", "latent_idx"]) + .apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ) + .reset_index(name="f1_score") # <- naive (un-weighted) F1 + ) + return latent_df.merge(f1s, on=["module", "latent_idx"]) + + def log_results( scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str] ): From 90e8a34daf75acffb3f1d2c096b362e3087728d4 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:19:11 +0000 Subject: [PATCH 05/25] Remove output-based intervention --- .../output_based_intervention_scorer.py | 141 ------------------ 1 file changed, 141 deletions(-) delete mode 100644 delphi/scorers/intervention/output_based_intervention_scorer.py diff --git a/delphi/scorers/intervention/output_based_intervention_scorer.py b/delphi/scorers/intervention/output_based_intervention_scorer.py deleted file mode 100644 index 9c706962..00000000 --- a/delphi/scorers/intervention/output_based_intervention_scorer.py +++ /dev/null @@ -1,141 +0,0 @@ -# Output-based intervention scorer (Gur-Arieh et al. 2025) -from dataclasses import dataclass -import torch -import torch.nn.functional as F -import random -from ...scorer import Scorer, ScorerResult -from ...latents import LatentRecord, ActivatingExample -from transformers import PreTrainedModel - -@dataclass -class OutputInterventionResult: - """Result of output-based intervention evaluation.""" - score: int # +1 if target set chosen, -1 otherwise - explanation: str - example_text: str - -class OutputInterventionScorer(Scorer): - """ - Output-based evaluation by steering (clamping) the feature and using a judge LLM - to pick which outputs best match the description:contentReference[oaicite:5]{index=5}. - We generate texts for the target feature and for a few random features, - then ask the judge to choose the matching set. - """ - name = "output_intervention" - - def __init__(self, subject_model: PreTrainedModel, explainer_model, **kwargs): - self.subject_model = subject_model - self.explainer_model = explainer_model - self.steering_strength = kwargs.get("strength", 5.0) - self.num_prompts = kwargs.get("num_prompts", 3) - self.num_random = kwargs.get("num_random_features", 2) - self.hookpoint = kwargs.get("hookpoint", "transformer.h.6.mlp") - self.tokenizer = getattr(subject_model, "tokenizer", None) - - async def __call__(self, record: LatentRecord) -> ScorerResult: - # Prepare activating prompts - examples = [ex for ex in record.test if isinstance(ex, ActivatingExample)] - random.shuffle(examples) - prompts = ["".join(str(t) for t in ex.str_tokens) for ex in examples[:self.num_prompts]] - - # Generate text for the target feature - target_texts = [] - for p in prompts: - text, _ = await self._generate(p, record.feature_id, self.steering_strength) - target_texts.append(text) - - # Pick a few random feature IDs (avoid the target) - random_ids = [] - while len(random_ids) < self.num_random: - rid = random.randint(0, 999) - if rid != record.feature_id: - random_ids.append(rid) - - # Generate texts for random features - random_sets = [] - for fid in random_ids: - rand_texts = [] - for p in prompts: - text, _ = await self._generate(p, fid, self.steering_strength) - rand_texts.append(text) - random_sets.append(rand_texts) - - # Create prompt for judge LLM - judge_prompt = self._format_judge_prompt(record.explanation, target_texts, random_sets) - judge_response = await self._ask_judge(judge_prompt) - - # Parse judge response: check if target set was chosen - resp_lower = judge_response.lower() - if "target" in resp_lower or "set 1" in resp_lower: - score = 1 - elif "set 2" in resp_lower or "set 3" in resp_lower or "random" in resp_lower: - score = -1 - else: - score = 0 - - example_text = prompts[0] if prompts else "" - detailed = OutputInterventionResult( - score=score, - explanation=record.explanation, - example_text=example_text - ) - return ScorerResult(record=record, score=detailed) - - async def _generate(self, prompt: str, feature_id: int, strength: float): - """ - Generates text with the feature clamped (added to hidden state). - Returns the (partial) generated text and logits. - """ - tokenizer = self.tokenizer or __import__("transformers").AutoTokenizer.from_pretrained("gpt2") - input_ids = tokenizer(prompt, return_tensors="pt").input_ids - - # Forward hook to clamp feature activation - direction = self.explainer_model.get_feature_vector(feature_id) - def hook_fn(module, inp, out): - out[:, -1, :] = out[:, -1, :] + strength * direction.to(out.device) - return out - layer = self._find_layer(self.subject_model, self.hookpoint) - handle = layer.register_forward_hook(hook_fn) - - with torch.no_grad(): - outputs = self.subject_model(input_ids) - logits = outputs.logits[0, -1, :] - log_probs = F.log_softmax(logits, dim=-1) - handle.remove() - - text = tokenizer.decode(input_ids[0]) - return text, log_probs - - def _format_judge_prompt(self, explanation: str, target_texts: list, other_sets: list): - """ - Constructs a prompt for the judge LLM listing each set of texts - under the target feature and random features. - """ - prompt = f"Feature description: \"{explanation}\"\n" - prompt += "Which of the following sets of generated texts best matches this description?\n\n" - prompt += "Set 1 (target feature):\n" - for txt in target_texts: - prompt += f"- {txt}\n" - for i, rand_set in enumerate(other_sets, start=2): - prompt += f"\nSet {i} (random feature):\n" - for txt in rand_set: - prompt += f"- {txt}\n" - prompt += "\nAnswer (mention the set number or 'target'/'random'): " - return prompt - - async def _ask_judge(self, prompt: str) -> str: - """ - Queries a judge LLM (e.g., GPT-4) with the prompt. Stubbed here. - """ - # TODO: Implement actual LLM call to get response - return "" - - def _find_layer(self, model, name: str): - """Locate a module by its dotted name.""" - current = model - for attr in name.split('.'): - if attr.isdigit(): - current = current[int(attr)] - else: - current = getattr(current, attr) - return current From 9fca7db9f5d922d62af9869338158ded361756b8 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 15:20:56 +0000 Subject: [PATCH 06/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/__main__.py | 19 +- delphi/config.py | 7 +- delphi/log/result_analysis.py | 124 ++++++++---- delphi/scorers/__init__.py | 7 +- .../surprisal_intervention_scorer.py | 177 ++++++++++-------- 5 files changed, 201 insertions(+), 133 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 7a8fd399..bf24a557 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -5,14 +5,11 @@ from pathlib import Path from typing import Callable -from dataclasses import asdict - import orjson import torch from simple_parsing import ArgumentParser from torch import Tensor from transformers import ( - AutoModel, AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, @@ -30,7 +27,12 @@ from delphi.latents.neighbours import NeighbourCalculator from delphi.log.result_analysis import log_results from delphi.pipeline import Pipe, Pipeline, process_wrapper -from delphi.scorers import DetectionScorer, FuzzingScorer, OpenAISimulator, SurprisalInterventionScorer +from delphi.scorers import ( + DetectionScorer, + FuzzingScorer, + OpenAISimulator, + SurprisalInterventionScorer, +) from delphi.sparse_coders import load_hooks_sparse_coders, load_sparse_coders from delphi.utils import assert_type, load_tokenized_data @@ -122,7 +124,7 @@ async def process_cache( tokenizer: PreTrainedTokenizer | PreTrainedTokenizerFast, latent_range: Tensor | None, model, - hookpoint_to_sparse_encode + hookpoint_to_sparse_encode, ): """ Converts SAE latent activations in on-disk cache in the `latents_path` directory @@ -223,7 +225,7 @@ def none_postprocessor(result): postprocess=none_postprocessor, ) ) - + def custom_serializer(obj): """A custom serializer for orjson to handle specific types.""" if isinstance(obj, Tensor): @@ -254,7 +256,6 @@ def scorer_postprocess(result, score_dir, scorer_name=None): with open(score_dir / f"{safe_latent_name}.txt", "wb") as f: f.write(orjson.dumps(result.score, default=custom_serializer)) - scorers = [] for scorer_name in run_cfg.scorers: scorer_path = scores_path / scorer_name @@ -281,7 +282,7 @@ def scorer_postprocess(result, score_dir, scorer_name=None): scorer = SurprisalInterventionScorer( model, hookpoint_to_sparse_encode, - hookpoints = run_cfg.hookpoints, + hookpoints=run_cfg.hookpoints, n_examples_shown=run_cfg.num_examples_per_scorer_prompt, verbose=run_cfg.verbose, log_prob=run_cfg.log_probs, @@ -476,7 +477,7 @@ async def run( tokenizer, latent_range, model, - hookpoint_to_sparse_encode + hookpoint_to_sparse_encode, ) del model, hookpoint_to_sparse_encode diff --git a/delphi/config.py b/delphi/config.py index 9d54c26e..b20c3324 100644 --- a/delphi/config.py +++ b/delphi/config.py @@ -148,12 +148,7 @@ class RunConfig(Serializable): the default single token explainer, and 'none' for no explanation generation.""" scorers: list[str] = list_field( - choices=[ - "fuzz", - "detection", - "simulation", - "surprisal_intervention" - ], + choices=["fuzz", "detection", "simulation", "surprisal_intervention"], default=[ "fuzz", "detection", diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index bffa7f6b..bf53da22 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -15,13 +15,15 @@ def plot_firing_vs_f1( out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): - if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") + if "firing_count" not in module_df.columns: + print( + f"WARNING: 'firing_count' column not found for module {module}. Skipping plot." + ) continue module_df = module_df.copy() # Filter out rows where f1_score is NaN to avoid errors in plotting - module_df = module_df[module_df['f1_score'].notna()] + module_df = module_df[module_df["f1_score"].notna()] if module_df.empty: continue @@ -60,7 +62,7 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): # Filter out surprisal_intervention as 'accuracy' is not relevant for it - if label == 'surprisal_intervention': + if label == "surprisal_intervention": continue fig = px.histogram( df[df["score_type"] == label], @@ -74,7 +76,11 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): def plot_roc_curve(df: pd.DataFrame, out_dir: Path): # Filter for rows where probability is not None and there's more than one unique value valid_df = df[df.probability.notna()] - if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: + if ( + valid_df.empty + or valid_df.activating.nunique() <= 1 + or valid_df.probability.nunique() <= 1 + ): return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) @@ -97,8 +103,16 @@ def plot_roc_curve(df: pd.DataFrame, out_dir: Path): def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: df_valid = df[df["prediction"].notna()] if df_valid.empty: - return dict(true_positives=0, true_negatives=0, false_positives=0, false_negatives=0, - total_examples=0, total_positives=0, total_negatives=0, failed_count=len(df)) + return dict( + true_positives=0, + true_negatives=0, + false_positives=0, + false_negatives=0, + total_examples=0, + total_positives=0, + total_negatives=0, + failed_count=len(df), + ) act = df_valid["activating"].astype(bool) total = len(df_valid) @@ -110,23 +124,40 @@ def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: fn = ((df_valid.prediction < threshold) & act).sum() return dict( - true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, - total_examples=total, total_positives=pos, total_negatives=neg, + true_positives=tp, + true_negatives=tn, + false_positives=fp, + false_negatives=fn, + total_examples=total, + total_positives=pos, + total_negatives=neg, failed_count=len(df) - len(df_valid), ) def compute_classification_metrics(conf: dict) -> dict: - tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] + tp, tn, fp, fn = ( + conf["true_positives"], + conf["true_negatives"], + conf["false_positives"], + conf["false_negatives"], + ) pos, neg = conf["total_positives"], conf["total_negatives"] - - balanced_accuracy = ((tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0)) / 2 + + balanced_accuracy = ( + (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) + ) / 2 precision = tp / (tp + fp) if tp + fp > 0 else 0 recall = tp / pos if pos > 0 else 0 - f1 = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 + f1 = ( + 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 + ) return dict( - precision=precision, recall=recall, f1_score=f1, accuracy=balanced_accuracy, + precision=precision, + recall=recall, + f1_score=f1, + accuracy=balanced_accuracy, true_positive_rate=tp / pos if pos > 0 else 0, true_negative_rate=tn / neg if neg > 0 else 0, false_positive_rate=fp / neg if neg > 0 else 0, @@ -143,9 +174,11 @@ def parse_score_file(path: Path) -> pd.DataFrame: except orjson.JSONDecodeError: print(f"Error decoding JSON from {path}. Skipping file.") return pd.DataFrame() - + if not isinstance(data, list): - print(f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.") + print( + f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file." + ) return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) @@ -198,7 +231,7 @@ def parse_score_file(path: Path) -> pd.DataFrame: if not latent_dfs: return pd.DataFrame(), counts - + return pd.concat(latent_dfs, ignore_index=True), counts @@ -208,17 +241,20 @@ def get_agg_metrics( processed_rows = [] for score_type, group_df in latent_df.groupby("score_type"): # For surprisal_intervention, we don't compute classification metrics - if score_type == 'surprisal_intervention': + if score_type == "surprisal_intervention": continue - + conf = compute_confusion(group_df) class_m = compute_classification_metrics(conf) auc = compute_auc(group_df) f1_w = frequency_weighted_f1(group_df, counts) if counts else None - + row = { "score_type": score_type, - **conf, **class_m, "auc": auc, "weighted_f1": f1_w + **conf, + **class_m, + "auc": auc, + "weighted_f1": f1_w, } processed_rows.append(row) @@ -245,17 +281,22 @@ def log_results( if latent_df.empty: print("No data to analyze.") return - + latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - + # Separate the dataframes for different processing - classification_df = latent_df[latent_df['score_type'] != 'surprisal_intervention'] - surprisal_df = latent_df[latent_df['score_type'] == 'surprisal_intervention'] + classification_df = latent_df[latent_df["score_type"] != "surprisal_intervention"] + surprisal_df = latent_df[latent_df["score_type"] == "surprisal_intervention"] if not classification_df.empty: classification_df = add_latent_f1(classification_df) if counts: - plot_firing_vs_f1(classification_df, num_tokens=10_000_000, out_dir=viz_path, run_label=scores_path.name) + plot_firing_vs_f1( + classification_df, + num_tokens=10_000_000, + out_dir=viz_path, + run_label=scores_path.name, + ) plot_roc_curve(classification_df, viz_path) processed_df = get_agg_metrics(classification_df, counts) plot_accuracy_hist(processed_df, viz_path) @@ -263,34 +304,39 @@ def log_results( if counts: dead = sum((counts[m] == 0).sum().item() for m in modules) print(f"Number of dead features: {dead}") - for score_type in latent_df["score_type"].unique(): - - if score_type == 'surprisal_intervention': + + if score_type == "surprisal_intervention": # Drop duplicates since score is per-latent, not per-example - unique_latents = surprisal_df.drop_duplicates(subset=['module', 'latent_idx']) - avg_score = unique_latents['final_score'].mean() - avg_kl = unique_latents['avg_kl_divergence'].mean() - + unique_latents = surprisal_df.drop_duplicates( + subset=["module", "latent_idx"] + ) + avg_score = unique_latents["final_score"].mean() + avg_kl = unique_latents["avg_kl_divergence"].mean() + print(f"\n--- {score_type.title()} Metrics ---") print(f"Average Normalized Score: {avg_score:.3f}") print(f"Average KL Divergence: {avg_kl:.3f}") else: if not classification_df.empty: - score_type_summary = processed_df[processed_df.score_type == score_type].iloc[0] + score_type_summary = processed_df[ + processed_df.score_type == score_type + ].iloc[0] print(f"\n--- {score_type.title()} Metrics ---") print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - if counts and score_type_summary['weighted_f1'] is not None: - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") - + if counts and score_type_summary["weighted_f1"] is not None: + print( + f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}" + ) + print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") - + if score_type_summary["auc"] is not None: print(f"AUC: {score_type_summary['auc']:.3f}") else: - print("AUC not available.") \ No newline at end of file + print("AUC not available.") diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index ad84c15f..6eeed35b 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -3,12 +3,12 @@ from .classifier.intruder import IntruderScorer from .embedding.embedding import EmbeddingScorer from .embedding.example_embedding import ExampleEmbeddingScorer -from .scorer import Scorer -from .simulator.oai_simulator import OpenAISimulator -from .surprisal.surprisal import SurprisalScorer from .intervention.intervention_scorer import InterventionScorer from .intervention.logprob_intervention_scorer import LogProbInterventionScorer from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer +from .scorer import Scorer +from .simulator.oai_simulator import OpenAISimulator +from .surprisal.surprisal import SurprisalScorer __all__ = [ "FuzzingScorer", @@ -22,5 +22,4 @@ "SurprisalInterventionScorer", "InterventionScorer", "LogProbInterventionScorer", - ] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index c683cca9..88c8497c 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -1,15 +1,14 @@ -import functools -import random import copy from dataclasses import dataclass -from typing import Any, List, Dict, Tuple +from typing import Any, Dict, List, Tuple import torch import torch.nn.functional as F from transformers import AutoTokenizer +from ...latents import LatentRecord from ..scorer import Scorer, ScorerResult -from ...latents import LatentRecord, ActivatingExample + @dataclass class SurprisalInterventionResult: @@ -21,6 +20,7 @@ class SurprisalInterventionResult: avg_kl: The average KL divergence between the clean and intervened next-token distributions. explanation: The explanation string that was scored. """ + score: float avg_kl: float explanation: str @@ -49,6 +49,7 @@ class SurprisalInterventionScorer(Scorer): 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε)[cite: 209]. """ + name = "surprisal_intervention" def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): @@ -76,7 +77,7 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): self.tokenizer = subject_model.tokenizer else: self.tokenizer = AutoTokenizer.from_pretrained("gpt2") - + if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id @@ -99,28 +100,30 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module: else: current = getattr(current, part) return current - + def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: """ Dynamically finds the correct model prefix and resolves the full hookpoint path. - + This makes the scorer agnostic to different transformer architectures. """ - parts = hookpoint_str.split('.') - + parts = hookpoint_str.split(".") + is_valid_format = ( - len(parts) == 3 and - parts[0] in ['layers', 'h'] and - parts[1].isdigit() and - parts[2] in ['mlp', 'attention', 'attn'] + len(parts) == 3 + and parts[0] in ["layers", "h"] + and parts[1].isdigit() + and parts[2] in ["mlp", "attention", "attn"] ) if not is_valid_format: if len(parts) == 1 and hasattr(model, hookpoint_str): - return getattr(model, hookpoint_str) - raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") + return getattr(model, hookpoint_str) + raise ValueError( + f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'." + ) - #Heuristically find the model prefix. + # Heuristically find the model prefix. prefix = None for p in ["gpt_neox", "transformer", "model"]: if hasattr(model, p): @@ -128,14 +131,15 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: if hasattr(candidate_body, parts[0]): prefix = p break - + full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") - + raise AttributeError( + f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}" + ) def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ @@ -143,57 +147,69 @@ def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ sanitized = [] for ex in examples: - if hasattr(ex, 'str_tokens') and ex.str_tokens is not None: - sanitized.append({'str_tokens': ex.str_tokens}) - + if hasattr(ex, "str_tokens") and ex.str_tokens is not None: + sanitized.append({"str_tokens": ex.str_tokens}) + elif isinstance(ex, dict) and "str_tokens" in ex: sanitized.append(ex) - + elif isinstance(ex, str): sanitized.append({"str_tokens": [ex]}) - + elif isinstance(ex, (list, tuple)): sanitized.append({"str_tokens": [str(t) for t in ex]}) - + else: sanitized.append({"str_tokens": [str(ex)]}) - - return sanitized - + return sanitized async def __call__(self, record: LatentRecord) -> ScorerResult: record_copy = copy.deepcopy(record) raw_examples = getattr(record_copy, "test", []) or [] - + if not raw_examples: - result = SurprisalInterventionResult(score=0.0, avg_kl=0.0, explanation=record_copy.explanation) + result = SurprisalInterventionResult( + score=0.0, avg_kl=0.0, explanation=record_copy.explanation + ) return ScorerResult(record=record, score=result) examples = self._sanitize_examples(raw_examples) - + record_copy.test = examples record_copy.examples = examples record_copy.train = examples - - prompts = ["".join(ex["str_tokens"]) for ex in examples[:self.num_prompts]] - + + prompts = ["".join(ex["str_tokens"]) for ex in examples[: self.num_prompts]] + total_diff = 0.0 total_kl = 0.0 n = 0 for prompt in prompts: - clean_text, clean_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=False) - int_text, int_logp_dist = await self._generate_with_and_without_intervention(prompt, record_copy, intervene=True) - - logp_clean = await self._score_explanation(clean_text, record_copy.explanation) + clean_text, clean_logp_dist = ( + await self._generate_with_and_without_intervention( + prompt, record_copy, intervene=False + ) + ) + int_text, int_logp_dist = ( + await self._generate_with_and_without_intervention( + prompt, record_copy, intervene=True + ) + ) + + logp_clean = await self._score_explanation( + clean_text, record_copy.explanation + ) logp_int = await self._score_explanation(int_text, record_copy.explanation) - + p_clean = torch.exp(clean_logp_dist) - kl_div = F.kl_div(int_logp_dist, p_clean, reduction='sum', log_target=False).item() - + kl_div = F.kl_div( + int_logp_dist, p_clean, reduction="sum", log_target=False + ).item() + total_diff += logp_int - logp_clean total_kl += kl_div n += 1 @@ -203,19 +219,21 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: final_score = avg_diff / (avg_kl + 1e-9) if n > 0 else 0.0 final_output_list = [] - for ex in examples[:self.num_prompts]: - final_output_list.append({ - "str_tokens": ex["str_tokens"], - "final_score": final_score, - "avg_kl_divergence": avg_kl, - # Add placeholder keys that the parser expects, with default values. - "distance": None, - "activating": None, - "prediction": None, - "correct": None, - "probability": None, - "activations": None, - }) + for ex in examples[: self.num_prompts]: + final_output_list.append( + { + "str_tokens": ex["str_tokens"], + "final_score": final_score, + "avg_kl_divergence": avg_kl, + # Add placeholder keys that the parser expects, with default values. + "distance": None, + "activating": None, + "prediction": None, + "correct": None, + "probability": None, + "activations": None, + } + ) return ScorerResult(record=record_copy, score=final_output_list) async def _generate_with_and_without_intervention( @@ -235,7 +253,7 @@ async def _generate_with_and_without_intervention( device = self._get_device() enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) input_ids = enc["input_ids"].to(device) - + hooks = [] if intervene: @@ -246,14 +264,16 @@ async def _generate_with_and_without_intervention( layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) direction = self._get_intervention_direction(record).to(device) - direction = direction.unsqueeze(0).unsqueeze(0) # Shape for broadcasting: [1, 1, D] + direction = direction.unsqueeze(0).unsqueeze( + 0 + ) # Shape for broadcasting: [1, 1, D] def hook_fn(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out - + # Apply intervention to the last token's hidden state hidden_states[:, -1:, :] += self.strength * direction - + # Return the modified activations in their original format if isinstance(out, tuple): return (hidden_states,) + out[1:] @@ -285,13 +305,15 @@ def hook_fn(module, inp, out): async def _score_explanation(self, generated_text: str, explanation: str) -> float: """Computes log P(explanation | generated_text) under the subject model.""" device = self._get_device() - + # Create the full input sequence: context + explanation context_enc = self.tokenizer(generated_text, return_tensors="pt") explanation_enc = self.tokenizer(explanation, return_tensors="pt") - - full_input_ids = torch.cat([context_enc.input_ids, explanation_enc.input_ids], dim=1).to(device) - + + full_input_ids = torch.cat( + [context_enc.input_ids, explanation_enc.input_ids], dim=1 + ).to(device) + with torch.no_grad(): outputs = self.subject_model(full_input_ids) logits = outputs.logits @@ -299,16 +321,16 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo # We only need to score the explanation part context_len = context_enc.input_ids.shape[1] # Get logits for positions that predict the explanation tokens - explanation_logits = logits[:, context_len - 1:-1, :] - + explanation_logits = logits[:, context_len - 1 : -1, :] + # Get the target token IDs for the explanation target_ids = explanation_enc.input_ids.to(device) - + log_probs = F.log_softmax(explanation_logits, dim=-1) - + # Gather the log-probabilities of the actual explanation tokens token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) - + return token_log_probs.sum().item() def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: @@ -324,9 +346,11 @@ def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) candidate = self.explainer_model.get(hookpoint_str) - if hasattr(candidate, 'get_feature_vector'): + if hasattr(candidate, "get_feature_vector"): sae = candidate - elif hasattr(candidate, 'sae') and hasattr(candidate.sae, 'get_feature_vector'): + elif hasattr(candidate, "sae") and hasattr( + candidate.sae, "get_feature_vector" + ): sae = candidate.sae if sae: @@ -342,16 +366,17 @@ def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: """Estimates an intervention direction by averaging activations.""" device = self._get_device() - + examples = self._sanitize_examples(getattr(record, "test", []) or []) if not examples: hidden_dim = self.subject_model.config.hidden_size return torch.zeros(hidden_dim, device=device) captured_activations = [] + def capture_hook(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out - + captured_activations.append(hidden_states[:, -1, :].detach().cpu()) hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) @@ -359,9 +384,11 @@ def capture_hook(module, inp, out): handle = layer_to_hook.register_forward_hook(capture_hook) try: - for ex in examples[:min(8, self.num_prompts)]: + for ex in examples[: min(8, self.num_prompts)]: prompt = "".join(ex["str_tokens"]) - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to( + device + ) with torch.no_grad(): self.subject_model(input_ids) finally: @@ -373,5 +400,5 @@ def capture_hook(module, inp, out): activations = torch.cat(captured_activations, dim=0).to(device) direction = activations.mean(dim=0) - - return F.normalize(direction, p=2, dim=0) \ No newline at end of file + + return F.normalize(direction, p=2, dim=0) From d3d269e53d678643ea8bb61a75de51e7003f3cbe Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:51:10 +0000 Subject: [PATCH 07/25] Fix pre-commit --- delphi/log/result_analysis.py | 13 +++++---- delphi/scorers/__init__.py | 4 --- .../surprisal_intervention_scorer.py | 29 +++++++++++-------- 3 files changed, 25 insertions(+), 21 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index bffa7f6b..9f8c4e65 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,7 +16,8 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") + print(f"""WARNING: 'firing_count' column not found for module {module}. + Skipping plot.""") continue module_df = module_df.copy() @@ -49,7 +50,7 @@ def import_plotly(): def compute_auc(df: pd.DataFrame) -> float | None: - # Filter for rows where probability is not None and there's more than one unique value + valid_df = df[df.probability.notna()] if valid_df.probability.nunique() <= 1: return None @@ -72,7 +73,7 @@ def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - # Filter for rows where probability is not None and there's more than one unique value + valid_df = df[df.probability.notna()] if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: return @@ -145,7 +146,8 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.") + print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. + Skipping file.""") return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) @@ -285,7 +287,8 @@ def log_results( print(f"F1 Score: {score_type_summary['f1_score']:.3f}") if counts and score_type_summary['weighted_f1'] is not None: - print(f"Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}") + print(f"""Frequency-Weighted F1 Score: + {score_type_summary['weighted_f1']:.3f}""") print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index ad84c15f..3bdd9d4d 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -6,8 +6,6 @@ from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer -from .intervention.intervention_scorer import InterventionScorer -from .intervention.logprob_intervention_scorer import LogProbInterventionScorer from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer __all__ = [ @@ -20,7 +18,5 @@ "IntruderScorer", "ExampleEmbeddingScorer", "SurprisalInterventionScorer", - "InterventionScorer", - "LogProbInterventionScorer", ] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index c683cca9..a1fe6618 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -18,7 +18,7 @@ class SurprisalInterventionResult: Attributes: score: The final computed score. - avg_kl: The average KL divergence between the clean and intervened next-token distributions. + avg_kl: The average KL-D between clean & intervened next-token distributions. explanation: The explanation string that was scored. """ score: float @@ -36,18 +36,19 @@ class SurprisalInterventionScorer(Scorer): by the intervention's strength, measured by the KL divergence between the clean and intervened next-token distributions. - Reference: Paulo et al., "Automatically Interpreting Millions of Features in Large Language Models" + Reference: Paulo et al., "Automatically Interpreting Millions of Features in LLMs" (https://arxiv.org/pdf/2410.13928), Section 3.3.5[cite: 206, 207]. Pipeline: 1. For a small set of activating prompts: a. Generate a continuation and get the next-token distribution ("clean"). - b. Add a directional vector for the feature to the activations and repeat ("intervened"). + b. Add directional vector for the feature to the activations ("intervened"). 2. Compute the log-probability of the explanation conditioned on both the clean and intervened generated texts: log P(explanation | text)[cite: 209]. - 3. Compute the KL divergence between the clean and intervened next-token distributions[cite: 216]. - 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: - score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε)[cite: 209]. + 3. Compute KL divergence between the clean & intervened next-token distributions. + 4. The final score is the mean change in explanation log-prob, divided by the + mean KL divergence: + score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε). """ name = "surprisal_intervention" @@ -55,12 +56,13 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): """ Args: subject_model: The language model to generate from and score with. - explainer_model: An optional model (e.g., an SAE) used to get feature directions. + explainer_model: A model (e.g., an SAE) used to get feature directions. **kwargs: Configuration options. strength (float): The magnitude of the intervention. Default: 5.0. num_prompts (int): Number of activating examples to test. Default: 3. - max_new_tokens (int): Max tokens to generate for continuations. Default: 20. - hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') for the intervention. + max_new_tokens (int): Max tokens to generate for continuations. + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') + for the intervention. """ self.subject_model = subject_model self.explainer_model = explainer_model @@ -118,7 +120,8 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: if not is_valid_format: if len(parts) == 1 and hasattr(model, hookpoint_str): return getattr(model, hookpoint_str) - raise ValueError(f"Hookpoint string '{hookpoint_str}' is not in a recognized format like 'layers.6.mlp'.") + raise ValueError(f"""Hookpoint string '{hookpoint_str}' is not in a recognized format + like 'layers.6.mlp'.""") #Heuristically find the model prefix. prefix = None @@ -134,7 +137,8 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") + raise AttributeError(f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. Original error: {e}""") def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: @@ -230,7 +234,8 @@ async def _generate_with_and_without_intervention( Returns: A tuple containing: - The generated text (string). - - The log-probability distribution for the token immediately following the prompt (Tensor). + - The log-probability distribution for the token immediately following + the prompt (Tensor). """ device = self._get_device() enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) From 126b97814f93f440a32ff11ef980110fda4a5dce Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 15:55:24 +0000 Subject: [PATCH 08/25] Fix pre-commit --- delphi/log/result_analysis.py | 2 +- delphi/scorers/__init__.py | 1 + delphi/scorers/intervention/surprisal_intervention_scorer.py | 3 ++- 3 files changed, 4 insertions(+), 2 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 9f8c4e65..52f4b70f 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,7 +16,7 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"""WARNING: 'firing_count' column not found for module {module}. + print(f"""WARNING:'firing_count' column not found for module {module}. Skipping plot.""") continue diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 3bdd9d4d..84a98012 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -8,6 +8,7 @@ from .surprisal.surprisal import SurprisalScorer from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer + __all__ = [ "FuzzingScorer", "OpenAISimulator", diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index a1fe6618..387ed0eb 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -18,7 +18,8 @@ class SurprisalInterventionResult: Attributes: score: The final computed score. - avg_kl: The average KL-D between clean & intervened next-token distributions. + avg_kl: The average KL divergence between clean & intervened + next-token distributions. explanation: The explanation string that was scored. """ score: float From 8e893e063899ac485444fc366849dec33b623839 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:01:34 +0000 Subject: [PATCH 09/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/log/result_analysis.py | 8 +++---- delphi/scorers/__init__.py | 3 +-- .../surprisal_intervention_scorer.py | 21 +++++++++++-------- 3 files changed, 17 insertions(+), 15 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index ae1aca3b..65334e52 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,7 +16,7 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. + print(f"WARNING: 'firing_count' column not found for module {module}. Skipping plot.") continue @@ -175,7 +175,7 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. + print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. Skipping file.""") return pd.DataFrame() @@ -327,9 +327,9 @@ def log_results( print(f"F1 Score: {score_type_summary['f1_score']:.3f}") if counts and score_type_summary['weighted_f1'] is not None: - print(f"""Frequency-Weighted F1 Score: + print(f"""Frequency-Weighted F1 Score: {score_type_summary['weighted_f1']:.3f}""") - + print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") diff --git a/delphi/scorers/__init__.py b/delphi/scorers/__init__.py index 3bdd9d4d..1191688c 100644 --- a/delphi/scorers/__init__.py +++ b/delphi/scorers/__init__.py @@ -3,10 +3,10 @@ from .classifier.intruder import IntruderScorer from .embedding.embedding import EmbeddingScorer from .embedding.example_embedding import ExampleEmbeddingScorer +from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer from .scorer import Scorer from .simulator.oai_simulator import OpenAISimulator from .surprisal.surprisal import SurprisalScorer -from .intervention.surprisal_intervention_scorer import SurprisalInterventionScorer __all__ = [ "FuzzingScorer", @@ -18,5 +18,4 @@ "IntruderScorer", "ExampleEmbeddingScorer", "SurprisalInterventionScorer", - ] diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 06eb2300..191d4c6f 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -17,7 +17,7 @@ class SurprisalInterventionResult: Attributes: score: The final computed score. - avg_kl: The average KL divergence between clean & intervened + avg_kl: The average KL divergence between clean & intervened next-token distributions. explanation: The explanation string that was scored. """ @@ -47,7 +47,7 @@ class SurprisalInterventionScorer(Scorer): 2. Compute the log-probability of the explanation conditioned on both the clean and intervened generated texts: log P(explanation | text)[cite: 209]. 3. Compute KL divergence between the clean & intervened next-token distributions. - 4. The final score is the mean change in explanation log-prob, divided by the + 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: score = mean(log_prob_intervened - log_prob_clean) / (mean_KL + ε). """ @@ -63,7 +63,7 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): strength (float): The magnitude of the intervention. Default: 5.0. num_prompts (int): Number of activating examples to test. Default: 3. max_new_tokens (int): Max tokens to generate for continuations. - hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') + hookpoint (str): The module name (e.g., 'transformer.h.10.mlp') for the intervention. """ self.subject_model = subject_model @@ -121,9 +121,11 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: if not is_valid_format: if len(parts) == 1 and hasattr(model, hookpoint_str): - return getattr(model, hookpoint_str) - raise ValueError(f"""Hookpoint string '{hookpoint_str}' is not in a recognized format - like 'layers.6.mlp'.""") + return getattr(model, hookpoint_str) + raise ValueError( + f"""Hookpoint string '{hookpoint_str}' is not in a recognized format + like 'layers.6.mlp'.""" + ) # Heuristically find the model prefix. prefix = None @@ -139,8 +141,9 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") - + raise AttributeError( + f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}" + ) def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ @@ -249,7 +252,7 @@ async def _generate_with_and_without_intervention( Returns: A tuple containing: - The generated text (string). - - The log-probability distribution for the token immediately following + - The log-probability distribution for the token immediately following the prompt (Tensor). """ device = self._get_device() From 2a546e09ce3bc59a6ec0fffb11162972cdb34052 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Fri, 29 Aug 2025 16:07:31 +0000 Subject: [PATCH 10/25] Fix EOFs --- delphi/log/result_analysis.py | 9 +++++---- .../intervention/surprisal_intervention_scorer.py | 4 +++- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index ae1aca3b..c8c8a48e 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -16,8 +16,8 @@ def plot_firing_vs_f1( for module, module_df in latent_df.groupby("module"): if 'firing_count' not in module_df.columns: - print(f"WARNING: 'firing_count' column not found for module {module}. - Skipping plot.") + print(f"""WARNING: 'firing_count' column not found for module {module}. + Skipping plot.""") continue module_df = module_df.copy() @@ -175,8 +175,9 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"""Warning: Expected a list of results in {path}, but found {type(data)}. - Skipping file.""") + print(f"""Warning: Expected a list of results in {path}, + but found {type(data)}. + Skipping file.""") return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 06eb2300..66421eef 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -139,7 +139,9 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"Could not resolve path '{full_path}'. Model structure might be unexpected. Original error: {e}") + raise AttributeError(f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. + Original error: {e}""") def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: From 6a6368c559d26be30d3f1dbe416388d9aa605d40 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 29 Aug 2025 16:09:54 +0000 Subject: [PATCH 11/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/log/result_analysis.py | 24 ++++++++++++------- .../surprisal_intervention_scorer.py | 9 +++---- 2 files changed, 20 insertions(+), 13 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 07fec861..4852e8d6 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -15,9 +15,11 @@ def plot_firing_vs_f1( out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): - if 'firing_count' not in module_df.columns: - print(f"""WARNING: 'firing_count' column not found for module {module}. - Skipping plot.""") + if "firing_count" not in module_df.columns: + print( + f"""WARNING: 'firing_count' column not found for module {module}. + Skipping plot.""" + ) continue module_df = module_df.copy() @@ -175,9 +177,11 @@ def parse_score_file(path: Path) -> pd.DataFrame: return pd.DataFrame() if not isinstance(data, list): - print(f"""Warning: Expected a list of results in {path}, - but found {type(data)}. - Skipping file.""") + print( + f"""Warning: Expected a list of results in {path}, + but found {type(data)}. + Skipping file.""" + ) return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) @@ -327,9 +331,11 @@ def log_results( print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - if counts and score_type_summary['weighted_f1'] is not None: - print(f"""Frequency-Weighted F1 Score: - {score_type_summary['weighted_f1']:.3f}""") + if counts and score_type_summary["weighted_f1"] is not None: + print( + f"""Frequency-Weighted F1 Score: + {score_type_summary['weighted_f1']:.3f}""" + ) print(f"Precision: {score_type_summary['precision']:.3f}") print(f"Recall: {score_type_summary['recall']:.3f}") diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 1fd5ef14..b47d61a7 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -141,10 +141,11 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: try: return self._find_layer(model, full_path) except AttributeError as e: - raise AttributeError(f"""Could not resolve path '{full_path}'. - Model structure might be unexpected. - Original error: {e}""") - + raise AttributeError( + f"""Could not resolve path '{full_path}'. + Model structure might be unexpected. + Original error: {e}""" + ) def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ From 88f1b355e195bf09608c863f867a855b8ee0f727 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Wed, 10 Sep 2025 22:35:42 +0000 Subject: [PATCH 12/25] Tuned Kl divergence --- .../surprisal_intervention_scorer.py | 399 +++++++++++++----- 1 file changed, 283 insertions(+), 116 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 1fd5ef14..2aa8a8ad 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -1,3 +1,4 @@ +import functools import copy from dataclasses import dataclass from typing import Any, Dict, List, Tuple @@ -25,6 +26,7 @@ class SurprisalInterventionResult: score: float avg_kl: float explanation: str + tuned_strength: float class SurprisalInterventionScorer(Scorer): @@ -45,7 +47,7 @@ class SurprisalInterventionScorer(Scorer): a. Generate a continuation and get the next-token distribution ("clean"). b. Add directional vector for the feature to the activations ("intervened"). 2. Compute the log-probability of the explanation conditioned on both the clean - and intervened generated texts: log P(explanation | text)[cite: 209]. + and intervened generated texts: log P(explanation | text). 3. Compute KL divergence between the clean & intervened next-token distributions. 4. The final score is the mean change in explanation log-prob, divided by the mean KL divergence: @@ -73,6 +75,10 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): self.max_new_tokens = int(kwargs.get("max_new_tokens", 20)) self.hookpoints = kwargs.get("hookpoints") + self.target_kl = float(kwargs.get("target_kl", 1.0)) + self.kl_tolerance = float(kwargs.get("kl_tolerance", 0.1)) + self.max_search_steps = int(kwargs.get("max_search_steps", 15)) + if len(self.hookpoints): self.hookpoint_str = self.hookpoints[0] @@ -85,6 +91,7 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): self.tokenizer.pad_token = self.tokenizer.eos_token self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id + def _get_device(self) -> torch.device: """Safely gets the device of the subject model.""" try: @@ -92,6 +99,7 @@ def _get_device(self) -> torch.device: except StopIteration: return torch.device("cuda" if torch.cuda.is_available() else "cpu") + def _find_layer(self, model: Any, name: str) -> torch.nn.Module: """Resolves a module by its dotted path name.""" if name is None: @@ -104,40 +112,29 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module: current = getattr(current, part) return current - def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: - """ - Dynamically finds the correct model prefix and resolves the full hookpoint path. - - This makes the scorer agnostic to different transformer architectures. - """ - parts = hookpoint_str.split(".") - - is_valid_format = ( - len(parts) == 3 - and parts[0] in ["layers", "h"] - and parts[1].isdigit() - and parts[2] in ["mlp", "attention", "attn"] - ) - - if not is_valid_format: - if len(parts) == 1 and hasattr(model, hookpoint_str): - return getattr(model, hookpoint_str) - raise ValueError( - f"""Hookpoint string '{hookpoint_str}' is not in a recognized format - like 'layers.6.mlp'.""" - ) - # Heuristically find the model prefix. - prefix = None - for p in ["gpt_neox", "transformer", "model"]: - if hasattr(model, p): - candidate_body = getattr(model, p) - if hasattr(candidate_body, parts[0]): - prefix = p - break + def _get_full_hookpoint_path(self, hookpoint_str: str) -> str: + """ + Heuristically finds the model's prefix and constructs the full hookpoint path string. + e.g., 'layers.6.mlp' -> 'model.layers.6.mlp' + """ + # Heuristically find the model prefix. + prefix = None + for p in ["gpt_neox", "transformer", "model"]: + if hasattr(self.subject_model, p): + candidate_body = getattr(self.subject_model, p) + if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"): + prefix = p + break + + return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str - full_path = f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str + def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: + """ + Finds and returns the actual module object for a given hookpoint string. + """ + full_path = self._get_full_hookpoint_path(hookpoint_str) try: return self._find_layer(model, full_path) except AttributeError as e: @@ -169,6 +166,7 @@ def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: return sanitized + async def __call__(self, record: LatentRecord) -> ScorerResult: record_copy = copy.deepcopy(record) @@ -179,41 +177,37 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: result = SurprisalInterventionResult( score=0.0, avg_kl=0.0, explanation=record_copy.explanation ) - return ScorerResult(record=record, score=result) + return ScorerResult(record=record, score=[result.__dict__]) examples = self._sanitize_examples(raw_examples) - record_copy.test = examples - record_copy.examples = examples - record_copy.train = examples - prompts = ["".join(ex["str_tokens"]) for ex in examples[: self.num_prompts]] + #Step 1 - Truncate prompts before tuning or scoring. + truncated_prompts = [ + await self._truncate_prompt(p, record_copy) for p in prompts + ] + + #Step 2 - Tune intervention strength to match target KL. + tuned_strength, initial_kl = await self._tune_strength(truncated_prompts, record_copy) + total_diff = 0.0 total_kl = 0.0 n = 0 - for prompt in prompts: - clean_text, clean_logp_dist = ( - await self._generate_with_and_without_intervention( - prompt, record_copy, intervene=False - ) + for prompt in truncated_prompts: + clean_text, clean_logp_dist = await self._generate_with_intervention( + prompt, record_copy, strength=0.0, get_logp_dist=True ) - int_text, int_logp_dist = ( - await self._generate_with_and_without_intervention( - prompt, record_copy, intervene=True - ) - ) - - logp_clean = await self._score_explanation( - clean_text, record_copy.explanation + int_text, int_logp_dist = await self._generate_with_intervention( + prompt, record_copy, strength=tuned_strength, get_logp_dist=True ) + + logp_clean = await self._score_explanation(clean_text, record_copy.explanation) logp_int = await self._score_explanation(int_text, record_copy.explanation) - + p_clean = torch.exp(clean_logp_dist) - kl_div = F.kl_div( - int_logp_dist, p_clean, reduction="sum", log_target=False - ).item() + kl_div = F.kl_div(int_logp_dist, p_clean, reduction="sum", log_target=False).item() total_diff += logp_int - logp_clean total_kl += kl_div @@ -221,92 +215,209 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: avg_diff = total_diff / n if n > 0 else 0.0 avg_kl = total_kl / n if n > 0 else 0.0 - final_score = avg_diff / (avg_kl + 1e-9) if n > 0 else 0.0 + + #Final score is the average difference, not normalized by KL. + final_score = avg_diff final_output_list = [] - for ex in examples[: self.num_prompts]: + for i, ex in enumerate(examples[: self.num_prompts]): final_output_list.append( { "str_tokens": ex["str_tokens"], + "truncated_prompt": truncated_prompts[i], "final_score": final_score, "avg_kl_divergence": avg_kl, - # Add placeholder keys that the parser expects, with default values. - "distance": None, - "activating": None, - "prediction": None, - "correct": None, - "probability": None, - "activations": None, + "tuned_strength": tuned_strength, + "target_kl": self.target_kl, + # Placeholder keys + "distance": None, "activating": None, "prediction": None, + "correct": None, "probability": None, "activations": None, } ) return ScorerResult(record=record_copy, score=final_output_list) - async def _generate_with_and_without_intervention( - self, prompt: str, record: LatentRecord, intervene: bool - ) -> Tuple[str, torch.Tensor]: + + async def _get_latent_activations(self, prompt: str, record: LatentRecord) -> torch.Tensor: + """ + Runs a forward pass to get the SAE's latent activations for a prompt. """ - Generates a text continuation and returns the next-token log-probabilities. + device = self._get_device() + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + sae = self._get_sae_for_hookpoint(hookpoint_str, record) + if not sae: + return torch.empty(0) # Return empty tensor if no SAE to encode with + + captured_hidden_states = [] + def capture_hook(module, inp, out): + hidden_states = out[0] if isinstance(out, tuple) else out + captured_hidden_states.append(hidden_states.detach().cpu()) + + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + handle = layer_to_hook.register_forward_hook(capture_hook) + + try: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) + with torch.no_grad(): + self.subject_model(input_ids) + finally: + handle.remove() + + if not captured_hidden_states: + return torch.empty(0) - If `intervene` is True, it adds a feature direction to the activations at the - specified hookpoint before generation. + hidden_states = captured_hidden_states[0].to(device) - Returns: - A tuple containing: - - The generated text (string). - - The log-probability distribution for the token immediately following - the prompt (Tensor). + encoding_result = sae.encode(hidden_states) + feature_acts = encoding_result[2] + + return feature_acts[0, :, record.feature_id].cpu() + + + async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: + """ + Truncates a prompt to end just before the first token where the latent activates. + """ + activations = await self._get_latent_activations(prompt, record) + if activations.numel() == 0: + return prompt # Cannot truncate if no activations found + + # Find the index of the first token with non-zero activation + first_activation_idx = (activations > 1e-6).nonzero(as_tuple=True)[0] + + if first_activation_idx.numel() > 0: + truncation_point = first_activation_idx[0].item() + if truncation_point > 0: + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0] + truncated_ids = input_ids[:truncation_point] + return self.tokenizer.decode(truncated_ids, skip_special_tokens=True) + + return prompt + + + async def _tune_strength( + self, prompts: List[str], record: LatentRecord + ) -> Tuple[float, float]: """ + Performs a binary search to find the intervention strength that matches `target_kl`. + """ + low_strength, high_strength = 0.0, 40.0 # Heuristic search range + best_strength = self.target_kl # Default to target_kl if search fails + + for _ in range(self.max_search_steps): + mid_strength = (low_strength + high_strength) / 2 + + # Estimate KL at mid_strength + total_kl = 0.0 + n = 0 + for prompt in prompts: + _, clean_logp = await self._generate_with_intervention(prompt, record, 0.0, True) + _, int_logp = await self._generate_with_intervention(prompt, record, mid_strength, True) + + p_clean = torch.exp(clean_logp) + kl_div = F.kl_div(int_logp, p_clean, reduction="sum", log_target=False).item() + total_kl += kl_div + n += 1 + + current_kl = total_kl / n if n > 0 else 0.0 + + if abs(current_kl - self.target_kl) < self.kl_tolerance: + return mid_strength, current_kl + + if current_kl < self.target_kl: + low_strength = mid_strength + else: + high_strength = mid_strength + + best_strength = mid_strength + + # Return the best found strength and the corresponding KL + final_kl = await self._calculate_avg_kl(prompts, record, best_strength) + return best_strength, final_kl + + + async def _calculate_avg_kl(self, prompts: List[str], record: LatentRecord, strength: float) -> float: + total_kl = 0.0 + n = 0 + for prompt in prompts: + _, clean_logp = await self._generate_with_intervention(prompt, record, 0.0, True) + _, int_logp = await self._generate_with_intervention(prompt, record, strength, True) + p_clean = torch.exp(clean_logp) + kl_div = F.kl_div(int_logp, p_clean, reduction="sum", log_target=False).item() + total_kl += kl_div + n += 1 + return total_kl / n if n > 0 else 0.0 + + + async def _generate_with_intervention( + self, prompt: str, record: LatentRecord, strength: float, get_logp_dist: bool = False + ) -> Tuple[str, torch.Tensor]: device = self._get_device() enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) input_ids = enc["input_ids"].to(device) hooks = [] - if intervene: - + if strength > 0: hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) if hookpoint_str is None: raise ValueError("No hookpoint string specified for intervention.") - + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) + sae = self._get_sae_for_hookpoint(hookpoint_str, record) + if not sae: + raise ValueError(f"Could not find a valid SAE for hookpoint {hookpoint_str}") - direction = self._get_intervention_direction(record).to(device) - direction = direction.unsqueeze(0).unsqueeze( - 0 - ) # Shape for broadcasting: [1, 1, D] def hook_fn(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out + original_dtype = hidden_states.dtype + + # Get the latent dimension from the SAE's encoder + d_latent = sae.encoder.out_features + sae_device = sae.encoder.weight.device + + # --- Compute the decoder vector for the target feature --- + # 1. Create a one-hot activation for our single feature. + one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device) + one_hot_activation[0, 0, record.feature_id] = 1.0 + + # 2. Create the corresponding indices needed for the decode method. + indices = torch.tensor([[[record.feature_id]]], device=sae_device, dtype=torch.long) + + # 3. Decode this one-hot vector to get the feature's direction in the hidden space. + # We subtract the decoded zero vector to remove any decoder bias. + decoded_zero = sae.decode(torch.zeros_like(one_hot_activation), indices) + decoder_vector = sae.decode(one_hot_activation, indices) - decoded_zero + decoder_vector = decoder_vector.squeeze() # Remove batch & seq dims + # --- End vector computation --- - # Apply intervention to the last token's hidden state - hidden_states[:, -1:, :] += self.strength * direction + # Calculate the change we want to apply. + delta = strength * decoder_vector + + new_hiddens = hidden_states.clone() + new_hiddens[:, -1, :] += delta.to(original_dtype) + + return (new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens - # Return the modified activations in their original format - if isinstance(out, tuple): - return (hidden_states,) + out[1:] - return hidden_states hooks.append(layer_to_hook.register_forward_hook(hook_fn)) try: with torch.no_grad(): - # 1. Get next-token logits for KL divergence calculation outputs = self.subject_model(input_ids) next_token_logits = outputs.logits[0, -1, :] - log_probs_next_token = F.log_softmax(next_token_logits, dim=-1) + log_probs_next_token = F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None - # 2. Generate the full text continuation gen_ids = self.subject_model.generate( - input_ids, - max_new_tokens=self.max_new_tokens, - do_sample=False, - pad_token_id=self.tokenizer.pad_token_id, + input_ids, max_new_tokens=self.max_new_tokens, + do_sample=False, pad_token_id=self.tokenizer.pad_token_id ) generated_text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) finally: for h in hooks: h.remove() + + return generated_text, log_probs_next_token.cpu() if get_logp_dist else torch.empty(0) - return generated_text, log_probs_next_token.cpu() async def _score_explanation(self, generated_text: str, explanation: str) -> float: """Computes log P(explanation | generated_text) under the subject model.""" @@ -339,35 +450,91 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo return token_log_probs.sum().item() - def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: + + """ + Retrieves the correct SAE model, handling cases where the framework + provides a functools.partial wrapper. """ - Gets the feature direction vector, preferring an SAE if available, - otherwise falling back to estimating it from activations. + candidate = None + + # 1. Try to get the SAE from the record object first. + if hasattr(record, "sae") and record.sae: + candidate = record.sae + # 2. If not on the record, look it up in the explainer_model dictionary. + elif self.explainer_model and isinstance(self.explainer_model, dict): + full_key = self._get_full_hookpoint_path(hookpoint_str) + for key in [hookpoint_str, full_key]: + if self.explainer_model.get(key) is not None: + candidate = self.explainer_model.get(key) + break + + if candidate is not None: + # 3. Check if we need to unwrap a partial object. + if isinstance(candidate, functools.partial): + # Case A: The instance is in a bound method's __self__. + instance = getattr(candidate.func, '__self__', None) + if instance is not None: + return instance # Unwrapped successfully. + + # Case B: The instance is the first argument to the partial. + if candidate.args and len(candidate.args) > 0: + instance = candidate.args[0] + # A sanity check to make sure it looks like an SAE model. + if hasattr(instance, 'encode') and hasattr(instance, 'decode'): + return instance # Unwrapped successfully. + + # If we found a partial but failed to unwrap it, we cannot proceed. + print(f"ERROR: Found a partial for {hookpoint_str} but could not unwrap the SAE instance.") + return None + + # If it's not a partial, it's the model itself. + return candidate + + print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'") + return None + + + def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any: + """ + Retrieves the correct SAE model, handling the specific functools.partial + wrapper provided by the Delphi framework. """ - # --- Fast Path: Try to get vector from an SAE-like explainer model --- - if self.explainer_model: - sae = None - candidate = self.explainer_model - if isinstance(self.explainer_model, dict): - hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) - candidate = self.explainer_model.get(hookpoint_str) - - if hasattr(candidate, "get_feature_vector"): - sae = candidate - elif hasattr(candidate, "sae") and hasattr( - candidate.sae, "get_feature_vector" - ): - sae = candidate.sae - - if sae: + candidate = None + + if hasattr(record, "sae") and record.sae: + candidate = record.sae + elif self.explainer_model and isinstance(self.explainer_model, dict): + full_key = self._get_full_hookpoint_path(hookpoint_str) + for key in [hookpoint_str, full_key]: + if self.explainer_model.get(key) is not None: + candidate = self.explainer_model.get(key) + break + + if candidate is not None: + if isinstance(candidate, functools.partial): + if candidate.keywords and 'sae' in candidate.keywords: + return candidate.keywords['sae'] + + return candidate + + print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'") + return None + + + def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + + sae = self._get_sae_for_hookpoint(hookpoint_str, record) + + if sae and hasattr(sae, "get_feature_vector"): direction = sae.get_feature_vector(record.feature_id) if not isinstance(direction, torch.Tensor): direction = torch.tensor(direction, dtype=torch.float32) direction = direction.squeeze() return F.normalize(direction, p=2, dim=0) - # --- Fallback: Estimate direction from activating examples --- - return self._estimate_direction_from_examples(record) + return self._estimate_direction_from_examples(record) + def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: """Estimates an intervention direction by averaging activations.""" From 6db120f02f2d91a25f79555ad7cdfc60ded10bc5 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 10 Sep 2025 22:46:22 +0000 Subject: [PATCH 13/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../surprisal_intervention_scorer.py | 232 ++++++++++-------- 1 file changed, 132 insertions(+), 100 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 3cb4b6cb..9fd47e3e 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -1,5 +1,5 @@ -import functools import copy +import functools from dataclasses import dataclass from typing import Any, Dict, List, Tuple @@ -91,7 +91,6 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): self.tokenizer.pad_token = self.tokenizer.eos_token self.subject_model.config.pad_token_id = self.tokenizer.eos_token_id - def _get_device(self) -> torch.device: """Safely gets the device of the subject model.""" try: @@ -99,7 +98,6 @@ def _get_device(self) -> torch.device: except StopIteration: return torch.device("cuda" if torch.cuda.is_available() else "cpu") - def _find_layer(self, model: Any, name: str) -> torch.nn.Module: """Resolves a module by its dotted path name.""" if name is None: @@ -112,23 +110,21 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module: current = getattr(current, part) return current - def _get_full_hookpoint_path(self, hookpoint_str: str) -> str: - """ - Heuristically finds the model's prefix and constructs the full hookpoint path string. - e.g., 'layers.6.mlp' -> 'model.layers.6.mlp' - """ - # Heuristically find the model prefix. - prefix = None - for p in ["gpt_neox", "transformer", "model"]: - if hasattr(self.subject_model, p): - candidate_body = getattr(self.subject_model, p) - if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"): - prefix = p - break - - return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str + """ + Heuristically finds the model's prefix and constructs the full hookpoint path string. + e.g., 'layers.6.mlp' -> 'model.layers.6.mlp' + """ + # Heuristically find the model prefix. + prefix = None + for p in ["gpt_neox", "transformer", "model"]: + if hasattr(self.subject_model, p): + candidate_body = getattr(self.subject_model, p) + if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"): + prefix = p + break + return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: """ @@ -167,7 +163,6 @@ def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: return sanitized - async def __call__(self, record: LatentRecord) -> ScorerResult: record_copy = copy.deepcopy(record) @@ -184,14 +179,16 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: prompts = ["".join(ex["str_tokens"]) for ex in examples[: self.num_prompts]] - #Step 1 - Truncate prompts before tuning or scoring. + # Step 1 - Truncate prompts before tuning or scoring. truncated_prompts = [ await self._truncate_prompt(p, record_copy) for p in prompts ] - #Step 2 - Tune intervention strength to match target KL. - tuned_strength, initial_kl = await self._tune_strength(truncated_prompts, record_copy) - + # Step 2 - Tune intervention strength to match target KL. + tuned_strength, initial_kl = await self._tune_strength( + truncated_prompts, record_copy + ) + total_diff = 0.0 total_kl = 0.0 n = 0 @@ -203,12 +200,16 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: int_text, int_logp_dist = await self._generate_with_intervention( prompt, record_copy, strength=tuned_strength, get_logp_dist=True ) - - logp_clean = await self._score_explanation(clean_text, record_copy.explanation) + + logp_clean = await self._score_explanation( + clean_text, record_copy.explanation + ) logp_int = await self._score_explanation(int_text, record_copy.explanation) - + p_clean = torch.exp(clean_logp_dist) - kl_div = F.kl_div(int_logp_dist, p_clean, reduction="sum", log_target=False).item() + kl_div = F.kl_div( + int_logp_dist, p_clean, reduction="sum", log_target=False + ).item() total_diff += logp_int - logp_clean total_kl += kl_div @@ -216,8 +217,8 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: avg_diff = total_diff / n if n > 0 else 0.0 avg_kl = total_kl / n if n > 0 else 0.0 - - #Final score is the average difference, not normalized by KL. + + # Final score is the average difference, not normalized by KL. final_score = avg_diff final_output_list = [] @@ -231,14 +232,19 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: "tuned_strength": tuned_strength, "target_kl": self.target_kl, # Placeholder keys - "distance": None, "activating": None, "prediction": None, - "correct": None, "probability": None, "activations": None, + "distance": None, + "activating": None, + "prediction": None, + "correct": None, + "probability": None, + "activations": None, } ) return ScorerResult(record=record_copy, score=final_output_list) - - async def _get_latent_activations(self, prompt: str, record: LatentRecord) -> torch.Tensor: + async def _get_latent_activations( + self, prompt: str, record: LatentRecord + ) -> torch.Tensor: """ Runs a forward pass to get the SAE's latent activations for a prompt. """ @@ -246,16 +252,17 @@ async def _get_latent_activations(self, prompt: str, record: LatentRecord) -> to hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) sae = self._get_sae_for_hookpoint(hookpoint_str, record) if not sae: - return torch.empty(0) # Return empty tensor if no SAE to encode with + return torch.empty(0) # Return empty tensor if no SAE to encode with captured_hidden_states = [] + def capture_hook(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out captured_hidden_states.append(hidden_states.detach().cpu()) layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) handle = layer_to_hook.register_forward_hook(capture_hook) - + try: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids.to(device) with torch.no_grad(): @@ -273,27 +280,25 @@ def capture_hook(module, inp, out): return feature_acts[0, :, record.feature_id].cpu() - async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: """ Truncates a prompt to end just before the first token where the latent activates. """ activations = await self._get_latent_activations(prompt, record) if activations.numel() == 0: - return prompt # Cannot truncate if no activations found + return prompt # Cannot truncate if no activations found # Find the index of the first token with non-zero activation first_activation_idx = (activations > 1e-6).nonzero(as_tuple=True)[0] - + if first_activation_idx.numel() > 0: truncation_point = first_activation_idx[0].item() if truncation_point > 0: input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0] truncated_ids = input_ids[:truncation_point] return self.tokenizer.decode(truncated_ids, skip_special_tokens=True) - - return prompt + return prompt async def _tune_strength( self, prompts: List[str], record: LatentRecord @@ -301,26 +306,32 @@ async def _tune_strength( """ Performs a binary search to find the intervention strength that matches `target_kl`. """ - low_strength, high_strength = 0.0, 40.0 # Heuristic search range - best_strength = self.target_kl # Default to target_kl if search fails - + low_strength, high_strength = 0.0, 40.0 # Heuristic search range + best_strength = self.target_kl # Default to target_kl if search fails + for _ in range(self.max_search_steps): mid_strength = (low_strength + high_strength) / 2 - + # Estimate KL at mid_strength total_kl = 0.0 n = 0 for prompt in prompts: - _, clean_logp = await self._generate_with_intervention(prompt, record, 0.0, True) - _, int_logp = await self._generate_with_intervention(prompt, record, mid_strength, True) + _, clean_logp = await self._generate_with_intervention( + prompt, record, 0.0, True + ) + _, int_logp = await self._generate_with_intervention( + prompt, record, mid_strength, True + ) p_clean = torch.exp(clean_logp) - kl_div = F.kl_div(int_logp, p_clean, reduction="sum", log_target=False).item() + kl_div = F.kl_div( + int_logp, p_clean, reduction="sum", log_target=False + ).item() total_kl += kl_div n += 1 - + current_kl = total_kl / n if n > 0 else 0.0 - + if abs(current_kl - self.target_kl) < self.kl_tolerance: return mid_strength, current_kl @@ -328,29 +339,39 @@ async def _tune_strength( low_strength = mid_strength else: high_strength = mid_strength - + best_strength = mid_strength # Return the best found strength and the corresponding KL final_kl = await self._calculate_avg_kl(prompts, record, best_strength) return best_strength, final_kl - - async def _calculate_avg_kl(self, prompts: List[str], record: LatentRecord, strength: float) -> float: + async def _calculate_avg_kl( + self, prompts: List[str], record: LatentRecord, strength: float + ) -> float: total_kl = 0.0 n = 0 for prompt in prompts: - _, clean_logp = await self._generate_with_intervention(prompt, record, 0.0, True) - _, int_logp = await self._generate_with_intervention(prompt, record, strength, True) + _, clean_logp = await self._generate_with_intervention( + prompt, record, 0.0, True + ) + _, int_logp = await self._generate_with_intervention( + prompt, record, strength, True + ) p_clean = torch.exp(clean_logp) - kl_div = F.kl_div(int_logp, p_clean, reduction="sum", log_target=False).item() + kl_div = F.kl_div( + int_logp, p_clean, reduction="sum", log_target=False + ).item() total_kl += kl_div n += 1 return total_kl / n if n > 0 else 0.0 - async def _generate_with_intervention( - self, prompt: str, record: LatentRecord, strength: float, get_logp_dist: bool = False + self, + prompt: str, + record: LatentRecord, + strength: float, + get_logp_dist: bool = False, ) -> Tuple[str, torch.Tensor]: device = self._get_device() enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) @@ -361,17 +382,18 @@ async def _generate_with_intervention( hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) if hookpoint_str is None: raise ValueError("No hookpoint string specified for intervention.") - + layer_to_hook = self._resolve_hookpoint(self.subject_model, hookpoint_str) sae = self._get_sae_for_hookpoint(hookpoint_str, record) if not sae: - raise ValueError(f"Could not find a valid SAE for hookpoint {hookpoint_str}") - + raise ValueError( + f"Could not find a valid SAE for hookpoint {hookpoint_str}" + ) def hook_fn(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out original_dtype = hidden_states.dtype - + # Get the latent dimension from the SAE's encoder d_latent = sae.encoder.out_features sae_device = sae.encoder.weight.device @@ -382,23 +404,26 @@ def hook_fn(module, inp, out): one_hot_activation[0, 0, record.feature_id] = 1.0 # 2. Create the corresponding indices needed for the decode method. - indices = torch.tensor([[[record.feature_id]]], device=sae_device, dtype=torch.long) + indices = torch.tensor( + [[[record.feature_id]]], device=sae_device, dtype=torch.long + ) # 3. Decode this one-hot vector to get the feature's direction in the hidden space. # We subtract the decoded zero vector to remove any decoder bias. decoded_zero = sae.decode(torch.zeros_like(one_hot_activation), indices) decoder_vector = sae.decode(one_hot_activation, indices) - decoded_zero - decoder_vector = decoder_vector.squeeze() # Remove batch & seq dims + decoder_vector = decoder_vector.squeeze() # Remove batch & seq dims # --- End vector computation --- # Calculate the change we want to apply. delta = strength * decoder_vector - + new_hiddens = hidden_states.clone() new_hiddens[:, -1, :] += delta.to(original_dtype) - return (new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens - + return ( + (new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens + ) hooks.append(layer_to_hook.register_forward_hook(hook_fn)) @@ -406,19 +431,24 @@ def hook_fn(module, inp, out): with torch.no_grad(): outputs = self.subject_model(input_ids) next_token_logits = outputs.logits[0, -1, :] - log_probs_next_token = F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None + log_probs_next_token = ( + F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None + ) gen_ids = self.subject_model.generate( - input_ids, max_new_tokens=self.max_new_tokens, - do_sample=False, pad_token_id=self.tokenizer.pad_token_id + input_ids, + max_new_tokens=self.max_new_tokens, + do_sample=False, + pad_token_id=self.tokenizer.pad_token_id, ) generated_text = self.tokenizer.decode(gen_ids[0], skip_special_tokens=True) finally: for h in hooks: h.remove() - - return generated_text, log_probs_next_token.cpu() if get_logp_dist else torch.empty(0) + return generated_text, ( + log_probs_next_token.cpu() if get_logp_dist else torch.empty(0) + ) async def _score_explanation(self, generated_text: str, explanation: str) -> float: """Computes log P(explanation | generated_text) under the subject model.""" @@ -451,13 +481,12 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo return token_log_probs.sum().item() - """ Retrieves the correct SAE model, handling cases where the framework provides a functools.partial wrapper. """ candidate = None - + # 1. Try to get the SAE from the record object first. if hasattr(record, "sae") and record.sae: candidate = record.sae @@ -468,40 +497,43 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo if self.explainer_model.get(key) is not None: candidate = self.explainer_model.get(key) break - + if candidate is not None: # 3. Check if we need to unwrap a partial object. if isinstance(candidate, functools.partial): # Case A: The instance is in a bound method's __self__. - instance = getattr(candidate.func, '__self__', None) + instance = getattr(candidate.func, "__self__", None) if instance is not None: return instance # Unwrapped successfully. - + # Case B: The instance is the first argument to the partial. if candidate.args and len(candidate.args) > 0: instance = candidate.args[0] # A sanity check to make sure it looks like an SAE model. - if hasattr(instance, 'encode') and hasattr(instance, 'decode'): + if hasattr(instance, "encode") and hasattr(instance, "decode"): return instance # Unwrapped successfully. - + # If we found a partial but failed to unwrap it, we cannot proceed. - print(f"ERROR: Found a partial for {hookpoint_str} but could not unwrap the SAE instance.") + print( + f"ERROR: Found a partial for {hookpoint_str} but could not unwrap the SAE instance." + ) return None - + # If it's not a partial, it's the model itself. return candidate - print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'") + print( + f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'" + ) return None - def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any: """ Retrieves the correct SAE model, handling the specific functools.partial wrapper provided by the Delphi framework. """ candidate = None - + if hasattr(record, "sae") and record.sae: candidate = record.sae elif self.explainer_model and isinstance(self.explainer_model, dict): @@ -510,32 +542,32 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An if self.explainer_model.get(key) is not None: candidate = self.explainer_model.get(key) break - + if candidate is not None: if isinstance(candidate, functools.partial): - if candidate.keywords and 'sae' in candidate.keywords: - return candidate.keywords['sae'] - + if candidate.keywords and "sae" in candidate.keywords: + return candidate.keywords["sae"] + return candidate - print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'") + print( + f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'" + ) return None - def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: - hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) - - sae = self._get_sae_for_hookpoint(hookpoint_str, record) - - if sae and hasattr(sae, "get_feature_vector"): - direction = sae.get_feature_vector(record.feature_id) - if not isinstance(direction, torch.Tensor): - direction = torch.tensor(direction, dtype=torch.float32) - direction = direction.squeeze() - return F.normalize(direction, p=2, dim=0) + hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) + + sae = self._get_sae_for_hookpoint(hookpoint_str, record) - return self._estimate_direction_from_examples(record) + if sae and hasattr(sae, "get_feature_vector"): + direction = sae.get_feature_vector(record.feature_id) + if not isinstance(direction, torch.Tensor): + direction = torch.tensor(direction, dtype=torch.float32) + direction = direction.squeeze() + return F.normalize(direction, p=2, dim=0) + return self._estimate_direction_from_examples(record) def _estimate_direction_from_examples(self, record: LatentRecord) -> torch.Tensor: """Estimates an intervention direction by averaging activations.""" From 1a6fa0c3decb06e0a007b193dac70d9fdbd79f58 Mon Sep 17 00:00:00 2001 From: saireddythfc Date: Thu, 11 Sep 2025 13:48:39 +0000 Subject: [PATCH 14/25] Pre-commit clears --- .../surprisal_intervention_scorer.py | 47 +++++++++++-------- 1 file changed, 27 insertions(+), 20 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 3cb4b6cb..d778aafe 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -114,20 +114,21 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module: def _get_full_hookpoint_path(self, hookpoint_str: str) -> str: - """ - Heuristically finds the model's prefix and constructs the full hookpoint path string. - e.g., 'layers.6.mlp' -> 'model.layers.6.mlp' - """ - # Heuristically find the model prefix. - prefix = None - for p in ["gpt_neox", "transformer", "model"]: - if hasattr(self.subject_model, p): - candidate_body = getattr(self.subject_model, p) - if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"): - prefix = p - break - - return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str + """ + Heuristically finds the model's prefix and constructs the full hookpoint + path string. + e.g., 'layers.6.mlp' -> 'model.layers.6.mlp' + """ + # Heuristically find the model prefix. + prefix = None + for p in ["gpt_neox", "transformer", "model"]: + if hasattr(self.subject_model, p): + candidate_body = getattr(self.subject_model, p) + if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"): + prefix = p + break + + return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: @@ -144,6 +145,7 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: Original error: {e}""" ) + def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ Function used for formatting results to run smoothly in the delphi pipeline @@ -276,7 +278,7 @@ def capture_hook(module, inp, out): async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: """ - Truncates a prompt to end just before the first token where the latent activates. + Truncates prompt to end just before the first token where latent activates. """ activations = await self._get_latent_activations(prompt, record) if activations.numel() == 0: @@ -299,7 +301,7 @@ async def _tune_strength( self, prompts: List[str], record: LatentRecord ) -> Tuple[float, float]: """ - Performs a binary search to find the intervention strength that matches `target_kl`. + Performs a binary search to find intervention strength that matches target_kl. """ low_strength, high_strength = 0.0, 40.0 # Heuristic search range best_strength = self.target_kl # Default to target_kl if search fails @@ -384,7 +386,7 @@ def hook_fn(module, inp, out): # 2. Create the corresponding indices needed for the decode method. indices = torch.tensor([[[record.feature_id]]], device=sae_device, dtype=torch.long) - # 3. Decode this one-hot vector to get the feature's direction in the hidden space. + # 3. Decode one-hot vector to get feature's direction in hidden space. # We subtract the decoded zero vector to remove any decoder bias. decoded_zero = sae.decode(torch.zeros_like(one_hot_activation), indices) decoder_vector = sae.decode(one_hot_activation, indices) - decoded_zero @@ -485,13 +487,17 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo return instance # Unwrapped successfully. # If we found a partial but failed to unwrap it, we cannot proceed. - print(f"ERROR: Found a partial for {hookpoint_str} but could not unwrap the SAE instance.") + print( + f"""ERROR: Found a partial for {hookpoint_str} but could not + unwrap the SAE instance.""") return None # If it's not a partial, it's the model itself. return candidate - print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'") + print( + f"""ERROR: Surprisal scorer could not find + an SAE for hookpoint '{hookpoint_str}'""") return None @@ -518,7 +524,8 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An return candidate - print(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}'") + print(f"""ERROR: Surprisal scorer could not find + an SAE for hookpoint '{hookpoint_str}'""") return None From 6e18bbaf8058374186f55c4d52a7d2b3be366bef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 11 Sep 2025 13:52:40 +0000 Subject: [PATCH 15/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../surprisal_intervention_scorer.py | 21 +++++++++++-------- 1 file changed, 12 insertions(+), 9 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 24731735..b4a2c098 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -112,7 +112,7 @@ def _find_layer(self, model: Any, name: str) -> torch.nn.Module: def _get_full_hookpoint_path(self, hookpoint_str: str) -> str: """ - Heuristically finds the model's prefix and constructs the full hookpoint + Heuristically finds the model's prefix and constructs the full hookpoint path string. e.g., 'layers.6.mlp' -> 'model.layers.6.mlp' """ @@ -124,7 +124,7 @@ def _get_full_hookpoint_path(self, hookpoint_str: str) -> str: if hasattr(candidate_body, "h") or hasattr(candidate_body, "layers"): prefix = p break - + return f"{prefix}.{hookpoint_str}" if prefix else hookpoint_str def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: @@ -141,7 +141,6 @@ def _resolve_hookpoint(self, model: Any, hookpoint_str: str) -> Any: Original error: {e}""" ) - def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: """ Function used for formatting results to run smoothly in the delphi pipeline @@ -517,16 +516,18 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo # If we found a partial but failed to unwrap it, we cannot proceed. print( - f"""ERROR: Found a partial for {hookpoint_str} but could not - unwrap the SAE instance.""") + f"""ERROR: Found a partial for {hookpoint_str} but could not + unwrap the SAE instance.""" + ) return None # If it's not a partial, it's the model itself. return candidate print( - f"""ERROR: Surprisal scorer could not find - an SAE for hookpoint '{hookpoint_str}'""") + f"""ERROR: Surprisal scorer could not find + an SAE for hookpoint '{hookpoint_str}'""" + ) return None def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any: @@ -552,8 +553,10 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An return candidate - print(f"""ERROR: Surprisal scorer could not find - an SAE for hookpoint '{hookpoint_str}'""") + print( + f"""ERROR: Surprisal scorer could not find + an SAE for hookpoint '{hookpoint_str}'""" + ) return None def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: From ba6533b16be98f610ec7798626c5dc6e55d53509 Mon Sep 17 00:00:00 2001 From: Sai Reddy Date: Mon, 17 Nov 2025 22:09:04 +0000 Subject: [PATCH 16/25] Fix intervention point --- .../surprisal_intervention_scorer.py | 248 ++++++++++-------- 1 file changed, 141 insertions(+), 107 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index b4a2c098..09363442 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -72,7 +72,7 @@ def __init__(self, subject_model: Any, explainer_model: Any = None, **kwargs): self.explainer_model = explainer_model self.strength = float(kwargs.get("strength", 5.0)) self.num_prompts = int(kwargs.get("num_prompts", 3)) - self.max_new_tokens = int(kwargs.get("max_new_tokens", 20)) + self.max_new_tokens = int(kwargs.get("max_new_tokens", 8)) self.hookpoints = kwargs.get("hookpoints") self.target_kl = float(kwargs.get("target_kl", 1.0)) @@ -164,6 +164,56 @@ def _sanitize_examples(self, examples: List[Any]) -> List[Dict[str, Any]]: return sanitized + def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor: + """ + Calculates the feature's decoder vector, subtracting the decoder bias. + """ + + + d_latent = sae.encoder.out_features + sae_device = sae.encoder.weight.device + + # Create a one-hot activation for our single feature. + one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device) + + if feature_id >= d_latent: + print(f"DEBUG: ERROR - Feature ID {feature_id} is out of bounds for d_latent {d_latent}") + return torch.zeros(1) + + one_hot_activation[0, 0, feature_id] = 1.0 + + # Create the corresponding indices needed for the decode method. + indices = torch.tensor( + [[[feature_id]]], device=sae_device, dtype=torch.long + ) + + with torch.no_grad(): + try: + decoded_zero = sae.decode(torch.zeros_like(one_hot_activation), indices) + vector_before_sub = sae.decode(one_hot_activation, indices) + except Exception as e: + print(f"DEBUG: ERROR during sae.decode: {e}") + return torch.zeros(1) + + decoder_vector = vector_before_sub - decoded_zero + + final_norm = decoder_vector.norm().item() + + # --- MODIFIED DEBUG BLOCK --- + # Only print if the feature is "decoder-live" + if final_norm > 1e-6: + print(f"\n--- DEBUG: 'Decoder-Live' Feature Found: {feature_id} ---") + print(f"DEBUG: sae.encoder.out_features (d_latent): {d_latent}") + print(f"DEBUG: sae.encoder.weight.device (sae_device): {sae_device}") + print(f"DEBUG: Norm of decoded_zero: {decoded_zero.norm().item()}") + print(f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}") + print(f"DEBUG: Feature {feature_id}, FINAL Vector Norm: {final_norm}") + print("--- END DEBUG ---\n") + # --- END MODIFIED BLOCK --- + + return decoder_vector.squeeze() + + async def __call__(self, record: LatentRecord) -> ScorerResult: record_copy = copy.deepcopy(record) @@ -186,8 +236,15 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: ] # Step 2 - Tune intervention strength to match target KL. + hookpoint_str = self.hookpoint_str or getattr(record_copy, "hookpoint", None) + sae = self._get_sae_for_hookpoint(hookpoint_str, record_copy) + if not sae: + raise ValueError(f"Could not find SAE for hookpoint {hookpoint_str}") + + intervention_vector = self._get_intervention_vector(sae, record_copy.feature_id) + tuned_strength, initial_kl = await self._tune_strength( - truncated_prompts, record_copy + truncated_prompts, record_copy, intervention_vector ) total_diff = 0.0 @@ -196,10 +253,10 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: for prompt in truncated_prompts: clean_text, clean_logp_dist = await self._generate_with_intervention( - prompt, record_copy, strength=0.0, get_logp_dist=True + prompt, record_copy, strength=0.0, intervention_vector=intervention_vector, get_logp_dist=True ) int_text, int_logp_dist = await self._generate_with_intervention( - prompt, record_copy, strength=tuned_strength, get_logp_dist=True + prompt, record_copy, strength=tuned_strength, intervention_vector=intervention_vector, get_logp_dist=True ) logp_clean = await self._score_explanation( @@ -243,6 +300,7 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: ) return ScorerResult(record=record_copy, score=final_output_list) + async def _get_latent_activations( self, prompt: str, record: LatentRecord ) -> torch.Tensor: @@ -281,6 +339,7 @@ def capture_hook(module, inp, out): return feature_acts[0, :, record.feature_id].cpu() + async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: """ Truncates prompt to end just before the first token where latent activates. @@ -290,19 +349,24 @@ async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: return prompt # Cannot truncate if no activations found # Find the index of the first token with non-zero activation - first_activation_idx = (activations > 1e-6).nonzero(as_tuple=True)[0] + # Get ALL non-zero indices first + all_activation_indices = (activations > 1e-6).nonzero(as_tuple=True)[0] + + # Filter out activations at position 0 (BOS) + first_activation_idx = all_activation_indices[all_activation_indices > 0] if first_activation_idx.numel() > 0: - truncation_point = first_activation_idx[0].item() - if truncation_point > 0: - input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0] - truncated_ids = input_ids[:truncation_point] - return self.tokenizer.decode(truncated_ids, skip_special_tokens=True) + truncation_point = first_activation_idx[0].item() + input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0] + truncated_ids = input_ids[:truncation_point + 1] + return self.tokenizer.decode(truncated_ids, skip_special_tokens=True) return prompt + async def _tune_strength( - self, prompts: List[str], record: LatentRecord + self, prompts: List[str], record: LatentRecord, + intervention_vector: torch.Tensor ) -> Tuple[float, float]: """ Performs a binary search to find intervention strength that matches target_kl. @@ -318,10 +382,10 @@ async def _tune_strength( n = 0 for prompt in prompts: _, clean_logp = await self._generate_with_intervention( - prompt, record, 0.0, True + prompt, record, 0.0, intervention_vector, True ) _, int_logp = await self._generate_with_intervention( - prompt, record, mid_strength, True + prompt, record, mid_strength, intervention_vector, True ) p_clean = torch.exp(clean_logp) @@ -344,20 +408,22 @@ async def _tune_strength( best_strength = mid_strength # Return the best found strength and the corresponding KL - final_kl = await self._calculate_avg_kl(prompts, record, best_strength) + final_kl = await self._calculate_avg_kl(prompts, record, best_strength, intervention_vector) return best_strength, final_kl + async def _calculate_avg_kl( - self, prompts: List[str], record: LatentRecord, strength: float + self, prompts: List[str], record: LatentRecord, strength: float, + intervention_vector: torch.Tensor ) -> float: total_kl = 0.0 n = 0 for prompt in prompts: _, clean_logp = await self._generate_with_intervention( - prompt, record, 0.0, True + prompt, record, 0.0, intervention_vector,True ) _, int_logp = await self._generate_with_intervention( - prompt, record, strength, True + prompt, record, strength, intervention_vector,True ) p_clean = torch.exp(clean_logp) kl_div = F.kl_div( @@ -367,16 +433,22 @@ async def _calculate_avg_kl( n += 1 return total_kl / n if n > 0 else 0.0 + async def _generate_with_intervention( self, prompt: str, record: LatentRecord, strength: float, + intervention_vector: torch.Tensor, get_logp_dist: bool = False, ) -> Tuple[str, torch.Tensor]: device = self._get_device() enc = self.tokenizer(prompt, return_tensors="pt", truncation=True, padding=True) input_ids = enc["input_ids"].to(device) + attention_mask = enc["attention_mask"].to(device) + + prompt_length = input_ids.shape[1] + delta = strength * intervention_vector hooks = [] if strength > 0: @@ -395,32 +467,14 @@ def hook_fn(module, inp, out): hidden_states = out[0] if isinstance(out, tuple) else out original_dtype = hidden_states.dtype - # Get the latent dimension from the SAE's encoder - d_latent = sae.encoder.out_features - sae_device = sae.encoder.weight.device - - # --- Compute the decoder vector for the target feature --- - # 1. Create a one-hot activation for our single feature. - one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device) - one_hot_activation[0, 0, record.feature_id] = 1.0 + current_seq_len = hidden_states.shape[1] + new_hiddens = hidden_states.detach().clone() - # 2. Create the corresponding indices needed for the decode method. - indices = torch.tensor( - [[[record.feature_id]]], device=sae_device, dtype=torch.long - ) - - # 3. Decode one-hot vector to get feature's direction in hidden space. - # We subtract the decoded zero vector to remove any decoder bias. - decoded_zero = sae.decode(torch.zeros_like(one_hot_activation), indices) - decoder_vector = sae.decode(one_hot_activation, indices) - decoded_zero - decoder_vector = decoder_vector.squeeze() # Remove batch & seq dims - # --- End vector computation --- + intervention_start_index = prompt_length - 1 - # Calculate the change we want to apply. - delta = strength * decoder_vector + if current_seq_len >= prompt_length: + new_hiddens[:, intervention_start_index:, :] += delta.to(original_dtype) - new_hiddens = hidden_states.clone() - new_hiddens[:, -1, :] += delta.to(original_dtype) return ( (new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens @@ -430,7 +484,7 @@ def hook_fn(module, inp, out): try: with torch.no_grad(): - outputs = self.subject_model(input_ids) + outputs =self.subject_model(input_ids, attention_mask=attention_mask) next_token_logits = outputs.logits[0, -1, :] log_probs_next_token = ( F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None @@ -438,6 +492,7 @@ def hook_fn(module, inp, out): gen_ids = self.subject_model.generate( input_ids, + attention_mask=attention_mask, max_new_tokens=self.max_new_tokens, do_sample=False, pad_token_id=self.tokenizer.pad_token_id, @@ -451,13 +506,25 @@ def hook_fn(module, inp, out): log_probs_next_token.cpu() if get_logp_dist else torch.empty(0) ) + async def _score_explanation(self, generated_text: str, explanation: str) -> float: - """Computes log P(explanation | generated_text) under the subject model.""" + """ + Computes log P(explanation | generated_text) using the paper's + prompt format. + """ device = self._get_device() - # Create the full input sequence: context + explanation - context_enc = self.tokenizer(generated_text, return_tensors="pt") - explanation_enc = self.tokenizer(explanation, return_tensors="pt") + # Build the prompt from Appendix G.1 + prompt_template = ( + "\n" + f"{generated_text}\n" + "The above passage contains an amplified amount of \"" + ) + explanation_suffix = f"{explanation}\"" + + # Tokenize the parts + context_enc = self.tokenizer(prompt_template, return_tensors="pt") + explanation_enc = self.tokenizer(explanation_suffix, return_tensors="pt") full_input_ids = torch.cat( [context_enc.input_ids, explanation_enc.input_ids], dim=1 @@ -469,66 +536,25 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo # We only need to score the explanation part context_len = context_enc.input_ids.shape[1] + # Get logits for positions that predict the explanation tokens + # Shape: [batch_size, explanation_len, vocab_size] explanation_logits = logits[:, context_len - 1 : -1, :] # Get the target token IDs for the explanation + # Shape: [batch_size, explanation_len] target_ids = explanation_enc.input_ids.to(device) log_probs = F.log_softmax(explanation_logits, dim=-1) # Gather the log-probabilities of the actual explanation tokens - token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) + token_log_probs = log_probs.gather( + 2, target_ids.unsqueeze(-1) + ).squeeze(-1) + # Return the sum of log-probs for the explanation return token_log_probs.sum().item() - """ - Retrieves the correct SAE model, handling cases where the framework - provides a functools.partial wrapper. - """ - candidate = None - - # 1. Try to get the SAE from the record object first. - if hasattr(record, "sae") and record.sae: - candidate = record.sae - # 2. If not on the record, look it up in the explainer_model dictionary. - elif self.explainer_model and isinstance(self.explainer_model, dict): - full_key = self._get_full_hookpoint_path(hookpoint_str) - for key in [hookpoint_str, full_key]: - if self.explainer_model.get(key) is not None: - candidate = self.explainer_model.get(key) - break - - if candidate is not None: - # 3. Check if we need to unwrap a partial object. - if isinstance(candidate, functools.partial): - # Case A: The instance is in a bound method's __self__. - instance = getattr(candidate.func, "__self__", None) - if instance is not None: - return instance # Unwrapped successfully. - - # Case B: The instance is the first argument to the partial. - if candidate.args and len(candidate.args) > 0: - instance = candidate.args[0] - # A sanity check to make sure it looks like an SAE model. - if hasattr(instance, "encode") and hasattr(instance, "decode"): - return instance # Unwrapped successfully. - - # If we found a partial but failed to unwrap it, we cannot proceed. - print( - f"""ERROR: Found a partial for {hookpoint_str} but could not - unwrap the SAE instance.""" - ) - return None - - # If it's not a partial, it's the model itself. - return candidate - - print( - f"""ERROR: Surprisal scorer could not find - an SAE for hookpoint '{hookpoint_str}'""" - ) - return None def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any: """ @@ -541,23 +567,31 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An candidate = record.sae elif self.explainer_model and isinstance(self.explainer_model, dict): full_key = self._get_full_hookpoint_path(hookpoint_str) - for key in [hookpoint_str, full_key]: + short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp" + + for key in [hookpoint_str, full_key, short_key]: if self.explainer_model.get(key) is not None: candidate = self.explainer_model.get(key) break - - if candidate is not None: - if isinstance(candidate, functools.partial): - if candidate.keywords and "sae" in candidate.keywords: - return candidate.keywords["sae"] - - return candidate - - print( - f"""ERROR: Surprisal scorer could not find - an SAE for hookpoint '{hookpoint_str}'""" - ) - return None + + if candidate is None: + # This will raise an error if the key isn't found + raise ValueError(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}' in self.explainer_model") + + if isinstance(candidate, functools.partial): + # As shown in load_sparsify.py, the SAE is in the 'sae' keyword. + if candidate.keywords and "sae" in candidate.keywords: + return candidate.keywords["sae"] # Unwrapped successfully + else: + # This will raise an error if the partial is missing the keyword + raise ValueError(f"""ERROR: Found a partial for {hookpoint_str} but could not + find the 'sae' keyword. + func: {candidate.func} + args: {candidate.args} + keywords: {candidate.keywords}""") + + # This will raise an error if the candidate isn't a partial + raise ValueError(f"ERROR: Candidate for {hookpoint_str} was not a partial object, which was not expected. Type: {type(candidate)}") def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) From 68d6c6302e80bfc5653ef5bf1fbb936561b5a1a9 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Nov 2025 22:23:14 +0000 Subject: [PATCH 17/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../surprisal_intervention_scorer.py | 109 ++++++++++-------- 1 file changed, 61 insertions(+), 48 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 09363442..cc4bba87 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -168,24 +168,23 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor: """ Calculates the feature's decoder vector, subtracting the decoder bias. """ - - + d_latent = sae.encoder.out_features sae_device = sae.encoder.weight.device # Create a one-hot activation for our single feature. one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device) - + if feature_id >= d_latent: - print(f"DEBUG: ERROR - Feature ID {feature_id} is out of bounds for d_latent {d_latent}") + print( + f"DEBUG: ERROR - Feature ID {feature_id} is out of bounds for d_latent {d_latent}" + ) return torch.zeros(1) - + one_hot_activation[0, 0, feature_id] = 1.0 # Create the corresponding indices needed for the decode method. - indices = torch.tensor( - [[[feature_id]]], device=sae_device, dtype=torch.long - ) + indices = torch.tensor([[[feature_id]]], device=sae_device, dtype=torch.long) with torch.no_grad(): try: @@ -196,9 +195,9 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor: return torch.zeros(1) decoder_vector = vector_before_sub - decoded_zero - + final_norm = decoder_vector.norm().item() - + # --- MODIFIED DEBUG BLOCK --- # Only print if the feature is "decoder-live" if final_norm > 1e-6: @@ -206,14 +205,15 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor: print(f"DEBUG: sae.encoder.out_features (d_latent): {d_latent}") print(f"DEBUG: sae.encoder.weight.device (sae_device): {sae_device}") print(f"DEBUG: Norm of decoded_zero: {decoded_zero.norm().item()}") - print(f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}") + print( + f"DEBUG: Norm of vector_before_sub: {vector_before_sub.norm().item()}" + ) print(f"DEBUG: Feature {feature_id}, FINAL Vector Norm: {final_norm}") print("--- END DEBUG ---\n") # --- END MODIFIED BLOCK --- return decoder_vector.squeeze() - async def __call__(self, record: LatentRecord) -> ScorerResult: record_copy = copy.deepcopy(record) @@ -240,7 +240,7 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: sae = self._get_sae_for_hookpoint(hookpoint_str, record_copy) if not sae: raise ValueError(f"Could not find SAE for hookpoint {hookpoint_str}") - + intervention_vector = self._get_intervention_vector(sae, record_copy.feature_id) tuned_strength, initial_kl = await self._tune_strength( @@ -253,10 +253,18 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: for prompt in truncated_prompts: clean_text, clean_logp_dist = await self._generate_with_intervention( - prompt, record_copy, strength=0.0, intervention_vector=intervention_vector, get_logp_dist=True + prompt, + record_copy, + strength=0.0, + intervention_vector=intervention_vector, + get_logp_dist=True, ) int_text, int_logp_dist = await self._generate_with_intervention( - prompt, record_copy, strength=tuned_strength, intervention_vector=intervention_vector, get_logp_dist=True + prompt, + record_copy, + strength=tuned_strength, + intervention_vector=intervention_vector, + get_logp_dist=True, ) logp_clean = await self._score_explanation( @@ -300,7 +308,6 @@ async def __call__(self, record: LatentRecord) -> ScorerResult: ) return ScorerResult(record=record_copy, score=final_output_list) - async def _get_latent_activations( self, prompt: str, record: LatentRecord ) -> torch.Tensor: @@ -339,7 +346,6 @@ def capture_hook(module, inp, out): return feature_acts[0, :, record.feature_id].cpu() - async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: """ Truncates prompt to end just before the first token where latent activates. @@ -356,17 +362,18 @@ async def _truncate_prompt(self, prompt: str, record: LatentRecord) -> str: first_activation_idx = all_activation_indices[all_activation_indices > 0] if first_activation_idx.numel() > 0: - truncation_point = first_activation_idx[0].item() + truncation_point = first_activation_idx[0].item() input_ids = self.tokenizer(prompt, return_tensors="pt").input_ids[0] - truncated_ids = input_ids[:truncation_point + 1] + truncated_ids = input_ids[: truncation_point + 1] return self.tokenizer.decode(truncated_ids, skip_special_tokens=True) return prompt - async def _tune_strength( - self, prompts: List[str], record: LatentRecord, - intervention_vector: torch.Tensor + self, + prompts: List[str], + record: LatentRecord, + intervention_vector: torch.Tensor, ) -> Tuple[float, float]: """ Performs a binary search to find intervention strength that matches target_kl. @@ -408,22 +415,26 @@ async def _tune_strength( best_strength = mid_strength # Return the best found strength and the corresponding KL - final_kl = await self._calculate_avg_kl(prompts, record, best_strength, intervention_vector) + final_kl = await self._calculate_avg_kl( + prompts, record, best_strength, intervention_vector + ) return best_strength, final_kl - async def _calculate_avg_kl( - self, prompts: List[str], record: LatentRecord, strength: float, - intervention_vector: torch.Tensor + self, + prompts: List[str], + record: LatentRecord, + strength: float, + intervention_vector: torch.Tensor, ) -> float: total_kl = 0.0 n = 0 for prompt in prompts: _, clean_logp = await self._generate_with_intervention( - prompt, record, 0.0, intervention_vector,True + prompt, record, 0.0, intervention_vector, True ) _, int_logp = await self._generate_with_intervention( - prompt, record, strength, intervention_vector,True + prompt, record, strength, intervention_vector, True ) p_clean = torch.exp(clean_logp) kl_div = F.kl_div( @@ -433,7 +444,6 @@ async def _calculate_avg_kl( n += 1 return total_kl / n if n > 0 else 0.0 - async def _generate_with_intervention( self, prompt: str, @@ -473,8 +483,9 @@ def hook_fn(module, inp, out): intervention_start_index = prompt_length - 1 if current_seq_len >= prompt_length: - new_hiddens[:, intervention_start_index:, :] += delta.to(original_dtype) - + new_hiddens[:, intervention_start_index:, :] += delta.to( + original_dtype + ) return ( (new_hiddens,) + out[1:] if isinstance(out, tuple) else new_hiddens @@ -484,7 +495,7 @@ def hook_fn(module, inp, out): try: with torch.no_grad(): - outputs =self.subject_model(input_ids, attention_mask=attention_mask) + outputs = self.subject_model(input_ids, attention_mask=attention_mask) next_token_logits = outputs.logits[0, -1, :] log_probs_next_token = ( F.log_softmax(next_token_logits, dim=-1) if get_logp_dist else None @@ -506,10 +517,9 @@ def hook_fn(module, inp, out): log_probs_next_token.cpu() if get_logp_dist else torch.empty(0) ) - async def _score_explanation(self, generated_text: str, explanation: str) -> float: """ - Computes log P(explanation | generated_text) using the paper's + Computes log P(explanation | generated_text) using the paper's prompt format. """ device = self._get_device() @@ -518,9 +528,9 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo prompt_template = ( "\n" f"{generated_text}\n" - "The above passage contains an amplified amount of \"" + 'The above passage contains an amplified amount of "' ) - explanation_suffix = f"{explanation}\"" + explanation_suffix = f'{explanation}"' # Tokenize the parts context_enc = self.tokenizer(prompt_template, return_tensors="pt") @@ -536,7 +546,7 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo # We only need to score the explanation part context_len = context_enc.input_ids.shape[1] - + # Get logits for positions that predict the explanation tokens # Shape: [batch_size, explanation_len, vocab_size] explanation_logits = logits[:, context_len - 1 : -1, :] @@ -548,14 +558,11 @@ async def _score_explanation(self, generated_text: str, explanation: str) -> flo log_probs = F.log_softmax(explanation_logits, dim=-1) # Gather the log-probabilities of the actual explanation tokens - token_log_probs = log_probs.gather( - 2, target_ids.unsqueeze(-1) - ).squeeze(-1) + token_log_probs = log_probs.gather(2, target_ids.unsqueeze(-1)).squeeze(-1) # Return the sum of log-probs for the explanation return token_log_probs.sum().item() - def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> Any: """ Retrieves the correct SAE model, handling the specific functools.partial @@ -567,16 +574,18 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An candidate = record.sae elif self.explainer_model and isinstance(self.explainer_model, dict): full_key = self._get_full_hookpoint_path(hookpoint_str) - short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp" + short_key = ".".join(hookpoint_str.split(".")[-2:]) # e.g., "layers.6.mlp" for key in [hookpoint_str, full_key, short_key]: if self.explainer_model.get(key) is not None: candidate = self.explainer_model.get(key) break - + if candidate is None: # This will raise an error if the key isn't found - raise ValueError(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}' in self.explainer_model") + raise ValueError( + f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}' in self.explainer_model" + ) if isinstance(candidate, functools.partial): # As shown in load_sparsify.py, the SAE is in the 'sae' keyword. @@ -584,14 +593,18 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An return candidate.keywords["sae"] # Unwrapped successfully else: # This will raise an error if the partial is missing the keyword - raise ValueError(f"""ERROR: Found a partial for {hookpoint_str} but could not + raise ValueError( + f"""ERROR: Found a partial for {hookpoint_str} but could not find the 'sae' keyword. func: {candidate.func} args: {candidate.args} - keywords: {candidate.keywords}""") - + keywords: {candidate.keywords}""" + ) + # This will raise an error if the candidate isn't a partial - raise ValueError(f"ERROR: Candidate for {hookpoint_str} was not a partial object, which was not expected. Type: {type(candidate)}") + raise ValueError( + f"ERROR: Candidate for {hookpoint_str} was not a partial object, which was not expected. Type: {type(candidate)}" + ) def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) From 101257ef7f8bdafcc3102c53ac3afbf3d0216315 Mon Sep 17 00:00:00 2001 From: Sai Reddy Date: Mon, 17 Nov 2025 22:27:23 +0000 Subject: [PATCH 18/25] Line-spacing fix --- .../intervention/surprisal_intervention_scorer.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 09363442..e33e98f0 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -177,7 +177,8 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor: one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device) if feature_id >= d_latent: - print(f"DEBUG: ERROR - Feature ID {feature_id} is out of bounds for d_latent {d_latent}") + print(f"""DEBUG: ERROR - Feature ID {feature_id} is out of bounds + for d_latent {d_latent}""") return torch.zeros(1) one_hot_activation[0, 0, feature_id] = 1.0 @@ -576,7 +577,8 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An if candidate is None: # This will raise an error if the key isn't found - raise ValueError(f"ERROR: Surprisal scorer could not find an SAE for hookpoint '{hookpoint_str}' in self.explainer_model") + raise ValueError(f"ERROR: Surprisal scorer could not find an SAE " + f"for hookpoint '{hookpoint_str}' in self.explainer_model") if isinstance(candidate, functools.partial): # As shown in load_sparsify.py, the SAE is in the 'sae' keyword. @@ -584,14 +586,16 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An return candidate.keywords["sae"] # Unwrapped successfully else: # This will raise an error if the partial is missing the keyword - raise ValueError(f"""ERROR: Found a partial for {hookpoint_str} but could not + raise ValueError(f"""ERROR: Found a partial for + {hookpoint_str} but could not find the 'sae' keyword. func: {candidate.func} args: {candidate.args} keywords: {candidate.keywords}""") # This will raise an error if the candidate isn't a partial - raise ValueError(f"ERROR: Candidate for {hookpoint_str} was not a partial object, which was not expected. Type: {type(candidate)}") + raise ValueError(f"""ERROR: Candidate for {hookpoint_str} was not a partial + object, which was not expected. Type: {type(candidate)}""") def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) From 8c65dbe90eced4524d77bed583bc286ddf420d85 Mon Sep 17 00:00:00 2001 From: Sai Reddy Date: Mon, 17 Nov 2025 22:37:08 +0000 Subject: [PATCH 19/25] Fix line-spacing issues --- delphi/scorers/intervention/surprisal_intervention_scorer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index e33e98f0..c416c5c2 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -461,7 +461,7 @@ async def _generate_with_intervention( sae = self._get_sae_for_hookpoint(hookpoint_str, record) if not sae: raise ValueError( - f"Could not find a valid SAE for hookpoint {hookpoint_str}" + f"Couldn't find a valid SAE for hookpoint {hookpoint_str}" ) def hook_fn(module, inp, out): From cabb1511f5b75053a7878b9ffc27ac001dc16879 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 17 Nov 2025 22:51:29 +0000 Subject: [PATCH 20/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../surprisal_intervention_scorer.py | 21 ++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/delphi/scorers/intervention/surprisal_intervention_scorer.py b/delphi/scorers/intervention/surprisal_intervention_scorer.py index 4a72c940..f2ae6179 100644 --- a/delphi/scorers/intervention/surprisal_intervention_scorer.py +++ b/delphi/scorers/intervention/surprisal_intervention_scorer.py @@ -176,8 +176,10 @@ def _get_intervention_vector(self, sae: Any, feature_id: int) -> torch.Tensor: one_hot_activation = torch.zeros(1, 1, d_latent, device=sae_device) if feature_id >= d_latent: - print(f"""DEBUG: ERROR - Feature ID {feature_id} is out of bounds - for d_latent {d_latent}""") + print( + f"""DEBUG: ERROR - Feature ID {feature_id} is out of bounds + for d_latent {d_latent}""" + ) return torch.zeros(1) one_hot_activation[0, 0, feature_id] = 1.0 @@ -582,8 +584,10 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An if candidate is None: # This will raise an error if the key isn't found - raise ValueError(f"ERROR: Surprisal scorer could not find an SAE " - f"for hookpoint '{hookpoint_str}' in self.explainer_model") + raise ValueError( + f"ERROR: Surprisal scorer could not find an SAE " + f"for hookpoint '{hookpoint_str}' in self.explainer_model" + ) if isinstance(candidate, functools.partial): # As shown in load_sparsify.py, the SAE is in the 'sae' keyword. @@ -591,7 +595,8 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An return candidate.keywords["sae"] # Unwrapped successfully else: # This will raise an error if the partial is missing the keyword - raise ValueError(f"""ERROR: Found a partial for + raise ValueError( + f"""ERROR: Found a partial for {hookpoint_str} but could not find the 'sae' keyword. func: {candidate.func} @@ -600,8 +605,10 @@ def _get_sae_for_hookpoint(self, hookpoint_str: str, record: LatentRecord) -> An ) # This will raise an error if the candidate isn't a partial - raise ValueError(f"""ERROR: Candidate for {hookpoint_str} was not a partial - object, which was not expected. Type: {type(candidate)}""") + raise ValueError( + f"""ERROR: Candidate for {hookpoint_str} was not a partial + object, which was not expected. Type: {type(candidate)}""" + ) def _get_intervention_direction(self, record: LatentRecord) -> torch.Tensor: hookpoint_str = self.hookpoint_str or getattr(record, "hookpoint", None) From 1cb61400884f4931be3f61567e31a8ae8451d58c Mon Sep 17 00:00:00 2001 From: Sai Reddy Date: Fri, 21 Nov 2025 15:20:10 +0000 Subject: [PATCH 21/25] Create plots for intervention scoring --- delphi/__main__.py | 6 +- delphi/log/result_analysis.py | 1214 +++++++++++++++++++++++++-------- delphi/temp.py | 11 + 3 files changed, 937 insertions(+), 294 deletions(-) create mode 100644 delphi/temp.py diff --git a/delphi/__main__.py b/delphi/__main__.py index 0553ff5a..1e5aced7 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -491,7 +491,11 @@ async def run( del model, hookpoint_to_sparse_encode if run_cfg.verbose: - log_results(scores_path, visualize_path, run_cfg.hookpoints, run_cfg.scorers) + log_results(scores_path, + visualize_path, + run_cfg.hookpoints, + run_cfg.scorers, + model_name=run_cfg.model) if __name__ == "__main__": diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 4852e8d6..6c26757a 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -1,3 +1,673 @@ +# from pathlib import Path +# from typing import Optional + +# import orjson +# import pandas as pd +# import plotly.express as px +# import plotly.graph_objects as go +# import torch +# from sklearn.metrics import roc_auc_score, roc_curve + + +# def plot_firing_vs_f1( +# latent_df: pd.DataFrame, num_tokens: int, out_dir: Path, run_label: str +# ) -> None: +# out_dir.mkdir(parents=True, exist_ok=True) +# for module, module_df in latent_df.groupby("module"): + +# if "firing_count" not in module_df.columns: +# print( +# f"""WARNING: 'firing_count' column not found for module {module}. +# Skipping plot.""" +# ) +# continue + +# module_df = module_df.copy() +# # Filter out rows where f1_score is NaN to avoid errors in plotting +# module_df = module_df[module_df["f1_score"].notna()] +# if module_df.empty: +# continue + +# module_df["firing_rate"] = module_df["firing_count"] / num_tokens +# fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) +# fig.update_layout( +# xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0] +# ) +# fig.write_image(out_dir / f"{run_label}_{module}_firing_rates.pdf") + + +# def import_plotly(): +# """Import plotly with mitigiation for MathJax bug.""" +# try: +# import plotly.express as px +# import plotly.io as pio +# except ImportError: +# raise ImportError( +# "Plotly is not installed.\n" +# "Please install it using `pip install plotly`, " +# "or install the `[visualize]` extra." +# ) +# pio.kaleido.scope.mathjax = None +# return px + + +# def compute_auc(df: pd.DataFrame) -> float | None: + +# valid_df = df[df.probability.notna()] +# if valid_df.probability.nunique() <= 1: +# return None +# return roc_auc_score(valid_df.activating, valid_df.probability) + + +# def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): +# out_dir.mkdir(exist_ok=True, parents=True) +# for label in df["score_type"].unique(): +# # Filter out surprisal_intervention as 'accuracy' is not relevant for it +# if label == "surprisal_intervention": +# continue +# fig = px.histogram( +# df[df["score_type"] == label], +# x="accuracy", +# nbins=100, +# title=f"Accuracy distribution: {label}", +# ) +# fig.write_image(out_dir / f"{label}_accuracy.pdf") + + +# def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): +# """ +# Plots statistics for the surprisal_intervention scorer: +# 1. A histogram of the KL Divergence scores. +# 2. A bar chart of 'Decoder-Live' vs 'Decoder-Dead' features. +# """ +# out_dir.mkdir(exist_ok=True, parents=True) + +# display_name = model_name.split("/")[-1] if "/" in model_name else model_name + +# # 1. KL Divergence Histogram +# # This shows the distribution of "causal impact" across all features +# fig_hist = px.histogram( +# df, +# x="avg_kl_divergence", +# nbins=50, +# title="Distribution of Intervention KL Divergence ({display_name})", +# labels={"avg_kl_divergence": "Average KL Divergence (Causal Effect)"}, +# log_y=True # Log scale helps visualize the 'long tail' if many are 0 +# ) +# fig_hist.update_layout(showlegend=False) +# fig_hist.write_image(out_dir / "intervention_kl_histogram.pdf") + +# # 2. Live vs Dead Bar Chart +# # We define "Live" as having a KL > 0.01 (non-zero effect) +# threshold = 0.01 +# df["status"] = df["avg_kl_divergence"].apply( +# lambda x: "Decoder-Live" if x > threshold else "Decoder-Dead" +# ) + +# counts = df["status"].value_counts().reset_index() +# counts.columns = ["Status", "Count"] + +# # Calculate percentage for the title +# total = counts["Count"].sum() +# live_count = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() +# live_pct = (live_count / total) * 100 if total > 0 else 0 + +# fig_bar = px.bar( +# counts, +# x="Status", +# y="Count", +# color="Status", +# title=f"Causal Relevance: {live_pct:.1f}% Live Features", +# text="Count", +# color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"} +# ) +# fig_bar.update_traces(textposition='auto') +# fig_bar.write_image(out_dir / "intervention_live_dead_split.pdf") + + +# def plot_roc_curve(df: pd.DataFrame, out_dir: Path): + +# valid_df = df[df.probability.notna()] +# if ( +# valid_df.empty +# or valid_df.activating.nunique() <= 1 +# or valid_df.probability.nunique() <= 1 +# ): +# return + +# fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) +# auc = roc_auc_score(valid_df.activating, valid_df.probability) +# fig = go.Figure( +# data=[ +# go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), +# go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), +# ] +# ) +# fig.update_layout( +# title="ROC Curve", +# xaxis_title="FPR", +# yaxis_title="TPR", +# ) +# out_dir.mkdir(exist_ok=True, parents=True) +# fig.write_image(out_dir / "roc_curve.pdf") + + +# def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: +# df_valid = df[df["prediction"].notna()] +# if df_valid.empty: +# return dict( +# true_positives=0, +# true_negatives=0, +# false_positives=0, +# false_negatives=0, +# total_examples=0, +# total_positives=0, +# total_negatives=0, +# failed_count=len(df), +# ) + +# act = df_valid["activating"].astype(bool) +# total = len(df_valid) +# pos = act.sum() +# neg = total - pos +# tp = ((df_valid.prediction >= threshold) & act).sum() +# tn = ((df_valid.prediction < threshold) & ~act).sum() +# fp = ((df_valid.prediction >= threshold) & ~act).sum() +# fn = ((df_valid.prediction < threshold) & act).sum() + +# return dict( +# true_positives=tp, +# true_negatives=tn, +# false_positives=fp, +# false_negatives=fn, +# total_examples=total, +# total_positives=pos, +# total_negatives=neg, +# failed_count=len(df) - len(df_valid), +# ) + + +# def compute_classification_metrics(conf: dict) -> dict: +# tp, tn, fp, fn = ( +# conf["true_positives"], +# conf["true_negatives"], +# conf["false_positives"], +# conf["false_negatives"], +# ) +# pos, neg = conf["total_positives"], conf["total_negatives"] + +# balanced_accuracy = ( +# (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) +# ) / 2 +# precision = tp / (tp + fp) if tp + fp > 0 else 0 +# recall = tp / pos if pos > 0 else 0 +# f1 = ( +# 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 +# ) + +# return dict( +# precision=precision, +# recall=recall, +# f1_score=f1, +# accuracy=balanced_accuracy, +# true_positive_rate=tp / pos if pos > 0 else 0, +# true_negative_rate=tn / neg if neg > 0 else 0, +# false_positive_rate=fp / neg if neg > 0 else 0, +# false_negative_rate=fn / pos if pos > 0 else 0, +# ) + + +# def load_data(scores_path: Path, modules: list[str]): +# """Load all on-disk data into a single DataFrame.""" + +# def parse_score_file(path: Path) -> pd.DataFrame: +# try: +# data = orjson.loads(path.read_bytes()) +# except orjson.JSONDecodeError: +# print(f"Error decoding JSON from {path}. Skipping file.") +# return pd.DataFrame() + +# if not isinstance(data, list): +# print( +# f"""Warning: Expected a list of results in {path}, +# but found {type(data)}. +# Skipping file.""" +# ) +# return pd.DataFrame() + +# latent_idx = int(path.stem.split("latent")[-1]) + +# # Updated to extract all possible keys safely using .get() +# return pd.DataFrame( +# [ +# { +# "text": "".join(ex.get("str_tokens", [])), +# "distance": ex.get("distance"), +# "activating": ex.get("activating"), +# "prediction": ex.get("prediction"), +# "probability": ex.get("probability"), +# "correct": ex.get("correct"), +# "activations": ex.get("activations"), +# "final_score": ex.get("final_score"), +# "avg_kl_divergence": ex.get("avg_kl_divergence"), +# "latent_idx": latent_idx, +# } +# for ex in data +# ] +# ) + +# counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" +# counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} +# if not all(module in counts for module in modules): +# print("Missing firing counts for some modules, setting counts to None.") +# print(f"Missing modules: {[m for m in modules if m not in counts]}") +# counts = None + +# latent_dfs = [] +# for score_type_dir in scores_path.iterdir(): +# if not score_type_dir.is_dir(): +# continue +# for module in modules: +# for file in score_type_dir.glob(f"*{module}*"): +# latent_df = parse_score_file(file) +# if latent_df.empty: +# continue +# latent_df["score_type"] = score_type_dir.name +# latent_df["module"] = module +# if counts: +# latent_idx = latent_df["latent_idx"].iloc[0] +# latent_df["firing_count"] = ( +# counts[module][latent_idx].item() +# if module in counts and latent_idx in counts[module] +# else None +# ) +# latent_dfs.append(latent_df) + +# if not latent_dfs: +# return pd.DataFrame(), counts + +# return pd.concat(latent_dfs, ignore_index=True), counts + + +# def get_agg_metrics( +# latent_df: pd.DataFrame, counts: Optional[dict[str, torch.Tensor]] +# ) -> pd.DataFrame: +# processed_rows = [] +# for score_type, group_df in latent_df.groupby("score_type"): +# # For surprisal_intervention, we don't compute classification metrics +# if score_type == "surprisal_intervention": +# continue + +# conf = compute_confusion(group_df) +# class_m = compute_classification_metrics(conf) +# auc = compute_auc(group_df) +# f1_w = frequency_weighted_f1(group_df, counts) if counts else None + +# row = { +# "score_type": score_type, +# **conf, +# **class_m, +# "auc": auc, +# "weighted_f1": f1_w, +# } +# processed_rows.append(row) + +# return pd.DataFrame(processed_rows) + + +# def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: +# f1s = ( +# latent_df.groupby(["module", "latent_idx"]) +# .apply( +# lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] +# ) +# .reset_index(name="f1_score") # <- naive (un-weighted) F1 +# ) +# return latent_df.merge(f1s, on=["module", "latent_idx"]) + + +# def log_results( +# scores_path: Path, +# viz_path: Path, +# modules: list[str], +# scorer_names: list[str], +# model_name: str = "Unknown Model" +# ): +# import_plotly() + +# latent_df, counts = load_data(scores_path, modules) +# if latent_df.empty: +# print("No data to analyze.") +# return + +# latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] + +# # Separate the dataframes for different processing +# classification_df = latent_df[latent_df["score_type"] != "surprisal_intervention"] +# surprisal_df = latent_df[latent_df["score_type"] == "surprisal_intervention"] + +# if not classification_df.empty: +# classification_df = add_latent_f1(classification_df) +# if counts: +# plot_firing_vs_f1( +# classification_df, +# num_tokens=10_000_000, +# out_dir=viz_path, +# run_label=scores_path.name, +# ) +# plot_roc_curve(classification_df, viz_path) +# processed_df = get_agg_metrics(classification_df, counts) +# plot_accuracy_hist(processed_df, viz_path) + +# if counts: +# dead = sum((counts[m] == 0).sum().item() for m in modules) +# print(f"Number of dead features: {dead}") + +# for score_type in latent_df["score_type"].unique(): + +# if score_type == "surprisal_intervention": +# # Drop duplicates since score is per-latent, not per-example +# unique_latents = surprisal_df.drop_duplicates( +# subset=["module", "latent_idx"] +# ).copy() + +# avg_score = unique_latents["final_score"].mean() +# avg_kl = unique_latents["avg_kl_divergence"].mean() + +# # We define "Decoder-Live" as having a KL > 0.01 (non-zero effect) +# threshold = 0.01 +# n_total = len(unique_latents) +# n_live = len(unique_latents[unique_latents["avg_kl_divergence"] > threshold]) +# live_pct = (n_live / n_total) * 100 if n_total > 0 else 0.0 + +# print(f"\n--- {score_type.title()} Metrics ---") +# print(f"Average Normalized Score: {avg_score:.3f}") +# print(f"Average KL Divergence: {avg_kl:.3f}") +# print(f"Decoder-Live Percentage: {live_pct:.2f}%") + + +# plot_intervention_stats(unique_latents, viz_path, model_name) + +# else: +# if not classification_df.empty: +# score_type_summary = processed_df[ +# processed_df.score_type == score_type +# ].iloc[0] +# print(f"\n--- {score_type.title()} Metrics ---") +# print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") +# print(f"F1 Score: {score_type_summary['f1_score']:.3f}") + +# if counts and score_type_summary["weighted_f1"] is not None: +# print( +# f"""Frequency-Weighted F1 Score: +# {score_type_summary['weighted_f1']:.3f}""" +# ) + +# print(f"Precision: {score_type_summary['precision']:.3f}") +# print(f"Recall: {score_type_summary['recall']:.3f}") + +# if score_type_summary["auc"] is not None: +# print(f"AUC: {score_type_summary['auc']:.3f}") +# else: +# print("AUC not available.") + +# from pathlib import Path +# from typing import Optional + +# import orjson +# import pandas as pd +# import plotly.express as px +# import plotly.graph_objects as go +# import torch +# from sklearn.metrics import roc_auc_score, roc_curve + + +# # --- PLOTTING HELPERS --- + +# def import_plotly(): +# """Import plotly with mitigation for MathJax bug.""" +# try: +# import plotly.express as px +# import plotly.io as pio +# except ImportError: +# raise ImportError( +# "Plotly is not installed. Please install it via `pip install plotly`." +# ) +# pio.kaleido.scope.mathjax = None +# return px + +# def plot_firing_vs_f1(latent_df, num_tokens, out_dir, run_label): +# out_dir.mkdir(parents=True, exist_ok=True) +# for module, module_df in latent_df.groupby("module"): +# if "firing_count" not in module_df.columns: +# continue +# module_df = module_df[module_df["f1_score"].notna()] +# if module_df.empty: continue + +# module_df["firing_rate"] = module_df["firing_count"] / num_tokens +# fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) +# fig.update_layout(xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0]) +# fig.write_image(out_dir / f"{run_label}_{module}_firing_rates.pdf") + +# def plot_roc_curve(df, out_dir): +# valid_df = df[df.probability.notna()] +# if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: +# return + +# fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) +# auc = roc_auc_score(valid_df.activating, valid_df.probability) +# fig = go.Figure(data=[ +# go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), +# go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), +# ]) +# fig.update_layout(title="ROC Curve", xaxis_title="FPR", yaxis_title="TPR") +# out_dir.mkdir(exist_ok=True, parents=True) +# fig.write_image(out_dir / "roc_curve.pdf") + +# def plot_accuracy_hist(df, out_dir): +# out_dir.mkdir(exist_ok=True, parents=True) +# for label in df["score_type"].unique(): +# fig = px.histogram(df[df["score_type"] == label], x="accuracy", nbins=100, title=f"Accuracy: {label}") +# fig.write_image(out_dir / f"{label}_accuracy.pdf") + +# def plot_intervention_stats(df, out_dir, model_name): +# """Specific plots for Intervention scoring.""" +# out_dir.mkdir(exist_ok=True, parents=True) +# display_name = model_name.split("/")[-1] if "/" in model_name else model_name + +# # 1. KL Histogram +# fig_hist = px.histogram( +# df, x="avg_kl_divergence", nbins=50, log_y=True, +# title=f"KL Divergence ({display_name})", +# labels={"avg_kl_divergence": "Avg KL Divergence (Causal Effect)"} +# ) +# fig_hist.write_image(out_dir / "intervention_kl_dist.pdf") + +# # 2. Live/Dead Split +# threshold = 0.01 +# df["status"] = df["avg_kl_divergence"].apply(lambda x: "Decoder-Live" if x > threshold else "Decoder-Dead") +# counts = df["status"].value_counts().reset_index() +# counts.columns = ["Status", "Count"] + +# total = counts["Count"].sum() +# live = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() if "Decoder-Live" in counts["Status"].values else 0 +# pct = (live / total * 100) if total > 0 else 0 + +# fig_bar = px.bar( +# counts, x="Status", y="Count", color="Status", text="Count", +# title=f"Causal Relevance: {pct:.1f}% Live ({display_name})", +# color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"} +# ) +# fig_bar.write_image(out_dir / "intervention_live_dead.pdf") + + +# # --- METRIC COMPUTATION --- + +# def compute_confusion(df, threshold=0.5): +# df_valid = df[df["prediction"].notna()] +# if df_valid.empty: return dict(tp=0, tn=0, fp=0, fn=0, pos=0, neg=0) + +# act = df_valid["activating"].astype(bool) +# pred = df_valid["prediction"] >= threshold + +# tp = (pred & act).sum() +# tn = (~pred & ~act).sum() +# fp = (pred & ~act).sum() +# fn = (~pred & act).sum() + +# return dict( +# true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, +# total_positives=act.sum(), total_negatives=(~act).sum() +# ) + +# def compute_classification_metrics(conf): +# tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] +# pos, neg = conf["total_positives"], conf["total_negatives"] + +# acc = ((tp / pos if pos else 0) + (tn / neg if neg else 0)) / 2 +# prec = tp / (tp + fp) if (tp + fp) else 0 +# rec = tp / pos if pos else 0 +# f1 = 2 * (prec * rec) / (prec + rec) if (prec + rec) else 0 + +# return dict(accuracy=acc, precision=prec, recall=rec, f1_score=f1) + +# def compute_auc(df): +# valid = df[df.probability.notna()] +# if valid.probability.nunique() <= 1: return None +# return roc_auc_score(valid.activating, valid.probability) + +# def get_agg_metrics(df): +# rows = [] +# for scorer, group in df.groupby("score_type"): +# conf = compute_confusion(group) +# metrics = compute_classification_metrics(conf) +# rows.append({"score_type": scorer, **conf, **metrics, "auc": compute_auc(group)}) +# return pd.DataFrame(rows) + +# def add_latent_f1(df): +# # Calculate F1 per latent for plotting +# f1s = df.groupby(["module", "latent_idx"]).apply( +# lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] +# ).reset_index(name="f1_score") +# return df.merge(f1s, on=["module", "latent_idx"]) + + +# # --- DATA LOADING --- + +# def load_data(scores_path, modules): +# def parse_file(path): +# try: +# data = orjson.loads(path.read_bytes()) +# if not isinstance(data, list): return pd.DataFrame() +# latent_idx = int(path.stem.split("latent")[-1]) +# return pd.DataFrame([{ +# "text": "".join(ex.get("str_tokens", [])), +# "activating": ex.get("activating"), +# "prediction": ex.get("prediction"), +# "probability": ex.get("probability"), +# "final_score": ex.get("final_score"), +# "avg_kl_divergence": ex.get("avg_kl_divergence"), +# "latent_idx": latent_idx +# } for ex in data]) +# except Exception: +# return pd.DataFrame() + +# counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" +# counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} + +# dfs = [] +# for scorer_dir in scores_path.iterdir(): +# if not scorer_dir.is_dir(): continue +# for module in modules: +# for f in scorer_dir.glob(f"*{module}*"): +# df = parse_file(f) +# if df.empty: continue +# df["score_type"] = scorer_dir.name +# df["module"] = module +# if module in counts: +# idx = df["latent_idx"].iloc[0] +# if idx < len(counts[module]): +# df["firing_count"] = counts[module][idx].item() +# dfs.append(df) + +# return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts + + +# # --- MAIN HANDLERS --- + +# def handle_classification_results(df, counts, viz_path, run_label): +# """Handles Fuzz, Detection, Simulation.""" +# print(f"\n--- Classification Analysis ({len(df)} examples) ---") + +# # Add per-latent F1 for plotting +# df = add_latent_f1(df) + +# # Plots +# if counts: +# plot_firing_vs_f1(df, 10_000_000, viz_path, run_label) +# plot_roc_curve(df, viz_path) + +# # Aggregated Metrics (Accuracy, F1, etc.) +# agg_df = get_agg_metrics(df) +# plot_accuracy_hist(agg_df, viz_path) + +# # Console Output +# for _, row in agg_df.iterrows(): +# print(f"\n[ {row['score_type'].title()} ]") +# print(f"Accuracy: {row['accuracy']:.3f}") +# print(f"F1 Score: {row['f1_score']:.3f}") +# print(f"Precision: {row['precision']:.3f}") +# print(f"Recall: {row['recall']:.3f}") +# if row['auc']: print(f"AUC: {row['auc']:.3f}") + + +# def handle_intervention_results(df, viz_path, model_name): +# """Handles Surprisal Intervention.""" +# # Deduplicate: we only need one row per latent per module +# unique_latents = df.drop_duplicates(subset=["module", "latent_idx"]).copy() + +# avg_score = unique_latents["final_score"].mean() +# avg_kl = unique_latents["avg_kl_divergence"].mean() + +# # Calculate Decoder-Live % +# total = len(unique_latents) +# live = len(unique_latents[unique_latents["avg_kl_divergence"] > 0.01]) +# pct = (live / total * 100) if total > 0 else 0 + +# print(f"\n--- Surprisal Intervention Analysis ({total} latents) ---") +# print(f"Avg Normalized Score: {avg_score:.3f}") +# print(f"Avg KL Divergence: {avg_kl:.3f}") +# print(f"Decoder-Live %: {pct:.2f}%") + +# plot_intervention_stats(unique_latents, viz_path, model_name) + + +# # --- ENTRY POINT --- + +# def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str], model_name: str = "Unknown"): +# import_plotly() + +# # 1. Load ALL data (Global Reporting) +# latent_df, counts = load_data(scores_path, modules) +# if latent_df.empty: +# print("No data found to analyze.") +# return + +# print(f"Generating report for scorers found: {latent_df['score_type'].unique()}") + +# # 2. Split Data +# classification_mask = latent_df["score_type"] != "surprisal_intervention" +# classification_df = latent_df[classification_mask] +# intervention_df = latent_df[~classification_mask] + +# # 3. Dispatch to Handlers +# if not classification_df.empty: +# handle_classification_results(classification_df, counts, viz_path, scores_path.name) + +# if not intervention_df.empty: +# handle_intervention_results(intervention_df, viz_path, model_name) + + from pathlib import Path from typing import Optional @@ -9,338 +679,296 @@ from sklearn.metrics import roc_auc_score, roc_curve -def plot_firing_vs_f1( - latent_df: pd.DataFrame, num_tokens: int, out_dir: Path, run_label: str -) -> None: - out_dir.mkdir(parents=True, exist_ok=True) - for module, module_df in latent_df.groupby("module"): +# --- 1. NEW PLOTTING FUNCTIONS --- + +def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label: str): + """ + Replicates the Scatter Plot from the paper (Figure 3/Appendix G). + Plots Fuzz Score vs. Intervention Score for the same latents. + """ + # We need to merge the rows for 'fuzz' and 'surprisal_intervention' + # 1. Pivot the table so we have columns: 'latent_idx', 'fuzz_score', 'intervention_score' + + # Extract Fuzz Scores (using F1 or Accuracy as the metric) + fuzz_df = latent_df[latent_df["score_type"] == "fuzz"].copy() + if fuzz_df.empty: return + + # Calculate per-latent F1 for fuzzing + fuzz_metrics = fuzz_df.groupby(["module", "latent_idx"]).apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ).reset_index(name="fuzz_score") + + # Extract Intervention Scores + int_df = latent_df[latent_df["score_type"] == "surprisal_intervention"].copy() + if int_df.empty: return + + # Deduplicate intervention scores + int_metrics = int_df.drop_duplicates(subset=["module", "latent_idx"])[ + ["module", "latent_idx", "avg_kl_divergence", "final_score"] + ] + + # Merge them + merged = pd.merge(fuzz_metrics, int_metrics, on=["module", "latent_idx"]) + + if merged.empty: + print("Could not merge Fuzz and Intervention scores (no matching latents).") + return + + # Plot 1: KL vs Fuzz (Causal Impact vs Correlational Quality) + fig_kl = px.scatter( + merged, + x="fuzz_score", + y="avg_kl_divergence", + hover_data=["latent_idx"], + title=f"Correlation vs. Causation (KL) - {run_label}", + labels={"fuzz_score": "Fuzzing Score (Correlation)", "avg_kl_divergence": "Intervention KL (Causation)"}, + trendline="ols" # Adds a regression line to show the negative/zero correlation + ) + fig_kl.write_image(out_dir / "scatter_fuzz_vs_kl.pdf") + + # Plot 2: Score vs Fuzz (Original Paper Metric) + fig_score = px.scatter( + merged, + x="fuzz_score", + y="final_score", + hover_data=["latent_idx"], + title=f"Correlation vs. Causation (Score) - {run_label}", + labels={"fuzz_score": "Fuzzing Score (Correlation)", "final_score": "Intervention Score (Surprisal)"}, + trendline="ols" + ) + fig_score.write_image(out_dir / "scatter_fuzz_vs_score.pdf") + print("Generated Fuzz vs. Intervention scatter plots.") + + +def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): + """ + Improved histograms. Plots two versions: + 1. All Features (Log Scale) - to show the dead features. + 2. Live Features Only - to show the distribution of the ones that work. + """ + out_dir.mkdir(exist_ok=True, parents=True) + display_name = model_name.split("/")[-1] if "/" in model_name else model_name + + # 1. Live/Dead Split Bar Chart + threshold = 0.01 + df["status"] = df["avg_kl_divergence"].apply( + lambda x: "Decoder-Live" if x > threshold else "Decoder-Dead" + ) + counts = df["status"].value_counts().reset_index() + counts.columns = ["Status", "Count"] + + # Get percentage + total = counts["Count"].sum() + live = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() if "Decoder-Live" in counts["Status"].values else 0 + pct = (live / total * 100) if total > 0 else 0 + + fig_bar = px.bar( + counts, x="Status", y="Count", color="Status", text="Count", + title=f"Causal Relevance: {pct:.1f}% Live ({display_name})", + color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"} + ) + fig_bar.write_image(out_dir / "intervention_live_dead_split.pdf") + + # 2. "Live Features Only" Histogram (The "Pretty" one) + live_df = df[df["avg_kl_divergence"] > threshold] + if not live_df.empty: + fig_live = px.histogram( + live_df, + x="avg_kl_divergence", + nbins=20, + title=f"Distribution of LIVE Features Only ({display_name})", + labels={"avg_kl_divergence": "KL Divergence (Causal Effect)"} + ) + fig_live.update_layout(showlegend=False) + fig_live.write_image(out_dir / "intervention_kl_dist_LIVE_ONLY.pdf") + + # 3. All Features Histogram (Log Scale) + fig_all = px.histogram( + df, + x="avg_kl_divergence", + nbins=50, + title=f"Distribution of All Features ({display_name})", + labels={"avg_kl_divergence": "KL Divergence"}, + log_y=True # Log scale to handle the massive spike at 0 + ) + fig_all.write_image(out_dir / "intervention_kl_dist_log_scale.pdf") - if "firing_count" not in module_df.columns: - print( - f"""WARNING: 'firing_count' column not found for module {module}. - Skipping plot.""" - ) - continue - module_df = module_df.copy() - # Filter out rows where f1_score is NaN to avoid errors in plotting +# --- 2. STANDARD PLOTTING HELPERS --- + +def plot_firing_vs_f1(latent_df, num_tokens, out_dir, run_label): + out_dir.mkdir(parents=True, exist_ok=True) + for module, module_df in latent_df.groupby("module"): + if "firing_count" not in module_df.columns: continue module_df = module_df[module_df["f1_score"].notna()] - if module_df.empty: - continue + if module_df.empty: continue module_df["firing_rate"] = module_df["firing_count"] / num_tokens fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) - fig.update_layout( - xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0] - ) + fig.update_layout(xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0]) fig.write_image(out_dir / f"{run_label}_{module}_firing_rates.pdf") - def import_plotly(): - """Import plotly with mitigiation for MathJax bug.""" try: import plotly.express as px import plotly.io as pio except ImportError: - raise ImportError( - "Plotly is not installed.\n" - "Please install it using `pip install plotly`, " - "or install the `[visualize]` extra." - ) + raise ImportError("Install plotly: pip install plotly") pio.kaleido.scope.mathjax = None return px - -def compute_auc(df: pd.DataFrame) -> float | None: - - valid_df = df[df.probability.notna()] - if valid_df.probability.nunique() <= 1: - return None - return roc_auc_score(valid_df.activating, valid_df.probability) - - -def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): +def plot_accuracy_hist(df, out_dir): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): - # Filter out surprisal_intervention as 'accuracy' is not relevant for it - if label == "surprisal_intervention": - continue - fig = px.histogram( - df[df["score_type"] == label], - x="accuracy", - nbins=100, - title=f"Accuracy distribution: {label}", - ) + if label == "surprisal_intervention": continue + fig = px.histogram(df[df["score_type"] == label], x="accuracy", nbins=100, title=f"Accuracy: {label}") fig.write_image(out_dir / f"{label}_accuracy.pdf") - -def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - +def plot_roc_curve(df, out_dir): valid_df = df[df.probability.notna()] - if ( - valid_df.empty - or valid_df.activating.nunique() <= 1 - or valid_df.probability.nunique() <= 1 - ): - return - + if valid_df.empty or valid_df.activating.nunique() <= 1: return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) auc = roc_auc_score(valid_df.activating, valid_df.probability) - fig = go.Figure( - data=[ - go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), - go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), - ] - ) - fig.update_layout( - title="ROC Curve", - xaxis_title="FPR", - yaxis_title="TPR", - ) + fig = go.Figure(data=[ + go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), + go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), + ]) + fig.update_layout(title="ROC Curve", xaxis_title="FPR", yaxis_title="TPR") out_dir.mkdir(exist_ok=True, parents=True) fig.write_image(out_dir / "roc_curve.pdf") -def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: - df_valid = df[df["prediction"].notna()] - if df_valid.empty: - return dict( - true_positives=0, - true_negatives=0, - false_positives=0, - false_negatives=0, - total_examples=0, - total_positives=0, - total_negatives=0, - failed_count=len(df), - ) +# --- 3. METRIC COMPUTATION --- +def compute_confusion(df, threshold=0.5): + df_valid = df[df["prediction"].notna()] + if df_valid.empty: return dict(true_positives=0, true_negatives=0, false_positives=0, false_negatives=0, total_positives=0, total_negatives=0) act = df_valid["activating"].astype(bool) - total = len(df_valid) - pos = act.sum() - neg = total - pos - tp = ((df_valid.prediction >= threshold) & act).sum() - tn = ((df_valid.prediction < threshold) & ~act).sum() - fp = ((df_valid.prediction >= threshold) & ~act).sum() - fn = ((df_valid.prediction < threshold) & act).sum() - - return dict( - true_positives=tp, - true_negatives=tn, - false_positives=fp, - false_negatives=fn, - total_examples=total, - total_positives=pos, - total_negatives=neg, - failed_count=len(df) - len(df_valid), - ) - + pred = df_valid["prediction"] >= threshold + tp, tn = (pred & act).sum(), (~pred & ~act).sum() + fp, fn = (pred & ~act).sum(), (~pred & act).sum() + return dict(true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, total_positives=act.sum(), total_negatives=(~act).sum()) -def compute_classification_metrics(conf: dict) -> dict: - tp, tn, fp, fn = ( - conf["true_positives"], - conf["true_negatives"], - conf["false_positives"], - conf["false_negatives"], - ) +def compute_classification_metrics(conf): + tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] pos, neg = conf["total_positives"], conf["total_negatives"] - - balanced_accuracy = ( - (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) - ) / 2 - precision = tp / (tp + fp) if tp + fp > 0 else 0 - recall = tp / pos if pos > 0 else 0 - f1 = ( - 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 - ) - - return dict( - precision=precision, - recall=recall, - f1_score=f1, - accuracy=balanced_accuracy, - true_positive_rate=tp / pos if pos > 0 else 0, - true_negative_rate=tn / neg if neg > 0 else 0, - false_positive_rate=fp / neg if neg > 0 else 0, - false_negative_rate=fn / pos if pos > 0 else 0, - ) - - -def load_data(scores_path: Path, modules: list[str]): - """Load all on-disk data into a single DataFrame.""" - - def parse_score_file(path: Path) -> pd.DataFrame: + acc = ((tp/pos if pos else 0) + (tn/neg if neg else 0)) / 2 + prec = tp/(tp+fp) if (tp+fp) else 0 + rec = tp/pos if pos else 0 + f1 = 2*(prec*rec)/(prec+rec) if (prec+rec) else 0 + return dict(accuracy=acc, precision=prec, recall=rec, f1_score=f1) + +def compute_auc(df): + valid = df[df.probability.notna()] + if valid.probability.nunique() <= 1: return None + return roc_auc_score(valid.activating, valid.probability) + +def get_agg_metrics(df): + rows = [] + for scorer, group in df.groupby("score_type"): + if scorer == "surprisal_intervention": continue + conf = compute_confusion(group) + rows.append({"score_type": scorer, **conf, **compute_classification_metrics(conf), "auc": compute_auc(group)}) + return pd.DataFrame(rows) + +def add_latent_f1(df): + f1s = df.groupby(["module", "latent_idx"]).apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ).reset_index(name="f1_score") + return df.merge(f1s, on=["module", "latent_idx"]) + + +# --- 4. DATA LOADING --- + +def load_data(scores_path, modules): + def parse_file(path): try: data = orjson.loads(path.read_bytes()) - except orjson.JSONDecodeError: - print(f"Error decoding JSON from {path}. Skipping file.") - return pd.DataFrame() - - if not isinstance(data, list): - print( - f"""Warning: Expected a list of results in {path}, - but found {type(data)}. - Skipping file.""" - ) - return pd.DataFrame() - - latent_idx = int(path.stem.split("latent")[-1]) - - # Updated to extract all possible keys safely using .get() - return pd.DataFrame( - [ - { - "text": "".join(ex.get("str_tokens", [])), - "distance": ex.get("distance"), - "activating": ex.get("activating"), - "prediction": ex.get("prediction"), - "probability": ex.get("probability"), - "correct": ex.get("correct"), - "activations": ex.get("activations"), - "final_score": ex.get("final_score"), - "avg_kl_divergence": ex.get("avg_kl_divergence"), - "latent_idx": latent_idx, - } - for ex in data - ] - ) + if not isinstance(data, list): return pd.DataFrame() + latent_idx = int(path.stem.split("latent")[-1]) + return pd.DataFrame([{ + "text": "".join(ex.get("str_tokens", [])), + "activating": ex.get("activating"), + "prediction": ex.get("prediction"), + "probability": ex.get("probability"), + "final_score": ex.get("final_score"), + "avg_kl_divergence": ex.get("avg_kl_divergence"), + "latent_idx": latent_idx + } for ex in data]) + except Exception: return pd.DataFrame() counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} - if not all(module in counts for module in modules): - print("Missing firing counts for some modules, setting counts to None.") - print(f"Missing modules: {[m for m in modules if m not in counts]}") - counts = None - - latent_dfs = [] - for score_type_dir in scores_path.iterdir(): - if not score_type_dir.is_dir(): - continue + + dfs = [] + for scorer_dir in scores_path.iterdir(): + if not scorer_dir.is_dir(): continue for module in modules: - for file in score_type_dir.glob(f"*{module}*"): - latent_df = parse_score_file(file) - if latent_df.empty: - continue - latent_df["score_type"] = score_type_dir.name - latent_df["module"] = module - if counts: - latent_idx = latent_df["latent_idx"].iloc[0] - latent_df["firing_count"] = ( - counts[module][latent_idx].item() - if module in counts and latent_idx in counts[module] - else None - ) - latent_dfs.append(latent_df) - - if not latent_dfs: - return pd.DataFrame(), counts - - return pd.concat(latent_dfs, ignore_index=True), counts - - -def get_agg_metrics( - latent_df: pd.DataFrame, counts: Optional[dict[str, torch.Tensor]] -) -> pd.DataFrame: - processed_rows = [] - for score_type, group_df in latent_df.groupby("score_type"): - # For surprisal_intervention, we don't compute classification metrics - if score_type == "surprisal_intervention": - continue - - conf = compute_confusion(group_df) - class_m = compute_classification_metrics(conf) - auc = compute_auc(group_df) - f1_w = frequency_weighted_f1(group_df, counts) if counts else None - - row = { - "score_type": score_type, - **conf, - **class_m, - "auc": auc, - "weighted_f1": f1_w, - } - processed_rows.append(row) - - return pd.DataFrame(processed_rows) - - -def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: - f1s = ( - latent_df.groupby(["module", "latent_idx"]) - .apply( - lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] - ) - .reset_index(name="f1_score") # <- naive (un-weighted) F1 - ) - return latent_df.merge(f1s, on=["module", "latent_idx"]) - - -def log_results( - scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str] -): + for f in scorer_dir.glob(f"*{module}*"): + df = parse_file(f) + if df.empty: continue + df["score_type"] = scorer_dir.name + df["module"] = module + if module in counts: + idx = df["latent_idx"].iloc[0] + if idx < len(counts[module]): + df["firing_count"] = counts[module][idx].item() + dfs.append(df) + + return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts + + +# --- 5. MAIN LOGIC --- + +def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str], model_name: str = "Unknown"): import_plotly() - + latent_df, counts = load_data(scores_path, modules) if latent_df.empty: - print("No data to analyze.") + print("No data found.") return - latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - - # Separate the dataframes for different processing - classification_df = latent_df[latent_df["score_type"] != "surprisal_intervention"] - surprisal_df = latent_df[latent_df["score_type"] == "surprisal_intervention"] - - if not classification_df.empty: - classification_df = add_latent_f1(classification_df) - if counts: - plot_firing_vs_f1( - classification_df, - num_tokens=10_000_000, - out_dir=viz_path, - run_label=scores_path.name, - ) - plot_roc_curve(classification_df, viz_path) - processed_df = get_agg_metrics(classification_df, counts) - plot_accuracy_hist(processed_df, viz_path) - - if counts: - dead = sum((counts[m] == 0).sum().item() for m in modules) - print(f"Number of dead features: {dead}") - - for score_type in latent_df["score_type"].unique(): - - if score_type == "surprisal_intervention": - # Drop duplicates since score is per-latent, not per-example - unique_latents = surprisal_df.drop_duplicates( - subset=["module", "latent_idx"] - ) - avg_score = unique_latents["final_score"].mean() - avg_kl = unique_latents["avg_kl_divergence"].mean() - - print(f"\n--- {score_type.title()} Metrics ---") - print(f"Average Normalized Score: {avg_score:.3f}") - print(f"Average KL Divergence: {avg_kl:.3f}") - - else: - if not classification_df.empty: - score_type_summary = processed_df[ - processed_df.score_type == score_type - ].iloc[0] - print(f"\n--- {score_type.title()} Metrics ---") - print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") - print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - - if counts and score_type_summary["weighted_f1"] is not None: - print( - f"""Frequency-Weighted F1 Score: - {score_type_summary['weighted_f1']:.3f}""" - ) - - print(f"Precision: {score_type_summary['precision']:.3f}") - print(f"Recall: {score_type_summary['recall']:.3f}") - - if score_type_summary["auc"] is not None: - print(f"AUC: {score_type_summary['auc']:.3f}") - else: - print("AUC not available.") + print(f"Generating report for: {latent_df['score_type'].unique()}") + + # Split Data + class_mask = latent_df["score_type"] != "surprisal_intervention" + class_df = latent_df[class_mask] + int_df = latent_df[~class_mask] + + # 1. Handle Classification (Fuzz/Detection) + if not class_df.empty: + class_df = add_latent_f1(class_df) + if counts: plot_firing_vs_f1(class_df, 10_000_000, viz_path, scores_path.name) + plot_roc_curve(class_df, viz_path) + + agg_df = get_agg_metrics(class_df) + plot_accuracy_hist(agg_df, viz_path) + + for _, row in agg_df.iterrows(): + print(f"\n[ {row['score_type'].title()} ]") + print(f"Accuracy: {row['accuracy']:.3f}") + print(f"F1 Score: {row['f1_score']:.3f}") + + # 2. Handle Intervention + if not int_df.empty: + unique_latents = int_df.drop_duplicates(subset=["module", "latent_idx"]).copy() + + avg_score = unique_latents["final_score"].mean() + avg_kl = unique_latents["avg_kl_divergence"].mean() + + threshold = 0.01 + n_total = len(unique_latents) + n_live = len(unique_latents[unique_latents["avg_kl_divergence"] > threshold]) + pct = (n_live / n_total * 100) if n_total > 0 else 0 + + print(f"\n--- Surprisal Intervention Analysis ---") + print(f"Avg Normalized Score: {avg_score:.3f}") + print(f"Avg KL Divergence: {avg_kl:.3f}") + print(f"Decoder-Live %: {pct:.2f}%") + + plot_intervention_stats(unique_latents, viz_path, model_name) + + # 3. Generate Scatter Plot (Fuzz vs. Intervention) + # Only works if we have BOTH types of data + if not class_df.empty and not int_df.empty: + plot_fuzz_vs_intervention(latent_df, viz_path, scores_path.name) \ No newline at end of file diff --git a/delphi/temp.py b/delphi/temp.py new file mode 100644 index 00000000..4572510c --- /dev/null +++ b/delphi/temp.py @@ -0,0 +1,11 @@ +# Create a file named run_analysis.py with these contents +from delphi.log.result_analysis import log_results +from pathlib import Path + +# Adjust the path to your results folder +scores_path = Path("results/pythia_100_test/scores") +viz_path = Path("results/pythia_100_test/visualize") +modules = ["layers.6.mlp"] +scorer_names = ["fuzz", "detection", "surprisal_intervention"] + +log_results(scores_path, viz_path, modules, scorer_names, model_name="EleutherAI/pythia-160m") \ No newline at end of file From c0ef075aea76bbc973a64fef92c60f85ba41182d Mon Sep 17 00:00:00 2001 From: Sai Reddy Date: Fri, 21 Nov 2025 15:24:05 +0000 Subject: [PATCH 22/25] intervention scoring plots --- delphi/log/result_analysis.py | 670 ---------------------------------- 1 file changed, 670 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 6c26757a..ef2b16ca 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -1,673 +1,3 @@ -# from pathlib import Path -# from typing import Optional - -# import orjson -# import pandas as pd -# import plotly.express as px -# import plotly.graph_objects as go -# import torch -# from sklearn.metrics import roc_auc_score, roc_curve - - -# def plot_firing_vs_f1( -# latent_df: pd.DataFrame, num_tokens: int, out_dir: Path, run_label: str -# ) -> None: -# out_dir.mkdir(parents=True, exist_ok=True) -# for module, module_df in latent_df.groupby("module"): - -# if "firing_count" not in module_df.columns: -# print( -# f"""WARNING: 'firing_count' column not found for module {module}. -# Skipping plot.""" -# ) -# continue - -# module_df = module_df.copy() -# # Filter out rows where f1_score is NaN to avoid errors in plotting -# module_df = module_df[module_df["f1_score"].notna()] -# if module_df.empty: -# continue - -# module_df["firing_rate"] = module_df["firing_count"] / num_tokens -# fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) -# fig.update_layout( -# xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0] -# ) -# fig.write_image(out_dir / f"{run_label}_{module}_firing_rates.pdf") - - -# def import_plotly(): -# """Import plotly with mitigiation for MathJax bug.""" -# try: -# import plotly.express as px -# import plotly.io as pio -# except ImportError: -# raise ImportError( -# "Plotly is not installed.\n" -# "Please install it using `pip install plotly`, " -# "or install the `[visualize]` extra." -# ) -# pio.kaleido.scope.mathjax = None -# return px - - -# def compute_auc(df: pd.DataFrame) -> float | None: - -# valid_df = df[df.probability.notna()] -# if valid_df.probability.nunique() <= 1: -# return None -# return roc_auc_score(valid_df.activating, valid_df.probability) - - -# def plot_accuracy_hist(df: pd.DataFrame, out_dir: Path): -# out_dir.mkdir(exist_ok=True, parents=True) -# for label in df["score_type"].unique(): -# # Filter out surprisal_intervention as 'accuracy' is not relevant for it -# if label == "surprisal_intervention": -# continue -# fig = px.histogram( -# df[df["score_type"] == label], -# x="accuracy", -# nbins=100, -# title=f"Accuracy distribution: {label}", -# ) -# fig.write_image(out_dir / f"{label}_accuracy.pdf") - - -# def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): -# """ -# Plots statistics for the surprisal_intervention scorer: -# 1. A histogram of the KL Divergence scores. -# 2. A bar chart of 'Decoder-Live' vs 'Decoder-Dead' features. -# """ -# out_dir.mkdir(exist_ok=True, parents=True) - -# display_name = model_name.split("/")[-1] if "/" in model_name else model_name - -# # 1. KL Divergence Histogram -# # This shows the distribution of "causal impact" across all features -# fig_hist = px.histogram( -# df, -# x="avg_kl_divergence", -# nbins=50, -# title="Distribution of Intervention KL Divergence ({display_name})", -# labels={"avg_kl_divergence": "Average KL Divergence (Causal Effect)"}, -# log_y=True # Log scale helps visualize the 'long tail' if many are 0 -# ) -# fig_hist.update_layout(showlegend=False) -# fig_hist.write_image(out_dir / "intervention_kl_histogram.pdf") - -# # 2. Live vs Dead Bar Chart -# # We define "Live" as having a KL > 0.01 (non-zero effect) -# threshold = 0.01 -# df["status"] = df["avg_kl_divergence"].apply( -# lambda x: "Decoder-Live" if x > threshold else "Decoder-Dead" -# ) - -# counts = df["status"].value_counts().reset_index() -# counts.columns = ["Status", "Count"] - -# # Calculate percentage for the title -# total = counts["Count"].sum() -# live_count = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() -# live_pct = (live_count / total) * 100 if total > 0 else 0 - -# fig_bar = px.bar( -# counts, -# x="Status", -# y="Count", -# color="Status", -# title=f"Causal Relevance: {live_pct:.1f}% Live Features", -# text="Count", -# color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"} -# ) -# fig_bar.update_traces(textposition='auto') -# fig_bar.write_image(out_dir / "intervention_live_dead_split.pdf") - - -# def plot_roc_curve(df: pd.DataFrame, out_dir: Path): - -# valid_df = df[df.probability.notna()] -# if ( -# valid_df.empty -# or valid_df.activating.nunique() <= 1 -# or valid_df.probability.nunique() <= 1 -# ): -# return - -# fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) -# auc = roc_auc_score(valid_df.activating, valid_df.probability) -# fig = go.Figure( -# data=[ -# go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), -# go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), -# ] -# ) -# fig.update_layout( -# title="ROC Curve", -# xaxis_title="FPR", -# yaxis_title="TPR", -# ) -# out_dir.mkdir(exist_ok=True, parents=True) -# fig.write_image(out_dir / "roc_curve.pdf") - - -# def compute_confusion(df: pd.DataFrame, threshold: float = 0.5) -> dict: -# df_valid = df[df["prediction"].notna()] -# if df_valid.empty: -# return dict( -# true_positives=0, -# true_negatives=0, -# false_positives=0, -# false_negatives=0, -# total_examples=0, -# total_positives=0, -# total_negatives=0, -# failed_count=len(df), -# ) - -# act = df_valid["activating"].astype(bool) -# total = len(df_valid) -# pos = act.sum() -# neg = total - pos -# tp = ((df_valid.prediction >= threshold) & act).sum() -# tn = ((df_valid.prediction < threshold) & ~act).sum() -# fp = ((df_valid.prediction >= threshold) & ~act).sum() -# fn = ((df_valid.prediction < threshold) & act).sum() - -# return dict( -# true_positives=tp, -# true_negatives=tn, -# false_positives=fp, -# false_negatives=fn, -# total_examples=total, -# total_positives=pos, -# total_negatives=neg, -# failed_count=len(df) - len(df_valid), -# ) - - -# def compute_classification_metrics(conf: dict) -> dict: -# tp, tn, fp, fn = ( -# conf["true_positives"], -# conf["true_negatives"], -# conf["false_positives"], -# conf["false_negatives"], -# ) -# pos, neg = conf["total_positives"], conf["total_negatives"] - -# balanced_accuracy = ( -# (tp / pos if pos > 0 else 0) + (tn / neg if neg > 0 else 0) -# ) / 2 -# precision = tp / (tp + fp) if tp + fp > 0 else 0 -# recall = tp / pos if pos > 0 else 0 -# f1 = ( -# 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0 -# ) - -# return dict( -# precision=precision, -# recall=recall, -# f1_score=f1, -# accuracy=balanced_accuracy, -# true_positive_rate=tp / pos if pos > 0 else 0, -# true_negative_rate=tn / neg if neg > 0 else 0, -# false_positive_rate=fp / neg if neg > 0 else 0, -# false_negative_rate=fn / pos if pos > 0 else 0, -# ) - - -# def load_data(scores_path: Path, modules: list[str]): -# """Load all on-disk data into a single DataFrame.""" - -# def parse_score_file(path: Path) -> pd.DataFrame: -# try: -# data = orjson.loads(path.read_bytes()) -# except orjson.JSONDecodeError: -# print(f"Error decoding JSON from {path}. Skipping file.") -# return pd.DataFrame() - -# if not isinstance(data, list): -# print( -# f"""Warning: Expected a list of results in {path}, -# but found {type(data)}. -# Skipping file.""" -# ) -# return pd.DataFrame() - -# latent_idx = int(path.stem.split("latent")[-1]) - -# # Updated to extract all possible keys safely using .get() -# return pd.DataFrame( -# [ -# { -# "text": "".join(ex.get("str_tokens", [])), -# "distance": ex.get("distance"), -# "activating": ex.get("activating"), -# "prediction": ex.get("prediction"), -# "probability": ex.get("probability"), -# "correct": ex.get("correct"), -# "activations": ex.get("activations"), -# "final_score": ex.get("final_score"), -# "avg_kl_divergence": ex.get("avg_kl_divergence"), -# "latent_idx": latent_idx, -# } -# for ex in data -# ] -# ) - -# counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" -# counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} -# if not all(module in counts for module in modules): -# print("Missing firing counts for some modules, setting counts to None.") -# print(f"Missing modules: {[m for m in modules if m not in counts]}") -# counts = None - -# latent_dfs = [] -# for score_type_dir in scores_path.iterdir(): -# if not score_type_dir.is_dir(): -# continue -# for module in modules: -# for file in score_type_dir.glob(f"*{module}*"): -# latent_df = parse_score_file(file) -# if latent_df.empty: -# continue -# latent_df["score_type"] = score_type_dir.name -# latent_df["module"] = module -# if counts: -# latent_idx = latent_df["latent_idx"].iloc[0] -# latent_df["firing_count"] = ( -# counts[module][latent_idx].item() -# if module in counts and latent_idx in counts[module] -# else None -# ) -# latent_dfs.append(latent_df) - -# if not latent_dfs: -# return pd.DataFrame(), counts - -# return pd.concat(latent_dfs, ignore_index=True), counts - - -# def get_agg_metrics( -# latent_df: pd.DataFrame, counts: Optional[dict[str, torch.Tensor]] -# ) -> pd.DataFrame: -# processed_rows = [] -# for score_type, group_df in latent_df.groupby("score_type"): -# # For surprisal_intervention, we don't compute classification metrics -# if score_type == "surprisal_intervention": -# continue - -# conf = compute_confusion(group_df) -# class_m = compute_classification_metrics(conf) -# auc = compute_auc(group_df) -# f1_w = frequency_weighted_f1(group_df, counts) if counts else None - -# row = { -# "score_type": score_type, -# **conf, -# **class_m, -# "auc": auc, -# "weighted_f1": f1_w, -# } -# processed_rows.append(row) - -# return pd.DataFrame(processed_rows) - - -# def add_latent_f1(latent_df: pd.DataFrame) -> pd.DataFrame: -# f1s = ( -# latent_df.groupby(["module", "latent_idx"]) -# .apply( -# lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] -# ) -# .reset_index(name="f1_score") # <- naive (un-weighted) F1 -# ) -# return latent_df.merge(f1s, on=["module", "latent_idx"]) - - -# def log_results( -# scores_path: Path, -# viz_path: Path, -# modules: list[str], -# scorer_names: list[str], -# model_name: str = "Unknown Model" -# ): -# import_plotly() - -# latent_df, counts = load_data(scores_path, modules) -# if latent_df.empty: -# print("No data to analyze.") -# return - -# latent_df = latent_df[latent_df["score_type"].isin(scorer_names)] - -# # Separate the dataframes for different processing -# classification_df = latent_df[latent_df["score_type"] != "surprisal_intervention"] -# surprisal_df = latent_df[latent_df["score_type"] == "surprisal_intervention"] - -# if not classification_df.empty: -# classification_df = add_latent_f1(classification_df) -# if counts: -# plot_firing_vs_f1( -# classification_df, -# num_tokens=10_000_000, -# out_dir=viz_path, -# run_label=scores_path.name, -# ) -# plot_roc_curve(classification_df, viz_path) -# processed_df = get_agg_metrics(classification_df, counts) -# plot_accuracy_hist(processed_df, viz_path) - -# if counts: -# dead = sum((counts[m] == 0).sum().item() for m in modules) -# print(f"Number of dead features: {dead}") - -# for score_type in latent_df["score_type"].unique(): - -# if score_type == "surprisal_intervention": -# # Drop duplicates since score is per-latent, not per-example -# unique_latents = surprisal_df.drop_duplicates( -# subset=["module", "latent_idx"] -# ).copy() - -# avg_score = unique_latents["final_score"].mean() -# avg_kl = unique_latents["avg_kl_divergence"].mean() - -# # We define "Decoder-Live" as having a KL > 0.01 (non-zero effect) -# threshold = 0.01 -# n_total = len(unique_latents) -# n_live = len(unique_latents[unique_latents["avg_kl_divergence"] > threshold]) -# live_pct = (n_live / n_total) * 100 if n_total > 0 else 0.0 - -# print(f"\n--- {score_type.title()} Metrics ---") -# print(f"Average Normalized Score: {avg_score:.3f}") -# print(f"Average KL Divergence: {avg_kl:.3f}") -# print(f"Decoder-Live Percentage: {live_pct:.2f}%") - - -# plot_intervention_stats(unique_latents, viz_path, model_name) - -# else: -# if not classification_df.empty: -# score_type_summary = processed_df[ -# processed_df.score_type == score_type -# ].iloc[0] -# print(f"\n--- {score_type.title()} Metrics ---") -# print(f"Class-Balanced Accuracy: {score_type_summary['accuracy']:.3f}") -# print(f"F1 Score: {score_type_summary['f1_score']:.3f}") - -# if counts and score_type_summary["weighted_f1"] is not None: -# print( -# f"""Frequency-Weighted F1 Score: -# {score_type_summary['weighted_f1']:.3f}""" -# ) - -# print(f"Precision: {score_type_summary['precision']:.3f}") -# print(f"Recall: {score_type_summary['recall']:.3f}") - -# if score_type_summary["auc"] is not None: -# print(f"AUC: {score_type_summary['auc']:.3f}") -# else: -# print("AUC not available.") - -# from pathlib import Path -# from typing import Optional - -# import orjson -# import pandas as pd -# import plotly.express as px -# import plotly.graph_objects as go -# import torch -# from sklearn.metrics import roc_auc_score, roc_curve - - -# # --- PLOTTING HELPERS --- - -# def import_plotly(): -# """Import plotly with mitigation for MathJax bug.""" -# try: -# import plotly.express as px -# import plotly.io as pio -# except ImportError: -# raise ImportError( -# "Plotly is not installed. Please install it via `pip install plotly`." -# ) -# pio.kaleido.scope.mathjax = None -# return px - -# def plot_firing_vs_f1(latent_df, num_tokens, out_dir, run_label): -# out_dir.mkdir(parents=True, exist_ok=True) -# for module, module_df in latent_df.groupby("module"): -# if "firing_count" not in module_df.columns: -# continue -# module_df = module_df[module_df["f1_score"].notna()] -# if module_df.empty: continue - -# module_df["firing_rate"] = module_df["firing_count"] / num_tokens -# fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) -# fig.update_layout(xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0]) -# fig.write_image(out_dir / f"{run_label}_{module}_firing_rates.pdf") - -# def plot_roc_curve(df, out_dir): -# valid_df = df[df.probability.notna()] -# if valid_df.empty or valid_df.activating.nunique() <= 1 or valid_df.probability.nunique() <= 1: -# return - -# fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) -# auc = roc_auc_score(valid_df.activating, valid_df.probability) -# fig = go.Figure(data=[ -# go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), -# go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), -# ]) -# fig.update_layout(title="ROC Curve", xaxis_title="FPR", yaxis_title="TPR") -# out_dir.mkdir(exist_ok=True, parents=True) -# fig.write_image(out_dir / "roc_curve.pdf") - -# def plot_accuracy_hist(df, out_dir): -# out_dir.mkdir(exist_ok=True, parents=True) -# for label in df["score_type"].unique(): -# fig = px.histogram(df[df["score_type"] == label], x="accuracy", nbins=100, title=f"Accuracy: {label}") -# fig.write_image(out_dir / f"{label}_accuracy.pdf") - -# def plot_intervention_stats(df, out_dir, model_name): -# """Specific plots for Intervention scoring.""" -# out_dir.mkdir(exist_ok=True, parents=True) -# display_name = model_name.split("/")[-1] if "/" in model_name else model_name - -# # 1. KL Histogram -# fig_hist = px.histogram( -# df, x="avg_kl_divergence", nbins=50, log_y=True, -# title=f"KL Divergence ({display_name})", -# labels={"avg_kl_divergence": "Avg KL Divergence (Causal Effect)"} -# ) -# fig_hist.write_image(out_dir / "intervention_kl_dist.pdf") - -# # 2. Live/Dead Split -# threshold = 0.01 -# df["status"] = df["avg_kl_divergence"].apply(lambda x: "Decoder-Live" if x > threshold else "Decoder-Dead") -# counts = df["status"].value_counts().reset_index() -# counts.columns = ["Status", "Count"] - -# total = counts["Count"].sum() -# live = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() if "Decoder-Live" in counts["Status"].values else 0 -# pct = (live / total * 100) if total > 0 else 0 - -# fig_bar = px.bar( -# counts, x="Status", y="Count", color="Status", text="Count", -# title=f"Causal Relevance: {pct:.1f}% Live ({display_name})", -# color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"} -# ) -# fig_bar.write_image(out_dir / "intervention_live_dead.pdf") - - -# # --- METRIC COMPUTATION --- - -# def compute_confusion(df, threshold=0.5): -# df_valid = df[df["prediction"].notna()] -# if df_valid.empty: return dict(tp=0, tn=0, fp=0, fn=0, pos=0, neg=0) - -# act = df_valid["activating"].astype(bool) -# pred = df_valid["prediction"] >= threshold - -# tp = (pred & act).sum() -# tn = (~pred & ~act).sum() -# fp = (pred & ~act).sum() -# fn = (~pred & act).sum() - -# return dict( -# true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, -# total_positives=act.sum(), total_negatives=(~act).sum() -# ) - -# def compute_classification_metrics(conf): -# tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] -# pos, neg = conf["total_positives"], conf["total_negatives"] - -# acc = ((tp / pos if pos else 0) + (tn / neg if neg else 0)) / 2 -# prec = tp / (tp + fp) if (tp + fp) else 0 -# rec = tp / pos if pos else 0 -# f1 = 2 * (prec * rec) / (prec + rec) if (prec + rec) else 0 - -# return dict(accuracy=acc, precision=prec, recall=rec, f1_score=f1) - -# def compute_auc(df): -# valid = df[df.probability.notna()] -# if valid.probability.nunique() <= 1: return None -# return roc_auc_score(valid.activating, valid.probability) - -# def get_agg_metrics(df): -# rows = [] -# for scorer, group in df.groupby("score_type"): -# conf = compute_confusion(group) -# metrics = compute_classification_metrics(conf) -# rows.append({"score_type": scorer, **conf, **metrics, "auc": compute_auc(group)}) -# return pd.DataFrame(rows) - -# def add_latent_f1(df): -# # Calculate F1 per latent for plotting -# f1s = df.groupby(["module", "latent_idx"]).apply( -# lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] -# ).reset_index(name="f1_score") -# return df.merge(f1s, on=["module", "latent_idx"]) - - -# # --- DATA LOADING --- - -# def load_data(scores_path, modules): -# def parse_file(path): -# try: -# data = orjson.loads(path.read_bytes()) -# if not isinstance(data, list): return pd.DataFrame() -# latent_idx = int(path.stem.split("latent")[-1]) -# return pd.DataFrame([{ -# "text": "".join(ex.get("str_tokens", [])), -# "activating": ex.get("activating"), -# "prediction": ex.get("prediction"), -# "probability": ex.get("probability"), -# "final_score": ex.get("final_score"), -# "avg_kl_divergence": ex.get("avg_kl_divergence"), -# "latent_idx": latent_idx -# } for ex in data]) -# except Exception: -# return pd.DataFrame() - -# counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" -# counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} - -# dfs = [] -# for scorer_dir in scores_path.iterdir(): -# if not scorer_dir.is_dir(): continue -# for module in modules: -# for f in scorer_dir.glob(f"*{module}*"): -# df = parse_file(f) -# if df.empty: continue -# df["score_type"] = scorer_dir.name -# df["module"] = module -# if module in counts: -# idx = df["latent_idx"].iloc[0] -# if idx < len(counts[module]): -# df["firing_count"] = counts[module][idx].item() -# dfs.append(df) - -# return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts - - -# # --- MAIN HANDLERS --- - -# def handle_classification_results(df, counts, viz_path, run_label): -# """Handles Fuzz, Detection, Simulation.""" -# print(f"\n--- Classification Analysis ({len(df)} examples) ---") - -# # Add per-latent F1 for plotting -# df = add_latent_f1(df) - -# # Plots -# if counts: -# plot_firing_vs_f1(df, 10_000_000, viz_path, run_label) -# plot_roc_curve(df, viz_path) - -# # Aggregated Metrics (Accuracy, F1, etc.) -# agg_df = get_agg_metrics(df) -# plot_accuracy_hist(agg_df, viz_path) - -# # Console Output -# for _, row in agg_df.iterrows(): -# print(f"\n[ {row['score_type'].title()} ]") -# print(f"Accuracy: {row['accuracy']:.3f}") -# print(f"F1 Score: {row['f1_score']:.3f}") -# print(f"Precision: {row['precision']:.3f}") -# print(f"Recall: {row['recall']:.3f}") -# if row['auc']: print(f"AUC: {row['auc']:.3f}") - - -# def handle_intervention_results(df, viz_path, model_name): -# """Handles Surprisal Intervention.""" -# # Deduplicate: we only need one row per latent per module -# unique_latents = df.drop_duplicates(subset=["module", "latent_idx"]).copy() - -# avg_score = unique_latents["final_score"].mean() -# avg_kl = unique_latents["avg_kl_divergence"].mean() - -# # Calculate Decoder-Live % -# total = len(unique_latents) -# live = len(unique_latents[unique_latents["avg_kl_divergence"] > 0.01]) -# pct = (live / total * 100) if total > 0 else 0 - -# print(f"\n--- Surprisal Intervention Analysis ({total} latents) ---") -# print(f"Avg Normalized Score: {avg_score:.3f}") -# print(f"Avg KL Divergence: {avg_kl:.3f}") -# print(f"Decoder-Live %: {pct:.2f}%") - -# plot_intervention_stats(unique_latents, viz_path, model_name) - - -# # --- ENTRY POINT --- - -# def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str], model_name: str = "Unknown"): -# import_plotly() - -# # 1. Load ALL data (Global Reporting) -# latent_df, counts = load_data(scores_path, modules) -# if latent_df.empty: -# print("No data found to analyze.") -# return - -# print(f"Generating report for scorers found: {latent_df['score_type'].unique()}") - -# # 2. Split Data -# classification_mask = latent_df["score_type"] != "surprisal_intervention" -# classification_df = latent_df[classification_mask] -# intervention_df = latent_df[~classification_mask] - -# # 3. Dispatch to Handlers -# if not classification_df.empty: -# handle_classification_results(classification_df, counts, viz_path, scores_path.name) - -# if not intervention_df.empty: -# handle_intervention_results(intervention_df, viz_path, model_name) - - from pathlib import Path from typing import Optional From dd3956dbd5b89fe39ac735db593bc00dbfd4b143 Mon Sep 17 00:00:00 2001 From: Sai Reddy Date: Fri, 21 Nov 2025 15:33:39 +0000 Subject: [PATCH 23/25] Clean plots and results file --- delphi/log/result_analysis.py | 16 +++------------- 1 file changed, 3 insertions(+), 13 deletions(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index ef2b16ca..99339e50 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -9,15 +9,12 @@ from sklearn.metrics import roc_auc_score, roc_curve -# --- 1. NEW PLOTTING FUNCTIONS --- def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label: str): """ Replicates the Scatter Plot from the paper (Figure 3/Appendix G). Plots Fuzz Score vs. Intervention Score for the same latents. """ - # We need to merge the rows for 'fuzz' and 'surprisal_intervention' - # 1. Pivot the table so we have columns: 'latent_idx', 'fuzz_score', 'intervention_score' # Extract Fuzz Scores (using F1 or Accuracy as the metric) fuzz_df = latent_df[latent_df["score_type"] == "fuzz"].copy() @@ -32,12 +29,10 @@ def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label: int_df = latent_df[latent_df["score_type"] == "surprisal_intervention"].copy() if int_df.empty: return - # Deduplicate intervention scores int_metrics = int_df.drop_duplicates(subset=["module", "latent_idx"])[ ["module", "latent_idx", "avg_kl_divergence", "final_score"] ] - # Merge them merged = pd.merge(fuzz_metrics, int_metrics, on=["module", "latent_idx"]) if merged.empty: @@ -63,7 +58,8 @@ def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label: y="final_score", hover_data=["latent_idx"], title=f"Correlation vs. Causation (Score) - {run_label}", - labels={"fuzz_score": "Fuzzing Score (Correlation)", "final_score": "Intervention Score (Surprisal)"}, + labels={"fuzz_score": "Fuzzing Score (Correlation)", + "final_score": "Intervention Score (Surprisal)"}, trendline="ols" ) fig_score.write_image(out_dir / "scatter_fuzz_vs_score.pdf") @@ -87,7 +83,6 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): counts = df["status"].value_counts().reset_index() counts.columns = ["Status", "Count"] - # Get percentage total = counts["Count"].sum() live = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() if "Decoder-Live" in counts["Status"].values else 0 pct = (live / total * 100) if total > 0 else 0 @@ -99,7 +94,7 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): ) fig_bar.write_image(out_dir / "intervention_live_dead_split.pdf") - # 2. "Live Features Only" Histogram (The "Pretty" one) + # 2. "Live Features Only" Histogram live_df = df[df["avg_kl_divergence"] > threshold] if not live_df.empty: fig_live = px.histogram( @@ -124,7 +119,6 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): fig_all.write_image(out_dir / "intervention_kl_dist_log_scale.pdf") -# --- 2. STANDARD PLOTTING HELPERS --- def plot_firing_vs_f1(latent_df, num_tokens, out_dir, run_label): out_dir.mkdir(parents=True, exist_ok=True) @@ -168,7 +162,6 @@ def plot_roc_curve(df, out_dir): fig.write_image(out_dir / "roc_curve.pdf") -# --- 3. METRIC COMPUTATION --- def compute_confusion(df, threshold=0.5): df_valid = df[df["prediction"].notna()] @@ -208,7 +201,6 @@ def add_latent_f1(df): return df.merge(f1s, on=["module", "latent_idx"]) -# --- 4. DATA LOADING --- def load_data(scores_path, modules): def parse_file(path): @@ -248,7 +240,6 @@ def parse_file(path): return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts -# --- 5. MAIN LOGIC --- def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str], model_name: str = "Unknown"): import_plotly() @@ -299,6 +290,5 @@ def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_na plot_intervention_stats(unique_latents, viz_path, model_name) # 3. Generate Scatter Plot (Fuzz vs. Intervention) - # Only works if we have BOTH types of data if not class_df.empty and not int_df.empty: plot_fuzz_vs_intervention(latent_df, viz_path, scores_path.name) \ No newline at end of file From dc672340234e1a8f61991e1cbf05cc9e1a9e941d Mon Sep 17 00:00:00 2001 From: Sai Reddy Date: Fri, 21 Nov 2025 15:37:27 +0000 Subject: [PATCH 24/25] minor fixes cli --- delphi/log/result_analysis.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 99339e50..808d733f 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -173,7 +173,7 @@ def compute_confusion(df, threshold=0.5): return dict(true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, total_positives=act.sum(), total_negatives=(~act).sum()) def compute_classification_metrics(conf): - tp, tn, fp, fn = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] + tp, tn, fp, _ = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] pos, neg = conf["total_positives"], conf["total_negatives"] acc = ((tp/pos if pos else 0) + (tn/neg if neg else 0)) / 2 prec = tp/(tp+fp) if (tp+fp) else 0 From 9ccf205ffa2b4aec965603b3852b88afdf9779f1 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 21 Nov 2025 15:37:40 +0000 Subject: [PATCH 25/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- delphi/__main__.py | 12 +- delphi/log/result_analysis.py | 246 ++++++++++++++++++++++------------ delphi/temp.py | 7 +- 3 files changed, 176 insertions(+), 89 deletions(-) diff --git a/delphi/__main__.py b/delphi/__main__.py index 1e5aced7..60c6855c 100644 --- a/delphi/__main__.py +++ b/delphi/__main__.py @@ -491,11 +491,13 @@ async def run( del model, hookpoint_to_sparse_encode if run_cfg.verbose: - log_results(scores_path, - visualize_path, - run_cfg.hookpoints, - run_cfg.scorers, - model_name=run_cfg.model) + log_results( + scores_path, + visualize_path, + run_cfg.hookpoints, + run_cfg.scorers, + model_name=run_cfg.model, + ) if __name__ == "__main__": diff --git a/delphi/log/result_analysis.py b/delphi/log/result_analysis.py index 808d733f..8afd4388 100644 --- a/delphi/log/result_analysis.py +++ b/delphi/log/result_analysis.py @@ -1,5 +1,4 @@ from pathlib import Path -from typing import Optional import orjson import pandas as pd @@ -9,58 +8,68 @@ from sklearn.metrics import roc_auc_score, roc_curve - def plot_fuzz_vs_intervention(latent_df: pd.DataFrame, out_dir: Path, run_label: str): """ Replicates the Scatter Plot from the paper (Figure 3/Appendix G). Plots Fuzz Score vs. Intervention Score for the same latents. """ - + # Extract Fuzz Scores (using F1 or Accuracy as the metric) fuzz_df = latent_df[latent_df["score_type"] == "fuzz"].copy() - if fuzz_df.empty: return - + if fuzz_df.empty: + return + # Calculate per-latent F1 for fuzzing - fuzz_metrics = fuzz_df.groupby(["module", "latent_idx"]).apply( - lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] - ).reset_index(name="fuzz_score") + fuzz_metrics = ( + fuzz_df.groupby(["module", "latent_idx"]) + .apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ) + .reset_index(name="fuzz_score") + ) # Extract Intervention Scores int_df = latent_df[latent_df["score_type"] == "surprisal_intervention"].copy() - if int_df.empty: return - + if int_df.empty: + return + int_metrics = int_df.drop_duplicates(subset=["module", "latent_idx"])[ ["module", "latent_idx", "avg_kl_divergence", "final_score"] ] merged = pd.merge(fuzz_metrics, int_metrics, on=["module", "latent_idx"]) - + if merged.empty: print("Could not merge Fuzz and Intervention scores (no matching latents).") return # Plot 1: KL vs Fuzz (Causal Impact vs Correlational Quality) fig_kl = px.scatter( - merged, - x="fuzz_score", + merged, + x="fuzz_score", y="avg_kl_divergence", hover_data=["latent_idx"], title=f"Correlation vs. Causation (KL) - {run_label}", - labels={"fuzz_score": "Fuzzing Score (Correlation)", "avg_kl_divergence": "Intervention KL (Causation)"}, - trendline="ols" # Adds a regression line to show the negative/zero correlation + labels={ + "fuzz_score": "Fuzzing Score (Correlation)", + "avg_kl_divergence": "Intervention KL (Causation)", + }, + trendline="ols", # Adds a regression line to show the negative/zero correlation ) fig_kl.write_image(out_dir / "scatter_fuzz_vs_kl.pdf") # Plot 2: Score vs Fuzz (Original Paper Metric) fig_score = px.scatter( - merged, - x="fuzz_score", + merged, + x="fuzz_score", y="final_score", hover_data=["latent_idx"], title=f"Correlation vs. Causation (Score) - {run_label}", - labels={"fuzz_score": "Fuzzing Score (Correlation)", - "final_score": "Intervention Score (Surprisal)"}, - trendline="ols" + labels={ + "fuzz_score": "Fuzzing Score (Correlation)", + "final_score": "Intervention Score (Surprisal)", + }, + trendline="ols", ) fig_score.write_image(out_dir / "scatter_fuzz_vs_score.pdf") print("Generated Fuzz vs. Intervention scatter plots.") @@ -74,7 +83,7 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): """ out_dir.mkdir(exist_ok=True, parents=True) display_name = model_name.split("/")[-1] if "/" in model_name else model_name - + # 1. Live/Dead Split Bar Chart threshold = 0.01 df["status"] = df["avg_kl_divergence"].apply( @@ -82,15 +91,23 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): ) counts = df["status"].value_counts().reset_index() counts.columns = ["Status", "Count"] - + total = counts["Count"].sum() - live = counts[counts["Status"] == "Decoder-Live"]["Count"].sum() if "Decoder-Live" in counts["Status"].values else 0 + live = ( + counts[counts["Status"] == "Decoder-Live"]["Count"].sum() + if "Decoder-Live" in counts["Status"].values + else 0 + ) pct = (live / total * 100) if total > 0 else 0 fig_bar = px.bar( - counts, x="Status", y="Count", color="Status", text="Count", + counts, + x="Status", + y="Count", + color="Status", + text="Count", title=f"Causal Relevance: {pct:.1f}% Live ({display_name})", - color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"} + color_discrete_map={"Decoder-Live": "green", "Decoder-Dead": "red"}, ) fig_bar.write_image(out_dir / "intervention_live_dead_split.pdf") @@ -98,11 +115,11 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): live_df = df[df["avg_kl_divergence"] > threshold] if not live_df.empty: fig_live = px.histogram( - live_df, - x="avg_kl_divergence", + live_df, + x="avg_kl_divergence", nbins=20, title=f"Distribution of LIVE Features Only ({display_name})", - labels={"avg_kl_divergence": "KL Divergence (Causal Effect)"} + labels={"avg_kl_divergence": "KL Divergence (Causal Effect)"}, ) fig_live.update_layout(showlegend=False) fig_live.write_image(out_dir / "intervention_kl_dist_LIVE_ONLY.pdf") @@ -114,24 +131,28 @@ def plot_intervention_stats(df: pd.DataFrame, out_dir: Path, model_name: str): nbins=50, title=f"Distribution of All Features ({display_name})", labels={"avg_kl_divergence": "KL Divergence"}, - log_y=True # Log scale to handle the massive spike at 0 + log_y=True, # Log scale to handle the massive spike at 0 ) fig_all.write_image(out_dir / "intervention_kl_dist_log_scale.pdf") - def plot_firing_vs_f1(latent_df, num_tokens, out_dir, run_label): out_dir.mkdir(parents=True, exist_ok=True) for module, module_df in latent_df.groupby("module"): - if "firing_count" not in module_df.columns: continue + if "firing_count" not in module_df.columns: + continue module_df = module_df[module_df["f1_score"].notna()] - if module_df.empty: continue + if module_df.empty: + continue module_df["firing_rate"] = module_df["firing_count"] / num_tokens fig = px.scatter(module_df, x="firing_rate", y="f1_score", log_x=True) - fig.update_layout(xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0]) + fig.update_layout( + xaxis_title="Firing rate", yaxis_title="F1 score", xaxis_range=[-5.4, 0] + ) fig.write_image(out_dir / f"{run_label}_{module}_firing_rates.pdf") + def import_plotly(): try: import plotly.express as px @@ -141,94 +162,149 @@ def import_plotly(): pio.kaleido.scope.mathjax = None return px + def plot_accuracy_hist(df, out_dir): out_dir.mkdir(exist_ok=True, parents=True) for label in df["score_type"].unique(): - if label == "surprisal_intervention": continue - fig = px.histogram(df[df["score_type"] == label], x="accuracy", nbins=100, title=f"Accuracy: {label}") + if label == "surprisal_intervention": + continue + fig = px.histogram( + df[df["score_type"] == label], + x="accuracy", + nbins=100, + title=f"Accuracy: {label}", + ) fig.write_image(out_dir / f"{label}_accuracy.pdf") + def plot_roc_curve(df, out_dir): valid_df = df[df.probability.notna()] - if valid_df.empty or valid_df.activating.nunique() <= 1: return + if valid_df.empty or valid_df.activating.nunique() <= 1: + return fpr, tpr, _ = roc_curve(valid_df.activating, valid_df.probability) auc = roc_auc_score(valid_df.activating, valid_df.probability) - fig = go.Figure(data=[ - go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), - go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), - ]) + fig = go.Figure( + data=[ + go.Scatter(x=fpr, y=tpr, mode="lines", name=f"ROC (AUC={auc:.3f})"), + go.Scatter(x=[0, 1], y=[0, 1], mode="lines", line=dict(dash="dash")), + ] + ) fig.update_layout(title="ROC Curve", xaxis_title="FPR", yaxis_title="TPR") out_dir.mkdir(exist_ok=True, parents=True) fig.write_image(out_dir / "roc_curve.pdf") - def compute_confusion(df, threshold=0.5): df_valid = df[df["prediction"].notna()] - if df_valid.empty: return dict(true_positives=0, true_negatives=0, false_positives=0, false_negatives=0, total_positives=0, total_negatives=0) + if df_valid.empty: + return dict( + true_positives=0, + true_negatives=0, + false_positives=0, + false_negatives=0, + total_positives=0, + total_negatives=0, + ) act = df_valid["activating"].astype(bool) pred = df_valid["prediction"] >= threshold tp, tn = (pred & act).sum(), (~pred & ~act).sum() fp, fn = (pred & ~act).sum(), (~pred & act).sum() - return dict(true_positives=tp, true_negatives=tn, false_positives=fp, false_negatives=fn, total_positives=act.sum(), total_negatives=(~act).sum()) + return dict( + true_positives=tp, + true_negatives=tn, + false_positives=fp, + false_negatives=fn, + total_positives=act.sum(), + total_negatives=(~act).sum(), + ) + def compute_classification_metrics(conf): - tp, tn, fp, _ = conf["true_positives"], conf["true_negatives"], conf["false_positives"], conf["false_negatives"] + tp, tn, fp, _ = ( + conf["true_positives"], + conf["true_negatives"], + conf["false_positives"], + conf["false_negatives"], + ) pos, neg = conf["total_positives"], conf["total_negatives"] - acc = ((tp/pos if pos else 0) + (tn/neg if neg else 0)) / 2 - prec = tp/(tp+fp) if (tp+fp) else 0 - rec = tp/pos if pos else 0 - f1 = 2*(prec*rec)/(prec+rec) if (prec+rec) else 0 + acc = ((tp / pos if pos else 0) + (tn / neg if neg else 0)) / 2 + prec = tp / (tp + fp) if (tp + fp) else 0 + rec = tp / pos if pos else 0 + f1 = 2 * (prec * rec) / (prec + rec) if (prec + rec) else 0 return dict(accuracy=acc, precision=prec, recall=rec, f1_score=f1) + def compute_auc(df): valid = df[df.probability.notna()] - if valid.probability.nunique() <= 1: return None + if valid.probability.nunique() <= 1: + return None return roc_auc_score(valid.activating, valid.probability) + def get_agg_metrics(df): rows = [] for scorer, group in df.groupby("score_type"): - if scorer == "surprisal_intervention": continue + if scorer == "surprisal_intervention": + continue conf = compute_confusion(group) - rows.append({"score_type": scorer, **conf, **compute_classification_metrics(conf), "auc": compute_auc(group)}) + rows.append( + { + "score_type": scorer, + **conf, + **compute_classification_metrics(conf), + "auc": compute_auc(group), + } + ) return pd.DataFrame(rows) + def add_latent_f1(df): - f1s = df.groupby(["module", "latent_idx"]).apply( - lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] - ).reset_index(name="f1_score") + f1s = ( + df.groupby(["module", "latent_idx"]) + .apply( + lambda g: compute_classification_metrics(compute_confusion(g))["f1_score"] + ) + .reset_index(name="f1_score") + ) return df.merge(f1s, on=["module", "latent_idx"]) - def load_data(scores_path, modules): def parse_file(path): try: data = orjson.loads(path.read_bytes()) - if not isinstance(data, list): return pd.DataFrame() + if not isinstance(data, list): + return pd.DataFrame() latent_idx = int(path.stem.split("latent")[-1]) - return pd.DataFrame([{ - "text": "".join(ex.get("str_tokens", [])), - "activating": ex.get("activating"), - "prediction": ex.get("prediction"), - "probability": ex.get("probability"), - "final_score": ex.get("final_score"), - "avg_kl_divergence": ex.get("avg_kl_divergence"), - "latent_idx": latent_idx - } for ex in data]) - except Exception: return pd.DataFrame() + return pd.DataFrame( + [ + { + "text": "".join(ex.get("str_tokens", [])), + "activating": ex.get("activating"), + "prediction": ex.get("prediction"), + "probability": ex.get("probability"), + "final_score": ex.get("final_score"), + "avg_kl_divergence": ex.get("avg_kl_divergence"), + "latent_idx": latent_idx, + } + for ex in data + ] + ) + except Exception: + return pd.DataFrame() counts_file = scores_path.parent / "log" / "hookpoint_firing_counts.pt" counts = torch.load(counts_file, weights_only=True) if counts_file.exists() else {} - + dfs = [] for scorer_dir in scores_path.iterdir(): - if not scorer_dir.is_dir(): continue + if not scorer_dir.is_dir(): + continue for module in modules: for f in scorer_dir.glob(f"*{module}*"): df = parse_file(f) - if df.empty: continue + if df.empty: + continue df["score_type"] = scorer_dir.name df["module"] = module if module in counts: @@ -236,14 +312,19 @@ def parse_file(path): if idx < len(counts[module]): df["firing_count"] = counts[module][idx].item() dfs.append(df) - - return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts + return (pd.concat(dfs, ignore_index=True) if dfs else pd.DataFrame()), counts -def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_names: list[str], model_name: str = "Unknown"): +def log_results( + scores_path: Path, + viz_path: Path, + modules: list[str], + scorer_names: list[str], + model_name: str = "Unknown", +): import_plotly() - + latent_df, counts = load_data(scores_path, modules) if latent_df.empty: print("No data found.") @@ -259,12 +340,13 @@ def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_na # 1. Handle Classification (Fuzz/Detection) if not class_df.empty: class_df = add_latent_f1(class_df) - if counts: plot_firing_vs_f1(class_df, 10_000_000, viz_path, scores_path.name) + if counts: + plot_firing_vs_f1(class_df, 10_000_000, viz_path, scores_path.name) plot_roc_curve(class_df, viz_path) - + agg_df = get_agg_metrics(class_df) plot_accuracy_hist(agg_df, viz_path) - + for _, row in agg_df.iterrows(): print(f"\n[ {row['score_type'].title()} ]") print(f"Accuracy: {row['accuracy']:.3f}") @@ -273,22 +355,22 @@ def log_results(scores_path: Path, viz_path: Path, modules: list[str], scorer_na # 2. Handle Intervention if not int_df.empty: unique_latents = int_df.drop_duplicates(subset=["module", "latent_idx"]).copy() - + avg_score = unique_latents["final_score"].mean() avg_kl = unique_latents["avg_kl_divergence"].mean() - + threshold = 0.01 n_total = len(unique_latents) n_live = len(unique_latents[unique_latents["avg_kl_divergence"] > threshold]) pct = (n_live / n_total * 100) if n_total > 0 else 0 - - print(f"\n--- Surprisal Intervention Analysis ---") + + print("\n--- Surprisal Intervention Analysis ---") print(f"Avg Normalized Score: {avg_score:.3f}") print(f"Avg KL Divergence: {avg_kl:.3f}") print(f"Decoder-Live %: {pct:.2f}%") - + plot_intervention_stats(unique_latents, viz_path, model_name) # 3. Generate Scatter Plot (Fuzz vs. Intervention) if not class_df.empty and not int_df.empty: - plot_fuzz_vs_intervention(latent_df, viz_path, scores_path.name) \ No newline at end of file + plot_fuzz_vs_intervention(latent_df, viz_path, scores_path.name) diff --git a/delphi/temp.py b/delphi/temp.py index 4572510c..19c29522 100644 --- a/delphi/temp.py +++ b/delphi/temp.py @@ -1,11 +1,14 @@ # Create a file named run_analysis.py with these contents -from delphi.log.result_analysis import log_results from pathlib import Path +from delphi.log.result_analysis import log_results + # Adjust the path to your results folder scores_path = Path("results/pythia_100_test/scores") viz_path = Path("results/pythia_100_test/visualize") modules = ["layers.6.mlp"] scorer_names = ["fuzz", "detection", "surprisal_intervention"] -log_results(scores_path, viz_path, modules, scorer_names, model_name="EleutherAI/pythia-160m") \ No newline at end of file +log_results( + scores_path, viz_path, modules, scorer_names, model_name="EleutherAI/pythia-160m" +)