Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions lib/ruby_llm/models.rb
Original file line number Diff line number Diff line change
Expand Up @@ -120,8 +120,10 @@ def resolve(model_id, provider: nil, assume_exists: false, config: nil) # ruboco
model = Models.find model_id, provider
rescue ModelNotFoundError
# Allow raw model IDs for Bedrock and BedrockConverse (they use ARN-style IDs not in registry)
if %w[bedrock bedrock_converse].include?(provider.to_s)
provider_class = Provider.providers[provider.to_sym]
if %w[bedrock bedrock_converse bedrockconverse].include?(provider.to_s)
# Normalize bedrockconverse -> bedrock_converse for provider lookup
provider_key = provider.to_s == 'bedrockconverse' ? :bedrock_converse : provider.to_sym
provider_class = Provider.providers[provider_key]
provider_instance = provider_class.new(config)
model = Model::Info.default(model_id, provider_instance.slug)
return [model, provider_instance]
Expand Down
2 changes: 2 additions & 0 deletions lib/ruby_llm/providers/bedrock_converse.rb
Original file line number Diff line number Diff line change
Expand Up @@ -80,3 +80,5 @@ def configuration_requirements
end
end
end

require_relative 'bedrock_converse/content'
5 changes: 4 additions & 1 deletion lib/ruby_llm/providers/bedrock_converse/chat.rb
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,8 @@ def build_converse_system_content(system_messages)
content = msg.content

if content.is_a?(RubyLLM::Content::Raw)
[{ text: content.value.to_s }]
# Preserve Raw content (e.g., with cachePoint blocks) as-is
Array(content.value)
else
[{ text: Media.format_content(msg.content).map { |c| c[:text] || c.to_s }.join }]
end
Expand Down Expand Up @@ -304,6 +305,8 @@ def build_converse_message(data, content, tool_use_blocks, response)
tool_calls: parse_converse_tool_calls(tool_use_blocks),
input_tokens: usage['inputTokens'],
output_tokens: usage['outputTokens'],
cached_tokens: usage['cacheReadInputTokenCount'],
cache_creation_tokens: usage['cacheWriteInputTokenCount'],
model_id: @model_id,
raw: response
)
Expand Down
32 changes: 32 additions & 0 deletions lib/ruby_llm/providers/bedrock_converse/content.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# frozen_string_literal: true

module RubyLLM
module Providers
class BedrockConverse
# Helper class for building content blocks with cache points for AWS Bedrock Converse API.
# Bedrock uses cachePoint blocks instead of Anthropic's cache_control format.
#
# @example Cache system instructions with default TTL
# content = BedrockConverse::Content.new("System prompt", cache: true)
#
# @example Cache with explicit TTL
# content = BedrockConverse::Content.new("System prompt", ttl: '1h')
class Content
VALID_TTLS = %w[5m 1h].freeze

def self.new(text = nil, cache: false, ttl: nil, parts: nil)
payload = if parts
Array(parts)
else
raise ArgumentError, 'text or parts required' if text.nil?

blocks = [{ text: text }]
blocks << { cachePoint: { type: 'default', ttl: ttl || '5m' } } if cache || ttl
blocks
end
RubyLLM::Content::Raw.new(payload)
end
end
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,11 @@ def extract_output_tokens(data)
end

def extract_cached_tokens(data)
# Converse API doesn't expose cache metrics in the same way
nil
data.dig('usage', 'cacheReadInputTokenCount')
end

def extract_cache_creation_tokens(data)
# Converse API doesn't expose cache metrics in the same way
nil
data.dig('usage', 'cacheWriteInputTokenCount')
end
end
end
Expand Down
61 changes: 60 additions & 1 deletion spec/ruby_llm/providers/bedrock_converse/chat_spec.rb
Original file line number Diff line number Diff line change
Expand Up @@ -342,7 +342,22 @@ def initialize

result = chat_instance.send(:build_converse_system_content, [message])

expect(result.first[:text]).to include('raw system prompt')
expect(result).to eq([{ text: 'raw system prompt' }])
end

it 'preserves cachePoint blocks in raw content' do
raw_content = RubyLLM::Content::Raw.new([
{ text: 'system prompt' },
{ cachePoint: { type: 'default', ttl: '5m' } }
])
message = RubyLLM::Message.new(role: :system, content: raw_content)

result = chat_instance.send(:build_converse_system_content, [message])

expect(result).to eq([
{ text: 'system prompt' },
{ cachePoint: { type: 'default', ttl: '5m' } }
])
end
end

Expand Down Expand Up @@ -599,6 +614,50 @@ def initialize
expect(message.input_tokens).to be_nil
expect(message.output_tokens).to be_nil
end

it 'extracts cache metrics from response' do
response_body = {
'output' => {
'message' => {
'content' => [{ 'text' => 'Response' }]
}
},
'usage' => {
'inputTokens' => 100,
'outputTokens' => 50,
'cacheReadInputTokenCount' => 80,
'cacheWriteInputTokenCount' => 20
}
}
response = instance_double(Faraday::Response, body: response_body)

message = chat_instance.send(:parse_converse_response, response)

expect(message.input_tokens).to eq(100)
expect(message.output_tokens).to eq(50)
expect(message.cached_tokens).to eq(80)
expect(message.cache_creation_tokens).to eq(20)
end

