From 34d88ee6c5869504eacd3dcf3889f39e546dd0b4 Mon Sep 17 00:00:00 2001 From: Martijn van Nieuwenhoven Date: Fri, 30 May 2025 19:26:58 +0200 Subject: [PATCH 1/2] conversesStream support WIP! --- src/Bedrock.php | 17 +- src/BedrockClientFactory.php | 27 +++ src/BedrockServiceProvider.php | 5 + src/Contracts/BedrockStreamHandler.php | 18 ++ src/Enums/BedrockSchema.php | 13 ++ src/Exceptions/BedrockException.php | 5 + src/HandlesStream.php | 164 ++++++++++++++++++ .../Converse/ConverseStreamHandler.php | 88 ++++++++++ .../BedrockRuntimeClientMockResponse.php | 48 +++++ tests/Fixtures/FixtureResponse.php | 23 +++ .../converse/stream-basic-text-1.jsonl | 14 ++ .../Converse/ConverseStreamHandlerTest.php | 35 ++++ 12 files changed, 456 insertions(+), 1 deletion(-) create mode 100644 src/BedrockClientFactory.php create mode 100644 src/Contracts/BedrockStreamHandler.php create mode 100644 src/Exceptions/BedrockException.php create mode 100644 src/HandlesStream.php create mode 100644 src/Schemas/Converse/ConverseStreamHandler.php create mode 100644 tests/Fixtures/BedrockRuntimeClientMockResponse.php create mode 100644 tests/Fixtures/converse/stream-basic-text-1.jsonl create mode 100644 tests/Schemas/Converse/ConverseStreamHandlerTest.php diff --git a/src/Bedrock.php b/src/Bedrock.php index 67d0860..7be1e2b 100644 --- a/src/Bedrock.php +++ b/src/Bedrock.php @@ -2,6 +2,7 @@ namespace Prism\Bedrock; +use Aws\BedrockRuntime\BedrockRuntimeClient; use Aws\Credentials\Credentials; use Aws\Signature\SignatureV4; use Generator; @@ -16,6 +17,7 @@ use Prism\Prism\Exceptions\PrismException; use Prism\Prism\Structured\Request as StructuredRequest; use Prism\Prism\Structured\Response as StructuredResponse; +use Prism\Prism\Text\Chunk; use Prism\Prism\Text\Request as TextRequest; use Prism\Prism\Text\Response as TextResponse; @@ -100,7 +102,15 @@ public function embeddings(EmbeddingRequest $request): EmbeddingsResponse */ public function stream(TextRequest $request): Generator { - throw new PrismException('Prism Bedrock does not support streaming yet.'); + $schema = BedrockSchema::Converse; + + $handler = $schema->streamHandler(); + + $client = $this->bedrockClient(); + + $handler = new $handler($this, $client); + + return $handler->handle($request); } public function schema(PrismRequest $request): BedrockSchema @@ -117,6 +127,11 @@ public function apiVersion(PrismRequest $request): ?string return $this->schema($request)->defaultApiVersion(); } + protected function bedrockClient(): BedrockRuntimeClient + { + return app(BedrockClientFactory::class)->make(); + } + /** * @param array $options * @param array $retry diff --git a/src/BedrockClientFactory.php b/src/BedrockClientFactory.php new file mode 100644 index 0000000..70ad6ef --- /dev/null +++ b/src/BedrockClientFactory.php @@ -0,0 +1,27 @@ + config('services.bedrock.region', 'eu-central-1'), + 'version' => config('services.bedrock.version', 'latest'), + 'credentials' => [ + 'key' => config('services.bedrock.api_key', ''), + 'secret' => config('services.bedrock.api_secret', ''), + ], + ]; + + if ($handler instanceof \GuzzleHttp\HandlerStack) { + $config['http'] = ['handler' => $handler]; + } + + return new BedrockRuntimeClient($config); + } +} diff --git a/src/BedrockServiceProvider.php b/src/BedrockServiceProvider.php index 476db45..a3cc163 100644 --- a/src/BedrockServiceProvider.php +++ b/src/BedrockServiceProvider.php @@ -43,4 +43,9 @@ protected function registerWithPrism(): void return $prismManager; }); } + + protected function registerBedrockClient(): void + { + $this->app->singleton(BedrockClientFactory::class, fn (): \Prism\Bedrock\BedrockClientFactory => new BedrockClientFactory); + } } diff --git a/src/Contracts/BedrockStreamHandler.php b/src/Contracts/BedrockStreamHandler.php new file mode 100644 index 0000000..ecbe2e7 --- /dev/null +++ b/src/Contracts/BedrockStreamHandler.php @@ -0,0 +1,18 @@ + + */ + public function streamHandler(): ?string + { + return match ($this) { + self::Converse => ConverseStreamHandler::class, + default => throw new PrismException('Prism Bedrock only supports streaming for Converse.'), + }; + } + /** * @return null|class-string */ diff --git a/src/Exceptions/BedrockException.php b/src/Exceptions/BedrockException.php new file mode 100644 index 0000000..af31dc8 --- /dev/null +++ b/src/Exceptions/BedrockException.php @@ -0,0 +1,5 @@ +state->reset(); + + $this->validateToolCallDepth($request, $depth); + + yield from $this->processStreamChunks($result, $request, $depth); + + if ($this->state->hasToolCalls()) { + yield from $this->handleToolUseFinish($request, $depth); + } + } + + protected function validateToolCallDepth(Request $request, int $depth): void + { + if ($depth >= $request->maxSteps()) { + throw new PrismException('Maximum tool call chain depth exceeded'); + } + } + + protected function processStreamChunks(Result $result, Request $request, int $depth): Generator + { + + foreach ($result->get('stream') as $event) { + + $outcome = $this->processChunk($event, $request, $depth); + + if ($outcome instanceof Generator) { + yield from $outcome; + } + + if ($outcome instanceof Chunk) { + yield $outcome; + } + } + } + + protected function processChunk(array $chunk, Request $request, int $depth): Generator|Chunk|null + { + return match (array_key_first($chunk) ?? null) { + 'messageStart' => $this->handleMessageStart(data_get($chunk, 'messageStart')), + 'contentBlockDelta' => $this->handleContentBlockDelta(data_get($chunk, 'contentBlockDelta')), + 'contentBlockStop' => $this->handleContentBlockStop(), + 'messageStop' => $this->handleMessageStop(data_get($chunk, 'messageStop'), $depth), + 'metadata' => $this->handleMetadata(data_get($chunk, 'metadata')), + 'modelStreamErrorException' => $this->handleException(data_get($chunk, 'modelStreamErrorException')), + 'serviceUnavailableException' => $this->handleException(data_get($chunk, 'serviceUnavailableException')), + 'throttlingException' => $this->handleException(data_get($chunk, 'throttlingException')), + 'validationException' => $this->handleException(data_get($chunk, 'validationException')), + default => null, + }; + } + + protected function handleMessageStart(array $chunk): null + { + // { + // messageStart: { + // role: assistant + // } + + return null; + } + + protected function handleContentBlockDelta(array $chunk): ?Chunk + { + if ($text = data_get($chunk, 'delta.text')) { + return $this->handleTextBlockDelta($text, (int) data_get($chunk, 'contentBlockIndex')); + } + + if ($reasoningContent = data_get($chunk, 'delta.reasoningContent')) { + return $this->handleReasoningContentBlockDelta($reasoningContent); + } + + if ($toolUse = data_get($chunk, 'delta.toolUse')) { + return $this->handleToolUseBlockDelta($toolUse); + } + + return null; + } + + protected function handleTextBlockDelta(string $text, int $contentBlockIndex): Chunk + { + $this->state->appendText($text); + + return new Chunk( + text: $text, + additionalContent: [ + 'contentBlockIndex' => $contentBlockIndex, + ], + chunkType: ChunkType::Text + ); + } + + protected function handleReasoningContentBlockDelta(array $reasoningContent): Chunk + { + $text = data_get($reasoningContent, 'reasoningText.text', ''); + $signature = data_get($reasoningContent, 'reasoningText.signature', ''); + + $this->state->appendThinking($text); + $this->state->appendThinkingSignature($signature); + + return new Chunk( + text: $text, + chunkType: ChunkType::Thinking + ); + } + + protected function handleContentBlockStop(): void + { + $this->state->resetContentBlock(); + } + + protected function handleMessageStop(array $chunk, int $depth): Generator|Chunk + { + $this->state->setStopReason(data_get($chunk, 'stopReason')); + + if ($this->state->isToolUseFinish()) { + return $this->handleToolUseFinish($chunk, $depth); + } + + return new Chunk( + text: $this->state->text(), + finishReason: FinishReasonMap::map($this->state->stopReason()), + additionalContent: $this->state->buildAdditionalContent(), + chunkType: ChunkType::Meta + ); + } + + protected function handleMetadata(array $chunk): void + { + // {"metadata":{"usage":{"inputTokens":11,"outputTokens":48,"totalTokens":59},"metrics":{"latencyMs":1269}}} + // not sure yet where to store this information. + } + + protected function handleException(array $chunk): void + { + throw new BedrockException(data_get($chunk, 'message')); + } + + protected function handleToolUseBlockDelta(array $toolUse): void + { + throw new \Exception('Tool use not yet supported'); + } + + public function handleToolUseFinish(Request $request, int $depth): void + { + throw new \Exception('Tool use not yet supported'); + } +} diff --git a/src/Schemas/Converse/ConverseStreamHandler.php b/src/Schemas/Converse/ConverseStreamHandler.php new file mode 100644 index 0000000..8fad43d --- /dev/null +++ b/src/Schemas/Converse/ConverseStreamHandler.php @@ -0,0 +1,88 @@ +state = new StreamState; + + $this->responseBuilder = new ResponseBuilder; + } + + /** + * @return Generator + * + * @throws PrismChunkDecodeException + * @throws PrismException + * @throws PrismRateLimitedException + */ + public function handle(Request $request): Generator + { + $result = $this->sendRequest($request); + + return $this->processStream($result, $request); + } + + /** + * @return array + */ + public static function buildPayload(Request $request, int $stepCount = 0): array + { + return array_filter([ + 'anthropic_version' => 'bedrock-2023-05-31', + '@http' => [ + 'stream' => true, + ], + 'modelId' => $request->model(), + 'max_tokens' => $request->maxTokens(), + 'temperature' => $request->temperature(), + 'top_p' => $request->topP(), + 'messages' => MessageMap::map($request->messages()), + 'system' => MessageMap::mapSystemMessages($request->systemPrompts()), + 'toolConfig' => $request->tools() === [] + ? null + : array_filter([ + 'tools' => ToolMap::map($request->tools()), + 'toolChoice' => $stepCount === 0 ? ToolChoiceMap::map($request->toolChoice()) : null, + ]), + ]); + } + protected function sendRequest(Request $request): Result + { + try { + $payload = static::buildPayload($request, $this->responseBuilder->steps->count()); + + return $this->client->converseStream($payload); + } catch (Throwable $e) { + throw PrismException::providerRequestError($request->model(), $e); + } + } +} diff --git a/tests/Fixtures/BedrockRuntimeClientMockResponse.php b/tests/Fixtures/BedrockRuntimeClientMockResponse.php new file mode 100644 index 0000000..b38ea5e --- /dev/null +++ b/tests/Fixtures/BedrockRuntimeClientMockResponse.php @@ -0,0 +1,48 @@ + new Response(200, [ + 'Content-Type' => 'application/vnd.amazon.eventstream', + ], ''), + ]); + + $handlerStack = HandlerStack::create($mockHandler); + + app()->singleton(BedrockClientFactory::class, fn(): \Prism\Bedrock\BedrockClientFactory => new class($handlerStack, $fakeResult) extends BedrockClientFactory + { + public function __construct( + private readonly HandlerStack $handler, + private readonly Result $fakeResult + ) {} + + public function make(?HandlerStack $handler = null): BedrockRuntimeClient + { + $client = parent::make($this->handler); + + $client->getHandlerList()->setHandler( + fn($command, $request): \GuzzleHttp\Promise\PromiseInterface => $command->getName() === 'ConverseStream' + ? Create::promiseFor($this->fakeResult) + : Create::promiseFor([]) + ); + + return $client; + } + }); + + } +} diff --git a/tests/Fixtures/FixtureResponse.php b/tests/Fixtures/FixtureResponse.php index 6180d20..dc86735 100644 --- a/tests/Fixtures/FixtureResponse.php +++ b/tests/Fixtures/FixtureResponse.php @@ -4,6 +4,8 @@ namespace Tests\Fixtures; +use ArrayIterator; +use Aws\Result; use GuzzleHttp\Promise\PromiseInterface; use Illuminate\Support\Facades\Http; @@ -68,4 +70,25 @@ public static function fakeResponseSequence(string $requestPath, string $name, a $requestPath => Http::sequence($responses->toArray()), ])->preventStrayRequests(); } + + public static function fakeConverseStream(string $name): Result + { + $filePath = static::filePath("{$name}-1.jsonl"); + + if (! file_exists($filePath)) { + throw new \RuntimeException("Fixture file not found: {$filePath}"); + } + + $lines = file($filePath, FILE_IGNORE_NEW_LINES | FILE_SKIP_EMPTY_LINES); + $events = array_map(fn ($line): mixed => json_decode((string) $line, true), $lines); + + return new Result([ + 'stream' => new ArrayIterator($events), + '@metadata' => [ + 'statusCode' => 200, + 'headers' => [], + 'effectiveUri' => 'https://bedrock-runtime...', + ], + ]); + } } diff --git a/tests/Fixtures/converse/stream-basic-text-1.jsonl b/tests/Fixtures/converse/stream-basic-text-1.jsonl new file mode 100644 index 0000000..7184468 --- /dev/null +++ b/tests/Fixtures/converse/stream-basic-text-1.jsonl @@ -0,0 +1,14 @@ +{"messageStart":{"role":"assistant"}} +{"contentBlockDelta":{"delta":{"text":"I'm"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" an AI"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" assistant created by"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" Anthropic to"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" be helpful, harm"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":"less, and honest"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":". I"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" don't have a"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" physical body or"},"contentBlockIndex":0}} +{"contentBlockDelta":{"delta":{"text":" avatar"},"contentBlockIndex":0}} +{"contentBlockStop":{"contentBlockIndex":0}} +{"messageStop":{"stopReason":"end_turn"}} +{"metadata":{"usage":{"inputTokens":11,"outputTokens":48,"totalTokens":59},"metrics":{"latencyMs":1269}}} \ No newline at end of file diff --git a/tests/Schemas/Converse/ConverseStreamHandlerTest.php b/tests/Schemas/Converse/ConverseStreamHandlerTest.php new file mode 100644 index 0000000..556dd19 --- /dev/null +++ b/tests/Schemas/Converse/ConverseStreamHandlerTest.php @@ -0,0 +1,35 @@ +using('bedrock', 'anthropic.claude-3-5-sonnet-20240620-v1:0') + ->withMessages([new UserMessage('Who are you?')]) + ->asStream(); + + $text = ''; + $chunks = []; + + foreach ($response as $chunk) { + $chunks[] = $chunk; + $text .= $chunk->text; + } + + expect($chunks)->not->toBeEmpty(); + expect($text)->not->toBeEmpty(); + expect(end($chunks)->finishReason)->toBe(FinishReason::Stop); +}); From 360087dac7201d0e1499b80a40bd96a540812ad3 Mon Sep 17 00:00:00 2001 From: Martijn van Nieuwenhoven Date: Fri, 30 May 2025 21:30:24 +0200 Subject: [PATCH 2/2] Fix config keys --- src/BedrockClientFactory.php | 8 ++++---- tests/Fixtures/FixtureResponse.php | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/src/BedrockClientFactory.php b/src/BedrockClientFactory.php index 70ad6ef..5e39bd2 100644 --- a/src/BedrockClientFactory.php +++ b/src/BedrockClientFactory.php @@ -10,11 +10,11 @@ class BedrockClientFactory public function make(?HandlerStack $handler = null): BedrockRuntimeClient { $config = [ - 'region' => config('services.bedrock.region', 'eu-central-1'), - 'version' => config('services.bedrock.version', 'latest'), + 'region' => config('prism.providers.bedrock.region', 'eu-central-1'), + 'version' => config('prism.providers.bedrock.version', 'latest'), 'credentials' => [ - 'key' => config('services.bedrock.api_key', ''), - 'secret' => config('services.bedrock.api_secret', ''), + 'key' => config('prism.providers.bedrock.api_key', ''), + 'secret' => config('prism.providers.bedrock.api_secret', ''), ], ]; diff --git a/tests/Fixtures/FixtureResponse.php b/tests/Fixtures/FixtureResponse.php index dc86735..2fe11fd 100644 --- a/tests/Fixtures/FixtureResponse.php +++ b/tests/Fixtures/FixtureResponse.php @@ -80,7 +80,7 @@ public static function fakeConverseStream(string $name): Result } $lines = file($filePath, FILE_IGNORE_NEW_LINES | FILE_SKIP_EMPTY_LINES); - $events = array_map(fn ($line): mixed => json_decode((string) $line, true), $lines); + $events = array_map(fn ($line): mixed => json_decode($line, true), $lines); return new Result([ 'stream' => new ArrayIterator($events),