diff --git a/backend/Generator/main.py b/backend/Generator/main.py index 04aed79f..2207a0c4 100644 --- a/backend/Generator/main.py +++ b/backend/Generator/main.py @@ -22,6 +22,8 @@ import os import fitz import mammoth +from pptx import Presentation +from utils.text_processor import TextProcessor class MCQGenerator: @@ -352,6 +354,7 @@ def get_document_content(self, document_url): class FileProcessor: def __init__(self, upload_folder='uploads/'): self.upload_folder = upload_folder + self.text_processor = TextProcessor() if not os.path.exists(self.upload_folder): os.makedirs(self.upload_folder) @@ -367,6 +370,29 @@ def extract_text_from_docx(self, file_path): result = mammoth.extract_raw_text(docx_file) return result.value + def extract_text_from_pptx(self, file_path): + """Extract text from a .pptx PowerPoint file. + + Iterates over every slide and shape, pulling text from text-frames + (titles, body placeholders, free text-boxes) and table cells. + """ + prs = Presentation(file_path) + text_parts = [] + for slide in prs.slides: + for shape in slide.shapes: + if shape.has_text_frame: + for paragraph in shape.text_frame.paragraphs: + para_text = paragraph.text.strip() + if para_text: + text_parts.append(para_text) + if shape.has_table: + for row in shape.table.rows: + for cell in row.cells: + cell_text = cell.text.strip() + if cell_text: + text_parts.append(cell_text) + return "\n".join(text_parts) + def process_file(self, file): file_path = os.path.join(self.upload_folder, file.filename) file.save(file_path) @@ -379,10 +405,36 @@ def process_file(self, file): content = self.extract_text_from_pdf(file_path) elif file.filename.endswith('.docx'): content = self.extract_text_from_docx(file_path) + elif file.filename.endswith('.pptx'): + content = self.extract_text_from_pptx(file_path) + elif file.filename.endswith('.ppt'): + import logging + logging.warning( + "Legacy .ppt format is not supported. " + "Please convert to .pptx and try again." + ) os.remove(file_path) return content + def process_file_chunked(self, file, chunk_size=1000, chunk_overlap=200): + """Process file and return chunked text for large documents. + + Returns a list of chunk dicts (see TextProcessor.chunk_document). + Falls back to an empty list when the file type is unsupported or + the extracted text is empty. + """ + content = self.process_file(file) + if not content: + return [] + + # Determine source type from filename + ext = os.path.splitext(file.filename)[1].lstrip('.').lower() + return self.text_processor.chunk_document( + content, source_type=ext or "unknown", + chunk_size=chunk_size, chunk_overlap=chunk_overlap, + ) + class QuestionGenerator: """A transformer-based NLP system for generating reading comprehension-style questions from texts. It can generate full sentence questions, multiple choice questions, or a mix of the diff --git a/backend/conftest.py b/backend/conftest.py new file mode 100644 index 00000000..c36a5f7e --- /dev/null +++ b/backend/conftest.py @@ -0,0 +1,224 @@ +"""Shared pytest fixtures for the EduAid backend test suite. + +All heavy ML models, NLP pipelines, and external services are mocked so that +tests run instantly without a GPU or network access. + +The session-scoped ``_prevent_model_initialization`` fixture patches heavy +constructors **before** ``server.py`` is first imported, preventing module-level +instantiations from triggering real model downloads or GPU usage. +""" + +import sys +from unittest.mock import MagicMock, patch + +import pytest + + +# --------------------------------------------------------------------------- +# Lightweight mock objects that replace the heavy ML classes +# --------------------------------------------------------------------------- + +def _make_mcq_gen_mock(): + mock = MagicMock() + mock.generate_mcq.return_value = { + "statement": "Mock statement", + "questions": [ + { + "question_statement": "What is AI?", + "question_type": "MCQ", + "answer": "Artificial Intelligence", + "id": 1, + "options": ["Machine Learning", "Deep Learning", "Robotics"], + "options_algorithm": "sense2vec", + "extra_options": ["Neural Networks", "NLP"], + "context": "AI is the simulation of human intelligence.", + } + ], + "time_taken": 0.01, + } + return mock + + +def _make_shortq_gen_mock(): + mock = MagicMock() + mock.generate_shortq.return_value = { + "statement": "Mock statement", + "questions": [ + { + "Question": "What is AI?", + "Answer": "Artificial Intelligence", + "id": 1, + "context": "AI is the simulation.", + } + ], + } + return mock + + +def _make_boolq_gen_mock(): + mock = MagicMock() + mock.generate_boolq.return_value = { + "Text": "Mock text", + "Count": 4, + "Boolean_Questions": [ + "Is AI a simulation of human intelligence?", + "Does machine learning use algorithms?", + ], + } + return mock + + +def _make_question_generator_mock(): + mock = MagicMock() + mock.generate.return_value = [ + {"question": "What is AI?", "answer": "Artificial Intelligence"}, + {"question": "What is ML?", "answer": "Machine Learning"}, + ] + return mock + + +def _make_answer_predictor_mock(): + mock = MagicMock() + mock.predict_boolean_answer.return_value = [True] + return mock + + +# --------------------------------------------------------------------------- +# Session-scoped fixture: prevent heavy model loading at import time +# --------------------------------------------------------------------------- + +@pytest.fixture(scope="session", autouse=True) +def _prevent_model_initialization(): + """Patch heavy constructors *before* ``server.py`` is imported. + + ``server.py`` instantiates heavy objects at module level:: + + MCQGen = main.MCQGenerator() # loads T5-large, spacy, sense2vec + answer = main.AnswerPredictor() # loads T5-large, NLI model + BoolQGen = main.BoolQGenerator() # loads T5-base + ShortQGen = main.ShortQGenerator() # loads T5-large, spacy + qg = main.QuestionGenerator() # loads t5-base-question-generator + qa_model = pipeline("question-answering") # loads QA model + + Without these patches, importing ``server.py`` would download several + gigabytes of model weights and require a GPU or significant RAM. + """ + import nltk + import transformers + import mediawikiapi as mwapi + import Generator.main as gen_main + + patches = [ + # Prevent NLTK data downloads + patch.object(nltk, "download", lambda *_a, **_kw: None), + # Prevent transformers QA pipeline loading + patch.object(transformers, "pipeline", return_value=MagicMock()), + # Prevent MediaWikiAPI instantiation + patch.object(mwapi, "MediaWikiAPI", return_value=MagicMock()), + # Prevent Generator class instantiation (each __init__ loads models) + patch.object(gen_main, "MCQGenerator", MagicMock), + patch.object(gen_main, "BoolQGenerator", MagicMock), + patch.object(gen_main, "ShortQGenerator", MagicMock), + patch.object(gen_main, "QuestionGenerator", MagicMock), + patch.object(gen_main, "AnswerPredictor", MagicMock), + patch.object(gen_main, "GoogleDocsService", MagicMock), + patch.object(gen_main, "FileProcessor", MagicMock), + ] + for p in patches: + p.start() + yield + for p in patches: + p.stop() + + +# --------------------------------------------------------------------------- +# Per-test fixtures: override module-level mocks with specific behaviour +# --------------------------------------------------------------------------- + +@pytest.fixture(autouse=True) +def _patch_heavy_imports(monkeypatch): + """Per-test NLTK download suppression (belt-and-suspenders).""" + monkeypatch.setattr("nltk.download", lambda *_a, **_kw: None) + + +@pytest.fixture() +def mock_mcq_gen(): + mock = _make_mcq_gen_mock() + with patch("server.MCQGen", mock): + yield mock + + +@pytest.fixture() +def mock_shortq_gen(): + mock = _make_shortq_gen_mock() + with patch("server.ShortQGen", mock): + yield mock + + +@pytest.fixture() +def mock_boolq_gen(): + mock = _make_boolq_gen_mock() + with patch("server.BoolQGen", mock): + yield mock + + +@pytest.fixture() +def mock_question_generator(): + mock = _make_question_generator_mock() + with patch("server.qg", mock): + yield mock + + +@pytest.fixture() +def mock_answer_predictor(): + mock = _make_answer_predictor_mock() + with patch("server.answer", mock): + yield mock + + +@pytest.fixture() +def mock_mediawiki(): + mock = MagicMock() + mock.summary.return_value = "Expanded text from MediaWiki." + with patch("server.mediawikiapi", mock): + yield mock + + +@pytest.fixture() +def mock_qa_pipeline(): + mock = MagicMock(return_value={"answer": "mocked answer", "score": 0.99}) + with patch("server.qa_model", mock): + yield mock + + +@pytest.fixture() +def mock_file_processor(): + mock = MagicMock() + mock.process_file.return_value = "Extracted text content" + with patch("server.file_processor", mock): + yield mock + + +@pytest.fixture() +def mock_docs_service(): + mock = MagicMock() + mock.get_document_content.return_value = "Document content" + with patch("server.docs_service", mock): + yield mock + + +@pytest.fixture() +def app(): + """Create the Flask app for testing.""" + sys.path.insert(0, ".") + from server import app as flask_app + flask_app.config["TESTING"] = True + return flask_app + + +@pytest.fixture() +def client(app, mock_mcq_gen, mock_shortq_gen, mock_boolq_gen, + mock_question_generator, mock_answer_predictor, mock_mediawiki, + mock_qa_pipeline, mock_file_processor, mock_docs_service): + """A Flask test client with all ML models mocked.""" + return app.test_client() diff --git a/backend/server.py b/backend/server.py index 1c9efaa4..5831ee15 100644 --- a/backend/server.py +++ b/backend/server.py @@ -5,6 +5,7 @@ import subprocess import os import glob +import logging from sklearn.metrics.pairwise import cosine_similarity from sklearn.feature_extraction.text import TfidfVectorizer @@ -25,10 +26,19 @@ from httplib2 import Http from oauth2client import client, file, tools from mediawikiapi import MediaWikiAPI +from utils.text_processor import TextProcessor app = Flask(__name__) CORS(app) -print("Starting Flask App...") + +# Logging +logging.basicConfig( + level=logging.INFO, + format="%(asctime)s [%(levelname)s] %(message)s", + handlers=[logging.StreamHandler()], +) +logger = logging.getLogger(__name__) +logger.info("Starting Flask App...") SERVICE_ACCOUNT_FILE = './service_account_key.json' SCOPES = ['https://www.googleapis.com/auth/documents.readonly'] @@ -40,14 +50,35 @@ qg = main.QuestionGenerator() docs_service = main.GoogleDocsService(SERVICE_ACCOUNT_FILE, SCOPES) file_processor = main.FileProcessor() +text_processor = TextProcessor() mediawikiapi = MediaWikiAPI() qa_model = pipeline("question-answering") def process_input_text(input_text, use_mediawiki): + """Optionally expand *input_text* via MediaWiki if the flag is set. + + Returns ``(processed_text, warning)`` where *warning* is ``None`` when + Wikipedia enrichment succeeded (or was not requested), and a human-readable + string when enrichment was skipped due to a network / SSL / API error. + + This ensures quiz generation continues even when the external MediaWiki + service is unreachable (see issue #428). + """ if use_mediawiki == 1: - input_text = mediawikiapi.summary(input_text,8) - return input_text + try: + input_text = mediawikiapi.summary(input_text, 8) + except Exception: + logger.warning( + "Wikipedia enrichment failed – continuing without it.", + exc_info=True, + ) + return input_text, ( + "Wikipedia enrichment was requested but could not be completed " + "due to a network error. Results were generated from the " + "original input text only." + ) + return input_text, None @app.route("/get_mcq", methods=["POST"]) @@ -56,12 +87,15 @@ def get_mcq(): input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) max_questions = data.get("max_questions", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output = MCQGen.generate_mcq( {"input_text": input_text, "max_questions": max_questions} ) - questions = output["questions"] - return jsonify({"output": questions}) + questions = output.get("questions", []) + result = {"output": questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_boolq", methods=["POST"]) @@ -70,12 +104,15 @@ def get_boolq(): input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) max_questions = data.get("max_questions", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output = BoolQGen.generate_boolq( {"input_text": input_text, "max_questions": max_questions} ) - boolean_questions = output["Boolean_Questions"] - return jsonify({"output": boolean_questions}) + boolean_questions = output.get("Boolean_Questions", []) + result = {"output": boolean_questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_shortq", methods=["POST"]) @@ -84,12 +121,15 @@ def get_shortq(): input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) max_questions = data.get("max_questions", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output = ShortQGen.generate_shortq( {"input_text": input_text, "max_questions": max_questions} ) - questions = output["questions"] - return jsonify({"output": questions}) + questions = output.get("questions", []) + result = {"output": questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_problems", methods=["POST"]) @@ -100,7 +140,7 @@ def get_problems(): max_questions_mcq = data.get("max_questions_mcq", 4) max_questions_boolq = data.get("max_questions_boolq", 4) max_questions_shortq = data.get("max_questions_shortq", 4) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) output1 = MCQGen.generate_mcq( {"input_text": input_text, "max_questions": max_questions_mcq} ) @@ -110,9 +150,10 @@ def get_problems(): output3 = ShortQGen.generate_shortq( {"input_text": input_text, "max_questions": max_questions_shortq} ) - return jsonify( - {"output_mcq": output1, "output_boolq": output2, "output_shortq": output3} - ) + result = {"output_mcq": output1, "output_boolq": output2, "output_shortq": output3} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_mcq_answer", methods=["POST"]) def get_mcq_answer(): @@ -369,7 +410,7 @@ def get_shortq_hard(): data = request.get_json() input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) - input_text = process_input_text(input_text,use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) input_questions = data.get("input_question", []) output = qg.generate( @@ -379,7 +420,10 @@ def get_shortq_hard(): for item in output: item["question"] = make_question_harder(item["question"]) - return jsonify({"output": output}) + result = {"output": output} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_mcq_hard", methods=["POST"]) @@ -387,16 +431,19 @@ def get_mcq_hard(): data = request.get_json() input_text = data.get("input_text", "") use_mediawiki = data.get("use_mediawiki", 0) - input_text = process_input_text(input_text,use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) input_questions = data.get("input_question", []) output = qg.generate( article=input_text, num_questions=input_questions, answer_style="multiple_choice" ) - + for q in output: q["question"] = make_question_harder(q["question"]) - - return jsonify({"output": output}) + + result = {"output": output} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route("/get_boolq_hard", methods=["POST"]) def get_boolq_hard(): @@ -405,7 +452,7 @@ def get_boolq_hard(): use_mediawiki = data.get("use_mediawiki", 0) input_questions = data.get("input_question", []) - input_text = process_input_text(input_text, use_mediawiki) + input_text, wiki_warning = process_input_text(input_text, use_mediawiki) # Generate questions using the same QG model generated = qg.generate( @@ -417,7 +464,10 @@ def get_boolq_hard(): # Apply transformation to make each question harder harder_questions = [make_question_harder(q) for q in generated] - return jsonify({"output": harder_questions}) + result = {"output": harder_questions} + if wiki_warning: + result["warning"] = wiki_warning + return jsonify(result) @app.route('/upload', methods=['POST']) def upload_file(): @@ -432,7 +482,12 @@ def upload_file(): content = file_processor.process_file(file) if content: - return jsonify({"content": content}) + chunks = text_processor.chunk_document(content) + return jsonify({ + "content": content, + "chunks": chunks, + "num_chunks": len(chunks), + }) else: return jsonify({"error": "Unsupported file type or error processing file"}), 400 diff --git a/backend/test_pptx_extraction.py b/backend/test_pptx_extraction.py new file mode 100644 index 00000000..fe453733 --- /dev/null +++ b/backend/test_pptx_extraction.py @@ -0,0 +1,135 @@ +"""Tests for PPTX text extraction in FileProcessor. + +Imports the real FileProcessor from Generator.main. Heavy ML dependencies +(torch, transformers, sense2vec, etc.) are mocked at the sys.modules level +so they are not actually loaded during test collection. +""" +import os +import sys +from unittest.mock import MagicMock + +import pytest +from pptx import Presentation +from pptx.util import Inches + +# ── Mock heavy ML dependencies before importing Generator.main ─────────────── +# This prevents model loading while still testing the real FileProcessor code. + +_mock_torch = MagicMock() +_mock_torch.cuda.is_available.return_value = False +sys.modules.setdefault("torch", _mock_torch) +sys.modules.setdefault("transformers", MagicMock()) +sys.modules.setdefault("sense2vec", MagicMock()) +sys.modules.setdefault("google.oauth2", MagicMock()) +sys.modules.setdefault("google.oauth2.service_account", MagicMock()) +sys.modules.setdefault("googleapiclient", MagicMock()) +sys.modules.setdefault("googleapiclient.discovery", MagicMock()) +sys.modules.setdefault("en_core_web_sm", MagicMock()) + +from Generator.main import FileProcessor # noqa: E402 + + +# ── helpers ────────────────────────────────────────────────────────────────── + +def _create_pptx(tmp_path, filename, texts): + """Create a minimal .pptx with one text-box per item in *texts*.""" + prs = Presentation() + slide = prs.slides.add_slide(prs.slide_layouts[6]) # blank layout + for t in texts: + txBox = slide.shapes.add_textbox(Inches(1), Inches(1), Inches(5), Inches(1)) + txBox.text_frame.text = t + path = os.path.join(tmp_path, filename) + prs.save(path) + return path + + +def _create_pptx_with_table(tmp_path, filename, rows_data): + """Create a .pptx containing a single table.""" + prs = Presentation() + slide = prs.slides.add_slide(prs.slide_layouts[6]) + cols = max(len(r) for r in rows_data) + table_shape = slide.shapes.add_table( + len(rows_data), cols, Inches(1), Inches(1), Inches(6), Inches(2) + ) + table = table_shape.table + for ri, row in enumerate(rows_data): + for ci, cell_text in enumerate(row): + table.cell(ri, ci).text = cell_text + path = os.path.join(tmp_path, filename) + prs.save(path) + return path + + +# ── tests ──────────────────────────────────────────────────────────────────── + +class TestExtractTextFromPptx: + + def test_simple_text(self, tmp_path): + path = _create_pptx(tmp_path, "simple.pptx", ["Hello World", "Second box"]) + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + assert "Hello World" in result + assert "Second box" in result + + def test_table_text(self, tmp_path): + rows = [["Name", "Score"], ["Alice", "95"], ["Bob", "88"]] + path = _create_pptx_with_table(tmp_path, "table.pptx", rows) + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + for cell in ["Name", "Score", "Alice", "95", "Bob", "88"]: + assert cell in result + + def test_empty_presentation(self, tmp_path): + prs = Presentation() + prs.slides.add_slide(prs.slide_layouts[6]) # blank slide, no shapes + path = os.path.join(tmp_path, "empty.pptx") + prs.save(path) + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + assert result == "" + + def test_process_file_routes_pptx(self, tmp_path): + """process_file() should route .pptx files to extract_text_from_pptx.""" + _create_pptx(tmp_path, "routed.pptx", ["Route test"]) + fp = FileProcessor(upload_folder=str(tmp_path)) + + mock_file = MagicMock() + mock_file.filename = "routed.pptx" + # File is already in upload_folder (tmp_path), so save is a no-op + mock_file.save = MagicMock(side_effect=lambda dest: None) + + result = fp.process_file(mock_file) + assert "Route test" in result + + def test_unsupported_ppt_extension(self, tmp_path): + """Legacy .ppt files are not supported; process_file returns empty.""" + dummy = os.path.join(tmp_path, "old.ppt") + with open(dummy, "w") as f: + f.write("not a real ppt") + + fp = FileProcessor(upload_folder=str(tmp_path)) + mock_file = MagicMock() + mock_file.filename = "old.ppt" + # File is already in upload_folder (tmp_path), so save is a no-op + mock_file.save = MagicMock(side_effect=lambda dest: None) + + result = fp.process_file(mock_file) + assert result == "" + + def test_multiple_slides(self, tmp_path): + """Text from multiple slides should all be extracted.""" + prs = Presentation() + for text in ["Slide 1 content", "Slide 2 content", "Slide 3 content"]: + slide = prs.slides.add_slide(prs.slide_layouts[6]) + txBox = slide.shapes.add_textbox( + Inches(1), Inches(1), Inches(5), Inches(1) + ) + txBox.text_frame.text = text + path = os.path.join(tmp_path, "multi.pptx") + prs.save(path) + + fp = FileProcessor(upload_folder=str(tmp_path)) + result = fp.extract_text_from_pptx(path) + assert "Slide 1 content" in result + assert "Slide 2 content" in result + assert "Slide 3 content" in result diff --git a/backend/test_text_processor.py b/backend/test_text_processor.py new file mode 100644 index 00000000..eecda073 --- /dev/null +++ b/backend/test_text_processor.py @@ -0,0 +1,375 @@ +"""Tests for the TextProcessor chunking utility. + +These tests verify that TextProcessor correctly splits text into +overlapping chunks on natural boundaries (paragraphs, sentences, words) +without any external dependencies beyond Python's stdlib. +""" + +import pytest + +from utils.text_processor import TextProcessor + + +# --------------------------------------------------------------------------- +# Helper data +# --------------------------------------------------------------------------- + +SHORT_TEXT = "This is a short sentence." + +PARAGRAPH_TEXT = ( + "Artificial intelligence is the simulation of human intelligence.\n\n" + "Machine learning is a subset of AI that focuses on algorithms.\n\n" + "Deep learning involves neural networks with many layers." +) + +LONG_PARAGRAPH = ( + "Artificial intelligence (AI) is the simulation of human intelligence " + "processes by machines, especially computer systems. These processes " + "include learning, reasoning, and self-correction. AI applications include " + "speech recognition, natural language processing, machine vision, expert " + "systems, and robotics. Machine learning, a subset of AI, focuses on the " + "development of algorithms that can learn from and make predictions or " + "decisions based on data. Deep learning, a technique within machine " + "learning, involves neural networks with many layers. It has revolutionized " + "AI by enabling complex pattern recognition and data processing tasks. " + "Ethical considerations in AI include issues of bias in algorithms, privacy " + "concerns with data collection, and the impact of AI on jobs and society." +) + + +def _make_large_text(num_paragraphs=20): + """Create a synthetic large document with multiple paragraphs.""" + paragraphs = [] + for i in range(num_paragraphs): + paragraphs.append( + f"This is paragraph {i + 1} of the document. " + f"It contains several sentences about topic {i + 1}. " + f"The content is rich and varied for testing purposes. " + f"We want to ensure that chunking works correctly across " + f"paragraph boundaries and preserves context." + ) + return "\n\n".join(paragraphs) + + +# =========================================================================== +# Constructor validation +# =========================================================================== + + +class TestTextProcessorInit: + + def test_default_parameters(self): + tp = TextProcessor() + assert tp.chunk_size == 1000 + assert tp.chunk_overlap == 200 + + def test_custom_parameters(self): + tp = TextProcessor(chunk_size=500, chunk_overlap=50) + assert tp.chunk_size == 500 + assert tp.chunk_overlap == 50 + + def test_invalid_chunk_size_zero(self): + with pytest.raises(ValueError, match="chunk_size must be a positive"): + TextProcessor(chunk_size=0) + + def test_invalid_chunk_size_negative(self): + with pytest.raises(ValueError, match="chunk_size must be a positive"): + TextProcessor(chunk_size=-10) + + def test_invalid_chunk_overlap_negative(self): + with pytest.raises(ValueError, match="chunk_overlap must be non-negative"): + TextProcessor(chunk_overlap=-1) + + def test_overlap_greater_than_size(self): + with pytest.raises(ValueError, match="chunk_overlap must be smaller"): + TextProcessor(chunk_size=100, chunk_overlap=100) + + def test_overlap_exceeds_size(self): + with pytest.raises(ValueError, match="chunk_overlap must be smaller"): + TextProcessor(chunk_size=100, chunk_overlap=150) + + +# =========================================================================== +# chunk_text — basic behaviour +# =========================================================================== + + +class TestChunkTextBasic: + + def test_short_text_single_chunk(self): + tp = TextProcessor(chunk_size=1000) + chunks = tp.chunk_text(SHORT_TEXT) + assert len(chunks) == 1 + assert chunks[0] == SHORT_TEXT + + def test_empty_string_returns_empty(self): + tp = TextProcessor() + assert tp.chunk_text("") == [] + + def test_whitespace_only_returns_empty(self): + tp = TextProcessor() + assert tp.chunk_text(" \n\n \t ") == [] + + def test_none_returns_empty(self): + tp = TextProcessor() + assert tp.chunk_text(None) == [] + + def test_single_character(self): + tp = TextProcessor() + chunks = tp.chunk_text("A") + assert len(chunks) == 1 + assert chunks[0] == "A" + + +# =========================================================================== +# chunk_text — paragraph splitting +# =========================================================================== + + +class TestChunkTextParagraphs: + + def test_paragraph_splitting(self): + """Text with \\n\\n delimiters should produce multiple chunks when + the full text exceeds chunk_size.""" + tp = TextProcessor(chunk_size=100, chunk_overlap=20) + chunks = tp.chunk_text(PARAGRAPH_TEXT) + assert len(chunks) > 1 + # Each chunk should be non-empty + for chunk in chunks: + assert len(chunk.strip()) > 0 + + def test_all_content_preserved(self): + """Joining all chunks should cover the original text content.""" + tp = TextProcessor(chunk_size=100, chunk_overlap=0) + chunks = tp.chunk_text(PARAGRAPH_TEXT) + joined = " ".join(chunks) + # Every sentence from the original should appear in the joined output + assert "Artificial intelligence" in joined + assert "Machine learning" in joined + assert "Deep learning" in joined + + +# =========================================================================== +# chunk_text — sentence splitting +# =========================================================================== + + +class TestChunkTextSentences: + + def test_sentence_splitting_for_long_paragraph(self): + """A long paragraph without \\n\\n should still split on '. '.""" + tp = TextProcessor(chunk_size=200, chunk_overlap=30) + chunks = tp.chunk_text(LONG_PARAGRAPH) + assert len(chunks) > 1 + + def test_chunks_respect_size_limit(self): + """Each chunk should be at most chunk_size (with small tolerance + for edge cases with indivisible tokens).""" + tp = TextProcessor(chunk_size=200, chunk_overlap=30) + chunks = tp.chunk_text(LONG_PARAGRAPH) + for chunk in chunks: + # Allow a small tolerance — the final chunk may overshoot slightly + # when a sentence cannot be split further. + assert len(chunk) <= 250, f"Chunk too large: {len(chunk)} chars" + + +# =========================================================================== +# chunk_text — overlap +# =========================================================================== + + +class TestChunkTextOverlap: + + def test_overlap_present(self): + """Adjacent chunks should share some overlapping text.""" + tp = TextProcessor(chunk_size=200, chunk_overlap=50) + chunks = tp.chunk_text(LONG_PARAGRAPH) + if len(chunks) < 2: + pytest.skip("Not enough chunks to test overlap") + + # Check at least one pair has overlap + found_overlap = False + for i in range(len(chunks) - 1): + # The tail of chunk i should appear at the start of chunk i+1 + tail = chunks[i][-50:] + if tail in chunks[i + 1]: + found_overlap = True + break + # Overlap may not always be exact substring match due to stripping, + # so also check for shared words + if not found_overlap: + words_a = set(chunks[0].split()) + words_b = set(chunks[1].split()) + assert len(words_a & words_b) > 0, "No overlap found between chunks" + + def test_zero_overlap(self): + """With overlap=0, chunks should have minimal shared content.""" + tp = TextProcessor(chunk_size=200, chunk_overlap=0) + chunks = tp.chunk_text(LONG_PARAGRAPH) + assert len(chunks) > 1 + + +# =========================================================================== +# chunk_text — large documents +# =========================================================================== + + +class TestChunkTextLargeDoc: + + def test_large_document_produces_many_chunks(self): + text = _make_large_text(30) + tp = TextProcessor(chunk_size=500, chunk_overlap=100) + chunks = tp.chunk_text(text) + assert len(chunks) >= 5 + + def test_no_empty_chunks(self): + text = _make_large_text(20) + tp = TextProcessor(chunk_size=300, chunk_overlap=50) + chunks = tp.chunk_text(text) + for chunk in chunks: + assert len(chunk.strip()) > 0 + + def test_all_paragraphs_represented(self): + """Every paragraph's content should appear in at least one chunk.""" + text = _make_large_text(10) + tp = TextProcessor(chunk_size=500, chunk_overlap=100) + chunks = tp.chunk_text(text) + joined = " ".join(chunks) + for i in range(1, 11): + assert f"paragraph {i}" in joined + + +# =========================================================================== +# chunk_text — custom parameters per call +# =========================================================================== + + +class TestChunkTextCustomParams: + + def test_override_chunk_size(self): + tp = TextProcessor(chunk_size=1000) + chunks_default = tp.chunk_text(LONG_PARAGRAPH) + chunks_small = tp.chunk_text(LONG_PARAGRAPH, chunk_size=100) + assert len(chunks_small) > len(chunks_default) + + def test_override_chunk_overlap(self): + tp = TextProcessor(chunk_size=200, chunk_overlap=0) + chunks_no_overlap = tp.chunk_text(LONG_PARAGRAPH) + chunks_overlap = tp.chunk_text(LONG_PARAGRAPH, chunk_overlap=50) + # With overlap, we should get at least as many chunks + assert len(chunks_overlap) >= len(chunks_no_overlap) + + +# =========================================================================== +# chunk_document — metadata +# =========================================================================== + + +class TestChunkDocument: + + def test_returns_list_of_dicts(self): + tp = TextProcessor(chunk_size=200, chunk_overlap=30) + docs = tp.chunk_document(LONG_PARAGRAPH, source_type="pdf") + assert isinstance(docs, list) + assert all(isinstance(d, dict) for d in docs) + + def test_dict_keys(self): + tp = TextProcessor(chunk_size=200, chunk_overlap=30) + docs = tp.chunk_document(LONG_PARAGRAPH) + for doc in docs: + assert "index" in doc + assert "text" in doc + assert "char_count" in doc + assert "preview" in doc + assert "source_type" in doc + + def test_index_sequential(self): + tp = TextProcessor(chunk_size=200, chunk_overlap=30) + docs = tp.chunk_document(LONG_PARAGRAPH) + for i, doc in enumerate(docs): + assert doc["index"] == i + + def test_char_count_correct(self): + tp = TextProcessor(chunk_size=200, chunk_overlap=30) + docs = tp.chunk_document(LONG_PARAGRAPH) + for doc in docs: + assert doc["char_count"] == len(doc["text"]) + + def test_preview_max_length(self): + tp = TextProcessor(chunk_size=200, chunk_overlap=30) + docs = tp.chunk_document(LONG_PARAGRAPH) + for doc in docs: + assert len(doc["preview"]) <= 100 + + def test_source_type_passed_through(self): + tp = TextProcessor() + docs = tp.chunk_document(SHORT_TEXT, source_type="pdf") + assert docs[0]["source_type"] == "pdf" + + def test_default_source_type(self): + tp = TextProcessor() + docs = tp.chunk_document(SHORT_TEXT) + assert docs[0]["source_type"] == "unknown" + + def test_empty_input_returns_empty_list(self): + tp = TextProcessor() + assert tp.chunk_document("") == [] + assert tp.chunk_document(" ") == [] + + def test_short_text_one_chunk(self): + tp = TextProcessor(chunk_size=1000) + docs = tp.chunk_document(SHORT_TEXT) + assert len(docs) == 1 + assert docs[0]["text"] == SHORT_TEXT + assert docs[0]["index"] == 0 + + +# =========================================================================== +# Edge cases +# =========================================================================== + + +class TestEdgeCases: + + def test_text_exactly_chunk_size(self): + """Text exactly equal to chunk_size should produce one chunk.""" + tp = TextProcessor(chunk_size=50, chunk_overlap=10) + text = "A" * 50 + chunks = tp.chunk_text(text) + assert len(chunks) == 1 + + def test_text_one_char_over_chunk_size(self): + """Text one character over chunk_size should produce at least two chunks + or gracefully handle the boundary.""" + tp = TextProcessor(chunk_size=50, chunk_overlap=10) + text = "word " * 11 # 55 chars + chunks = tp.chunk_text(text) + assert len(chunks) >= 1 + + def test_repeated_separators(self): + """Text with many consecutive newlines should not produce empty chunks.""" + tp = TextProcessor(chunk_size=100, chunk_overlap=10) + text = "Hello\n\n\n\n\n\nWorld\n\n\n\nFoo" + chunks = tp.chunk_text(text) + for chunk in chunks: + assert len(chunk.strip()) > 0 + + def test_no_natural_boundaries(self): + """A single long string with no spaces should still be split.""" + tp = TextProcessor(chunk_size=50, chunk_overlap=10) + text = "a" * 200 + chunks = tp.chunk_text(text) + assert len(chunks) >= 2 + + def test_only_newlines(self): + tp = TextProcessor() + assert tp.chunk_text("\n\n\n\n") == [] + + def test_unicode_text(self): + """Unicode content should be handled correctly.""" + tp = TextProcessor(chunk_size=100, chunk_overlap=10) + text = "人工智能是计算机科学的一个分支。" * 10 + chunks = tp.chunk_text(text) + assert len(chunks) >= 1 + joined = "".join(chunks) + assert "人工智能" in joined diff --git a/backend/test_wikipedia_fallback.py b/backend/test_wikipedia_fallback.py new file mode 100644 index 00000000..94a9373b --- /dev/null +++ b/backend/test_wikipedia_fallback.py @@ -0,0 +1,282 @@ +"""Tests for graceful Wikipedia / MediaWiki fallback (issue #428). + +When ``use_mediawiki=1`` is passed but the MediaWiki API call fails +(SSL error, timeout, network unreachable, etc.), every endpoint should: + + 1. Still return HTTP 200 with valid quiz data. + 2. Include a ``warning`` key in the JSON response. + 3. Generate questions from the *original* input text (not crash). + +Uses **pytest** with Flask's built-in test client. All heavy ML models +are mocked in ``conftest.py`` – no running server or GPU required. +""" + +from unittest.mock import patch, MagicMock + +import pytest + +# --------------------------------------------------------------------------- +# Shared test data +# --------------------------------------------------------------------------- + +SAMPLE_TEXT = ( + "Artificial intelligence (AI) is the simulation of human intelligence " + "processes by machines, especially computer systems. These processes " + "include learning, reasoning, and self-correction. AI applications " + "include speech recognition, natural language processing, machine " + "vision, expert systems, and robotics. Machine learning is a subset " + "of AI that focuses on algorithms that learn from data." +) + + +# =========================================================================== +# SSL / network errors produce a warning, not a 500 +# =========================================================================== + + +class TestWikipediaFallbackBasicEndpoints: + """Test the three primary question-generation endpoints.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_ssl_error_returns_200_with_warning(self, client, endpoint): + """SSLError in MediaWiki should NOT crash the request.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = Exception( + "SSLError(SSLEOFError(8, 'EOF occurred in violation of protocol'))" + ) + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + assert "Wikipedia" in data["warning"] or "network" in data["warning"] + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_connection_error_returns_200_with_warning(self, client, endpoint): + """ConnectionError in MediaWiki should NOT crash the request.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = ConnectionError("Network unreachable") + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_timeout_returns_200_with_warning(self, client, endpoint): + """TimeoutError in MediaWiki should NOT crash the request.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = TimeoutError("Connection timed out") + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + +class TestWikipediaFallbackHardEndpoints: + """Test the three hard-question endpoints.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_shortq_hard", "/get_mcq_hard", "/get_boolq_hard", + ]) + def test_connection_error_hard_endpoints_returns_200_with_warning(self, client, endpoint): + """Connection failure should NOT crash hard-question endpoints.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = ConnectionError("Network unreachable") + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + +class TestWikipediaFallbackProblems: + """Test the combined /get_problems endpoint.""" + + def test_ssl_error_get_problems_returns_200_with_warning(self, client): + """SSL failure should NOT crash the combined /get_problems endpoint.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = TimeoutError("Connection timed out") + resp = client.post( + "/get_problems", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output_mcq" in data + assert "output_boolq" in data + assert "output_shortq" in data + assert "warning" in data + + +# =========================================================================== +# Successful MediaWiki calls should NOT have a warning +# =========================================================================== + + +class TestWikipediaSuccessNoWarning: + """When MediaWiki succeeds, no warning should be present.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_successful_wiki_call_has_no_warning(self, client, endpoint): + """Successful MediaWiki call should NOT include a warning.""" + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" not in data + + def test_successful_wiki_get_problems_no_warning(self, client): + """Successful MediaWiki call on /get_problems should NOT include a warning.""" + resp = client.post( + "/get_problems", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "warning" not in data + + +# =========================================================================== +# Without use_mediawiki, no warning should appear +# =========================================================================== + + +class TestNoMediawikiNoWarning: + """When use_mediawiki is not set (or 0), no warning should appear.""" + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_no_mediawiki_no_warning(self, client, endpoint): + """Without use_mediawiki=1, no warning should appear.""" + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 0}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "warning" not in data + + @pytest.mark.parametrize("endpoint", [ + "/get_mcq", "/get_boolq", "/get_shortq", + ]) + def test_default_no_mediawiki_no_warning(self, client, endpoint): + """When use_mediawiki is not sent, no warning should appear.""" + resp = client.post( + endpoint, + json={"input_text": SAMPLE_TEXT}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "warning" not in data + + +# =========================================================================== +# Various exception types are all caught +# =========================================================================== + + +class TestVariousExceptionTypes: + """All exception types from MediaWiki should be caught and handled.""" + + @pytest.mark.parametrize("exc_class,exc_msg", [ + (ConnectionError, "Network unreachable"), + (TimeoutError, "Connection timed out"), + (OSError, "SSL: CERTIFICATE_VERIFY_FAILED"), + (RuntimeError, "Unexpected mediawiki error"), + (Exception, "SSLEOFError EOF occurred in violation of protocol"), + ]) + def test_various_exceptions_handled_on_mcq(self, client, exc_class, exc_msg): + """All exception types from MediaWiki should be caught on /get_mcq.""" + with patch("server.mediawikiapi") as wiki_mock: + wiki_mock.summary.side_effect = exc_class(exc_msg) + resp = client.post( + "/get_mcq", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + data = resp.get_json() + assert "output" in data + assert "warning" in data + + +# =========================================================================== +# Original text is preserved when MediaWiki fails +# =========================================================================== + + +class TestTextPreservation: + """Verify the correct text is forwarded to the generator.""" + + def test_original_text_used_on_failure(self, client): + """When MediaWiki fails, generation should use the original text.""" + with patch("server.mediawikiapi") as wiki_mock, \ + patch("server.MCQGen") as mcq_mock: + wiki_mock.summary.side_effect = ConnectionError("fail") + mcq_mock.generate_mcq.return_value = { + "questions": [{ + "question_statement": "Test?", + "answer": "Yes", + "id": 1, + "options": ["No", "Maybe"], + "extra_options": [], + "context": "ctx", + }] + } + resp = client.post( + "/get_mcq", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + # Verify generate_mcq was called with the *original* text + call_args = mcq_mock.generate_mcq.call_args + assert call_args[0][0]["input_text"] == SAMPLE_TEXT + + def test_enriched_text_used_on_success(self, client): + """When MediaWiki succeeds, enriched text is forwarded to generation.""" + enriched = "Enriched text from Wikipedia about AI and machine learning." + with patch("server.mediawikiapi") as wiki_mock, \ + patch("server.MCQGen") as mcq_mock: + wiki_mock.summary.return_value = enriched + mcq_mock.generate_mcq.return_value = { + "questions": [{ + "question_statement": "Test?", + "answer": "Yes", + "id": 1, + "options": ["No", "Maybe"], + "extra_options": [], + "context": "ctx", + }] + } + resp = client.post( + "/get_mcq", + json={"input_text": SAMPLE_TEXT, "use_mediawiki": 1}, + ) + assert resp.status_code == 200 + call_args = mcq_mock.generate_mcq.call_args + assert call_args[0][0]["input_text"] == enriched diff --git a/backend/utils/__init__.py b/backend/utils/__init__.py new file mode 100644 index 00000000..f402344b --- /dev/null +++ b/backend/utils/__init__.py @@ -0,0 +1 @@ +# Utils package for EduAid backend diff --git a/backend/utils/text_processor.py b/backend/utils/text_processor.py new file mode 100644 index 00000000..51e7c4d7 --- /dev/null +++ b/backend/utils/text_processor.py @@ -0,0 +1,226 @@ +"""Text chunking utilities for RAG preparation. + +Splits large documents into manageable "Context Blocks" before they reach +the question-generation pipeline. This prevents OOM errors and pipeline +hangs when processing large PDFs (50+ pages). + +The splitting strategy mirrors LangChain's ``RecursiveCharacterTextSplitter`` +but is implemented from scratch with **zero external dependencies** — only +Python's standard library is used. +""" + +from __future__ import annotations + +import logging +from typing import Dict, List, Optional + +logger = logging.getLogger(__name__) + +# Default separators ordered from coarsest to finest granularity. +DEFAULT_SEPARATORS: List[str] = ["\n\n", "\n", ". ", " ", ""] + + +class TextProcessor: + """Recursively splits text into overlapping chunks on natural boundaries. + + Parameters + ---------- + chunk_size : int + Target maximum number of characters per chunk (default 1000, + roughly 200–250 tokens for English text). + chunk_overlap : int + Number of characters that adjacent chunks share so that no + context is lost at boundaries (default 200). + separators : list[str] | None + Ordered list of boundary strings to try when splitting. + Falls back to character-level splitting when none match. + """ + + def __init__( + self, + chunk_size: int = 1000, + chunk_overlap: int = 200, + separators: Optional[List[str]] = None, + ) -> None: + if chunk_size <= 0: + raise ValueError("chunk_size must be a positive integer") + if chunk_overlap < 0: + raise ValueError("chunk_overlap must be non-negative") + if chunk_overlap >= chunk_size: + raise ValueError("chunk_overlap must be smaller than chunk_size") + + self.chunk_size = chunk_size + self.chunk_overlap = chunk_overlap + self.separators = separators or DEFAULT_SEPARATORS + + # ------------------------------------------------------------------ + # Public API + # ------------------------------------------------------------------ + + def chunk_text( + self, + text: str, + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + ) -> List[str]: + """Split *text* into a list of overlapping string chunks. + + Parameters + ---------- + text : str + The input document text. + chunk_size : int | None + Override the instance-level chunk size for this call. + chunk_overlap : int | None + Override the instance-level chunk overlap for this call. + + Returns + ------- + list[str] + Ordered list of text chunks. Each chunk is at most + *chunk_size* characters long (except when a single + indivisible token exceeds that limit). + """ + if not text or not text.strip(): + return [] + + size = chunk_size if chunk_size is not None else self.chunk_size + overlap = chunk_overlap if chunk_overlap is not None else self.chunk_overlap + + return self._recursive_split(text.strip(), self.separators, size, overlap) + + def chunk_document( + self, + text: str, + source_type: str = "unknown", + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + ) -> List[Dict]: + """Split *text* and return chunk metadata dicts. + + Each dict contains: + - ``index`` – zero-based chunk position + - ``text`` – the chunk content + - ``char_count`` – length of the chunk in characters + - ``preview`` – first 100 characters (for logging / UI) + - ``source_type``– provenance hint (``"pdf"``, ``"docx"``, …) + + Returns an empty list for blank / whitespace-only input. + """ + raw_chunks = self.chunk_text(text, chunk_size, chunk_overlap) + + chunks: List[Dict] = [] + for idx, chunk in enumerate(raw_chunks): + chunks.append( + { + "index": idx, + "text": chunk, + "char_count": len(chunk), + "preview": chunk[:100].replace("\n", " "), + "source_type": source_type, + } + ) + + if chunks: + logger.info( + "Chunked %s document: %d chars → %d chunks (size=%d, overlap=%d)", + source_type, + len(text), + len(chunks), + chunk_size or self.chunk_size, + chunk_overlap or self.chunk_overlap, + ) + + return chunks + + # ------------------------------------------------------------------ + # Internal helpers + # ------------------------------------------------------------------ + + def _recursive_split( + self, + text: str, + separators: List[str], + chunk_size: int, + chunk_overlap: int, + ) -> List[str]: + """Core recursive splitting algorithm.""" + + # Base case: text already fits in one chunk. + if len(text) <= chunk_size: + return [text] if text.strip() else [] + + # Pick the best separator — the first one that actually appears. + separator = "" + remaining_separators = [] + for i, sep in enumerate(separators): + if sep == "": + separator = sep + remaining_separators = [] + break + if sep in text: + separator = sep + remaining_separators = separators[i + 1 :] + break + + # Split text on the chosen separator. + if separator: + pieces = text.split(separator) + else: + # Character-level fallback (separator == ""). + pieces = list(text) + + # Merge small pieces back together up to chunk_size. + chunks: List[str] = [] + current_chunk: List[str] = [] + current_length = 0 + + for piece in pieces: + piece_len = len(piece) + (len(separator) if current_chunk else 0) + + if current_length + piece_len > chunk_size and current_chunk: + # Flush current chunk. + merged = separator.join(current_chunk) + if merged.strip(): + chunks.append(merged.strip()) + + # Keep overlap: walk backwards through pieces to retain + # approximately `chunk_overlap` characters. + overlap_chunks: List[str] = [] + overlap_len = 0 + for prev_piece in reversed(current_chunk): + if overlap_len + len(prev_piece) > chunk_overlap: + break + overlap_chunks.insert(0, prev_piece) + overlap_len += len(prev_piece) + len(separator) + + current_chunk = overlap_chunks + current_length = sum(len(p) for p in current_chunk) + max( + 0, (len(current_chunk) - 1) + ) * len(separator) + + current_chunk.append(piece) + current_length += piece_len + + # Flush remaining. + if current_chunk: + merged = separator.join(current_chunk) + if merged.strip(): + chunks.append(merged.strip()) + + # If any chunk is still too large and we have finer separators, + # recurse with the next separator level. + if remaining_separators: + final_chunks: List[str] = [] + for chunk in chunks: + if len(chunk) > chunk_size: + final_chunks.extend( + self._recursive_split( + chunk, remaining_separators, chunk_size, chunk_overlap + ) + ) + else: + final_chunks.append(chunk) + return final_chunks + + return chunks diff --git a/eduaid_web/src/pages/Text_Input.jsx b/eduaid_web/src/pages/Text_Input.jsx index e341d331..7f96915f 100644 --- a/eduaid_web/src/pages/Text_Input.jsx +++ b/eduaid_web/src/pages/Text_Input.jsx @@ -5,7 +5,7 @@ import stars from "../assets/stars.png"; import cloud from "../assets/cloud.png"; import { FaClipboard } from "react-icons/fa"; import Switch from "react-switch"; -import { Link,useNavigate } from "react-router-dom"; +import { Link, useNavigate } from "react-router-dom"; import apiClient from "../utils/apiClient"; const Text_Input = () => { @@ -185,9 +185,9 @@ const Text_Input = () => { {/* File Upload Section */}
cloud -

Choose a file (PDF, MP3 supported)

+

Choose a file (PDF, PPTX, TXT, DOCX, MP3 supported)

- +
- + > + +
diff --git a/requirements.txt b/requirements.txt index 717390b7..77bd1f29 100644 --- a/requirements.txt +++ b/requirements.txt @@ -29,6 +29,7 @@ google-auth datasets==3.1.0 tokenizers mammoth +python-pptx mediawikiapi PyMuPDF textblob \ No newline at end of file