Skip to content

Conversation

@aryarathoree
Copy link

Description

This PR adds centralized input validation for model forward() inputs to ensure that tensors follow the expected [batch, nodes, features] shape.

Previously, invalid input shapes could lead to cryptic PyTorch runtime errors or failures deeper in the computation graph. With this change, models fail fast with clear and informative error messages at the API boundary.

Unit tests are included to verify the validation logic without pulling in heavy optional dependencies.

Fixes #190

How Has This Been Tested?

  • Added focused unit tests for the input validation logic.
  • Ran the new tests locally using:
    pytest tests/test_input_validation.py
    

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Add explicit input validation and clearer error messages for model inputs

1 participant