From b605458fef7eb56f7101726e39d1b5ad4e54fddc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Sun, 11 Jan 2026 09:54:12 -0300 Subject: [PATCH 1/2] fix: Include token_type_ids in text classification serving Previously, text_classification set return_token_type_ids: false which broke sentence-pair inputs like query-document pairs used by cross-encoder rerankers. Now token_type_ids are included, making rerankers produce correct scores. Closes #251 --- lib/bumblebee/text/text_classification.ex | 5 ++-- .../text/text_classification_test.exs | 28 +++++++++++++++++++ 2 files changed, 31 insertions(+), 2 deletions(-) diff --git a/lib/bumblebee/text/text_classification.ex b/lib/bumblebee/text/text_classification.ex index be054ae8..c445397b 100644 --- a/lib/bumblebee/text/text_classification.ex +++ b/lib/bumblebee/text/text_classification.ex @@ -32,7 +32,7 @@ defmodule Bumblebee.Text.TextClassification do sequence_length = compile[:sequence_length] tokenizer = - Bumblebee.configure(tokenizer, length: sequence_length, return_token_type_ids: false) + Bumblebee.configure(tokenizer, length: sequence_length) {_init_fun, predict_fun} = Axon.build(model) @@ -58,7 +58,8 @@ defmodule Bumblebee.Text.TextClassification do inputs = %{ "input_ids" => Nx.template({batch_size, sequence_length}, :u32), - "attention_mask" => Nx.template({batch_size, sequence_length}, :u32) + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), + "token_type_ids" => Nx.template({batch_size, sequence_length}, :u32) } [params, inputs] diff --git a/test/bumblebee/text/text_classification_test.exs b/test/bumblebee/text/text_classification_test.exs index dd1525c1..f9ca9f04 100644 --- a/test/bumblebee/text/text_classification_test.exs +++ b/test/bumblebee/text/text_classification_test.exs @@ -22,4 +22,32 @@ defmodule Bumblebee.Text.TextClassificationTest do ] } = Nx.Serving.run(serving, text) end + + test "scores sentence pairs correctly for cross-encoder reranking" do + {:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"}) + + serving = + Bumblebee.Text.TextClassification.text_classification(model_info, tokenizer, + scores_function: :none + ) + + query = "How many people live in Berlin?" + + # Relevant document should score higher than irrelevant + %{predictions: [%{score: relevant_score}]} = + Nx.Serving.run( + serving, + {query, "Berlin has a population of 3,520,031 registered inhabitants."} + ) + + %{predictions: [%{score: irrelevant_score}]} = + Nx.Serving.run(serving, {query, "New York City is famous for its skyscrapers."}) + + assert relevant_score > irrelevant_score + + # Verify scores match Python sentence-transformers reference values + assert_in_delta relevant_score, 8.76, 0.01 + assert_in_delta irrelevant_score, -11.24, 0.01 + end end From 15cb6577c1a43a0d17616277c7f7413a19c0ed3e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Sun, 11 Jan 2026 10:22:35 -0300 Subject: [PATCH 2/2] feat: Add cross_encoding serving for scoring text pairs --- lib/bumblebee/text.ex | 62 +++++++++++++ lib/bumblebee/text/cross_encoding.ex | 96 +++++++++++++++++++++ test/bumblebee/text/cross_encoding_test.exs | 36 ++++++++ 3 files changed, 194 insertions(+) create mode 100644 lib/bumblebee/text/cross_encoding.ex create mode 100644 test/bumblebee/text/cross_encoding_test.exs diff --git a/lib/bumblebee/text.ex b/lib/bumblebee/text.ex index f40e6c82..c92371d2 100644 --- a/lib/bumblebee/text.ex +++ b/lib/bumblebee/text.ex @@ -361,6 +361,68 @@ defmodule Bumblebee.Text do defdelegate text_classification(model_info, tokenizer, opts \\ []), to: Bumblebee.Text.TextClassification + @type cross_encoding_input :: {String.t(), String.t()} + @type cross_encoding_output :: %{score: number()} + + @doc """ + Builds serving for cross-encoder models. + + Cross-encoders score text pairs by encoding them jointly through a + transformer with full cross-attention. This is commonly used for + reranking search results, semantic similarity, and natural language + inference tasks. + + The serving accepts `t:cross_encoding_input/0` and returns + `t:cross_encoding_output/0`. A list of inputs is also supported. + + ## Options + + * `:compile` - compiles all computations for predefined input shapes + during serving initialization. Should be a keyword list with the + following keys: + + * `:batch_size` - the maximum batch size of the input. Inputs + are optionally padded to always match this batch size + + * `:sequence_length` - the maximum input sequence length. Input + sequences are always padded/truncated to match that length. + A list can be given, in which case the serving compiles + a separate computation for each length and then inputs are + matched to the smallest bounding length + + It is advised to set this option in production and also configure + a defn compiler using `:defn_options` to maximally reduce inference + time. + + * `:defn_options` - the options for JIT compilation. Defaults to `[]` + + * `:preallocate_params` - when `true`, explicitly allocates params + on the device configured by `:defn_options`. You may want to set + this option when using partitioned serving, to allocate params + on each of the devices. When using this option, you should first + load the parameters into the host. This can be done by passing + `backend: {EXLA.Backend, client: :host}` to `load_model/1` and friends. + Defaults to `false` + + ## Examples + + {:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"}) + + serving = Bumblebee.Text.cross_encoding(model_info, tokenizer) + + Nx.Serving.run(serving, {"How many people live in Berlin?", "Berlin has a population of 3.5 million."}) + #=> %{score: 8.761} + + """ + @spec cross_encoding( + Bumblebee.model_info(), + Bumblebee.Tokenizer.t(), + keyword() + ) :: Nx.Serving.t() + defdelegate cross_encoding(model_info, tokenizer, opts \\ []), + to: Bumblebee.Text.CrossEncoding + @type text_embedding_input :: String.t() @type text_embedding_output :: %{embedding: Nx.Tensor.t()} diff --git a/lib/bumblebee/text/cross_encoding.ex b/lib/bumblebee/text/cross_encoding.ex new file mode 100644 index 00000000..3592a6fc --- /dev/null +++ b/lib/bumblebee/text/cross_encoding.ex @@ -0,0 +1,96 @@ +defmodule Bumblebee.Text.CrossEncoding do + @moduledoc false + + alias Bumblebee.Shared + + def cross_encoding(model_info, tokenizer, opts \\ []) do + %{model: model, params: params, spec: spec} = model_info + Shared.validate_architecture!(spec, :for_sequence_classification) + + opts = + Keyword.validate!(opts, [ + :compile, + defn_options: [], + preallocate_params: false + ]) + + preallocate_params = opts[:preallocate_params] + defn_options = opts[:defn_options] + + compile = + if compile = opts[:compile] do + compile + |> Keyword.validate!([:batch_size, :sequence_length]) + |> Shared.require_options!([:batch_size, :sequence_length]) + end + + batch_size = compile[:batch_size] + sequence_length = compile[:sequence_length] + + tokenizer = + Bumblebee.configure(tokenizer, length: sequence_length) + + {_init_fun, predict_fun} = Axon.build(model) + + scores_fun = fn params, input -> + outputs = predict_fun.(params, input) + Nx.squeeze(outputs.logits, axes: [-1]) + end + + batch_keys = Shared.sequence_batch_keys(sequence_length) + + Nx.Serving.new( + fn batch_key, defn_options -> + params = Shared.maybe_preallocate(params, preallocate_params, defn_options) + + scope = {:cross_encoding, batch_key} + + scores_fun = + Shared.compile_or_jit(scores_fun, scope, defn_options, compile != nil, fn -> + {:sequence_length, sequence_length} = batch_key + + inputs = %{ + "input_ids" => Nx.template({batch_size, sequence_length}, :u32), + "attention_mask" => Nx.template({batch_size, sequence_length}, :u32), + "token_type_ids" => Nx.template({batch_size, sequence_length}, :u32) + } + + [params, inputs] + end) + + fn inputs -> + inputs = Shared.maybe_pad(inputs, batch_size) + scores_fun.(params, inputs) |> Shared.serving_post_computation() + end + end, + defn_options + ) + |> Nx.Serving.batch_size(batch_size) + |> Nx.Serving.process_options(batch_keys: batch_keys) + |> Nx.Serving.client_preprocessing(fn input -> + {pairs, multi?} = Shared.validate_serving_input!(input, &validate_pair/1) + + inputs = + Nx.with_default_backend(Nx.BinaryBackend, fn -> + Bumblebee.apply_tokenizer(tokenizer, pairs) + end) + + batch_key = Shared.sequence_batch_key_for_inputs(inputs, sequence_length) + batch = [inputs] |> Nx.Batch.concatenate() |> Nx.Batch.key(batch_key) + + {batch, multi?} + end) + |> Nx.Serving.client_postprocessing(fn {scores, _metadata}, multi? -> + scores + |> Nx.to_list() + |> Enum.map(&%{score: &1}) + |> Shared.normalize_output(multi?) + end) + end + + defp validate_pair({text1, text2}) when is_binary(text1) and is_binary(text2), + do: {:ok, {text1, text2}} + + defp validate_pair(value), + do: {:error, "expected a {string, string} pair, got: #{inspect(value)}"} +end diff --git a/test/bumblebee/text/cross_encoding_test.exs b/test/bumblebee/text/cross_encoding_test.exs new file mode 100644 index 00000000..b57d22d7 --- /dev/null +++ b/test/bumblebee/text/cross_encoding_test.exs @@ -0,0 +1,36 @@ +defmodule Bumblebee.Text.CrossEncodingTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag serving_test_tags() + + test "scores sentence pairs" do + {:ok, model_info} = Bumblebee.load_model({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"}) + {:ok, tokenizer} = Bumblebee.load_tokenizer({:hf, "cross-encoder/ms-marco-MiniLM-L-6-v2"}) + + serving = Bumblebee.Text.cross_encoding(model_info, tokenizer) + + query = "How many people live in Berlin?" + + # Single pair + assert %{score: score} = + Nx.Serving.run( + serving, + {query, "Berlin has a population of 3,520,031 registered inhabitants."} + ) + + assert_in_delta score, 8.76, 0.01 + + # Multiple pairs (batch) + assert [%{score: relevant_score}, %{score: irrelevant_score}] = + Nx.Serving.run(serving, [ + {query, "Berlin has a population of 3,520,031 registered inhabitants."}, + {query, "New York City is famous for its skyscrapers."} + ]) + + assert relevant_score > irrelevant_score + assert_in_delta relevant_score, 8.76, 0.01 + assert_in_delta irrelevant_score, -11.24, 0.01 + end +end