Skip to content

Conversation

@SteveAndCow
Copy link
Owner

This pull request introduces significant improvements to the local training pipeline for chess models, including a new training script, enhancements to the dataset processing, and updates to the transformer model architecture. The main themes are the addition of a training script, improvements to dataset handling, and architectural changes to the transformer model.

Training Pipeline Enhancements:

  • Added a new script local_training.py that sets up and runs a training loop for chess models using PyTorch, including model saving, evaluation, and dataloader setup. This script enables local training using either the ChessCNN or ChessTransformer models.

Dataset Handling Improvements:

  • Refactored fen_to_bitboards in stockfishdataset.py to output a 12x8x8 tensor for each board position, supporting both white and black to move, and added a StockfishDataset class compatible with PyTorch’s Dataset API. This makes data loading and preprocessing more robust and efficient for model training.
  • Updated the dataset script to process input files more efficiently, handle mate scores, normalize evaluation values based on the side to move, and save processed tensors for both bitboards and evaluations.

Transformer Model Architecture Updates:

  • Introduced a custom MultiheadAttention module that uses Flash Attention for improved efficiency, although it is not yet integrated into the encoder.
  • Increased the number of layers in the ChessTransformer model from 4 to 6, potentially improving model capacity and performance.
  • Commented out the use of relative positional bias and its associated attention mask in the transformer encoder layer, simplifying the attention mechanism for now.

SteveAndCow pushed a commit that referenced this pull request Nov 16, 2025
This commit addresses all critical and recommended fixes identified
during comprehensive code review of the batch inference optimizations.

CRITICAL FIXES:

1. Fix virtual loss underflow (Issue #1):
   - Virtual loss can now never go negative
   - Changed: node.virtual_loss -= 1 → max(0, node.virtual_loss - 1)
   - Prevents UCT score corruption if exception occurs
   - Location: src/main.py:194

2. Add FP16 support for batch tensors (Issue #2):
   - Batch tensors now match model dtype (FP16 if enabled)
   - Checks model parameters and converts batch accordingly
   - Fixes potential dtype mismatch in batched inference
   - Location: src/models/lc0_inference.py:345-347

3. Fix cache eviction logic (Issue #3):
   - Changed if to while loop for safe eviction
   - Added StopIteration and RuntimeError handling
   - Prevents crash if cache is empty or modified during iteration
   - Location: src/main.py:459-465

MEDIUM PRIORITY FIXES:

4. Validate batch evaluator results (Issue #5):
   - Added length check: len(values) must equal len(boards)
   - Raises RuntimeError if mismatch detected
   - Prevents silent truncation via zip()
   - Location: src/main.py:335-339

5. Fix entropy calculation edge cases (Issue #4):
   - Handle empty move_probs dict (default to 0.5 entropy)
   - Filter out zero probabilities before log()
   - Use 1e-10 instead of 1e-8 for better numerical stability
   - Avoid log(0) with max(1, len(move_probs))
   - Location: src/main.py:645-651

6. Remove dead code:
   - Deleted unused _cached_position_evaluation function
   - Removed lru_cache import that wasn't being used
   - Simplified to manual cache implementation only
   - Location: src/main.py:432

7. Add ultimate fallback to batch evaluator:
   - Two-level fallback: batch → sequential → neutral (0.0)
   - Prevents uncaught exceptions from crashing bot
   - Returns neutral evaluation if all else fails
   - Location: src/main.py:681-687

TESTING RECOMMENDATIONS:
- Test with FP16 model on GPU
- Verify cache stays under 10000 entries
- Test error handling with corrupted model
- Verify virtual loss never goes negative
- Test entropy calculation with edge cases

All fixes are backward compatible and maintain existing behavior
while adding safety checks and edge case handling.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants