Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion backend/Generator/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# Constructor for questgen
from __future__ import absolute_import
from Generator.main import MCQGenerator, BoolQGenerator, ShortQGenerator, AnswerPredictor, GoogleDocsService, FileProcessor, QuestionGenerator
from utils.file_processor import FileProcessor
from Generator.main import MCQGenerator, BoolQGenerator, ShortQGenerator, AnswerPredictor, GoogleDocsService, QuestionGenerator
35 changes: 1 addition & 34 deletions backend/Generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import os
import fitz
import mammoth
from pptx import Presentation

class MCQGenerator:

Expand Down Expand Up @@ -349,40 +350,6 @@ def get_document_content(self, document_url):
return text.strip()


class FileProcessor:
def __init__(self, upload_folder='uploads/'):
self.upload_folder = upload_folder
if not os.path.exists(self.upload_folder):
os.makedirs(self.upload_folder)

def extract_text_from_pdf(self, file_path):
doc = fitz.open(file_path)
text = ""
for page in doc:
text += page.get_text()
return text

def extract_text_from_docx(self, file_path):
with open(file_path, "rb") as docx_file:
result = mammoth.extract_raw_text(docx_file)
return result.value

def process_file(self, file):
file_path = os.path.join(self.upload_folder, file.filename)
file.save(file_path)
content = ""

if file.filename.endswith('.txt'):
with open(file_path, 'r') as f:
content = f.read()
elif file.filename.endswith('.pdf'):
content = self.extract_text_from_pdf(file_path)
elif file.filename.endswith('.docx'):
content = self.extract_text_from_docx(file_path)

os.remove(file_path)
return content

class QuestionGenerator:
"""A transformer-based NLP system for generating reading comprehension-style questions from
texts. It can generate full sentence questions, multiple choice questions, or a mix of the
Expand Down
223 changes: 223 additions & 0 deletions backend/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,223 @@
"""Shared pytest fixtures for the EduAid backend test suite.

All heavy ML models, NLP pipelines, and external services are mocked so that
tests run instantly without a GPU or network access.

The session-scoped ``_prevent_model_initialization`` fixture patches heavy
constructors **before** ``server.py`` is first imported, preventing module-level
instantiations from triggering real model downloads or GPU usage.
"""

import sys
from unittest.mock import MagicMock, patch

import pytest


# ---------------------------------------------------------------------------
# Lightweight mock objects that replace the heavy ML classes
# ---------------------------------------------------------------------------

def _make_mcq_gen_mock():
mock = MagicMock()
mock.generate_mcq.return_value = {
"statement": "Mock statement",
"questions": [
{
"question_statement": "What is AI?",
"question_type": "MCQ",
"answer": "Artificial Intelligence",
"id": 1,
"options": ["Machine Learning", "Deep Learning", "Robotics"],
"options_algorithm": "sense2vec",
"extra_options": ["Neural Networks", "NLP"],
"context": "AI is the simulation of human intelligence.",
}
],
"time_taken": 0.01,
}
return mock


def _make_shortq_gen_mock():
mock = MagicMock()
mock.generate_shortq.return_value = {
"statement": "Mock statement",
"questions": [
{
"Question": "What is AI?",
"Answer": "Artificial Intelligence",
"id": 1,
"context": "AI is the simulation.",
}
],
}
return mock


def _make_boolq_gen_mock():
mock = MagicMock()
mock.generate_boolq.return_value = {
"Text": "Mock text",
"Count": 4,
"Boolean_Questions": [
"Is AI a simulation of human intelligence?",
"Does machine learning use algorithms?",
],
}
return mock


def _make_question_generator_mock():
mock = MagicMock()
mock.generate.return_value = [
{"question": "What is AI?", "answer": "Artificial Intelligence"},
{"question": "What is ML?", "answer": "Machine Learning"},
]
return mock


def _make_answer_predictor_mock():
mock = MagicMock()
mock.predict_boolean_answer.return_value = [True]
return mock


# ---------------------------------------------------------------------------
# Session-scoped fixture: prevent heavy model loading at import time
# ---------------------------------------------------------------------------

@pytest.fixture(scope="session", autouse=True)
def _prevent_model_initialization():
"""Patch heavy constructors *before* ``server.py`` is imported.

