diff --git a/composer.json b/composer.json index d662461..3b37f4c 100644 --- a/composer.json +++ b/composer.json @@ -22,7 +22,7 @@ "php": "^8.2", "laravel/framework": "^11.0|^12.0", "aws/aws-sdk-php": "^3.339", - "prism-php/prism": ">=0.77.1" + "prism-php/prism": ">=0.80.0" }, "config": { "allow-plugins": { @@ -97,4 +97,4 @@ ] } } -} +} \ No newline at end of file diff --git a/src/Bedrock.php b/src/Bedrock.php index f141dcb..8f90f5b 100644 --- a/src/Bedrock.php +++ b/src/Bedrock.php @@ -2,8 +2,10 @@ namespace Prism\Bedrock; +use Aws\BedrockRuntime\BedrockRuntimeClient; use Aws\Credentials\Credentials; use Aws\Signature\SignatureV4; +use Generator; use Illuminate\Http\Client\PendingRequest; use Illuminate\Http\Client\Request; use Prism\Bedrock\Enums\BedrockSchema; @@ -15,6 +17,7 @@ use Prism\Prism\Providers\Provider; use Prism\Prism\Structured\Request as StructuredRequest; use Prism\Prism\Structured\Response as StructuredResponse; +use Prism\Prism\Text\Chunk; use Prism\Prism\Text\Request as TextRequest; use Prism\Prism\Text\Response as TextResponse; @@ -95,6 +98,23 @@ public function embeddings(EmbeddingRequest $request): EmbeddingsResponse return $handler->handle($request); } + #[\Override] + /** + * @return Generator + */ + public function stream(TextRequest $request): Generator + { + $schema = BedrockSchema::Converse; + + $handler = $schema->streamHandler(); + + $client = $this->bedrockClient(); + + $handler = new $handler($this, $client); + + return $handler->handle($request); + } + public function schema(PrismRequest $request): BedrockSchema { $override = $request->providerOptions(); @@ -109,6 +129,11 @@ public function apiVersion(PrismRequest $request): ?string return $this->schema($request)->defaultApiVersion(); } + protected function bedrockClient(): BedrockRuntimeClient + { + return app(BedrockClientFactory::class)->make(); + } + /** * @param array $options * @param array $retry diff --git a/src/BedrockClientFactory.php b/src/BedrockClientFactory.php new file mode 100644 index 0000000..5e39bd2 --- /dev/null +++ b/src/BedrockClientFactory.php @@ -0,0 +1,27 @@ + config('prism.providers.bedrock.region', 'eu-central-1'), + 'version' => config('prism.providers.bedrock.version', 'latest'), + 'credentials' => [ + 'key' => config('prism.providers.bedrock.api_key', ''), + 'secret' => config('prism.providers.bedrock.api_secret', ''), + ], + ]; + + if ($handler instanceof \GuzzleHttp\HandlerStack) { + $config['http'] = ['handler' => $handler]; + } + + return new BedrockRuntimeClient($config); + } +} diff --git a/src/BedrockServiceProvider.php b/src/BedrockServiceProvider.php index 476db45..a3cc163 100644 --- a/src/BedrockServiceProvider.php +++ b/src/BedrockServiceProvider.php @@ -43,4 +43,9 @@ protected function registerWithPrism(): void return $prismManager; }); } + + protected function registerBedrockClient(): void + { + $this->app->singleton(BedrockClientFactory::class, fn (): \Prism\Bedrock\BedrockClientFactory => new BedrockClientFactory); + } } diff --git a/src/Contracts/BedrockStreamHandler.php b/src/Contracts/BedrockStreamHandler.php new file mode 100644 index 0000000..ecbe2e7 --- /dev/null +++ b/src/Contracts/BedrockStreamHandler.php @@ -0,0 +1,18 @@ + + */ + public function streamHandler(): ?string + { + return match ($this) { + self::Converse => ConverseStreamHandler::class, + default => throw new PrismException('Prism Bedrock only supports streaming for Converse.'), + }; + } + /** * @return null|class-string */ diff --git a/src/Exceptions/BedrockException.php b/src/Exceptions/BedrockException.php new file mode 100644 index 0000000..af31dc8 --- /dev/null +++ b/src/Exceptions/BedrockException.php @@ -0,0 +1,5 @@ +state->reset(); + + $this->validateToolCallDepth($request, $depth); + + yield from $this->processStreamChunks($result, $request, $depth); + + if ($this->state->hasToolCalls()) { + yield from $this->handleToolUseFinish($request, $depth); + } + } + + protected function validateToolCallDepth(Request $request, int $depth): void + { + if ($depth >= $request->maxSteps()) { + throw new PrismException('Maximum tool call chain depth exceeded'); + } + } + + protected function processStreamChunks(Result $result, Request $request, int $depth): Generator + { + + foreach ($result->get('stream') as $event) { + + $outcome = $this->processChunk($event, $request, $depth); + + if ($outcome instanceof Generator) { + yield from $outcome; + } + + if ($outcome instanceof Chunk) { + yield $outcome; + } + } + } + + protected function processChunk(array $chunk, Request $request, int $depth): Generator|Chunk|null + { + return match (array_key_first($chunk) ?? null) { + 'messageStart' => $this->handleMessageStart(data_get($chunk, 'messageStart')), + 'contentBlockDelta' => $this->handleContentBlockDelta(data_get($chunk, 'contentBlockDelta')), + 'contentBlockStop' => $this->handleContentBlockStop(), + 'messageStop' => $this->handleMessageStop(data_get($chunk, 'messageStop'), $depth), + 'metadata' => $this->handleMetadata(data_get($chunk, 'metadata')), + 'modelStreamErrorException' => $this->handleException(data_get($chunk, 'modelStreamErrorException')), + 'serviceUnavailableException' => $this->handleException(data_get($chunk, 'serviceUnavailableException')), + 'throttlingException' => $this->handleException(data_get($chunk, 'throttlingException')), + 'validationException' => $this->handleException(data_get($chunk, 'validationException')), + default => null, + }; + } + + protected function handleMessageStart(array $chunk): null + { + // { + // messageStart: { + // role: assistant + // } + + return null; + } + + protected function handleContentBlockDelta(array $chunk): ?Chunk + { + if ($text = data_get($chunk, 'delta.text')) { + return $this->handleTextBlockDelta($text, (int) data_get($chunk, 'contentBlockIndex')); + } + + if ($reasoningContent = data_get($chunk, 'delta.reasoningContent')) { + return $this->handleReasoningContentBlockDelta($reasoningContent); + } + + if ($toolUse = data_get($chunk, 'delta.toolUse')) { + return $this->handleToolUseBlockDelta($toolUse); + } + + return null; + } + + protected function handleTextBlockDelta(string $text, int $contentBlockIndex): Chunk + { + $this->state->appendText($text); + + return new Chunk( + text: $text, + additionalContent: [ + 'contentBlockIndex' => $contentBlockIndex, + ], + chunkType: ChunkType::Text + ); + } + + protected function handleReasoningContentBlockDelta(array $reasoningContent): Chunk + { + $text = data_get($reasoningContent, 'reasoningText.text', ''); + $signature = data_get($reasoningContent, 'reasoningText.signature', ''); + + $this->state->appendThinking($text); + $this->state->appendThinkingSignature($signature); + + return new Chunk( + text: $text, + chunkType: ChunkType::Thinking + ); + } + + protected function handleContentBlockStop(): void + { + $this->state->resetContentBlock(); + } + + protected function handleMessageStop(array $chunk, int $depth): Generator|Chunk + { + $this->state->setStopReason(data_get($chunk, 'stopReason')); + + if ($this->state->isToolUseFinish()) { + return $this->handleToolUseFinish($chunk, $depth); + } + + return new Chunk( + text: $this->state->text(), + finishReason: FinishReasonMap::map($this->state->stopReason()), + additionalContent: $this->state->buildAdditionalContent(), + chunkType: ChunkType::Meta + ); + } + + protected function handleMetadata(array $chunk): void + { + // {"metadata":{"usage":{"inputTokens":11,"outputTokens":48,"totalTokens":59},"metrics":{"latencyMs":1269}}} + // not sure yet where to store this information. + } + + protected function handleException(array $chunk): void + { + throw new BedrockException(data_get($chunk, 'message')); + } + + protected function handleToolUseBlockDelta(array $toolUse): void + { + throw new \Exception('Tool use not yet supported'); + } + + public function handleToolUseFinish(Request $request, int $depth): void + { + throw new \Exception('Tool use not yet supported'); + } +} diff --git a/src/Schemas/Anthropic/Maps/MessageMap.php b/src/Schemas/Anthropic/Maps/MessageMap.php index 8e98014..4f2451b 100644 --- a/src/Schemas/Anthropic/Maps/MessageMap.php +++ b/src/Schemas/Anthropic/Maps/MessageMap.php @@ -8,8 +8,8 @@ use Exception; use Prism\Prism\Contracts\Message; use Prism\Prism\Exceptions\PrismException; +use Prism\Prism\ValueObjects\Media\Image; use Prism\Prism\ValueObjects\Messages\AssistantMessage; -use Prism\Prism\ValueObjects\Messages\Support\Image; use Prism\Prism\ValueObjects\Messages\SystemMessage; use Prism\Prism\ValueObjects\Messages\ToolResultMessage; use Prism\Prism\ValueObjects\Messages\UserMessage; diff --git a/src/Schemas/Converse/ConverseStreamHandler.php b/src/Schemas/Converse/ConverseStreamHandler.php new file mode 100644 index 0000000..9bd869d --- /dev/null +++ b/src/Schemas/Converse/ConverseStreamHandler.php @@ -0,0 +1,88 @@ +state = new StreamState; + + $this->responseBuilder = new ResponseBuilder; + } + + /** + * @return Generator + * + * @throws PrismChunkDecodeException + * @throws PrismException + * @throws PrismRateLimitedException + */ + public function handle(Request $request): Generator + { + $result = $this->sendRequest($request); + + return $this->processStream($result, $request); + } + + /** + * @return array + */ + public static function buildPayload(Request $request, int $stepCount = 0): array + { + return array_filter([ + 'anthropic_version' => 'bedrock-2023-05-31', + '@http' => [ + 'stream' => true, + ], + 'modelId' => $request->model(), + 'max_tokens' => $request->maxTokens(), + 'temperature' => $request->temperature(), + 'top_p' => $request->topP(), + 'messages' => MessageMap::map($request->messages()), + 'system' => MessageMap::mapSystemMessages($request->systemPrompts()), + 'toolConfig' => $request->tools() === [] + ? null + : array_filter([ + 'tools' => ToolMap::map($request->tools()), + 'toolChoice' => $stepCount === 0 ? ToolChoiceMap::map($request->toolChoice()) : null, + ]), + ]); + } + + protected function sendRequest(Request $request): Result + { + try { + $payload = static::buildPayload($request, $this->responseBuilder->steps->count()); + + return $this->client->converseStream($payload); + } catch (Throwable $e) { + throw PrismException::providerRequestError($request->model(), $e); + } + } +} diff --git a/src/Schemas/Converse/ConverseTextHandler.php b/src/Schemas/Converse/ConverseTextHandler.php index c84da55..b8fb94e 100644 --- a/src/Schemas/Converse/ConverseTextHandler.php +++ b/src/Schemas/Converse/ConverseTextHandler.php @@ -112,7 +112,6 @@ protected function prepareTempResponse(): void $this->tempResponse = new TextResponse( steps: new Collection, responseMessages: new Collection, - messages: new Collection, text: data_get($data, 'output.message.content.0.text', ''), finishReason: FinishReasonMap::map(data_get($data, 'stopReason')), toolCalls: $this->extractToolCalls($data), @@ -121,7 +120,8 @@ protected function prepareTempResponse(): void promptTokens: data_get($data, 'usage.inputTokens'), completionTokens: data_get($data, 'usage.outputTokens') ), - meta: new Meta(id: '', model: '') // Not provided in Converse response. + meta: new Meta(id: '', model: ''), + messages: new Collection // Not provided in Converse response. ); } diff --git a/src/Schemas/Converse/Maps/DocumentMapper.php b/src/Schemas/Converse/Maps/DocumentMapper.php index c7fa7fd..a19d790 100644 --- a/src/Schemas/Converse/Maps/DocumentMapper.php +++ b/src/Schemas/Converse/Maps/DocumentMapper.php @@ -5,8 +5,8 @@ use Prism\Bedrock\Enums\Mimes; use Prism\Prism\Contracts\ProviderMediaMapper; use Prism\Prism\Enums\Provider; -use Prism\Prism\ValueObjects\Messages\Support\Document; -use Prism\Prism\ValueObjects\Messages\Support\Media; +use Prism\Prism\ValueObjects\Media\Document; +use Prism\Prism\ValueObjects\Media\Media; class DocumentMapper extends ProviderMediaMapper { diff --git a/src/Schemas/Converse/Maps/ImageMapper.php b/src/Schemas/Converse/Maps/ImageMapper.php index a4387f5..fa8f9c4 100644 --- a/src/Schemas/Converse/Maps/ImageMapper.php +++ b/src/Schemas/Converse/Maps/ImageMapper.php @@ -5,8 +5,8 @@ use Prism\Bedrock\Enums\Mimes; use Prism\Prism\Contracts\ProviderMediaMapper; use Prism\Prism\Enums\Provider; -use Prism\Prism\ValueObjects\Messages\Support\Image; -use Prism\Prism\ValueObjects\Messages\Support\Media; +use Prism\Prism\ValueObjects\Media\Image; +use Prism\Prism\ValueObjects\Media\Media; class ImageMapper extends ProviderMediaMapper { diff --git a/src/Schemas/Converse/Maps/MessageMap.php b/src/Schemas/Converse/Maps/MessageMap.php index 8370636..5170266 100644 --- a/src/Schemas/Converse/Maps/MessageMap.php +++ b/src/Schemas/Converse/Maps/MessageMap.php @@ -7,9 +7,9 @@ use Exception; use Prism\Prism\Contracts\Message; use Prism\Prism\Exceptions\PrismException; +use Prism\Prism\ValueObjects\Media\Document; +use Prism\Prism\ValueObjects\Media\Image; use Prism\Prism\ValueObjects\Messages\AssistantMessage; -use Prism\Prism\ValueObjects\Messages\Support\Document; -use Prism\Prism\ValueObjects\Messages\Support\Image; use Prism\Prism\ValueObjects\Messages\SystemMessage; use Prism\Prism\ValueObjects\Messages\ToolResultMessage; use Prism\Prism\ValueObjects\Messages\UserMessage; diff --git a/tests/Fixtures/BedrockRuntimeClientMockResponse.php b/tests/Fixtures/BedrockRuntimeClientMockResponse.php new file mode 100644 index 0000000..b38ea5e --- /dev/null +++ b/tests/Fixtures/BedrockRuntimeClientMockResponse.php @@ -0,0 +1,48 @@ + new Response(200, [ + 'Content-Type' => 'application/vnd.amazon.eventstream', + ], ''), + ]); + + $handlerStack = HandlerStack::create($mockHandler); + + app()->singleton(BedrockClientFactory::class, fn(): \Prism\Bedrock\BedrockClientFactory => new class($handlerStack, $fakeResult) extends BedrockClientFactory + { + public function __construct( + private readonly HandlerStack $handler, + private readonly Result $fakeResult + ) {} + + public function make(?HandlerStack $handler = null): BedrockRuntimeClient + { + $client = parent::make($this->handler); + + $client->getHandlerList()->setHandler( + fn($command, $request): \GuzzleHttp\Promise\PromiseInterface => $command->getName() === 'ConverseStream' + ? Create::promiseFor($this->fakeResult) + : Create::promiseFor([]) + ); + + return $client; + } + }); + + } +} diff --git a/tests/Fixtures/FixtureResponse.php b/tests/Fixtures/FixtureResponse.php index 6180d20..2fe11fd 100644 --- a/tests/Fixtures/FixtureResponse.php +++ b/tests/Fixtures/FixtureResponse.php @@ -4,6 +4,8 @@ namespace Tests\Fixtures; +use ArrayIterator; +use Aws\Result; use GuzzleHttp\Promise\PromiseInterface; use Illuminate\Support\Facades\Http; @@ -68,4 +70,25 @@ public static function fakeResponseSequence(string $requestPath, string $name, a $requestPath => Http::sequence($responses->toArray()), ])->preventStrayRequests(); } + + public static function fakeConverseStream(string $name): Result + { + $filePath = static::filePath("{$name}-1.jsonl"); + + if (! file_exists($filePath)) { + throw new \RuntimeException("Fixture file not found: {$filePath}"); + } + + $lines = file($filePath, FILE_IGNORE_NEW_LINES | FILE_SKIP_EMPTY_LINES); + $events = array_map(fn ($line): mixed => json_decode($line, true), $lines); + + return new Result([ + 'stream' => new ArrayIterator($events), + '@metadata' => [ + 'statusCode' => 200, + 'headers' => [], + 'effectiveUri' => 'https://bedrock-runtime...', + ], + ]); + } } diff --git a/tests/Fixtures/converse/stream-basic-text-1.jsonl b/tests/Fixtures/converse/stream-basic-text-1.jsonl new file mode 100644 index 0000000..7184468 --- /dev/null +++ b/tests/Fixtures/converse/stream-basic-text-1.jsonl @@ -0,0 +1,14 @@ +{"messageStart":{"role":"assistant"}} +{"contentBlockDelta":{"delta":{"text":"I'm"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" an AI"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" assistant created by"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" Anthropic to"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" be helpful, harm"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":"less, and honest"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":". I"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" don't have a"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" physical body or"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" avatar"},"contentBlockIndex":0}} +{"contentBlockStop":{"contentBlockIndex":0}} +{"messageStop":{"stopReason":"end_turn"}} +{"metadata":{"usage":{"inputTokens":11,"outputTokens":48,"totalTokens":59},"metrics":{"latencyMs":1269}}} \ No newline at end of file diff --git a/tests/Schemas/Anthropic/AnthropicTextHandlerTest.php b/tests/Schemas/Anthropic/AnthropicTextHandlerTest.php index 202f1fc..e678916 100644 --- a/tests/Schemas/Anthropic/AnthropicTextHandlerTest.php +++ b/tests/Schemas/Anthropic/AnthropicTextHandlerTest.php @@ -6,7 +6,7 @@ use Illuminate\Support\Facades\Http; use Prism\Prism\Facades\Tool; use Prism\Prism\Prism; -use Prism\Prism\ValueObjects\Messages\Support\Image; +use Prism\Prism\ValueObjects\Media\Image; use Prism\Prism\ValueObjects\Messages\SystemMessage; use Prism\Prism\ValueObjects\Messages\UserMessage; use Tests\Fixtures\FixtureResponse; diff --git a/tests/Schemas/Anthropic/Maps/MessageMapTest.php b/tests/Schemas/Anthropic/Maps/MessageMapTest.php index 50cfff0..8be9954 100644 --- a/tests/Schemas/Anthropic/Maps/MessageMapTest.php +++ b/tests/Schemas/Anthropic/Maps/MessageMapTest.php @@ -7,8 +7,8 @@ use Prism\Bedrock\Schemas\Anthropic\Maps\MessageMap; use Prism\Prism\Exceptions\PrismException; use Prism\Prism\Providers\Anthropic\Enums\AnthropicCacheType; +use Prism\Prism\ValueObjects\Media\Image; use Prism\Prism\ValueObjects\Messages\AssistantMessage; -use Prism\Prism\ValueObjects\Messages\Support\Image; use Prism\Prism\ValueObjects\Messages\SystemMessage; use Prism\Prism\ValueObjects\Messages\ToolResultMessage; use Prism\Prism\ValueObjects\Messages\UserMessage; diff --git a/tests/Schemas/Converse/ConverseStreamHandlerTest.php b/tests/Schemas/Converse/ConverseStreamHandlerTest.php new file mode 100644 index 0000000..556dd19 --- /dev/null +++ b/tests/Schemas/Converse/ConverseStreamHandlerTest.php @@ -0,0 +1,35 @@ +using('bedrock', 'anthropic.claude-3-5-sonnet-20240620-v1:0') + ->withMessages([new UserMessage('Who are you?')]) + ->asStream(); + + $text = ''; + $chunks = []; + + foreach ($response as $chunk) { + $chunks[] = $chunk; + $text .= $chunk->text; + } + + expect($chunks)->not->toBeEmpty(); + expect($text)->not->toBeEmpty(); + expect(end($chunks)->finishReason)->toBe(FinishReason::Stop); +}); diff --git a/tests/Schemas/Converse/ConverseTextHandlerTest.php b/tests/Schemas/Converse/ConverseTextHandlerTest.php index 3e7a4d0..b2b7f1f 100644 --- a/tests/Schemas/Converse/ConverseTextHandlerTest.php +++ b/tests/Schemas/Converse/ConverseTextHandlerTest.php @@ -11,8 +11,8 @@ use Prism\Prism\Prism; use Prism\Prism\Testing\TextStepFake; use Prism\Prism\Text\ResponseBuilder; -use Prism\Prism\ValueObjects\Messages\Support\Document; -use Prism\Prism\ValueObjects\Messages\Support\Image; +use Prism\Prism\ValueObjects\Media\Document; +use Prism\Prism\ValueObjects\Media\Image; use Prism\Prism\ValueObjects\Messages\UserMessage; use Tests\Fixtures\FixtureResponse; diff --git a/tests/Schemas/Converse/Maps/MessageMapTest.php b/tests/Schemas/Converse/Maps/MessageMapTest.php index d148669..ce03ec7 100644 --- a/tests/Schemas/Converse/Maps/MessageMapTest.php +++ b/tests/Schemas/Converse/Maps/MessageMapTest.php @@ -5,9 +5,9 @@ namespace Tests\Schemas\Converse\Maps; use Prism\Bedrock\Schemas\Converse\Maps\MessageMap; +use Prism\Prism\ValueObjects\Media\Document; +use Prism\Prism\ValueObjects\Media\Image; use Prism\Prism\ValueObjects\Messages\AssistantMessage; -use Prism\Prism\ValueObjects\Messages\Support\Document; -use Prism\Prism\ValueObjects\Messages\Support\Image; use Prism\Prism\ValueObjects\Messages\SystemMessage; use Prism\Prism\ValueObjects\Messages\ToolResultMessage; use Prism\Prism\ValueObjects\Messages\UserMessage;