diff --git a/backend/Generator/main.py b/backend/Generator/main.py index 04aed79f..d048f196 100644 --- a/backend/Generator/main.py +++ b/backend/Generator/main.py @@ -1,8 +1,13 @@ import time import torch import random -from transformers import T5ForConditionalGeneration, T5Tokenizer -from transformers import AutoModelForSequenceClassification, AutoTokenizer,AutoModelForSeq2SeqLM, T5ForConditionalGeneration, T5Tokenizer +from transformers import ( + AutoModelForSequenceClassification, + AutoModelForSeq2SeqLM, + AutoTokenizer, + T5ForConditionalGeneration, + T5Tokenizer, +) import numpy as np import spacy from sense2vec import Sense2Vec @@ -14,14 +19,16 @@ from Generator.encoding import beam_search_decoding from google.oauth2 import service_account from googleapiclient.discovery import build +from werkzeug.utils import secure_filename import en_core_web_sm import json import re from typing import Any, List, Mapping, Tuple -import re import os import fitz import mammoth +import uuid + class MCQGenerator: @@ -31,7 +38,7 @@ def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.nlp = spacy.load('en_core_web_sm') - self.s2v = Sense2Vec().from_disk('s2v_old') + self.s2v = None self.fdist = FreqDist(brown.words()) self.normalized_levenshtein = NormalizedLevenshtein() self.set_seed(42) @@ -53,7 +60,15 @@ def generate_mcq(self, payload): sentences = tokenize_into_sentences(text) modified_text = " ".join(sentences) - keywords = identify_keywords(self.nlp, modified_text, inp['max_questions'], self.s2v, self.fdist, self.normalized_levenshtein, len(sentences)) + keywords = identify_keywords( + self.nlp, + modified_text, + inp['max_questions'], + None, # disable sense2vec + self.fdist, + self.normalized_levenshtein, + len(sentences) + ) keyword_sentence_mapping = find_sentences_with_keywords(keywords, sentences) for k in keyword_sentence_mapping.keys(): @@ -89,7 +104,7 @@ def __init__(self): self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") self.model.to(self.device) self.nlp = spacy.load('en_core_web_sm') - self.s2v = Sense2Vec().from_disk('s2v_old') + self.s2v = None self.fdist = FreqDist(brown.words()) self.normalized_levenshtein = NormalizedLevenshtein() self.set_seed(42) @@ -110,7 +125,15 @@ def generate_shortq(self, payload): sentences = tokenize_into_sentences(text) modified_text = " ".join(sentences) - keywords = identify_keywords(self.nlp, modified_text, inp['max_questions'], self.s2v, self.fdist, self.normalized_levenshtein, len(sentences)) + keywords = identify_keywords( + self.nlp, + modified_text, + inp['max_questions'], + None, # disable sense2vec + self.fdist, + self.normalized_levenshtein, + len(sentences) + ) keyword_sentence_mapping = find_sentences_with_keywords(keywords, sentences) for k in keyword_sentence_mapping.keys(): @@ -160,7 +183,12 @@ def generate_paraphrase(self, payload): sentence = text text_to_paraphrase = "paraphrase: " + sentence + " " - encoding = self.tokenizer.encode_plus(text_to_paraphrase, pad_to_max_length=True, return_tensors="pt") + encoding = self.tokenizer.encode_plus( + text_to_paraphrase, + padding="max_length", + truncation=True, + return_tensors="pt" + ) input_ids, attention_masks = encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device) beam_outputs = self.model.generate( @@ -171,7 +199,7 @@ def generate_paraphrase(self, payload): num_return_sequences=num, no_repeat_ngram_size=2, early_stopping=True - ) + ) final_outputs =[] for beam_output in beam_outputs: @@ -208,7 +236,6 @@ def random_choice(self): a = random.choice([0,1]) return bool(a) - def generate_boolq(self, payload): start_time = time.time() inp = { @@ -226,7 +253,7 @@ def generate_boolq(self, payload): encoding = self.tokenizer.encode_plus(form, return_tensors="pt") input_ids, attention_masks = encoding["input_ids"].to(self.device), encoding["attention_mask"].to(self.device) - output = beam_search_decoding (input_ids, attention_masks, self.model, self.tokenizer,num) + output = beam_search_decoding(input_ids, attention_masks, self.model, self.tokenizer, num) if self.device.type == 'cuda': torch.cuda.empty_cache() @@ -237,7 +264,6 @@ def generate_boolq(self, payload): return final - class AnswerPredictor: def __init__(self): @@ -267,11 +293,10 @@ def greedy_decoding(self, inp_ids, attn_mask): def predict_answer(self, payload): answers = [] inp = { - "input_text": payload.get("input_text"), - "input_question" : payload.get("input_question") - } + "input_text": payload.get("input_text"), + "input_question": payload.get("input_question") + } for ques in payload.get("input_question"): - context = inp["input_text"] question = ques input_text = "question: %s context: %s " % (question, context) @@ -348,7 +373,6 @@ def get_document_content(self, document_url): return text.strip() - class FileProcessor: def __init__(self, upload_folder='uploads/'): self.upload_folder = upload_folder @@ -367,21 +391,82 @@ def extract_text_from_docx(self, file_path): result = mammoth.extract_raw_text(docx_file) return result.value + def extract_text_from_image(self, file_path): + try: + import cv2 + import pytesseract + import shutil + except ImportError as e: + raise RuntimeError( + "OCR requires opencv-python and pytesseract installed." + ) from e + + image = cv2.imread(file_path) + if image is None: + return "" + + gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) + thresh = cv2.adaptiveThreshold( + gray, + 255, + cv2.ADAPTIVE_THRESH_GAUSSIAN_C, + cv2.THRESH_BINARY, + 11, + 2 + ) + + # Cross-platform Tesseract discovery + tesseract_cmd = os.getenv("TESSERACT_CMD") + if tesseract_cmd: + pytesseract.pytesseract.tesseract_cmd = tesseract_cmd + else: + detected = shutil.which("tesseract") + if detected: + pytesseract.pytesseract.tesseract_cmd = detected + + text = pytesseract.image_to_string(thresh) + return text.strip() + def process_file(self, file): - file_path = os.path.join(self.upload_folder, file.filename) - file.save(file_path) - content = "" + safe_name = secure_filename(file.filename or "") + if not safe_name: + return "" + + unique_name = f"{uuid.uuid4().hex}_{safe_name}" + file_path = os.path.join(self.upload_folder, unique_name) + # Extra safety check (prevents ../ traversal) + abs_upload = os.path.abspath(self.upload_folder) + abs_path = os.path.abspath(file_path) - if file.filename.endswith('.txt'): - with open(file_path, 'r') as f: - content = f.read() - elif file.filename.endswith('.pdf'): - content = self.extract_text_from_pdf(file_path) - elif file.filename.endswith('.docx'): - content = self.extract_text_from_docx(file_path) + if not abs_path.startswith(abs_upload): + return "" - os.remove(file_path) - return content + file.save(file_path) + content = "" + filename = safe_name.lower() + + try: + if filename.endswith('.txt'): + with open(file_path, 'r', encoding='utf-8') as f: + content = f.read() + elif filename.endswith('.pdf'): + content = self.extract_text_from_pdf(file_path) + elif filename.endswith('.docx'): + content = self.extract_text_from_docx(file_path) + elif filename.endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff')): + content = self.extract_text_from_image(file_path) + return content + + except Exception: + return "" + + + finally: + try: + if os.path.exists(file_path): + os.remove(file_path) + except OSError: + pass class QuestionGenerator: """A transformer-based NLP system for generating reading comprehension-style questions from @@ -803,4 +888,4 @@ def print_qa(qa_list: List[Mapping[str, str]], show_answers: bool = True) -> Non # print full sentence answers else: if show_answers: - print(f"{space}A: {answer}\n") + print(f"{space}A: {answer}\n") \ No newline at end of file diff --git a/backend/Generator/mcq.py b/backend/Generator/mcq.py index e2c82954..a1432ae1 100644 --- a/backend/Generator/mcq.py +++ b/backend/Generator/mcq.py @@ -12,14 +12,25 @@ nltk.download('stopwords') nltk.download('popular') -def is_word_available(word, s2v_model): - word = word.replace(" ", "_") - sense = s2v_model.get_best_sense(word) - if sense is not None: +def is_word_available(word, s2v_model, fdist, normalized_levenshtein): + """ + Checks if a word is valid for question generation. + Safely handles s2v_model=None. + """ + + # If sense2vec disabled, skip sense check + if s2v_model is None: return True - else: + + try: + sense = s2v_model.get_best_sense(word) + if sense is None: + return False + except Exception: return False + return True + def generate_word_variations(word): letters = 'abcdefghijklmnopqrstuvwxyz ' + string.punctuation splits = [(word[:i], word[i:]) for i in range(len(word) + 1)] @@ -55,6 +66,8 @@ def find_similar_words(word, s2v_model): return out def get_answer_choices(answer, s2v_model): + if s2v_model is None: + return [], "None" choices = [] try: @@ -177,8 +190,12 @@ def generate_multiple_choice_questions(keyword_sent_mapping, device, tokenizer, text = context + " " + "answer: " + answer + " " batch_text.append(text) - encoding = tokenizer.batch_encode_plus(batch_text, pad_to_max_length=True, return_tensors="pt") - + encoding = tokenizer.batch_encode_plus( + batch_text, + padding="max_length", + truncation=True, + return_tensors="pt" + ) print("Generating questions using the model...") input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device) @@ -223,8 +240,12 @@ def generate_normal_questions(keyword_sent_mapping, device, tokenizer, model): text = context + " " + "answer: " + answer + " " batch_text.append(text) - encoding = tokenizer.batch_encode_plus(batch_text, pad_to_max_length=True, return_tensors="pt") - + encoding = tokenizer.batch_encode_plus( + batch_text, + padding="max_length", + truncation=True, + return_tensors="pt" + ) print("Running model for generation...") input_ids, attention_masks = encoding["input_ids"].to(device), encoding["attention_mask"].to(device) diff --git a/backend/Generator/rag.py b/backend/Generator/rag.py new file mode 100644 index 00000000..53753310 --- /dev/null +++ b/backend/Generator/rag.py @@ -0,0 +1,186 @@ +import faiss +import numpy as np +from sentence_transformers import SentenceTransformer +from transformers import T5Tokenizer, T5ForConditionalGeneration +import torch + + +class RAGService: + + def __init__(self): + self.current_text = None + + # Embedding model + self.embedder = SentenceTransformer( + "sentence-transformers/all-MiniLM-L6-v2" + ) + + # Generator model + self.tokenizer = T5Tokenizer.from_pretrained("t5-base") + self.generator = T5ForConditionalGeneration.from_pretrained("t5-base") + + self.device = torch.device( + "cuda" if torch.cuda.is_available() else "cpu" + ) + self.generator.to(self.device) + + # FAISS state + self.index = None + self.text_chunks = [] + self.dimension = None + + # --------------------------------------------------- + # TEXT CHUNKING + # --------------------------------------------------- + def chunk_text(self, text, chunk_size=400, overlap=50): + + if chunk_size <= 0: + raise ValueError("chunk_size must be > 0") + + if overlap < 0 or overlap >= chunk_size: + raise ValueError( + "overlap must be >= 0 and < chunk_size" + ) + + words = text.split() + chunks = [] + + step = chunk_size - overlap + + for i in range(0, len(words), step): + chunk = words[i:i + chunk_size] + chunks.append(" ".join(chunk)) + + return chunks + + # --------------------------------------------------- + # SAFE ATOMIC INDEXING + # --------------------------------------------------- + def index_text(self, text): + + # Skip only if state is already valid + if ( + self.current_text == text + and self.index is not None + and self.text_chunks + ): + return + + try: + # Build temporary chunks + temp_chunks = self.chunk_text(text) + + if not temp_chunks: + self.current_text = text + self.text_chunks = [] + self.index = None + self.dimension = None + return + + # Generate embeddings + temp_embeddings = self.embedder.encode( + temp_chunks, + convert_to_numpy=True + ) + + # Normalize for cosine similarity + faiss.normalize_L2(temp_embeddings) + temp_embeddings = temp_embeddings.astype("float32") + + temp_dimension = temp_embeddings.shape[1] + temp_index = faiss.IndexFlatIP(temp_dimension) + temp_index.add(temp_embeddings) + + # 🔐 Commit state only after success + self.current_text = text + self.text_chunks = temp_chunks + self.dimension = temp_dimension + self.index = temp_index + + except Exception: + # Prevent corrupted state + self.index = None + self.text_chunks = [] + self.dimension = None + raise + + # --------------------------------------------------- + # QUERY WITH MEMORY + # --------------------------------------------------- + def query(self, question, chat_history=None, top_k=3): + + if self.index is None: + return "No document indexed." + + # Embed question + question_embedding = self.embedder.encode( + [question], + convert_to_numpy=True + ) + + faiss.normalize_L2(question_embedding) + question_embedding = question_embedding.astype("float32") + + top_k = min(top_k, len(self.text_chunks)) + + distances, indices = self.index.search( + question_embedding, + top_k + ) + + retrieved_chunks = [ + self.text_chunks[i] + for i in indices[0] + if i < len(self.text_chunks) + ] + + context = " ".join(retrieved_chunks) + + # Build conversation history + history_text = "" + + if chat_history: + for turn in chat_history: + role = turn.get("role") + message = turn.get("message") + + if role == "user": + history_text += f"User: {message}\n" + elif role == "assistant": + history_text += f"Assistant: {message}\n" + + # Final prompt + input_text = f""" +You are a helpful educational assistant. + +Use the provided context to answer the question. +If the answer is not found in the context, say you don't know. + +Context: +{context} + +Conversation History: +{history_text} + +Current Question: +{question} +""" + + encoding = self.tokenizer( + input_text, + return_tensors="pt", + truncation=True, + max_length=512 + ).to(self.device) + + output = self.generator.generate( + **encoding, + max_length=150 + ) + + answer = self.tokenizer.decode( + output[0], + skip_special_tokens=True + ) + + return answer \ No newline at end of file diff --git a/backend/server.py b/backend/server.py index 1c9efaa4..3d58b0c4 100644 --- a/backend/server.py +++ b/backend/server.py @@ -12,6 +12,7 @@ nltk.download('punkt_tab') from Generator import main from Generator.question_filters import make_question_harder +from threading import Lock import re import json import spacy @@ -25,10 +26,14 @@ from httplib2 import Http from oauth2client import client, file, tools from mediawikiapi import MediaWikiAPI +from Generator.rag import RAGService + app = Flask(__name__) CORS(app) -print("Starting Flask App...") +rag_service = None +rag_lock = Lock() +print("RAG SERVICE INITIALIZED SUCCESSFULLY") SERVICE_ACCOUNT_FILE = './service_account_key.json' SCOPES = ['https://www.googleapis.com/auth/documents.readonly'] @@ -39,8 +44,10 @@ ShortQGen = main.ShortQGenerator() qg = main.QuestionGenerator() docs_service = main.GoogleDocsService(SERVICE_ACCOUNT_FILE, SCOPES) + file_processor = main.FileProcessor() mediawikiapi = MediaWikiAPI() + qa_model = pipeline("question-answering") @@ -191,8 +198,9 @@ def get_content(): return jsonify(text) except ValueError as e: return jsonify({'error': str(e)}), 400 - except Exception as e: - return jsonify({'error': str(e)}), 500 + except Exception: + app.logger.exception("Google Docs content error") + return jsonify({"error": "Internal server error"}), 500 @app.route("/generate_gform", methods=["POST"]) @@ -490,6 +498,40 @@ def get_transcript(): return jsonify({"transcript": transcript_text}) + +@app.route("/chat", methods=["POST"]) +def chat(): + global rag_service + + try: + data = request.get_json() + + document_text = data.get("document_text") + question = data.get("question") + chat_history = data.get("chat_history", []) + + # Input validation + if not document_text or not document_text.strip(): + return jsonify({"error": "Document text is required"}), 400 + + if not question or not question.strip(): + return jsonify({"error": "Question is required"}), 400 + + with rag_lock: + if rag_service is None: + rag_service = RAGService() + + rag_service.index_text(document_text) + answer = rag_service.query(question, chat_history) + + return jsonify({ + "answer": answer, + "status": "success" + }) + + except Exception: + app.logger.exception("Chat endpoint error") + return jsonify({"error": "Internal server error"}), 500 + if __name__ == "__main__": - os.makedirs("subtitles", exist_ok=True) - app.run() + app.run(host="0.0.0.0", port=5000, debug=False) \ No newline at end of file