Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
62 changes: 62 additions & 0 deletions lib/bumblebee/text.ex
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,68 @@ defmodule Bumblebee.Text do
defdelegate text_classification(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.TextClassification

@type cross_encoding_input :: {String.t(), String.t()}
@type cross_encoding_output :: %{score: number()}

@doc """
Builds serving for cross-encoder models.

Cross-encoders score text pairs by encoding them jointly through a
transformer with full cross-attention. This is commonly used for
reranking search results, semantic similarity, and natural language
inference tasks.

The serving accepts `t:cross_encoding_input/0` and returns
`t:cross_encoding_output/0`. A list of inputs is also supported.

## Options

* `:compile` - compiles all computations for predefined input shapes
during serving initialization. Should be a keyword list with the
following keys:

* `:batch_size` - the maximum batch size of the input. Inputs
are optionally padded to always match this batch size

* `:sequence_length` - the maximum input sequence length. Input
sequences are always padded/truncated to match that length.
A list can be given, in which case the serving compiles
a separate computation for each length and then inputs are
matched to the smallest bounding length

It is advised to set this option in production and also configure
a defn compiler using `:defn_options` to maximally reduce inference
time.

* `:defn_options` - the options for JIT compilation. Defaults to `[]`

* `:preallocate_params` - when `true`, explicitly allocates params
on the device configured by `:defn_options`. You may want to set
this option when using partitioned serving, to allocate params
on each of the devices. When using this option, you should first
load the parameters into the host. This can be done by passing
`backend: {EXLA.Backend, client: :host}` to `load_model/1` and friends.
Defaults to `false`

## Examples

{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})

serving = Bumblebee.Text.cross_encoding(model_info, tokenizer)

Nx.Serving.run(serving, {"How many people live in Berlin?", "Berlin has a population of 3.5 million."})
#=> %{score: 8.761}

"""
@spec cross_encoding(
Bumblebee.model_info(),
Bumblebee.Tokenizer.t(),
keyword()
) :: Nx.Serving.t()
defdelegate cross_encoding(model_info, tokenizer, opts \\ []),
to: Bumblebee.Text.CrossEncoding

@type text_embedding_input :: String.t()
@type text_embedding_output :: %{embedding: Nx.Tensor.t()}

Expand Down
96 changes: 96 additions & 0 deletions lib/bumblebee/text/cross_encoding.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
defmodule Bumblebee.Text.CrossEncoding do
@moduledoc false

alias Bumblebee.Shared

def cross_encoding(model_info, tokenizer, opts \\ []) do
%{model: model, params: params, spec: spec} = model_info
Shared.validate_architecture!(spec, :for_sequence_classification)

opts =
Keyword.validate!(opts, [
:compile,
defn_options: [],
preallocate_params: false
])

preallocate_params = opts[:preallocate_params]
defn_options = opts[:defn_options]

compile =
if compile = opts[:compile] do
compile
|> Keyword.validate!([:batch_size, :sequence_length])
|> Shared.require_options!([:batch_size, :sequence_length])
end

batch_size = compile[:batch_size]
sequence_length = compile[:sequence_length]

tokenizer =
Bumblebee.configure(tokenizer, length: sequence_length)

{_init_fun, predict_fun} = Axon.build(model)

scores_fun = fn params, input ->
outputs = predict_fun.(params, input)
Nx.squeeze(outputs.logits, axes: [-1])
end

batch_keys = Shared.sequence_batch_keys(sequence_length)

Nx.Serving.new(
fn batch_key, defn_options ->
params = Shared.maybe_preallocate(params, preallocate_params, defn_options)

scope = {:cross_encoding, batch_key}

scores_fun =
Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn ->
{:sequence_length, sequence_length} = batch_key

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32),
"token_type_ids" => Nx.template({batch_size, sequence_length}, :u32)
}

[params, inputs]
end)

