diff --git a/lib/ruby_llm/models.rb b/lib/ruby_llm/models.rb index adf9b9796..12e9adc31 100644 --- a/lib/ruby_llm/models.rb +++ b/lib/ruby_llm/models.rb @@ -120,8 +120,10 @@ def resolve(model_id, provider: nil, assume_exists: false, config: nil) # ruboco model = Models.find model_id, provider rescue ModelNotFoundError # Allow raw model IDs for Bedrock and BedrockConverse (they use ARN-style IDs not in registry) - if %w[bedrock bedrock_converse].include?(provider.to_s) - provider_class = Provider.providers[provider.to_sym] + if %w[bedrock bedrock_converse bedrockconverse].include?(provider.to_s) + # Normalize bedrockconverse -> bedrock_converse for provider lookup + provider_key = provider.to_s == 'bedrockconverse' ? :bedrock_converse : provider.to_sym + provider_class = Provider.providers[provider_key] provider_instance = provider_class.new(config) model = Model::Info.default(model_id, provider_instance.slug) return [model, provider_instance] diff --git a/lib/ruby_llm/providers/bedrock_converse.rb b/lib/ruby_llm/providers/bedrock_converse.rb index 39c270b73..ce8f89f45 100644 --- a/lib/ruby_llm/providers/bedrock_converse.rb +++ b/lib/ruby_llm/providers/bedrock_converse.rb @@ -80,3 +80,5 @@ def configuration_requirements end end end + +require_relative 'bedrock_converse/content' diff --git a/lib/ruby_llm/providers/bedrock_converse/chat.rb b/lib/ruby_llm/providers/bedrock_converse/chat.rb index 071661696..a3f127227 100644 --- a/lib/ruby_llm/providers/bedrock_converse/chat.rb +++ b/lib/ruby_llm/providers/bedrock_converse/chat.rb @@ -232,7 +232,8 @@ def build_converse_system_content(system_messages) content = msg.content if content.is_a?(RubyLLM::Content::Raw) - [{ text: content.value.to_s }] + # Preserve Raw content (e.g., with cachePoint blocks) as-is + Array(content.value) else [{ text: Media.format_content(msg.content).map { |c| c[:text] || c.to_s }.join }] end @@ -304,6 +305,8 @@ def build_converse_message(data, content, tool_use_blocks, response) tool_calls: parse_converse_tool_calls(tool_use_blocks), input_tokens: usage['inputTokens'], output_tokens: usage['outputTokens'], + cached_tokens: usage['cacheReadInputTokenCount'], + cache_creation_tokens: usage['cacheWriteInputTokenCount'], model_id: @model_id, raw: response ) diff --git a/lib/ruby_llm/providers/bedrock_converse/content.rb b/lib/ruby_llm/providers/bedrock_converse/content.rb new file mode 100644 index 000000000..b174e9003 --- /dev/null +++ b/lib/ruby_llm/providers/bedrock_converse/content.rb @@ -0,0 +1,32 @@ +# frozen_string_literal: true + +module RubyLLM + module Providers + class BedrockConverse + # Helper class for building content blocks with cache points for AWS Bedrock Converse API. + # Bedrock uses cachePoint blocks instead of Anthropic's cache_control format. + # + # @example Cache system instructions with default TTL + # content = BedrockConverse::Content.new("System prompt", cache: true) + # + # @example Cache with explicit TTL + # content = BedrockConverse::Content.new("System prompt", ttl: '1h') + class Content + VALID_TTLS = %w[5m 1h].freeze + + def self.new(text = nil, cache: false, ttl: nil, parts: nil) + payload = if parts + Array(parts) + else + raise ArgumentError, 'text or parts required' if text.nil? + + blocks = [{ text: text }] + blocks << { cachePoint: { type: 'default', ttl: ttl || '5m' } } if cache || ttl + blocks + end + RubyLLM::Content::Raw.new(payload) + end + end + end + end +end diff --git a/lib/ruby_llm/providers/bedrock_converse/streaming/content_extraction.rb b/lib/ruby_llm/providers/bedrock_converse/streaming/content_extraction.rb index e620b5188..677567c82 100644 --- a/lib/ruby_llm/providers/bedrock_converse/streaming/content_extraction.rb +++ b/lib/ruby_llm/providers/bedrock_converse/streaming/content_extraction.rb @@ -65,13 +65,11 @@ def extract_output_tokens(data) end def extract_cached_tokens(data) - # Converse API doesn't expose cache metrics in the same way - nil + data.dig('usage', 'cacheReadInputTokenCount') end def extract_cache_creation_tokens(data) - # Converse API doesn't expose cache metrics in the same way - nil + data.dig('usage', 'cacheWriteInputTokenCount') end end end diff --git a/spec/ruby_llm/providers/bedrock_converse/chat_spec.rb b/spec/ruby_llm/providers/bedrock_converse/chat_spec.rb index 27383d071..87a165ecc 100644 --- a/spec/ruby_llm/providers/bedrock_converse/chat_spec.rb +++ b/spec/ruby_llm/providers/bedrock_converse/chat_spec.rb @@ -342,7 +342,22 @@ def initialize result = chat_instance.send(:build_converse_system_content, [message]) - expect(result.first[:text]).to include('raw system prompt') + expect(result).to eq([{ text: 'raw system prompt' }]) + end + + it 'preserves cachePoint blocks in raw content' do + raw_content = RubyLLM::Content::Raw.new([ + { text: 'system prompt' }, + { cachePoint: { type: 'default', ttl: '5m' } } + ]) + message = RubyLLM::Message.new(role: :system, content: raw_content) + + result = chat_instance.send(:build_converse_system_content, [message]) + + expect(result).to eq([ + { text: 'system prompt' }, + { cachePoint: { type: 'default', ttl: '5m' } } + ]) end end @@ -599,6 +614,50 @@ def initialize expect(message.input_tokens).to be_nil expect(message.output_tokens).to be_nil end + + it 'extracts cache metrics from response' do + response_body = { + 'output' => { + 'message' => { + 'content' => [{ 'text' => 'Response' }] + } + }, + 'usage' => { + 'inputTokens' => 100, + 'outputTokens' => 50, + 'cacheReadInputTokenCount' => 80, + 'cacheWriteInputTokenCount' => 20 + } + } + response = instance_double(Faraday::Response, body: response_body) + + message = chat_instance.send(:parse_converse_response, response) + + expect(message.input_tokens).to eq(100) + expect(message.output_tokens).to eq(50) + expect(message.cached_tokens).to eq(80) + expect(message.cache_creation_tokens).to eq(20) + end + + it 'handles missing cache metrics' do + response_body = { + 'output' => { + 'message' => { + 'content' => [{ 'text' => 'Response' }] + } + }, + 'usage' => { + 'inputTokens' => 100, + 'outputTokens' => 50 + } + } + response = instance_double(Faraday::Response, body: response_body) + + message = chat_instance.send(:parse_converse_response, response) + + expect(message.cached_tokens).to be_nil + expect(message.cache_creation_tokens).to be_nil + end end describe '.format_tool_for_converse' do diff --git a/spec/ruby_llm/providers/bedrock_converse/content_spec.rb b/spec/ruby_llm/providers/bedrock_converse/content_spec.rb new file mode 100644 index 000000000..a08e6c932 --- /dev/null +++ b/spec/ruby_llm/providers/bedrock_converse/content_spec.rb @@ -0,0 +1,90 @@ +# frozen_string_literal: true + +require 'spec_helper' + +RSpec.describe RubyLLM::Providers::BedrockConverse::Content do + describe '.new' do + it 'creates a Raw content object with text only when cache is false' do + result = described_class.new('Hello, world!') + + expect(result).to be_a(RubyLLM::Content::Raw) + expect(result.value).to eq([{ text: 'Hello, world!' }]) + end + + it 'creates a Raw content object with text and cachePoint when cache is true' do + result = described_class.new('System prompt', cache: true) + + expect(result).to be_a(RubyLLM::Content::Raw) + expect(result.value).to eq([ + { text: 'System prompt' }, + { cachePoint: { type: 'default', ttl: '5m' } } + ]) + end + + it 'creates a Raw content object with text and cachePoint when ttl is provided' do + result = described_class.new('System prompt', ttl: '1h') + + expect(result).to be_a(RubyLLM::Content::Raw) + expect(result.value).to eq([ + { text: 'System prompt' }, + { cachePoint: { type: 'default', ttl: '1h' } } + ]) + end + + it 'uses provided ttl over default when both cache and ttl are set' do + result = described_class.new('System prompt', cache: true, ttl: '1h') + + expect(result).to be_a(RubyLLM::Content::Raw) + expect(result.value).to eq([ + { text: 'System prompt' }, + { cachePoint: { type: 'default', ttl: '1h' } } + ]) + end + + it 'uses default 5m ttl when cache is true and no ttl provided' do + result = described_class.new('System prompt', cache: true) + + expect(result.value[1][:cachePoint][:ttl]).to eq('5m') + end + + it 'accepts parts directly' do + parts = [ + { text: 'Part 1' }, + { text: 'Part 2' }, + { cachePoint: { type: 'default', ttl: '5m' } } + ] + result = described_class.new(parts: parts) + + expect(result).to be_a(RubyLLM::Content::Raw) + expect(result.value).to eq(parts) + end + + it 'raises ArgumentError when neither text nor parts provided' do + expect { described_class.new }.to raise_error(ArgumentError, 'text or parts required') + end + + it 'raises ArgumentError when text is nil and parts is nil' do + expect { described_class.new(nil) }.to raise_error(ArgumentError, 'text or parts required') + end + + it 'ignores text when parts is provided' do + parts = [{ text: 'From parts' }] + result = described_class.new('From text', parts: parts) + + expect(result.value).to eq(parts) + end + + it 'accepts parts as an array' do + parts = [{ text: 'Single part' }] + result = described_class.new(parts: parts) + + expect(result.value).to eq([{ text: 'Single part' }]) + end + end + + describe 'VALID_TTLS constant' do + it 'includes valid TTL values' do + expect(described_class::VALID_TTLS).to contain_exactly('5m', '1h') + end + end +end diff --git a/spec/ruby_llm/providers/bedrock_converse/streaming/content_extraction_spec.rb b/spec/ruby_llm/providers/bedrock_converse/streaming/content_extraction_spec.rb index 1dcad85a6..e30c3cb6f 100644 --- a/spec/ruby_llm/providers/bedrock_converse/streaming/content_extraction_spec.rb +++ b/spec/ruby_llm/providers/bedrock_converse/streaming/content_extraction_spec.rb @@ -328,43 +328,91 @@ def initialize end describe '#extract_cached_tokens' do - it 'returns nil (not supported in Converse API)' do + it 'extracts cacheReadInputTokenCount from usage' do data = { 'usage' => { - 'cachedTokens' => 100 + 'cacheReadInputTokenCount' => 100 } } result = extractor.extract_cached_tokens(data) - expect(result).to be_nil + expect(result).to eq(100) end - it 'returns nil for any data' do + it 'returns nil when usage is missing' do result = extractor.extract_cached_tokens({}) expect(result).to be_nil end + + it 'returns nil when cacheReadInputTokenCount is missing' do + data = { + 'usage' => { + 'inputTokens' => 50 + } + } + + result = extractor.extract_cached_tokens(data) + + expect(result).to be_nil + end + + it 'handles zero tokens' do + data = { + 'usage' => { + 'cacheReadInputTokenCount' => 0 + } + } + + result = extractor.extract_cached_tokens(data) + + expect(result).to eq(0) + end end describe '#extract_cache_creation_tokens' do - it 'returns nil (not supported in Converse API)' do + it 'extracts cacheWriteInputTokenCount from usage' do data = { 'usage' => { - 'cacheCreationTokens' => 50 + 'cacheWriteInputTokenCount' => 50 } } result = extractor.extract_cache_creation_tokens(data) - expect(result).to be_nil + expect(result).to eq(50) end - it 'returns nil for any data' do + it 'returns nil when usage is missing' do result = extractor.extract_cache_creation_tokens({}) expect(result).to be_nil end + + it 'returns nil when cacheWriteInputTokenCount is missing' do + data = { + 'usage' => { + 'inputTokens' => 50 + } + } + + result = extractor.extract_cache_creation_tokens(data) + + expect(result).to be_nil + end + + it 'handles zero tokens' do + data = { + 'usage' => { + 'cacheWriteInputTokenCount' => 0 + } + } + + result = extractor.extract_cache_creation_tokens(data) + + expect(result).to eq(0) + end end describe 'integration: extracting from realistic streaming data' do