Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions backend/Generator/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
import os
import fitz
import mammoth
from pptx import Presentation
from utils.text_processor import TextProcessor

class MCQGenerator:

Expand Down Expand Up @@ -352,6 +354,7 @@ def get_document_content(self, document_url):
class FileProcessor:
def __init__(self, upload_folder='uploads/'):
self.upload_folder = upload_folder
self.text_processor = TextProcessor()
if not os.path.exists(self.upload_folder):
os.makedirs(self.upload_folder)

Expand All @@ -367,6 +370,29 @@ def extract_text_from_docx(self, file_path):
result = mammoth.extract_raw_text(docx_file)
return result.value

def extract_text_from_pptx(self, file_path):
"""Extract text from a .pptx PowerPoint file.

Iterates over every slide and shape, pulling text from text-frames
(titles, body placeholders, free text-boxes) and table cells.
"""
prs = Presentation(file_path)
text_parts = []
for slide in prs.slides:
for shape in slide.shapes:
if shape.has_text_frame:
for paragraph in shape.text_frame.paragraphs:
para_text = paragraph.text.strip()
if para_text:
text_parts.append(para_text)
if shape.has_table:
for row in shape.table.rows:
for cell in row.cells:
cell_text = cell.text.strip()
if cell_text:
text_parts.append(cell_text)
return "\n".join(text_parts)

def process_file(self, file):
file_path = os.path.join(self.upload_folder, file.filename)
file.save(file_path)
Expand All @@ -379,10 +405,36 @@ def process_file(self, file):
content = self.extract_text_from_pdf(file_path)
elif file.filename.endswith('.docx'):
content = self.extract_text_from_docx(file_path)
elif file.filename.endswith('.pptx'):
content = self.extract_text_from_pptx(file_path)
elif file.filename.endswith('.ppt'):
import logging
logging.warning(
"Legacy .ppt format is not supported. "
"Please convert to .pptx and try again."
)

os.remove(file_path)
return content

def process_file_chunked(self, file, chunk_size=1000, chunk_overlap=200):
"""Process file and return chunked text for large documents.

Returns a list of chunk dicts (see TextProcessor.chunk_document).
Falls back to an empty list when the file type is unsupported or
the extracted text is empty.
"""
content = self.process_file(file)
if not content:
return []

# Determine source type from filename
ext = os.path.splitext(file.filename)[1].lstrip('.').lower()
return self.text_processor.chunk_document(
content, source_type=ext or "unknown",
chunk_size=chunk_size, chunk_overlap=chunk_overlap,
)

class QuestionGenerator:
"""A transformer-based NLP system for generating reading comprehension-style questions from
texts. It can generate full sentence questions, multiple choice questions, or a mix of the
Expand Down
224 changes: 224 additions & 0 deletions backend/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,224 @@
"""Shared pytest fixtures for the EduAid backend test suite.

All heavy ML models, NLP pipelines, and external services are mocked so that
tests run instantly without a GPU or network access.

The session-scoped ``_prevent_model_initialization`` fixture patches heavy
constructors **before** ``server.py`` is first imported, preventing module-level
instantiations from triggering real model downloads or GPU usage.
"""

import sys
from unittest.mock import MagicMock, patch

import pytest


# ---------------------------------------------------------------------------
# Lightweight mock objects that replace the heavy ML classes
# ---------------------------------------------------------------------------

def _make_mcq_gen_mock():
mock = MagicMock()
mock.generate_mcq.return_value = {
"statement": "Mock statement",
"questions": [
{
"question_statement": "What is AI?",
"question_type": "MCQ",
"answer": "Artificial Intelligence",
"id": 1,
"options": ["Machine Learning", "Deep Learning", "Robotics"],
"options_algorithm": "sense2vec",
"extra_options": ["Neural Networks", "NLP"],
"context": "AI is the simulation of human intelligence.",
}
],
"time_taken": 0.01,
}
return mock


def _make_shortq_gen_mock():
mock = MagicMock()
mock.generate_shortq.return_value = {
"statement": "Mock statement",
"questions": [
{
"Question": "What is AI?",
"Answer": "Artificial Intelligence",
"id": 1,
"context": "AI is the simulation.",
}
],
}
return mock


def _make_boolq_gen_mock():
mock = MagicMock()
mock.generate_boolq.return_value = {
"Text": "Mock text",
"Count": 4,
"Boolean_Questions": [
"Is AI a simulation of human intelligence?",
"Does machine learning use algorithms?",
],
}
return mock


def _make_question_generator_mock():
mock = MagicMock()
mock.generate.return_value = [
{"question": "What is AI?", "answer": "Artificial Intelligence"},
{"question": "What is ML?", "answer": "Machine Learning"},
]
return mock


def _make_answer_predictor_mock():
mock = MagicMock()
mock.predict_boolean_answer.return_value = [True]
return mock


# ---------------------------------------------------------------------------
# Session-scoped fixture: prevent heavy model loading at import time
# ---------------------------------------------------------------------------

@pytest.fixture(scope="session", autouse=True)
def _prevent_model_initialization():
"""Patch heavy constructors *before* ``server.py`` is imported.

