diff --git a/src/server.py b/src/server.py index d0bf389..8414f31 100644 --- a/src/server.py +++ b/src/server.py @@ -18,9 +18,38 @@ import ssl import asyncmy -import anyio +import anyio from fastmcp import FastMCP, Context +# --- Helper Functions --- + +def is_valid_mariadb_identifier(name: str, max_length: int = 64) -> bool: + """ + Validates a MariaDB/MySQL identifier (database name, table name, etc.). + + MariaDB identifiers can contain alphanumerics, underscores, dashes, and dollar signs. + + Args: + name: The identifier to validate + max_length: Maximum allowed length (default 64, per MariaDB limit) + + Returns: + True if valid, False otherwise + """ + if not name or not isinstance(name, str): + return False + + # Check length (MariaDB limit is 64 characters) + if len(name) > max_length: + return False + + # Allow: alphanumerics, underscore, dash, dollar sign + # Disallow: backtick, null byte, and other special chars + if not re.match(r'^[a-zA-Z0-9_\-$]+$', name): + return False + + return True + # Import custom connection pool that disables MULTI_STATEMENTS from custom_connection import create_safe_pool @@ -230,7 +259,7 @@ async def _execute_query(self, sql: str, params: Optional[tuple] = None, databas async def _database_exists(self, database_name: str) -> bool: """Checks if a database exists.""" - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.warning(f"_database_exists called with invalid database_name: {database_name}") return False @@ -244,8 +273,8 @@ async def _database_exists(self, database_name: str) -> bool: async def _table_exists(self, database_name: str, table_name: str) -> bool: """Checks if a table exists in the given database.""" - if not database_name or not database_name.isidentifier() or \ - not table_name or not table_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name) or \ + not is_valid_mariadb_identifier(table_name): logger.warning(f"_table_exists called with invalid names: db='{database_name}', table='{table_name}'") return False @@ -272,8 +301,8 @@ async def _is_vector_store(self, database_name: str, table_name: str) -> bool: """ logger.debug(f"Checking if '{database_name}.{table_name}' is a vector store.") - if not database_name or not database_name.isidentifier() or \ - not table_name or not table_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name) or \ + not is_valid_mariadb_identifier(table_name): logger.warning(f"_is_vector_store called with invalid names: db='{database_name}', table='{table_name}'") return False @@ -321,7 +350,7 @@ async def list_databases(self) -> List[str]: async def list_tables(self, database_name: str) -> List[str]: """Lists all tables within the specified database.""" logger.info(f"TOOL START: list_tables called. database_name={database_name}") - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.warning(f"TOOL WARNING: list_tables called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") sql = "SHOW TABLES" @@ -340,10 +369,10 @@ async def get_table_schema(self, database_name: str, table_name: str) -> Dict[st for a specific table in a database. """ logger.info(f"TOOL START: get_table_schema called. database_name={database_name}, table_name={table_name}") - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.warning(f"TOOL WARNING: get_table_schema called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") - if not table_name or not table_name.isidentifier(): + if not is_valid_mariadb_identifier(table_name): logger.warning(f"TOOL WARNING: get_table_schema called with invalid table_name: {table_name}") raise ValueError(f"Invalid table name provided: {table_name}") @@ -385,10 +414,10 @@ async def get_table_schema_with_relations(self, database_name: str, table_name: Includes all basic schema info plus foreign key relationships and referenced tables. """ logger.info(f"TOOL START: get_table_schema_with_relations called. database_name={database_name}, table_name={table_name}") - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") - if not table_name or not table_name.isidentifier(): + if not is_valid_mariadb_identifier(table_name): logger.warning(f"TOOL WARNING: get_table_schema_with_relations called with invalid table_name: {table_name}") raise ValueError(f"Invalid table name provided: {table_name}") @@ -456,7 +485,7 @@ async def execute_sql(self, sql_query: str, database_name: str, parameters: Opti Example `parameters`: ["value1", 123] corresponding to %s placeholders in `sql_query`. """ logger.info(f"TOOL START: execute_sql called. database_name={database_name}, sql_query={sql_query[:100]}, parameters={parameters}") - if database_name and not database_name.isidentifier(): + if database_name and not is_valid_mariadb_identifier(database_name): logger.warning(f"TOOL WARNING: execute_sql called with invalid database_name: {database_name}") raise ValueError(f"Invalid database name provided: {database_name}") param_tuple = tuple(parameters) if parameters is not None else None @@ -473,7 +502,7 @@ async def create_database(self, database_name: str) -> Dict[str, Any]: Creates a new database if it doesn't exist. """ logger.info(f"TOOL START: create_database called for database: '{database_name}'") - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.error(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name for creation: '{database_name}'. Must be a valid identifier.") @@ -522,10 +551,10 @@ async def create_vector_store_tool(self, logger.info(f"TOOL START: create_vector_store called. DB: '{database_name}', Store: '{vector_store_name}', Model: '{model_name}', Embedding_Length: {embedding_length}, Distance_Requested: '{distance_function}'") # --- Input Validation --- - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not vector_store_name or not vector_store_name.isidentifier(): + if not is_valid_mariadb_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") @@ -617,7 +646,7 @@ async def list_vector_stores(self, database_name: str) -> List[str]: logger.info(f"TOOL START: list_vector_stores called for database: '{database_name}'") # --- Input Validation --- - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") @@ -681,10 +710,10 @@ async def delete_vector_store(self, logger.info(f"TOOL START: delete_vector_store called for: '{database_name}.{vector_store_name}'") # --- Input Validation for names --- - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") raise ValueError(f"Invalid database_name: '{database_name}'. Must be a valid identifier.") - if not vector_store_name or not vector_store_name.isidentifier(): + if not is_valid_mariadb_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'. Must be a valid identifier.") @@ -737,10 +766,10 @@ async def insert_docs_vector_store(self, database_name: str, vector_store_name: If metadata is not provided, an empty dict will be used for each document. """ import json - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'") raise ValueError(f"Invalid database_name: '{database_name}'") - if not vector_store_name or not vector_store_name.isidentifier(): + if not is_valid_mariadb_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") if not isinstance(documents, list) or not documents or not all(isinstance(doc, str) and doc for doc in documents): @@ -790,10 +819,10 @@ async def search_vector_store(self, user_query: str, database_name: str, vector_ if not user_query or not isinstance(user_query, str): logger.error("user_query must be a non-empty string.") raise ValueError("user_query must be a non-empty string.") - if not database_name or not database_name.isidentifier(): + if not is_valid_mariadb_identifier(database_name): logger.error(f"Invalid database_name: '{database_name}'") raise ValueError(f"Invalid database_name: '{database_name}'") - if not vector_store_name or not vector_store_name.isidentifier(): + if not is_valid_mariadb_identifier(vector_store_name): logger.error(f"Invalid vector_store_name: '{vector_store_name}'") raise ValueError(f"Invalid vector_store_name: '{vector_store_name}'") if not isinstance(k, int) or k <= 0: