Skip to content

Commit 96d5e9e

Browse files
committed
refactor: create_engine helper
1 parent c51f6e4 commit 96d5e9e

File tree

4 files changed

+21
-39
lines changed

4 files changed

+21
-39
lines changed

openagent/core/database/engine.py

Lines changed: 10 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,38 +1,22 @@
11
import os
2-
from typing import Optional
2+
from typing import Optional, Literal
33
from urllib.parse import urlparse
44

55
from sqlalchemy import create_engine as sa_create_engine, text as sa_text, Engine
66
from loguru import logger
77

88

9-
def _create_sqlite_engine(db_url: Optional[str] = None, db_name: Optional[str] = None, storage_dir: str = "storage") -> Engine:
9+
def _create_sqlite_engine(db_url: str) -> Engine:
1010
"""
11-
Create a SQLite engine from a URL or create a default URL if not provided.
11+
Create a SQLite engine from a URL.
1212
1313
Args:
1414
db_url: SQLite database URL (sqlite:///path/to/file.db)
15-
db_name: Database name to use if db_url is not provided
16-
storage_dir: Directory to store SQLite databases (default: 'storage')
1715
1816
Returns:
1917
SQLAlchemy engine instance
20-
21-
Raises:
22-
ValueError: If neither db_url nor db_name is provided
2318
"""
24-
if db_url:
25-
return sa_create_engine(db_url)
26-
27-
if not db_name:
28-
raise ValueError("Either db_url or db_name must be provided for SQLite")
29-
30-
# Create default SQLite path
31-
db_path = os.path.join(os.getcwd(), storage_dir, f"{db_name}.db")
32-
if not os.path.exists(os.path.dirname(db_path)):
33-
os.makedirs(os.path.dirname(db_path))
34-
35-
return sa_create_engine(f"sqlite:///{db_path}")
19+
return sa_create_engine(db_url)
3620

3721

3822
def _ensure_postgres_database_exists(db_url: str) -> None:
@@ -86,28 +70,27 @@ def _create_postgres_engine(db_url: str) -> Engine:
8670
return sa_create_engine(db_url)
8771

8872

89-
def create_engine(db_type: str = "sqlite", db_url: Optional[str] = None, db_name: Optional[str] = None, storage_dir: str = "storage") -> Engine:
73+
def create_engine(db_type: Literal["sqlite", "postgres"] = "sqlite", db_url: str = None) -> Engine:
9074
"""
9175
Create a database engine based on the provided configuration.
9276
9377
Args:
9478
db_type: Type of database ('sqlite' or 'postgres')
9579
db_url: Database URL. For postgres: postgresql://user:password@host:port/database,
9680
for sqlite: sqlite:///path/to/file.db
97-
db_name: Database name to use if db_url is not provided (SQLite only)
98-
storage_dir: Directory to store SQLite databases (default: 'storage')
9981
10082
Returns:
10183
SQLAlchemy engine instance
10284
10385
Raises:
104-
ValueError: If an unsupported database type is specified or if required parameters are missing
86+
ValueError: If an unsupported database type is specified or if db_url is missing
10587
"""
88+
if not db_url:
89+
raise ValueError("Database URL is required")
90+
10691
if db_type == "sqlite":
107-
return _create_sqlite_engine(db_url, db_name, storage_dir)
92+
return _create_sqlite_engine(db_url)
10893
elif db_type == "postgres":
109-
if not db_url:
110-
raise ValueError("Database URL is required for PostgreSQL")
11194
return _create_postgres_engine(db_url)
11295
else:
11396
raise ValueError(f"Unsupported database type: {db_type}")

openagent/tools/pendle/market_analysis.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ def _init_database(self, config: Optional[DatabaseConfig] = None) -> None:
9191
"""Initialize database connection based on configuration"""
9292
# Set default configuration if not provided
9393
db_type = "sqlite"
94-
db_url = None
94+
db_url = 'sqlite:///' + os.path.join(os.getcwd(), "storage", f"{self.name}.db")
9595

9696
if config:
9797
db_type = config.type
@@ -101,7 +101,6 @@ def _init_database(self, config: Optional[DatabaseConfig] = None) -> None:
101101
engine = create_engine(
102102
db_type=db_type,
103103
db_url=db_url,
104-
db_name=self.name
105104
)
106105

107106
# Create tables and initialize session factory
@@ -266,10 +265,10 @@ def get_top_markets(markets: list[PendleMarketData], key_attr: str, n: int = 3)
266265

267266
# Combine all relevant symbols
268267
relevant_symbols = (
269-
liquidity_increase_top_symbols
270-
| new_market_liquidity_increase_symbols
271-
| apy_increase_symbols
272-
| new_market_apy_increase_symbols
268+
liquidity_increase_top_symbols
269+
| new_market_liquidity_increase_symbols
270+
| apy_increase_symbols
271+
| new_market_apy_increase_symbols
273272
)
274273

275274
# Filter markets to include only those in the relevant symbols set

openagent/tools/pendle/voter_apy_analysis.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
from sqlalchemy.orm import sessionmaker
1515

1616
from openagent.agent.config import ModelConfig
17-
from openagent.core.database import sqlite
17+
from openagent.core.database.engine import create_engine
1818
from openagent.core.tool import Tool
1919
from openagent.core.utils.fetch_json import fetch_json
2020
from openagent.core.utils.json_equal import json_equal
@@ -47,8 +47,8 @@ def __init__(self, core_model=None):
4747
self.core_model = core_model
4848
self.tool_model = None
4949
self.tool_prompt = None
50-
db_path = os.path.join(os.getcwd(), "storage", f"{self.name}.db")
51-
self.engine = sqlite.create_engine(db_path)
50+
db_path = 'sqlite:///' + os.path.join(os.getcwd(), "storage", f"{self.name}.db")
51+
self.engine = create_engine("sqlite", db_path)
5252
Base.metadata.create_all(self.engine)
5353
session = sessionmaker(bind=self.engine)
5454
self.session = session()

openagent/tools/twitter/feed.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from sqlalchemy import Column, Integer, String, DateTime
1111
from sqlalchemy.ext.declarative import declarative_base
1212
from sqlalchemy.orm import sessionmaker
13-
from openagent.core.database import sqlite
13+
from openagent.core.database.engine import create_engine
1414
from openagent.core.tool import Tool
1515

1616
Base = declarative_base()
@@ -80,8 +80,8 @@ def __init__(self):
8080
self.retry_delay = 1
8181

8282
# Initialize database
83-
db_path = os.path.join(os.getcwd(), "storage", f"{self.name}.db")
84-
self.engine = sqlite.create_engine(db_path)
83+
db_path = 'sqlite:///'+os.path.join(os.getcwd(), "storage", f"{self.name}.db")
84+
self.engine = create_engine('sqlite',db_path)
8585
Base.metadata.create_all(self.engine)
8686
Session = sessionmaker(bind=self.engine)
8787
self.session = Session()

0 commit comments

Comments
 (0)