``server.py`` instantiates heavy objects at module level::

MCQGen = main.MCQGenerator() # loads T5-large, spacy, sense2vec
answer = main.AnswerPredictor() # loads T5-large, NLI model
BoolQGen = main.BoolQGenerator() # loads T5-base
ShortQGen = main.ShortQGenerator() # loads T5-large, spacy
qg = main.QuestionGenerator() # loads t5-base-question-generator
qa_model = pipeline("question-answering") # loads QA model

Without these patches, importing ``server.py`` would download several
gigabytes of model weights and require a GPU or significant RAM.
"""
import nltk
import transformers
import mediawikiapi as mwapi
import Generator.main as gen_main

patches = [
# Prevent NLTK data downloads
patch.object(nltk, "download", lambda *_a, **_kw: None),
# Prevent transformers QA pipeline loading
patch.object(transformers, "pipeline", return_value=MagicMock()),
# Prevent MediaWikiAPI instantiation
patch.object(mwapi, "MediaWikiAPI", return_value=MagicMock()),
# Prevent Generator class instantiation (each __init__ loads models)
patch.object(gen_main, "MCQGenerator", MagicMock),
patch.object(gen_main, "BoolQGenerator", MagicMock),
patch.object(gen_main, "ShortQGenerator", MagicMock),
patch.object(gen_main, "QuestionGenerator", MagicMock),
patch.object(gen_main, "AnswerPredictor", MagicMock),
patch.object(gen_main, "GoogleDocsService", MagicMock),
patch.object(gen_main, "FileProcessor", MagicMock),
]
for p in patches:
p.start()
yield
for p in patches:
p.stop()


# ---------------------------------------------------------------------------
# Per-test fixtures: override module-level mocks with specific behaviour
# ---------------------------------------------------------------------------

@pytest.fixture(autouse=True)
def _patch_heavy_imports(monkeypatch):
"""Per-test NLTK download suppression (belt-and-suspenders)."""
monkeypatch.setattr("nltk.download", lambda *_a, **_kw: None)


@pytest.fixture()
def mock_mcq_gen():
mock = _make_mcq_gen_mock()
with patch("server.MCQGen", mock):
yield mock


@pytest.fixture()
def mock_shortq_gen():
mock = _make_shortq_gen_mock()
with patch("server.ShortQGen", mock):
yield mock


@pytest.fixture()
def mock_boolq_gen():
mock = _make_boolq_gen_mock()
with patch("server.BoolQGen", mock):
yield mock


@pytest.fixture()
def mock_question_generator():
mock = _make_question_generator_mock()
with patch("server.qg", mock):
yield mock


@pytest.fixture()
def mock_answer_predictor():
mock = _make_answer_predictor_mock()
with patch("server.answer", mock):
yield mock


@pytest.fixture()
def mock_mediawiki():
mock = MagicMock()
mock.summary.return_value = "Expanded text from MediaWiki."
with patch("server.mediawikiapi", mock):
yield mock


@pytest.fixture()
def mock_qa_pipeline():
mock = MagicMock(return_value={"answer": "mocked answer", "score": 0.99})
with patch("server.qa_model", mock):
yield mock


@pytest.fixture()
def mock_file_processor():
mock = MagicMock()
mock.process_file.return_value = "Extracted text content"
with patch("server.file_processor", mock):
yield mock


@pytest.fixture()
def mock_docs_service():
mock = MagicMock()
mock.get_document_content.return_value = "Document content"
with patch("server.docs_service", mock):
yield mock


@pytest.fixture()
def app():
"""Create the Flask app for testing."""
sys.path.insert(0, ".")
from server import app as flask_app
flask_app.config["TESTING"] = True
return flask_app


@pytest.fixture()
def client(app, mock_mcq_gen, mock_shortq_gen, mock_boolq_gen,
mock_question_generator, mock_answer_predictor, mock_mediawiki,
mock_qa_pipeline, mock_file_processor, mock_docs_service):
"""A Flask test client with all ML models mocked."""
return app.test_client()
Loading