Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 51 additions & 22 deletions src/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,38 @@
import ssl

import asyncmy
import anyio
import anyio
from fastmcp import FastMCP, Context

# --- Helper Functions ---

def is_valid_mariadb_identifier(name: str, max_length: int = 64) -> bool:
"""
Validates a MariaDB/MySQL identifier (database name, table name, etc.).

MariaDB identifiers can contain alphanumerics, underscores, dashes, and dollar signs.

Args:
name: The identifier to validate
max_length: Maximum allowed length (default 64, per MariaDB limit)

Returns:
True if valid, False otherwise
"""
if not name or not isinstance(name, str):
return False

# Check length (MariaDB limit is 64 characters)
if len(name) > max_length:
return False

# Allow: alphanumerics, underscore, dash, dollar sign
# Disallow: backtick, null byte, and other special chars
if not re.match(r'^[a-zA-Z0-9_\-$]+$', name):
return False

return True

# Import custom connection pool that disables MULTI_STATEMENTS
from custom_connection import create_safe_pool

Expand Down Expand Up @@ -230,7 +259,7 @@ async def _execute_query(self, sql: str, params: Optional[tuple] = None, databas

async def _database_exists(self, database_name: str) -> bool:
"""Checks if a database exists."""
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.warning(f"_database_exists called with invalid database_name: {database_name}")
return False

Expand All @@ -244,8 +273,8 @@ async def _database_exists(self, database_name: str) -> bool:

async def _table_exists(self, database_name: str, table_name: str) -> bool:
"""Checks if a table exists in the given database."""
if not database_name or not database_name.isidentifier() or \
not table_name or not table_name.isidentifier():
if not is_valid_mariadb_identifier(database_name) or \
not is_valid_mariadb_identifier(table_name):
logger.warning(f"_table_exists called with invalid names: db='{database_name}', table='{table_name}'")
return False

Expand All @@ -272,8 +301,8 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool:
"""
logger.debug(f"Checking if '{database_name}.{table_name}' is a vector store.")

if not database_name or not database_name.isidentifier() or \
not table_name or not table_name.isidentifier():
if not is_valid_mariadb_identifier(database_name) or \
not is_valid_mariadb_identifier(table_name):
logger.warning(f"_is_vector_store called with invalid names: db='{database_name}', table='{table_name}'")
return False

Expand Down Expand Up @@ -321,7 +350,7 @@ async def list_databases(self) -> List[str]:
async def list_tables(self, database_name: str) -> List[str]:
"""Lists all tables within the specified database."""
logger.info(f"TOOL START: list_tables called. database_name={database_name}")
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.warning(f"TOOL WARNING: list_tables called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
sql = "SHOW TABLES"
Expand All @@ -340,10 +369,10 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st
for a specific table in a database.
"""
logger.info(f"TOOL START: get_table_schema called. database_name={database_name}, table_name={table_name}")
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.warning(f"TOOL WARNING: get_table_schema called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
if not table_name or not table_name.isidentifier():
if not is_valid_mariadb_identifier(table_name):
logger.warning(f"TOOL WARNING: get_table_schema called with invalid table_name: {table_name}")
raise ValueError(f"Invalid table name provided: {table_name}")

Expand Down Expand Up @@ -385,10 +414,10 @@ async def get_table_schema_with_relations(self, database_name: str, table_name:
Includes all basic schema info plus foreign key relationships and referenced tables.
"""
logger.info(f"TOOL START: get_table_schema_with_relations called. database_name={database_name}, table_name={table_name}")
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
if not table_name or not table_name.isidentifier():
if not is_valid_mariadb_identifier(table_name):
logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid table_name: {table_name}")
raise ValueError(f"Invalid table name provided: {table_name}")

Expand Down Expand Up @@ -456,7 +485,7 @@ async def execute_sql(self, sql_query: str, database_name: str, parameters: Opti
Example `parameters`: ["value1", 123] corresponding to %s placeholders in `sql_query`.
"""
logger.info(f"TOOL START: execute_sql called. database_name={database_name}, sql_query={sql_query[:100]}, parameters={parameters}")
if database_name and not database_name.isidentifier():
if database_name and not is_valid_mariadb_identifier(database_name):
logger.warning(f"TOOL WARNING: execute_sql called with invalid database_name: {database_name}")
raise ValueError(f"Invalid database name provided: {database_name}")
param_tuple = tuple(parameters) if parameters is not None else None
Expand All @@ -473,7 +502,7 @@ async def create_database(self, database_name: str) -> Dict[str, Any]:
Creates a new database if it doesn't exist.
"""
logger.info(f"TOOL START: create_database called for database: '{database_name}'")
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.error(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.")

Expand Down Expand Up @@ -522,10 +551,10 @@ async def create_vector_store_tool(self,
logger.info(f"TOOL START: create_vector_store called. DB: '{database_name}', Store: '{vector_store_name}', Model: '{model_name}', Embedding_Length: {embedding_length}, Distance_Requested: '{distance_function}'")

# --- Input Validation ---
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
if not vector_store_name or not vector_store_name.isidentifier():
if not is_valid_mariadb_identifier(vector_store_name):
logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")

Expand Down Expand Up @@ -617,7 +646,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]:
logger.info(f"TOOL START: list_vector_stores called for database: '{database_name}'")

# --- Input Validation ---
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")

Expand Down Expand Up @@ -681,10 +710,10 @@ async def delete_vector_store(self,
logger.info(f"TOOL START: delete_vector_store called for: '{database_name}.{vector_store_name}'")

# --- Input Validation for names ---
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.")
if not vector_store_name or not vector_store_name.isidentifier():
if not is_valid_mariadb_identifier(vector_store_name):
logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.")

Expand Down Expand Up @@ -737,10 +766,10 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name:
If metadata is not provided, an empty dict will be used for each document.
"""
import json
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.error(f"Invalid database_name: '{database_name}'")
raise ValueError(f"Invalid database_name: '{database_name}'")
if not vector_store_name or not vector_store_name.isidentifier():
if not is_valid_mariadb_identifier(vector_store_name):
logger.error(f"Invalid vector_store_name: '{vector_store_name}'")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'")
if not isinstance(documents, list) or not documents or not all(isinstance(doc, str) and doc for doc in documents):
Expand Down Expand Up @@ -790,10 +819,10 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_
if not user_query or not isinstance(user_query, str):
logger.error("user_query must be a non-empty string.")
raise ValueError("user_query must be a non-empty string.")
if not database_name or not database_name.isidentifier():
if not is_valid_mariadb_identifier(database_name):
logger.error(f"Invalid database_name: '{database_name}'")
raise ValueError(f"Invalid database_name: '{database_name}'")
if not vector_store_name or not vector_store_name.isidentifier():
if not is_valid_mariadb_identifier(vector_store_name):
logger.error(f"Invalid vector_store_name: '{vector_store_name}'")
raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'")
if not isinstance(k, int) or k <= 0:
Expand Down