diff --git a/backend/Generator/main.py b/backend/Generator/main.py index 04aed79f..1f944d2c 100644 --- a/backend/Generator/main.py +++ b/backend/Generator/main.py @@ -22,18 +22,57 @@ import os import fitz import mammoth +import threading + + + + +class ModelManager: + """Singleton class to load and share massive ML models across generators.""" + _instance = None + _is_initialized = False + _lock = threading.Lock() + + def __new__(cls): + if cls._instance is None: + with cls._lock: + if cls._instance is None: + cls._instance = super(ModelManager, cls).__new__(cls) + return cls._instance + + def __init__(self): + if self._is_initialized: + return + + with self._lock: + if not self._is_initialized: + print("Initializing Shared ModelManager... Loading massive models into memory ONCE.") + self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + + self.qg_tokenizer = T5Tokenizer.from_pretrained('t5-large') + self.qg_model = T5ForConditionalGeneration.from_pretrained('Roasters/Question-Generator') + self.qg_model.to(self.device) + self.qg_model.eval() + + self.nlp = spacy.load('en_core_web_sm') + self.s2v = Sense2Vec().from_disk('s2v_old') + self.fdist = FreqDist(brown.words()) + self.normalized_levenshtein = NormalizedLevenshtein() + + self._is_initialized = True + class MCQGenerator: def __init__(self): - self.tokenizer = T5Tokenizer.from_pretrained('t5-large') - self.model = T5ForConditionalGeneration.from_pretrained('Roasters/Question-Generator') - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model.to(self.device) - self.nlp = spacy.load('en_core_web_sm') - self.s2v = Sense2Vec().from_disk('s2v_old') - self.fdist = FreqDist(brown.words()) - self.normalized_levenshtein = NormalizedLevenshtein() + manager = ModelManager() + self.tokenizer = manager.qg_tokenizer + self.model = manager.qg_model + self.device = manager.device + self.nlp = manager.nlp + self.s2v = manager.s2v + self.fdist = manager.fdist + self.normalized_levenshtein = manager.normalized_levenshtein self.set_seed(42) def set_seed(self, seed): @@ -84,14 +123,14 @@ def generate_mcq(self, payload): class ShortQGenerator: def __init__(self): - self.tokenizer = T5Tokenizer.from_pretrained('t5-large') - self.model = T5ForConditionalGeneration.from_pretrained('Roasters/Question-Generator') - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model.to(self.device) - self.nlp = spacy.load('en_core_web_sm') - self.s2v = Sense2Vec().from_disk('s2v_old') - self.fdist = FreqDist(brown.words()) - self.normalized_levenshtein = NormalizedLevenshtein() + manager = ModelManager() + self.tokenizer = manager.qg_tokenizer + self.model = manager.qg_model + self.device = manager.device + self.nlp = manager.nlp + self.s2v = manager.s2v + self.fdist = manager.fdist + self.normalized_levenshtein = manager.normalized_levenshtein self.set_seed(42) def set_seed(self, seed): @@ -135,10 +174,10 @@ def generate_shortq(self, payload): class ParaphraseGenerator: def __init__(self): - self.tokenizer = T5Tokenizer.from_pretrained('t5-large') - self.model = T5ForConditionalGeneration.from_pretrained('Roasters/Question-Generator') - self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") - self.model.to(self.device) + manager = ModelManager() + self.tokenizer = manager.qg_tokenizer + self.model = manager.qg_model + self.device = manager.device self.set_seed(42) def set_seed(self, seed): @@ -251,6 +290,10 @@ def __init__(self): self.nli_tokenizer = AutoTokenizer.from_pretrained(self.nli_model_name) self.nli_model = AutoModelForSequenceClassification.from_pretrained(self.nli_model_name) + # Explicitly push the NLI model to the detected hardware (GPU or CPU) + self.nli_model.to(self.device) + self.nli_model.eval() + self.set_seed(42) def set_seed(self, seed): @@ -286,7 +329,8 @@ def predict_answer(self, payload): torch.cuda.empty_cache() return answers - + + @torch.no_grad() def predict_boolean_answer(self, payload): input_text = payload.get("input_text", "") input_questions = payload.get("input_question", []) @@ -296,6 +340,10 @@ def predict_boolean_answer(self, payload): for question in input_questions: hypothesis = question inputs = self.nli_tokenizer.encode_plus(input_text, hypothesis, return_tensors="pt") + + # Push the input tensors to the same device as the model + inputs = {key: value.to(self.device) for key, value in inputs.items()} + outputs = self.nli_model(**inputs) logits = outputs.logits probabilities = torch.softmax(logits, dim=1) diff --git a/backend/test_server.py b/backend/test_server.py index 7a4bd38f..d7ed3962 100644 --- a/backend/test_server.py +++ b/backend/test_server.py @@ -73,8 +73,8 @@ def test_root(): print(f'Root Endpoint Response: {response.text}') assert response.status_code == 200 -def test_get_answer(): - endpoint = '/get_answer' +def test_get_shortq_answer(): + endpoint = '/get_shortq_answer' data = { 'input_text': input_text, 'input_question': [ @@ -85,7 +85,7 @@ def test_get_answer(): ] } response = make_post_request(endpoint, data) - print(f'/get_answer Response: {response}') + print(f"{endpoint} Response: {response}") assert 'output' in response def test_get_boolean_answer(): @@ -114,5 +114,5 @@ def make_post_request(endpoint, data): test_get_shortq() test_get_problems() test_root() - test_get_answer() + test_get_shortq_answer() test_get_boolean_answer()