From f1c2e79812299f643448042498d98a521c0178be Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 16 Jul 2025 18:49:38 +0200 Subject: [PATCH 1/3] Add async client and collection, and tests --- pytest.ini | 1 + setup.py | 3 +- src/tests/conftest.py | 9 + src/tests/test_async_client.py | 161 +++++++ src/tests/test_async_collection.py | 554 ++++++++++++++++++++++ src/vecs/__init__.py | 11 + src/vecs/adapter/text.py | 1 + src/vecs/async_client.py | 233 ++++++++++ src/vecs/async_collection.py | 723 +++++++++++++++++++++++++++++ src/vecs/client.py | 17 +- src/vecs/collection.py | 27 +- 11 files changed, 1717 insertions(+), 23 deletions(-) create mode 100644 src/tests/test_async_client.py create mode 100644 src/tests/test_async_collection.py create mode 100644 src/vecs/async_client.py create mode 100644 src/vecs/async_collection.py diff --git a/pytest.ini b/pytest.ini index d888c9e..73657f7 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,2 +1,3 @@ [pytest] addopts = src/tests +asyncio_mode = auto diff --git a/setup.py b/setup.py index 7bce72c..263ebf5 100644 --- a/setup.py +++ b/setup.py @@ -43,6 +43,7 @@ def read_package_variable(key, filename="__init__.py"): "pgvector==0.3.*", "sqlalchemy==2.*", "psycopg2-binary==2.9.*", + "asyncpg==0.29.*", "flupy==1.*", "deprecated==1.2.*", ] @@ -75,7 +76,7 @@ def read_package_variable(key, filename="__init__.py"): ], install_requires=REQUIRES, extras_require={ - "dev": ["pytest", "parse", "numpy", "pytest-cov"], + "dev": ["pytest", "parse", "numpy", "pytest-cov", "pytest-asyncio"], "docs": [ "mkdocs", "pygments", diff --git a/src/tests/conftest.py b/src/tests/conftest.py index 23ce338..dfedc17 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -7,6 +7,8 @@ from typing import Generator import pytest +import pytest_asyncio + from parse import parse from sqlalchemy import create_engine, text @@ -103,3 +105,10 @@ def clean_db(maybe_start_pg: None) -> Generator[str, None, None]: def client(clean_db: str) -> Generator[vecs.Client, None, None]: client_ = vecs.create_client(clean_db) yield client_ + + +@pytest_asyncio.fixture +async def async_client(clean_db: str): + """Create an async client for testing""" + client_ = await vecs.create_async_client(clean_db) + yield client_ diff --git a/src/tests/test_async_client.py b/src/tests/test_async_client.py new file mode 100644 index 0000000..d149a6a --- /dev/null +++ b/src/tests/test_async_client.py @@ -0,0 +1,161 @@ +import pytest + +import vecs + + +@pytest.mark.asyncio +async def test_create_async_client(clean_db: str): + """Test creating an async client""" + # Convert regular connection to async + async_db = clean_db.replace("postgresql://", "postgresql+asyncpg://") + client = await vecs.create_async_client(async_db) + assert isinstance(client, vecs.AsyncClient) + await client.disconnect() + + +@pytest.mark.asyncio +async def test_async_client_context_manager(clean_db: str): + """Test async client as context manager""" + async_db = clean_db.replace("postgresql://", "postgresql+asyncpg://") + client = await vecs.create_async_client(async_db) + async with client: + assert isinstance(client, vecs.AsyncClient) + assert client.vector_version is not None + + +@pytest.mark.asyncio +async def test_async_collection_create_and_upsert(async_client: vecs.AsyncClient): + """Test async collection creation and upsert""" + # Create collection + collection = await async_client.get_or_create_collection( + "test_collection", dimension=3 + ) + assert collection.name == "test_collection" + assert collection.dimension == 3 + + # Test upsert + records = [ + ("id1", [1.0, 2.0, 3.0], {"type": "test"}), + ("id2", [4.0, 5.0, 6.0], {"type": "test"}), + ] + await collection.upsert(records) + + # Test collection length + length = await collection.__len__() + assert length == 2 + + +@pytest.mark.asyncio +async def test_async_collection_query(async_client: vecs.AsyncClient): + """Test async collection query""" + # Create collection and add data + collection = await async_client.get_or_create_collection("test_query", dimension=3) + + records = [ + ("id1", [1.0, 2.0, 3.0], {"type": "test"}), + ("id2", [4.0, 5.0, 6.0], {"type": "test"}), + ("id3", [7.0, 8.0, 9.0], {"type": "test"}), + ] + await collection.upsert(records) + + # Test query + results = await collection.query([1.0, 2.0, 3.0], limit=2) + assert len(results) == 2 + assert results[0] == "id1" # Should be closest match + + # Test query with metadata + results_with_meta = await collection.query( + [1.0, 2.0, 3.0], limit=2, include_metadata=True + ) + assert len(results_with_meta) == 2 + assert results_with_meta[0][0] == "id1" + assert results_with_meta[0][1] == {"type": "test"} + + +@pytest.mark.asyncio +async def test_async_collection_fetch_and_delete(async_client: vecs.AsyncClient): + """Test async collection fetch and delete""" + # Create collection and add data + collection = await async_client.get_or_create_collection("test_fetch", dimension=3) + + records = [ + ("id1", [1.0, 2.0, 3.0], {"type": "test"}), + ("id2", [4.0, 5.0, 6.0], {"type": "test"}), + ] + await collection.upsert(records) + + # Test fetch + fetched = await collection.fetch(["id1", "id2"]) + assert len(fetched) == 2 + + # Test delete + await collection.delete(["id1"]) + length = await collection.__len__() + assert length == 1 + + # Test fetch after delete + fetched_after_delete = await collection.fetch(["id1", "id2"]) + assert len(fetched_after_delete) == 1 + + +@pytest.mark.asyncio +async def test_async_list_collections(async_client: vecs.AsyncClient): + """Test async list collections""" + # Create multiple collections + collection1 = await async_client.get_or_create_collection( + "collection1", dimension=3 + ) + collection2 = await async_client.get_or_create_collection( + "collection2", dimension=4 + ) + + # List collections + collections = await async_client.list_collections() + collection_names = [c.name for c in collections] + + assert "collection1" in collection_names + assert "collection2" in collection_names + assert len(collections) >= 2 + + +@pytest.mark.asyncio +async def test_async_delete_collection(async_client: vecs.AsyncClient): + """Test async delete collection""" + # Create collection + collection = await async_client.get_or_create_collection("to_delete", dimension=3) + + # Add some data + records = [("id1", [1.0, 2.0, 3.0], {"type": "test"})] + await collection.upsert(records) + + # Delete collection + await async_client.delete_collection("to_delete") + + # Try to get deleted collection - should raise error + with pytest.raises(vecs.exc.CollectionNotFound): + await async_client.get_collection("to_delete") + + +@pytest.mark.asyncio +async def test_async_collection_create_index(async_client: vecs.AsyncClient): + """Test async collection index creation""" + # Create collection and add enough data for indexing + collection = await async_client.get_or_create_collection("test_index", dimension=3) + + # Add data (need enough for index to be created) + records = [ + (f"id{i}", [float(i), float(i + 1), float(i + 2)], {"i": i}) + for i in range(1100) + ] + await collection.upsert(records) + + # Create index + await collection.create_index() + + # Check if index was created + index_name = await collection.index() + assert index_name is not None + + # Test querying with index + results = await collection.query([1.0, 2.0, 3.0], limit=5) + assert len(results) == 5 diff --git a/src/tests/test_async_collection.py b/src/tests/test_async_collection.py new file mode 100644 index 0000000..91a8d53 --- /dev/null +++ b/src/tests/test_async_collection.py @@ -0,0 +1,554 @@ +import itertools +import random + +import numpy as np +import pytest + +import vecs +from vecs import IndexArgsHNSW, IndexArgsIVFFlat, IndexMethod + + +@pytest.mark.asyncio +async def test_async_upsert(async_client: vecs.AsyncClient) -> None: + n_records = 100 + dim = 384 + + movies = await async_client.get_or_create_collection(name="ping", dimension=dim) + + # collection initially empty + assert await movies.__len__() == 0 + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + # insert works + await movies.upsert(records) + assert await movies.__len__() == n_records + + # upserting overwrites + new_record = ("vec0", np.zeros(384), {}) + await movies.upsert([new_record]) + db_record = await movies.__getitem__("vec0") + assert db_record[0] == new_record[0] + assert np.array_equal(db_record[1], new_record[1]) + assert db_record[2] == new_record[2] + + +@pytest.mark.asyncio +async def test_async_fetch(async_client: vecs.AsyncClient) -> None: + n_records = 100 + dim = 384 + + movies = await async_client.get_or_create_collection(name="ping", dimension=dim) + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + # insert works + await movies.upsert(records) + + # test basic usage + fetch_ids = ["vec0", "vec15", "vec99"] + res = await movies.fetch(ids=fetch_ids) + assert len(res) == 3 + ids = set([x[0] for x in res]) + assert all([x in ids for x in fetch_ids]) + + # test one of the keys does not exist not an error + fetch_ids = ["vec0", "vec15", "does not exist"] + res = await movies.fetch(ids=fetch_ids) + assert len(res) == 2 + + # bad input + with pytest.raises(vecs.exc.ArgError): + await movies.fetch(ids="should_be_a_list") + + +@pytest.mark.asyncio +async def test_async_delete(async_client: vecs.AsyncClient) -> None: + n_records = 100 + dim = 384 + + movies = await async_client.get_or_create_collection(name="ping", dimension=dim) + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": genre, + "year": int(50 * random.random()) + 1970, + }, + ) + for (ix, vec), genre in zip( + enumerate(np.random.random((n_records, dim))), + itertools.cycle(["action", "rom-com", "drama"]), + ) + ] + + # insert works + await movies.upsert(records) + + # delete by IDs. + delete_ids = ["vec0", "vec15", "vec99"] + await movies.delete(ids=delete_ids) + assert await movies.__len__() == n_records - len(delete_ids) + + # insert works + await movies.upsert(records) + + # delete with filters + genre_to_delete = "action" + deleted_ids_by_genre = await movies.delete( + filters={"genre": {"$eq": genre_to_delete}} + ) + assert len(deleted_ids_by_genre) == 34 + + # bad input + with pytest.raises(vecs.exc.ArgError): + await movies.delete(ids="should_be_a_list") + + # bad input: neither ids nor filters provided. + with pytest.raises(vecs.exc.ArgError): + await movies.delete() + + # bad input: should only provide either ids or filters, not both + with pytest.raises(vecs.exc.ArgError): + await movies.delete(ids=["vec0"], filters={"genre": {"$eq": genre_to_delete}}) + + +@pytest.mark.asyncio +async def test_async_repr(async_client: vecs.AsyncClient) -> None: + movies = await async_client.get_or_create_collection(name="movies", dimension=99) + assert repr(movies) == 'vecs.AsyncCollection(name="movies", dimension=99)' + + +@pytest.mark.asyncio +async def test_async_getitem(async_client: vecs.AsyncClient) -> None: + movies = await async_client.get_or_create_collection(name="movies", dimension=3) + await movies.upsert(records=[("1", [1, 2, 3], {})]) + + result = await movies.__getitem__("1") + assert result is not None + assert len(result) == 3 + + with pytest.raises(KeyError): + await movies.__getitem__("2") + + with pytest.raises(vecs.exc.ArgError): + await movies.__getitem__(["only strings work not lists"]) + + +@pytest.mark.asyncio +@pytest.mark.filterwarnings("ignore:Query does") +async def test_async_query(async_client: vecs.AsyncClient) -> None: + n_records = 100 + dim = 64 + + bar = await async_client.get_or_create_collection(name="bar", dimension=dim) + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + await bar.upsert(records) + + _, query_vec, query_meta = await bar.__getitem__("vec5") + + top_k = 7 + + res = await bar.query( + data=query_vec, + limit=top_k, + filters=None, + measure="cosine_distance", + include_value=False, + include_metadata=False, + ) + + # correct number of results + assert len(res) == top_k + # most similar to self + assert res[0] == "vec5" + + with pytest.raises(vecs.exc.ArgError): + await bar.query( + data=query_vec, + limit=1001, + ) + + with pytest.raises(vecs.exc.ArgError): + await bar.query( + data=query_vec, + probes=0, + ) + + with pytest.raises(vecs.exc.ArgError): + await bar.query( + data=query_vec, + probes=-1, + ) + + with pytest.raises(vecs.exc.ArgError): + await bar.query( + data=query_vec, + probes="a", + ) + + with pytest.raises(vecs.exc.ArgError): + await bar.query(data=query_vec, limit=top_k, measure="invalid") + + # skip_adapter has no effect (no adapter present) + res = await bar.query(data=query_vec, limit=top_k, skip_adapter=True) + assert len(res) == top_k + + # include_value + res = await bar.query( + data=query_vec, + limit=top_k, + filters=None, + measure="cosine_distance", + include_value=True, + ) + assert len(res[0]) == 2 + assert res[0][0] == "vec5" + assert pytest.approx(res[0][1]) == 0 + + # include_metadata + res = await bar.query( + data=query_vec, + limit=top_k, + filters=None, + measure="cosine_distance", + include_metadata=True, + ) + assert len(res[0]) == 2 + assert res[0][0] == "vec5" + assert res[0][1] == query_meta + + # include_vector + res = await bar.query( + data=query_vec, + limit=top_k, + filters=None, + measure="cosine_distance", + include_vector=True, + ) + assert len(res[0]) == 2 + assert res[0][0] == "vec5" + assert all(res[0][1] == query_vec) + + # test for different numbers of probes + assert len(await bar.query(data=query_vec, limit=top_k, probes=10)) == top_k + + assert len(await bar.query(data=query_vec, limit=top_k, probes=5)) == top_k + + assert len(await bar.query(data=query_vec, limit=top_k, probes=1)) == top_k + + assert len(await bar.query(data=query_vec, limit=top_k, probes=999)) == top_k + + +@pytest.mark.asyncio +@pytest.mark.filterwarnings("ignore:Query does") +async def test_async_query_filters(async_client: vecs.AsyncClient) -> None: + n_records = 100 + dim = 4 + + bar = await async_client.get_or_create_collection(name="bar", dimension=dim) + + records = [ + (f"0", [0, 0, 0, 0], {"year": 1990}), + (f"1", [1, 0, 0, 0], {"year": 1995}), + (f"2", [1, 1, 0, 0], {"year": 2005}), + (f"3", [1, 1, 1, 0], {"year": 2001}), + (f"4", [1, 1, 1, 1], {"year": 1985}), + (f"5", [2, 1, 1, 1], {"year": 1863}), + (f"6", [2, 2, 1, 1], {"year": 2021}), + (f"7", [2, 2, 2, 1], {"year": 2019}), + (f"8", [2, 2, 2, 2], {"year": 2003}), + (f"9", [3, 2, 2, 2], {"year": 1997}), + ] + + await bar.upsert(records) + + query_rec = records[0] + + res = await bar.query( + data=query_rec[1], + limit=3, + filters={"year": {"$lt": 1990}}, + measure="cosine_distance", + include_value=False, + include_metadata=False, + ) + + # Only records with year < 1990 should be returned + assert len(res) == 2 # records 4 and 5 have years 1985 and 1863 + assert "4" in res or "5" in res + + # Test $eq filter + res = await bar.query( + data=query_rec[1], + limit=10, + filters={"year": {"$eq": 1995}}, + measure="cosine_distance", + include_value=False, + include_metadata=False, + ) + + assert len(res) == 1 + assert res[0] == "1" + + # Test $gt filter + res = await bar.query( + data=query_rec[1], + limit=10, + filters={"year": {"$gt": 2010}}, + measure="cosine_distance", + include_value=False, + include_metadata=False, + ) + + assert len(res) == 2 # records 6 and 7 have years 2021 and 2019 + assert "6" in res and "7" in res + + # Test $in filter + res = await bar.query( + data=query_rec[1], + limit=10, + filters={"year": {"$in": [1990, 1995, 2005]}}, + measure="cosine_distance", + include_value=False, + include_metadata=False, + ) + + assert len(res) == 3 + assert "0" in res and "1" in res and "2" in res + + +@pytest.mark.asyncio +async def test_async_create_index_ivfflat(async_client: vecs.AsyncClient) -> None: + n_records = 1100 + dim = 384 + + movies = await async_client.get_or_create_collection(name="movies", dimension=dim) + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + await movies.upsert(records) + + # Test IVFFlat index creation + await movies.create_index(method=IndexMethod.ivfflat) + index_name = await movies.index() + assert index_name is not None + assert "ivfflat" in index_name + + # Test querying with index + query_vec = np.random.random(dim) + results = await movies.query(data=query_vec, limit=10) + assert len(results) == 10 + + # Test with custom index arguments + await movies.create_index( + method=IndexMethod.ivfflat, + index_arguments=IndexArgsIVFFlat(n_lists=50), + replace=True, + ) + index_name = await movies.index() + assert index_name is not None + assert "nl50" in index_name + + +@pytest.mark.asyncio +async def test_async_create_index_hnsw(async_client: vecs.AsyncClient) -> None: + n_records = 1100 + dim = 384 + + movies = await async_client.get_or_create_collection(name="movies", dimension=dim) + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + await movies.upsert(records) + + # Test HNSW index creation (if supported) + if async_client._supports_hnsw(): + await movies.create_index(method=IndexMethod.hnsw) + index_name = await movies.index() + assert index_name is not None + assert "hnsw" in index_name + + # Test querying with index + query_vec = np.random.random(dim) + results = await movies.query(data=query_vec, limit=10) + assert len(results) == 10 + + # Test with custom index arguments + await movies.create_index( + method=IndexMethod.hnsw, + index_arguments=IndexArgsHNSW(m=32, ef_construction=100), + replace=True, + ) + index_name = await movies.index() + assert index_name is not None + assert "m32" in index_name and "efc100" in index_name + + +@pytest.mark.asyncio +async def test_async_create_index_auto(async_client: vecs.AsyncClient) -> None: + n_records = 1100 + dim = 384 + + movies = await async_client.get_or_create_collection(name="movies", dimension=dim) + + records = [ + ( + f"vec{ix}", + vec, + { + "genre": random.choice(["action", "rom-com", "drama"]), + "year": int(50 * random.random()) + 1970, + }, + ) + for ix, vec in enumerate(np.random.random((n_records, dim))) + ] + + await movies.upsert(records) + + # Test auto index creation + await movies.create_index(method=IndexMethod.auto) + index_name = await movies.index() + assert index_name is not None + + # Should create HNSW if supported, otherwise IVFFlat + if async_client._supports_hnsw(): + assert "hnsw" in index_name + else: + assert "ivfflat" in index_name + + +@pytest.mark.asyncio +async def test_async_index_error_cases(async_client: vecs.AsyncClient) -> None: + movies = await async_client.get_or_create_collection(name="movies", dimension=3) + + # Test index arguments validation + with pytest.raises(vecs.exc.ArgError): + await movies.create_index( + method=IndexMethod.auto, index_arguments=IndexArgsIVFFlat(n_lists=50) + ) + + with pytest.raises(vecs.exc.ArgError): + await movies.create_index( + method=IndexMethod.ivfflat, index_arguments=IndexArgsHNSW(m=32) + ) + + with pytest.raises(vecs.exc.ArgError): + await movies.create_index( + method=IndexMethod.hnsw, index_arguments=IndexArgsIVFFlat(n_lists=50) + ) + + +@pytest.mark.asyncio +async def test_async_list_collections(async_client: vecs.AsyncClient) -> None: + # Create multiple collections + await async_client.get_or_create_collection(name="test1", dimension=3) + await async_client.get_or_create_collection(name="test2", dimension=4) + await async_client.get_or_create_collection(name="test3", dimension=5) + + # List all collections + collections = await async_client.list_collections() + collection_names = [c.name for c in collections] + + assert "test1" in collection_names + assert "test2" in collection_names + assert "test3" in collection_names + assert len(collections) >= 3 + + +@pytest.mark.asyncio +async def test_async_delete_collection(async_client: vecs.AsyncClient) -> None: + # Create a collection + await async_client.get_or_create_collection(name="to_delete", dimension=3) + + # Add some data + collection = await async_client.get_or_create_collection( + name="to_delete", dimension=3 + ) + await collection.upsert([("test", [1, 2, 3], {})]) + + # Delete the collection + await async_client.delete_collection("to_delete") + + # Verify it's deleted + with pytest.raises(vecs.exc.CollectionNotFound): + await async_client.get_collection("to_delete") + + +@pytest.mark.asyncio +async def test_async_is_indexed_for_measure(async_client: vecs.AsyncClient) -> None: + collection = await async_client.get_or_create_collection( + name="test_indexed", dimension=384 + ) + + # Add some data + records = [(f"vec{i}", np.random.random(384), {}) for i in range(1100)] + await collection.upsert(records) + + # Initially not indexed + assert not await collection.is_indexed_for_measure( + vecs.IndexMeasure.cosine_distance + ) + + # Create cosine distance index + await collection.create_index(measure=vecs.IndexMeasure.cosine_distance) + + # Now should be indexed for cosine distance + assert await collection.is_indexed_for_measure(vecs.IndexMeasure.cosine_distance) + + # But not for other measures + assert not await collection.is_indexed_for_measure(vecs.IndexMeasure.l2_distance) diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index 6f1c7f3..3c4ffff 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -7,6 +7,8 @@ IndexMeasure, IndexMethod, ) +from vecs.async_client import AsyncClient +from vecs.async_collection import AsyncCollection __project__ = "vecs" __version__ = "0.4.5" @@ -19,6 +21,8 @@ "IndexMeasure", "Collection", "Client", + "AsyncCollection", + "AsyncClient", "exc", ] @@ -26,3 +30,10 @@ def create_client(connection_string: str) -> Client: """Creates a client from a Postgres connection string""" return Client(connection_string) + + +async def create_async_client(connection_string: str) -> AsyncClient: + """Creates an async client from a Postgres connection string""" + client = AsyncClient(connection_string) + await client._init_db() + return client diff --git a/src/vecs/adapter/text.py b/src/vecs/adapter/text.py index 08eb97f..6eb934a 100644 --- a/src/vecs/adapter/text.py +++ b/src/vecs/adapter/text.py @@ -4,6 +4,7 @@ All public classes, enums, and functions are re-exported by `vecs.adapters` module. """ + from typing import Any, Dict, Generator, Iterable, Literal, Optional, Tuple from flupy import flu diff --git a/src/vecs/async_client.py b/src/vecs/async_client.py new file mode 100644 index 0000000..384f1a6 --- /dev/null +++ b/src/vecs/async_client.py @@ -0,0 +1,233 @@ +""" +Defines the 'AsyncClient' class + +Importing from the `vecs.async_client` directly is not supported. +All public classes, enums, and functions are re-exported by the top level `vecs` module. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, List, Optional + +from deprecated import deprecated +from sqlalchemy import MetaData, text +from sqlalchemy.ext.asyncio import AsyncSession, async_sessionmaker, create_async_engine + +from vecs.adapter import Adapter +from vecs.exc import CollectionNotFound + +if TYPE_CHECKING: + from vecs.async_collection import AsyncCollection + + +class AsyncClient: + """ + The `vecs.AsyncClient` class serves as an async interface to a PostgreSQL database with pgvector support. + It facilitates the creation, retrieval, listing and deletion of vector collections, while managing + async connections to the database. + + An `AsyncClient` instance represents an async connection to a PostgreSQL database. This connection can + be used to create and manipulate vector collections, where each collection is a group of vector records + in a PostgreSQL table. + + The `vecs.AsyncClient` class can also be used as an async context manager to ensure the connection + to the database is properly closed after operations, or it can be used directly. + + Example usage: + + DB_CONNECTION = "postgresql+asyncpg://:@:/" + + async with vecs.create_async_client(DB_CONNECTION) as vx: + # do some work + pass + + # OR + + vx = await vecs.create_async_client(DB_CONNECTION) + # do some work + await vx.disconnect() + """ + + def __init__(self, connection_string: str): + """ + Initialize an AsyncClient instance. + + Args: + connection_string (str): A string representing the database connection information. + Should use 'postgresql+asyncpg://' protocol for async connections. + + Returns: + None + """ + # Convert regular postgresql:// URLs to postgresql+asyncpg:// + if connection_string.startswith("postgresql://"): + connection_string = connection_string.replace( + "postgresql://", "postgresql+asyncpg://", 1 + ) + elif not connection_string.startswith("postgresql+asyncpg://"): + # If it's not already async, assume it should be + connection_string = ( + "postgresql+asyncpg://" + connection_string.split("://", 1)[-1] + ) + + self.engine = create_async_engine(connection_string) + self.meta = MetaData(schema="vecs") + self.AsyncSession = async_sessionmaker(self.engine, class_=AsyncSession) + self.vector_version: Optional[str] = None + + async def _init_db(self): + """Initialize the database schema and extensions.""" + async with self.AsyncSession() as sess: + async with sess.begin(): + await sess.execute(text("create schema if not exists vecs;")) + await sess.execute(text("create extension if not exists vector;")) + result = await sess.execute( + text( + "select installed_version from pg_available_extensions where name = 'vector' limit 1;" + ) + ) + self.vector_version = result.scalar_one() + + def _supports_hnsw(self): + return self.vector_version is not None and not self.vector_version.startswith( + ("0.0", "0.1", "0.2", "0.3", "0.4") + ) + + async def get_or_create_collection( + self, + name: str, + *, + dimension: Optional[int] = None, + adapter: Optional[Adapter] = None, + ) -> AsyncCollection: + """ + Get a vector collection by name, or create it if no collection with + *name* exists. + + Args: + name (str): The name of the collection. + + Keyword Args: + dimension (int): The dimensionality of the vectors in the collection. + adapter (Adapter): The adapter to use for the collection. + + Returns: + AsyncCollection: The created collection. + """ + from vecs.async_collection import AsyncCollection + + adapter_dimension = adapter.exported_dimension if adapter else None + + collection = AsyncCollection( + name=name, + dimension=dimension or adapter_dimension, # type: ignore + client=self, + adapter=adapter, + ) + + return await collection._create_if_not_exists() + + @deprecated("use AsyncClient.get_or_create_collection") + async def get_collection(self, name: str) -> AsyncCollection: + """ + Retrieve an existing vector collection (async version). + + Args: + name (str): The name of the collection. + + Returns: + AsyncCollection: The retrieved collection. + + Raises: + CollectionNotFound: If no collection with the given name exists. + """ + query = text( + """ + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + and pc.relname = :name + """ + ).bindparams(name=name) + + async with self.AsyncSession() as sess: + result = await sess.execute(query) + query_result = result.fetchone() + + if query_result is None: + raise CollectionNotFound("No collection found with requested name") + + name, dimension = query_result + return AsyncCollection(name, dimension, self) + + async def list_collections(self) -> List["AsyncCollection"]: + """ + List all vector collections. + + Returns: + list[AsyncCollection]: A list of all collections. + """ + from vecs.async_collection import AsyncCollection + + return await AsyncCollection._list_collections(self) + + async def delete_collection(self, name: str) -> None: + """ + Delete a vector collection. + + If no collection with requested name exists, does nothing. + + Args: + name (str): The name of the collection. + + Returns: + None + """ + from vecs.async_collection import AsyncCollection + + await AsyncCollection(name, -1, self)._drop() + return + + async def disconnect(self) -> None: + """ + Disconnect the client from the database. + + Returns: + None + """ + await self.engine.dispose() + return + + async def __aenter__(self) -> "AsyncClient": + """ + Enable use of the 'async with' statement. + + Returns: + AsyncClient: The current instance of the AsyncClient. + """ + await self._init_db() + return self + + async def __aexit__(self, exc_type, exc_val, exc_tb): + """ + Disconnect the client on exiting the 'async with' statement context. + + Args: + exc_type: The exception type, if any. + exc_val: The exception value, if any. + exc_tb: The traceback, if any. + + Returns: + None + """ + await self.disconnect() + return diff --git a/src/vecs/async_collection.py b/src/vecs/async_collection.py new file mode 100644 index 0000000..f32717c --- /dev/null +++ b/src/vecs/async_collection.py @@ -0,0 +1,723 @@ +""" +Defines the 'AsyncCollection' class + +Importing from the `vecs.async_collection` directly is not supported. +All public classes, enums, and functions are re-exported by the top level `vecs` module. +""" + +from __future__ import annotations + +import math +import uuid +import warnings +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union + +from flupy import flu +from sqlalchemy import delete, func, select, text +from sqlalchemy.dialects import postgresql + +from vecs.adapter import Adapter, AdapterContext, NoOp +from vecs.collection import ( + INDEX_MEASURE_TO_OPS, + INDEX_MEASURE_TO_SQLA_ACC, + IndexArgsHNSW, + IndexArgsIVFFlat, + IndexMeasure, + IndexMethod, + Metadata, + Numeric, + Record, + build_filters, + build_table, +) +from vecs.exc import ( + ArgError, + CollectionAlreadyExists, + CollectionNotFound, + MismatchedDimension, +) + +if TYPE_CHECKING: + from vecs.async_client import AsyncClient + + +class AsyncCollection: + """ + The `vecs.AsyncCollection` class represents a collection of vectors within a PostgreSQL database with pgvector support. + It provides async methods to manage (create, delete, fetch, upsert), index, and perform similarity searches on these vector collections. + + The collections are stored in separate tables in the database, with each vector associated with an identifier and optional metadata. + + Example usage: + + async with vecs.create_async_client(DB_CONNECTION) as vx: + collection = await vx.get_or_create_collection(name="docs", dimension=3) + await collection.upsert([("id1", [1, 1, 1], {"key": "value"})]) + # Further operations on 'collection' + + Public Attributes: + name: The name of the vector collection. + dimension: The dimension of vectors in the collection. + + Note: Some methods of this class can raise exceptions from the `vecs.exc` module if errors occur. + """ + + def __init__( + self, + name: str, + dimension: int, + client: AsyncClient, + adapter: Optional[Adapter] = None, + ): + """ + Initializes a new instance of the `AsyncCollection` class. + + During expected use, developers initialize instances of `AsyncCollection` using the + `vecs.AsyncClient` with `vecs.AsyncClient.get_or_create_collection(...)` rather than directly. + + Args: + name (str): The name of the collection. + dimension (int): The dimension of the vectors in the collection. + client (AsyncClient): The async client to use for interacting with the database. + adapter (Adapter, optional): The adapter to use for the collection. + """ + self.client = client + self.name = name + self.dimension = dimension + self.table = build_table(name, client.meta, dimension) + self._index: Optional[str] = None + self.adapter = adapter or Adapter(steps=[NoOp(dimension=dimension)]) + + reported_dimensions = set( + [ + x + for x in [ + dimension, + adapter.exported_dimension if adapter else None, + ] + if x is not None + ] + ) + if len(reported_dimensions) == 0: + raise ArgError("Either dimension or adapter must provide a dimension.") + elif len(reported_dimensions) > 1: + raise MismatchedDimension( + "Dimensions reported by `dimension` argument and adapter do not match." + ) + + def __repr__(self): + """ + Returns a string representation of the `AsyncCollection` instance. + + Returns: + str: A string representation of the `AsyncCollection` instance. + """ + return f'vecs.AsyncCollection(name="{self.name}", dimension={self.dimension})' + + async def __len__(self) -> int: + """ + Returns the number of vectors in the collection. + + Returns: + int: The number of vectors in the collection. + """ + async with self.client.AsyncSession() as sess: + async with sess.begin(): + stmt = select(func.count()).select_from(self.table) + result = await sess.execute(stmt) + return result.scalar() or 0 + + async def _create_if_not_exists(self): + """ + Creates a new collection in the database if it doesn't already exist + + Returns: + AsyncCollection: The found or created collection. + """ + query = text( + f""" + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + and pc.relname = :name + """ + ).bindparams(name=self.name) + async with self.client.AsyncSession() as sess: + result = await sess.execute(query) + query_result = result.fetchone() + + if query_result: + _, collection_dimension = query_result + else: + collection_dimension = None + + reported_dimensions = set( + [x for x in [self.dimension, collection_dimension] if x is not None] + ) + if len(reported_dimensions) > 1: + raise MismatchedDimension( + "Dimensions reported by `dimension` argument and existing collection do not match" + ) + + if not collection_dimension: + async with self.client.engine.begin() as conn: + await conn.run_sync(self.table.create) + + return self + + async def _create(self): + """ + Creates the collection. + + Returns: + AsyncCollection: The current instance of the AsyncCollection. + + Raises: + CollectionAlreadyExists: If a collection with the same name already exists. + """ + collection_exists = await self.__class__._does_collection_exist( + self.client, self.name + ) + if collection_exists: + raise CollectionAlreadyExists( + "Collection with requested name already exists" + ) + async with self.client.engine.begin() as conn: + await conn.run_sync(self.table.create) + + await self._create_gin_index() + return self + + async def _create_gin_index(self): + """ + Creates a GIN index on the metadata column for efficient filtering. + """ + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + async with self.client.AsyncSession() as sess: + await sess.execute( + text( + f""" + create index ix_meta_{unique_string} + on vecs."{self.table.name}" + using gin (metadata jsonb_path_ops); + """ + ) + ) + await sess.commit() + + async def _drop(self): + """ + Drops the collection from the database. + + Returns: + AsyncCollection: The current instance of the AsyncCollection. + """ + from sqlalchemy.schema import DropTable + + async with self.client.AsyncSession() as sess: + await sess.execute(DropTable(self.table, if_exists=True)) + await sess.commit() + + return self + + async def upsert( + self, records: Iterable[Tuple[str, Any, Metadata]], skip_adapter: bool = False + ) -> None: + """ + Inserts or updates *vectors* records in the collection. + + Args: + records (Iterable[Tuple[str, Any, Metadata]]): An iterable of content to upsert. + Each record is a tuple where: + - the first element is a unique string identifier + - the second element is an iterable of numeric values or relevant input type for the + adapter assigned to the collection + - the third element is metadata associated with the vector + + skip_adapter (bool): Should the adapter be skipped while upserting. i.e. if vectors are being + provided, rather than a media type that needs to be transformed + """ + + chunk_size = 500 + + if skip_adapter: + pipeline = flu(records).chunk(chunk_size) + else: + # Construct a lazy pipeline of steps to transform and chunk user input + pipeline = flu(self.adapter(records, AdapterContext("upsert"))).chunk( + chunk_size + ) + + async with self.client.AsyncSession() as sess: + async with sess.begin(): + for chunk in pipeline: + stmt = postgresql.insert(self.table).values(chunk) + stmt = stmt.on_conflict_do_update( + index_elements=[self.table.c.id], + set_=dict( + vec=stmt.excluded.vec, metadata=stmt.excluded.metadata + ), + ) + await sess.execute(stmt) + return None + + async def fetch(self, ids: Iterable[str]) -> List[Record]: + """ + Fetches vectors from the collection by their identifiers. + + Args: + ids (Iterable[str]): An iterable of vector identifiers. + + Returns: + List[Record]: A list of the fetched vectors. + """ + if isinstance(ids, str): + raise ArgError("ids must be a list of strings") + + chunk_size = 12 + records = [] + async with self.client.AsyncSession() as sess: + async with sess.begin(): + for id_chunk in flu(ids).chunk(chunk_size): + stmt = select(self.table).where(self.table.c.id.in_(id_chunk)) + result = await sess.execute(stmt) + chunk_records = result.all() + records.extend(chunk_records) + return records + + async def delete( + self, ids: Optional[Iterable[str]] = None, filters: Optional[Metadata] = None + ) -> List[str]: + """ + Asynchronously deletes vectors from the collection by matching ids or filters. + + Args: + ids (Iterable[str], optional): An iterable of vector identifiers. + filters (Optional[Dict], optional): Metadata filters to match vectors for deletion. + + Returns: + List[str]: A list of the identifiers of the deleted vectors. + + Raises: + ArgError: If both or neither of `ids` and `filters` are provided. + """ + if ids is None and filters is None: + raise ArgError("Either ids or filters must be provided.") + + if ids is not None and filters is not None: + raise ArgError("Either ids or filters must be provided, not both.") + + if isinstance(ids, str): + raise ArgError("ids must be a list of strings") + + ids = ids or [] + filters = filters or {} + del_ids: List[str] = [] + + async with self.client.AsyncSession() as sess: + async with sess.begin(): + if ids: + for id_chunk in flu(ids).chunk(12): + stmt = ( + delete(self.table) + .where(self.table.c.id.in_(id_chunk)) + .returning(self.table.c.id) + ) + result = await sess.execute(stmt) + del_ids.extend(result.scalars().all()) + + if filters: + meta_filter = build_filters(self.table.c.metadata, filters) + stmt = ( + delete(self.table) + .where(meta_filter) + .returning(self.table.c.id) # type: ignore + ) + result = await sess.execute(stmt) + del_ids.extend([r for r in result.scalars()]) + + return del_ids + + async def __getitem__(self, items: str): + """ + Asynchronously fetches a vector from the collection by its identifier. + + Args: + items (str): The identifier of the vector. + + Returns: + Record: The fetched vector. + + Raises: + ArgError: If the input is not a string. + KeyError: If no vector is found with the given ID. + """ + if not isinstance(items, str): + raise ArgError("items must be a string id") + + row = await self.fetch([items]) + + if not row: + raise KeyError("no item found with requested id") + + return row[0] + + async def query( + self, + data: Union[Iterable[Numeric], Any], + limit: int = 10, + filters: Optional[Dict] = None, + measure: Union[IndexMeasure, str] = IndexMeasure.cosine_distance, + include_value: bool = False, + include_metadata: bool = False, + include_vector: bool = False, + *, + probes: Optional[int] = None, + ef_search: Optional[int] = None, + skip_adapter: bool = False, + ) -> Union[List[Record], List[str]]: + """ + Executes a similarity search in the collection. + + The return type is dependent on arguments *include_value* and *include_metadata* + + Args: + data (Any): The vector to use as the query. + limit (int, optional): The maximum number of results to return. Defaults to 10. + filters (Optional[Dict], optional): Filters to apply to the search. Defaults to None. + measure (Union[IndexMeasure, str], optional): The distance measure to use for the search. Defaults to 'cosine_distance'. + include_value (bool, optional): Whether to include the distance value in the results. Defaults to False. + include_metadata (bool, optional): Whether to include the metadata in the results. Defaults to False. + include_vector (bool, optional): Whether to include the vector in the results. Defaults to False. + probes (Optional[Int], optional): Number of ivfflat index lists to query. Higher increases accuracy but decreases speed + ef_search (Optional[Int], optional): Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed + skip_adapter (bool, optional): When True, skips any associated adapter and queries using a literal vector provided to *data* + + Returns: + Union[List[Record], List[str]]: The result of the similarity search. + """ + + if probes is None: + probes = 10 + + if ef_search is None: + ef_search = 40 + + if not isinstance(probes, int): + raise ArgError("probes must be an integer") + + if probes < 1: + raise ArgError("probes must be >= 1") + + if limit > 1000: + raise ArgError("limit must be <= 1000") + + # ValueError on bad input + try: + imeasure = IndexMeasure(measure) + except ValueError: + raise ArgError("Invalid index measure") + + if not await self.is_indexed_for_measure(imeasure): + warnings.warn( + UserWarning( + f"Query does not have a covering index for {imeasure}. See Collection.create_index" + ) + ) + + if skip_adapter: + adapted_query = [("", data, {})] + else: + # Adapt the query using the pipeline + adapted_query = [ + x + for x in self.adapter( + records=[("", data, {})], adapter_context=AdapterContext("query") + ) + ] + + if len(adapted_query) != 1: + raise ArgError("Failed to produce exactly one query vector from input") + + _, vec, _ = adapted_query[0] + + distance_lambda = INDEX_MEASURE_TO_SQLA_ACC.get(imeasure) + if distance_lambda is None: + # unreachable + raise ArgError("invalid distance_measure") # pragma: no cover + + distance_clause = distance_lambda(self.table.c.vec)(vec) + + cols = [self.table.c.id] + + if include_value: + cols.append(distance_clause) + + if include_vector: + cols.append(self.table.c.vec) + + if include_metadata: + cols.append(self.table.c.metadata) + + stmt = select(*cols) + if filters: + stmt = stmt.filter( + build_filters(self.table.c.metadata, filters) # type: ignore + ) + + stmt = stmt.order_by(distance_clause) + stmt = stmt.limit(limit) + + async with self.client.AsyncSession() as sess: + async with sess.begin(): + # index ignored if greater than n_lists + await sess.execute(text(f"set local ivfflat.probes = {int(probes)}")) + if self.client._supports_hnsw(): + await sess.execute( + text(f"set local hnsw.ef_search = {int(ef_search)}") + ) + result = await sess.execute(stmt) + if len(cols) == 1: + return [str(x) for x in result.scalars().all()] + return result.fetchall() + + @classmethod + async def _list_collections(cls, client: AsyncClient) -> List[AsyncCollection]: + """ + Lists all collections in the database. + + Args: + client (AsyncClient): The async client to use for the database connection. + + Returns: + List[AsyncCollection]: A list of all collections. + """ + query = text( + f""" + select + relname as table_name, + atttypmod as embedding_dim + from + pg_class pc + join pg_attribute pa + on pc.oid = pa.attrelid + where + pc.relnamespace = 'vecs'::regnamespace + and pc.relkind = 'r' + and pa.attname = 'vec' + and not pc.relname ^@ '_' + """ + ) + xc = [] + async with client.AsyncSession() as sess: + result = await sess.execute(query) + for name, dimension in result.all(): + existing_collection = cls(name, dimension, client) + xc.append(existing_collection) + return xc + + @classmethod + async def _does_collection_exist(cls, client: "AsyncClient", name: str) -> bool: + """ + PRIVATE + + Checks if a collection with a given name exists within the database + + Args: + client (AsyncClient): The database client. + name (str): The name of the collection + + Returns: + Exists: Whether the collection exists or not + """ + try: + await client.get_collection(name) + return True + except CollectionNotFound: + return False + + async def index(self) -> Optional[str]: + """ + Returns the name of the index for the collection, if one exists. + + Returns: + Optional[str]: The name of the index, or None if no index exists. + """ + if self._index is None: + query = text( + """ + select + pi.relname as index_name + from + pg_class pi -- index info + join pg_index i -- extend index info + on pi.oid = i.indexrelid + join pg_class pt -- owning table info + on pt.oid = i.indrelid + where + pi.relnamespace = 'vecs'::regnamespace + and pi.relname ilike 'ix_vector%' + and pi.relkind = 'i' + and pt.relname = :table_name + """ + ) + async with self.client.AsyncSession() as sess: + result = await sess.execute(query, {"table_name": self.name}) + ix_name = result.scalar() + self._index = ix_name + return self._index + + async def is_indexed_for_measure(self, measure: IndexMeasure): + """ + Checks if the collection is indexed for the given measure. + + Args: + measure (IndexMeasure): The measure to check for. + + Returns: + bool: True if the collection is indexed for the measure, False otherwise. + """ + index_name = await self.index() + if index_name is None: + return False + ops = INDEX_MEASURE_TO_OPS.get(measure) + if ops is None: + return False + + return ops in index_name + + async def create_index( + self, + measure: IndexMeasure = IndexMeasure.cosine_distance, + method: IndexMethod = IndexMethod.auto, + index_arguments: Optional[Union[IndexArgsIVFFlat, IndexArgsHNSW]] = None, + replace: bool = True, + ) -> None: + """ + Asynchronously creates an index for the collection. + + Note: + When `vecs` creates an index on a pgvector column in PostgreSQL, it uses a multi-step + process that enables performant indexes to be built for large collections with low end + database hardware. + + Those steps are: + + - Creates a new table with a different name + - Randomly selects records from the existing table + - Inserts the random records from the existing table into the new table + - Creates the requested vector index on the new table + - Upserts all data from the existing table into the new table + - Drops the existing table + - Renames the new table to the existing tables name + + If you create dependencies (like views) on the table that underpins + a `vecs.Collection` the `create_index` step may require you to drop those dependencies before + it will succeed. + + Args: + measure (IndexMeasure, optional): The measure to index for. Defaults to 'cosine_distance'. + method (IndexMethod, optional): The indexing method to use. Defaults to 'auto'. + index_arguments: (IndexArgsIVFFlat | IndexArgsHNSW, optional): Index type specific arguments + replace (bool, optional): Whether to replace the existing index. Defaults to True. + + Raises: + ArgError: + """ + if index_arguments: + if method == IndexMethod.auto: + raise ArgError( + "Index build parameters are not allowed when using the IndexMethod.auto index." + ) + if ( + isinstance(index_arguments, IndexArgsHNSW) + and method != IndexMethod.hnsw + ) or ( + isinstance(index_arguments, IndexArgsIVFFlat) + and method != IndexMethod.ivfflat + ): + raise ArgError( + f"{index_arguments.__class__.__name__} build parameters were supplied but {method} index was specified." + ) + + # Auto-detect method if needed + if method == IndexMethod.auto: + if self.client._supports_hnsw(): + method = IndexMethod.hnsw + else: + method = IndexMethod.ivfflat + + if method == IndexMethod.hnsw and not self.client._supports_hnsw(): + raise ArgError( + "HNSW Unavailable. Upgrade your pgvector installation to > 0.5.0 to enable HNSW support" + ) + + ops = INDEX_MEASURE_TO_OPS.get(measure) + if ops is None: + raise ArgError("Unknown index measure") + + unique_string = str(uuid.uuid4()).replace("-", "_")[0:7] + + async with self.client.AsyncSession() as sess: + async with sess.begin(): + current_index = await self.index() + if current_index is not None: + if replace: + await sess.execute( + text(f'DROP INDEX IF EXISTS vecs."{current_index}";') + ) + self._index = None + else: + raise ArgError("replace is set to False but an index exists") + + if method == IndexMethod.ivfflat: + if not index_arguments: + result = await sess.execute( + select(func.count()).select_from(self.table) + ) + n_records: int = result.scalar_one() + n_lists = ( + int(max(n_records / 1000, 30)) + if n_records < 1_000_000 + else int(math.sqrt(n_records)) + ) + else: + n_lists = index_arguments.n_lists # type: ignore + + await sess.execute( + text( + f""" + CREATE INDEX ix_{ops}_ivfflat_nl{n_lists}_{unique_string} + ON vecs."{self.table.name}" + USING ivfflat (vec {ops}) WITH (lists = {n_lists}) + """ + ) + ) + + elif method == IndexMethod.hnsw: + if not index_arguments: + index_arguments = IndexArgsHNSW() + + m = index_arguments.m # type: ignore + ef_construction = index_arguments.ef_construction # type: ignore + + await sess.execute( + text( + f""" + CREATE INDEX ix_{ops}_hnsw_m{m}_efc{ef_construction}_{unique_string} + ON vecs."{self.table.name}" + USING hnsw (vec {ops}) WITH (m = {m}, ef_construction = {ef_construction}) + """ + ) + ) + + return None diff --git a/src/vecs/client.py b/src/vecs/client.py index 89bb3e3..aa2ae36 100644 --- a/src/vecs/client.py +++ b/src/vecs/client.py @@ -29,8 +29,8 @@ class Client: A `Client` instance represents a connection to a PostgreSQL database. This connection can be used to create and manipulate vector collections, where each collection is a group of vector records in a PostgreSQL table. - The `vecs.Client` class can be also supports usage as a context manager to ensure the connection to the database - is properly closed after operations, or used directly. + The `vecs.Client` class can also be used as a context manager to ensure the connection + to the database is properly closed after operations, or it can be used directly. Example usage: @@ -72,13 +72,7 @@ def __init__(self, connection_string: str): ).scalar_one() def _supports_hnsw(self): - return ( - not self.vector_version.startswith("0.4") - and not self.vector_version.startswith("0.3") - and not self.vector_version.startswith("0.2") - and not self.vector_version.startswith("0.1") - and not self.vector_version.startswith("0.0") - ) + return not self.vector_version.startswith(("0.0", "0.1", "0.2", "0.3", "0.4")) def get_or_create_collection( self, @@ -96,13 +90,10 @@ def get_or_create_collection( Keyword Args: dimension (int): The dimensionality of the vectors in the collection. - pipeline (int): The dimensionality of the vectors in the collection. + adapter (Adapter): The adapter to use for the collection. Returns: Collection: The created collection. - - Raises: - CollectionAlreadyExists: If a collection with the same name already exists """ from vecs.collection import Collection diff --git a/src/vecs/collection.py b/src/vecs/collection.py index eab8f93..9f071af 100644 --- a/src/vecs/collection.py +++ b/src/vecs/collection.py @@ -55,9 +55,6 @@ class IndexMethod(str, Enum): """ An enum representing the index methods available. - This class currently only supports the 'ivfflat' method but may - expand in the future. - Attributes: auto (str): Automatically choose the best available index method. ivfflat (str): The ivfflat index method. @@ -77,6 +74,7 @@ class IndexMeasure(str, Enum): cosine_distance (str): The cosine distance measure for indexing. l2_distance (str): The Euclidean (L2) distance measure for indexing. max_inner_product (str): The maximum inner product measure for indexing. + l1_distance (str): The L1 distance measure for indexing. """ cosine_distance = "cosine_distance" @@ -92,7 +90,7 @@ class IndexArgsIVFFlat: method when building an IVFFlat type index. Attributes: - nlist (int): The number of IVF centroids that the index should use + n_lists (int): The number of IVF centroids that the index should use """ n_lists: int @@ -120,7 +118,7 @@ class IndexArgsHNSW: INDEX_MEASURE_TO_OPS = { - # Maps the IndexMeasure enum options to the SQL ops string required by + # Maps the IndexMeasure enum options to SQL ops strings required by # the pgvector `create index` statement IndexMeasure.cosine_distance: "vector_cosine_ops", IndexMeasure.l2_distance: "vector_l2_ops", @@ -174,6 +172,7 @@ def __init__( name (str): The name of the collection. dimension (int): The dimension of the vectors in the collection. client (Client): The client to use for interacting with the database. + adapter (Adapter): The adapter to use for the collection. """ self.client = client self.name = name @@ -193,10 +192,12 @@ def __init__( ] ) if len(reported_dimensions) == 0: - raise ArgError("One of dimension or adapter must provide a dimension") + raise ArgError( + "Dimension must be provided by either `dimension` argument or adapter." + ) elif len(reported_dimensions) > 1: raise MismatchedDimension( - "Dimensions reported by adapter, dimension, and collection do not match" + "Dimensions reported by `dimension` argument and adapter do not match." ) def __repr__(self): @@ -259,7 +260,7 @@ def _create_if_not_exists(self): ) if len(reported_dimensions) > 1: raise MismatchedDimension( - "Dimensions reported by adapter, dimension, and existing collection do not match" + "Dimensions reported by `dimension` argument and existing collection do not match." ) if not collection_dimension: @@ -394,6 +395,9 @@ def delete( Returns: List[str]: A list of the identifiers of the deleted vectors. + + Raises: + ArgError: If both or neither of `ids` and `filters` are provided. """ if ids is None and filters is None: raise ArgError("Either ids or filters must be provided.") @@ -438,6 +442,10 @@ def __getitem__(self, items): Returns: Record: The fetched vector. + + Raises: + ArgError: If the input is not a string. + KeyError: If no vector is found with the given ID. """ if not isinstance(items, str): raise ArgError("items must be a string id") @@ -475,6 +483,7 @@ def query( include_value (bool, optional): Whether to include the distance value in the results. Defaults to False. include_metadata (bool, optional): Whether to include the metadata in the results. Defaults to False. probes (Optional[Int], optional): Number of ivfflat index lists to query. Higher increases accuracy but decreases speed + include_vector: (Optional[bool], optional): Wether to include the vector in the result. Defaults to False. ef_search (Optional[Int], optional): Size of the dynamic candidate list for HNSW index search. Higher increases accuracy but decreases speed skip_adapter (bool, optional): When True, skips any associated adapter and queries using a literal vector provided to *data* @@ -960,7 +969,7 @@ def build_table(name: str, meta: MetaData, dimension: int) -> Table: """ PRIVATE - Builds a SQLAlchemy model underpinning a `vecs.Collection`. + Builds an SQLAlchemy model underpinning a `vecs.Collection`. Args: name (str): The name of the table. From 5de44e2d2291cbd23d8841da25d9e908f6ef1684 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Wed, 16 Jul 2025 19:22:15 +0200 Subject: [PATCH 2/3] Add greenlet to setup.py requirements. --- setup.py | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.py b/setup.py index 263ebf5..4c99110 100644 --- a/setup.py +++ b/setup.py @@ -44,6 +44,7 @@ def read_package_variable(key, filename="__init__.py"): "sqlalchemy==2.*", "psycopg2-binary==2.9.*", "asyncpg==0.29.*", + "greenlet>=1.0.0", "flupy==1.*", "deprecated==1.2.*", ] From 726ebf736199992d44ccdbe04280a0e85e8ba153 Mon Sep 17 00:00:00 2001 From: Dijana Pavlovic Date: Thu, 17 Jul 2025 22:41:42 +0200 Subject: [PATCH 3/3] update autoflake version --- .pre-commit-config.yaml | 46 +++++++++++++++++++++++------------------ src/tests/conftest.py | 1 - src/vecs/__init__.py | 4 ++-- 3 files changed, 28 insertions(+), 23 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 1fcf653..15e7818 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,28 +1,34 @@ repos: -- repo: https://github.com/pre-commit/mirrors-isort - rev: v5.6.4 + - repo: https://github.com/pre-commit/mirrors-isort + rev: v5.10.1 hooks: - - id: isort - args: ['--multi-line=3', '--trailing-comma', '--force-grid-wrap=0', '--use-parentheses', '--line-width=88'] + - id: isort + args: + [ + "--multi-line=3", + "--trailing-comma", + "--force-grid-wrap=0", + "--use-parentheses", + "--line-width=88", + ] - -- repo: https://github.com/pre-commit/pre-commit-hooks - rev: v3.3.0 + - repo: https://github.com/pre-commit/pre-commit-hooks + rev: v5.0.0 hooks: - - id: trailing-whitespace - - id: check-added-large-files - - id: check-yaml - - id: mixed-line-ending - args: ['--fix=lf'] + - id: trailing-whitespace + - id: check-added-large-files + - id: check-yaml + - id: mixed-line-ending + args: ["--fix=lf"] -- repo: https://github.com/humitos/mirrors-autoflake.git - rev: v1.1 + - repo: https://github.com/PyCQA/autoflake + rev: v2.3.1 hooks: - - id: autoflake - args: ['--in-place', '--remove-all-unused-imports'] + - id: autoflake + args: ["--in-place", "--remove-all-unused-imports"] -- repo: https://github.com/ambv/black - rev: 22.10.0 + - repo: https://github.com/ambv/black + rev: 25.1.0 hooks: - - id: black - language_version: python3.9 + - id: black + language_version: python3.9 diff --git a/src/tests/conftest.py b/src/tests/conftest.py index dfedc17..44f4b4f 100644 --- a/src/tests/conftest.py +++ b/src/tests/conftest.py @@ -8,7 +8,6 @@ import pytest import pytest_asyncio - from parse import parse from sqlalchemy import create_engine, text diff --git a/src/vecs/__init__.py b/src/vecs/__init__.py index 3c4ffff..535aa2c 100644 --- a/src/vecs/__init__.py +++ b/src/vecs/__init__.py @@ -1,4 +1,6 @@ from vecs import exc +from vecs.async_client import AsyncClient +from vecs.async_collection import AsyncCollection from vecs.client import Client from vecs.collection import ( Collection, @@ -7,8 +9,6 @@ IndexMeasure, IndexMethod, ) -from vecs.async_client import AsyncClient -from vecs.async_collection import AsyncCollection __project__ = "vecs" __version__ = "0.4.5"