``server.py`` instantiates heavy objects at module level::

MCQGen = main.MCQGenerator() # loads T5-large, spacy, sense2vec
answer = main.AnswerPredictor() # loads T5-large, NLI model
BoolQGen = main.BoolQGenerator() # loads T5-base
ShortQGen = main.ShortQGenerator() # loads T5-large, spacy
qg = main.QuestionGenerator() # loads t5-base-question-generator
qa_model = pipeline("question-answering") # loads QA model

Without these patches, importing ``server.py`` would download several
gigabytes of model weights and require a GPU or significant RAM.
"""
import nltk
import transformers
import mediawikiapi as mwapi
import Generator.main as gen_main

patches = [
# Prevent NLTK data downloads
patch.object(nltk, "download", lambda *_a, **_kw: None),
# Prevent transformers QA pipeline loading
patch.object(transformers, "pipeline", return_value=MagicMock()),
# Prevent MediaWikiAPI instantiation
patch.object(mwapi, "MediaWikiAPI", return_value=MagicMock()),
# Prevent Generator class instantiation (each __init__ loads models)
patch.object(gen_main, "MCQGenerator", MagicMock),
patch.object(gen_main, "BoolQGenerator", MagicMock),
patch.object(gen_main, "ShortQGenerator", MagicMock),
patch.object(gen_main, "QuestionGenerator", MagicMock),
patch.object(gen_main, "AnswerPredictor", MagicMock),
patch.object(gen_main, "GoogleDocsService", MagicMock),
]
for p in patches:
p.start()
yield
for p in patches:
p.stop()


# ---------------------------------------------------------------------------
# Per-test fixtures: override module-level mocks with specific behaviour
# ---------------------------------------------------------------------------

@pytest.fixture(autouse=True)
def _patch_heavy_imports(monkeypatch):
"""Per-test NLTK download suppression (belt-and-suspenders)."""
monkeypatch.setattr("nltk.download", lambda *_a, **_kw: None)


@pytest.fixture()
def mock_mcq_gen():
mock = _make_mcq_gen_mock()
with patch("server.MCQGen", mock):
yield mock


@pytest.fixture()
def mock_shortq_gen():
mock = _make_shortq_gen_mock()
with patch("server.ShortQGen", mock):
yield mock


@pytest.fixture()
def mock_boolq_gen():
mock = _make_boolq_gen_mock()
with patch("server.BoolQGen", mock):
yield mock


@pytest.fixture()
def mock_question_generator():
mock = _make_question_generator_mock()
with patch("server.qg", mock):
yield mock


@pytest.fixture()
def mock_answer_predictor():
mock = _make_answer_predictor_mock()
with patch("server.answer", mock):
yield mock


@pytest.fixture()
def mock_mediawiki():
mock = MagicMock()
mock.summary.return_value = "Expanded text from MediaWiki."
with patch("server.mediawikiapi", mock):
yield mock


@pytest.fixture()
def mock_qa_pipeline():
mock = MagicMock(return_value={"answer": "mocked answer", "score": 0.99})
with patch("server.qa_model", mock):
yield mock


@pytest.fixture()
def mock_file_processor():
mock = MagicMock()
mock.process_file.return_value = "Extracted text content"
with patch("server.file_processor", mock):
yield mock


@pytest.fixture()
def mock_docs_service():
mock = MagicMock()
mock.get_document_content.return_value = "Document content"
with patch("server.docs_service", mock):
yield mock


@pytest.fixture()
def app():
"""Create the Flask app for testing."""
sys.path.insert(0, ".")
from server import app as flask_app
flask_app.config["TESTING"] = True
return flask_app


@pytest.fixture()
def client(app, mock_mcq_gen, mock_shortq_gen, mock_boolq_gen,
mock_question_generator, mock_answer_predictor, mock_mediawiki,
mock_qa_pipeline, mock_file_processor, mock_docs_service):
"""A Flask test client with all ML models mocked."""
return app.test_client()
Loading