Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import os

from qdrant_client.http.models import Distance, VectorParams

from griptape.chunkers import TextChunker
from griptape.drivers.embedding.openai import OpenAiEmbeddingDriver
from griptape.drivers.vector.qdrant import QdrantVectorStoreDriver
Expand Down Expand Up @@ -27,7 +29,7 @@
# Recreate Qdrant collection
vector_store_driver.client.recreate_collection(
collection_name=vector_store_driver.collection_name,
vectors_config={"size": 1536, "distance": vector_store_driver.distance},
vectors_config=VectorParams(size=1536, distance=Distance.COSINE),
)

# Upsert Artifacts into the Vector Store Driver
Expand Down
12 changes: 6 additions & 6 deletions docs/recipes/src/talk_to_a_video_1.py
Original file line number Diff line number Diff line change
@@ -1,20 +1,20 @@
import time

from google.generativeai.files import get_file, upload_file

from griptape.artifacts import GenericArtifact, TextArtifact
from griptape.configs import Defaults
from griptape.configs.drivers import GoogleDriversConfig
from griptape.structures import Agent

Defaults.drivers_config = GoogleDriversConfig()
client = Defaults.drivers_config.prompt_driver.client

video_file = upload_file(path="tests/resources/griptape-comfyui.mp4")
while video_file.state.name == "PROCESSING":
video_file = client.files.upload(file="tests/resources/griptape-comfyui.mp4")
while video_file.state and video_file.state.name == "PROCESSING":
time.sleep(2)
video_file = get_file(video_file.name)
if video_file.name:
video_file = client.files.get(name=video_file.name)

if video_file.state.name == "FAILED":
if video_file.state and video_file.state.name == "FAILED":
raise ValueError(video_file.state.name)

agent = Agent(
Expand Down
23 changes: 16 additions & 7 deletions griptape/drivers/assistant/openai_assistant_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,26 @@ def client(self) -> openai.OpenAI:
)

def try_run(self, *args: BaseArtifact) -> TextArtifact:
if self.thread_id is None and self.auto_create_thread:
self.thread_id = self.client.beta.threads.create().id
response = self._create_run(*args)
if self.thread_id is None:
if self.auto_create_thread:
thread_id = self.client.beta.threads.create().id
self.thread_id = thread_id
else:
raise ValueError("Thread ID is required but not provided and auto_create_thread is disabled.")
else:
thread_id = self.thread_id

response = self._create_run(thread_id, *args)

response.meta.update({"assistant_id": self.assistant_id, "thread_id": self.thread_id})

return response

def _create_run(self, *args: BaseArtifact) -> TextArtifact:
def _create_run(self, thread_id: str, *args: BaseArtifact) -> TextArtifact:
content = "\n".join(arg.value for arg in args)
message_id = self.client.beta.threads.messages.create(thread_id=self.thread_id, role="user", content=content)
message_id = self.client.beta.threads.messages.create(thread_id=thread_id, role="user", content=content)
with self.client.beta.threads.runs.stream(
thread_id=self.thread_id,
thread_id=thread_id,
assistant_id=self.assistant_id,
event_handler=self.event_handler,
) as stream:
Expand All @@ -80,7 +87,9 @@ def _create_run(self, *args: BaseArtifact) -> TextArtifact:

message_contents = []
for message in last_messages:
message_contents.append("".join(content.text.value for content in message.content))
message_contents.append(
"".join(content.text.value for content in message.content if content.type == "TextContentBlock")
)
message_text = "\n".join(message_contents)

response = TextArtifact(message_text)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
import boto3
from mypy_boto3_bedrock import BedrockClient
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient

from griptape.tokenizers.base_tokenizer import BaseTokenizer

