From cbefdd54f54e0256b4b2dfc21e7046e00832dfd7 Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 9 Apr 2025 13:30:52 -0700 Subject: [PATCH 1/2] refactor(google): use new google-genai SDK --- docs/recipes/src/talk_to_a_video_1.py | 12 +- .../drivers/prompt/google_prompt_driver.py | 199 +++++++++--------- griptape/tokenizers/google_tokenizer.py | 25 +-- griptape/utils/dict_utils.py | 2 + pyproject.toml | 6 +- .../prompt/test_google_prompt_driver.py | 29 +-- uv.lock | 152 ++----------- 7 files changed, 145 insertions(+), 280 deletions(-) diff --git a/docs/recipes/src/talk_to_a_video_1.py b/docs/recipes/src/talk_to_a_video_1.py index f44243c4d0..205a6cff92 100644 --- a/docs/recipes/src/talk_to_a_video_1.py +++ b/docs/recipes/src/talk_to_a_video_1.py @@ -1,20 +1,20 @@ import time -from google.generativeai.files import get_file, upload_file - from griptape.artifacts import GenericArtifact, TextArtifact from griptape.configs import Defaults from griptape.configs.drivers import GoogleDriversConfig from griptape.structures import Agent Defaults.drivers_config = GoogleDriversConfig() +client = Defaults.drivers_config.prompt_driver.client -video_file = upload_file(path="tests/resources/griptape-comfyui.mp4") -while video_file.state.name == "PROCESSING": +video_file = client.files.upload(file="tests/resources/griptape-comfyui.mp4") +while video_file.state and video_file.state.name == "PROCESSING": time.sleep(2) - video_file = get_file(video_file.name) + if video_file.name: + video_file = client.files.get(name=video_file.name) -if video_file.state.name == "FAILED": +if video_file.state and video_file.state.name == "FAILED": raise ValueError(video_file.state.name) agent = Agent( diff --git a/griptape/drivers/prompt/google_prompt_driver.py b/griptape/drivers/prompt/google_prompt_driver.py index d01b35e34e..e388f7d95c 100644 --- a/griptape/drivers/prompt/google_prompt_driver.py +++ b/griptape/drivers/prompt/google_prompt_driver.py @@ -4,7 +4,8 @@ import logging from typing import TYPE_CHECKING, Optional -from attrs import Attribute, Factory, define, field +from attrs import Factory, define, field +from pydantic import BaseModel from griptape.artifacts import ActionArtifact, TextArtifact from griptape.common import ( @@ -32,9 +33,8 @@ if TYPE_CHECKING: from collections.abc import Iterator - from google.generativeai.generative_models import GenerativeModel - from google.generativeai.protos import Part - from google.generativeai.types import ContentDict, ContentsType, GenerateContentResponse + from google.genai import Client + from google.genai.types import Content, ContentDict, GenerateContentResponse, Part, Tool from griptape.drivers.prompt.base_prompt_driver import StructuredOutputStrategy from griptape.tools import BaseTool @@ -64,123 +64,114 @@ class GooglePromptDriver(BasePromptDriver): top_k: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) use_native_tools: bool = field(default=True, kw_only=True, metadata={"serializable": True}) structured_output_strategy: StructuredOutputStrategy = field( - default="tool", kw_only=True, metadata={"serializable": True} + default="native", kw_only=True, metadata={"serializable": True} ) tool_choice: str = field(default="auto", kw_only=True, metadata={"serializable": True}) - _client: Optional[GenerativeModel] = field( - default=None, kw_only=True, alias="client", metadata={"serializable": False} - ) - - @structured_output_strategy.validator # pyright: ignore[reportAttributeAccessIssue, reportOptionalMemberAccess] - def validate_structured_output_strategy(self, _: Attribute, value: str) -> str: - if value == "native": - raise ValueError(f"{__class__.__name__} does not support `{value}` structured output strategy.") - return value + _client: Client = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() - def client(self) -> GenerativeModel: - genai = import_optional_dependency("google.generativeai") - genai.configure(api_key=self.api_key) - - return genai.GenerativeModel(self.model) + def client(self) -> Client: + genai = import_optional_dependency("google.genai") + return genai.Client(api_key=self.api_key) @observable def try_run(self, prompt_stack: PromptStack) -> Message: - messages = self.__to_google_messages(prompt_stack) params = self._base_params(prompt_stack) - logging.debug((messages, params)) - response: GenerateContentResponse = self.client.generate_content(messages, **params) - logging.debug(response.to_dict()) + logging.debug(params) + response = self.client.models.generate_content(**params) + logging.debug(response.model_dump()) usage_metadata = response.usage_metadata return Message( - content=[self.__to_prompt_stack_message_content(part) for part in response.parts], + content=self.__to_prompt_stack_message_content(response), role=Message.ASSISTANT_ROLE, usage=Message.Usage( - input_tokens=usage_metadata.prompt_token_count, - output_tokens=usage_metadata.candidates_token_count, + input_tokens=usage_metadata.prompt_token_count if usage_metadata else None, + output_tokens=usage_metadata.candidates_token_count if usage_metadata else None, ), ) @observable def try_stream(self, prompt_stack: PromptStack) -> Iterator[DeltaMessage]: - messages = self.__to_google_messages(prompt_stack) - params = {**self._base_params(prompt_stack), "stream": True} - logging.debug((messages, params)) - response: GenerateContentResponse = self.client.generate_content( - messages, - **params, - ) + params = self._base_params(prompt_stack) + logging.debug(params) + response = self.client.models.generate_content_stream(**params) prompt_token_count = None for chunk in response: - logger.debug(chunk.to_dict()) + logger.debug(chunk.model_dump()) usage_metadata = chunk.usage_metadata - content = self.__to_prompt_stack_delta_message_content(chunk.parts[0]) if chunk.parts else None + content = self.__to_prompt_stack_delta_message_content(chunk) # Only want to output the prompt token count once since it is static each chunk if prompt_token_count is None: - prompt_token_count = usage_metadata.prompt_token_count yield DeltaMessage( content=content, usage=DeltaMessage.Usage( - input_tokens=usage_metadata.prompt_token_count, - output_tokens=usage_metadata.candidates_token_count, + input_tokens=usage_metadata.prompt_token_count if usage_metadata else None, + output_tokens=usage_metadata.candidates_token_count if usage_metadata else None, ), ) else: yield DeltaMessage( content=content, - usage=DeltaMessage.Usage(output_tokens=usage_metadata.candidates_token_count), + usage=DeltaMessage.Usage( + output_tokens=usage_metadata.candidates_token_count if usage_metadata else None + ), ) def _base_params(self, prompt_stack: PromptStack) -> dict: - types = import_optional_dependency("google.generativeai.types") - protos = import_optional_dependency("google.generativeai.protos") + types = import_optional_dependency("google.genai.types") system_messages = prompt_stack.system_messages + system_instruction = None if system_messages: - self.client._system_instruction = types.ContentDict( - role="system", - parts=[protos.Part(text=system_message.to_text()) for system_message in system_messages], - ) + system_instruction = "".join([system_message.to_text() for system_message in system_messages]) params = { - "generation_config": types.GenerationConfig( - **{ - # For some reason, providing stop sequences when streaming breaks native functions - # https://github.com/google-gemini/generative-ai-python/issues/446 - "stop_sequences": [] if self.stream and self.use_native_tools else self.tokenizer.stop_sequences, - "max_output_tokens": self.max_tokens, - "temperature": self.temperature, - "top_p": self.top_p, - "top_k": self.top_k, - **self.extra_params, - }, - ), + "model": self.model, + "contents": self.__to_google_messages(prompt_stack), + } + + config = { + "stop_sequences": [] if self.use_native_tools else self.tokenizer.stop_sequences, + "max_output_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "system_instruction": system_instruction, + "automatic_function_calling": types.AutomaticFunctionCallingConfig(disable=True), + **self.extra_params, } + if ( + self.structured_output_strategy == "native" + and isinstance(prompt_stack.output_schema, type) + and issubclass(prompt_stack.output_schema, BaseModel) + ): + config["response_schema"] = prompt_stack.output_schema + if prompt_stack.tools and self.use_native_tools: - params["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} + config["tool_config"] = {"function_calling_config": {"mode": self.tool_choice}} if prompt_stack.output_schema is not None and self.structured_output_strategy == "tool": - params["tool_config"]["function_calling_config"]["mode"] = "auto" + config["tool_config"]["function_calling_config"]["mode"] = "auto" + + config["tools"] = self.__to_google_tools(prompt_stack.tools) - params["tools"] = self.__to_google_tools(prompt_stack.tools) + params["config"] = types.GenerateContentConfig(**config) return params - def __to_google_messages(self, prompt_stack: PromptStack) -> ContentsType: - types = import_optional_dependency("google.generativeai.types") + def __to_google_messages(self, prompt_stack: PromptStack) -> list[Content]: + types = import_optional_dependency("google.genai.types") return [ - types.ContentDict( - { - "role": self.__to_google_role(message), - "parts": [self.__to_google_message_content(content) for content in message.content], - }, + types.Content( + role=self.__to_google_role(message), + parts=[self.__to_google_message_content(content) for content in message.content], ) for message in prompt_stack.messages if not message.is_system() @@ -191,8 +182,8 @@ def __to_google_role(self, message: Message) -> str: return "model" return "user" - def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: - types = import_optional_dependency("google.generativeai.types") + def __to_google_tools(self, tools: list[BaseTool]) -> list[Tool]: + types = import_optional_dependency("google.genai.types") tool_declarations = [] for tool in tools: @@ -203,7 +194,7 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: schema = schema["properties"]["values"] schema = remove_key_in_dict_recursively(schema, "additionalProperties") - tool_declaration = types.FunctionDeclaration( + function_declaration = types.FunctionDeclaration( name=tool.to_native_tool_name(activity), description=tool.activity_description(activity), **( @@ -218,67 +209,69 @@ def __to_google_tools(self, tools: list[BaseTool]) -> list[dict]: else {} ), ) + google_tool = types.Tool(function_declarations=[function_declaration]) - tool_declarations.append(tool_declaration) + tool_declarations.append(google_tool) return tool_declarations def __to_google_message_content(self, content: BaseMessageContent) -> ContentDict | Part | str: - types = import_optional_dependency("google.generativeai.types") - protos = import_optional_dependency("google.generativeai.protos") + types = import_optional_dependency("google.genai.types") if isinstance(content, TextMessageContent): - return content.artifact.to_text() + return types.Part.from_text(text=content.artifact.to_text()) if isinstance(content, ImageMessageContent): - return types.ContentDict(mime_type=content.artifact.mime_type, data=content.artifact.value) + return types.Part.from_bytes(mime_type=content.artifact.mime_type, data=content.artifact.value) if isinstance(content, ActionCallMessageContent): action = content.artifact.value - return protos.Part(function_call=protos.FunctionCall(name=action.tag, args=action.input)) + return types.Part(function_call=types.FunctionCall(name=action.tag, args=action.input)) if isinstance(content, ActionResultMessageContent): artifact = content.artifact - return protos.Part( - function_response=protos.FunctionResponse( + return types.Part( + function_response=types.FunctionResponse( name=content.action.to_native_tool_name(), response=artifact.to_dict(), ), ) if isinstance(content, GenericMessageContent): - return content.artifact.value + file = content.artifact.value + return types.Part.from_uri(file_uri=file.uri, mime_type=file.mime_type) raise ValueError(f"Unsupported prompt stack content type: {type(content)}") - def __to_prompt_stack_message_content(self, content: Part) -> BaseMessageContent: - json_format = import_optional_dependency("google.protobuf.json_format") - + def __to_prompt_stack_message_content(self, content: GenerateContentResponse) -> list[BaseMessageContent]: if content.text: - return TextMessageContent(TextArtifact(content.text)) - if content.function_call: - function_call = content.function_call - - name, path = ToolAction.from_native_tool_name(function_call.name) - - args = json_format.MessageToDict(function_call._pb).get("args", {}) - return ActionCallMessageContent( - artifact=ActionArtifact(value=ToolAction(tag=function_call.name, name=name, path=path, input=args)), - ) - raise ValueError(f"Unsupported message content type {content}") + return [TextMessageContent(TextArtifact(content.text))] + if content.function_calls: + return [ + ActionCallMessageContent( + ActionArtifact( + ToolAction( + tag=function_call.name, + name=ToolAction.from_native_tool_name(function_call.name)[0], + path=ToolAction.from_native_tool_name(function_call.name)[1], + input=function_call.args or {}, + ), + ), + ) + for function_call in content.function_calls + if function_call.name + ] - def __to_prompt_stack_delta_message_content(self, content: Part) -> BaseDeltaMessageContent: - json_format = import_optional_dependency("google.protobuf.json_format") + return [] + def __to_prompt_stack_delta_message_content(self, content: GenerateContentResponse) -> BaseDeltaMessageContent: if content.text: return TextDeltaMessageContent(content.text) - if content.function_call: - function_call = content.function_call - - name, path = ToolAction.from_native_tool_name(function_call.name) + if content.function_calls: + function_call = content.function_calls[0] - args = json_format.MessageToDict(function_call._pb).get("args", {}) + args = function_call.args return ActionCallDeltaMessageContent( tag=function_call.name, - name=name, - path=path, - partial_input=json.dumps(args), + name=ToolAction.from_native_tool_name(function_call.name)[0] if function_call.name else None, + path=ToolAction.from_native_tool_name(function_call.name)[1] if function_call.name else None, + partial_input=json.dumps(args) if args else None, ) raise ValueError(f"Unsupported message content type {content}") diff --git a/griptape/tokenizers/google_tokenizer.py b/griptape/tokenizers/google_tokenizer.py index 400edb6cfb..a3d02909ff 100644 --- a/griptape/tokenizers/google_tokenizer.py +++ b/griptape/tokenizers/google_tokenizer.py @@ -9,25 +9,26 @@ from griptape.utils.decorators import lazy_property if TYPE_CHECKING: - from google.generativeai.generative_models import GenerativeModel + from google.genai import Client @define() class GoogleTokenizer(BaseTokenizer): - MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = {"gemini-1.5-pro": 2097152, "gemini": 1048576} - MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"gemini": 8192} + MODEL_PREFIXES_TO_MAX_INPUT_TOKENS = { + "gemini-2.5-pro-preview": 1048576, + "gemini-2.0-flash": 1048576, + "gemini-1.5-pro": 2097152, + "gemini": 1048576, + } + MODEL_PREFIXES_TO_MAX_OUTPUT_TOKENS = {"gemini-2.5-pro-preview": 65536, "gemini": 8192} api_key: str = field(kw_only=True, metadata={"serializable": True}) - _client: Optional[GenerativeModel] = field( - default=None, kw_only=True, alias="client", metadata={"serializable": False} - ) + _client: Optional[Client] = field(default=None, kw_only=True, alias="client", metadata={"serializable": False}) @lazy_property() - def client(self) -> GenerativeModel: - genai = import_optional_dependency("google.generativeai") - genai.configure(api_key=self.api_key) - - return genai.GenerativeModel(self.model) + def client(self) -> Client: + genai = import_optional_dependency("google.genai") + return genai.Client(api_key=self.api_key) def count_tokens(self, text: str) -> int: - return self.client.count_tokens(text).total_tokens + return self.client.models.count_tokens(model=self.model, contents=text).total_tokens or 0 diff --git a/griptape/utils/dict_utils.py b/griptape/utils/dict_utils.py index 50261d2ab3..946d3bcf3e 100644 --- a/griptape/utils/dict_utils.py +++ b/griptape/utils/dict_utils.py @@ -12,6 +12,8 @@ def remove_null_values_in_dict_recursively(d: dict) -> dict: def remove_key_in_dict_recursively(d: dict, key: str) -> dict: if isinstance(d, dict): return {k: remove_key_in_dict_recursively(v, key) for k, v in d.items() if k != key} + if isinstance(d, list): + return [remove_key_in_dict_recursively(v, key) for v in d] return d diff --git a/pyproject.toml b/pyproject.toml index 061a2d9ea9..30080f53f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ drivers-prompt-huggingface-hub = [ drivers-prompt-huggingface-pipeline = ["transformers>=4.41.1"] drivers-prompt-amazon-bedrock = ["boto3>=1.34.119", "anthropic>=0.45.1"] drivers-prompt-amazon-sagemaker = ["boto3>=1.34.119", "transformers>=4.41.1"] -drivers-prompt-google = ["google-generativeai>=0.8.2"] +drivers-prompt-google = ["google-genai>=1.10.0"] drivers-prompt-ollama = ["ollama>=0.4.1"] drivers-sql = ["sqlalchemy>=2.0.31"] drivers-sql-amazon-redshift = ["boto3>=1.34.119"] @@ -67,7 +67,7 @@ drivers-embedding-huggingface = [ "transformers>=4.41.1", ] drivers-embedding-voyageai = ["voyageai>=0.2.1"] -drivers-embedding-google = ["google-generativeai>=0.8.2"] +drivers-embedding-google = ["google-genai>=1.10.0"] drivers-embedding-cohere = ["cohere>=5.11.2"] drivers-embedding-ollama = ["ollama>=0.4.1"] drivers-web-scraper-trafilatura = ["trafilatura>=2.0"] @@ -129,7 +129,7 @@ all = [ "opensearch-py>=2.3.1", "pgvector>=0.3.4", "psycopg2-binary>=2.9.9", - "google-generativeai>=0.8.2", + "google-genai>=1.10.0", "trafilatura>=2.0", "playwright>=1.42", "beautifulsoup4>=4.12.3", diff --git a/tests/unit/drivers/prompt/test_google_prompt_driver.py b/tests/unit/drivers/prompt/test_google_prompt_driver.py index d5f66ca4cb..6464aceab4 100644 --- a/tests/unit/drivers/prompt/test_google_prompt_driver.py +++ b/tests/unit/drivers/prompt/test_google_prompt_driver.py @@ -1,8 +1,7 @@ from unittest.mock import MagicMock, Mock import pytest -from google.generativeai.protos import FunctionCall, FunctionResponse, Part -from google.generativeai.types import ContentDict, GenerationConfig +from google.genai.types import ContentDict, FunctionCall, FunctionResponse, GenerationConfig, Part from google.protobuf.json_format import MessageToDict from schema import Schema @@ -51,7 +50,7 @@ class TestGooglePromptDriver: @pytest.fixture() def mock_generative_model(self, mocker): - mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") + mock_generative_model = mocker.patch("google.genai.Client") mocker.patch("google.protobuf.json_format.MessageToDict").return_value = { "args": {"foo": "bar"}, } @@ -60,10 +59,8 @@ def mock_generative_model(self, mocker): ) mock_function_call.name = "MockTool_test" mock_generative_model.return_value.generate_content.return_value = Mock( - parts=[ - Mock(text="model-output", function_call=None), - MagicMock(name="foo", text=None, function_call=mock_function_call), - ], + text="model-output", + function_calls=[mock_function_call], usage_metadata=MagicMock(prompt_token_count=5, candidates_token_count=10), ) @@ -71,7 +68,7 @@ def mock_generative_model(self, mocker): @pytest.fixture() def mock_stream_generative_model(self, mocker): - mock_generative_model = mocker.patch("google.generativeai.GenerativeModel") + mock_generative_model = mocker.patch("google.genai.Client") mocker.patch("google.protobuf.json_format.MessageToDict").return_value = { "args": {"foo": "bar"}, } @@ -82,15 +79,17 @@ def mock_stream_generative_model(self, mocker): mock_generative_model.return_value.generate_content.return_value = iter( [ MagicMock( - parts=[MagicMock(text="model-output")], + text="model-output", + function_calls=[], usage_metadata=MagicMock(prompt_token_count=5, candidates_token_count=5), ), MagicMock( - parts=[MagicMock(text=None, function_call=mock_function_call_delta)], + text=None, + function_calls=[mock_function_call_delta], usage_metadata=MagicMock(prompt_token_count=5, candidates_token_count=5), ), MagicMock( - parts=[MagicMock(text="model-output", id="3")], + text="model-output", usage_metadata=MagicMock(prompt_token_count=5, candidates_token_count=5), ), ] @@ -271,11 +270,3 @@ def test_try_stream( event = next(stream) assert event.usage.output_tokens == 5 - - def test_verify_structured_output_strategy(self): - assert GooglePromptDriver(model="foo", structured_output_strategy="tool") - - with pytest.raises( - ValueError, match="GooglePromptDriver does not support `native` structured output strategy." - ): - GooglePromptDriver(model="foo", structured_output_strategy="native") diff --git a/uv.lock b/uv.lock index d6ea8c448d..dcd88d23f2 100644 --- a/uv.lock +++ b/uv.lock @@ -1057,59 +1057,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034 }, ] -[[package]] -name = "google-ai-generativelanguage" -version = "0.6.15" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core", extra = ["grpc"] }, - { name = "google-auth" }, - { name = "proto-plus" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/11/d1/48fe5d7a43d278e9f6b5ada810b0a3530bbeac7ed7fcbcd366f932f05316/google_ai_generativelanguage-0.6.15.tar.gz", hash = "sha256:8f6d9dc4c12b065fe2d0289026171acea5183ebf2d0b11cefe12f3821e159ec3", size = 1375443 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/7c/a3/67b8a6ff5001a1d8864922f2d6488dc2a14367ceb651bc3f09a947f2f306/google_ai_generativelanguage-0.6.15-py3-none-any.whl", hash = "sha256:5a03ef86377aa184ffef3662ca28f19eeee158733e45d7947982eb953c6ebb6c", size = 1327356 }, -] - -[[package]] -name = "google-api-core" -version = "2.24.2" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "googleapis-common-protos" }, - { name = "proto-plus" }, - { name = "protobuf" }, - { name = "requests" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/09/5c/085bcb872556934bb119e5e09de54daa07873f6866b8f0303c49e72287f7/google_api_core-2.24.2.tar.gz", hash = "sha256:81718493daf06d96d6bc76a91c23874dbf2fac0adbbf542831b805ee6e974696", size = 163516 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/46/95/f472d85adab6e538da2025dfca9e976a0d125cc0af2301f190e77b76e51c/google_api_core-2.24.2-py3-none-any.whl", hash = "sha256:810a63ac95f3c441b7c0e43d344e372887f62ce9071ba972eacf32672e072de9", size = 160061 }, -] - -[package.optional-dependencies] -grpc = [ - { name = "grpcio" }, - { name = "grpcio-status" }, -] - -[[package]] -name = "google-api-python-client" -version = "2.164.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-api-core" }, - { name = "google-auth" }, - { name = "google-auth-httplib2" }, - { name = "httplib2" }, - { name = "uritemplate" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/32/5b/4ed16fac5ef6928d0c1ca0fba42f27e73938f04729ef97e63d7a7bb5fd6d/google_api_python_client-2.164.0.tar.gz", hash = "sha256:116f5a05dfb95ed7f7ea0d0f561fc5464146709c583226cc814690f9bb221492", size = 12595711 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/94/0d/4eacf5bff40a42e6be3086b85164f0624fee9724c11bb2c79305fbc2f355/google_api_python_client-2.164.0-py2.py3-none-any.whl", hash = "sha256:b2037c3d280793c8d5180b04317b16be4acd5f77af5dfa7213ace32d140a9ffe", size = 13106781 }, -] - [[package]] name = "google-auth" version = "2.38.0" @@ -1125,34 +1072,21 @@ wheels = [ ] [[package]] -name = "google-auth-httplib2" -version = "0.2.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "google-auth" }, - { name = "httplib2" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/56/be/217a598a818567b28e859ff087f347475c807a5649296fb5a817c58dacef/google-auth-httplib2-0.2.0.tar.gz", hash = "sha256:38aa7badf48f974f1eb9861794e9c0cb2a0511a4ec0679b1f886d108f5640e05", size = 10842 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/be/8a/fe34d2f3f9470a27b01c9e76226965863f153d5fbe276f83608562e49c04/google_auth_httplib2-0.2.0-py2.py3-none-any.whl", hash = "sha256:b65a0a2123300dd71281a7bf6e64d65a0759287df52729bdd1ae2e47dc311a3d", size = 9253 }, -] - -[[package]] -name = "google-generativeai" -version = "0.8.4" +name = "google-genai" +version = "1.10.0" source = { registry = "https://pypi.org/simple" } dependencies = [ - { name = "google-ai-generativelanguage" }, - { name = "google-api-core" }, - { name = "google-api-python-client" }, + { name = "anyio" }, { name = "google-auth" }, - { name = "protobuf" }, + { name = "httpx" }, { name = "pydantic" }, - { name = "tqdm" }, + { name = "requests" }, { name = "typing-extensions" }, + { name = "websockets" }, ] +sdist = { url = "https://files.pythonhosted.org/packages/0e/7a/224e2f70c835202042969685ee3da00a6475508d1b64f0f1e90144f96beb/google_genai-1.10.0.tar.gz", hash = "sha256:f59423e0f155dc66b7792c8a0e6724c75c72dc699d1eb7907d4d0006d4f6186f", size = 156355 } wheels = [ - { url = "https://files.pythonhosted.org/packages/9b/b0/6c6af327a8a6ef3be6fe79be1d6f1e2914d6c363aa6b081b93396f4460a7/google_generativeai-0.8.4-py3-none-any.whl", hash = "sha256:e987b33ea6decde1e69191ddcaec6ef974458864d243de7191db50c21a7c5b82", size = 175409 }, + { url = "https://files.pythonhosted.org/packages/ba/a0/56839a2e202d79c773edd1c1db124da8eb2a7b657267a888080b678d0369/google_genai-1.10.0-py3-none-any.whl", hash = "sha256:41b105a2fcf8a027fc45cc16694cd559b8cd1272eab7345ad58cfa2c353bf34f", size = 154705 }, ] [[package]] @@ -1242,7 +1176,7 @@ wheels = [ [[package]] name = "griptape" -version = "1.5.0" +version = "1.6.0" source = { editable = "." } dependencies = [ { name = "attrs" }, @@ -1277,7 +1211,7 @@ all = [ { name = "duckduckgo-search" }, { name = "elevenlabs" }, { name = "exa-py" }, - { name = "google-generativeai" }, + { name = "google-genai" }, { name = "huggingface-hub" }, { name = "mail-parser" }, { name = "markdownify" }, @@ -1318,7 +1252,7 @@ drivers-embedding-cohere = [ { name = "cohere" }, ] drivers-embedding-google = [ - { name = "google-generativeai" }, + { name = "google-genai" }, ] drivers-embedding-huggingface = [ { name = "huggingface-hub" }, @@ -1388,7 +1322,7 @@ drivers-prompt-cohere = [ { name = "cohere" }, ] drivers-prompt-google = [ - { name = "google-generativeai" }, + { name = "google-genai" }, ] drivers-prompt-huggingface-hub = [ { name = "huggingface-hub" }, @@ -1549,9 +1483,9 @@ requires-dist = [ { name = "exa-py", marker = "extra == 'all'", specifier = ">=1.1.4" }, { name = "exa-py", marker = "extra == 'drivers-web-search-exa'", specifier = ">=1.1.4" }, { name = "filetype", specifier = ">=1.2" }, - { name = "google-generativeai", marker = "extra == 'all'", specifier = ">=0.8.2" }, - { name = "google-generativeai", marker = "extra == 'drivers-embedding-google'", specifier = ">=0.8.2" }, - { name = "google-generativeai", marker = "extra == 'drivers-prompt-google'", specifier = ">=0.8.2" }, + { name = "google-genai", marker = "extra == 'all'", specifier = ">=1.10.0" }, + { name = "google-genai", marker = "extra == 'drivers-embedding-google'", specifier = ">=1.10.0" }, + { name = "google-genai", marker = "extra == 'drivers-prompt-google'", specifier = ">=1.10.0" }, { name = "huggingface-hub", marker = "extra == 'all'", specifier = ">=0.28.1" }, { name = "huggingface-hub", marker = "extra == 'drivers-embedding-huggingface'", specifier = ">=0.28.1" }, { name = "huggingface-hub", marker = "extra == 'drivers-prompt-huggingface-hub'", specifier = ">=0.28.1" }, @@ -1744,20 +1678,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/97/22/b1535291aaa9c046c79a9dc4db125f6b9974d41de154221b72da4e8a005c/grpcio-1.71.0-cp39-cp39-win_amd64.whl", hash = "sha256:63e41b91032f298b3e973b3fa4093cbbc620c875e2da7b93e249d4728b54559a", size = 4280941 }, ] -[[package]] -name = "grpcio-status" -version = "1.71.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "googleapis-common-protos" }, - { name = "grpcio" }, - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/d7/53/a911467bece076020456401f55a27415d2d70d3bc2c37af06b44ea41fc5c/grpcio_status-1.71.0.tar.gz", hash = "sha256:11405fed67b68f406b3f3c7c5ae5104a79d2d309666d10d61b152e91d28fb968", size = 13669 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/ad/d6/31fbc43ff097d8c4c9fc3df741431b8018f67bf8dfbe6553a555f6e5f675/grpcio_status-1.71.0-py3-none-any.whl", hash = "sha256:843934ef8c09e3e858952887467f8256aac3910c55f077a359a65b2b3cde3e68", size = 14424 }, -] - [[package]] name = "grpcio-tools" version = "1.71.0" @@ -1882,18 +1802,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/87/f5/72347bc88306acb359581ac4d52f23c0ef445b57157adedb9aee0cd689d2/httpcore-1.0.7-py3-none-any.whl", hash = "sha256:a3fff8f43dc260d5bd363d9f9cf1830fa3a458b332856f34282de498ed420edd", size = 78551 }, ] -[[package]] -name = "httplib2" -version = "0.22.0" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "pyparsing" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/3d/ad/2371116b22d616c194aa25ec410c9c6c37f23599dcd590502b74db197584/httplib2-0.22.0.tar.gz", hash = "sha256:d7a10bc5ef5ab08322488bde8c726eeee5c8618723fdb399597ec58f3d82df81", size = 351116 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/a8/6c/d2fbdaaa5959339d53ba38e94c123e4e84b8fbc4b84beb0e70d7c1608486/httplib2-0.22.0-py3-none-any.whl", hash = "sha256:14ae0a53c1ba8f3d37e9e27cf37eabb0fb9980f435ba405d546948b009dd64dc", size = 96854 }, -] - [[package]] name = "httpx" version = "0.28.1" @@ -3877,18 +3785,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b5/35/6c4c6fc8774a9e3629cd750dc24a7a4fb090a25ccd5c3246d127b70f9e22/propcache-0.3.0-py3-none-any.whl", hash = "sha256:67dda3c7325691c2081510e92c561f465ba61b975f481735aefdfc845d2cd043", size = 12101 }, ] -[[package]] -name = "proto-plus" -version = "1.26.1" -source = { registry = "https://pypi.org/simple" } -dependencies = [ - { name = "protobuf" }, -] -sdist = { url = "https://files.pythonhosted.org/packages/f4/ac/87285f15f7cce6d4a008f33f1757fb5a13611ea8914eb58c3d0d26243468/proto_plus-1.26.1.tar.gz", hash = "sha256:21a515a4c4c0088a773899e23c7bbade3d18f9c66c73edd4c7ee3816bc96a012", size = 56142 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/4e/6d/280c4c2ce28b1593a19ad5239c8b826871fc6ec275c21afc8e1820108039/proto_plus-1.26.1-py3-none-any.whl", hash = "sha256:13285478c2dcf2abb829db158e1047e2f1e8d63a077d94263c2b88b043c75a66", size = 50163 }, -] - [[package]] name = "protobuf" version = "5.29.3" @@ -4283,15 +4179,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/ca/d7/eb76863d2060dcbe7c7e6cccfd95ac02ea0b9acc37745a0d99ff6457aefb/pyOpenSSL-25.0.0-py3-none-any.whl", hash = "sha256:424c247065e46e76a37411b9ab1782541c23bb658bf003772c3405fbaa128e90", size = 56453 }, ] -[[package]] -name = "pyparsing" -version = "3.2.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/8b/1a/3544f4f299a47911c2ab3710f534e52fea62a633c96806995da5d25be4b2/pyparsing-3.2.1.tar.gz", hash = "sha256:61980854fd66de3a90028d679a954d5f2623e83144b5afe5ee86f43d762e5f0a", size = 1067694 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/1c/a7/c8a2d361bf89c0d9577c934ebb7421b25dc84bf3a8e3ac0a40aed9acc547/pyparsing-3.2.1-py3-none-any.whl", hash = "sha256:506ff4f4386c4cec0590ec19e6302d3aedb992fdc02c761e90416f158dacf8e1", size = 107716 }, -] - [[package]] name = "pypdf" version = "5.3.1" @@ -5485,15 +5372,6 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c2/14/e2a54fabd4f08cd7af1c07030603c3356b74da07f7cc056e600436edfa17/tzlocal-5.3.1-py3-none-any.whl", hash = "sha256:eb1a66c3ef5847adf7a834f1be0800581b683b5608e74f86ecbcef8ab91bb85d", size = 18026 }, ] -[[package]] -name = "uritemplate" -version = "4.1.1" -source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/d2/5a/4742fdba39cd02a56226815abfa72fe0aa81c33bed16ed045647d6000eba/uritemplate-4.1.1.tar.gz", hash = "sha256:4346edfc5c3b79f694bccd6d6099a322bbeb628dbf2cd86eea55a456ce5124f0", size = 273898 } -wheels = [ - { url = "https://files.pythonhosted.org/packages/81/c0/7461b49cd25aeece13766f02ee576d1db528f1c37ce69aee300e075b485b/uritemplate-4.1.1-py2.py3-none-any.whl", hash = "sha256:830c08b8d99bdd312ea4ead05994a38e8936266f84b9a7878232db50b044e02e", size = 10356 }, -] - [[package]] name = "urllib3" version = "1.26.20" From 976fa3a862792900922b8133b347fe834fd5e0dd Mon Sep 17 00:00:00 2001 From: Collin Dutter Date: Wed, 9 Apr 2025 14:44:55 -0700 Subject: [PATCH 2/2] Squash merge fix/lazy into refactor/google-prompt-driver --- .../drivers/src/vector_store_drivers_10.py | 4 ++- .../assistant/openai_assistant_driver.py | 23 ++++++++++----- .../amazon_bedrock_cohere_embedding_driver.py | 6 ++-- .../amazon_bedrock_titan_embedding_driver.py | 6 ++-- ...on_sagemaker_jumpstart_embedding_driver.py | 6 ++-- .../huggingface_hub_embedding_driver.py | 2 +- .../amazon_sqs_event_listener_driver.py | 5 +++- .../amazon_s3_file_manager_driver.py | 7 +++-- .../amazon_bedrock_image_generation_driver.py | 6 ++-- ...zon_dynamodb_conversation_memory_driver.py | 2 +- .../drivers/sql/amazon_redshift_sql_driver.py | 8 +++--- .../vector/marqo_vector_store_driver.py | 2 +- .../vector/pinecone_vector_store_driver.py | 14 +++++----- .../vector/qdrant_vector_store_driver.py | 16 +++++++---- .../vector/redis_vector_store_driver.py | 8 +++--- .../web_search/exa_web_search_driver.py | 4 +-- griptape/schemas/base_schema.py | 4 +-- griptape/utils/decorators.py | 21 ++++++++------ pyproject.toml | 2 +- .../assistant/test_openai_assistant_driver.py | 14 ++++++++-- uv.lock | 28 +++++++++---------- 21 files changed, 111 insertions(+), 77 deletions(-) diff --git a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py index 9a27868b5f..99580bbbbe 100644 --- a/docs/griptape-framework/drivers/src/vector_store_drivers_10.py +++ b/docs/griptape-framework/drivers/src/vector_store_drivers_10.py @@ -1,5 +1,7 @@ import os +from qdrant_client.http.models import Distance, VectorParams + from griptape.chunkers import TextChunker from griptape.drivers.embedding.openai import OpenAiEmbeddingDriver from griptape.drivers.vector.qdrant import QdrantVectorStoreDriver @@ -27,7 +29,7 @@ # Recreate Qdrant collection vector_store_driver.client.recreate_collection( collection_name=vector_store_driver.collection_name, - vectors_config={"size": 1536, "distance": vector_store_driver.distance}, + vectors_config=VectorParams(size=1536, distance=Distance.COSINE), ) # Upsert Artifacts into the Vector Store Driver diff --git a/griptape/drivers/assistant/openai_assistant_driver.py b/griptape/drivers/assistant/openai_assistant_driver.py index c2a519eef6..3c35bc82b6 100644 --- a/griptape/drivers/assistant/openai_assistant_driver.py +++ b/griptape/drivers/assistant/openai_assistant_driver.py @@ -59,19 +59,26 @@ def client(self) -> openai.OpenAI: ) def try_run(self, *args: BaseArtifact) -> TextArtifact: - if self.thread_id is None and self.auto_create_thread: - self.thread_id = self.client.beta.threads.create().id - response = self._create_run(*args) + if self.thread_id is None: + if self.auto_create_thread: + thread_id = self.client.beta.threads.create().id + self.thread_id = thread_id + else: + raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.") + else: + thread_id = self.thread_id + + response = self._create_run(thread_id, *args) response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id}) return response - def _create_run(self, *args: BaseArtifact) -> TextArtifact: + def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact: content = "\n".join(arg.value for arg in args) - message_id = self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=content) + message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content) with self.client.beta.threads.runs.stream( - thread_id=self.thread_id, + thread_id=thread_id, assistant_id=self.assistant_id, event_handler=self.event_handler, ) as stream: @@ -80,7 +87,9 @@ def _create_run(self, *args: BaseArtifact) -> TextArtifact: message_contents = [] for message in last_messages: - message_contents.append("".join(content.text.value for content in message.content)) + message_contents.append( + "".join(content.text.value for content in message.content if content.type == "TextContentBlock") + ) message_text = "\n".join(message_contents) response = TextArtifact(message_text) diff --git a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py index 5ddc115066..b775dd98d8 100644 --- a/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_cohere_embedding_driver.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: import boto3 - from mypy_boto3_bedrock import BedrockClient + from mypy_boto3_bedrock_runtime import BedrockRuntimeClient from griptape.tokenizers.base_tokenizer import BaseTokenizer @@ -40,12 +40,12 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) - _client: Optional[BedrockClient] = field( + _client: Optional[BedrockRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() - def client(self) -> BedrockClient: + def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: diff --git a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py index 7d58e7d45c..495eb7b762 100644 --- a/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_bedrock_titan_embedding_driver.py @@ -14,7 +14,7 @@ if TYPE_CHECKING: import boto3 - from mypy_boto3_bedrock import BedrockClient + from mypy_boto3_bedrock_runtime import BedrockRuntimeClient from griptape.tokenizers.base_tokenizer import BaseTokenizer @@ -38,12 +38,12 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver): default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True), kw_only=True, ) - _client: Optional[BedrockClient] = field( + _client: Optional[BedrockRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() - def client(self) -> BedrockClient: + def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime") def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact) -> list[float]: diff --git a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py index 878f167cb5..05cbee9e72 100644 --- a/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py +++ b/griptape/drivers/embedding/amazon_sagemaker_jumpstart_embedding_driver.py @@ -11,7 +11,7 @@ if TYPE_CHECKING: import boto3 - from mypy_boto3_sagemaker import SageMakerClient + from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient @define @@ -20,12 +20,12 @@ class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver): endpoint: str = field(kw_only=True, metadata={"serializable": True}) custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True}) inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True}) - _client: Optional[SageMakerClient] = field( + _client: Optional[SageMakerRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() - def client(self) -> SageMakerClient: + def client(self) -> SageMakerRuntimeClient: return self.session.client("sagemaker-runtime") def try_embed_chunk(self, chunk: str) -> list[float]: diff --git a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py index 29c635e32c..e4fff5f37b 100644 --- a/griptape/drivers/embedding/huggingface_hub_embedding_driver.py +++ b/griptape/drivers/embedding/huggingface_hub_embedding_driver.py @@ -37,4 +37,4 @@ def client(self) -> InferenceClient: def try_embed_chunk(self, chunk: str) -> list[float]: response = self.client.feature_extraction(chunk) - return response.flatten().tolist() + return [float(val) for val in response.flatten().tolist()] diff --git a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py index d94264a503..ff92aa2996 100644 --- a/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py +++ b/griptape/drivers/event_listener/amazon_sqs_event_listener_driver.py @@ -10,8 +10,11 @@ from griptape.utils.decorators import lazy_property if TYPE_CHECKING: + from collections.abc import Sequence + import boto3 from mypy_boto3_sqs import SQSClient + from mypy_boto3_sqs.type_defs import SendMessageBatchRequestEntryTypeDef @define @@ -28,7 +31,7 @@ def try_publish_event_payload(self, event_payload: dict) -> None: self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload)) def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None: - entries = [ + entries: Sequence[SendMessageBatchRequestEntryTypeDef] = [ {"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)} for event_payload in event_payload_batch ] diff --git a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py index e913842234..b06bf9b250 100644 --- a/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py +++ b/griptape/drivers/file_manager/amazon_s3_file_manager_driver.py @@ -12,6 +12,7 @@ if TYPE_CHECKING: import boto3 from mypy_boto3_s3 import S3Client + from mypy_boto3_s3.type_defs import PaginatorConfigTypeDef @define @@ -98,7 +99,7 @@ def _to_dir_full_key(self, path: str) -> str: def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]: max_items = kwargs.get("max_items") - pagination_config = {} + pagination_config: PaginatorConfigTypeDef = {} if max_items is not None: pagination_config["MaxItems"] = max_items @@ -112,12 +113,12 @@ def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]: files_and_dirs = [] for page in pages: for obj in page.get("CommonPrefixes", []): - prefix = obj.get("Prefix") + prefix = obj.get("Prefix", "") directory = prefix[len(full_key) :].rstrip("/") files_and_dirs.append(directory) for obj in page.get("Contents", []): - key = obj.get("Key") + key = obj.get("Key", "") file = key[len(full_key) :] files_and_dirs.append(file) return files_and_dirs diff --git a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py index 24a966cb9b..3d99ff0109 100644 --- a/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py +++ b/griptape/drivers/image_generation/amazon_bedrock_image_generation_driver.py @@ -12,7 +12,7 @@ if TYPE_CHECKING: import boto3 - from mypy_boto3_bedrock import BedrockClient + from mypy_boto3_bedrock_runtime import BedrockRuntimeClient @define @@ -32,12 +32,12 @@ class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver): image_width: int = field(default=512, kw_only=True, metadata={"serializable": True}) image_height: int = field(default=512, kw_only=True, metadata={"serializable": True}) seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True}) - _client: Optional[BedrockClient] = field( + _client: Optional[BedrockRuntimeClient] = field( default=None, kw_only=True, alias="client", metadata={"serializable": False} ) @lazy_property() - def client(self) -> BedrockClient: + def client(self) -> BedrockRuntimeClient: return self.session.client("bedrock-runtime") def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact: diff --git a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py index e0c9683011..754f1ecc51 100644 --- a/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py +++ b/griptape/drivers/memory/conversation/amazon_dynamodb_conversation_memory_driver.py @@ -45,7 +45,7 @@ def load(self) -> tuple[list[Run], dict[str, Any]]: response = self.table.get_item(Key=self._get_key()) if "Item" in response and self.value_attribute_key in response["Item"]: - memory_dict = json.loads(response["Item"][self.value_attribute_key]) + memory_dict = json.loads(str(response["Item"][self.value_attribute_key])) return self._from_params_dict(memory_dict) return [], {} diff --git a/griptape/drivers/sql/amazon_redshift_sql_driver.py b/griptape/drivers/sql/amazon_redshift_sql_driver.py index 102326a061..27bce06c81 100644 --- a/griptape/drivers/sql/amazon_redshift_sql_driver.py +++ b/griptape/drivers/sql/amazon_redshift_sql_driver.py @@ -72,7 +72,7 @@ def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any] if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn - response = self.client.execute_statement(**function_kwargs) + response = self.client.execute_statement(**function_kwargs) # pyright: ignore[reportArgumentType] response_id = response["Id"] statement = self.client.describe_statement(Id=response_id) @@ -92,7 +92,7 @@ def execute_query_raw(self, query: str) -> Optional[list[dict[str, Optional[Any] ) results = results + response.get("Records", []) - return self._post_process(statement_result["ColumnMetadata"], results) + return self._post_process(statement_result["ColumnMetadata"], results) # pyright: ignore[reportArgumentType] if statement["Status"] in ["FAILED", "ABORTED"]: return None @@ -110,5 +110,5 @@ def get_table_schema(self, table_name: str, schema: Optional[str] = None) -> Opt function_kwargs["DbUser"] = self.db_user if self.database_credentials_secret_arn: function_kwargs["SecretArn"] = self.database_credentials_secret_arn - response = self.client.describe_table(**function_kwargs) - return str([col["name"] for col in response["ColumnList"]]) + response = self.client.describe_table(**function_kwargs) # pyright: ignore[reportArgumentType] + return str([col["name"] for col in response["ColumnList"] if "name" in col]) diff --git a/griptape/drivers/vector/marqo_vector_store_driver.py b/griptape/drivers/vector/marqo_vector_store_driver.py index 0559c8937b..b166b6106b 100644 --- a/griptape/drivers/vector/marqo_vector_store_driver.py +++ b/griptape/drivers/vector/marqo_vector_store_driver.py @@ -201,7 +201,7 @@ def query( "filter_string": f"namespace:{namespace}" if namespace else None, } | kwargs - results = self.client.index(self.index).search(query, **params) + results = self.client.index(self.index).search(str(query), **params) return self.__process_results(results, include_vectors=include_vectors) def delete_index(self, name: str) -> dict[str, Any]: diff --git a/griptape/drivers/vector/pinecone_vector_store_driver.py b/griptape/drivers/vector/pinecone_vector_store_driver.py index 0bbf390807..04d9993116 100644 --- a/griptape/drivers/vector/pinecone_vector_store_driver.py +++ b/griptape/drivers/vector/pinecone_vector_store_driver.py @@ -34,7 +34,7 @@ def client(self) -> pinecone.Pinecone: ) @lazy_property() - def index(self) -> pinecone.Index: + def index(self) -> pinecone.data.index.Index: return self.client.Index(self.index_name) def upsert_vector( @@ -49,12 +49,12 @@ def upsert_vector( params: dict[str, Any] = {"namespace": namespace} | kwargs - self.index.upsert(vectors=[(vector_id, vector, meta)], **params) + self.index.upsert(vectors=[(vector_id, vector, meta)], **params) # pyright: ignore[reportArgumentType] return vector_id def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Optional[BaseVectorStoreDriver.Entry]: - result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() + result = self.index.fetch(ids=[vector_id], namespace=namespace).to_dict() # pyright: ignore[reportAttributeAccessIssue] vectors = list(result["vectors"].values()) if len(vectors) > 0: @@ -85,9 +85,9 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto id=r["id"], vector=r["values"], meta=r["metadata"], - namespace=results["namespace"], + namespace=results["namespace"], # pyright: ignore[reportIndexIssue] ) - for r in results["matches"] + for r in results["matches"] # pyright: ignore[reportIndexIssue] ] def query_vector( @@ -115,9 +115,9 @@ def query_vector( vector=r["values"], score=r["score"], meta=r["metadata"], - namespace=results["namespace"], + namespace=results["namespace"], # pyright: ignore[reportIndexIssue] ) - for r in results["matches"] + for r in results["matches"] # pyright: ignore[reportIndexIssue] ] def delete_vector(self, vector_id: str) -> NoReturn: diff --git a/griptape/drivers/vector/qdrant_vector_store_driver.py b/griptape/drivers/vector/qdrant_vector_store_driver.py index 274912e05b..8114e1420c 100644 --- a/griptape/drivers/vector/qdrant_vector_store_driver.py +++ b/griptape/drivers/vector/qdrant_vector_store_driver.py @@ -123,12 +123,13 @@ def query_vector( # Convert results to QueryResult objects return [ BaseVectorStoreDriver.Entry( - id=result.id, - vector=result.vector if include_vectors else [], + id=str(result.id), + vector=result.vector if include_vectors else [], # pyright: ignore[reportArgumentType] score=result.score, meta={k: v for k, v in result.payload.items() if k not in ["_score", "_tensor_facets"]}, ) for result in results + if result.payload is not None ] def upsert_vector( @@ -184,9 +185,11 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti results = self.client.retrieve(collection_name=self.collection_name, ids=[vector_id]) if results: entry = results[0] + if entry.payload is None: + entry.payload = {} return BaseVectorStoreDriver.Entry( - id=entry.id, - vector=entry.vector, + id=str(entry.id), + vector=entry.vector if entry.vector is not None else [], # pyright: ignore[reportArgumentType] meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, ) return None @@ -209,9 +212,10 @@ def load_entries(self, *, namespace: Optional[str] = None, **kwargs) -> list[Bas return [ BaseVectorStoreDriver.Entry( - id=entry.id, - vector=entry.vector if kwargs.get("with_vectors", True) else [], + id=str(entry.id), + vector=entry.vector if kwargs.get("with_vectors", True) else [], # pyright: ignore[reportArgumentType] meta={k: v for k, v in entry.payload.items() if k not in ["_score", "_tensor_facets"]}, ) for entry in results + if entry.payload is not None ] diff --git a/griptape/drivers/vector/redis_vector_store_driver.py b/griptape/drivers/vector/redis_vector_store_driver.py index d80f228314..8674d5e0bd 100644 --- a/griptape/drivers/vector/redis_vector_store_driver.py +++ b/griptape/drivers/vector/redis_vector_store_driver.py @@ -85,8 +85,8 @@ def load_entry(self, vector_id: str, *, namespace: Optional[str] = None) -> Opti """ key = self._generate_key(vector_id, namespace) result = self.client.hgetall(key) - vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist() - meta = json.loads(result[b"metadata"]) if b"metadata" in result else None + vector = np.frombuffer(result[b"vector"], dtype=np.float32).tolist() # pyright: ignore[reportIndexIssue] https://github.com/redis/redis-py/issues/2399 + meta = json.loads(result[b"metadata"]) if b"metadata" in result else None # pyright: ignore[reportIndexIssue, reportOperatorIssue] return BaseVectorStoreDriver.Entry(id=vector_id, meta=meta, vector=vector, namespace=namespace) @@ -100,7 +100,7 @@ def load_entries(self, *, namespace: Optional[str] = None) -> list[BaseVectorSto keys = self.client.keys(pattern) entries = [] - for key in keys: + for key in keys: # pyright: ignore[reportGeneralTypeIssues] https://github.com/redis/redis-py/issues/2399 entry = self.load_entry(key.decode("utf-8"), namespace=namespace) if entry: entries.append(entry) @@ -136,7 +136,7 @@ def query_vector( query_params = {"vector": np.array(vector, dtype=np.float32).tobytes()} - results = self.client.ft(self.index).search(query_expression, query_params).docs # pyright: ignore[reportArgumentType] + results = self.client.ft(self.index).search(query_expression, query_params).docs # pyright: ignore[reportArgumentType, reportAttributeAccessIssue] query_results = [] for document in results: diff --git a/griptape/drivers/web_search/exa_web_search_driver.py b/griptape/drivers/web_search/exa_web_search_driver.py index 60b0aa59f9..1cbac87955 100644 --- a/griptape/drivers/web_search/exa_web_search_driver.py +++ b/griptape/drivers/web_search/exa_web_search_driver.py @@ -26,8 +26,8 @@ def client(self) -> Exa: return import_optional_dependency("exa_py").Exa(api_key=self.api_key) def search(self, query: str, **kwargs) -> ListArtifact[JsonArtifact]: - response = self.client.search_and_contents( - highlights=self.highlights, + response = self.client.search_and_contents( # pyright: ignore[reportCallIssue] + highlights=self.highlights, # pyright: ignore[reportArgumentType] use_autoprompt=self.use_autoprompt, query=query, num_results=self.results_count, diff --git a/griptape/schemas/base_schema.py b/griptape/schemas/base_schema.py index 33e3983535..07bd981c94 100644 --- a/griptape/schemas/base_schema.py +++ b/griptape/schemas/base_schema.py @@ -335,8 +335,8 @@ def _resolve_types(cls, attrs_cls: type, types_override: Optional[dict[str, type "Anthropic": import_optional_dependency("anthropic").Anthropic if is_dependency_installed("anthropic") else Any, - "BedrockClient": import_optional_dependency("mypy_boto3_bedrock").BedrockClient - if is_dependency_installed("mypy_boto3_bedrock") + "BedrockRuntimeClient": import_optional_dependency("mypy_boto3_bedrock_runtime").BedrockRuntimeClient + if is_dependency_installed("mypy_boto3_bedrock_runtime") else Any, "voyageai": import_optional_dependency("voyageai") if is_dependency_installed("voyageai") else Any, "Schema": Schema, diff --git a/griptape/utils/decorators.py b/griptape/utils/decorators.py index c7677eb207..bfc939500a 100644 --- a/griptape/utils/decorators.py +++ b/griptape/utils/decorators.py @@ -2,11 +2,13 @@ import functools import inspect -from typing import TYPE_CHECKING, Any, Callable, Optional, cast +from collections import OrderedDict +from typing import TYPE_CHECKING, Any, Callable, Optional, TypeVar, cast import schema from pydantic import BaseModel from schema import Schema +from typing_extensions import ParamSpec if TYPE_CHECKING: from collections import OrderedDict @@ -20,6 +22,9 @@ } ) +P = ParamSpec("P") +R = TypeVar("R") + def activity(config: dict) -> Any: validated_config = CONFIG_SCHEMA.validate(config) @@ -43,22 +48,22 @@ def wrapper(self: Any, params: dict) -> Any: return decorator -def lazy_property(attr_name: Optional[str] = None) -> Callable[[Callable[[Any], Any]], property]: - def decorator(func: Callable[[Any], Any]) -> property: +def lazy_property(attr_name: Optional[str] = None) -> Callable[[Callable[P, R]], R]: + def decorator(func: Callable[P, R]) -> R: actual_attr_name = f"_{func.__name__}" if attr_name is None else attr_name @property @functools.wraps(func) - def lazy_attr(self: Any) -> Any: + def wrapper(self: Any) -> R: if getattr(self, actual_attr_name) is None: - setattr(self, actual_attr_name, func(self)) + setattr(self, actual_attr_name, func(self)) # pyright: ignore[reportCallIssue] return getattr(self, actual_attr_name) - @lazy_attr.setter - def lazy_attr(self: Any, value: Any) -> None: + @wrapper.setter + def wrapper(self: Any, value: Any) -> None: setattr(self, actual_attr_name, value) - return lazy_attr + return cast("R", wrapper) return decorator diff --git a/pyproject.toml b/pyproject.toml index 30080f53f2..1c6c90aff8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -174,7 +174,7 @@ dev = [ "ruff>=0.9.1", "pyright>=1.1.376", "pre-commit>=4.0.0", - "boto3-stubs[bedrock, iam, opensearch, s3, sagemaker, sqs, iot-data, dynamodb, redshift-data]>=1.34.105", + "boto3-stubs[bedrock-runtime, iam, opensearch, s3, sagemaker-runtime, sqs, iot-data, dynamodb, redshift-data]>=1.34.105", "typos>=1.22.9", "mdformat>=0.7.17", "mdformat-gfm>=0.4.1", diff --git a/tests/unit/drivers/assistant/test_openai_assistant_driver.py b/tests/unit/drivers/assistant/test_openai_assistant_driver.py index 2bcc103686..94f53a6a07 100644 --- a/tests/unit/drivers/assistant/test_openai_assistant_driver.py +++ b/tests/unit/drivers/assistant/test_openai_assistant_driver.py @@ -16,8 +16,18 @@ def mock_event_handler(self, mocker): event_handler, "get_final_messages", return_value=[ - Mock(content=[Mock(text=Mock(value="foo")), Mock(text=Mock(value=" bar"))]), - Mock(content=[Mock(text=Mock(value="foo")), Mock(text=Mock(value=" bar"))]), + Mock( + content=[ + Mock(type="TextContentBlock", text=Mock(value="foo")), + Mock(type="TextContentBlock", text=Mock(value=" bar")), + ] + ), + Mock( + content=[ + Mock(type="TextContentBlock", text=Mock(value="foo")), + Mock(type="TextContentBlock", text=Mock(value=" bar")), + ] + ), ], ) mocker.patch.object(event_handler, "until_done") diff --git a/uv.lock b/uv.lock index dcd88d23f2..b3ecf55abb 100644 --- a/uv.lock +++ b/uv.lock @@ -304,8 +304,8 @@ wheels = [ ] [package.optional-dependencies] -bedrock = [ - { name = "mypy-boto3-bedrock" }, +bedrock-runtime = [ + { name = "mypy-boto3-bedrock-runtime" }, ] dynamodb = [ { name = "mypy-boto3-dynamodb" }, @@ -325,8 +325,8 @@ redshift-data = [ s3 = [ { name = "mypy-boto3-s3" }, ] -sagemaker = [ - { name = "mypy-boto3-sagemaker" }, +sagemaker-runtime = [ + { name = "mypy-boto3-sagemaker-runtime" }, ] sqs = [ { name = "mypy-boto3-sqs" }, @@ -1413,7 +1413,7 @@ loaders-sql = [ [package.dev-dependencies] dev = [ - { name = "boto3-stubs", extra = ["bedrock", "dynamodb", "iam", "iot-data", "opensearch", "redshift-data", "s3", "sagemaker", "sqs"] }, + { name = "boto3-stubs", extra = ["bedrock-runtime", "dynamodb", "iam", "iot-data", "opensearch", "redshift-data", "s3", "sagemaker-runtime", "sqs"] }, { name = "mdformat" }, { name = "mdformat-footnote" }, { name = "mdformat-frontmatter" }, @@ -1584,7 +1584,7 @@ provides-extras = ["drivers-prompt-cohere", "drivers-prompt-anthropic", "drivers [package.metadata.requires-dev] dev = [ - { name = "boto3-stubs", extras = ["bedrock", "iam", "opensearch", "s3", "sagemaker", "sqs", "iot-data", "dynamodb", "redshift-data"], specifier = ">=1.34.105" }, + { name = "boto3-stubs", extras = ["bedrock-runtime", "iam", "opensearch", "s3", "sagemaker-runtime", "sqs", "iot-data", "dynamodb", "redshift-data"], specifier = ">=1.34.105" }, { name = "mdformat", specifier = ">=0.7.17" }, { name = "mdformat-footnote", specifier = ">=0.1.1" }, { name = "mdformat-frontmatter", specifier = ">=2.0.8" }, @@ -2794,15 +2794,15 @@ wheels = [ ] [[package]] -name = "mypy-boto3-bedrock" -version = "1.37.8" +name = "mypy-boto3-bedrock-runtime" +version = "1.37.30" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/d3/17/d6ba56ca7155b78c6367b3974ee71265245908f12e7d473436d1efaf26e6/mypy_boto3_bedrock-1.37.8.tar.gz", hash = "sha256:52e154288bed4b01ec47ac3c608fcaa71397b10887865e567e13a576d9b89649", size = 42083 } +sdist = { url = "https://files.pythonhosted.org/packages/df/29/9c23ceb80af9d028ddf2bbaea3469017049f1693a620e1ee472482f3ee39/mypy_boto3_bedrock_runtime-1.37.30.tar.gz", hash = "sha256:0dfc1d9910eb14900ca4ae88f09e37dd57ec56a95d355b56f4da3139823dec99", size = 26176 } wheels = [ - { url = "https://files.pythonhosted.org/packages/16/c9/7da6d3074395adf3a3be41518aacb63a7098e180705bc2fa8cc39e8bba99/mypy_boto3_bedrock-1.37.8-py3-none-any.whl", hash = "sha256:38c7572a15ed7fa2b9496d23408a2a985f16231f6fcc918633424104410490f9", size = 47614 }, + { url = "https://files.pythonhosted.org/packages/89/7d/a278edd8880263e3a8617348c69a3712dc6e6c2de77ba4381506161c0aaf/mypy_boto3_bedrock_runtime-1.37.30-py3-none-any.whl", hash = "sha256:c55836500da1938e40cf579aef9528769855655cd4fb0ee4e3db34c3b44a720f", size = 31870 }, ] [[package]] @@ -2878,15 +2878,15 @@ wheels = [ ] [[package]] -name = "mypy-boto3-sagemaker" -version = "1.37.5" +name = "mypy-boto3-sagemaker-runtime" +version = "1.37.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "typing-extensions", marker = "python_full_version < '3.12'" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/c7/cb/abdd237e4e6876674a4d548821cd37e33ee0a802ee2234dd86c9f76f10a9/mypy_boto3_sagemaker-1.37.5.tar.gz", hash = "sha256:099981df3990c2c5b2f39bf32bcf751d7b917e680d060d985c46ac8a13ab6d46", size = 208907 } +sdist = { url = "https://files.pythonhosted.org/packages/c4/92/6b14ddbd41de8893a119214f4fef27ccbc591132daae0edd8303d03b2147/mypy_boto3_sagemaker_runtime-1.37.0.tar.gz", hash = "sha256:1503eeae9a1f131c0c4d6e309845f90c083c8eadb80f80ad4444876815efd477", size = 15764 } wheels = [ - { url = "https://files.pythonhosted.org/packages/3b/b9/2c8008863a57f0f66602199dda84484e83332b46fc5e46c7f69b68af3366/mypy_boto3_sagemaker-1.37.5-py3-none-any.whl", hash = "sha256:6ac0d5754bdf7a8bd7c339d455c0a02a61991beeae50d239f433a6f96d4bdc9a", size = 212405 }, + { url = "https://files.pythonhosted.org/packages/44/a2/f9232ad55ad60bd877bde5bce0526d6d2db1afc1d2827402f2342eb5dd89/mypy_boto3_sagemaker_runtime-1.37.0-py3-none-any.whl", hash = "sha256:5044642147dc49c2dc81e390c93ae422f29eec9cf530e71f3674503d99b90b78", size = 19211 }, ] [[package]]