From 34b227579c936319ebecda79ee137268f0bc0fa3 Mon Sep 17 00:00:00 2001 From: Zohaib Shahid Date: Fri, 13 Feb 2026 23:16:33 +0500 Subject: [PATCH 1/5] fix(backend): handle Wikipedia/MediaWiki SSL failures gracefully (fixes #428) When use_mediawiki=1 is passed but the MediaWiki API call fails (SSL error, timeout, network unreachable, etc.), the entire request crashed with a 500 response. Users saw a generic error and lost their input. Changes: - Wrap mediawikiapi.summary() in try/except inside process_input_text() - On failure, log a warning and continue with the original input text - Return a "warning" field in the JSON response so the frontend can notify users that Wikipedia enrichment was skipped - All 7 endpoints that use MediaWiki are updated: /get_mcq, /get_boolq, /get_shortq, /get_problems, /get_shortq_hard, /get_mcq_hard, /get_boolq_hard - Add conftest.py with session-scoped fixture to prevent heavy ML model loading during tests - Add 30 new pytest tests covering SSL errors, connection errors, timeouts, successful calls, and text preservation Co-authored-by: Cursor --- backend/conftest.py | 224 +++++++++++++++++++++++ backend/server.py | 90 ++++++--- backend/test_wikipedia_fallback.py | 282 +++++++++++++++++++++++++++++ 3 files changed, 575 insertions(+), 21 deletions(-) create mode 100644 backend/conftest.py create mode 100644 backend/test_wikipedia_fallback.py diff --git a/backend/conftest.py b/backend/conftest.py new file mode 100644 index 00000000..c36a5f7e --- /dev/null +++ b/backend/conftest.py @@ -0,0 +1,224 @@ +"""Shared pytest fixtures for the EduAid backend test suite. + +All heavy ML models, NLP pipelines, and external services are mocked so that +tests run instantly without a GPU or network access. + +The session-scoped ``_prevent_model_initialization`` fixture patches heavy +constructors **before** ``server.py`` is first imported, preventing module-level +instantiations from triggering real model downloads or GPU usage. +""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Lightweight mock objects that replace the heavy ML classes +# --------------------------------------------------------------------------- + +def _make_mcq_gen_mock(): + mock = MagicMock() + mock.generate_mcq.return_value = { + "statement": "Mock statement", + "questions": [ + { + "question_statement": "What is AI?", + "question_type": "MCQ", + "answer": "Artificial Intelligence", + "id": 1, + "options": ["Machine Learning", "Deep Learning", "Robotics"], + "options_algorithm": "sense2vec", + "extra_options": ["Neural Networks", "NLP"], + "context": "AI is the simulation of human intelligence.", + } + ], + "time_taken": 0.01, + } + return mock + + +def _make_shortq_gen_mock(): + mock = MagicMock() + mock.generate_shortq.return_value = { + "statement": "Mock statement", + "questions": [ + { + "Question": "What is AI?", + "Answer": "Artificial Intelligence", + "id": 1, + "context": "AI is the simulation.", + } + ], + } + return mock + + +def _make_boolq_gen_mock(): + mock = MagicMock() + mock.generate_boolq.return_value = { + "Text": "Mock text", + "Count": 4, + "Boolean_Questions": [ + "Is AI a simulation of human intelligence?", + "Does machine learning use algorithms?", + ], + } + return mock + + +def _make_question_generator_mock(): + mock = MagicMock() + mock.generate.return_value = [ + {"question": "What is AI?", "answer": "Artificial Intelligence"}, + {"question": "What is ML?", "answer": "Machine Learning"}, + ] + return mock + + +def _make_answer_predictor_mock(): + mock = MagicMock() + mock.predict_boolean_answer.return_value = [True] + return mock + + +# --------------------------------------------------------------------------- +# Session-scoped fixture: prevent heavy model loading at import time +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="session", autouse=True) +def _prevent_model_initialization(): + """Patch heavy constructors *before* ``server.py`` is imported. + + ``server.py`` instantiates heavy objects at module level:: + + MCQGen = main.MCQGenerator() # loads T5-large, spacy, sense2vec + answer = main.AnswerPredictor() # loads T5-large, NLI model + BoolQGen = main.BoolQGenerator() # loads T5-base + ShortQGen = main.ShortQGenerator() # loads T5-large, spacy + qg = main.QuestionGenerator() # loads t5-base-question-generator + qa_model = pipeline("question-answering") # loads QA model + + Without these patches, importing ``server.py`` would download several + gigabytes of model weights and require a GPU or significant RAM. + """ + import nltk + import transformers + import mediawikiapi as mwapi + import Generator.main as gen_main + + patches = [ + # Prevent NLTK data downloads + patch.object(nltk, "download", lambda *_a, **_kw: None), + # Prevent transformers QA pipeline loading + patch.object(transformers, "pipeline", return_value=MagicMock()), + # Prevent MediaWikiAPI instantiation + patch.object(mwapi, "MediaWikiAPI", return_value=MagicMock()), + # Prevent Generator class instantiation (each __init__ loads models) + patch.object(gen_main, "MCQGenerator", MagicMock), + patch.object(gen_main, "BoolQGenerator", MagicMock), + patch.object(gen_main, "ShortQGenerator", MagicMock), + patch.object(gen_main, "QuestionGenerator", MagicMock), + patch.object(gen_main, "AnswerPredictor", MagicMock), + patch.object(gen_main, "GoogleDocsService", MagicMock), + patch.object(gen_main, "FileProcessor", MagicMock), + ] + for p in patches: + p.start() + yield + for p in patches: + p.stop() + + +# --------------------------------------------------------------------------- +# Per-test fixtures: override module-level mocks with specific behaviour +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _patch_heavy_imports(monkeypatch): + """Per-test NLTK download suppression (belt-and-suspenders).""" + monkeypatch.setattr("nltk.download", lambda *_a, **_kw: None) + + +@pytest.fixture() +def mock_mcq_gen(): + mock = _make_mcq_gen_mock() + with patch("server.MCQGen", mock): + yield mock + + +@pytest.fixture() +def mock_shortq_gen(): + mock = _make_shortq_gen_mock() + with patch("server.ShortQGen", mock): + yield mock + + +@pytest.fixture() +def mock_boolq_gen(): + mock = _make_boolq_gen_mock() + with patch("server.BoolQGen", mock): + yield mock + + +@pytest.fixture() +def mock_question_generator(): + mock = _make_question_generator_mock() + with patch("server.qg", mock): + yield mock + + +@pytest.fixture() +def mock_answer_predictor(): + mock = _make_answer_predictor_mock() + with patch("server.answer", mock): + yield mock + + +@pytest.fixture() +def mock_mediawiki(): + mock = MagicMock() + mock.summary.return_value = "Expanded text from MediaWiki." + with patch("server.mediawikiapi", mock): + yield mock + + +@pytest.fixture() +def mock_qa_pipeline(): + mock = MagicMock(return_value={"answer": "mocked answer", "score": 0.99}) + with patch("server.qa_model", mock): + yield mock + + +@pytest.fixture() +def mock_file_processor(): + mock = MagicMock() + mock.process_file.return_value = "Extracted text content" + with patch("server.file_processor", mock): + yield mock + + +@pytest.fixture() +def mock_docs_service(): + mock = MagicMock() + mock.get_document_content.return_value = "Document content" + with patch("server.docs_service", mock): + yield mock + + +@pytest.fixture() +def app(): + """Create the Flask app for testing.""" + sys.path.insert(0, ".") + from server import app as flask_app + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture() +def client(app, mock_mcq_gen, mock_shortq_gen, mock_boolq_gen, + mock_question_generator, mock_answer_predictor, mock_mediawiki, + mock_qa_pipeline, mock_file_processor, mock_docs_service): + """A Flask test client with all ML models mocked.""" + return app.test_client() diff --git a/backend/server.py b/backend/server.py index 1c9efaa4..11a967ca 100644 --- a/backend/server.py +++ b/backend/server.py @@ -5,6 +5,7 @@ import subprocess import os import glob +import logging from sklearn.metrics.pairwise import cosine_similarity from sklearn.feature_extraction.text import TfidfVectorizer @@ -28,7 +29,15 @@ app = Flask(__name__) CORS(app) -print("Starting Flask App...") + +# Logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) +logger.info("Starting Flask App...") SERVICE_ACCOUNT_FILE = './service_account_key.json' SCOPES = ['https://www.googleapis.com/auth/documents.readonly'] @@ -45,9 +54,29 @@ def process_input_text(input_text, use_mediawiki): + """Optionally expand *input_text* via MediaWiki if the flag is set. + + Returns ``(processed_text, warning)`` where *warning* is ``None`` when + Wikipedia enrichment succeeded (or was not requested), and a human-readable + string when enrichment was skipped due to a network / SSL / API error. + + This ensures quiz generation continues even when the external MediaWiki + service is unreachable (see issue #428). + """ if use_mediawiki == 1: - input_text = mediawikiapi.summary(input_text,8) - return input_text + try: + input_text = mediawikiapi.summary(input_text, 8) + except Exception: + logger.warning( + "Wikipedia enrichment failed – continuing without it.", + exc_info=True, + ) + return input_text, ( + "Wikipedia enrichment was requested but could not be completed " + "due to a network error. Results were generated from the " + "original input text only." + ) + return input_text, None @app.route("/get_mcq", methods=["POST"]) @@ -56,12 +85,15 @@ def get_mcq(): input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) max_questions = data.get("max_questions", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output = MCQGen.generate_mcq( {"input_text": input_text, "max_questions": max_questions} ) questions = output["questions"] - return jsonify({"output": questions}) + result = {"output": questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_boolq", methods=["POST"]) @@ -70,12 +102,15 @@ def get_boolq(): input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) max_questions = data.get("max_questions", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output = BoolQGen.generate_boolq( {"input_text": input_text, "max_questions": max_questions} ) boolean_questions = output["Boolean_Questions"] - return jsonify({"output": boolean_questions}) + result = {"output": boolean_questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_shortq", methods=["POST"]) @@ -84,12 +119,15 @@ def get_shortq(): input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) max_questions = data.get("max_questions", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output = ShortQGen.generate_shortq( {"input_text": input_text, "max_questions": max_questions} ) questions = output["questions"] - return jsonify({"output": questions}) + result = {"output": questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_problems", methods=["POST"]) @@ -100,7 +138,7 @@ def get_problems(): max_questions_mcq = data.get("max_questions_mcq", 4) max_questions_boolq = data.get("max_questions_boolq", 4) max_questions_shortq = data.get("max_questions_shortq", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output1 = MCQGen.generate_mcq( {"input_text": input_text, "max_questions": max_questions_mcq} ) @@ -110,9 +148,10 @@ def get_problems(): output3 = ShortQGen.generate_shortq( {"input_text": input_text, "max_questions": max_questions_shortq} ) - return jsonify( - {"output_mcq": output1, "output_boolq": output2, "output_shortq": output3} - ) + result = {"output_mcq": output1, "output_boolq": output2, "output_shortq": output3} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_mcq_answer", methods=["POST"]) def get_mcq_answer(): @@ -369,7 +408,7 @@ def get_shortq_hard(): data = request.get_json() input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) - input_text = process_input_text(input_text,use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) input_questions = data.get("input_question", []) output = qg.generate( @@ -379,7 +418,10 @@ def get_shortq_hard(): for item in output: item["question"] = make_question_harder(item["question"]) - return jsonify({"output": output}) + result = {"output": output} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_mcq_hard", methods=["POST"]) @@ -387,16 +429,19 @@ def get_mcq_hard(): data = request.get_json() input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) - input_text = process_input_text(input_text,use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) input_questions = data.get("input_question", []) output = qg.generate( article=input_text, num_questions=input_questions, answer_style="multiple_choice" ) - + for q in output: q["question"] = make_question_harder(q["question"]) - - return jsonify({"output": output}) + + result = {"output": output} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_boolq_hard", methods=["POST"]) def get_boolq_hard(): @@ -405,7 +450,7 @@ def get_boolq_hard(): use_mediawiki = data.get("use_mediawiki", 0) input_questions = data.get("input_question", []) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) # Generate questions using the same QG model generated = qg.generate( @@ -417,7 +462,10 @@ def get_boolq_hard(): # Apply transformation to make each question harder harder_questions = [make_question_harder(q) for q in generated] - return jsonify({"output": harder_questions}) + result = {"output": harder_questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route('/upload', methods=['POST']) def upload_file(): diff --git a/backend/test_wikipedia_fallback.py b/backend/test_wikipedia_fallback.py new file mode 100644 index 00000000..ff63b386 --- /dev/null +++ b/backend/test_wikipedia_fallback.py @@ -0,0 +1,282 @@ +"""Tests for graceful Wikipedia / MediaWiki fallback (issue #428). + +When ``use_mediawiki=1`` is passed but the MediaWiki API call fails +(SSL error, timeout, network unreachable, etc.), every endpoint should: + + 1. Still return HTTP 200 with valid quiz data. + 2. Include a ``warning`` key in the JSON response. + 3. Generate questions from the *original* input text (not crash). + +Uses **pytest** with Flask's built-in test client. All heavy ML models +are mocked in ``conftest.py`` – no running server or GPU required. +""" + +from unittest.mock import patch, MagicMock + +import pytest + +# --------------------------------------------------------------------------- +# Shared test data +# --------------------------------------------------------------------------- + +SAMPLE_TEXT = ( + "Artificial intelligence (AI) is the simulation of human intelligence " + "processes by machines, especially computer systems. These processes " + "include learning, reasoning, and self-correction. AI applications " + "include speech recognition, natural language processing, machine " + "vision, expert systems, and robotics. Machine learning is a subset " + "of AI that focuses on algorithms that learn from data." +) + + +# =========================================================================== +# SSL / network errors produce a warning, not a 500 +# =========================================================================== + + +class TestWikipediaFallbackBasicEndpoints: + """Test the three primary question-generation endpoints.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_ssl_error_returns_200_with_warning(self, client, endpoint): + """SSLError in MediaWiki should NOT crash the request.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = Exception( + "SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol'))" + ) + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + assert "Wikipedia" in data["warning"] or "network" in data["warning"] + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_connection_error_returns_200_with_warning(self, client, endpoint): + """ConnectionError in MediaWiki should NOT crash the request.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = ConnectionError("Network unreachable") + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_timeout_returns_200_with_warning(self, client, endpoint): + """TimeoutError in MediaWiki should NOT crash the request.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = TimeoutError("Connection timed out") + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + +class TestWikipediaFallbackHardEndpoints: + """Test the three hard-question endpoints.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_shortq_hard", "/get_mcq_hard", "/get_boolq_hard", + ]) + def test_ssl_error_hard_endpoints_returns_200_with_warning(self, client, endpoint): + """SSL failure should NOT crash hard-question endpoints.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = ConnectionError("Network unreachable") + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + +class TestWikipediaFallbackProblems: + """Test the combined /get_problems endpoint.""" + + def test_ssl_error_get_problems_returns_200_with_warning(self, client): + """SSL failure should NOT crash the combined /get_problems endpoint.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = TimeoutError("Connection timed out") + resp = client.post( + "/get_problems", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output_mcq" in data + assert "output_boolq" in data + assert "output_shortq" in data + assert "warning" in data + + +# =========================================================================== +# Successful MediaWiki calls should NOT have a warning +# =========================================================================== + + +class TestWikipediaSuccessNoWarning: + """When MediaWiki succeeds, no warning should be present.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_successful_wiki_call_has_no_warning(self, client, endpoint): + """Successful MediaWiki call should NOT include a warning.""" + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" not in data + + def test_successful_wiki_get_problems_no_warning(self, client): + """Successful MediaWiki call on /get_problems should NOT include a warning.""" + resp = client.post( + "/get_problems", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "warning" not in data + + +# =========================================================================== +# Without use_mediawiki, no warning should appear +# =========================================================================== + + +class TestNoMediawikiNoWarning: + """When use_mediawiki is not set (or 0), no warning should appear.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_no_mediawiki_no_warning(self, client, endpoint): + """Without use_mediawiki=1, no warning should appear.""" + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 0}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "warning" not in data + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_default_no_mediawiki_no_warning(self, client, endpoint): + """When use_mediawiki is not sent, no warning should appear.""" + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "warning" not in data + + +# =========================================================================== +# Various exception types are all caught +# =========================================================================== + + +class TestVariousExceptionTypes: + """All exception types from MediaWiki should be caught and handled.""" + + @pytest.mark.parametrize("exc_class,exc_msg", [ + (ConnectionError, "Network unreachable"), + (TimeoutError, "Connection timed out"), + (OSError, "SSL: CERTIFICATE_VERIFY_FAILED"), + (RuntimeError, "Unexpected mediawiki error"), + (Exception, "SSLEOFError EOF occurred in violation of protocol"), + ]) + def test_various_exceptions_handled_on_mcq(self, client, exc_class, exc_msg): + """All exception types from MediaWiki should be caught on /get_mcq.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = exc_class(exc_msg) + resp = client.post( + "/get_mcq", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + +# =========================================================================== +# Original text is preserved when MediaWiki fails +# =========================================================================== + + +class TestTextPreservation: + """Verify the correct text is forwarded to the generator.""" + + def test_original_text_used_on_failure(self, client): + """When MediaWiki fails, generation should use the original text.""" + with patch("server.mediawikiapi") as wiki_mock, \ + patch("server.MCQGen") as mcq_mock: + wiki_mock.summary.side_effect = ConnectionError("fail") + mcq_mock.generate_mcq.return_value = { + "questions": [{ + "question_statement": "Test?", + "answer": "Yes", + "id": 1, + "options": ["No", "Maybe"], + "extra_options": [], + "context": "ctx", + }] + } + resp = client.post( + "/get_mcq", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + # Verify generate_mcq was called with the *original* text + call_args = mcq_mock.generate_mcq.call_args + assert call_args[0][0]["input_text"] == SAMPLE_TEXT + + def test_enriched_text_used_on_success(self, client): + """When MediaWiki succeeds, enriched text is forwarded to generation.""" + enriched = "Enriched text from Wikipedia about AI and machine learning." + with patch("server.mediawikiapi") as wiki_mock, \ + patch("server.MCQGen") as mcq_mock: + wiki_mock.summary.return_value = enriched + mcq_mock.generate_mcq.return_value = { + "questions": [{ + "question_statement": "Test?", + "answer": "Yes", + "id": 1, + "options": ["No", "Maybe"], + "extra_options": [], + "context": "ctx", + }] + } + resp = client.post( + "/get_mcq", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + call_args = mcq_mock.generate_mcq.call_args + assert call_args[0][0]["input_text"] == enriched From c0c207762148e97e8587bcd0a3d942adc0095bd3 Mon Sep 17 00:00:00 2001 From: Zohaib Shahid Date: Fri, 13 Feb 2026 23:29:04 +0500 Subject: [PATCH 2/5] fix: address CodeRabbitAI review feedback - Use defensive .get() with default [] on output["questions"] and output["Boolean_Questions"] to prevent KeyError when generator returns an empty dict (major) - Rename test_ssl_error_hard_endpoints to test_connection_error_hard_endpoints to match actual exception (minor) Co-authored-by: Cursor --- backend/server.py | 6 +++--- backend/test_wikipedia_fallback.py | 4 ++-- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/backend/server.py b/backend/server.py index 11a967ca..feffd6ef 100644 --- a/backend/server.py +++ b/backend/server.py @@ -89,7 +89,7 @@ def get_mcq(): output = MCQGen.generate_mcq( {"input_text": input_text, "max_questions": max_questions} ) - questions = output["questions"] + questions = output.get("questions", []) result = {"output": questions} if wiki_warning: result["warning"] = wiki_warning @@ -106,7 +106,7 @@ def get_boolq(): output = BoolQGen.generate_boolq( {"input_text": input_text, "max_questions": max_questions} ) - boolean_questions = output["Boolean_Questions"] + boolean_questions = output.get("Boolean_Questions", []) result = {"output": boolean_questions} if wiki_warning: result["warning"] = wiki_warning @@ -123,7 +123,7 @@ def get_shortq(): output = ShortQGen.generate_shortq( {"input_text": input_text, "max_questions": max_questions} ) - questions = output["questions"] + questions = output.get("questions", []) result = {"output": questions} if wiki_warning: result["warning"] = wiki_warning diff --git a/backend/test_wikipedia_fallback.py b/backend/test_wikipedia_fallback.py index ff63b386..94a9373b 100644 --- a/backend/test_wikipedia_fallback.py +++ b/backend/test_wikipedia_fallback.py @@ -95,8 +95,8 @@ class TestWikipediaFallbackHardEndpoints: @pytest.mark.parametrize("endpoint", [ "/get_shortq_hard", "/get_mcq_hard", "/get_boolq_hard", ]) - def test_ssl_error_hard_endpoints_returns_200_with_warning(self, client, endpoint): - """SSL failure should NOT crash hard-question endpoints.""" + def test_connection_error_hard_endpoints_returns_200_with_warning(self, client, endpoint): + """Connection failure should NOT crash hard-question endpoints.""" with patch("server.mediawikiapi") as wiki_mock: wiki_mock.summary.side_effect = ConnectionError("Network unreachable") resp = client.post( From 5c9f086e8b7d13b42c672987a57d215787ed80e2 Mon Sep 17 00:00:00 2001 From: Zohaib Shahid Date: Thu, 26 Feb 2026 22:16:52 +0500 Subject: [PATCH 3/5] feat: add PPTX file upload support for quiz generation (#361) - Add python-pptx dependency to requirements.txt - Implement extract_text_from_pptx() in FileProcessor class - Extracts text from text frames (titles, text boxes, placeholders) - Extracts text from table cells across all slides - Update process_file() to route .pptx files to the new extractor - Update frontend labels and file accept filters in: - eduaid_web (Text_Input.jsx) - extension (TextInput.jsx) - Add 6 unit tests covering: - Simple text extraction - Table text extraction - Empty presentations - process_file routing for .pptx - Unsupported .ppt rejection - Multi-slide extraction Closes #361 --- backend/Generator/main.py | 26 +++ backend/test_pptx_extraction.py | 160 +++++++++++++++++++ eduaid_web/src/pages/Text_Input.jsx | 6 +- extension/src/pages/text_input/TextInput.jsx | 19 +-- requirements.txt | 1 + 5 files changed, 200 insertions(+), 12 deletions(-) create mode 100644 backend/test_pptx_extraction.py diff --git a/backend/Generator/main.py b/backend/Generator/main.py index 04aed79f..9d6afdcd 100644 --- a/backend/Generator/main.py +++ b/backend/Generator/main.py @@ -22,6 +22,7 @@ import os import fitz import mammoth +from pptx import Presentation class MCQGenerator: @@ -367,6 +368,29 @@ def extract_text_from_docx(self, file_path): result = mammoth.extract_raw_text(docx_file) return result.value + def extract_text_from_pptx(self, file_path): + """Extract text from a .pptx PowerPoint file. + + Iterates over every slide and shape, pulling text from text-frames + (titles, body placeholders, free text-boxes) and table cells. + """ + prs = Presentation(file_path) + text_parts = [] + for slide in prs.slides: + for shape in slide.shapes: + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + para_text = paragraph.text.strip() + if para_text: + text_parts.append(para_text) + if shape.has_table: + for row in shape.table.rows: + for cell in row.cells: + cell_text = cell.text.strip() + if cell_text: + text_parts.append(cell_text) + return "\n".join(text_parts) + def process_file(self, file): file_path = os.path.join(self.upload_folder, file.filename) file.save(file_path) @@ -379,6 +403,8 @@ def process_file(self, file): content = self.extract_text_from_pdf(file_path) elif file.filename.endswith('.docx'): content = self.extract_text_from_docx(file_path) + elif file.filename.endswith('.pptx'): + content = self.extract_text_from_pptx(file_path) os.remove(file_path) return content diff --git a/backend/test_pptx_extraction.py b/backend/test_pptx_extraction.py new file mode 100644 index 00000000..45adc95b --- /dev/null +++ b/backend/test_pptx_extraction.py @@ -0,0 +1,160 @@ +"""Tests for PPTX text extraction in FileProcessor. + +We test the FileProcessor class in isolation – without importing the full +Generator.main module which would trigger heavy ML model loading. Instead +we import only the minimal dependencies and re-create the class here. +""" +import os +import pytest +from pptx import Presentation +from pptx.util import Inches +from unittest.mock import MagicMock + + +# ── Lightweight re-creation of FileProcessor (extraction logic only) ───────── + +class FileProcessor: + """Mirror of Generator.main.FileProcessor – text extraction methods only.""" + + def __init__(self, upload_folder='uploads/'): + self.upload_folder = upload_folder + if not os.path.exists(self.upload_folder): + os.makedirs(self.upload_folder) + + def extract_text_from_pptx(self, file_path): + """Extract text from a .pptx PowerPoint file.""" + prs = Presentation(file_path) + text_parts = [] + for slide in prs.slides: + for shape in slide.shapes: + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + para_text = paragraph.text.strip() + if para_text: + text_parts.append(para_text) + if shape.has_table: + for row in shape.table.rows: + for cell in row.cells: + cell_text = cell.text.strip() + if cell_text: + text_parts.append(cell_text) + return "\n".join(text_parts) + + def process_file(self, file): + file_path = os.path.join(self.upload_folder, file.filename) + file.save(file_path) + content = "" + + if file.filename.endswith('.txt'): + with open(file_path, 'r') as f: + content = f.read() + elif file.filename.endswith('.pptx'): + content = self.extract_text_from_pptx(file_path) + + os.remove(file_path) + return content + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _create_pptx(tmp_path, filename, texts): + """Create a minimal .pptx with one text-box per item in *texts*.""" + prs = Presentation() + slide_layout = prs.slide_layouts[6] # blank layout + slide = prs.slides.add_slide(slide_layout) + for t in texts: + txBox = slide.shapes.add_textbox(Inches(1), Inches(1), Inches(5), Inches(1)) + txBox.text_frame.text = t + path = os.path.join(tmp_path, filename) + prs.save(path) + return path + + +def _create_pptx_with_table(tmp_path, filename, rows_data): + """Create a .pptx containing a single table.""" + prs = Presentation() + slide = prs.slides.add_slide(prs.slide_layouts[6]) + cols = max(len(r) for r in rows_data) + table_shape = slide.shapes.add_table(len(rows_data), cols, Inches(1), Inches(1), Inches(6), Inches(2)) + table = table_shape.table + for ri, row in enumerate(rows_data): + for ci, cell_text in enumerate(row): + table.cell(ri, ci).text = cell_text + path = os.path.join(tmp_path, filename) + prs.save(path) + return path + + +# ── tests ──────────────────────────────────────────────────────────────────── + +class TestExtractTextFromPptx: + + def test_simple_text(self, tmp_path): + path = _create_pptx(tmp_path, "simple.pptx", ["Hello World", "Second box"]) + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + assert "Hello World" in result + assert "Second box" in result + + def test_table_text(self, tmp_path): + rows = [["Name", "Score"], ["Alice", "95"], ["Bob", "88"]] + path = _create_pptx_with_table(tmp_path, "table.pptx", rows) + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + for cell in ["Name", "Score", "Alice", "95", "Bob", "88"]: + assert cell in result + + def test_empty_presentation(self, tmp_path): + prs = Presentation() + prs.slides.add_slide(prs.slide_layouts[6]) # blank slide, no shapes + path = os.path.join(tmp_path, "empty.pptx") + prs.save(path) + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + assert result == "" + + def test_process_file_routes_pptx(self, tmp_path): + """process_file() should call extract_text_from_pptx for .pptx files.""" + _create_pptx(tmp_path, "routed.pptx", ["Route test"]) + fp = FileProcessor(upload_folder=str(tmp_path)) + + mock_file = MagicMock() + mock_file.filename = "routed.pptx" + # save() just copies the file into the upload folder (it's already there) + mock_file.save = MagicMock(side_effect=lambda dest: None) + + # Place the file where process_file expects it + _create_pptx(tmp_path, "routed.pptx", ["Route test"]) + result = fp.process_file(mock_file) + assert "Route test" in result + + def test_unsupported_ppt_extension(self, tmp_path): + """Legacy .ppt files are not supported; process_file returns empty.""" + dummy = os.path.join(tmp_path, "old.ppt") + with open(dummy, "w") as f: + f.write("not a real ppt") + + fp = FileProcessor(upload_folder=str(tmp_path)) + mock_file = MagicMock() + mock_file.filename = "old.ppt" + # File is already in upload_folder (tmp_path), so save is a no-op + mock_file.save = MagicMock(side_effect=lambda dest: None) + + result = fp.process_file(mock_file) + assert result == "" + + def test_multiple_slides(self, tmp_path): + """Text from multiple slides should all be extracted.""" + prs = Presentation() + for text in ["Slide 1 content", "Slide 2 content", "Slide 3 content"]: + slide = prs.slides.add_slide(prs.slide_layouts[6]) + txBox = slide.shapes.add_textbox(Inches(1), Inches(1), Inches(5), Inches(1)) + txBox.text_frame.text = text + path = os.path.join(tmp_path, "multi.pptx") + prs.save(path) + + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + assert "Slide 1 content" in result + assert "Slide 2 content" in result + assert "Slide 3 content" in result diff --git a/eduaid_web/src/pages/Text_Input.jsx b/eduaid_web/src/pages/Text_Input.jsx index e341d331..acf0ca99 100644 --- a/eduaid_web/src/pages/Text_Input.jsx +++ b/eduaid_web/src/pages/Text_Input.jsx @@ -5,7 +5,7 @@ import stars from "../assets/stars.png"; import cloud from "../assets/cloud.png"; import { FaClipboard } from "react-icons/fa"; import Switch from "react-switch"; -import { Link,useNavigate } from "react-router-dom"; +import { Link, useNavigate } from "react-router-dom"; import apiClient from "../utils/apiClient"; const Text_Input = () => { @@ -185,9 +185,9 @@ const Text_Input = () => { {/* File Upload Section */}
cloud -

Choose a file (PDF, MP3 supported)

+

Choose a file (PDF, PPTX, MP3 supported)

- +
- + > + +
diff --git a/requirements.txt b/requirements.txt index 717390b7..77bd1f29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,7 @@ google-auth datasets==3.1.0 tokenizers mammoth +python-pptx mediawikiapi PyMuPDF textblob \ No newline at end of file From 28a237d7cdd1f0baf377d835807a251d5f8b9b99 Mon Sep 17 00:00:00 2001 From: Zohaib Shahid Date: Thu, 26 Feb 2026 22:50:57 +0500 Subject: [PATCH 4/5] fix: address CodeRabbit review feedback - Add explicit .ppt handling with warning log for unsupported legacy format - Rewrite tests to import the real FileProcessor from Generator.main (mocking heavy ML dependencies via sys.modules to avoid model loading) - Sync upload hint text with accept filter in both web and extension ('PDF, PPTX, TXT, DOCX, MP3 supported') --- backend/Generator/main.py | 6 ++ backend/test_pptx_extraction.py | 83 +++++++------------- eduaid_web/src/pages/Text_Input.jsx | 2 +- extension/src/pages/text_input/TextInput.jsx | 2 +- 4 files changed, 37 insertions(+), 56 deletions(-) diff --git a/backend/Generator/main.py b/backend/Generator/main.py index 9d6afdcd..fd6539ad 100644 --- a/backend/Generator/main.py +++ b/backend/Generator/main.py @@ -405,6 +405,12 @@ def process_file(self, file): content = self.extract_text_from_docx(file_path) elif file.filename.endswith('.pptx'): content = self.extract_text_from_pptx(file_path) + elif file.filename.endswith('.ppt'): + import logging + logging.warning( + "Legacy .ppt format is not supported. " + "Please convert to .pptx and try again." + ) os.remove(file_path) return content diff --git a/backend/test_pptx_extraction.py b/backend/test_pptx_extraction.py index 45adc95b..fe453733 100644 --- a/backend/test_pptx_extraction.py +++ b/backend/test_pptx_extraction.py @@ -1,58 +1,32 @@ """Tests for PPTX text extraction in FileProcessor. -We test the FileProcessor class in isolation – without importing the full -Generator.main module which would trigger heavy ML model loading. Instead -we import only the minimal dependencies and re-create the class here. +Imports the real FileProcessor from Generator.main. Heavy ML dependencies +(torch, transformers, sense2vec, etc.) are mocked at the sys.modules level +so they are not actually loaded during test collection. """ import os +import sys +from unittest.mock import MagicMock + import pytest from pptx import Presentation from pptx.util import Inches -from unittest.mock import MagicMock +# ── Mock heavy ML dependencies before importing Generator.main ─────────────── +# This prevents model loading while still testing the real FileProcessor code. -# ── Lightweight re-creation of FileProcessor (extraction logic only) ───────── - -class FileProcessor: - """Mirror of Generator.main.FileProcessor – text extraction methods only.""" - - def __init__(self, upload_folder='uploads/'): - self.upload_folder = upload_folder - if not os.path.exists(self.upload_folder): - os.makedirs(self.upload_folder) - - def extract_text_from_pptx(self, file_path): - """Extract text from a .pptx PowerPoint file.""" - prs = Presentation(file_path) - text_parts = [] - for slide in prs.slides: - for shape in slide.shapes: - if shape.has_text_frame: - for paragraph in shape.text_frame.paragraphs: - para_text = paragraph.text.strip() - if para_text: - text_parts.append(para_text) - if shape.has_table: - for row in shape.table.rows: - for cell in row.cells: - cell_text = cell.text.strip() - if cell_text: - text_parts.append(cell_text) - return "\n".join(text_parts) - - def process_file(self, file): - file_path = os.path.join(self.upload_folder, file.filename) - file.save(file_path) - content = "" - - if file.filename.endswith('.txt'): - with open(file_path, 'r') as f: - content = f.read() - elif file.filename.endswith('.pptx'): - content = self.extract_text_from_pptx(file_path) - - os.remove(file_path) - return content +_mock_torch = MagicMock() +_mock_torch.cuda.is_available.return_value = False +sys.modules.setdefault("torch", _mock_torch) +sys.modules.setdefault("transformers", MagicMock()) +sys.modules.setdefault("sense2vec", MagicMock()) +sys.modules.setdefault("google.oauth2", MagicMock()) +sys.modules.setdefault("google.oauth2.service_account", MagicMock()) +sys.modules.setdefault("googleapiclient", MagicMock()) +sys.modules.setdefault("googleapiclient.discovery", MagicMock()) +sys.modules.setdefault("en_core_web_sm", MagicMock()) + +from Generator.main import FileProcessor # noqa: E402 # ── helpers ────────────────────────────────────────────────────────────────── @@ -60,8 +34,7 @@ def process_file(self, file): def _create_pptx(tmp_path, filename, texts): """Create a minimal .pptx with one text-box per item in *texts*.""" prs = Presentation() - slide_layout = prs.slide_layouts[6] # blank layout - slide = prs.slides.add_slide(slide_layout) + slide = prs.slides.add_slide(prs.slide_layouts[6]) # blank layout for t in texts: txBox = slide.shapes.add_textbox(Inches(1), Inches(1), Inches(5), Inches(1)) txBox.text_frame.text = t @@ -75,7 +48,9 @@ def _create_pptx_with_table(tmp_path, filename, rows_data): prs = Presentation() slide = prs.slides.add_slide(prs.slide_layouts[6]) cols = max(len(r) for r in rows_data) - table_shape = slide.shapes.add_table(len(rows_data), cols, Inches(1), Inches(1), Inches(6), Inches(2)) + table_shape = slide.shapes.add_table( + len(rows_data), cols, Inches(1), Inches(1), Inches(6), Inches(2) + ) table = table_shape.table for ri, row in enumerate(rows_data): for ci, cell_text in enumerate(row): @@ -114,17 +89,15 @@ def test_empty_presentation(self, tmp_path): assert result == "" def test_process_file_routes_pptx(self, tmp_path): - """process_file() should call extract_text_from_pptx for .pptx files.""" + """process_file() should route .pptx files to extract_text_from_pptx.""" _create_pptx(tmp_path, "routed.pptx", ["Route test"]) fp = FileProcessor(upload_folder=str(tmp_path)) mock_file = MagicMock() mock_file.filename = "routed.pptx" - # save() just copies the file into the upload folder (it's already there) + # File is already in upload_folder (tmp_path), so save is a no-op mock_file.save = MagicMock(side_effect=lambda dest: None) - # Place the file where process_file expects it - _create_pptx(tmp_path, "routed.pptx", ["Route test"]) result = fp.process_file(mock_file) assert "Route test" in result @@ -148,7 +121,9 @@ def test_multiple_slides(self, tmp_path): prs = Presentation() for text in ["Slide 1 content", "Slide 2 content", "Slide 3 content"]: slide = prs.slides.add_slide(prs.slide_layouts[6]) - txBox = slide.shapes.add_textbox(Inches(1), Inches(1), Inches(5), Inches(1)) + txBox = slide.shapes.add_textbox( + Inches(1), Inches(1), Inches(5), Inches(1) + ) txBox.text_frame.text = text path = os.path.join(tmp_path, "multi.pptx") prs.save(path) diff --git a/eduaid_web/src/pages/Text_Input.jsx b/eduaid_web/src/pages/Text_Input.jsx index acf0ca99..7f96915f 100644 --- a/eduaid_web/src/pages/Text_Input.jsx +++ b/eduaid_web/src/pages/Text_Input.jsx @@ -185,7 +185,7 @@ const Text_Input = () => { {/* File Upload Section */}
cloud -

Choose a file (PDF, PPTX, MP3 supported)

+

Choose a file (PDF, PPTX, TXT, DOCX, MP3 supported)