From fa9056b1bfa129409c11720e27e1fe305bd698a6 Mon Sep 17 00:00:00 2001 From: Konstantinos Ziovas Date: Mon, 27 Oct 2025 12:37:29 +0200 Subject: [PATCH] feat: Add support for passing a pre-created SQLAlchemy engine to Client --- src/vecs/__init__.py | 6 ++++-- src/vecs/client.py | 21 ++++++++++++++++----- 2 files changed, 20 insertions(+), 7 deletions(-) diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index 6f1c7f3..dcae534 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -1,3 +1,5 @@ +from sqlalchemy import Engine + from vecs import exc from vecs.client import Client from vecs.collection import ( @@ -23,6 +25,6 @@ ] -def create_client(connection_string: str) -> Client: +def create_client(connection_string: str = None, engine: Engine = None) -> Client: """Creates a client from a Postgres connection string""" - return Client(connection_string) + return Client(connection_string=connection_string, engine=engine) diff --git a/src/vecs/client.py b/src/vecs/client.py index 89bb3e3..29c42fe 100644 --- a/src/vecs/client.py +++ b/src/vecs/client.py @@ -10,7 +10,7 @@ from typing import TYPE_CHECKING, List, Optional from deprecated import deprecated -from sqlalchemy import MetaData, create_engine, text +from sqlalchemy import Engine, MetaData, create_engine, text from sqlalchemy.orm import sessionmaker from vecs.adapter import Adapter @@ -47,17 +47,28 @@ class Client: vx.disconnect() """ - def __init__(self, connection_string: str): + def __init__(self, connection_string: str = None, engine: Engine = None): """ Initialize a Client instance. Args: - connection_string (str): A string representing the database connection information. + connection_string (str, optional): Database connection string. Required if engine is not provided. + engine (Engine, optional): Pre-created SQLAlchemy engine. If provided, connection_string is ignored. Returns: None + Raises: + ValueError: If neither connection_string nor engine is provided. """ - self.engine = create_engine(connection_string) + if engine is not None: + self.engine = engine + elif connection_string is not None: + self.engine = create_engine(connection_string) + else: + raise ValueError( + "Either a connection_string or an engine must be provided." + ) + self.meta = MetaData(schema="vecs") self.Session = sessionmaker(self.engine) @@ -153,7 +164,7 @@ def get_collection(self, name: str) -> Collection: from vecs.collection import Collection query = text( - f""" + """ select relname as table_name, atttypmod as embedding_dim