Implementation of the Chefer et al. (2021) gradient-weighted attention explainability method for MedGemma 1.5 4B vision-language model.
This library generates relevancy maps that highlight which image regions and text tokens contribute to MedGemma's predictions. It implements the method from:
Chefer, H., Gur, S., & Wolf, L. (2021). Transformer Interpretability Beyond Attention Visualization. CVPR 2021.
Read the detailed blog post here
- Gradient-weighted attention: Combines attention patterns with gradient information
- Token-specific explanations: Explain why specific words were generated
- Keyword search: Automatically find and explain keywords in responses
- Answer span explanation: Explain the entire generated response
- Medical imaging support: Correct anatomical orientation for chest X-rays
# Clone the repository
git clone https://github.com/thedatasense/medgemma-explainer.git
cd medgemma-explainer
# Install dependencies
pip install -r requirements.txtOr in Google Colab:
!git clone https://github.com/thedatasense/medgemma-explainer.git
import sys
sys.path.insert(0, '/content/medgemma-explainer')
from medgemma_explainability import MedGemmaExplainer, load_medgemmafrom medgemma_explainability import MedGemmaExplainer, load_medgemma
# Load model (must use eager attention)
model, processor, device = load_medgemma(
"google/medgemma-1.5-4b-it",
attn_implementation="eager" # Required for attention output
)
# Create explainer
explainer = MedGemmaExplainer(model, processor, device=device)
# Generate explanation
result = explainer.explain(image, "What do you see in this image?")
# Or explain a specific keyword
result = explainer.explain_keyword(image, prompt, keyword="pneumonia")
# Or explain the entire answer
result = explainer.explain_answer_span(image, prompt)This is the most important aspect of the implementation.
For causal language models like MedGemma:
- Logit at position i predicts token at position i+1
- To explain why token at position
pwas generated:- Backprop from logit at position
p-1 - Use the actual token id at position
p(not argmax) - Extract relevancy from row
p-1in the R matrix
- Backprop from logit at position
A common mistake is to backprop from the last position using argmax:
# WRONG - explains "what comes after the last token"
target_logit = logits[0, -1, logits[0, -1].argmax()]The correct approach:
# CORRECT - explains why token at position p was generated
logit_position = target_token_position - 1
target_token_id = input_ids[0, target_token_position] # Actual token
target_logit = logits[0, logit_position, target_token_id]
# Extract from R[logit_position, :]- Keep model in eval mode: Use
torch.enable_grad()context instead ofmodel.train() - Retain gradients: Call
attn.requires_grad_(True)andattn.retain_grad()on attention tensors - Use eager attention: MedGemma's default SDPA doesn't support
output_attentions=True - Convert to float32: Attention tensors are bfloat16; convert for stable gradient computation
Equation 5: Gradient-Weighted Attention
Ā = E_h[(∇A ⊙ A)⁺]
Where:
A= attention weights∇A= gradient of loss w.r.t. attention⊙= element-wise multiplication(·)⁺= keep only positive valuesE_h= average over attention heads
Equation 6: Relevancy Propagation
R = Ā @ R
Starting with R = I (identity), propagate through each layer.
- Language Model: 34 transformer layers
- Attention: 8 query heads, 4 KV heads (GQA)
- Image Tokens: 256 tokens (16×16 grid) at positions 6-261
- Global Attention Layers: 5, 11, 17, 23, 29 (5:1 local:global ratio)
- Local Window: 1024 tokens
Position 0: <bos>
Position 1: <start_of_turn>
Position 2: user
Position 3-5: prefix tokens
Position 6-261: 256 IMAGE TOKENS (16×16 grid)
Position 262: <end_of_image>
Position 263+: Text prompt and generated response
For chest X-rays (PA view):
- Left side of image = Patient's RIGHT side
- Right side of image = Patient's LEFT side
The 16×16 relevancy grid maps to anatomical regions accordingly.
class MedGemmaExplainer:
def __init__(self, model, processor, device=None, ...):
"""Initialize explainer with MedGemma model."""
def explain(self, image, prompt, target_token_position=None, ...):
"""Generate explanation for a specific token position."""
def explain_keyword(self, image, prompt, keyword, ...):
"""Find and explain a keyword in the response."""
def explain_answer_span(self, image, prompt, ...):
"""Explain the entire generated answer."""@dataclass
class ExplanationResult:
image_relevancy: np.ndarray # 16x16 relevancy map
text_relevancy: np.ndarray # Text token relevancy scores
token_labels: List[str] # Decoded token strings
generated_text: str # Full generated response
raw_image_relevancy: np.ndarray # Raw attention baseline (optional)
raw_text_relevancy: np.ndarray # Raw attention baseline (optional)
metadata: dict # Additional infomedgemma-explainer/
├── medgemma_explainability/ # Main package
│ ├── __init__.py # Package exports
│ ├── explainer.py # Main MedGemmaExplainer class
│ ├── relevancy.py # Chefer method implementation
│ ├── visualization.py # Plotting utilities
│ ├── attention_hooks.py # Attention capture hooks
│ └── utils.py # Helper functions
├── scripts/ # Example scripts
├── tutorial_mem_optimized.ipynb/ # Tutorial notebook
├── tests/ # Unit tests
├── requirements.txt
└── README.md
When asked "Where is the remote control?", the model generates a response mentioning the remote. Explaining the "remote" token shows highest relevancy at the bottom-center of the image where the remote is located.
When analyzing a chest X-ray with right middle lobe pneumonia, the relevancy map correctly highlights the patient's right lung field (left side of image).
The X-ray image used in the example was downloaded from: Radiopaedia
The cat image used in the remote example was downloaded from: Wikimedia Commons
- Python 3.8+
- PyTorch 2.0+
- transformers
- Pillow
- matplotlib
- numpy
If you use this work, please cite the following:
- Chefer, H., Gur, S., & Wolf, L. (2021). Transformer Interpretability Beyond Attention Visualization. CVPR 2021.
- MedGemma Explainer GitHub Repository
- Detailed Blog Post
This implementation is provided for research and educational purposes.