diff --git a/pyproject.toml b/pyproject.toml index 62f7016..02ce7dc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -37,6 +37,9 @@ namespaces = false [tool.setuptools.package-data] classifai = [] +[tool.deptry.package_module_name_map] +google-genai = "google" + [project.optional-dependencies] huggingface = [ "transformers>=4.52.4", diff --git a/src/classifai/indexers/main.py b/src/classifai/indexers/main.py index d35dc2c..c3d4dcf 100644 --- a/src/classifai/indexers/main.py +++ b/src/classifai/indexers/main.py @@ -318,7 +318,9 @@ def reverse_search(self, query: VectorStoreReverseSearchInput, n_results=100) -> ) from e # polars conversion - paired_query = pl.DataFrame({"id": query.id, "doc_id": query.doc_id}) + paired_query = pl.DataFrame( + {"id": query.id.astype(str).to_list(), "doc_id": query.doc_id.astype(str).to_list()} + ) # join query with vdb to get matches joined_table = paired_query.join(self.vectors.rename({"id": "doc_id"}), on="doc_id", how="inner") @@ -333,7 +335,7 @@ def reverse_search(self, query: VectorStoreReverseSearchInput, n_results=100) -> ] ) - result_df = VectorStoreReverseSearchOutput.from_data(final_table.to_pandas()) + result_df = VectorStoreReverseSearchOutput.from_data(final_table.to_dict(as_series=False)) # Check if there is a user defined postprocess hook for the VectorStore reverse search method if "reverse_search_postprocess" in self.hooks: @@ -444,9 +446,8 @@ def search(self, query: VectorStoreSearchInput, n_results=10, batch_size=8) -> V *self.meta_data.keys(), ] ) - - # Now that polars has been used for processing convert back to pandas for user familiarity - result_df = VectorStoreSearchOutput.from_data(reordered_df.to_pandas()) + # Now that polars has been used for processing convert back to VectorStoreSearchOutput dataclass for output + result_df = VectorStoreSearchOutput.from_data(reordered_df.to_dict(as_series=False)) # Check if there is a user defined postprocess hook for the VectorStore search method if "search_postprocess" in self.hooks: diff --git a/uv.lock b/uv.lock index 8082439..758d0b8 100644 --- a/uv.lock +++ b/uv.lock @@ -302,7 +302,7 @@ wheels = [ [[package]] name = "classifai" -version = "0.1.0" +version = "0.2.0" source = { editable = "." } dependencies = [ { name = "fastapi", extra = ["standard"] },