[Backend] Fix device mismatch for NLI model in AnswerPredictor#441
[Backend] Fix device mismatch for NLI model in AnswerPredictor#441Siddhazntx wants to merge 2 commits intoAOSSIE-Org:mainfrom
Conversation
|
No actionable comments were generated in the recent review. 🎉 ℹ️ Recent review infoConfiguration used: defaults Review profile: CHILL Plan: Pro 📒 Files selected for processing (1)
📝 WalkthroughWalkthroughExplicitly move NLI models to the detected device and set them to eval(); ensure input tensors are moved to the model device and inference runs under torch.no_grad(), preventing device mismatches during prediction. Changes
Estimated code review effort🎯 3 (Moderate) | ⏱️ ~20 minutes Possibly related PRs
Poem
🚥 Pre-merge checks | ✅ 4 | ❌ 1❌ Failed checks (1 warning)
✅ Passed checks (4 passed)
✏️ Tip: You can configure your own custom pre-merge checks in the settings. ✨ Finishing Touches🧪 Generate unit tests (beta)
Tip Try Coding Plans. Let us write the prompt for your AI agent so you can ship faster (with fewer bugs). Thanks for using CodeRabbit! It's free for OSS, and your support helps us grow. If you like it, consider giving us a shout-out. Comment |
There was a problem hiding this comment.
Caution
Some comments are outside the diff and can’t be posted inline due to platform limitations.
⚠️ Outside diff range comments (2)
backend/Generator/main.py (2)
253-260:⚠️ Potential issue | 🟠 MajorAdd
.eval()to the NLI model after moving it to the device.The
.to(self.device)fix is correct. However,self.nli_modelis never put into eval mode. Every other inference model in this file calls.eval()immediately after.to(self.device)(seeself.qg_model.eval()at line 418 andself.qae_model.eval()at line 726). Without it, dropout layers remain active duringpredict_boolean_answer, producing non-deterministic NLI results.🛠️ Proposed fix
self.nli_model = AutoModelForSequenceClassification.from_pretrained(self.nli_model_name) # Explicitly push the NLI model to the detected hardware (GPU or CPU) self.nli_model.to(self.device) +self.nli_model.eval()🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@backend/Generator/main.py` around lines 253 - 260, The NLI model is moved to the device but not set to eval mode, causing nondeterministic behavior (dropout active) during predict_boolean_answer; after the existing self.nli_model.to(self.device) call, call self.nli_model.eval() to mirror how other inference models (self.qg_model.eval(), self.qae_model.eval()) are handled so the NLI model runs deterministically in inference.
296-323: 🛠️ Refactor suggestion | 🟠 MajorMissing
@torch.no_grad()onpredict_boolean_answer.Every other inference method in this file uses
@torch.no_grad()(see_generate_questionat line 646,_evaluate_qaat line 774). The NLI forward pass at line 309 will unnecessarily compute and retain gradient tensors, wasting GPU memory — especially relevant now that GPU execution is the target of this fix.♻️ Proposed fix
+@torch.no_grad() def predict_boolean_answer(self, payload):🤖 Prompt for AI Agents
Verify each finding against the current code and only fix it if needed. In `@backend/Generator/main.py` around lines 296 - 323, The predict_boolean_answer method is missing the `@torch.no_grad`() decorator so the NLI forward pass (nli_model(**inputs)) will compute and retain gradients; add `@torch.no_grad`() above def predict_boolean_answer(self, payload) to disable gradient tracking for the entire method, keeping the existing device handling (inputs = {key: value.to(self.device) ...} and torch.cuda.empty_cache()) unchanged so inference runs without allocating unnecessary GPU gradients.
🤖 Prompt for all review comments with AI agents
Verify each finding against the current code and only fix it if needed.
Outside diff comments:
In `@backend/Generator/main.py`:
- Around line 253-260: The NLI model is moved to the device but not set to eval
mode, causing nondeterministic behavior (dropout active) during
predict_boolean_answer; after the existing self.nli_model.to(self.device) call,
call self.nli_model.eval() to mirror how other inference models
(self.qg_model.eval(), self.qae_model.eval()) are handled so the NLI model runs
deterministically in inference.
- Around line 296-323: The predict_boolean_answer method is missing the
`@torch.no_grad`() decorator so the NLI forward pass (nli_model(**inputs)) will
compute and retain gradients; add `@torch.no_grad`() above def
predict_boolean_answer(self, payload) to disable gradient tracking for the
entire method, keeping the existing device handling (inputs = {key:
value.to(self.device) ...} and torch.cuda.empty_cache()) unchanged so inference
runs without allocating unnecessary GPU gradients.
|
@coderabbitai I've applied the fixes in main.py:
Verified locally with test_server.py and boolean predictions are working as expected. |
|
Tip For best results, initiate chat on the files or code changes. 🧩 Analysis chain🏁 Script executed: Length of output: 30370
✅ Line 254: inputs = {key: value.to(self.device) for key, value in inputs.items()}These are all excellent PyTorch best practices that ensure:
Great job testing locally with If you found this review helpful, would you consider giving us a shout-out on X? Thank you for using CodeRabbit! |
Addressed Issues:
Closes #442
The NLI model in AnswerPredictor was not explicitly moved to the detected device (CPU/GPU), and input tensors were not aligned with the model device.
Screenshots/Recordings:
N/A - This is a backend architectural fix.
Additional Notes:
The Issue:
While reviewing the backend model loading, I noticed that the distilbert-base-uncased-mnli model in the AnswerPredictor class wasn't being pushed to the hardware device during initialization. Additionally, its input tensors were defaulting to the CPU during prediction.
The fix:
Added .to(self.device) to both the NLI model initialization and the input tensors. This ensures the model actually utilizes the GPU when available and prevents potential PyTorch tensor mismatch crashes (RuntimeError: Expected all tensors to be on the same device).
Note on Testing:
I successfully tested the device synchronization locally and ran the official test_server.py suite. All generation endpoints pass successfully with my fix. During testing, I observed a pre-existing failure in the test_server.py suite on the current main branch that is unrelated to this change. I will open a separate Issue/PR to address that independently.
Checklist
Summary by CodeRabbit
Refactor
Chores