diff --git a/.claude/commands/fix-issue.md b/.claude/commands/fix-issue.md new file mode 100644 index 0000000..7f93749 --- /dev/null +++ b/.claude/commands/fix-issue.md @@ -0,0 +1,20 @@ +# Fix GitHub Issue + +Fix GitHub issue #$ARGUMENTS + +## Workflow + +1. **Fetch issue details** using `gh issue view $ARGUMENTS` +2. **Understand the root cause** by reading relevant code +3. **Create a plan** for the fix using extended thinking +4. **Implement the fix** with minimal changes +5. **Write or update tests** to cover the fix +6. **Run tests** to verify the fix works +7. **Create a commit** with message "Fixes #$ARGUMENTS: " + +## Guidelines + +- Focus on the specific issue - avoid unrelated changes +- Follow existing code patterns in the repository +- Ensure backward compatibility unless explicitly requested +- Update documentation if behavior changes diff --git a/.claude/commands/review-pr.md b/.claude/commands/review-pr.md new file mode 100644 index 0000000..173e6e7 --- /dev/null +++ b/.claude/commands/review-pr.md @@ -0,0 +1,34 @@ +# Review Pull Request + +Review PR #$ARGUMENTS + +## Workflow + +1. **Fetch PR details** using `gh pr view $ARGUMENTS` +2. **Get the diff** using `gh pr diff $ARGUMENTS` +3. **Understand the changes** - what problem does this PR solve? +4. **Review code quality**: + - Check for bugs or logic errors + - Verify error handling + - Look for security issues (SQL injection, XSS, etc.) + - Check naming conventions and code clarity +5. **Verify test coverage** - are changes adequately tested? +6. **Check for breaking changes** - is backward compatibility maintained? +7. **Provide constructive feedback** with specific suggestions + +## Review Checklist + +- [ ] Code follows project conventions +- [ ] No obvious security vulnerabilities +- [ ] Error cases are handled appropriately +- [ ] Tests cover the changes +- [ ] Documentation updated if needed +- [ ] No unrelated changes included + +## Output Format + +Provide a summary with: +- Overall assessment (approve/request changes/comment) +- Specific issues found (with line references) +- Suggestions for improvement +- Questions for the author diff --git a/.claude/settings.local.json b/.claude/settings.local.json new file mode 100644 index 0000000..b9ce24e --- /dev/null +++ b/.claude/settings.local.json @@ -0,0 +1,19 @@ +{ + "permissions": { + "allow": [ + "Bash(git add:*)", + "Bash(git commit:*)", + "Bash(git push:*)", + "Bash(git remote add:*)", + "Bash(where:*)", + "Bash(/c/Users/brunstof/AppData/Local/gh/bin/gh.exe:*)", + "mcp__github__fork_repository", + "mcp__github__create_pull_request", + "mcp__github__get_pull_request", + "mcp__MCP_DOCKER__mcp-find", + "mcp__MCP_DOCKER__mcp-config-set", + "mcp__MCP_DOCKER__mcp-add", + "mcp__MCP_DOCKER__browser_navigate" + ] + } +} diff --git a/.claude/skills/mariadb-debug/SKILL.md b/.claude/skills/mariadb-debug/SKILL.md new file mode 100644 index 0000000..fbf4789 --- /dev/null +++ b/.claude/skills/mariadb-debug/SKILL.md @@ -0,0 +1,119 @@ +--- +name: mariadb-debug +description: Debug MariaDB MCP server issues, analyze connection pool problems, troubleshoot embedding service failures, diagnose vector store operations. Use when working with database connectivity, embedding errors, or MCP tool failures. +--- + +# MariaDB MCP Server Debugging + +## Key Files to Check + +1. **src/server.py** - Main MCP server and tool definitions + - Connection pool initialization (`initialize_pool`) + - Tool registration (`register_tools`) + - Query execution (`_execute_query`) + +2. **src/config.py** - Configuration loading + - Environment variables validation + - Logging setup + - Embedding provider configuration + +3. **src/embeddings.py** - Embedding service + - Provider initialization (OpenAI, Gemini, HuggingFace) + - Model dimension lookup + - Embedding generation + +4. **logs/mcp_server.log** - Server logs + +## Common Issues & Solutions + +### Connection Pool Exhaustion +- **Symptom**: "Database connection pool not available" +- **Check**: `MCP_MAX_POOL_SIZE` in .env (default: 10) +- **Solution**: Increase pool size or check for connection leaks + +### Embedding Service Failures +- **Symptom**: "Embedding provider not configured" or API errors +- **Check**: `EMBEDDING_PROVIDER` must be: openai, gemini, or huggingface +- **Verify**: Corresponding API key is set (OPENAI_API_KEY, GEMINI_API_KEY, or HF_MODEL) + +### Read-Only Mode Violations +- **Symptom**: "Operation forbidden: Server is in read-only mode" +- **Check**: `MCP_READ_ONLY` environment variable +- **Note**: Only SELECT, SHOW, DESCRIBE queries allowed when true + +### Vector Store Creation Fails +- **Symptom**: "Failed to create vector store" +- **Check**: + - Database exists and user has CREATE TABLE permission + - Embedding dimension matches model (e.g., text-embedding-3-small = 1536) + - MariaDB version supports VECTOR type + +### Tool Not Registered +- **Symptom**: Tool not found errors +- **Check**: EMBEDDING_PROVIDER must be set for vector tools +- **Verify**: Pool initialized before tool registration + +### Connection Timeout +- **Symptom**: Queries hang or timeout errors +- **Check**: `DB_CONNECT_TIMEOUT`, `DB_READ_TIMEOUT`, `DB_WRITE_TIMEOUT` in .env +- **Defaults**: 10s connect, 30s read/write +- **Solution**: Increase timeout values or check database server load + +### Large Result Sets +- **Symptom**: Memory errors or slow responses +- **Check**: `MCP_MAX_RESULTS` in .env (default: 10000) +- **Solution**: Add LIMIT to queries or reduce MCP_MAX_RESULTS + +### Embedding Rate Limiting +- **Symptom**: API quota exceeded or 429 errors +- **Check**: `EMBEDDING_MAX_CONCURRENT` in .env (default: 5) +- **Solution**: Reduce concurrent limit or upgrade API plan + +### Health Check Failures (Docker) +- **Symptom**: Container marked unhealthy +- **Check**: `/health` endpoint returns 503 +- **Verify**: Database connection pool is initialized +- **Solution**: Check DB credentials and network connectivity + +### Multiple Database Config Mismatch +- **Symptom**: Warning about array length mismatch +- **Check**: `DB_HOSTS`, `DB_USERS`, `DB_PASSWORDS` must have same length +- **Solution**: Ensure comma-separated values align across all multi-DB env vars + +### Metadata JSON Parse Errors +- **Symptom**: Warning logs about metadata parsing +- **Check**: `logs/mcp_server.log` for JSON decode errors +- **Solution**: Verify metadata stored correctly in vector store + +## Debugging Commands + +```bash +# Check server logs +tail -f logs/mcp_server.log + +# Test database connection +uv run python -c "from config import *; print(f'DB: {DB_HOST}:{DB_PORT}')" + +# Verify environment +uv run python -c "from config import *; print(f'Provider: {EMBEDDING_PROVIDER}')" + +# Run tests +uv run -m pytest src/tests/ -v +``` + +## Environment Variables Reference + +| Variable | Required | Default | Description | +|----------|----------|---------|-------------| +| DB_HOST | Yes | localhost | MariaDB host | +| DB_PORT | No | 3306 | MariaDB port | +| DB_USER | Yes | - | Database username | +| DB_PASSWORD | Yes | - | Database password | +| DB_CONNECT_TIMEOUT | No | 10 | Connection timeout (seconds) | +| DB_READ_TIMEOUT | No | 30 | Read timeout (seconds) | +| DB_WRITE_TIMEOUT | No | 30 | Write timeout (seconds) | +| MCP_READ_ONLY | No | true | Enforce read-only | +| MCP_MAX_POOL_SIZE | No | 10 | Max connections in pool | +| MCP_MAX_RESULTS | No | 10000 | Max rows per query | +| EMBEDDING_PROVIDER | No | None | openai/gemini/huggingface | +| EMBEDDING_MAX_CONCURRENT | No | 5 | Max concurrent embedding calls | diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..44bd13d --- /dev/null +++ b/.dockerignore @@ -0,0 +1,26 @@ +# Virtual environments +.venv/ +venv/ +__pycache__/ + +# Git +.git/ +.gitignore + +# IDE +.vscode/ +.idea/ +.claude/ + +# Logs and caches +logs/ +*.log +*.pyc +*.pyo + +# Downloaded data +scripts/geography_data/ + +# Local env files (use .env in container) +.env.local +.env.*.local diff --git a/.github/copilot-instructions.md b/.github/copilot-instructions.md new file mode 100644 index 0000000..8229701 --- /dev/null +++ b/.github/copilot-instructions.md @@ -0,0 +1,78 @@ +## Quick orientation for AI coding agents + +This repository implements a Model Context Protocol (MCP) server that exposes MariaDB-focused tools and optional vector/embedding features. + +- Entry points & important files: + - `src/server.py` — main MCP server implementation and tool definitions (list_databases, list_tables, execute_sql, vector-store tools, etc.). Read this first to understand available tools and their contracts. + - `src/embeddings.py` — provider-agnostic EmbeddingService (OpenAI, Gemini, HuggingFace). Embedding clients are initialized at runtime based on env config. + - `src/config.py` — loads `.env` and environment variables; contains defaults (notably `MCP_READ_ONLY` default=true) and validation that can raise on missing keys. + - `src/tests/` — integration-style tests that demonstrate how the server is started and how the FastMCP client calls tools. Useful runnable examples. + - `README.md` — installation, run commands and example tool payloads (useful to replicate CLI behavior). + +## Big-picture architecture (short) + +- FastMCP-based server: `MariaDBServer` builds a `FastMCP` instance and registers tools. Tools are asynchronous methods on `MariaDBServer`. +- Database access: Uses `asyncmy` connection pool. Pool is created by `MariaDBServer.initialize_pool()` and used by `_execute_query()` for all SQL operations. +- Embeddings: Optional feature toggled by `EMBEDDING_PROVIDER` in env. `EmbeddingService` supports OpenAI, Gemini, and HuggingFace. When disabled, all vector-store tools should be treated as unavailable. +- Vector-store implementation: persisted in MariaDB tables (VECTOR column + VECTOR INDEX). The server uses information_schema queries to validate existence and structure of vector stores. + +Why certain choices matter for edits: +- `config.py` reads env at import time and will raise if required embedding keys are missing — set env before importing modules in tests or scripts. +- `MCP_READ_ONLY` influences `self.autocommit` and `_execute_query` enforcement: code blocks non-read-only queries when read-only mode is enabled. + +## Developer workflows and concrete commands + +- Python version: 3.11 (see `pyproject.toml`). +- Dependency & environment setup (as in README): + - Install `uv` and sync dependencies: + ```bash + pip install uv + uv lock + uv sync + ``` +- Run server (examples shown in README): + - stdio (default): `uv run server.py` + - SSE transport: `uv run server.py --transport sse --host 127.0.0.1 --port 9001` + - HTTP transport: `uv run server.py --transport http --host 127.0.0.1 --port 9001 --path /mcp` +- Tests: tests live in `src/tests/` and use `unittest.IsolatedAsyncioTestCase` with `anyio` and `fastmcp.client.Client`. They start the server in-process by calling `MariaDBServer.run_async_server('stdio')` and then call tools through `Client(self.server.mcp)`. Run them with your preferred runner, e.g.: + ```bash + # With unittest discovery + python -m unittest discover -s src/tests + ``` + +## Project-specific patterns & gotchas for agents + +- Environment-on-import: `config.py` performs validation and logs/raises if required env vars are not set (e.g., DB_USER/DB_PASSWORD, provider-specific API keys). Always ensure env is configured before importing modules. +- Read-only enforcement: `_execute_query()` strips comments and checks an allow-list of SQL prefixes (`SELECT`, `SHOW`, `DESC`, `DESCRIBE`, `USE`). Any mutation must either run with `MCP_READ_ONLY=false` or be explicitly implemented as a server tool that bypasses that check (not recommended). +- Validation via information_schema: many tools check existence and vector-store status using `information_schema` queries — prefer reproducing those queries when writing migrations or tools that manipulate schema. +- Embedding service lifecycle: `EmbeddingService` may try to import provider SDKs on init and raise if missing; tests and CI should supply the right env or mock the service. + +## Integration & external dependencies + +- DB: MariaDB reachable via `DB_HOST`, `DB_PORT`, `DB_USER`, `DB_PASSWORD`. `DB_NAME` is optional; many tools accept `database_name` parameter. +- Embedding providers: + - `openai` (requires `OPENAI_API_KEY`) — uses `openai` AsyncOpenAI client when available. + - `gemini` (requires `GEMINI_API_KEY`) — uses `google.genai` when present. + - `huggingface` (requires `HF_MODEL`) — uses `sentence-transformers` and may dynamically load models. +- Logs: default file at `logs/mcp_server.log` (configurable via env). Use this for debugging server startup or tool call failures. + +## Examples extracted from the codebase + +- How tests start the server (see `src/tests/test_mcp_server.py`): + - Instantiate server: `server = MariaDBServer(autocommit=False)` + - Start background server task: `tg.start_soon(server.run_async_server, 'stdio')` + - Create client: `client = fastmcp.client.Client(server.mcp)` and call `await client.call_tool('list_databases', {})`. + +- Tool payload example (from README): + ```json + {"tool":"execute_sql","parameters":{"database_name":"test_db","sql_query":"SELECT * FROM users WHERE id = %s","parameters":[123]}} + ``` + +## Short checklist for code changes + +1. Ensure required env vars are set before imports (or mock config/EmbeddingService in tests). +2. If adding SQL tools, follow `_execute_query()`'s comment-stripping + prefix checks; avoid enabling writes unless intended. +3. If changing embedding behavior, reference `src/embeddings.py` model lists and `pyproject.toml` dependencies — CI must install required SDKs. +4. Run unit/integration tests in `src/tests/` using unittest discovery or pytest if present. + +If anything in this document is unclear or you'd like more concrete examples (unit test scaffolds, CI matrix entries, or mock patterns for `EmbeddingService`), tell me which section to expand and I'll iterate. diff --git a/.gitignore b/.gitignore index 5bb1bcd..d4dcaa5 100644 --- a/.gitignore +++ b/.gitignore @@ -9,5 +9,6 @@ src/logs/* .env uv.lock .DS_Store -.env -.env + +# Downloaded data files +scripts/geography_data/ diff --git a/.python-version b/.python-version index 2c07333..24ee5b1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11 +3.13 diff --git a/CLAUDE.md b/CLAUDE.md new file mode 100644 index 0000000..e5d18ea --- /dev/null +++ b/CLAUDE.md @@ -0,0 +1,134 @@ +# CLAUDE.md + +This file provides guidance to Claude Code (claude.ai/code) when working with code in this repository. + +## Project Overview + +MariaDB MCP Server - A Model Context Protocol (MCP) server providing an interface for AI assistants to interact with MariaDB databases. Supports standard SQL operations and optional vector/embedding-based semantic search. + +**Requirements:** Python 3.13+, MariaDB 11.7+ (for vector store features with `UUID_v7()`) + +## Development Commands + +```bash +# Setup (uses uv package manager) +pip install uv && uv sync + +# Run server (stdio - default for MCP clients) +uv run src/server.py + +# Run server with SSE/HTTP transport +uv run src/server.py --transport sse --host 127.0.0.1 --port 9001 +uv run src/server.py --transport http --host 127.0.0.1 --port 9001 --path /mcp + +# Run tests (requires live MariaDB with configured .env) +uv run -m pytest src/tests/ -v +uv run -m pytest src/tests/test_mcp_server.py::TestMariaDBMCPTools::test_list_databases + +# Docker +docker compose up --build +docker compose logs -f mariadb-mcp + +# Check server logs +tail -f logs/mcp_server.log +``` + +## Architecture + +### Core Components + +- **`src/server.py`**: `MariaDBServer` class using FastMCP. Contains all MCP tool definitions, connection pool management, and query execution. Entry point via `anyio.run()` with `functools.partial`. + +- **`src/config.py`**: Loads environment/.env configuration at import time. Sets up logging (console + rotating file at `logs/mcp_server.log`). Validates credentials and embedding provider, raising `ValueError` if required keys are missing. + +- **`src/embeddings.py`**: `EmbeddingService` class supporting OpenAI, Gemini, and HuggingFace providers. HuggingFace models are pre-loaded at init; Gemini uses `asyncio.to_thread()` since SDK lacks async. + +### Key Design Patterns + +1. **Connection Pooling**: Uses `asyncmy` pool. Supports multiple databases via comma-separated env vars: + - `DB_HOSTS`, `DB_PORTS`, `DB_USERS`, `DB_PASSWORDS`, `DB_NAMES`, `DB_CHARSETS` + - First connection becomes default pool; others stored in `self.pools` dict keyed by `host:port` + +2. **Read-Only Mode**: `MCP_READ_ONLY=true` (default) allows only SELECT/SHOW/DESCRIBE/USE. SQL comments (`--` and `/* */`) are stripped via regex in `_execute_query()` before checking query prefix to prevent bypass attempts. + +3. **Conditional Tool Registration**: Vector store tools only registered when `EMBEDDING_PROVIDER` is set. Check in `register_tools()` method (`if EMBEDDING_PROVIDER is not None`). + +4. **Singleton EmbeddingService**: Created at module load only when `EMBEDDING_PROVIDER` is configured. Used by all vector store tools. Embedding dimensions vary by model: OpenAI text-embedding-3-small=1536, large=3072; Gemini=768; HuggingFace varies by model (e.g., BGE-M3=1024). + +5. **Middleware Stack**: HTTP/SSE transports use Starlette middleware for CORS (`ALLOWED_ORIGINS`) and trusted host filtering (`ALLOWED_HOSTS`). + +### MCP Tools + +**Standard:** `list_databases`, `list_tables`, `get_table_schema`, `get_table_schema_with_relations`, `execute_sql`, `create_database` + +**Vector Store (requires EMBEDDING_PROVIDER):** `create_vector_store`, `delete_vector_store`, `list_vector_stores`, `insert_docs_vector_store`, `search_vector_store` + +### Vector Store Table Schema + +```sql +-- Requires MariaDB 11.7+ for UUID_v7() and VECTOR type +CREATE TABLE vector_store_name ( + id VARCHAR(36) NOT NULL DEFAULT UUID_v7() PRIMARY KEY, + document TEXT NOT NULL, + embedding VECTOR(dimension) NOT NULL, + metadata JSON NOT NULL, + VECTOR INDEX (embedding) DISTANCE=COSINE +); +``` + +## Configuration + +**Required:** `DB_USER`, `DB_PASSWORD` + +**Database:** `DB_HOST` (localhost), `DB_PORT` (3306), `DB_NAME`, `DB_CHARSET` + +**Timeouts:** `DB_CONNECT_TIMEOUT` (10s), `DB_READ_TIMEOUT` (30s), `DB_WRITE_TIMEOUT` (30s) + +**MCP:** `MCP_READ_ONLY` (true), `MCP_MAX_POOL_SIZE` (10), `MCP_MAX_RESULTS` (10000) + +**Security:** `ALLOWED_ORIGINS`, `ALLOWED_HOSTS` (for HTTP/SSE transports) + +**Embeddings:** `EMBEDDING_PROVIDER` (openai|gemini|huggingface), `EMBEDDING_MAX_CONCURRENT` (5), plus provider-specific key (`OPENAI_API_KEY`, `GEMINI_API_KEY`, `HF_MODEL`) + +**Logging:** `LOG_LEVEL` (INFO), `LOG_FILE` (logs/mcp_server.log), `LOG_MAX_BYTES` (10MB), `LOG_BACKUP_COUNT` (5) + +## Docker Health Checks + +Both containers have health checks configured in `docker-compose.yml`: +- **mariadb**: Uses `mariadb-admin ping` (note: MariaDB 11+ renamed `mysqladmin` to `mariadb-admin`) +- **mariadb-mcp**: Uses TCP socket connection check on port 9001 + +## Health Check & Metrics + +HTTP/SSE transports expose `/health` endpoint returning: +- `status`: "healthy" or "unhealthy" +- `uptime_seconds`: Server uptime +- `pool_status`: "connected" or "disconnected" +- `metrics`: Query counts, error rates, average query time, embedding counts + +## Code Quality Rules + +- **CRITICAL:** Always use parameterized queries (`%s` placeholders) - never concatenate SQL strings +- **CRITICAL:** Validate database/table names with `isidentifier()` before use in SQL +- All database operations must be `async` with `await` +- Log tool calls: `logger.info(f"TOOL START: ...")` / `logger.info(f"TOOL END: ...")` +- Catch `AsyncMyError` for database errors, `PermissionError` for read-only violations +- Vector store tests require `EMBEDDING_PROVIDER` configured +- Use backtick quoting for identifiers in SQL: `` `database_name`.`table_name` `` + +## Custom Commands + +- `/project:fix-issue ` - Fix GitHub issue with full workflow +- `/project:review-pr ` - Review a pull request + +## Skills + +- `mariadb-debug` - Debug database connectivity, embedding errors, MCP tool failures + +## Testing Notes + +- Tests in `src/tests/` use unittest framework with pytest runner +- Integration tests require live MariaDB with configured `.env` +- Tests start server with stdio transport using FastMCP test client +- Vector store tests require `EMBEDDING_PROVIDER` to be configured +- Run single test: `uv run -m pytest src/tests/test_mcp_server.py::TestClass::test_method -v` diff --git a/Dockerfile b/Dockerfile index 3c0eb92..04ae6df 100644 --- a/Dockerfile +++ b/Dockerfile @@ -1,23 +1,25 @@ -FROM python:3.11-slim AS builder +FROM python:3.13-slim AS builder -# Install build dependencies and curl for uv installer +# Install build dependencies for compiling packages RUN apt-get update && apt-get install -y --no-install-recommends \ - build-essential curl ca-certificates \ + build-essential ca-certificates \ && rm -rf /var/lib/apt/lists/* -# Install uv -RUN curl -fsSL https://astral.sh/uv/install.sh | sh -ENV PATH="/root/.local/bin:${PATH}" - WORKDIR /app # Copy project files COPY . . -# Install project dependencies into a local venv +# Install uv and use it to install dependencies from pyproject.toml/uv.lock +RUN pip install --no-cache-dir uv +ENV PATH="/app/.venv/bin:${PATH}" RUN uv sync --no-dev -FROM python:3.11-slim +FROM python:3.13-slim + +# Install curl for healthcheck +RUN apt-get update && apt-get install -y --no-install-recommends curl && \ + rm -rf /var/lib/apt/lists/* WORKDIR /app ENV PATH="/app/.venv/bin:${PATH}" @@ -25,7 +27,12 @@ ENV PATH="/app/.venv/bin:${PATH}" # Copy venv and app from builder COPY --from=builder /app/.venv /app/.venv COPY --from=builder /app/src /app/src +COPY --from=builder /app/scripts /app/scripts EXPOSE 9001 -CMD ["python", "src/server.py", "--host", "0.0.0.0", "--transport", "sse"] +# Add healthcheck +HEALTHCHECK --interval=30s --timeout=10s --start-period=5s --retries=3 \ + CMD curl -f http://localhost:9001/health || exit 1 + +CMD ["python", "src/server.py", "--host", "0.0.0.0", "--port", "9001", "--transport", "sse"] diff --git a/README.md b/README.md index 3bfe4c7..a2b9d0d 100644 --- a/README.md +++ b/README.md @@ -200,7 +200,7 @@ export FASTMCP_SERVER_AUTH_GOOGLE_CLIENT_SECRET="GOCSPX-..." ### Requirements -- **Python 3.11** (see `.python-version`) +- **Python 3.13+** (see `.python-version`) - **uv** (dependency manager; [install instructions](https://github.com/astral-sh/uv)) - MariaDB server (local or remote) diff --git a/docker-compose.yml b/docker-compose.yml index 6d5d8f3..0fe745f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -1,25 +1,59 @@ +name: mcp-mariadb + services: mariadb-server: - image: mariadb:11 + image: mariadb:11.2 container_name: mariadb environment: MARIADB_ROOT_PASSWORD: rootpassword123 MARIADB_DATABASE: demo MARIADB_USER: user MARIADB_PASSWORD: password123 + MARIADB_ALLOW_EMPTY_ROOT_PASSWORD: "no" ports: - "3306:3306" + volumes: + - mariadb_data:/var/lib/mysql healthcheck: - test: ["CMD", "mariadb-admin", "ping", "-h", "127.0.0.1", "-p'rootpassword123'"] + test: ["CMD-SHELL", "mariadb-admin ping -h 127.0.0.1 -u root -prootpassword123 || exit 1"] interval: 5s timeout: 3s retries: 10 + restart: unless-stopped + networks: + - mcp-network + mariadb-mcp: - build: . + build: + context: . + dockerfile: Dockerfile container_name: mariadb-mcp - env_file: .env + environment: + - DB_HOST=mariadb + - DB_USER=user + - DB_PASSWORD=password123 + - DB_PORT=3306 + - DB_NAME=demo + - MCP_READ_ONLY=true + - MCP_MAX_POOL_SIZE=10 ports: - "9001:9001" depends_on: mariadb-server: condition: service_healthy + healthcheck: + test: ["CMD-SHELL", "python -c \"import socket; s=socket.socket(); s.settimeout(3); s.connect(('localhost', 9001)); s.close()\""] + interval: 10s + timeout: 5s + retries: 5 + restart: unless-stopped + networks: + - mcp-network + +volumes: + mariadb_data: + driver: local + +networks: + mcp-network: + driver: bridge diff --git a/pyproject.toml b/pyproject.toml index 4c22b8b..df500f7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,9 +1,9 @@ [project] -name = "mariadb-server" +name = "mariadb-mcp" version = "0.2.2" description = "MariaDB MCP Server" readme = "README.md" -requires-python = ">=3.11" +requires-python = ">=3.13" dependencies = [ "asyncmy>=0.2.10", "fastmcp[websockets]==2.12.1", @@ -13,3 +13,13 @@ dependencies = [ "sentence-transformers>=4.1.0", "tokenizers==0.21.2", ] + +[dependency-groups] +dev = [ + "mypy>=1.19.1", + "pytest>=9.0.2", +] + +[tool.pytest.ini_options] +pythonpath = ["src"] +testpaths = ["src/tests"] diff --git a/scripts/populate_geography.py b/scripts/populate_geography.py new file mode 100644 index 0000000..9946828 --- /dev/null +++ b/scripts/populate_geography.py @@ -0,0 +1,250 @@ +#!/usr/bin/env python3 +""" +Download and populate MariaDB with geographical data. +Data source: https://github.com/dr5hn/countries-states-cities-database + +Tables: regions (continents), subregions, countries, states, cities +""" + +import os +import sys +import gzip +import tempfile +import urllib.request +from pathlib import Path + +# Add src to path for config import +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +from dotenv import load_dotenv +load_dotenv(Path(__file__).parent.parent / ".env") + +import asyncio +import asyncmy + +# Database config from environment +# Check if running inside Docker (container has /.dockerenv) +_inside_docker = os.path.exists("/.dockerenv") +_db_host = os.getenv("DB_HOST", "localhost") +# Only convert to localhost when running on host machine, not inside Docker +if not _inside_docker and _db_host in ("mariadb", "mysql", "db"): + DB_HOST = "127.0.0.1" +else: + DB_HOST = _db_host +DB_PORT = int(os.getenv("DB_PORT", "3306")) +# Use root credentials for database creation (can override with env vars) +DB_USER = os.getenv("DB_ROOT_USER", "root") +DB_PASSWORD = os.getenv("DB_ROOT_PASSWORD", "rootpassword123") +DATABASE_NAME = "geography" + +# GitHub raw URLs for SQL files +BASE_URL = "https://raw.githubusercontent.com/dr5hn/countries-states-cities-database/master/sql" +SQL_FILES = [ + ("schema.sql", False), + ("regions.sql", False), + ("subregions.sql", False), + ("countries.sql", False), + ("states.sql", False), + ("cities.sql.gz", True), # Compressed +] + + +def download_file(url: str, dest: Path, compressed: bool = False) -> Path: + """Download a file from URL to destination.""" + print(f"Downloading {url}...") + + if compressed: + # Download to temp file, decompress + with tempfile.NamedTemporaryFile(delete=False, suffix=".gz") as tmp: + urllib.request.urlretrieve(url, tmp.name) + print(f" Decompressing {dest.name}...") + with gzip.open(tmp.name, 'rb') as f_in: + with open(dest, 'wb') as f_out: + f_out.write(f_in.read()) + os.unlink(tmp.name) + else: + urllib.request.urlretrieve(url, dest) + + size_mb = dest.stat().st_size / (1024 * 1024) + print(f" Downloaded: {dest.name} ({size_mb:.2f} MB)") + return dest + + +def fix_sql_for_mariadb(content: str) -> str: + """Fix MySQL-specific syntax for MariaDB compatibility.""" + # Remove MySQL-specific SET statements that might cause issues + lines = content.split('\n') + filtered_lines = [] + + for line in lines: + # Skip problematic SET statements + if line.strip().startswith('SET ') and any(x in line for x in [ + 'GLOBAL', 'SESSION', 'sql_require_primary_key' + ]): + continue + filtered_lines.append(line) + + return '\n'.join(filtered_lines) + + +async def execute_sql_file(conn, filepath: Path, description: str): + """Execute SQL statements from a file.""" + print(f"\nImporting {description} from {filepath.name}...") + + content = filepath.read_text(encoding='utf-8') + content = fix_sql_for_mariadb(content) + + # Split into individual statements + statements = [] + current = [] + + for line in content.split('\n'): + stripped = line.strip() + + # Skip comments and empty lines + if not stripped or stripped.startswith('--') or stripped.startswith('/*'): + continue + + current.append(line) + + # Check for statement end + if stripped.endswith(';'): + stmt = '\n'.join(current).strip() + if stmt and not stmt.startswith('--'): + statements.append(stmt) + current = [] + + # Execute statements + async with conn.cursor() as cursor: + executed = 0 + errors = 0 + + for i, stmt in enumerate(statements): + try: + await cursor.execute(stmt) + executed += 1 + + # Progress indicator for large files + if (i + 1) % 1000 == 0: + print(f" Progress: {i + 1}/{len(statements)} statements...") + + except Exception as e: + errors += 1 + if errors <= 3: + print(f" Warning: {str(e)[:100]}") + elif errors == 4: + print(f" (suppressing further warnings...)") + + await conn.commit() + print(f" Completed: {executed} statements executed, {errors} errors") + + +async def main(): + """Main function to download and import geographical data.""" + + if not DB_USER or not DB_PASSWORD: + print("Error: DB_USER and DB_PASSWORD must be set in .env") + sys.exit(1) + + print("=" * 60) + print("Geography Database Population Script") + print("=" * 60) + print(f"Target: {DB_USER}@{DB_HOST}:{DB_PORT}/{DATABASE_NAME}") + print() + + # Create download directory + download_dir = Path(__file__).parent / "geography_data" + download_dir.mkdir(exist_ok=True) + + # Download all SQL files + print("Step 1: Downloading SQL files...") + downloaded_files = [] + for filename, compressed in SQL_FILES: + url = f"{BASE_URL}/{filename}" + dest_name = filename.replace('.gz', '') if compressed else filename + dest = download_dir / dest_name + + if dest.exists(): + print(f" Using cached: {dest_name}") + else: + download_file(url, dest, compressed) + + downloaded_files.append((dest, filename.replace('.sql.gz', '').replace('.sql', ''))) + + # Connect to MariaDB + print("\nStep 2: Connecting to MariaDB...") + try: + conn = await asyncmy.connect( + host=DB_HOST, + port=DB_PORT, + user=DB_USER, + password=DB_PASSWORD, + autocommit=False, + ) + print(f" Connected successfully") + except Exception as e: + print(f"Error connecting to MariaDB: {e}") + sys.exit(1) + + try: + async with conn.cursor() as cursor: + # Create database + print(f"\nStep 3: Creating database '{DATABASE_NAME}'...") + await cursor.execute(f"CREATE DATABASE IF NOT EXISTS `{DATABASE_NAME}` CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci") + await cursor.execute(f"USE `{DATABASE_NAME}`") + await conn.commit() + print(f" Database ready") + + # Disable foreign key checks during import + await cursor.execute("SET FOREIGN_KEY_CHECKS = 0") + await conn.commit() + + # Import files in order + print("\nStep 4: Importing data...") + + import_order = [ + ("schema", "Schema (table definitions)"), + ("regions", "Regions (continents)"), + ("subregions", "Subregions"), + ("countries", "Countries"), + ("states", "States/Provinces"), + ("cities", "Cities/Towns"), + ] + + for file_key, description in import_order: + filepath = download_dir / f"{file_key}.sql" + if filepath.exists(): + await execute_sql_file(conn, filepath, description) + else: + print(f" Skipping {file_key}: file not found") + + # Re-enable foreign key checks + async with conn.cursor() as cursor: + await cursor.execute("SET FOREIGN_KEY_CHECKS = 1") + await conn.commit() + + # Print summary + print("\n" + "=" * 60) + print("Import Complete! Summary:") + print("=" * 60) + + async with conn.cursor() as cursor: + tables = ['regions', 'subregions', 'countries', 'states', 'cities'] + for table in tables: + try: + await cursor.execute(f"SELECT COUNT(*) as cnt FROM `{table}`") + result = await cursor.fetchone() + count = result[0] if result else 0 + print(f" {table}: {count:,} records") + except Exception as e: + print(f" {table}: Error - {e}") + + print("\nDatabase is ready to use!") + print(f"Connection: mysql -u {DB_USER} -p -h {DB_HOST} -P {DB_PORT} {DATABASE_NAME}") + + finally: + await conn.ensure_closed() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/test_mcp_client.py b/scripts/test_mcp_client.py new file mode 100644 index 0000000..ee039d8 --- /dev/null +++ b/scripts/test_mcp_client.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +""" +Test MCP server via SSE transport (connects to running container). +""" + +import asyncio +from mcp import ClientSession +from mcp.client.sse import sse_client + +MCP_URL = "http://localhost:9001/sse" + + +async def main(): + print("=" * 70) + print("MCP Client Test - Connecting to Container") + print("=" * 70) + print(f"URL: {MCP_URL}\n") + + async with sse_client(MCP_URL) as (read, write): + async with ClientSession(read, write) as session: + # Initialize the connection + await session.initialize() + print("✓ Connected to MCP server\n") + + # List available tools + print("1. Listing available tools...") + tools = await session.list_tools() + print(f" ✓ Found {len(tools.tools)} tools:") + for tool in tools.tools: + print(f" - {tool.name}") + + # Test list_databases + print("\n2. Calling list_databases()...") + result = await session.call_tool("list_databases", {}) + databases = result.content[0].text if result.content else "[]" + print(f" ✓ Result: {databases}") + + # Test list_tables on demo + print("\n3. Calling list_tables('demo')...") + result = await session.call_tool("list_tables", {"database_name": "demo"}) + tables = result.content[0].text if result.content else "[]" + print(f" ✓ Result: {tables}") + + # Test execute_sql - List continents + print("\n4. Querying continents from demo.regions...") + result = await session.call_tool("execute_sql", { + "sql_query": "SELECT name FROM regions ORDER BY name", + "database_name": "demo" + }) + print(f" ✓ Result: {result.content[0].text if result.content else 'empty'}") + + # Test execute_sql - Country count by continent + print("\n5. Querying country counts by continent...") + result = await session.call_tool("execute_sql", { + "sql_query": """ + SELECT r.name as continent, COUNT(c.id) as country_count + FROM regions r + LEFT JOIN countries c ON r.id = c.region_id + GROUP BY r.id, r.name + ORDER BY country_count DESC + """, + "database_name": "demo" + }) + print(f" ✓ Result: {result.content[0].text if result.content else 'empty'}") + + # Test execute_sql - Cities in a specific country + print("\n6. Querying cities in Switzerland...") + result = await session.call_tool("execute_sql", { + "sql_query": """ + SELECT ci.name as city, s.name as state, ci.population + FROM cities ci + JOIN states s ON ci.state_id = s.id + JOIN countries co ON ci.country_id = co.id + WHERE co.name = %s + ORDER BY ci.population DESC + LIMIT 10 + """, + "database_name": "demo", + "parameters": ["Switzerland"] + }) + print(f" ✓ Result: {result.content[0].text if result.content else 'empty'}") + + # Test get_table_schema + print("\n7. Getting schema for demo.countries...") + result = await session.call_tool("get_table_schema", { + "database_name": "demo", + "table_name": "countries" + }) + # Just show column count + import json + schema = json.loads(result.content[0].text) if result.content else {} + print(f" ✓ Found {len(schema)} columns") + print(f" ✓ Sample columns: {list(schema.keys())[:5]}") + + print("\n" + "=" * 70) + print("All MCP client tests passed!") + print("=" * 70) + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/scripts/test_mcp_geography.py b/scripts/test_mcp_geography.py new file mode 100644 index 0000000..ed8a894 --- /dev/null +++ b/scripts/test_mcp_geography.py @@ -0,0 +1,156 @@ +#!/usr/bin/env python3 +""" +Test script to query the geography database via MCP server. +Runs the MCP server tools directly (not via network). +""" + +import sys +import asyncio +from pathlib import Path + +# Add src to path +sys.path.insert(0, str(Path(__file__).parent.parent / "src")) + +# Patch environment for local testing +import os +# Detect if running inside Docker container +if not os.path.exists("/.dockerenv"): + # Running on host - use localhost to connect via mapped port + os.environ["DB_HOST"] = "127.0.0.1" +# Inside Docker, keep DB_HOST as-is (uses docker network hostname) + +from server import MariaDBServer + + +async def main(): + print("=" * 70) + print("MCP Geography Database Test") + print("=" * 70) + + # Create server instance + server = MariaDBServer("Geography_Test") + + try: + # Initialize connection pool + print("\n1. Initializing connection pool...") + await server.initialize_pool() + print(" ✓ Pool initialized") + + # Test list_databases + print("\n2. Testing list_databases()...") + databases = await server.list_databases() + print(f" ✓ Found {len(databases)} databases:") + for db in databases: + marker = " ← target" if db == "geography" else "" + print(f" - {db}{marker}") + + if "geography" not in databases: + print(" ✗ 'geography' database not found!") + return + + # Test list_tables + print("\n3. Testing list_tables('geography')...") + tables = await server.list_tables("geography") + print(f" ✓ Found {len(tables)} tables:") + for table in tables: + print(f" - {table}") + + # Test get_table_schema + print("\n4. Testing get_table_schema('geography', 'countries')...") + schema = await server.get_table_schema("geography", "countries") + print(f" ✓ Countries table has {len(schema)} columns:") + for col, info in list(schema.items())[:8]: + print(f" - {col}: {info['type']}") + print(f" ... and {len(schema) - 8} more columns") + + # Test execute_sql - List continents + print("\n5. Testing execute_sql() - List all continents...") + results = await server.execute_sql( + "SELECT id, name FROM regions ORDER BY name", + "geography" + ) + print(f" ✓ Found {len(results)} continents:") + for row in results: + print(f" - {row['name']}") + + # Test execute_sql - Countries by continent + print("\n6. Testing execute_sql() - Countries in Europe...") + results = await server.execute_sql( + """ + SELECT c.name as country, c.capital, c.population + FROM countries c + JOIN regions r ON c.region_id = r.id + WHERE r.name = %s + ORDER BY c.population DESC + LIMIT 10 + """, + "geography", + parameters=["Europe"] + ) + print(f" ✓ Top 10 European countries by population:") + for row in results: + pop = f"{row['population']:,}" if row['population'] else "N/A" + print(f" - {row['country']} (capital: {row['capital']}, pop: {pop})") + + # Test execute_sql - Cities search + print("\n7. Testing execute_sql() - Cities named 'Paris'...") + results = await server.execute_sql( + """ + SELECT ci.name as city, s.name as state, co.name as country + FROM cities ci + JOIN states s ON ci.state_id = s.id + JOIN countries co ON ci.country_id = co.id + WHERE ci.name = %s + ORDER BY co.name + """, + "geography", + parameters=["Paris"] + ) + print(f" ✓ Found {len(results)} cities named 'Paris':") + for row in results: + print(f" - {row['city']}, {row['state']}, {row['country']}") + + # Test execute_sql - Large cities + print("\n8. Testing execute_sql() - Largest cities worldwide...") + results = await server.execute_sql( + """ + SELECT ci.name as city, co.name as country, ci.population + FROM cities ci + JOIN countries co ON ci.country_id = co.id + WHERE ci.population IS NOT NULL + ORDER BY ci.population DESC + LIMIT 10 + """, + "geography" + ) + print(f" ✓ Top 10 largest cities:") + for row in results: + pop = f"{row['population']:,}" if row['population'] else "N/A" + print(f" - {row['city']}, {row['country']} (pop: {pop})") + + # Test get_table_schema_with_relations + print("\n9. Testing get_table_schema_with_relations('geography', 'cities')...") + schema = await server.get_table_schema_with_relations("geography", "cities") + fk_cols = [col for col, info in schema['columns'].items() if info.get('foreign_key')] + print(f" ✓ Cities table foreign keys:") + for col in fk_cols: + fk = schema['columns'][col]['foreign_key'] + print(f" - {col} → {fk['referenced_table']}.{fk['referenced_column']}") + + # Summary statistics + print("\n10. Database summary statistics...") + for table in ['regions', 'subregions', 'countries', 'states', 'cities']: + results = await server.execute_sql(f"SELECT COUNT(*) as cnt FROM {table}", "geography") + count = results[0]['cnt'] if results else 0 + print(f" - {table}: {count:,} records") + + print("\n" + "=" * 70) + print("All MCP tool tests passed successfully!") + print("=" * 70) + + finally: + await server.close_pool() + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/sql/geography_views.sql b/sql/geography_views.sql new file mode 100644 index 0000000..2f5801a --- /dev/null +++ b/sql/geography_views.sql @@ -0,0 +1,133 @@ +-- Geography Database Views +-- For querying cities by country, region, and subregion +-- Compatible with both 'geography' and 'demo' databases + +-- ============================================================================= +-- v_cities_full: Complete city information with full geographic hierarchy +-- ============================================================================= +-- Use: SELECT * FROM v_cities_full WHERE country_code = 'US' LIMIT 10; +CREATE OR REPLACE VIEW v_cities_full AS +SELECT + ci.id AS city_id, + ci.name AS city_name, + ci.population AS city_population, + ci.latitude AS city_lat, + ci.longitude AS city_lng, + ci.timezone, + s.id AS state_id, + s.name AS state_name, + s.iso2 AS state_code, + co.id AS country_id, + co.name AS country_name, + co.iso2 AS country_code, + co.iso3 AS country_iso3, + co.capital, + co.currency, + sr.id AS subregion_id, + sr.name AS subregion_name, + r.id AS region_id, + r.name AS region_name +FROM cities ci +JOIN states s ON ci.state_id = s.id +JOIN countries co ON ci.country_id = co.id +LEFT JOIN subregions sr ON co.subregion_id = sr.id +LEFT JOIN regions r ON co.region_id = r.id; + +-- ============================================================================= +-- v_cities_by_country: Simple view for querying cities by country +-- ============================================================================= +-- Use: SELECT * FROM v_cities_by_country WHERE country_code = 'CH'; +-- Use: SELECT * FROM v_cities_by_country WHERE country = 'Switzerland'; +CREATE OR REPLACE VIEW v_cities_by_country AS +SELECT + ci.name AS city, + ci.population, + s.name AS state, + co.name AS country, + co.iso2 AS country_code +FROM cities ci +JOIN states s ON ci.state_id = s.id +JOIN countries co ON ci.country_id = co.id +ORDER BY co.name, ci.population DESC; + +-- ============================================================================= +-- v_cities_by_region: View for querying cities by region/subregion +-- ============================================================================= +-- Use: SELECT * FROM v_cities_by_region WHERE region = 'Europe'; +-- Use: SELECT * FROM v_cities_by_region WHERE subregion = 'Western Europe'; +CREATE OR REPLACE VIEW v_cities_by_region AS +SELECT + r.name AS region, + sr.name AS subregion, + co.name AS country, + s.name AS state, + ci.name AS city, + ci.population +FROM cities ci +JOIN states s ON ci.state_id = s.id +JOIN countries co ON ci.country_id = co.id +LEFT JOIN subregions sr ON co.subregion_id = sr.id +LEFT JOIN regions r ON co.region_id = r.id +ORDER BY r.name, sr.name, co.name, ci.population DESC; + +-- ============================================================================= +-- v_country_stats: Country statistics with city counts +-- ============================================================================= +-- Use: SELECT * FROM v_country_stats WHERE region = 'Africa'; +-- Use: SELECT * FROM v_country_stats ORDER BY city_count DESC LIMIT 20; +CREATE OR REPLACE VIEW v_country_stats AS +SELECT + r.name AS region, + sr.name AS subregion, + co.name AS country, + co.iso2 AS country_code, + co.capital, + co.population AS country_population, + COUNT(ci.id) AS city_count, + SUM(ci.population) AS total_city_population +FROM countries co +LEFT JOIN cities ci ON ci.country_id = co.id +LEFT JOIN subregions sr ON co.subregion_id = sr.id +LEFT JOIN regions r ON co.region_id = r.id +GROUP BY r.name, sr.name, co.id, co.name, co.iso2, co.capital, co.population +ORDER BY r.name, sr.name, co.name; + +-- ============================================================================= +-- v_region_summary: Aggregated statistics by region and subregion +-- ============================================================================= +-- Use: SELECT * FROM v_region_summary; +-- Use: SELECT * FROM v_region_summary WHERE region = 'Asia'; +CREATE OR REPLACE VIEW v_region_summary AS +SELECT + r.name AS region, + sr.name AS subregion, + COUNT(DISTINCT co.id) AS country_count, + COUNT(DISTINCT s.id) AS state_count, + COUNT(ci.id) AS city_count, + SUM(ci.population) AS total_population +FROM regions r +LEFT JOIN subregions sr ON sr.region_id = r.id +LEFT JOIN countries co ON co.subregion_id = sr.id +LEFT JOIN states s ON s.country_id = co.id +LEFT JOIN cities ci ON ci.country_id = co.id +GROUP BY r.id, r.name, sr.id, sr.name +ORDER BY r.name, sr.name; + +-- ============================================================================= +-- Example Queries +-- ============================================================================= +-- +-- All cities in a specific country: +-- SELECT * FROM v_cities_by_country WHERE country_code = 'DE' ORDER BY population DESC LIMIT 20; +-- +-- All cities in a continent: +-- SELECT country, city, population FROM v_cities_by_region WHERE region = 'Africa' ORDER BY population DESC LIMIT 20; +-- +-- All cities in a subregion: +-- SELECT country, city, population FROM v_cities_by_region WHERE subregion = 'South America' ORDER BY population DESC; +-- +-- Country statistics for a region: +-- SELECT country, capital, city_count FROM v_country_stats WHERE region = 'Europe' ORDER BY city_count DESC; +-- +-- Full city details with coordinates: +-- SELECT city_name, state_name, country_name, city_lat, city_lng FROM v_cities_full WHERE country_code = 'JP' AND city_population > 1000000; diff --git a/src/config.py b/src/config.py index 85270c2..01ae797 100644 --- a/src/config.py +++ b/src/config.py @@ -14,15 +14,15 @@ LOG_MAX_BYTES = int(os.getenv("LOG_MAX_BYTES", 10 * 1024 * 1024)) LOG_BACKUP_COUNT = int(os.getenv("LOG_BACKUP_COUNT", 5)) -ALLOWED_ORIGINS = os.getenv("ALLOWED_ORIGINS") -if ALLOWED_ORIGINS: - ALLOWED_ORIGINS = ALLOWED_ORIGINS.split(",") +_allowed_origins_env = os.getenv("ALLOWED_ORIGINS") +if _allowed_origins_env: + ALLOWED_ORIGINS: list[str] = _allowed_origins_env.split(",") else: ALLOWED_ORIGINS = ["http://localhost", "http://127.0.0.1", "http://*", "https://localhost", "https://127.0.0.1", "vscode-file://vscode-app"] -ALLOWED_HOSTS = os.getenv("ALLOWED_HOSTS") -if ALLOWED_HOSTS: - ALLOWED_HOSTS = ALLOWED_HOSTS.split(",") +_allowed_hosts_env = os.getenv("ALLOWED_HOSTS") +if _allowed_hosts_env: + ALLOWED_HOSTS: list[str] = _allowed_hosts_env.split(",") else: ALLOWED_HOSTS = ["localhost", "127.0.0.1"] @@ -58,19 +58,46 @@ logger = logging.getLogger(__name__) # --- Database Configuration --- -DB_HOST = os.getenv("DB_HOST", "localhost") -DB_PORT = int(os.getenv("DB_PORT", 3306)) -DB_USER = os.getenv("DB_USER") -DB_PASSWORD = os.getenv("DB_PASSWORD") -DB_NAME = os.getenv("DB_NAME") -DB_CHARSET = os.getenv("DB_CHARSET") +# Support multiple databases via comma-separated values +DB_HOSTS = os.getenv("DB_HOSTS", os.getenv("DB_HOST", "localhost")).split(",") +DB_PORTS = [int(p) for p in os.getenv("DB_PORTS", os.getenv("DB_PORT", "3306")).split(",")] +DB_USERS = os.getenv("DB_USERS", os.getenv("DB_USER", "")).split(",") +DB_PASSWORDS = os.getenv("DB_PASSWORDS", os.getenv("DB_PASSWORD", "")).split(",") +DB_NAMES = os.getenv("DB_NAMES", os.getenv("DB_NAME", "")).split(",") +DB_CHARSETS = os.getenv("DB_CHARSETS", os.getenv("DB_CHARSET", "")).split(",") + +# Validate multiple database configuration - arrays should match in length +if len(DB_HOSTS) > 1: + _min_len = min(len(DB_HOSTS), len(DB_USERS), len(DB_PASSWORDS)) + if len(DB_HOSTS) != len(DB_USERS) or len(DB_HOSTS) != len(DB_PASSWORDS): + logger.warning( + f"Multiple database config length mismatch: " + f"DB_HOSTS={len(DB_HOSTS)}, DB_USERS={len(DB_USERS)}, DB_PASSWORDS={len(DB_PASSWORDS)}. " + f"Using first {_min_len} entries." + ) + DB_HOSTS = DB_HOSTS[:_min_len] + DB_USERS = DB_USERS[:_min_len] + DB_PASSWORDS = DB_PASSWORDS[:_min_len] + +# Legacy single database support +DB_HOST = DB_HOSTS[0] +DB_PORT = DB_PORTS[0] if DB_PORTS else 3306 +DB_USER = DB_USERS[0] if DB_USERS else None +DB_PASSWORD = DB_PASSWORDS[0] if DB_PASSWORDS else None +DB_NAME = DB_NAMES[0] if DB_NAMES else None +DB_CHARSET = DB_CHARSETS[0] if DB_CHARSETS and DB_CHARSETS[0] else None + +# --- Database Timeout Configuration --- +DB_CONNECT_TIMEOUT = int(os.getenv("DB_CONNECT_TIMEOUT", 10)) # seconds # --- MCP Server Configuration --- # Read-only mode MCP_READ_ONLY = os.getenv("MCP_READ_ONLY", "true").lower() == "true" MCP_MAX_POOL_SIZE = int(os.getenv("MCP_MAX_POOL_SIZE", 10)) +MCP_MAX_RESULTS = int(os.getenv("MCP_MAX_RESULTS", 10000)) # Max rows returned per query # --- Embedding Configuration --- +EMBEDDING_MAX_CONCURRENT = int(os.getenv("EMBEDDING_MAX_CONCURRENT", 5)) # Max concurrent embedding API calls # Provider selection ('openai' or 'gemini' or 'huggingface') EMBEDDING_PROVIDER = os.getenv("EMBEDDING_PROVIDER") EMBEDDING_PROVIDER = EMBEDDING_PROVIDER.lower() if EMBEDDING_PROVIDER else None diff --git a/src/embeddings.py b/src/embeddings.py index e16ea1c..6d81ae2 100644 --- a/src/embeddings.py +++ b/src/embeddings.py @@ -5,6 +5,9 @@ from typing import List, Optional, Dict, Any, Union, Awaitable import numpy as np +# Maximum number of HuggingFace models to keep in memory simultaneously +_HF_MODEL_CACHE_MAX_SIZE = 5 + # Import configuration variables and the logger instance from config import ( EMBEDDING_PROVIDER, @@ -86,7 +89,7 @@ def __init__(self): logger.info(f"Initializing EmbeddingService with provider: {self.provider}") if self.provider == "openai": - if not AsyncOpenAI: + if AsyncOpenAI is None: logger.error("OpenAI provider selected, but 'openai' library is not installed.") raise ImportError("OpenAI library not found. Please install it.") if not OPENAI_API_KEY: @@ -121,15 +124,19 @@ def __init__(self): raise ValueError("HuggingFace model (HF_MODEL) is required in config for the HuggingFace provider.") try: from sentence_transformers import SentenceTransformer - + # The primary model for this service instance will be HF_MODEL from config - self.default_model = HF_MODEL + self.default_model = HF_MODEL self.allowed_models = ALLOWED_HF_MODELS # These are other models that can be specified via embed() - + + # Model cache for dynamically loaded HuggingFace models + self._hf_model_cache: Dict[str, Any] = {} + # Pre-load the default model from config logger.info(f"Initializing SentenceTransformer with configured HF_MODEL: {self.default_model}") - self.huggingface_client = SentenceTransformer(self.default_model) - # self.huggingface_client now holds the loaded model instance for config.HF_MODEL + self.huggingface_client = SentenceTransformer(self.default_model) + # Cache the default model as well + self._hf_model_cache[self.default_model] = self.huggingface_client logger.info(f"HuggingFace provider initialized. Default model (from config.HF_MODEL): '{self.default_model}'. Client loaded. Allowed models for override: {self.allowed_models}") @@ -308,38 +315,52 @@ async def embed(self, text: Union[str, List[str]], model_name: Optional[str] = N raise RuntimeError("HuggingFace client (SentenceTransformer) not initialized. Check service setup.") # target_model is already determined: model_name if valid, else self.default_model (which is config.HF_MODEL) - + embeddings_np: np.ndarray effective_model_name = target_model - if target_model == self.default_model: - logger.debug(f"Using pre-loaded HuggingFace model '{self.default_model}' for embedding.") - embeddings_np = self.huggingface_client.encode(texts) + # Check cache first for the requested model + if target_model in self._hf_model_cache: + logger.debug(f"Using cached HuggingFace model '{target_model}' for embedding.") + model_instance = self._hf_model_cache[target_model] + embeddings_np = model_instance.encode(texts) else: # A different model was requested via model_name, and it's valid (already checked in pre-amble of embed) - logger.info(f"Dynamically loading HuggingFace model '{target_model}' for this embed call (different from pre-loaded '{self.default_model}').") + logger.info(f"Loading and caching HuggingFace model '{target_model}' (different from default '{self.default_model}').") try: - # Ensure sentence_transformers is available for dynamic loading too - from sentence_transformers import SentenceTransformer - dynamic_model_loader = SentenceTransformer(target_model) - embeddings_np = dynamic_model_loader.encode(texts) + from sentence_transformers import SentenceTransformer + model_instance = SentenceTransformer(target_model) + # Evict oldest entry if cache is at capacity + if len(self._hf_model_cache) >= _HF_MODEL_CACHE_MAX_SIZE: + oldest_key = next(iter(self._hf_model_cache)) + del self._hf_model_cache[oldest_key] + logger.info(f"HF model cache full ({_HF_MODEL_CACHE_MAX_SIZE}), evicted '{oldest_key}'") + self._hf_model_cache[target_model] = model_instance + embeddings_np = model_instance.encode(texts) + logger.info(f"HuggingFace model '{target_model}' loaded and cached successfully.") except Exception as e: - logger.error(f"Failed to load or use dynamically specified HuggingFace model '{target_model}': {e}", exc_info=True) + logger.error(f"Failed to load HuggingFace model '{target_model}': {e}", exc_info=True) raise RuntimeError(f"Error with HuggingFace model '{target_model}': {e}") # Convert numpy array to list of lists of floats (or list of floats) - embeddings_list: Union[List[float], List[List[float]]] + embeddings_list: List[List[float]] if isinstance(embeddings_np, np.ndarray): - embeddings_list = embeddings_np.tolist() + raw_list = embeddings_np.tolist() + # Ensure we have a list of lists for batch processing + if raw_list and not isinstance(raw_list[0], list): + # Single embedding case - wrap in list + embeddings_list = [raw_list] + else: + embeddings_list = raw_list else: # Should ideally not happen with sentence-transformers if encode ran logger.warning("HuggingFace encode did not return a numpy array as expected.") - embeddings_list = texts # Fallback, though likely incorrect + raise RuntimeError("HuggingFace encoding failed to return valid embeddings.") + + logger.debug(f"HuggingFace embedding(s) with model '{effective_model_name}' received. Count: {len(embeddings_list)}, Dimension: {len(embeddings_list[0]) if embeddings_list else 'N/A'}") - logger.debug(f"HuggingFace embedding(s) with model '{effective_model_name}' received. Count: {len(embeddings_list)}, Dimension: {len(embeddings_list[0]) if embeddings_list and isinstance(embeddings_list[0], list) and embeddings_list[0] else (len(embeddings_list) if embeddings_list and not isinstance(embeddings_list[0], list) else 'N/A')}") - # Adjust return for single_input if single_input: - return embeddings_list[0] if embeddings_list and isinstance(embeddings_list, list) and embeddings_list[0] else embeddings_list + return embeddings_list[0] else: return embeddings_list else: diff --git a/src/main.py b/src/main.py index c0e4f04..a0d7346 100644 --- a/src/main.py +++ b/src/main.py @@ -1,5 +1,5 @@ def main(): - print("Hello from mcp-mariadb-server!") + print("Hello from mariadb-mcp!") if __name__ == "__main__": diff --git a/src/server.py b/src/server.py index 7b18e66..a3be77e 100644 --- a/src/server.py +++ b/src/server.py @@ -3,24 +3,31 @@ # Import configuration settings from config import ( DB_HOST, DB_PORT, DB_USER, DB_PASSWORD, DB_NAME, DB_CHARSET, + DB_HOSTS, DB_PORTS, DB_USERS, DB_PASSWORDS, DB_NAMES, DB_CHARSETS, MCP_READ_ONLY, MCP_MAX_POOL_SIZE, EMBEDDING_PROVIDER, ALLOWED_ORIGINS, ALLOWED_HOSTS, + DB_CONNECT_TIMEOUT, + EMBEDDING_MAX_CONCURRENT, MCP_MAX_RESULTS, logger ) import asyncio import argparse +import json import re +import time from typing import List, Dict, Any, Optional -from functools import partial +from functools import partial import asyncmy -import anyio +import anyio from fastmcp import FastMCP, Context from starlette.middleware import Middleware from starlette.middleware.cors import CORSMiddleware from starlette.middleware.trustedhost import TrustedHostMiddleware +from starlette.responses import JSONResponse +from starlette.routing import Route # Import EmbeddingService for vector store creation from embeddings import EmbeddingService @@ -32,17 +39,31 @@ from asyncmy.errors import Error as AsyncMyError +# Semaphore for rate limiting embedding API calls +_embedding_semaphore: Optional[asyncio.Semaphore] = None + # --- MariaDB MCP Server Class --- class MariaDBServer: """ MCP Server exposing tools to interact with a MariaDB database. Manages the database connection pool. """ - def __init__(self, server_name="MariaDB_Server", autocommit=True): + def __init__(self, server_name="MariaDB_Server"): self.mcp = FastMCP(server_name) self.pool: Optional[asyncmy.Pool] = None + self.pools: Dict[str, asyncmy.Pool] = {} # Multiple pools by connection name self.autocommit = not MCP_READ_ONLY self.is_read_only = MCP_READ_ONLY + self._current_db_cache: Dict[int, str] = {} # Cache database context per connection + # Metrics tracking + self._metrics = { + "queries_executed": 0, + "query_errors": 0, + "total_query_time_ms": 0, + "embeddings_generated": 0, + "pool_acquisitions": 0, + } + self._start_time = time.time() logger.info(f"Initializing {server_name}...") if self.is_read_only: logger.warning("Server running in READ-ONLY mode. Write operations are disabled.") @@ -68,9 +89,21 @@ async def create_vector_store(self, database_name: str, vector_store_name: str, async def initialize_pool(self): """Initializes the asyncmy connection pool within the running event loop.""" + global _embedding_semaphore + + # Initialize embedding semaphore for rate limiting + if EMBEDDING_PROVIDER is not None and _embedding_semaphore is None: + _embedding_semaphore = asyncio.Semaphore(EMBEDDING_MAX_CONCURRENT) + logger.info(f"Embedding rate limiter initialized (max concurrent: {EMBEDDING_MAX_CONCURRENT})") + + # Initialize multiple pools if configured + if len(DB_HOSTS) > 1: + await self.initialize_multiple_pools() + return + if not all([DB_USER, DB_PASSWORD]): - logger.error("Cannot initialize pool due to missing database credentials.") - raise ConnectionError("Missing database credentials for pool initialization.") + logger.error("Cannot initialize pool due to missing database credentials.") + raise ConnectionError("Missing database credentials for pool initialization.") if self.pool is not None: logger.info("Connection pool already initialized.") @@ -86,17 +119,21 @@ async def initialize_pool(self): "minsize": 1, "maxsize": MCP_MAX_POOL_SIZE, "autocommit": self.autocommit, - "pool_recycle": 3600 + "pool_recycle": 3600, + "connect_timeout": DB_CONNECT_TIMEOUT, } - + if DB_CHARSET: pool_params["charset"] = DB_CHARSET logger.info(f"Creating connection pool for {DB_USER}@{DB_HOST}:{DB_PORT}/{DB_NAME} (max size: {MCP_MAX_POOL_SIZE}, charset: {DB_CHARSET})") else: logger.info(f"Creating connection pool for {DB_USER}@{DB_HOST}:{DB_PORT}/{DB_NAME} (max size: {MCP_MAX_POOL_SIZE})") - + self.pool = await asyncmy.create_pool(**pool_params) - logger.info("Connection pool initialized successfully.") + + # Pool warmup - verify connection works + await self._warmup_pool() + logger.info("Connection pool initialized and validated successfully.") except AsyncMyError as e: logger.error(f"Failed to initialize database connection pool: {e}", exc_info=True) self.pool = None @@ -106,8 +143,84 @@ async def initialize_pool(self): self.pool = None raise + async def _warmup_pool(self): + """Validates the connection pool by executing a simple query.""" + if self.pool is None: + return + try: + async with self.pool.acquire() as conn: + async with conn.cursor() as cursor: + await cursor.execute("SELECT 1") + await cursor.fetchone() + logger.debug("Pool warmup successful - connection validated.") + except Exception as e: + logger.warning(f"Pool warmup query failed: {e}") + + async def initialize_multiple_pools(self): + """Initialize multiple database connection pools.""" + global _embedding_semaphore + + # Initialize embedding semaphore for rate limiting + if EMBEDDING_PROVIDER is not None and _embedding_semaphore is None: + _embedding_semaphore = asyncio.Semaphore(EMBEDDING_MAX_CONCURRENT) + logger.info(f"Embedding rate limiter initialized (max concurrent: {EMBEDDING_MAX_CONCURRENT})") + + logger.info(f"Initializing {len(DB_HOSTS)} database connection pools...") + + for i, host in enumerate(DB_HOSTS): + port = DB_PORTS[i] if i < len(DB_PORTS) else 3306 + user = DB_USERS[i] if i < len(DB_USERS) else None + password = DB_PASSWORDS[i] if i < len(DB_PASSWORDS) else None + db_name = DB_NAMES[i] if i < len(DB_NAMES) else None + charset = DB_CHARSETS[i] if i < len(DB_CHARSETS) and DB_CHARSETS[i] else None + + if not all([user, password]): + logger.warning(f"Skipping pool {i}: missing credentials for {host}") + continue + + conn_name = f"{host}:{port}" + try: + pool_params = { + "host": host, + "port": port, + "user": user, + "password": password, + "db": db_name, + "minsize": 1, + "maxsize": MCP_MAX_POOL_SIZE, + "autocommit": self.autocommit, + "pool_recycle": 3600, + "connect_timeout": DB_CONNECT_TIMEOUT, + } + if charset: + pool_params["charset"] = charset + + self.pools[conn_name] = await asyncmy.create_pool(**pool_params) + logger.info(f"Pool '{conn_name}' initialized for {user}@{host}:{port}/{db_name}") + + # Set first successful pool as default + if self.pool is None: + self.pool = self.pools[conn_name] + await self._warmup_pool() + logger.info(f"Default pool set to '{conn_name}'") + except Exception as e: + logger.error(f"Failed to initialize pool for {conn_name}: {e}", exc_info=True) + async def close_pool(self): """Closes the connection pool gracefully.""" + # Close multiple pools + if self.pools: + logger.info(f"Closing {len(self.pools)} database connection pools...") + for conn_name, pool in self.pools.items(): + try: + pool.close() + await pool.wait_closed() + logger.info(f"Pool '{conn_name}' closed.") + except Exception as e: + logger.error(f"Error closing pool '{conn_name}': {e}", exc_info=True) + self.pools.clear() + self.pool = None # Prevent double-close; default pool was already closed above + if self.pool: logger.info("Closing database connection pool...") try: @@ -119,65 +232,83 @@ async def close_pool(self): finally: self.pool = None - async def _execute_query(self, sql: str, params: Optional[tuple] = None, database: Optional[str] = None) -> List[Dict[str, Any]]: - """Helper function to execute SELECT queries using the pool.""" + async def _execute_query(self, sql: str, params: Optional[tuple] = None, database: Optional[str] = None, limit_results: bool = True) -> List[Dict[str, Any]]: + """Helper function to execute SELECT queries using the pool. + + Args: + sql: The SQL query to execute + params: Optional tuple of parameters for parameterized queries + database: Optional database to switch to before executing + limit_results: If True, limits results to MCP_MAX_RESULTS (default True) + + Returns: + List of result dictionaries + + Raises: + RuntimeError: If pool not available or database error + PermissionError: If query blocked by read-only mode + """ if self.pool is None: logger.error("Connection pool is not initialized.") raise RuntimeError("Database connection pool not available.") allowed_prefixes = ('SELECT', 'SHOW', 'DESC', 'DESCRIBE', 'USE') - + # Strip SQL comments from query # Remove single-line comments (-- comment) sql_no_comments = re.sub(r'--.*?$', '', sql, flags=re.MULTILINE) # Remove multi-line comments (/* comment */) sql_no_comments = re.sub(r'/\*.*?\*/', '', sql_no_comments, flags=re.DOTALL) sql_no_comments = sql_no_comments.strip() - + query_upper = sql_no_comments.upper() is_allowed_read_query = any(query_upper.startswith(prefix) for prefix in allowed_prefixes) if self.is_read_only and not is_allowed_read_query: - logger.warning(f"Blocked potentially non-read-only query in read-only mode: {sql[:100]}...") - raise PermissionError("Operation forbidden: Server is in read-only mode.") + logger.warning(f"Blocked potentially non-read-only query in read-only mode: {sql[:100]}...") + raise PermissionError("Operation forbidden: Server is in read-only mode.") logger.info(f"Executing query (DB: {database or DB_NAME}): {sql[:100]}...") if params: logger.debug(f"Parameters: {params}") conn = None + start_time = time.time() try: + self._metrics["pool_acquisitions"] += 1 async with self.pool.acquire() as conn: async with conn.cursor(cursor=asyncmy.cursors.DictCursor) as cursor: - current_db_query = "SELECT DATABASE()" - await cursor.execute(current_db_query) - current_db_result = await cursor.fetchone() - current_db_name = current_db_result.get('DATABASE()') if current_db_result else None - pool_db_name = DB_NAME - actual_current_db = current_db_name or pool_db_name - - if database and database != actual_current_db: - logger.info(f"Switching database context from '{actual_current_db}' to '{database}'") + # Only switch database context if explicitly requested + # This avoids unnecessary SELECT DATABASE() calls + if database: await cursor.execute(f"USE `{database}`") await cursor.execute(sql, params or ()) results = await cursor.fetchall() - logger.info(f"Query executed successfully, {len(results)} rows returned.") + + # Apply result limit for safety (prevent memory issues with large results) + if limit_results and results and len(results) > MCP_MAX_RESULTS: + logger.warning(f"Query returned {len(results)} rows, limiting to {MCP_MAX_RESULTS}") + results = results[:MCP_MAX_RESULTS] + + elapsed_ms = (time.time() - start_time) * 1000 + self._metrics["queries_executed"] += 1 + self._metrics["total_query_time_ms"] += elapsed_ms + logger.info(f"Query executed successfully, {len(results)} rows returned in {elapsed_ms:.1f}ms.") return results if results else [] except AsyncMyError as e: + self._metrics["query_errors"] += 1 conn_state = f"Connection: {'acquired' if conn else 'not acquired'}" logger.error(f"Database error executing query ({conn_state}): {e}", exc_info=True) - # Check for specific connection-related errors if possible raise RuntimeError(f"Database error: {e}") from e except PermissionError as e: - logger.warning(f"Permission denied: {e}") - raise e + logger.warning(f"Permission denied: {e}") + raise e except Exception as e: - # Catch potential loop closed errors here too, although ideally fixed by structure change + self._metrics["query_errors"] += 1 if isinstance(e, RuntimeError) and 'Event loop is closed' in str(e): - logger.critical("Detected closed event loop during query execution!", exc_info=True) - # This indicates a fundamental problem with loop management still exists - raise RuntimeError("Event loop closed unexpectedly during query.") from e + logger.critical("Detected closed event loop during query execution!", exc_info=True) + raise RuntimeError("Event loop closed unexpectedly during query.") from e conn_state = f"Connection: {'acquired' if conn else 'not acquired'}" logger.error(f"Unexpected error during query execution ({conn_state}): {e}", exc_info=True) raise RuntimeError(f"An unexpected error occurred: {e}") from e @@ -684,13 +815,21 @@ async def delete_vector_store(self, "vector_store_name": vector_store_name } - async def insert_docs_vector_store(self, database_name: str, vector_store_name: str, documents: List[str], metadata: Optional[List[dict]] = None) -> dict: + async def insert_docs_vector_store(self, database_name: str, vector_store_name: str, documents: List[str], metadata: Optional[List[dict]] = None, batch_size: int = 100) -> dict: """ Insert a batch of documents (with optional metadata) into a vector store. Documents must be a non-empty list of strings. Metadata, if provided, must be a list of dicts of the same length as documents. If metadata is not provided, an empty dict will be used for each document. + + Args: + database_name: Target database + vector_store_name: Target vector store table + documents: List of document strings to insert + metadata: Optional list of metadata dicts (same length as documents) + batch_size: Number of documents to insert per batch (default 100) """ - import json + logger.info(f"TOOL START: insert_docs_vector_store called for {database_name}.{vector_store_name} with {len(documents)} documents") + if not database_name or not database_name.isidentifier(): logger.error(f"Invalid database_name: '{database_name}'") raise ValueError(f"Invalid database_name: '{database_name}'") @@ -700,46 +839,78 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: if not isinstance(documents, list) or not documents or not all(isinstance(doc, str) and doc for doc in documents): logger.error("'documents' must be a non-empty list of non-empty strings.") raise ValueError("'documents' must be a non-empty list of non-empty strings.") + # Handle metadata: optional if metadata is None: metadata = [{} for _ in documents] if not isinstance(metadata, list) or len(metadata) != len(documents): logger.error("'metadata' must be a list of dicts, same length as documents (or omitted).") raise ValueError("'metadata' must be a list of dicts, same length as documents (or omitted).") - # Generate embeddings - embeddings = await embedding_service.embed(documents) - # Prepare metadata JSON - metadata_json = [json.dumps(m) for m in metadata] - # Prepare values for batch insert - insert_query = f"INSERT INTO `{database_name}`.`{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)" + inserted = 0 errors = [] - for doc, emb, meta in zip(documents, embeddings, metadata_json): - emb_str = json.dumps(emb) + + # Process in batches for better performance + for batch_start in range(0, len(documents), batch_size): + batch_end = min(batch_start + batch_size, len(documents)) + batch_docs = documents[batch_start:batch_end] + batch_meta = metadata[batch_start:batch_end] + try: - await self._execute_query(insert_query, params=(doc, emb_str, meta), database=database_name) - inserted += 1 + # Generate embeddings with rate limiting + if embedding_service is None: + raise RuntimeError("Embedding service not initialized. Ensure EMBEDDING_PROVIDER is configured.") + if _embedding_semaphore: + async with _embedding_semaphore: + embeddings = await embedding_service.embed(batch_docs) + self._metrics["embeddings_generated"] += len(batch_docs) + else: + embeddings = await embedding_service.embed(batch_docs) + self._metrics["embeddings_generated"] += len(batch_docs) + + # Prepare metadata JSON + metadata_json = [json.dumps(m) for m in batch_meta] + + # Build batch INSERT query for better performance + insert_query = f"INSERT INTO `{database_name}`.`{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)" + + # Insert each document (MariaDB doesn't support batch vector inserts well) + for doc, emb, meta in zip(batch_docs, embeddings, metadata_json): + emb_str = json.dumps(emb) + try: + await self._execute_query(insert_query, params=(doc, emb_str, meta), database=database_name, limit_results=False) + inserted += 1 + except Exception as e: + logger.error(f"Failed to insert doc into {database_name}.{vector_store_name}: {e}") + errors.append(str(e)) + except Exception as e: - logger.error(f"Failed to insert doc into {database_name}.{vector_store_name}: {e}", exc_info=True) - errors.append(str(e)) - logger.info(f"Inserted {inserted} documents into {database_name}.{vector_store_name} (errors: {len(errors)})") - result = {"status": "success" if inserted == len(documents) else "partial", "inserted": inserted} + logger.error(f"Failed to process batch {batch_start}-{batch_end}: {e}", exc_info=True) + errors.append(f"Batch {batch_start}-{batch_end}: {str(e)}") + + logger.info(f"TOOL END: insert_docs_vector_store. Inserted {inserted}/{len(documents)} documents (errors: {len(errors)})") + result: Dict[str, Any] = {"status": "success" if inserted == len(documents) else "partial", "inserted": inserted, "total": len(documents)} if errors: - result["errors"] = errors + result["errors"] = errors[:10] # Limit error messages to avoid huge responses + if len(errors) > 10: + result["errors_truncated"] = len(errors) - 10 return result async def search_vector_store(self, user_query: str, database_name: str, vector_store_name: str, k: int = 7) -> list: """ Search a vector store for the most similar documents to a query using semantic search. - Parameters: - user_query (str): The search query string. - database_name (str): The database name. - vector_store_name (str): The vector store (table) name. - k (int, optional): Number of top results to retrieve (default 7). + + Args: + user_query: The search query string. + database_name: The database name. + vector_store_name: The vector store (table) name. + k: Number of top results to retrieve (default 7). + Returns: List of dicts with document, metadata, and distance. """ - import json + logger.info(f"TOOL START: search_vector_store called for {database_name}.{vector_store_name}") + # Input validation if not user_query or not isinstance(user_query, str): logger.error("user_query must be a non-empty string.") @@ -753,12 +924,23 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ if not isinstance(k, int) or k <= 0: logger.error("k must be a positive integer.") raise ValueError("k must be a positive integer.") - # Generate embedding for the query - embedding = await embedding_service.embed(user_query) + + # Generate embedding for the query with rate limiting + if embedding_service is None: + raise RuntimeError("Embedding service not initialized. Ensure EMBEDDING_PROVIDER is configured.") + if _embedding_semaphore: + async with _embedding_semaphore: + embedding = await embedding_service.embed(user_query) + self._metrics["embeddings_generated"] += 1 + else: + embedding = await embedding_service.embed(user_query) + self._metrics["embeddings_generated"] += 1 + emb_str = json.dumps(embedding) + # Prepare the search query search_query = f""" - SELECT + SELECT document, metadata, VEC_DISTANCE_COSINE(embedding, VEC_FromText(%s)) AS distance @@ -767,18 +949,20 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ LIMIT %s """ try: - results = await self._execute_query(search_query, params=(emb_str, k), database=database_name) + results = await self._execute_query(search_query, params=(emb_str, k), database=database_name, limit_results=False) for row in results: if isinstance(row.get('metadata'), str): try: row['metadata'] = json.loads(row['metadata']) - except Exception: - pass - logger.info(f"Semantic search in {database_name}.{vector_store_name} returned {len(results)} results.") + except json.JSONDecodeError as e: + raw_meta = row.get('metadata') or '' + logger.warning(f"Failed to parse metadata JSON for document: {e}. Raw value: {raw_meta[:100]}...") + # Keep raw string if parsing fails + logger.info(f"TOOL END: search_vector_store. Returned {len(results)} results.") return results except Exception as e: logger.error(f"Failed to search vector store {database_name}.{vector_store_name}: {e}", exc_info=True) - return [] + raise RuntimeError(f"Vector store search failed: {e}") from e # --- Tool Registration (Synchronous) --- def register_tools(self): @@ -845,6 +1029,41 @@ async def search_vector_store(user_query: str, database_name: str, vector_store_ logger.info("Registered MCP tools explicitly.") + # Register /health endpoint for HTTP/SSE transports + self.mcp.custom_route("/health", methods=["GET"])(self._health_endpoint) + logger.info("Registered /health endpoint.") + + def get_health(self) -> Dict[str, Any]: + """Returns health check information for the server.""" + uptime_seconds = time.time() - self._start_time + pool_status = "connected" if self.pool is not None else "disconnected" + + # Calculate average query time + avg_query_time = 0 + if self._metrics["queries_executed"] > 0: + avg_query_time = self._metrics["total_query_time_ms"] / self._metrics["queries_executed"] + + return { + "status": "healthy" if self.pool is not None else "unhealthy", + "uptime_seconds": round(uptime_seconds, 2), + "pool_status": pool_status, + "read_only_mode": self.is_read_only, + "embedding_provider": EMBEDDING_PROVIDER, + "metrics": { + "queries_executed": self._metrics["queries_executed"], + "query_errors": self._metrics["query_errors"], + "avg_query_time_ms": round(avg_query_time, 2), + "embeddings_generated": self._metrics["embeddings_generated"], + "pool_acquisitions": self._metrics["pool_acquisitions"], + } + } + + async def _health_endpoint(self, request): + """Starlette endpoint handler for /health.""" + health_data = self.get_health() + status_code = 200 if health_data["status"] == "healthy" else 503 + return JSONResponse(health_data, status_code=status_code) + # --- Async Main Server Logic --- async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, path="/mcp"): """ @@ -861,16 +1080,22 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, # 3. Prepare transport arguments transport_kwargs = {} if transport != "stdio": + # Broaden CORS and include OPTIONS and credentials to accommodate + # browser-based clients and websocket upgrade flows used by some + # agent UIs. Keep TrustedHostMiddleware to limit allowed hosts. middleware = [ Middleware( CORSMiddleware, allow_origins=ALLOWED_ORIGINS, - allow_methods=["GET", "POST"], + allow_methods=["GET", "POST", "OPTIONS"], allow_headers=["*"], + allow_credentials=True, + expose_headers=["*"], ), - Middleware(TrustedHostMiddleware, + Middleware(TrustedHostMiddleware, allowed_hosts=ALLOWED_HOSTS) ] + if transport == "sse": transport_kwargs = {"host": host, "port": port, "middleware": middleware} logger.info(f"Starting MCP server via {transport} on {host}:{port}...") @@ -878,10 +1103,10 @@ async def run_async_server(self, transport="stdio", host="127.0.0.1", port=9001, transport_kwargs = {"host": host, "port": port, "path": path, "middleware": middleware} logger.info(f"Starting MCP server via {transport} on {host}:{port}{path}...") elif transport == "stdio": - logger.info(f"Starting MCP server via {transport}...") + logger.info(f"Starting MCP server via {transport}...") else: - logger.error(f"Unsupported transport type: {transport}") - return + logger.error(f"Unsupported transport type: {transport}") + return # 4. Run the appropriate async listener from FastMCP await self.mcp.run_async(transport=transport, **transport_kwargs) diff --git a/src/tests/test_mariadb_mcp_tools.py b/src/tests/test_mariadb_mcp_tools.py index 4123cee..78c863e 100644 --- a/src/tests/test_mariadb_mcp_tools.py +++ b/src/tests/test_mariadb_mcp_tools.py @@ -63,6 +63,16 @@ async def test_step_1_list_databases(self): if __name__ == "__main__": unittest.main() +def test_step_1_list_databases(): + """ + Test: Call mcp0_list_databases. + Purpose: Verify it returns a list of database names. + Expected Outcome: Success, returns a JSON list of strings (database names). + Result: PASSED (as documented in TestMariaDBMCPTools.test_step_1_list_databases) + """ + print("Executing: mcp0_list_databases()") + # Manual execution via Cascade passed. + def test_step_2_list_tables_valid_db(): """ Test: Call mcp0_list_tables with a valid database ('information_schema'). diff --git a/src/tests/test_mcp_server.py b/src/tests/test_mcp_server.py index 76f184e..b79d790 100644 --- a/src/tests/test_mcp_server.py +++ b/src/tests/test_mcp_server.py @@ -9,9 +9,12 @@ # It tests the server's tools using the FastMCP client class TestMariaDBMCPTools(unittest.IsolatedAsyncioTestCase): + # Test database name used for cleanup + TEST_DB_NAME = "test_database" + async def asyncSetUp(self): # Start the MariaDBServer in the background using stdio transport - self.server = MariaDBServer(autocommit=False) + self.server = MariaDBServer() async def task_group_helper(self, tg): # Start the server as a background task @@ -218,11 +221,11 @@ async def test_search_vector_store(self): await self.client.call_tool('create_database', {'database_name': 'test_database'}) await self.client.call_tool('create_vector_store', {'database_name': 'test_database', 'vector_store_name': 'test_vector_store'}) await self.client.call_tool('insert_docs_vector_store', {'database_name': 'test_database', 'vector_store_name': 'test_vector_store', 'documents': ['test_document'], 'metadata': [{'test': 'test'}]}) - result = await self.client.call_tool('search_vector_store', {'database_name': 'test_database', 'vector_store_name': 'test_vector_store', 'query': 'test_query'}) + result = await self.client.call_tool('search_vector_store', {'database_name': 'test_database', 'vector_store_name': 'test_vector_store', 'user_query': 'test_query'}) result = result[0].text result = json.loads(result) - self.assertIsInstance(result, dict) - self.assertTrue(result['status'] == 'success') + self.assertIsInstance(result, list) + self.assertGreater(len(result), 0) tg.cancel_scope.cancel() async def test_readonly_mode(self):