Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .env.example
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
# Databricks API configuration
# Databricks API configuration (REQUIRED)
DATABRICKS_HOST=https://adb-xxxxxxxxxxxx.xx.azuredatabricks.net
DATABRICKS_TOKEN=your_databricks_token_here
DATABRICKS_TOKEN=dapi_your_real_token_here

# Server configuration
SERVER_HOST=0.0.0.0
SERVER_HOST=127.0.0.1
SERVER_PORT=8000
DEBUG=False

Expand Down
54 changes: 31 additions & 23 deletions src/core/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,23 @@
Configuration settings for the Databricks MCP server.
"""

import logging
import os
import sys
from typing import Any, Dict, Optional

# Import dotenv if available, but don't require it
try:
from dotenv import load_dotenv
# Load .env file if it exists
load_dotenv()
print("Successfully loaded dotenv")
except ImportError:
print("WARNING: python-dotenv not found, environment variables must be set manually")
# We'll just rely on OS environment variables being set manually
print("WARNING: python-dotenv not found, environment variables must be set manually", file=sys.stderr)

from pydantic import field_validator
from pydantic import SecretStr, field_validator
from pydantic_settings import BaseSettings

logger = logging.getLogger(__name__)

# Version
VERSION = "0.1.0"

Expand All @@ -26,25 +27,36 @@ class Settings(BaseSettings):
"""Base settings for the application."""

# Databricks API configuration
DATABRICKS_HOST: str = os.environ.get("DATABRICKS_HOST", "https://example.databricks.net")
DATABRICKS_TOKEN: str = os.environ.get("DATABRICKS_TOKEN", "dapi_token_placeholder")
DATABRICKS_HOST: str
DATABRICKS_TOKEN: SecretStr

# Server configuration
SERVER_HOST: str = os.environ.get("SERVER_HOST", "0.0.0.0")
SERVER_PORT: int = int(os.environ.get("SERVER_PORT", "8000"))
DEBUG: bool = os.environ.get("DEBUG", "False").lower() == "true"
SERVER_HOST: str = "127.0.0.1"
SERVER_PORT: int = 8000
DEBUG: bool = False

# Logging
LOG_LEVEL: str = os.environ.get("LOG_LEVEL", "INFO")
LOG_LEVEL: str = "INFO"

# Version
VERSION: str = VERSION

@field_validator("DATABRICKS_HOST")
def validate_databricks_host(cls, v: str) -> str:
"""Validate Databricks host URL."""
if not v.startswith(("https://", "http://")):
raise ValueError("DATABRICKS_HOST must start with http:// or https://")
"""Validate Databricks host URL. Only HTTPS is allowed."""
if not v.startswith("https://"):
raise ValueError("DATABRICKS_HOST must start with https://")
return v.rstrip("/")

@field_validator("DATABRICKS_TOKEN")
def validate_databricks_token(cls, v: SecretStr) -> SecretStr:
"""Validate that the token is not a placeholder."""
token_value = v.get_secret_value()
if not token_value or token_value in ("dapi_token_placeholder", "your_databricks_token_here"):
raise ValueError(
"DATABRICKS_TOKEN must be set to a valid token. "
"Check your .env file or environment variables."
)
return v

class Config:
Expand All @@ -61,26 +73,22 @@ class Config:
def get_api_headers() -> Dict[str, str]:
"""Get headers for Databricks API requests."""
return {
"Authorization": f"Bearer {settings.DATABRICKS_TOKEN}",
"Authorization": f"Bearer {settings.DATABRICKS_TOKEN.get_secret_value()}",
"Content-Type": "application/json",
}


def get_databricks_api_url(endpoint: str) -> str:
"""
Construct the full Databricks API URL.

Args:
endpoint: The API endpoint path, e.g., "/api/2.0/clusters/list"

