From 513b16ce63c82e56612e2bbe9fa0e4984378c281 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Sat, 3 Jan 2026 18:13:57 -0300 Subject: [PATCH 01/11] feat: Add Nomic BERT model support Add support for nomic-ai/nomic-embed-text-v1.5 embedding model. Architecture: - Postnorm transformer (standard BERT-style) - SwiGLU activation (up * silu(gate)) - Rotary position embeddings (RoPE) with base 1000 - Combined Wqkv projection - No biases in attention and FFN layers - Mean pooling over non-masked tokens Tested against Python transformers with ~2e-6 precision. --- lib/bumblebee.ex | 2 + lib/bumblebee/text/nomic_bert.ex | 411 ++++++++++++++++++++++++ mix.exs | 1 + test/bumblebee/text/nomic_bert_test.exs | 42 +++ 4 files changed, 456 insertions(+) create mode 100644 lib/bumblebee/text/nomic_bert.ex create mode 100644 test/bumblebee/text/nomic_bert_test.exs diff --git a/lib/bumblebee.ex b/lib/bumblebee.ex index 13928ee1..7a78d7df 100644 --- a/lib/bumblebee.ex +++ b/lib/bumblebee.ex @@ -176,6 +176,7 @@ defmodule Bumblebee do "MPNetForTokenClassification" => {Bumblebee.Text.MpNet, :for_token_classification}, "MPNetForQuestionAnswering" => {Bumblebee.Text.MpNet, :for_question_answering}, "MPNetForMultipleChoice" => {Bumblebee.Text.MpNet, :for_multiple_choice}, + "NomicBertModel" => {Bumblebee.Text.NomicBert, :base}, "PhiModel" => {Bumblebee.Text.Phi, :base}, "PhiForCausalLM" => {Bumblebee.Text.Phi, :for_causal_language_modeling}, "PhiForSequenceClassification" => {Bumblebee.Text.Phi, :for_sequence_classification}, @@ -266,6 +267,7 @@ defmodule Bumblebee do "mistral" => :llama, "mbart" => :mbart, "mpnet" => :mpnet, + "nomic_bert" => :bert, "phi" => :code_gen, "phi3" => :llama, "qwen3" => :qwen2, diff --git a/lib/bumblebee/text/nomic_bert.ex b/lib/bumblebee/text/nomic_bert.ex new file mode 100644 index 00000000..2d78aab6 --- /dev/null +++ b/lib/bumblebee/text/nomic_bert.ex @@ -0,0 +1,411 @@ +defmodule Bumblebee.Text.NomicBert do + alias Bumblebee.Shared + + options = + [ + vocab_size: [ + default: 30528, + doc: """ + the vocabulary size of the token embedding. This corresponds to the number of distinct + tokens that can be represented in model input and output + """ + ], + max_positions: [ + default: 8192, + doc: """ + the vocabulary size of the position embedding. This corresponds to the maximum sequence + length that this model can process. Typically this is set to a large value just in case, + such as 512, 1024 or 2048 + """ + ], + max_token_types: [ + default: 2, + doc: """ + the vocabulary size of the token type embedding (also referred to as segment embedding). + This corresponds to how many different token groups can be distinguished in the input + """ + ], + hidden_size: [ + default: 768, + doc: "the dimensionality of hidden layers" + ], + num_blocks: [ + default: 12, + doc: "the number of Transformer blocks in the encoder" + ], + num_attention_heads: [ + default: 12, + doc: "the number of attention heads for each attention layer in the encoder" + ], + intermediate_size: [ + default: 3072, + doc: + "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" + ], + activation: [ + default: :silu, + doc: "the activation function" + ], + rotary_embedding_base: [ + default: 10_000, + doc: "base for computing rotary embedding frequency" + ], + rotary_embedding_percentage: [ + default: 1.0, + doc: "percentage of hidden size to use for rotary embeddings" + ], + layer_norm_epsilon: [ + default: 1.0e-5, + doc: "the epsilon used by the layer normalization layers" + ], + initializer_scale: [ + default: 0.02, + doc: + "the standard deviation of the normal initializer used for initializing kernel parameters" + ] + ] ++ Shared.common_options([:num_labels, :id_to_label]) + + @moduledoc """ + Nomic BERT model family. + + This is a variant of BERT that uses: + - Rotary position embeddings (RoPE) instead of absolute position embeddings + - SwiGLU activation in the feed-forward network + - Pre-normalization (like GPT) instead of post-normalization + - No biases in attention and feed-forward layers + + ## Architectures + + * `:base` - plain Nomic BERT without any head on top + + ## Inputs + + * `"input_ids"` - `{batch_size, sequence_length}` + + Indices of input sequence tokens in the vocabulary. + + * `"attention_mask"` - `{batch_size, sequence_length}` + + Mask indicating which tokens to attend to. This is used to ignore + padding tokens, which are added when processing a batch of sequences + with different length. + + * `"position_ids"` - `{batch_size, sequence_length}` + + Indices of positions of each input sequence tokens in the position + embeddings. + + ## Global layer options + + #{Shared.global_layer_options_doc([:output_hidden_states, :output_attentions])} + + ## Configuration + + #{Shared.options_doc(options)} + """ + + defstruct [architecture: :base] ++ Shared.option_defaults(options) + + @behaviour Bumblebee.ModelSpec + @behaviour Bumblebee.Configurable + + import Bumblebee.Utils.Model, only: [join: 2] + + alias Bumblebee.Layers + + @impl true + def architectures(), do: [:base] + + @impl true + def config(spec, opts) do + spec + |> Shared.put_config_attrs(opts) + |> Shared.validate_label_options() + end + + @impl true + def input_template(_spec) do + %{"input_ids" => Nx.template({1, 1}, :u32)} + end + + @impl true + def model(%__MODULE__{architecture: :base} = spec) do + inputs = inputs(spec) + + inputs + |> core(spec) + |> Layers.output() + end + + defp inputs(spec) do + shape = {nil, nil} + attention_head_mask_shape = {spec.num_blocks, spec.num_attention_heads} + + Bumblebee.Utils.Model.inputs_to_map([ + Axon.input("input_ids", shape: shape), + Axon.input("attention_mask", optional: true, shape: shape), + Axon.input("token_type_ids", optional: true, shape: shape), + Axon.input("position_ids", optional: true, shape: shape), + Axon.input("attention_head_mask", optional: true, shape: attention_head_mask_shape) + ]) + end + + defp core(inputs, spec) do + token_type_ids = + Layers.default inputs["token_type_ids"] do + Layers.default_token_type_ids(inputs["input_ids"]) + end + + embeddings = embedder(inputs["input_ids"], token_type_ids, spec, name: "embedder") + + position_ids = + Layers.default inputs["position_ids"] do + Layers.default_position_ids(embeddings) + end + + encoder_outputs = + encoder( + embeddings, + position_ids, + inputs["attention_mask"], + inputs["attention_head_mask"], + spec, + name: "encoder" + ) + + # Mean pooling over non-masked tokens + pooled_state = + Layers.if_present inputs["attention_mask"] do + Axon.layer( + fn hidden_state, attention_mask, _opts -> + # Expand mask for broadcasting with hidden_size + mask = Nx.new_axis(attention_mask, -1) + # Mask out padding tokens + masked = Nx.multiply(hidden_state, mask) + # Sum and normalize by actual sequence length + sum = Nx.sum(masked, axes: [1]) + count = Nx.sum(mask, axes: [1]) + Nx.divide(sum, Nx.max(count, 1.0e-9)) + end, + [encoder_outputs.hidden_state, inputs["attention_mask"]] + ) + else + Axon.layer( + fn hidden_state, _opts -> + Nx.mean(hidden_state, axes: [1]) + end, + [encoder_outputs.hidden_state] + ) + end + + %{ + hidden_state: encoder_outputs.hidden_state, + pooled_state: pooled_state, + hidden_states: encoder_outputs.hidden_states, + attentions: encoder_outputs.attentions + } + end + + defp embedder(input_ids, token_type_ids, spec, opts) do + name = opts[:name] + + token_embeddings = + Axon.embedding(input_ids, spec.vocab_size, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_embedding") + ) + + token_type_embeddings = + Axon.embedding(token_type_ids, spec.max_token_types, spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "token_type_embedding") + ) + + Axon.add([token_embeddings, token_type_embeddings]) + |> Axon.layer_norm(epsilon: spec.layer_norm_epsilon, name: join(name, "norm")) + end + + defp encoder(hidden_state, position_ids, attention_mask, attention_head_mask, spec, opts) do + name = opts[:name] + + # Nomic BERT uses postnorm (like standard BERT): + # Each block: + # attn_output = attention(hidden_states) + # hidden_states = norm1(attn_output + hidden_states) + # ffn_output = ffn(hidden_states) + # hidden_states = norm2(ffn_output + hidden_states) + + initial_state = %{ + hidden_state: hidden_state, + hidden_states: Axon.container({hidden_state}), + attentions: Axon.container({}) + } + + final_state = + Enum.reduce(0..(spec.num_blocks - 1), initial_state, fn idx, state -> + block_name = join(name, "blocks.#{idx}") + + # Self-attention (no prenorm) + {attention_output, attention_weights, _cache, _bias} = + Layers.Transformer.multi_head_attention( + state.hidden_state, + state.hidden_state, + state.hidden_state, + attention_mask: attention_mask, + attention_head_mask: + Layers.if_present attention_head_mask do + Axon.nx(attention_head_mask, & &1[idx]) + end, + num_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + causal: false, + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + percentage: spec.rotary_embedding_percentage + ], + name: join(block_name, "self_attention") + ) + + # Residual + Norm after attention (postnorm) + hidden_state = Axon.add(attention_output, state.hidden_state) + + hidden_state = + Axon.layer_norm(hidden_state, + epsilon: spec.layer_norm_epsilon, + name: join(block_name, "self_attention_norm") + ) + + # FFN + ffn_output = + gated_ffn(hidden_state, spec.intermediate_size, spec.hidden_size, + name: join(block_name, "ffn"), + activation: spec.activation + ) + + # Residual + Norm after FFN (postnorm) + hidden_state = Axon.add(ffn_output, hidden_state) + + hidden_state = + Axon.layer_norm(hidden_state, + epsilon: spec.layer_norm_epsilon, + name: join(block_name, "output_norm") + ) + + %{ + hidden_state: hidden_state, + hidden_states: Layers.append(state.hidden_states, hidden_state), + attentions: Layers.append(state.attentions, attention_weights) + } + end) + + %{ + hidden_state: final_state.hidden_state, + hidden_states: final_state.hidden_states, + attentions: final_state.attentions + } + end + + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do + name = opts[:name] + activation = opts[:activation] + + # Nomic MLP: y = fc11(x) * activation(fc12(x)), then fc2 + # fc11 is "up", fc12 is "gate", fc2 is "down" + up = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "up"), + use_bias: false + ) + + gate = + Axon.dense(hidden_state, intermediate_size, + name: join(name, "gate"), + use_bias: false + ) + + # Nomic applies activation to gate, not up: up * activation(gate) + hidden_state = Axon.multiply(up, Axon.activation(gate, activation)) + + Axon.dense(hidden_state, output_size, name: join(name, "down"), use_bias: false) + end + + defp kernel_initializer(spec) do + Axon.Initializers.normal(scale: spec.initializer_scale) + end + + defimpl Bumblebee.HuggingFace.Transformers.Config do + def load(spec, data) do + import Shared.Converters + + opts = + convert!(data, + vocab_size: {"vocab_size", number()}, + max_positions: {"n_positions", number()}, + max_token_types: {"type_vocab_size", number()}, + hidden_size: {"n_embd", number()}, + num_blocks: {"n_layer", number()}, + num_attention_heads: {"n_head", number()}, + intermediate_size: {"n_inner", optional(number())}, + rotary_embedding_base: {"rotary_emb_base", number()}, + rotary_embedding_percentage: {"rotary_emb_fraction", optional(number())}, + layer_norm_epsilon: {"layer_norm_epsilon", number()}, + initializer_scale: {"initializer_range", number()} + ) ++ Shared.common_options_from_transformers(data, spec) + + # Nomic uses 4 * n_embd for intermediate size if not specified + opts = + if opts[:intermediate_size] do + opts + else + hidden_size = opts[:hidden_size] || spec.hidden_size + Keyword.put(opts, :intermediate_size, 4 * hidden_size) + end + + @for.config(spec, opts) + end + end + + defimpl Bumblebee.HuggingFace.Transformers.Model do + def params_mapping(_spec) do + %{ + "embedder.token_embedding" => "embeddings.word_embeddings", + "embedder.token_type_embedding" => "embeddings.token_type_embeddings", + "embedder.norm" => "emb_ln", + "encoder.blocks.{n}.self_attention.query" => qkv_dense("encoder.layers.{n}.attn.Wqkv", 0), + "encoder.blocks.{n}.self_attention.key" => qkv_dense("encoder.layers.{n}.attn.Wqkv", 1), + "encoder.blocks.{n}.self_attention.value" => qkv_dense("encoder.layers.{n}.attn.Wqkv", 2), + "encoder.blocks.{n}.self_attention.output" => "encoder.layers.{n}.attn.out_proj", + "encoder.blocks.{n}.self_attention_norm" => "encoder.layers.{n}.norm1", + "encoder.blocks.{n}.ffn.up" => "encoder.layers.{n}.mlp.fc11", + "encoder.blocks.{n}.ffn.gate" => "encoder.layers.{n}.mlp.fc12", + "encoder.blocks.{n}.ffn.down" => "encoder.layers.{n}.mlp.fc2", + "encoder.blocks.{n}.output_norm" => "encoder.layers.{n}.norm2" + } + end + + defp qkv_dense(source_layer_name, chunk_idx) do + # Wqkv is [3 * hidden_size, hidden_size] in PyTorch format + # After slicing, transpose to get [hidden_size, hidden_size] for Axon + %{ + "kernel" => { + [{source_layer_name, "weight"}], + fn [kernel] -> + size = Nx.axis_size(kernel, 0) + step = div(size, 3) + + kernel + |> Nx.slice_along_axis(chunk_idx * step, step, axis: 0) + |> Nx.transpose() + end + } + } + end + end +end diff --git a/mix.exs b/mix.exs index 47b2f90f..089c7f9d 100644 --- a/mix.exs +++ b/mix.exs @@ -101,6 +101,7 @@ defmodule Bumblebee.MixProject do Bumblebee.Text.Mbart, Bumblebee.Text.Mistral, Bumblebee.Text.MpNet, + Bumblebee.Text.NomicBert, Bumblebee.Text.Phi, Bumblebee.Text.Phi3, Bumblebee.Text.Roberta, diff --git a/test/bumblebee/text/nomic_bert_test.exs b/test/bumblebee/text/nomic_bert_test.exs new file mode 100644 index 00000000..e2279c74 --- /dev/null +++ b/test/bumblebee/text/nomic_bert_test.exs @@ -0,0 +1,42 @@ +defmodule Bumblebee.Text.NomicBertTest do + use ExUnit.Case, async: true + + import Bumblebee.TestHelpers + + @moduletag model_test_tags() + + @tag :slow + test ":base" do + assert {:ok, %{model: model, params: params, spec: spec}} = + Bumblebee.load_model({:hf, "nomic-ai/nomic-embed-text-v1.5"}) + + assert %Bumblebee.Text.NomicBert{architecture: :base} = spec + assert spec.hidden_size == 768 + assert spec.num_blocks == 12 + assert spec.num_attention_heads == 12 + assert spec.rotary_embedding_base == 1000 + + inputs = %{ + "input_ids" => Nx.tensor([[101, 2023, 2003, 1037, 3231, 102]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1]]) + } + + outputs = Axon.predict(model, params, inputs) + + assert Nx.shape(outputs.hidden_state) == {1, 6, 768} + assert Nx.shape(outputs.pooled_state) == {1, 768} + + # Values verified against Python transformers + assert_all_close( + outputs.hidden_state[[.., 0, 0..4]], + Nx.tensor([[1.3752, 0.7431, -4.6988, -0.6574, 2.1887]]), + atol: 1.0e-3 + ) + + assert_all_close( + outputs.pooled_state[[.., 0..4]], + Nx.tensor([[1.0917, 0.5968, -3.9347, -0.6988, 1.5423]]), + atol: 1.0e-3 + ) + end +end From fa9ac243d8e8ff3940b2fc244445abf500b17a07 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Sun, 4 Jan 2026 13:11:32 -0300 Subject: [PATCH 02/11] fix: Address PR review comments for Nomic BERT documentation - Fix max_positions doc: remove incorrect "vocabulary size" reference - Fix rotary_embedding_base default: change from 10_000 to 1000 - Fix normalization doc: correct "pre-normalization" to "post-normalization" - Fix position_ids doc: clarify usage with RoPE instead of position embeddings --- lib/bumblebee/text/nomic_bert.ex | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/lib/bumblebee/text/nomic_bert.ex b/lib/bumblebee/text/nomic_bert.ex index 2d78aab6..01efa351 100644 --- a/lib/bumblebee/text/nomic_bert.ex +++ b/lib/bumblebee/text/nomic_bert.ex @@ -13,9 +13,8 @@ defmodule Bumblebee.Text.NomicBert do max_positions: [ default: 8192, doc: """ - the vocabulary size of the position embedding. This corresponds to the maximum sequence - length that this model can process. Typically this is set to a large value just in case, - such as 512, 1024 or 2048 + the maximum sequence length that this model can process. Typically this is set to a large + value just in case, such as 512, 1024 or 2048 """ ], max_token_types: [ @@ -47,7 +46,7 @@ defmodule Bumblebee.Text.NomicBert do doc: "the activation function" ], rotary_embedding_base: [ - default: 10_000, + default: 1000, doc: "base for computing rotary embedding frequency" ], rotary_embedding_percentage: [ @@ -71,7 +70,7 @@ defmodule Bumblebee.Text.NomicBert do This is a variant of BERT that uses: - Rotary position embeddings (RoPE) instead of absolute position embeddings - SwiGLU activation in the feed-forward network - - Pre-normalization (like GPT) instead of post-normalization + - Post-normalization (like original BERT) - No biases in attention and feed-forward layers ## Architectures @@ -92,8 +91,8 @@ defmodule Bumblebee.Text.NomicBert do * `"position_ids"` - `{batch_size, sequence_length}` - Indices of positions of each input sequence tokens in the position - embeddings. + Indices of positions of each input sequence token used when applying + rotary position embeddings (RoPE). ## Global layer options From 3874426e6f62df2f92df12e0350f37b0f3fb532f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 5 Jan 2026 16:40:38 -0300 Subject: [PATCH 03/11] Remove custom atol override in test --- test/bumblebee/text/nomic_bert_test.exs | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/test/bumblebee/text/nomic_bert_test.exs b/test/bumblebee/text/nomic_bert_test.exs index e2279c74..68ef0d6e 100644 --- a/test/bumblebee/text/nomic_bert_test.exs +++ b/test/bumblebee/text/nomic_bert_test.exs @@ -29,14 +29,12 @@ defmodule Bumblebee.Text.NomicBertTest do # Values verified against Python transformers assert_all_close( outputs.hidden_state[[.., 0, 0..4]], - Nx.tensor([[1.3752, 0.7431, -4.6988, -0.6574, 2.1887]]), - atol: 1.0e-3 + Nx.tensor([[1.3752, 0.7431, -4.6988, -0.6574, 2.1887]]) ) assert_all_close( outputs.pooled_state[[.., 0..4]], - Nx.tensor([[1.0917, 0.5968, -3.9347, -0.6988, 1.5423]]), - atol: 1.0e-3 + Nx.tensor([[1.0917, 0.5968, -3.9347, -0.6988, 1.5423]]) ) end end From 4aa7c0d6528e75d16da45c35d2de1161c511e06c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 5 Jan 2026 16:45:51 -0300 Subject: [PATCH 04/11] Use standard test inputs and slice pattern --- test/bumblebee/text/nomic_bert_test.exs | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/test/bumblebee/text/nomic_bert_test.exs b/test/bumblebee/text/nomic_bert_test.exs index 68ef0d6e..ee15d8de 100644 --- a/test/bumblebee/text/nomic_bert_test.exs +++ b/test/bumblebee/text/nomic_bert_test.exs @@ -17,24 +17,26 @@ defmodule Bumblebee.Text.NomicBertTest do assert spec.rotary_embedding_base == 1000 inputs = %{ - "input_ids" => Nx.tensor([[101, 2023, 2003, 1037, 3231, 102]]), - "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1]]) + "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), + "attention_mask" => Nx.tensor([[1, 1, 1, 1, 1, 1, 1, 1, 0, 0]]) } outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.hidden_state) == {1, 6, 768} + assert Nx.shape(outputs.hidden_state) == {1, 10, 768} assert Nx.shape(outputs.pooled_state) == {1, 768} # Values verified against Python transformers assert_all_close( - outputs.hidden_state[[.., 0, 0..4]], - Nx.tensor([[1.3752, 0.7431, -4.6988, -0.6574, 2.1887]]) + outputs.hidden_state[[.., 1..3, 1..3]], + Nx.tensor([[[0.0315, -5.2254, 0.0180], + [0.0877, -5.3772, 0.1800], + [-0.0546, -4.8813, 0.2614]]]) ) assert_all_close( - outputs.pooled_state[[.., 0..4]], - Nx.tensor([[1.0917, 0.5968, -3.9347, -0.6988, 1.5423]]) + outputs.pooled_state[[.., 1..3]], + Nx.tensor([[0.0340, -5.2018, 0.1686]]) ) end end From 3b8941308754cfd1dfda4607ea3af89cdfa9b81c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 5 Jan 2026 16:48:33 -0300 Subject: [PATCH 05/11] Use Layers.Transformer.blocks for encoder --- lib/bumblebee/text/nomic_bert.ex | 109 ++++++++----------------------- 1 file changed, 27 insertions(+), 82 deletions(-) diff --git a/lib/bumblebee/text/nomic_bert.ex b/lib/bumblebee/text/nomic_bert.ex index 01efa351..5f19a279 100644 --- a/lib/bumblebee/text/nomic_bert.ex +++ b/lib/bumblebee/text/nomic_bert.ex @@ -227,88 +227,33 @@ defmodule Bumblebee.Text.NomicBert do defp encoder(hidden_state, position_ids, attention_mask, attention_head_mask, spec, opts) do name = opts[:name] - # Nomic BERT uses postnorm (like standard BERT): - # Each block: - # attn_output = attention(hidden_states) - # hidden_states = norm1(attn_output + hidden_states) - # ffn_output = ffn(hidden_states) - # hidden_states = norm2(ffn_output + hidden_states) - - initial_state = %{ - hidden_state: hidden_state, - hidden_states: Axon.container({hidden_state}), - attentions: Axon.container({}) - } - - final_state = - Enum.reduce(0..(spec.num_blocks - 1), initial_state, fn idx, state -> - block_name = join(name, "blocks.#{idx}") - - # Self-attention (no prenorm) - {attention_output, attention_weights, _cache, _bias} = - Layers.Transformer.multi_head_attention( - state.hidden_state, - state.hidden_state, - state.hidden_state, - attention_mask: attention_mask, - attention_head_mask: - Layers.if_present attention_head_mask do - Axon.nx(attention_head_mask, & &1[idx]) - end, - num_heads: spec.num_attention_heads, - hidden_size: spec.hidden_size, - kernel_initializer: kernel_initializer(spec), - causal: false, - query_use_bias: false, - key_use_bias: false, - value_use_bias: false, - output_use_bias: false, - rotary_embedding: [ - position_ids: position_ids, - max_positions: spec.max_positions, - base: spec.rotary_embedding_base, - percentage: spec.rotary_embedding_percentage - ], - name: join(block_name, "self_attention") - ) - - # Residual + Norm after attention (postnorm) - hidden_state = Axon.add(attention_output, state.hidden_state) - - hidden_state = - Axon.layer_norm(hidden_state, - epsilon: spec.layer_norm_epsilon, - name: join(block_name, "self_attention_norm") - ) - - # FFN - ffn_output = - gated_ffn(hidden_state, spec.intermediate_size, spec.hidden_size, - name: join(block_name, "ffn"), - activation: spec.activation - ) - - # Residual + Norm after FFN (postnorm) - hidden_state = Axon.add(ffn_output, hidden_state) - - hidden_state = - Axon.layer_norm(hidden_state, - epsilon: spec.layer_norm_epsilon, - name: join(block_name, "output_norm") - ) - - %{ - hidden_state: hidden_state, - hidden_states: Layers.append(state.hidden_states, hidden_state), - attentions: Layers.append(state.attentions, attention_weights) - } - end) - - %{ - hidden_state: final_state.hidden_state, - hidden_states: final_state.hidden_states, - attentions: final_state.attentions - } + Layers.Transformer.blocks(hidden_state, + attention_mask: attention_mask, + attention_head_mask: attention_head_mask, + num_blocks: spec.num_blocks, + num_attention_heads: spec.num_attention_heads, + hidden_size: spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + layer_norm: [epsilon: spec.layer_norm_epsilon], + ffn: + &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + name: &2, + activation: spec.activation + ), + block_type: :standard, + causal: false, + rotary_embedding: [ + position_ids: position_ids, + max_positions: spec.max_positions, + base: spec.rotary_embedding_base, + percentage: spec.rotary_embedding_percentage + ], + query_use_bias: false, + key_use_bias: false, + value_use_bias: false, + output_use_bias: false, + name: join(name, "blocks") + ) end defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do From 9e208815572ea0efeb4d47c7be91db23d63a79fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 5 Jan 2026 16:51:35 -0300 Subject: [PATCH 06/11] Make intermediate_size optional, default to 4 * hidden_size --- lib/bumblebee/text/nomic_bert.ex | 15 +++------------ 1 file changed, 3 insertions(+), 12 deletions(-) diff --git a/lib/bumblebee/text/nomic_bert.ex b/lib/bumblebee/text/nomic_bert.ex index 5f19a279..a274ba2b 100644 --- a/lib/bumblebee/text/nomic_bert.ex +++ b/lib/bumblebee/text/nomic_bert.ex @@ -37,9 +37,9 @@ defmodule Bumblebee.Text.NomicBert do doc: "the number of attention heads for each attention layer in the encoder" ], intermediate_size: [ - default: 3072, + default: nil, doc: - "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder" + "the dimensionality of the intermediate layer in the transformer feed-forward network (FFN) in the encoder. Defaults to 4 * hidden_size" ], activation: [ default: :silu, @@ -236,7 +236,7 @@ defmodule Bumblebee.Text.NomicBert do kernel_initializer: kernel_initializer(spec), layer_norm: [epsilon: spec.layer_norm_epsilon], ffn: - &gated_ffn(&1, spec.intermediate_size, spec.hidden_size, + &gated_ffn(&1, spec.intermediate_size || 4 * spec.hidden_size, spec.hidden_size, name: &2, activation: spec.activation ), @@ -303,15 +303,6 @@ defmodule Bumblebee.Text.NomicBert do initializer_scale: {"initializer_range", number()} ) ++ Shared.common_options_from_transformers(data, spec) - # Nomic uses 4 * n_embd for intermediate size if not specified - opts = - if opts[:intermediate_size] do - opts - else - hidden_size = opts[:hidden_size] || spec.hidden_size - Keyword.put(opts, :intermediate_size, 4 * hidden_size) - end - @for.config(spec, opts) end end From e30d3db4b0ec88d3fc6649012f90cfe94a441766 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 5 Jan 2026 16:53:56 -0300 Subject: [PATCH 07/11] Import mlp_fc1_bias and mlp_fc2_bias config attributes --- lib/bumblebee/text/nomic_bert.ex | 24 +++++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/lib/bumblebee/text/nomic_bert.ex b/lib/bumblebee/text/nomic_bert.ex index a274ba2b..5025ee3a 100644 --- a/lib/bumblebee/text/nomic_bert.ex +++ b/lib/bumblebee/text/nomic_bert.ex @@ -61,6 +61,14 @@ defmodule Bumblebee.Text.NomicBert do default: 0.02, doc: "the standard deviation of the normal initializer used for initializing kernel parameters" + ], + ffn_gate_bias: [ + default: true, + doc: "whether to use bias in the up and gate projections of the FFN" + ], + ffn_output_bias: [ + default: true, + doc: "whether to use bias in the output projection of the FFN" ] ] ++ Shared.common_options([:num_labels, :id_to_label]) @@ -238,7 +246,9 @@ defmodule Bumblebee.Text.NomicBert do ffn: &gated_ffn(&1, spec.intermediate_size || 4 * spec.hidden_size, spec.hidden_size, name: &2, - activation: spec.activation + activation: spec.activation, + gate_use_bias: spec.ffn_gate_bias, + output_use_bias: spec.ffn_output_bias ), block_type: :standard, causal: false, @@ -259,25 +269,27 @@ defmodule Bumblebee.Text.NomicBert do defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do name = opts[:name] activation = opts[:activation] + gate_use_bias = opts[:gate_use_bias] + output_use_bias = opts[:output_use_bias] # Nomic MLP: y = fc11(x) * activation(fc12(x)), then fc2 # fc11 is "up", fc12 is "gate", fc2 is "down" up = Axon.dense(hidden_state, intermediate_size, name: join(name, "up"), - use_bias: false + use_bias: gate_use_bias ) gate = Axon.dense(hidden_state, intermediate_size, name: join(name, "gate"), - use_bias: false + use_bias: gate_use_bias ) # Nomic applies activation to gate, not up: up * activation(gate) hidden_state = Axon.multiply(up, Axon.activation(gate, activation)) - Axon.dense(hidden_state, output_size, name: join(name, "down"), use_bias: false) + Axon.dense(hidden_state, output_size, name: join(name, "down"), use_bias: output_use_bias) end defp kernel_initializer(spec) do @@ -300,7 +312,9 @@ defmodule Bumblebee.Text.NomicBert do rotary_embedding_base: {"rotary_emb_base", number()}, rotary_embedding_percentage: {"rotary_emb_fraction", optional(number())}, layer_norm_epsilon: {"layer_norm_epsilon", number()}, - initializer_scale: {"initializer_range", number()} + initializer_scale: {"initializer_range", number()}, + ffn_gate_bias: {"mlp_fc1_bias", boolean()}, + ffn_output_bias: {"mlp_fc2_bias", boolean()} ) ++ Shared.common_options_from_transformers(data, spec) @for.config(spec, opts) From 2ad695cabdbb520b82a2a2dacde51da4aef69a89 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Mon, 5 Jan 2026 17:03:31 -0300 Subject: [PATCH 08/11] Fix formatting in nomic_bert_test.exs --- test/bumblebee/text/nomic_bert_test.exs | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/bumblebee/text/nomic_bert_test.exs b/test/bumblebee/text/nomic_bert_test.exs index ee15d8de..a00503e1 100644 --- a/test/bumblebee/text/nomic_bert_test.exs +++ b/test/bumblebee/text/nomic_bert_test.exs @@ -29,9 +29,9 @@ defmodule Bumblebee.Text.NomicBertTest do # Values verified against Python transformers assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], - Nx.tensor([[[0.0315, -5.2254, 0.0180], - [0.0877, -5.3772, 0.1800], - [-0.0546, -4.8813, 0.2614]]]) + Nx.tensor([ + [[0.0315, -5.2254, 0.0180], [0.0877, -5.3772, 0.1800], [-0.0546, -4.8813, 0.2614]] + ]) ) assert_all_close( From e987c7574d5bd18adc10bd04b19147113e6f5401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Tue, 6 Jan 2026 11:20:21 -0300 Subject: [PATCH 09/11] Use tiny-random model in tests Round intermediate_size to multiple of 256 to match Python's GatedMLP behavior. --- lib/bumblebee/text/nomic_bert.ex | 9 ++++++++- test/bumblebee/text/nomic_bert_test.exs | 16 +++++----------- 2 files changed, 13 insertions(+), 12 deletions(-) diff --git a/lib/bumblebee/text/nomic_bert.ex b/lib/bumblebee/text/nomic_bert.ex index 5025ee3a..0ef224fb 100644 --- a/lib/bumblebee/text/nomic_bert.ex +++ b/lib/bumblebee/text/nomic_bert.ex @@ -244,7 +244,7 @@ defmodule Bumblebee.Text.NomicBert do kernel_initializer: kernel_initializer(spec), layer_norm: [epsilon: spec.layer_norm_epsilon], ffn: - &gated_ffn(&1, spec.intermediate_size || 4 * spec.hidden_size, spec.hidden_size, + &gated_ffn(&1, intermediate_size(spec), spec.hidden_size, name: &2, activation: spec.activation, gate_use_bias: spec.ffn_gate_bias, @@ -296,6 +296,13 @@ defmodule Bumblebee.Text.NomicBert do Axon.Initializers.normal(scale: spec.initializer_scale) end + # NomicBERT rounds intermediate_size to nearest multiple of 256 for hardware efficiency + defp intermediate_size(spec) do + size = spec.intermediate_size || div(8 * spec.hidden_size, 3) + multiple_of = 256 + div(size + multiple_of - 1, multiple_of) * multiple_of + end + defimpl Bumblebee.HuggingFace.Transformers.Config do def load(spec, data) do import Shared.Converters diff --git a/test/bumblebee/text/nomic_bert_test.exs b/test/bumblebee/text/nomic_bert_test.exs index a00503e1..825ac5f5 100644 --- a/test/bumblebee/text/nomic_bert_test.exs +++ b/test/bumblebee/text/nomic_bert_test.exs @@ -5,16 +5,11 @@ defmodule Bumblebee.Text.NomicBertTest do @moduletag model_test_tags() - @tag :slow test ":base" do assert {:ok, %{model: model, params: params, spec: spec}} = - Bumblebee.load_model({:hf, "nomic-ai/nomic-embed-text-v1.5"}) + Bumblebee.load_model({:hf, "bumblebee-testing/tiny-random-NomicBertModel"}) assert %Bumblebee.Text.NomicBert{architecture: :base} = spec - assert spec.hidden_size == 768 - assert spec.num_blocks == 12 - assert spec.num_attention_heads == 12 - assert spec.rotary_embedding_base == 1000 inputs = %{ "input_ids" => Nx.tensor([[10, 20, 30, 40, 50, 60, 70, 80, 0, 0]]), @@ -23,20 +18,19 @@ defmodule Bumblebee.Text.NomicBertTest do outputs = Axon.predict(model, params, inputs) - assert Nx.shape(outputs.hidden_state) == {1, 10, 768} - assert Nx.shape(outputs.pooled_state) == {1, 768} + assert Nx.shape(outputs.hidden_state) == {1, 10, 32} + assert Nx.shape(outputs.pooled_state) == {1, 32} - # Values verified against Python transformers assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ - [[0.0315, -5.2254, 0.0180], [0.0877, -5.3772, 0.1800], [-0.0546, -4.8813, 0.2614]] + [[1.5269, -0.3709, -0.6235], [0.0301, -0.1500, -1.0316], [-1.4733, -1.1167, 0.2346]] ]) ) assert_all_close( outputs.pooled_state[[.., 1..3]], - Nx.tensor([[0.0340, -5.2018, 0.1686]]) + Nx.tensor([[0.1788, -0.2985, 0.4405]]) ) end end From 67cefe744ec0ec8117136738280f9aa5b1ee62db Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Sun, 11 Jan 2026 09:41:59 -0300 Subject: [PATCH 10/11] Fix expected test values to match tiny-random model --- test/bumblebee/text/nomic_bert_test.exs | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/bumblebee/text/nomic_bert_test.exs b/test/bumblebee/text/nomic_bert_test.exs index 825ac5f5..6a0ff6df 100644 --- a/test/bumblebee/text/nomic_bert_test.exs +++ b/test/bumblebee/text/nomic_bert_test.exs @@ -24,13 +24,13 @@ defmodule Bumblebee.Text.NomicBertTest do assert_all_close( outputs.hidden_state[[.., 1..3, 1..3]], Nx.tensor([ - [[1.5269, -0.3709, -0.6235], [0.0301, -0.1500, -1.0316], [-1.4733, -1.1167, 0.2346]] + [[0.3062, -2.1607, -0.1782], [0.5486, 0.5353, -0.0453], [0.6323, -0.7370, -2.0245]] ]) ) assert_all_close( outputs.pooled_state[[.., 1..3]], - Nx.tensor([[0.1788, -0.2985, 0.4405]]) + Nx.tensor([[0.5064, -0.6647, -0.9303]]) ) end end From f370e48b09e09a72570a83c8a949ed3ca07bd601 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?George=20Guimar=C3=A3es?= Date: Tue, 13 Jan 2026 13:01:11 -0300 Subject: [PATCH 11/11] Use pooler layer instead of mean pooling --- lib/bumblebee/text/nomic_bert.ex | 40 ++++++++++--------------- test/bumblebee/text/nomic_bert_test.exs | 2 +- 2 files changed, 16 insertions(+), 26 deletions(-) diff --git a/lib/bumblebee/text/nomic_bert.ex b/lib/bumblebee/text/nomic_bert.ex index 0ef224fb..98a4ef84 100644 --- a/lib/bumblebee/text/nomic_bert.ex +++ b/lib/bumblebee/text/nomic_bert.ex @@ -180,30 +180,7 @@ defmodule Bumblebee.Text.NomicBert do name: "encoder" ) - # Mean pooling over non-masked tokens - pooled_state = - Layers.if_present inputs["attention_mask"] do - Axon.layer( - fn hidden_state, attention_mask, _opts -> - # Expand mask for broadcasting with hidden_size - mask = Nx.new_axis(attention_mask, -1) - # Mask out padding tokens - masked = Nx.multiply(hidden_state, mask) - # Sum and normalize by actual sequence length - sum = Nx.sum(masked, axes: [1]) - count = Nx.sum(mask, axes: [1]) - Nx.divide(sum, Nx.max(count, 1.0e-9)) - end, - [encoder_outputs.hidden_state, inputs["attention_mask"]] - ) - else - Axon.layer( - fn hidden_state, _opts -> - Nx.mean(hidden_state, axes: [1]) - end, - [encoder_outputs.hidden_state] - ) - end + pooled_state = pooler(encoder_outputs.hidden_state, spec, name: "pooler") %{ hidden_state: encoder_outputs.hidden_state, @@ -266,6 +243,18 @@ defmodule Bumblebee.Text.NomicBert do ) end + defp pooler(hidden_state, spec, opts) do + name = opts[:name] + + hidden_state + |> Layers.take_token(index: 0, axis: 1) + |> Axon.dense(spec.hidden_size, + kernel_initializer: kernel_initializer(spec), + name: join(name, "output") + ) + |> Axon.tanh() + end + defp gated_ffn(hidden_state, intermediate_size, output_size, opts) do name = opts[:name] activation = opts[:activation] @@ -342,7 +331,8 @@ defmodule Bumblebee.Text.NomicBert do "encoder.blocks.{n}.ffn.up" => "encoder.layers.{n}.mlp.fc11", "encoder.blocks.{n}.ffn.gate" => "encoder.layers.{n}.mlp.fc12", "encoder.blocks.{n}.ffn.down" => "encoder.layers.{n}.mlp.fc2", - "encoder.blocks.{n}.output_norm" => "encoder.layers.{n}.norm2" + "encoder.blocks.{n}.output_norm" => "encoder.layers.{n}.norm2", + "pooler.output" => "pooler.dense" } end diff --git a/test/bumblebee/text/nomic_bert_test.exs b/test/bumblebee/text/nomic_bert_test.exs index 6a0ff6df..92b1829b 100644 --- a/test/bumblebee/text/nomic_bert_test.exs +++ b/test/bumblebee/text/nomic_bert_test.exs @@ -30,7 +30,7 @@ defmodule Bumblebee.Text.NomicBertTest do assert_all_close( outputs.pooled_state[[.., 1..3]], - Nx.tensor([[0.5064, -0.6647, -0.9303]]) + Nx.tensor([[0.0197, -0.2129, -0.0071]]) ) end end