it 'handles missing cache metrics' do
response_body = {
'output' => {
'message' => {
'content' => [{ 'text' => 'Response' }]
}
},
'usage' => {
'inputTokens' => 100,
'outputTokens' => 50
}
}
response = instance_double(Faraday::Response, body: response_body)

message = chat_instance.send(:parse_converse_response, response)

expect(message.cached_tokens).to be_nil
expect(message.cache_creation_tokens).to be_nil
end
end

describe '.format_tool_for_converse' do
Expand Down
90 changes: 90 additions & 0 deletions spec/ruby_llm/providers/bedrock_converse/content_spec.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# frozen_string_literal: true

require 'spec_helper'

RSpec.describe RubyLLM::Providers::BedrockConverse::Content do
describe '.new' do
it 'creates a Raw content object with text only when cache is false' do
result = described_class.new('Hello, world!')

expect(result).to be_a(RubyLLM::Content::Raw)
expect(result.value).to eq([{ text: 'Hello, world!' }])
end

it 'creates a Raw content object with text and cachePoint when cache is true' do
result = described_class.new('System prompt', cache: true)

expect(result).to be_a(RubyLLM::Content::Raw)
expect(result.value).to eq([
{ text: 'System prompt' },
{ cachePoint: { type: 'default', ttl: '5m' } }
])
end

it 'creates a Raw content object with text and cachePoint when ttl is provided' do
result = described_class.new('System prompt', ttl: '1h')

expect(result).to be_a(RubyLLM::Content::Raw)
expect(result.value).to eq([
{ text: 'System prompt' },
{ cachePoint: { type: 'default', ttl: '1h' } }
])
end

it 'uses provided ttl over default when both cache and ttl are set' do
result = described_class.new('System prompt', cache: true, ttl: '1h')

expect(result).to be_a(RubyLLM::Content::Raw)
expect(result.value).to eq([
{ text: 'System prompt' },
{ cachePoint: { type: 'default', ttl: '1h' } }
])
end

it 'uses default 5m ttl when cache is true and no ttl provided' do
result = described_class.new('System prompt', cache: true)

expect(result.value[1][:cachePoint][:ttl]).to eq('5m')
end

it 'accepts parts directly' do
parts = [
{ text: 'Part 1' },
{ text: 'Part 2' },
{ cachePoint: { type: 'default', ttl: '5m' } }
]
result = described_class.new(parts: parts)

expect(result).to be_a(RubyLLM::Content::Raw)
expect(result.value).to eq(parts)
end

it 'raises ArgumentError when neither text nor parts provided' do
expect { described_class.new }.to raise_error(ArgumentError, 'text or parts required')
end

it 'raises ArgumentError when text is nil and parts is nil' do
expect { described_class.new(nil) }.to raise_error(ArgumentError, 'text or parts required')
end

it 'ignores text when parts is provided' do
parts = [{ text: 'From parts' }]
result = described_class.new('From text', parts: parts)

expect(result.value).to eq(parts)
end

it 'accepts parts as an array' do
parts = [{ text: 'Single part' }]
result = described_class.new(parts: parts)

expect(result.value).to eq([{ text: 'Single part' }])
end
end

describe 'VALID_TTLS constant' do
it 'includes valid TTL values' do
expect(described_class::VALID_TTLS).to contain_exactly('5m', '1h')
end
end
end
Original file line number Diff line number Diff line change
Expand Up @@ -328,43 +328,91 @@ def initialize
end

describe '#extract_cached_tokens' do
it 'returns nil (not supported in Converse API)' do
it 'extracts cacheReadInputTokenCount from usage' do
data = {
'usage' => {
'cachedTokens' => 100
'cacheReadInputTokenCount' => 100
}
}

result = extractor.extract_cached_tokens(data)

expect(result).to be_nil
expect(result).to eq(100)
end

it 'returns nil for any data' do
it 'returns nil when usage is missing' do
result = extractor.extract_cached_tokens({})

expect(result).to be_nil
end

it 'returns nil when cacheReadInputTokenCount is missing' do
data = {
'usage' => {
'inputTokens' => 50
}
}

result = extractor.extract_cached_tokens(data)

expect(result).to be_nil
end

it 'handles zero tokens' do
data = {
'usage' => {
'cacheReadInputTokenCount' => 0
}
}

result = extractor.extract_cached_tokens(data)

expect(result).to eq(0)
end
end

describe '#extract_cache_creation_tokens' do
it 'returns nil (not supported in Converse API)' do
it 'extracts cacheWriteInputTokenCount from usage' do
data = {
'usage' => {
'cacheCreationTokens' => 50
'cacheWriteInputTokenCount' => 50
}
}

result = extractor.extract_cache_creation_tokens(data)

expect(result).to be_nil
expect(result).to eq(50)
end

it 'returns nil for any data' do
it 'returns nil when usage is missing' do
result = extractor.extract_cache_creation_tokens({})

expect(result).to be_nil
end

it 'returns nil when cacheWriteInputTokenCount is missing' do
data = {
'usage' => {
'inputTokens' => 50
}
}

result = extractor.extract_cache_creation_tokens(data)

expect(result).to be_nil
end

it 'handles zero tokens' do
data = {
'usage' => {
'cacheWriteInputTokenCount' => 0
}
}

result = extractor.extract_cache_creation_tokens(data)

expect(result).to eq(0)
end
end

describe 'integration: extracting from realistic streaming data' do
Expand Down