Expand Down Expand Up @@ -40,12 +40,12 @@ class AmazonBedrockCohereEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
_client: Optional[BedrockClient] = field(
_client: Optional[BedrockRuntimeClient] = field(
default=None, kw_only=True, alias="client", metadata={"serializable": False}
)

@lazy_property()
def client(self) -> BedrockClient:
def client(self) -> BedrockRuntimeClient:
return self.session.client("bedrock-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

if TYPE_CHECKING:
import boto3
from mypy_boto3_bedrock import BedrockClient
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient

from griptape.tokenizers.base_tokenizer import BaseTokenizer

Expand All @@ -38,12 +38,12 @@ class AmazonBedrockTitanEmbeddingDriver(BaseEmbeddingDriver):
default=Factory(lambda self: AmazonBedrockTokenizer(model=self.model), takes_self=True),
kw_only=True,
)
_client: Optional[BedrockClient] = field(
_client: Optional[BedrockRuntimeClient] = field(
default=None, kw_only=True, alias="client", metadata={"serializable": False}
)

@lazy_property()
def client(self) -> BedrockClient:
def client(self) -> BedrockRuntimeClient:
return self.session.client("bedrock-runtime")

def try_embed_artifact(self, artifact: TextArtifact | ImageArtifact) -> list[float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

if TYPE_CHECKING:
import boto3
from mypy_boto3_sagemaker import SageMakerClient
from mypy_boto3_sagemaker_runtime import SageMakerRuntimeClient


@define
Expand All @@ -20,12 +20,12 @@ class AmazonSageMakerJumpstartEmbeddingDriver(BaseEmbeddingDriver):
endpoint: str = field(kw_only=True, metadata={"serializable": True})
custom_attributes: str = field(default="accept_eula=true", kw_only=True, metadata={"serializable": True})
inference_component_name: Optional[str] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: Optional[SageMakerClient] = field(
_client: Optional[SageMakerRuntimeClient] = field(
default=None, kw_only=True, alias="client", metadata={"serializable": False}
)

@lazy_property()
def client(self) -> SageMakerClient:
def client(self) -> SageMakerRuntimeClient:
return self.session.client("sagemaker-runtime")

def try_embed_chunk(self, chunk: str) -> list[float]:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@ def client(self) -> InferenceClient:
def try_embed_chunk(self, chunk: str) -> list[float]:
response = self.client.feature_extraction(chunk)

return response.flatten().tolist()
return [float(val) for val in response.flatten().tolist()]
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,11 @@
from griptape.utils.decorators import lazy_property

if TYPE_CHECKING:
from collections.abc import Sequence

import boto3
from mypy_boto3_sqs import SQSClient
from mypy_boto3_sqs.type_defs import SendMessageBatchRequestEntryTypeDef


@define
Expand All @@ -28,7 +31,7 @@ def try_publish_event_payload(self, event_payload: dict) -> None:
self.client.send_message(QueueUrl=self.queue_url, MessageBody=json.dumps(event_payload))

def try_publish_event_payload_batch(self, event_payload_batch: list[dict]) -> None:
entries = [
entries: Sequence[SendMessageBatchRequestEntryTypeDef] = [
{"Id": str(event_payload["id"]), "MessageBody": json.dumps(event_payload)}
for event_payload in event_payload_batch
]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
if TYPE_CHECKING:
import boto3
from mypy_boto3_s3 import S3Client
from mypy_boto3_s3.type_defs import PaginatorConfigTypeDef


@define
Expand Down Expand Up @@ -98,7 +99,7 @@ def _to_dir_full_key(self, path: str) -> str:

def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
max_items = kwargs.get("max_items")
pagination_config = {}
pagination_config: PaginatorConfigTypeDef = {}
if max_items is not None:
pagination_config["MaxItems"] = max_items

Expand All @@ -112,12 +113,12 @@ def _list_files_and_dirs(self, full_key: str, **kwargs) -> list[str]:
files_and_dirs = []
for page in pages:
for obj in page.get("CommonPrefixes", []):
prefix = obj.get("Prefix")
prefix = obj.get("Prefix", "")
directory = prefix[len(full_key) :].rstrip("/")
files_and_dirs.append(directory)

for obj in page.get("Contents", []):
key = obj.get("Key")
key = obj.get("Key", "")
file = key[len(full_key) :]
files_and_dirs.append(file)
return files_and_dirs
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@

if TYPE_CHECKING:
import boto3
from mypy_boto3_bedrock import BedrockClient
from mypy_boto3_bedrock_runtime import BedrockRuntimeClient


@define
Expand All @@ -32,12 +32,12 @@ class AmazonBedrockImageGenerationDriver(BaseMultiModelImageGenerationDriver):
image_width: int = field(default=512, kw_only=True, metadata={"serializable": True})
image_height: int = field(default=512, kw_only=True, metadata={"serializable": True})
seed: Optional[int] = field(default=None, kw_only=True, metadata={"serializable": True})
_client: Optional[BedrockClient] = field(
_client: Optional[BedrockRuntimeClient] = field(
default=None, kw_only=True, alias="client", metadata={"serializable": False}
)

@lazy_property()
def client(self) -> BedrockClient:
def client(self) -> BedrockRuntimeClient:
return self.session.client("bedrock-runtime")

def try_text_to_image(self, prompts: list[str], negative_prompts: Optional[list[str]] = None) -> ImageArtifact:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load(self) -> tuple[list[Run], dict[str, Any]]:
response = self.table.get_item(Key=self._get_key())

if "Item" in response and self.value_attribute_key in response["Item"]:
memory_dict = json.loads(response["Item"][self.value_attribute_key])
memory_dict = json.loads(str(response["Item"][self.value_attribute_key]))
return self._from_params_dict(memory_dict)
return [], {}

Expand Down
Loading
Loading