Returns:
Full URL to the Databricks API endpoint
"""
# Ensure endpoint starts with a slash
if not endpoint.startswith("/"):
endpoint = f"/{endpoint}"

# Remove trailing slash from host if present
host = settings.DATABRICKS_HOST.rstrip("/")

return f"{host}{endpoint}"
return f"{settings.DATABRICKS_HOST}{endpoint}"
48 changes: 25 additions & 23 deletions src/core/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,40 +29,43 @@ def __init__(self, message: str, status_code: Optional[int] = None, response: Op
super().__init__(self.message)


REQUEST_TIMEOUT_SECONDS = 30


def make_api_request(
method: str,
endpoint: str,
data: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
files: Optional[Dict[str, Any]] = None,
timeout: int = REQUEST_TIMEOUT_SECONDS,
) -> Dict[str, Any]:
"""
Make a request to the Databricks API.

Args:
method: HTTP method ("GET", "POST", "PUT", "DELETE")
endpoint: API endpoint path
data: Request body data
params: Query parameters
files: Files to upload

timeout: Request timeout in seconds

Returns:
Response data as a dictionary

Raises:
DatabricksAPIError: If the API request fails
"""
url = get_databricks_api_url(endpoint)
headers = get_api_headers()

try:
# Log the request (omit sensitive information)
safe_data = "**REDACTED**" if data else None
logger.debug(f"API Request: {method} {url} Params: {params} Data: {safe_data}")

logger.debug(f"API Request: {method} {url} Params: {params} Data: **REDACTED**")

# Convert data to JSON string if provided
json_data = json.dumps(data) if data and not files else data

# Make the request
response = requests.request(
method=method,
Expand All @@ -71,35 +74,34 @@ def make_api_request(
params=params,
data=json_data if not files else data,
files=files,
timeout=timeout,
verify=True,
)

# Check for HTTP errors
response.raise_for_status()

# Parse response
if response.content:
return response.json()
return {}

except RequestException as e:
# Handle request exceptions
status_code = getattr(e.response, "status_code", None) if hasattr(e, "response") else None
error_msg = f"API request failed: {str(e)}"

# Try to extract error details from response

# Sanitize error message — avoid leaking internal details
error_response = None
if hasattr(e, "response") and e.response is not None:
try:
error_response = e.response.json()
error_msg = f"{error_msg} - {error_response.get('error', '')}"
except ValueError:
error_response = e.response.text

# Log the error
logger.error(f"API Error: {error_msg}", exc_info=True)

# Raise custom exception
raise DatabricksAPIError(error_msg, status_code, error_response) from e
error_response = None

safe_msg = f"Databricks API request to {endpoint} failed (status={status_code})"
logger.error(f"API Error: {safe_msg} — {str(e)}", exc_info=True)

raise DatabricksAPIError(safe_msg, status_code, error_response) from e


def format_response(
Expand Down
1 change: 0 additions & 1 deletion src/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ async def main():
# Log startup information
logger = logging.getLogger(__name__)
logger.info(f"Starting Databricks MCP server v{settings.VERSION}")
logger.info(f"Databricks host: {settings.DATABRICKS_HOST}")

# Start the MCP server
await start_mcp_server()
Expand Down
4 changes: 2 additions & 2 deletions src/server/databricks_mcp_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def __init__(self):
version="1.0.0",
instructions="Use this server to manage Databricks resources")
logger.info("Initializing Databricks MCP server")
logger.info(f"Databricks host: {settings.DATABRICKS_HOST}")
logger.info("Databricks MCP server configured successfully")

# Register tools
self._register_tools()
Expand Down Expand Up @@ -202,7 +202,7 @@ async def execute_sql(params: Dict[str, Any]) -> List[TextContent]:
catalog = params.get("catalog")
schema = params.get("schema")

result = await sql.execute_sql(statement, warehouse_id, catalog, schema)
result = await sql.execute_statement(statement, warehouse_id, catalog, schema)
return [{"text": json.dumps(result)}]
except Exception as e:
logger.error(f"Error executing SQL: {str(e)}")
Expand Down