fn inputs ->
inputs = Shared.maybe_pad(inputs, batch_size)
scores_fun.(params, inputs) |> Shared.serving_post_computation()
end
end,
defn_options
)
|> Nx.Serving.batch_size(batch_size)
|> Nx.Serving.process_options(batch_keys: batch_keys)
|> Nx.Serving.client_preprocessing(fn input ->
{pairs, multi?} = Shared.validate_serving_input!(input, &validate_pair/1)

inputs =
Nx.with_default_backend(Nx.BinaryBackend, fn ->
Bumblebee.apply_tokenizer(tokenizer, pairs)
end)

batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length)
batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key)

{batch, multi?}
end)
|> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? ->
scores
|> Nx.to_list()
|> Enum.map(&%{score: &1})
|> Shared.normalize_output(multi?)
end)
end

defp validate_pair({text1, text2}) when is_binary(text1) and is_binary(text2),
do: {:ok, {text1, text2}}

defp validate_pair(value),
do: {:error, "expected a {string, string} pair, got: #{inspect(value)}"}
end
5 changes: 3 additions & 2 deletions lib/bumblebee/text/text_classification.ex
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ defmodule Bumblebee.Text.TextClassification do
sequence_length = compile[:sequence_length]

tokenizer =
Bumblebee.configure(tokenizer, length: sequence_length, return_token_type_ids: false)
Bumblebee.configure(tokenizer, length: sequence_length)

{_init_fun, predict_fun} = Axon.build(model)

Expand All @@ -58,7 +58,8 @@ defmodule Bumblebee.Text.TextClassification do

inputs = %{
"input_ids" => Nx.template({batch_size, sequence_length}, :u32),
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32)
"attention_mask" => Nx.template({batch_size, sequence_length}, :u32),
"token_type_ids" => Nx.template({batch_size, sequence_length}, :u32)
}

[params, inputs]
Expand Down
36 changes: 36 additions & 0 deletions test/bumblebee/text/cross_encoding_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
defmodule Bumblebee.Text.CrossEncodingTest do
use ExUnit.Case, async: true

import Bumblebee.TestHelpers

@moduletag serving_test_tags()

test "scores sentence pairs" do
{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})

serving = Bumblebee.Text.cross_encoding(model_info, tokenizer)

query = "How many people live in Berlin?"

# Single pair
assert %{score: score} =
Nx.Serving.run(
serving,
{query, "Berlin has a population of 3,520,031 registered inhabitants."}
)

assert_in_delta score, 8.76, 0.01

# Multiple pairs (batch)
assert [%{score: relevant_score}, %{score: irrelevant_score}] =
Nx.Serving.run(serving, [
{query, "Berlin has a population of 3,520,031 registered inhabitants."},
{query, "New York City is famous for its skyscrapers."}
])

assert relevant_score > irrelevant_score
assert_in_delta relevant_score, 8.76, 0.01
assert_in_delta irrelevant_score, -11.24, 0.01
end
end
28 changes: 28 additions & 0 deletions test/bumblebee/text/text_classification_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,32 @@ defmodule Bumblebee.Text.TextClassificationTest do
]
} = Nx.Serving.run(serving, text)
end

test "scores sentence pairs correctly for cross-encoder reranking" do
{:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})
{:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"})

serving =
Bumblebee.Text.TextClassification.text_classification(model_info, tokenizer,
scores_function: :none
)

query = "How many people live in Berlin?"

# Relevant document should score higher than irrelevant
%{predictions: [%{score: relevant_score}]} =
Nx.Serving.run(
serving,
{query, "Berlin has a population of 3,520,031 registered inhabitants."}
)

%{predictions: [%{score: irrelevant_score}]} =
Nx.Serving.run(serving, {query, "New York City is famous for its skyscrapers."})

assert relevant_score > irrelevant_score

# Verify scores match Python sentence-transformers reference values
assert_in_delta relevant_score, 8.76, 0.01
assert_in_delta irrelevant_score, -11.24, 0.01
end
end
Loading