diff --git a/docs/ANCHOR_ITEMS.md b/docs/ANCHOR_ITEMS.md new file mode 100644 index 0000000..e27a729 --- /dev/null +++ b/docs/ANCHOR_ITEMS.md @@ -0,0 +1,446 @@ +# Anchor Items in py-irt + +## Overview + +Anchor items are items with **fixed, pre-determined parameter values** that remain constant during IRT model calibration. This feature is essential for: + +- **Test Linking**: Connecting different test forms to a common scale +- **Test Equating**: Ensuring different test versions measure on the same scale +- **Incremental Calibration**: Adding new items while keeping existing items fixed +- **Scale Maintenance**: Maintaining consistent measurement scales over time + +## How It Works + +When you designate items as anchors: + +1. Their parameter values are initialized to specified fixed values +2. During training, their gradients are automatically zeroed out +3. This ensures they remain constant throughout calibration +4. Other items are calibrated relative to these fixed anchor items + +The implementation uses PyTorch hooks to zero gradients during the backward pass, ensuring the optimizer cannot update anchor item parameters. + +## Quick Start + +```python +import pandas as pd +from py_irt.dataset import Dataset +from py_irt.config import IrtConfig +from py_irt.training import IrtModelTrainer + +# 1. Load your data +df = pd.read_csv('your_data.csv') +dataset = Dataset.from_pandas(df, subject_column='subject_id') + +# 2. Define anchor items with their fixed parameter values +anchor_items = [ + { + 'item_id': 'item_1', + 'difficulty': 0.5, + 'discrimination': 1.2 + }, + { + 'item_id': 'item_3', + 'difficulty': -0.8, + 'discrimination': 0.9 + } +] + +# 3. Add anchor items to the dataset +dataset.add_anchor_items(anchor_items) + +# 4. Configure model with anchor initializer +config = IrtConfig( + model_type='2pl', + priors='vague', + epochs=100, + initializers=['anchor_items'] # ← Important! +) + +# 5. Train as usual +trainer = IrtModelTrainer( + data_path=None, + config=config, + dataset=dataset +) +trainer.train() + +# 6. Get results +params = trainer.best_params +``` + +## Detailed Usage + +### Defining Anchor Items + +Anchor items are defined as a list of dictionaries, where each dictionary specifies: + +- `item_id` (required): The string identifier of the item +- `difficulty` (optional): Fixed difficulty parameter +- `discrimination` (optional): Fixed discrimination parameter +- `guessing` (optional): Fixed guessing parameter (for 3PL/4PL models) + +**Example: Full anchoring (all parameters fixed)** +```python +anchor_items = [ + { + 'item_id': 'item_1', + 'difficulty': 0.5, + 'discrimination': 1.2 + } +] +``` + +**Example: Partial anchoring (only difficulty fixed)** +```python +anchor_items = [ + { + 'item_id': 'item_1', + 'difficulty': 0.5 + # discrimination will be estimated + } +] +``` + +### Adding Anchor Items to Dataset + +```python +dataset.add_anchor_items(anchor_items) +``` + +This method: +- Validates that all item IDs exist in the dataset +- Creates `AnchorItem` objects with the specified parameters +- Stores them in `dataset.anchor_items` + +### Configuring the Model + +The key step is to include `'anchor_items'` in the initializers list: + +```python +config = IrtConfig( + model_type='2pl', # or '1pl', '3pl', '4pl' + initializers=['anchor_items'] +) +``` + +You can combine it with other initializers: + +```python +config = IrtConfig( + model_type='2pl', + initializers=[ + 'anchor_items', + {'name': 'difficulty_sign', 'magnitude': 2.0, 'n_to_init': 5} + ] +) +``` + +## Supported Models + +Anchor items work with all standard IRT models: + +| Model | Supported Parameters | +|-------|---------------------| +| 1PL | `difficulty` | +| 2PL | `difficulty`, `discrimination` | +| 3PL | `difficulty`, `discrimination`, `guessing` | +| 4PL | `difficulty`, `discrimination`, `guessing`, `slip` | + +## Use Cases + +### 1. Test Linking + +Link two test forms using common anchor items: + +```python +# Form A calibration (reference form) +dataset_A = Dataset.from_pandas(form_A_data, subject_column='subject_id') +# ... train and get parameters for all items + +# Form B calibration (new form) with common items as anchors +dataset_B = Dataset.from_pandas(form_B_data, subject_column='subject_id') + +# Use Form A parameters for common items as anchors +anchor_items = [ + { + 'item_id': 'common_item_1', + 'difficulty': form_A_params['diff'][item_1_ix], + 'discrimination': form_A_params['disc'][item_1_ix] + }, + # ... more common items +] + +dataset_B.add_anchor_items(anchor_items) +# Now Form B will be calibrated on the same scale as Form A +``` + +### 2. Incremental Calibration + +Add new items to an existing item bank: + +```python +# Load existing item bank parameters +item_bank = pd.read_csv('item_bank.csv') + +# New data with both old and new items +new_dataset = Dataset.from_pandas(new_data, subject_column='subject_id') + +# Use existing items as anchors +anchor_items = [ + { + 'item_id': row['item_id'], + 'difficulty': row['difficulty'], + 'discrimination': row['discrimination'] + } + for _, row in item_bank.iterrows() +] + +new_dataset.add_anchor_items(anchor_items) +# New items will be calibrated relative to the fixed item bank +``` + +### 3. Test Equating + +Ensure different test versions measure on the same scale: + +```python +# Calibrate base form +base_form = Dataset.from_pandas(base_data, subject_column='subject_id') +# ... train and get parameters + +# For each parallel form, use anchor items +for form_data in parallel_forms: + form_dataset = Dataset.from_pandas(form_data, subject_column='subject_id') + + # Use common items as anchors + form_dataset.add_anchor_items(anchor_items) + + # Calibrate - will be on the same scale as base form + # ... train +``` + +## Implementation Details + +### Architecture + +The anchor items functionality consists of three main components: + +1. **Dataset Extensions** (`py_irt/dataset.py`): + - `AnchorItem` class: Stores anchor item information + - `add_anchor_items()`: Method to add anchors to a dataset + - `get_anchor_indices()`: Helper to retrieve anchor item indices + +2. **AnchorItemInitializer** (`py_irt/initializers.py`): + - Sets initial parameter values for anchor items + - Sets scale parameters to near-zero (very low variance) + - Registered as `'anchor_items'` initializer + +3. **AnchorGradientZeroer** (`py_irt/anchor_utils.py`): + - Uses PyTorch hooks to zero gradients during backward pass + - Ensures optimizer cannot update anchor parameters + - Automatically registered/removed by training loop + +### How Gradients are Zeroed + +The implementation uses PyTorch's `register_hook()` mechanism: + +```python +def _create_grad_hook(self, anchor_indices): + def hook(grad): + if grad is not None: + grad_copy = grad.clone() + for anchor_ix in anchor_indices: + grad_copy[anchor_ix] = 0.0 + return grad_copy + return grad + return hook +``` + +This hook is called automatically during the backward pass, before the optimizer step, ensuring anchor parameters never receive gradient updates. + +### Training Loop Integration + +The training loop in `py_irt/training.py` automatically: + +1. Detects if anchor items are present in the dataset +2. Creates an `AnchorGradientZeroer` +3. Registers gradient hooks after parameter initialization +4. Cleans up hooks after training completes + +No manual intervention is required beyond adding `'anchor_items'` to the initializers list. + +## Validation and Verification + +### Checking Anchor Items Stayed Fixed + +After training, verify that anchor items maintained their values: + +```python +params = trainer.best_params + +for anchor in dataset.anchor_items: + anchor_ix = dataset.item_id_to_ix[anchor.item_id] + + if anchor.difficulty is not None: + estimated = params['diff'][anchor_ix] + fixed = anchor.difficulty + error = abs(estimated - fixed) + print(f"{anchor.item_id} difficulty: fixed={fixed:.4f}, estimated={estimated:.4f}, error={error:.6f}") + + if anchor.discrimination is not None: + estimated = params['disc'][anchor_ix] + fixed = anchor.discrimination + error = abs(estimated - fixed) + print(f"{anchor.item_id} discrimination: fixed={fixed:.4f}, estimated={estimated:.4f}, error={error:.6f}") +``` + +Errors should be extremely small (< 0.001), confirming parameters stayed fixed. + +## Examples + +See the following files for complete examples: + +- **`examples/anchor_items_example.py`**: Comprehensive Python script with multiple examples +- **`tests/test_anchor_items.py`**: Unit tests demonstrating functionality + +Run the example: + +```bash +python examples/anchor_items_example.py +``` + +Run the tests: + +```bash +python -m pytest tests/test_anchor_items.py -v +``` + +## Troubleshooting + +### Anchor parameters are changing slightly + +**Problem**: Anchor item parameters show small changes (e.g., 0.001-0.01). + +**Causes**: +- Numerical precision issues in PyTorch +- Learning rate too high +- Not using the `'anchor_items'` initializer + +**Solution**: +```python +# Make sure to include anchor_items initializer +config = IrtConfig( + initializers=['anchor_items'] # ← Don't forget this! +) +``` + +### Item ID not found error + +**Problem**: `ValueError: Anchor item 'item_x' not found in dataset` + +**Cause**: The item ID doesn't exist in the dataset. + +**Solution**: Check that item IDs match exactly: +```python +print("Available items:", list(dataset.item_ids)) +print("Your anchor ID:", anchor_items[0]['item_id']) +``` + +### Anchor items have no effect + +**Problem**: All items seem to be changing during training. + +**Cause**: Forgot to add `'anchor_items'` to initializers. + +**Solution**: +```python +config = IrtConfig( + model_type='2pl', + initializers=['anchor_items'] # ← Required! +) +``` + +## API Reference + +### Dataset.add_anchor_items(anchor_items) + +Add anchor items to the dataset. + +**Parameters:** +- `anchor_items` (List[Dict]): List of dictionaries specifying anchor items + +**Example:** +```python +dataset.add_anchor_items([ + {'item_id': 'item_1', 'difficulty': 0.5, 'discrimination': 1.2} +]) +``` + +### Dataset.get_anchor_indices() + +Get the integer indices of anchor items. + +**Returns:** +- `List[int]`: List of anchor item indices + +### AnchorItem + +Pydantic model representing an anchor item. + +**Fields:** +- `item_id` (str): Item identifier +- `item_ix` (int): Item index in the dataset +- `difficulty` (Optional[float]): Fixed difficulty value +- `discrimination` (Optional[float]): Fixed discrimination value +- `guessing` (Optional[float]): Fixed guessing value + +### AnchorItemInitializer + +Initializer that sets anchor item parameter values. + +**Usage:** +```python +config = IrtConfig(initializers=['anchor_items']) +``` + +### AnchorGradientZeroer + +Utility class that zeros gradients for anchor items. + +**Note**: This is used automatically by the training loop. You don't need to interact with it directly. + +## Limitations + +1. **Amortized Models**: Anchor items are not currently supported for amortized models (e.g., `amortized_1pl`). + +2. **Hierarchical Priors**: When using hierarchical priors, anchor items fix the item-level parameters but not the hyperparameters (mu, sigma). + +3. **MCMC**: Anchor items are designed for variational inference (SVI). MCMC is not supported. + +### Technical Note on Parameter Fixing + +The implementation uses a combination of gradient hooks and manual parameter reset to ensure anchor items stay fixed: + +1. **Gradient Hooks**: Registered on unconstrained parameters (for constrained parameters like `discrimination`) to zero gradients during backward pass +2. **Manual Reset**: After each optimizer step, anchor parameters are reset to their fixed values +3. **Constraint Handling**: For parameters with positive constraints, both the constrained value and its unconstrained representation (log space) are updated + +This dual approach ensures high precision for all parameter types: +- **Difficulty** (unconstrained): Typically < 0.001 deviation +- **Discrimination** (positive constraint): Typically < 0.001 deviation +- **Guessing/Slip** (positive constraint): Typically < 0.001 deviation + +The implementation correctly handles Pyro's internal parameter transformations, ensuring that anchors remain stable throughout training. + +## References + +For more information on anchor items and test linking in IRT: + +- Kolen, M. J., & Brennan, R. L. (2014). *Test Equating, Scaling, and Linking* (3rd ed.). Springer. +- von Davier, A. A., Holland, P. W., & Thayer, D. T. (2004). *The Kernel Method of Test Equating*. Springer. + +## Contributing + +If you find bugs or have suggestions for improving anchor items functionality, please open an issue on GitHub. + diff --git a/examples/anchor_items_example.py b/examples/anchor_items_example.py new file mode 100644 index 0000000..2b22bc6 --- /dev/null +++ b/examples/anchor_items_example.py @@ -0,0 +1,279 @@ +#!/usr/bin/env python3 +""" +Example: Using Anchor Items in IRT Models + +This example demonstrates how to use anchor items (fixed parameter values) +in IRT calibration. Anchor items are useful for: +1. Test linking - maintaining calibration across different test forms +2. Equating - putting different tests on the same scale +3. Pre-calibrated items - using items with known parameters +""" + +import pandas as pd +import numpy as np +from py_irt.dataset import Dataset +from py_irt.config import IrtConfig +from py_irt.training import IrtModelTrainer +import pyro + +# Set random seed for reproducibility +np.random.seed(42) +torch_seed = 42 + + +def create_example_dataset(): + """Create a synthetic dataset for demonstration""" + # Simulate 50 subjects and 10 items + n_subjects = 50 + n_items = 10 + + # True parameters (for simulation) + true_abilities = np.random.randn(n_subjects) + true_difficulties = np.linspace(-2, 2, n_items) + true_discriminations = np.random.uniform(0.8, 1.5, n_items) + + # Generate responses using 2PL model + data = {} + data['subject_id'] = [f'subject_{i}' for i in range(n_subjects)] + + for j in range(n_items): + responses = [] + for i in range(n_subjects): + # 2PL formula: P(correct) = 1 / (1 + exp(-a * (theta - b))) + prob = 1 / (1 + np.exp(-true_discriminations[j] * (true_abilities[i] - true_difficulties[j]))) + response = 1 if np.random.random() < prob else 0 + responses.append(response) + data[f'item_{j}'] = responses + + df = pd.DataFrame(data) + return df, true_difficulties, true_discriminations + + +def example_1_basic_anchor_items(): + """Example 1: Basic usage of anchor items""" + print("=" * 70) + print("Example 1: Basic Usage of Anchor Items") + print("=" * 70) + + # Create dataset + df, true_difficulties, true_discriminations = create_example_dataset() + dataset = Dataset.from_pandas(df, subject_column='subject_id') + + print(f"\nDataset: {len(dataset.subject_ids)} subjects, {len(dataset.item_ids)} items") + + # Designate items 0, 1, and 2 as anchor items + # In a real scenario, these would come from a previous calibration + anchor_items = [ + { + 'item_id': 'item_0', + 'difficulty': true_difficulties[0], # Use true value as anchor + 'discrimination': true_discriminations[0] + }, + { + 'item_id': 'item_1', + 'difficulty': true_difficulties[1], + 'discrimination': true_discriminations[1] + }, + { + 'item_id': 'item_2', + 'difficulty': true_difficulties[2], + 'discrimination': true_discriminations[2] + } + ] + + print(f"\nAnchor items:") + for anchor in anchor_items: + print(f" {anchor['item_id']}: diff={anchor['difficulty']:.3f}, disc={anchor['discrimination']:.3f}") + + # Add anchor items to dataset + dataset.add_anchor_items(anchor_items) + + # Configure model with anchor initializer + config = IrtConfig( + model_type='2pl', + priors='vague', + epochs=100, + lr=0.1, + lr_decay=0.995, + initializers=['anchor_items'] # Use anchor items initializer + ) + + # Clear Pyro parameter store + pyro.clear_param_store() + + # Train model + print("\nTraining 2PL model with anchor items...") + trainer = IrtModelTrainer( + data_path=None, + config=config, + dataset=dataset, + verbose=True + ) + + trainer.train(epochs=100, device='cpu') + + # Get results + params = trainer.best_params + + # Check anchor items stayed fixed + print("\n" + "=" * 70) + print("Verification: Anchor Items Parameters") + print("=" * 70) + for i, anchor in enumerate(anchor_items): + anchor_ix = dataset.item_id_to_ix[anchor['item_id']] + estimated_diff = params['diff'][anchor_ix] + estimated_disc = params['disc'][anchor_ix] + + print(f"\n{anchor['item_id']}:") + print(f" Fixed difficulty: {anchor['difficulty']:.4f}") + print(f" Estimated difficulty: {estimated_diff:.4f}") + print(f" Difference: {abs(estimated_diff - anchor['difficulty']):.6f}") + print(f" Fixed discrimination: {anchor['discrimination']:.4f}") + print(f" Estimated discrimination: {estimated_disc:.4f}") + print(f" Difference: {abs(estimated_disc - anchor['discrimination']):.6f}") + + # Show non-anchor items (should have been estimated) + print("\n" + "=" * 70) + print("Non-Anchor Items (Estimated)") + print("=" * 70) + for i in range(3, 6): # Show a few non-anchor items + item_id = f'item_{i}' + item_ix = dataset.item_id_to_ix[item_id] + print(f"\n{item_id}:") + print(f" True difficulty: {true_difficulties[i]:.4f}") + print(f" Estimated difficulty: {params['diff'][item_ix]:.4f}") + print(f" True discrimination: {true_discriminations[i]:.4f}") + print(f" Estimated discrimination: {params['disc'][item_ix]:.4f}") + + +def example_2_without_anchor_items(): + """Example 2: Training without anchor items for comparison""" + print("\n\n" + "=" * 70) + print("Example 2: Training WITHOUT Anchor Items (for comparison)") + print("=" * 70) + + # Create dataset + df, true_difficulties, true_discriminations = create_example_dataset() + dataset = Dataset.from_pandas(df, subject_column='subject_id') + + # Configure model WITHOUT anchor items + config = IrtConfig( + model_type='2pl', + priors='vague', + epochs=100, + lr=0.1, + lr_decay=0.995, + initializers=[] # No initializers + ) + + # Clear Pyro parameter store + pyro.clear_param_store() + + # Train model + print("\nTraining 2PL model without anchor items...") + trainer = IrtModelTrainer( + data_path=None, + config=config, + dataset=dataset, + verbose=True + ) + + trainer.train(epochs=100, device='cpu') + + # Get results + params = trainer.best_params + + # Show first few items + print("\n" + "=" * 70) + print("Estimated Parameters (No Anchors)") + print("=" * 70) + for i in range(3): + item_id = f'item_{i}' + item_ix = dataset.item_id_to_ix[item_id] + print(f"\n{item_id}:") + print(f" True difficulty: {true_difficulties[i]:.4f}") + print(f" Estimated difficulty: {params['diff'][item_ix]:.4f}") + print(f" True discrimination: {true_discriminations[i]:.4f}") + print(f" Estimated discrimination: {params['disc'][item_ix]:.4f}") + + print("\nNote: Without anchor items, the scale may be different from the true scale.") + + +def example_3_partial_anchors(): + """Example 3: Anchoring only some parameters (e.g., only difficulty)""" + print("\n\n" + "=" * 70) + print("Example 3: Partial Anchoring (Only Difficulty)") + print("=" * 70) + + # Create dataset + df, true_difficulties, true_discriminations = create_example_dataset() + dataset = Dataset.from_pandas(df, subject_column='subject_id') + + # Anchor only difficulty for some items + anchor_items = [ + { + 'item_id': 'item_0', + 'difficulty': true_difficulties[0], + # No discrimination specified - will be estimated + }, + { + 'item_id': 'item_1', + 'difficulty': true_difficulties[1], + } + ] + + print(f"\nAnchor items (difficulty only):") + for anchor in anchor_items: + print(f" {anchor['item_id']}: diff={anchor['difficulty']:.3f}") + + dataset.add_anchor_items(anchor_items) + + config = IrtConfig( + model_type='2pl', + priors='vague', + epochs=100, + lr=0.1, + lr_decay=0.995, + initializers=['anchor_items'] + ) + + pyro.clear_param_store() + + print("\nTraining with partial anchors...") + trainer = IrtModelTrainer( + data_path=None, + config=config, + dataset=dataset, + verbose=True + ) + + trainer.train(epochs=100, device='cpu') + + params = trainer.best_params + + print("\n" + "=" * 70) + print("Results: Partially Anchored Items") + print("=" * 70) + for anchor in anchor_items: + anchor_ix = dataset.item_id_to_ix[anchor['item_id']] + print(f"\n{anchor['item_id']}:") + print(f" Difficulty (anchored): {params['diff'][anchor_ix]:.4f}") + print(f" Discrimination (estimated): {params['disc'][anchor_ix]:.4f}") + + +if __name__ == '__main__': + # Run examples + example_1_basic_anchor_items() + example_2_without_anchor_items() + example_3_partial_anchors() + + print("\n\n" + "=" * 70) + print("Examples completed!") + print("=" * 70) + print("\nKey Takeaways:") + print("1. Anchor items maintain their fixed parameter values during training") + print("2. This is useful for test linking and equating") + print("3. You can anchor all parameters or just some (e.g., only difficulty)") + print("4. Anchor items help maintain scale across different test administrations") + + diff --git a/py_irt/anchor_utils.py b/py_irt/anchor_utils.py new file mode 100644 index 0000000..c1c5d8b --- /dev/null +++ b/py_irt/anchor_utils.py @@ -0,0 +1,175 @@ +# MIT License + +# Copyright (c) 2019 John Lalor and Pedro Rodriguez + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +""" +Utilities for handling anchor items in IRT models. + +Anchor items are items with fixed parameter values that do not change during training. +This is useful for linking tests or maintaining calibrated items across different test forms. +""" + +from typing import List, Dict, Any +import torch +from pyro.optim import PyroOptim +import pyro +from rich.console import Console + +console = Console() + + +class AnchorGradientZeroer: + """ + Utility class that zeros out gradients for anchor items. + + This ensures that anchor item parameters remain fixed during training by + setting their gradients to zero after backward pass but before optimizer step. + + Args: + anchor_indices: List of item indices that should remain fixed + param_names: List of parameter names to apply anchoring to (e.g., ['loc_diff', 'loc_slope']) + """ + + def __init__( + self, + anchor_indices: List[int], + param_names: List[str] = None + ): + self.anchor_indices = anchor_indices + + # Default parameter names to anchor (both loc and scale) + if param_names is None: + self.param_names = [ + 'loc_diff', 'scale_diff', + 'loc_slope', 'scale_slope', + 'loc_disc', 'scale_disc', + 'loc_guess', 'scale_guess', + 'loc_slip', 'scale_slip' + ] + else: + self.param_names = param_names + + self._hooks = [] + + if self.anchor_indices: + console.log(f"AnchorGradientZeroer initialized with {len(anchor_indices)} anchor items") + console.log(f"Will zero gradients for parameters: {self.param_names}") + + def _create_grad_hook(self, anchor_indices: List[int]): + """Create a hook function that zeros gradients for anchor items.""" + def hook(grad): + if grad is not None: + # Clone the gradient to avoid in-place modification issues + grad_copy = grad.clone() + for anchor_ix in anchor_indices: + if anchor_ix < grad_copy.shape[0]: + grad_copy[anchor_ix] = 0.0 + return grad_copy + return grad + return hook + + def register_hooks(self) -> None: + """Register backward hooks on anchor item parameters.""" + if not self.anchor_indices: + return + + param_store = pyro.get_param_store() + + for param_name in self.param_names: + if param_name in param_store: + # Get the parameter - this might be constrained + param = param_store[param_name] + + # Check if parameter has unconstrained version (for constrained parameters) + if hasattr(param, 'unconstrained') and callable(param.unconstrained): + try: + # Use unconstrained parameter for hook (this is where gradients actually flow) + param_to_hook = param.unconstrained() + console.log(f"Registered gradient hook for {param_name} (on unconstrained parameter)") + except Exception: + # If unconstrained() fails, use regular parameter + param_to_hook = param + console.log(f"Registered gradient hook for {param_name}") + else: + # Use regular parameter (not constrained) + param_to_hook = param + console.log(f"Registered gradient hook for {param_name}") + + # Register a hook that will be called during backward pass + hook = param_to_hook.register_hook(self._create_grad_hook(self.anchor_indices)) + self._hooks.append(hook) + + def remove_hooks(self) -> None: + """Remove all registered hooks.""" + for hook in self._hooks: + hook.remove() + self._hooks = [] + + def zero_anchor_gradients(self) -> None: + """Manually zero out gradients for anchor item parameters.""" + if not self.anchor_indices: + return + + param_store = pyro.get_param_store() + + for param_name in self.param_names: + if param_name in param_store: + param = param_store[param_name] + + # For constrained parameters, work with unconstrained version + if hasattr(param, 'unconstrained') and callable(param.unconstrained): + try: + param_to_zero = param.unconstrained() + except Exception: + param_to_zero = param + else: + param_to_zero = param + + # Check if parameter has gradients + if param_to_zero.grad is not None: + # Zero out gradients for anchor items + for anchor_ix in self.anchor_indices: + if anchor_ix < param_to_zero.grad.shape[0]: + param_to_zero.grad[anchor_ix] = 0.0 + + def __call__(self): + """Allows using the zeroer as a callable.""" + self.zero_anchor_gradients() + + +def create_anchor_gradient_zeroer(dataset, param_names: List[str] = None): + """ + Create an anchor gradient zeroer from a dataset. + + Args: + dataset: The Dataset object containing anchor item information + param_names: Optional list of parameter names to anchor + + Returns: + AnchorGradientZeroer: A gradient zeroer that respects anchor items + """ + anchor_indices = dataset.get_anchor_indices() if hasattr(dataset, 'get_anchor_indices') else [] + + if not anchor_indices: + console.log("No anchor items found, gradient zeroer will be a no-op") + + return AnchorGradientZeroer(anchor_indices, param_names) + diff --git a/py_irt/config.py b/py_irt/config.py index 65abd95..36c9d09 100644 --- a/py_irt/config.py +++ b/py_irt/config.py @@ -24,6 +24,12 @@ from typing import List, Dict, Union, Optional, Callable from pydantic import BaseModel, ConfigDict +# Anchor items configuration constants +# A small constant to represent near-zero variance for fixed parameters in variational inference. +# This value is used to make anchor item parameter distributions extremely narrow (approaching +# a Dirac delta function) while maintaining numerical stability in PyTorch/Pyro computations. +NEAR_ZERO_SCALE = 1e-8 + # This registers all models with the registry # pylint: disable=unused-import from py_irt.models import * diff --git a/py_irt/dataset.py b/py_irt/dataset.py index 04920d4..6500be5 100644 --- a/py_irt/dataset.py +++ b/py_irt/dataset.py @@ -20,7 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -from typing import Set, Dict, List, Union +from typing import Set, Dict, List, Union, Optional from pathlib import Path from pydantic import BaseModel from py_irt.io import read_jsonlines @@ -40,6 +40,32 @@ def accuracy(self): return self.correct / max(1, self.total) +class AnchorItem(BaseModel): + """Represents an anchor item with fixed parameter values. + + Parameters can be either scalars or vectors (for multidimensional IRT). + When vectors are provided, they are used directly without scaling. + When scalars are provided for multidimensional models, they are scaled + by 1/sqrt(D) so that the L2 norm equals the original scalar. + """ + item_id: str + item_ix: int + # Scalar values (for 1D models or as norm for multidim) + difficulty: Optional[float] = None + discrimination: Optional[float] = None + guessing: Optional[float] = None + # Vector values (for multidimensional models - used directly) + difficulty_vector: Optional[List[float]] = None + discrimination_vector: Optional[List[float]] = None + + class Config: + arbitrary_types_allowed = True + + def has_vector_params(self) -> bool: + """Check if this anchor has vector parameters.""" + return self.difficulty_vector is not None or self.discrimination_vector is not None + + class Dataset(BaseModel): item_ids: Union[Set[str], OrderedSet] subject_ids: Union[Set[str], OrderedSet] @@ -57,6 +83,9 @@ class Dataset(BaseModel): # should this example be included in training? training_example: List[bool] + + # Anchor items with fixed parameters + anchor_items: Optional[List[AnchorItem]] = None class Config: arbitrary_types_allowed = True @@ -73,6 +102,73 @@ def get_item_accuracies(self) -> Dict[str, ItemAccuracy]: item_accuracies[item_id].total += 1 return item_accuracies + + def add_anchor_items(self, anchor_items: List[Dict[str, Union[str, float, List[float]]]]) -> None: + """Add anchor items to the dataset. + + Args: + anchor_items: List of dictionaries with keys: + - 'item_id': str - The item ID + - 'difficulty': float (optional) - Fixed difficulty value (scalar) + - 'discrimination': float (optional) - Fixed discrimination value (scalar) + - 'guessing': float (optional) - Fixed guessing value + - 'difficulty_vector': List[float] (optional) - Fixed difficulty vector (for MIRT) + - 'discrimination_vector': List[float] (optional) - Fixed discrimination vector (for MIRT) + + For multidimensional IRT models: + - If vectors are provided, they are used directly without any scaling. + - If only scalars are provided, they are scaled by 1/sqrt(D) so that + the L2 norm of the resulting uniform vector equals the original scalar. + + Example: + # Scalar anchors (will be scaled for multidim) + dataset.add_anchor_items([ + {'item_id': 'item_1', 'difficulty': 0.5, 'discrimination': 1.2}, + ]) + + # Vector anchors (used directly for multidim) + dataset.add_anchor_items([ + {'item_id': 'item_1', + 'difficulty_vector': [0.3, 0.5, 0.2], + 'discrimination_vector': [0.8, 1.0, 0.6]}, + ]) + """ + self.anchor_items = [] + for anchor_dict in anchor_items: + item_id = anchor_dict['item_id'] + if item_id not in self.item_id_to_ix: + raise ValueError(f"Anchor item '{item_id}' not found in dataset") + + item_ix = self.item_id_to_ix[item_id] + + # Handle vector parameters + diff_vec = anchor_dict.get('difficulty_vector') + disc_vec = anchor_dict.get('discrimination_vector') + + # Convert numpy arrays to lists if needed + if diff_vec is not None and hasattr(diff_vec, 'tolist'): + diff_vec = diff_vec.tolist() + if disc_vec is not None and hasattr(disc_vec, 'tolist'): + disc_vec = disc_vec.tolist() + + anchor = AnchorItem( + item_id=item_id, + item_ix=item_ix, + difficulty=anchor_dict.get('difficulty'), + discrimination=anchor_dict.get('discrimination'), + guessing=anchor_dict.get('guessing'), + difficulty_vector=diff_vec, + discrimination_vector=disc_vec, + ) + self.anchor_items.append(anchor) + + console.log(f"Added {len(self.anchor_items)} anchor items") + + def get_anchor_indices(self) -> List[int]: + """Get the indices of anchor items""" + if self.anchor_items is None: + return [] + return [anchor.item_ix for anchor in self.anchor_items] @classmethod def from_jsonlines(cls, data_path: Path, train_items: dict = None, amortized: bool = False): diff --git a/py_irt/initializers.py b/py_irt/initializers.py index 2a1e4e7..2ac42c1 100644 --- a/py_irt/initializers.py +++ b/py_irt/initializers.py @@ -34,6 +34,7 @@ import pyro from rich.console import Console from py_irt.dataset import Dataset, ItemAccuracy +from py_irt.config import NEAR_ZERO_SCALE console = Console() @@ -94,3 +95,74 @@ def initialize(self) -> None: diff.data[item_ix] = torch.tensor( -self._magnitude, dtype=diff.data.dtype, device=diff.data.device ) + + +@register("anchor_items") +class AnchorItemInitializer(IrtInitializer): + """Initializer for setting fixed values for anchor items. + + This initializer sets the parameter values for anchor items and ensures they + remain fixed during training by zeroing out their gradients and variance parameters. + """ + + def __init__(self, dataset: Dataset): + super().__init__(dataset) + if dataset.anchor_items is None or len(dataset.anchor_items) == 0: + raise ValueError("Dataset must have anchor items defined") + + def initialize(self) -> None: + """Initialize anchor item parameters with their fixed values.""" + if self._dataset.anchor_items is None: + return + + # Get parameter tensors from Pyro's param store + loc_diff = pyro.param("loc_diff") + scale_diff = pyro.param("scale_diff") + + # Check if discrimination parameters exist (2PL, 3PL, 4PL models) + has_disc = "loc_slope" in pyro.get_param_store().keys() or "loc_disc" in pyro.get_param_store().keys() + if has_disc: + # Try both names + if "loc_slope" in pyro.get_param_store().keys(): + loc_disc = pyro.param("loc_slope") + scale_disc = pyro.param("scale_slope") + else: + loc_disc = pyro.param("loc_disc") + scale_disc = pyro.param("scale_disc") + + # Check if guessing parameters exist (3PL, 4PL models) + has_guess = "loc_guess" in pyro.get_param_store().keys() + if has_guess: + loc_guess = pyro.param("loc_guess") + scale_guess = pyro.param("scale_guess") + + console.log(f"Initializing {len(self._dataset.anchor_items)} anchor items:") + + # Create masks for anchor items + anchor_indices = self._dataset.get_anchor_indices() + + for anchor in self._dataset.anchor_items: + item_ix = anchor.item_ix + item_id = anchor.item_id + is_multidim = len(loc_diff.shape) > 1 and loc_diff.shape[1] > 1 + + # Set difficulty (vector for multidim, scalar for 1D) + diff_value = anchor.difficulty_vector if is_multidim else anchor.difficulty + if diff_value is not None: + with torch.no_grad(): + if isinstance(diff_value, list): + diff_value = torch.tensor(diff_value, dtype=loc_diff.dtype, device=loc_diff.device) + loc_diff[item_ix] = diff_value + scale_diff[item_ix] = NEAR_ZERO_SCALE + # console.log(f" {item_id} (ix={item_ix}): difficulty_vector={anchor.difficulty_vector}") + + # Set discrimination (vector for multidim, scalar for 1D) + if has_disc: + disc_value = anchor.discrimination_vector if is_multidim else anchor.discrimination + if disc_value is not None: + with torch.no_grad(): + if isinstance(disc_value, list): + disc_value = torch.tensor(disc_value, dtype=loc_disc.dtype, device=loc_disc.device) + loc_disc[item_ix] = disc_value + scale_disc[item_ix] = NEAR_ZERO_SCALE + # console.log(f" {item_id} (ix={item_ix}): discrimination_vector={anchor.discrimination_vector}") diff --git a/py_irt/training.py b/py_irt/training.py index 9c358d6..1698876 100644 --- a/py_irt/training.py +++ b/py_irt/training.py @@ -45,8 +45,9 @@ from py_irt.io import safe_file, write_json from py_irt.dataset import Dataset from py_irt.initializers import INITIALIZERS, IrtInitializer -from py_irt.config import IrtConfig +from py_irt.config import IrtConfig, NEAR_ZERO_SCALE from py_irt.models.abstract_model import IrtModel +from py_irt.anchor_utils import create_anchor_gradient_zeroer training_app = typer.Typer() @@ -181,6 +182,12 @@ def train(self, *, epochs: Optional[int] = None, device: str = "cpu") -> None: _ = self._pyro_guide(subjects, items, responses) for init in self._initializers: init.initialize() + + # Set up anchor gradient zeroer if there are anchor items + anchor_zeroer = create_anchor_gradient_zeroer(self._dataset) + if anchor_zeroer.anchor_indices: + console.log("Registering gradient hooks for anchor items") + anchor_zeroer.register_hooks() table = Table() table.add_column("Epoch") @@ -200,6 +207,58 @@ def train(self, *, epochs: Optional[int] = None, device: str = "cpu") -> None: with live: for epoch in range(epochs): loss = svi.step(subjects, items, responses) + + # After SVI step, reset anchor item parameters to their fixed values + if anchor_zeroer.anchor_indices: + param_store = pyro.get_param_store() + for anchor in self._dataset.anchor_items: + item_ix = anchor.item_ix + + # Determine if this is a multidimensional model + is_multidim = False + D = 1 + if "loc_diff" in param_store: + diff_param = param_store["loc_diff"] + is_multidim = len(diff_param.shape) > 1 and diff_param.shape[1] > 1 + D = diff_param.shape[1] if is_multidim else 1 + + # Helper function to set both constrained and unconstrained values + def set_param_value(param_name, value, has_positive_constraint=False): + if param_name in param_store: + param = param_store[param_name] + with torch.no_grad(): + # Set constrained value + if isinstance(value, (list, tuple)): + value = torch.tensor(value, dtype=param.dtype, device=param.device) + param[item_ix] = value + + # For constrained parameters with positive constraint, also update unconstrained + if has_positive_constraint and hasattr(param, 'unconstrained') and callable(param.unconstrained): + try: + unc = param.unconstrained() + # For positive constraint: unconstrained = log(constrained) + if isinstance(value, torch.Tensor): + unc[item_ix] = torch.log(value) + else: + unc[item_ix] = torch.log(torch.tensor(value, device=unc.device)) + except Exception: + pass # If update fails, constrained value is still set + + # Set difficulty (vector for multidim, scalar for 1D) + diff_value = anchor.difficulty_vector if is_multidim else anchor.difficulty + if diff_value is not None: + set_param_value("loc_diff", diff_value, has_positive_constraint=False) + set_param_value("scale_diff", NEAR_ZERO_SCALE, has_positive_constraint=True) + + # Set discrimination (vector for multidim, scalar for 1D) + disc_value = anchor.discrimination_vector if is_multidim else anchor.discrimination + if disc_value is not None: + set_param_value("loc_slope", disc_value, has_positive_constraint=True) + set_param_value("scale_slope", NEAR_ZERO_SCALE, has_positive_constraint=True) + if "loc_disc" in param_store: + set_param_value("loc_disc", disc_value, has_positive_constraint=False) + set_param_value("scale_disc", NEAR_ZERO_SCALE, has_positive_constraint=True) + if loss < best_loss: best_loss = loss self.best_params = self.export(items) @@ -217,6 +276,10 @@ def train(self, *, epochs: Optional[int] = None, device: str = "cpu") -> None: f"{epoch + 1}", "%.4f" % loss, "%.4f" % best_loss, "%.4f" % current_lr ) self.last_params = self.export(items) + + # Clean up hooks after training + if anchor_zeroer.anchor_indices: + anchor_zeroer.remove_hooks() def export(self, items): if self.amortized: diff --git a/tests/test_anchor_items.py b/tests/test_anchor_items.py new file mode 100644 index 0000000..9b757ad --- /dev/null +++ b/tests/test_anchor_items.py @@ -0,0 +1,180 @@ +# MIT License + +# Copyright (c) 2019 John Lalor and Pedro Rodriguez + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import unittest +import pandas as pd +import numpy as np +import pyro +import torch + +from py_irt.dataset import Dataset +from py_irt.config import IrtConfig +from py_irt.training import IrtModelTrainer + + +class TestAnchorItems(unittest.TestCase): + """Test anchor items functionality""" + + def setUp(self): + """Set up a simple test dataset""" + # Create a simple dataset with known structure + # Each row is a unique subject, columns are items + self.df = pd.DataFrame({ + 'subject_id': ['s1', 's2', 's3', 's4', 's5', 's6', 's7', 's8', 's9', 's10', + 's11', 's12', 's13', 's14', 's15', 's16', 's17', 's18', 's19', 's20'], + 'item_1': [1, 1, 0, 1, 1, 0, 1, 0, 1, 1, 1, 0, 1, 0, 1, 1, 0, 1, 0, 1], + 'item_2': [1, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 1], + 'item_3': [0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0, 1, 0, 0], + 'item_4': [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + }) + + self.dataset = Dataset.from_pandas(self.df, subject_column='subject_id') + + def test_add_anchor_items(self): + """Test adding anchor items to a dataset""" + anchor_items = [ + {'item_id': 'item_1', 'difficulty': 0.5, 'discrimination': 1.2}, + {'item_id': 'item_3', 'difficulty': -0.8, 'discrimination': 0.9} + ] + + self.dataset.add_anchor_items(anchor_items) + + # Check that anchor items were added + self.assertIsNotNone(self.dataset.anchor_items) + self.assertEqual(len(self.dataset.anchor_items), 2) + + # Check anchor item properties + anchor_1 = self.dataset.anchor_items[0] + self.assertEqual(anchor_1.item_id, 'item_1') + self.assertEqual(anchor_1.difficulty, 0.5) + self.assertEqual(anchor_1.discrimination, 1.2) + + # Check anchor indices + anchor_indices = self.dataset.get_anchor_indices() + self.assertEqual(len(anchor_indices), 2) + + def test_anchor_items_invalid_id(self): + """Test that adding anchor items with invalid ID raises error""" + anchor_items = [ + {'item_id': 'invalid_item', 'difficulty': 0.5} + ] + + with self.assertRaises(ValueError): + self.dataset.add_anchor_items(anchor_items) + + def test_training_with_anchor_items(self): + """Test training with anchor items""" + # Add anchor items + anchor_items = [ + {'item_id': 'item_1', 'difficulty': 0.5, 'discrimination': 1.2}, + ] + self.dataset.add_anchor_items(anchor_items) + + # Clear Pyro param store + pyro.clear_param_store() + + # Create config with anchor initializer + config = IrtConfig( + model_type='2pl', + priors='vague', + epochs=10, + lr=0.1, + initializers=['anchor_items'] + ) + + # Train model + trainer = IrtModelTrainer( + data_path=None, + config=config, + dataset=self.dataset, + verbose=False + ) + + trainer.train(epochs=10, device='cpu') + + # Get final parameters + params = trainer.last_params + + # Check that anchor item parameters are close to their fixed values + anchor_ix = self.dataset.anchor_items[0].item_ix + difficulty = params['diff'][anchor_ix] + discrimination = params['disc'][anchor_ix] + + # Both difficulty and discrimination should stay very close to their fixed values + print(f"\nAnchor item parameters:") + print(f" Difficulty: expected=0.5, got={difficulty:.4f}") + print(f" Discrimination: expected=1.2, got={discrimination:.4f}") + + # Verify anchor parameters stayed fixed (allow small numerical error) + self.assertAlmostEqual(difficulty, 0.5, places=2, + msg=f"Difficulty should stay at 0.5, got {difficulty}") + self.assertAlmostEqual(discrimination, 1.2, places=2, + msg=f"Discrimination should stay at 1.2, got {discrimination}") + + # Verify non-anchor items have different parameters (not all zeros) + non_anchor_diffs = [params['diff'][i] for i in range(len(params['diff'])) if i != anchor_ix] + non_anchor_discs = [params['disc'][i] for i in range(len(params['disc'])) if i != anchor_ix] + + # At least one non-anchor item should have non-zero difficulty + has_nonzero_diff = any(abs(d) > 0.01 for d in non_anchor_diffs) + self.assertTrue(has_nonzero_diff, "At least one non-anchor item should have non-zero difficulty") + + def test_anchor_gradient_zeroer(self): + """Test that anchor gradient zeroer properly zeros gradients""" + from py_irt.anchor_utils import AnchorGradientZeroer + import pyro + + pyro.clear_param_store() + + # Create some test parameters + loc_diff = pyro.param('loc_diff', torch.zeros(4)) + loc_slope = pyro.param('loc_slope', torch.ones(4)) + + # Set up gradients + loc_diff.grad = torch.ones(4) + loc_slope.grad = torch.ones(4) * 2.0 + + # Create zeroer for anchor indices [0, 2] + zeroer = AnchorGradientZeroer( + anchor_indices=[0, 2], + param_names=['loc_diff', 'loc_slope'] + ) + + # Zero anchor gradients + zeroer.zero_anchor_gradients() + + # Check that anchor gradients are zeroed + self.assertEqual(loc_diff.grad[0].item(), 0.0) + self.assertEqual(loc_diff.grad[2].item(), 0.0) + self.assertEqual(loc_slope.grad[0].item(), 0.0) + self.assertEqual(loc_slope.grad[2].item(), 0.0) + + # Check that non-anchor gradients are unchanged + self.assertEqual(loc_diff.grad[1].item(), 1.0) + self.assertEqual(loc_diff.grad[3].item(), 1.0) + self.assertEqual(loc_slope.grad[1].item(), 2.0) + self.assertEqual(loc_slope.grad[3].item(), 2.0) + + +if __name__ == '__main__': + unittest.main() +