From 023f909994c3f44735e591267d4fbe341e57b63d Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 31 Jan 2026 08:07:28 -0500 Subject: [PATCH 1/3] Add parameter shape and name conveniences --- lib/axon.ex | 238 ++++++++++++++++++--------------------- lib/axon/quantization.ex | 7 +- lib/axon/shape.ex | 116 ------------------- 3 files changed, 113 insertions(+), 248 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 2c381ccf0..4164fcb6c 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -441,6 +441,51 @@ defmodule Axon do } end + def parameter(name, shape_dsl, opts) when is_list(shape_dsl) do + if Enum.all?(shape_dsl, &is_integer/1) do + template = Nx.template(List.to_tuple(shape_dsl), {:f, 32}) + parameter(name, template, opts) + else + template_fn = compile_shape_dsl(shape_dsl, {:f, 32}) + parameter(name, template_fn, opts) + end + end + + defp compile_shape_dsl(shape_dsl, type) do + arity = + Enum.reduce(shape_dsl, 1, fn + {:axis, _n, [input: k]}, acc -> max(acc, k + 1) + _, acc -> acc + end) + + shape_fn = fn shapes -> + shape_dsl + |> Enum.map(fn + {:axis, n} -> + shape = hd(shapes) + axis = normalize_axis(n, tuple_size(shape)) + elem(shape, axis) + + {:axis, n, [input: k]} -> + shape = Enum.at(shapes, k) + axis = normalize_axis(n, tuple_size(shape)) + elem(shape, axis) + + n when is_integer(n) -> + n + end) + |> List.to_tuple() + end + + shape_fun(arity, fn templates -> + shapes = templates |> List.wrap() |> Enum.map(&Nx.shape/1) + Nx.template(shape_fn.(shapes), type) + end) + end + + defp normalize_axis(axis, rank) when axis < 0, do: rank + axis + defp normalize_axis(axis, _rank), do: axis + @doc """ Trainable Axon parameter used to create custom layers. @@ -484,6 +529,20 @@ defmodule Axon do parameter(name, template, opts) end + def param(name, shape_dsl, opts) when is_binary(name) and is_list(shape_dsl) do + opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter) + {type, opts} = Keyword.pop(opts, :type, {:f, 32}) + + # If all elements are integers, treat it as a static shape (like tuples) + if Enum.all?(shape_dsl, &is_integer/1) do + template = Nx.template(List.to_tuple(shape_dsl), type) + parameter(name, template, opts) + else + template_fn = compile_shape_dsl(shape_dsl, type) + parameter(name, template_fn, opts) + end + end + for i <- 0..128 do args = Macro.generate_arguments(i, __MODULE__) @@ -861,14 +920,11 @@ defmodule Axon do |> Map.put(:units, units) |> Map.put(:use_bias, opts[:use_bias]) - kernel_shape = &Axon.Shape.dense_kernel(&1, units) - bias_shape = &Axon.Shape.dense_bias(&1, units) - - kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) + kernel = param("kernel", [{:axis, -1}, units], initializer: opts[:kernel_initializer]) {inputs, op} = if opts[:use_bias] do - bias = param("bias", bias_shape, initializer: opts[:bias_initializer]) + bias = param("bias", [units], initializer: opts[:bias_initializer]) {[x, kernel, bias], :dense} else {[x, kernel], :dense} @@ -934,14 +990,14 @@ defmodule Axon do use_bias: true ]) - kernel_shape = &Axon.Shape.bilinear_kernel(&1, &2, units) - bias_shape = &Axon.Shape.bilinear_bias(&1, &2, units) - - kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) + kernel = + param("kernel", [units, {:axis, -1, input: 0}, {:axis, -1, input: 1}], + initializer: opts[:kernel_initializer] + ) {inputs, op} = if opts[:use_bias] do - bias = param("bias", bias_shape, initializer: opts[:bias_initializer]) + bias = param("bias", [units], initializer: opts[:bias_initializer]) {[input1, input2, kernel, bias], :bilinear} else {[input1, input2, kernel], :bilinear} @@ -1028,13 +1084,12 @@ defmodule Axon do feature_group_size = opts[:feature_group_size] kernel_shape = &Axon.Shape.conv_kernel(&1, units, kernel_size, channels, feature_group_size) - bias_shape = &Axon.Shape.conv_bias(&1, units, kernel_size, channels, feature_group_size) kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) {inputs, op} = if opts[:use_bias] do - bias = param("bias", bias_shape, initializer: opts[:bias_initializer]) + bias = param("bias", [units], initializer: opts[:bias_initializer]) {[x, kernel, bias], :conv} else {[x, kernel], :conv} @@ -1121,13 +1176,11 @@ defmodule Axon do channels = opts[:channels] kernel_shape = &Axon.Shape.conv_kernel(&1, units, kernel_size, channels, 1) - bias_shape = &Axon.Shape.conv_bias(&1, units, kernel_size, channels, 1) - kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) {inputs, op} = if opts[:use_bias] do - bias = param("bias", bias_shape, initializer: opts[:bias_initializer]) + bias = param("bias", [units], initializer: opts[:bias_initializer]) {[x, kernel, bias], :conv_transpose} else {[x, kernel], :conv_transpose} @@ -1935,16 +1988,11 @@ defmodule Axon do channel_index = opts[:channel_index] - gamma_shape = &Axon.Shape.norm_param(&1, channel_index) - beta_shape = &Axon.Shape.norm_param(&1, channel_index) - mean_shape = &Axon.Shape.norm_param(&1, channel_index) - var_shape = &Axon.Shape.norm_param(&1, channel_index) - - gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer]) - beta = param("beta", beta_shape, initializer: opts[:beta_initializer]) + gamma = param("gamma", [{:axis, channel_index}], initializer: opts[:gamma_initializer]) + beta = param("beta", [{:axis, channel_index}], initializer: opts[:beta_initializer]) - mean = param("mean", mean_shape, initializer: :zeros, kind: :state) - var = param("var", var_shape, initializer: :ones, kind: :state) + mean = param("mean", [{:axis, channel_index}], initializer: :zeros, kind: :state) + var = param("var", [{:axis, channel_index}], initializer: :ones, kind: :state) layer( norm, @@ -2003,11 +2051,8 @@ defmodule Axon do channel_index = opts[:channel_index] - gamma_shape = &Axon.Shape.norm_param(&1, channel_index) - beta_shape = &Axon.Shape.norm_param(&1, channel_index) - - gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer]) - beta = param("beta", beta_shape, initializer: opts[:beta_initializer]) + gamma = param("gamma", [{:axis, channel_index}], initializer: opts[:gamma_initializer]) + beta = param("beta", [{:axis, channel_index}], initializer: opts[:beta_initializer]) layer(norm, [x, gamma, beta], name: opts[:name], @@ -2054,11 +2099,8 @@ defmodule Axon do channel_index = opts[:channel_index] - gamma_shape = &Axon.Shape.norm_param(&1, channel_index) - beta_shape = &Axon.Shape.norm_param(&1, channel_index) - - gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer]) - beta = param("beta", beta_shape, initializer: opts[:beta_initializer]) + gamma = param("gamma", [{:axis, channel_index}], initializer: opts[:gamma_initializer]) + beta = param("beta", [{:axis, channel_index}], initializer: opts[:beta_initializer]) layer(:group_norm, [x, gamma, beta], name: opts[:name], @@ -2124,8 +2166,7 @@ defmodule Axon do ]) channel_index = opts[:channel_index] - gamma_shape = &Axon.Shape.norm_param(&1, channel_index) - gamma = param("gamma", gamma_shape, initializer: opts[:gamma_initializer]) + gamma = param("gamma", [{:axis, channel_index}], initializer: opts[:gamma_initializer]) layer(:rms_norm, [x, gamma], name: opts[:name], @@ -2740,13 +2781,8 @@ defmodule Axon do hidden_state_name = case opts[:name] do - nil -> - fn _, op_counts -> - "lstm_#{op_counts[:lstm]}_hidden_state" - end - - name when is_binary(name) -> - "#{name}_hidden_state" + nil -> "lstm_{n}_hidden_state" + name when is_binary(name) -> "#{name}_hidden_state" end hidden_state = Axon.container(hidden_state, name: hidden_state_name) @@ -2776,35 +2812,20 @@ defmodule Axon do new_c_name = case opts[:name] do - nil -> - fn _, op_counts -> - "lstm_#{op_counts[:lstm]}_c_hidden_state" - end - - name when is_binary(name) -> - "#{name}_c_hidden_state" + nil -> "lstm_{n}_c_hidden_state" + name when is_binary(name) -> "#{name}_c_hidden_state" end new_h_name = case opts[:name] do - nil -> - fn _, op_counts -> - "lstm_#{op_counts[:lstm]}_h_hidden_state" - end - - name when is_binary(name) -> - "#{name}_h_hidden_state" + nil -> "lstm_{n}_h_hidden_state" + name when is_binary(name) -> "#{name}_h_hidden_state" end output_sequence_name = case opts[:name] do - nil -> - fn _, op_counts -> - "lstm_#{op_counts[:lstm]}_output_sequence" - end - - name when is_binary(name) -> - "#{name}_output_sequence" + nil -> "lstm_{n}_output_sequence" + name when is_binary(name) -> "#{name}_output_sequence" end output_sequence = @@ -2975,13 +2996,8 @@ defmodule Axon do hidden_state_name = case opts[:name] do - nil -> - fn _, op_counts -> - "gru_#{op_counts[:gru]}_hidden_state" - end - - name when is_binary(name) -> - "#{name}_hidden_state" + nil -> "gru_{n}_hidden_state" + name when is_binary(name) -> "#{name}_hidden_state" end hidden_state = Axon.container(hidden_state, name: hidden_state_name) @@ -3036,24 +3052,14 @@ defmodule Axon do new_h_name = case opts[:name] do - nil -> - fn _, op_counts -> - "gru_#{op_counts[:gru]}_hidden_state" - end - - name when is_binary(name) -> - "#{name}_hidden_state" + nil -> "gru_{n}_hidden_state" + name when is_binary(name) -> "#{name}_hidden_state" end output_sequence_name = case opts[:name] do - nil -> - fn _, op_counts -> - "gru_#{op_counts[:gru]}_output_sequence" - end - - name when is_binary(name) -> - "#{name}_output_sequence" + nil -> "gru_{n}_output_sequence" + name when is_binary(name) -> "#{name}_output_sequence" end output_sequence = @@ -3192,13 +3198,8 @@ defmodule Axon do hidden_state_name = case opts[:name] do - nil -> - fn _, op_counts -> - "conv_lstm_#{op_counts[:conv_lstm]}_hidden_state" - end - - name when is_binary(name) -> - "#{name}_hidden_state" + nil -> "conv_lstm_{n}_hidden_state" + name when is_binary(name) -> "#{name}_hidden_state" end hidden_state = Axon.container(hidden_state, name: hidden_state_name) @@ -3228,35 +3229,20 @@ defmodule Axon do new_c_name = case opts[:name] do - nil -> - fn _, op_counts -> - "conv_lstm_#{op_counts[:lstm]}_c_hidden_state" - end - - name when is_binary(name) -> - "#{name}_c_hidden_state" + nil -> "conv_lstm_{n}_c_hidden_state" + name when is_binary(name) -> "#{name}_c_hidden_state" end new_h_name = case opts[:name] do - nil -> - fn _, op_counts -> - "conv_lstm_#{op_counts[:lstm]}_h_hidden_state" - end - - name when is_binary(name) -> - "#{name}_h_hidden_state" + nil -> "conv_lstm_{n}_h_hidden_state" + name when is_binary(name) -> "#{name}_h_hidden_state" end output_sequence_name = case opts[:name] do - nil -> - fn _, op_counts -> - "conv_lstm_#{op_counts[:lstm]}_output_sequence" - end - - name when is_binary(name) -> - "#{name}_output_sequence" + nil -> "conv_lstm_{n}_output_sequence" + name when is_binary(name) -> "#{name}_output_sequence" end output_sequence = @@ -3292,14 +3278,8 @@ defmodule Axon do name = case parent_name do - nil -> - fn _, op_counts -> - count = op_counts[rnn_type] || 0 - "#{Atom.to_string(rnn_type)}_#{count}_#{state_name}_hidden_state" - end - - parent_name when is_binary(parent_name) -> - "#{parent_name}_#{state_name}_hidden_state" + nil -> "#{Atom.to_string(rnn_type)}_{n}_#{state_name}_hidden_state" + parent_name when is_binary(parent_name) -> "#{parent_name}_#{state_name}_hidden_state" end initializer = @@ -3365,9 +3345,7 @@ defmodule Axon do def embedding(%Axon{} = x, vocab_size, embedding_size, opts \\ []) do opts = Keyword.validate!(opts, [:name, :meta, kernel_initializer: :uniform]) - kernel_shape = &Axon.Shape.embedding_kernel(&1, vocab_size, embedding_size) - - kernel = param("kernel", kernel_shape, initializer: opts[:kernel_initializer]) + kernel = param("kernel", [vocab_size, embedding_size], initializer: opts[:kernel_initializer]) layer(:embedding, [x, kernel], name: opts[:name], meta: opts[:meta], op_name: :embedding) end @@ -3389,8 +3367,7 @@ defmodule Axon do def bias(%Axon{} = x, opts \\ []) do opts = Keyword.validate!(opts, [:name, :meta, bias_initializer: :zeros]) - bias_shape = fn shape -> {elem(shape, tuple_size(shape) - 1)} end - bias = param("bias", bias_shape, initializer: opts[:bias_initializer]) + bias = param("bias", [{:axis, -1}], initializer: opts[:bias_initializer]) layer(:bias, [x, bias], name: opts[:name], meta: opts[:meta], op_name: :bias) end @@ -4259,7 +4236,14 @@ defmodule Axon do end defp name(_type, name) when is_binary(name) do - fn _, _ -> name end + if String.contains?(name, "{n}") do + fn op, op_counts -> + count = op_counts[op] || 0 + String.replace(name, "{n}", Integer.to_string(count)) + end + else + fn _, _ -> name end + end end defp name(_type, name) do diff --git a/lib/axon/quantization.ex b/lib/axon/quantization.ex index ed976b8d0..01a2b5676 100644 --- a/lib/axon/quantization.ex +++ b/lib/axon/quantization.ex @@ -123,11 +123,8 @@ defmodule Axon.Quantization do |> Map.put(:units, units) |> Map.put(:use_bias, opts[:use_bias]) - kernel_shape = &Axon.Shape.dense_kernel(&1, units) - bias_shape = &Axon.Shape.dense_bias(&1, units) - kernel = - Axon.param("kernel", kernel_shape, + Axon.param("kernel", [{:axis, -1}, units], initializer: fn shape, type, key -> fun = case opts[:kernel_initializer] do @@ -153,7 +150,7 @@ defmodule Axon.Quantization do {inputs, op} = if opts[:use_bias] do - bias = Axon.param("bias", bias_shape, initializer: opts[:bias_initializer]) + bias = Axon.param("bias", [units], initializer: opts[:bias_initializer]) {[x, kernel, bias], &Layers.weight_only_quantized_dense/4} else {[x, kernel], &Layers.weight_only_quantized_dense/3} diff --git a/lib/axon/shape.ex b/lib/axon/shape.ex index d06471f64..55fc98f38 100644 --- a/lib/axon/shape.ex +++ b/lib/axon/shape.ex @@ -91,122 +91,6 @@ defmodule Axon.Shape do end end - ## Linear - - @doc """ - Calculates the shape of a dense kernel given the input - shape and output units. - - ## Examples - - iex> Axon.Shape.dense_kernel({nil, 784}, 128) - {784, 128} - - iex> Axon.Shape.dense_kernel({nil, 128}, 256) - {128, 256} - - iex> Axon.Shape.dense_kernel({nil, 3, 256, 256}, 128) - {256, 128} - """ - def dense_kernel(input_shape, units) do - unless Nx.rank(input_shape) >= 2 do - raise ArgumentError, - "input shape must have at least rank 2, got rank" <> - " #{Nx.rank(input_shape)}" - end - - {elem(input_shape, Nx.rank(input_shape) - 1), units} - end - - @doc """ - Calculates the shape of a dense bias given the input - shape and output units. - - ## Examples - - iex> Axon.Shape.dense_bias({nil, 784}, 128) - {128} - - iex> Axon.Shape.dense_bias({nil, 128}, 256) - {256} - - iex> Axon.Shape.dense_bias({nil, 3, 256, 256}, 128) - {128} - """ - def dense_bias(input_shape, units) do - unless Nx.rank(input_shape) >= 2 do - raise ArgumentError, - "input shape must have at least rank 2, got rank" <> - " #{Nx.rank(input_shape)}" - end - - {units} - end - - @doc """ - Calculates the shape of a bilinear kernel given both input - shapes and output units. - - ## Examples - - iex> Axon.Shape.bilinear_kernel({nil, 32}, {nil, 64}, 128) - {128, 32, 64} - - iex> Axon.Shape.bilinear_kernel({nil, 32, 64}, {nil, 16}, 32) - {32, 64, 16} - """ - def bilinear_kernel(parent1, parent2, units) do - unless Nx.rank(parent1) >= 2 and Nx.rank(parent2) >= 2 do - raise ArgumentError, - "input shapes must both have at least rank 2" <> - " got ranks #{Nx.rank(parent1)} and #{Nx.rank(parent2)}" - end - - parent1_features = elem(parent1, Nx.rank(parent1) - 1) - parent2_features = elem(parent2, Nx.rank(parent2) - 1) - {units, parent1_features, parent2_features} - end - - @doc """ - Calculates the shape of a bilinear bias given both input - shapes and output units. - - ## Examples - - iex> Axon.Shape.bilinear_bias({nil, 32}, {nil, 64}, 128) - {128} - - iex> Axon.Shape.bilinear_bias({nil, 32, 64}, {nil, 32, 16}, 32) - {32} - """ - def bilinear_bias(parent1, parent2, units) do - unless Nx.rank(parent1) >= 2 and Nx.rank(parent2) >= 2 do - raise ArgumentError, - "input shapes must both have at least rank 2" <> - " got ranks #{Nx.rank(parent1)} and #{Nx.rank(parent2)}" - end - - {units} - end - - ## Sparse - - @doc """ - Calculates the shape of an embedding kernel given input shape - vocab size and embedding size. - - ## Examples - - iex> Axon.Shape.embedding_kernel({nil, 10}, 128, 32) - {128, 32} - - iex> Axon.Shape.embedding_kernel({nil, 32}, 10, 10) - {10, 10} - """ - def embedding_kernel(_input_shape, vocab_size, embedding_size) do - {vocab_size, embedding_size} - end - ## Conv @doc """ From 6ba2474867a62a87bfdddf17270bafbb9642f065 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 31 Jan 2026 08:09:35 -0500 Subject: [PATCH 2/3] Apply suggestion from @seanmor5 --- lib/axon.ex | 1 - 1 file changed, 1 deletion(-) diff --git a/lib/axon.ex b/lib/axon.ex index 4164fcb6c..654d6cbbe 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -533,7 +533,6 @@ defmodule Axon do opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter) {type, opts} = Keyword.pop(opts, :type, {:f, 32}) - # If all elements are integers, treat it as a static shape (like tuples) if Enum.all?(shape_dsl, &is_integer/1) do template = Nx.template(List.to_tuple(shape_dsl), type) parameter(name, template, opts) From 69992bd5bf0a4859e39d3bbd198791154a2118c4 Mon Sep 17 00:00:00 2001 From: Sean Moriarity Date: Sat, 31 Jan 2026 14:09:01 -0500 Subject: [PATCH 3/3] Update docs --- lib/axon.ex | 70 +++++++++++++++++++++++++++++++++++++++++++++++------ 1 file changed, 62 insertions(+), 8 deletions(-) diff --git a/lib/axon.ex b/lib/axon.ex index 654d6cbbe..764dd56b9 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -402,14 +402,52 @@ defmodule Axon do automatically initialized and used in subsequent applications of Axon models. - You must specify a parameter "template" which can be a static template - tensor or a function which takes model input templates and returns a - template. It's most common to use functions because most parameters' - shapes rely on input shape information. + You may specify the parameter shape as: + + * A static tuple shape, e.g. `{32, 64}` + * A static template tensor + * A function which takes model input templates and returns a template + * A shape list using `{:axis, n}` or `{:axis, n, input: k}` to + reference input dimensions dynamically + + ## Options + + * `:initializer` - parameter initializer. Defaults to `:glorot_uniform`. + * `:type` - parameter type. Defaults to `{:f, 32}`. + * `:kind` - parameter kind. Defaults to `:parameter`. + + ## Examples + + A static shape: + + parameter("kernel", {32, 64}) + + Using a function: + + parameter("kernel", fn input -> + Nx.template({elem(Nx.shape(input), 1), 64}, Nx.type(input)) + end) + + Using the shape DSL: + + parameter("kernel", [{:axis, -1}, 64]) + + With options: + + parameter("kernel", {32, 64}, type: {:bf, 16}, initializer: :lecun_normal) + """ @doc type: :special def parameter(name, template, opts \\ []) + def parameter(name, shape, opts) when is_tuple(shape) do + opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter) + {type, opts} = Keyword.pop!(opts, :type) + + template = Nx.template(shape, type) + parameter(name, template, opts) + end + def parameter(name, %Nx.Tensor{} = template, opts) do opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter) initializer = validate_initializer!(opts[:initializer]) @@ -429,24 +467,40 @@ defmodule Axon do end def parameter(name, function, opts) when is_function(function) do - opts = Keyword.validate!(opts, initializer: :glorot_uniform, kind: :parameter) + opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: nil, kind: :parameter) + {type, opts} = Keyword.pop!(opts, :type) initializer = validate_initializer!(opts[:initializer]) kind = opts[:kind] || :parameter + template = + if type do + {:arity, arity} = Function.info(function, :arity) + + shape_fun(arity, fn templates -> + result = apply(function, List.wrap(templates)) + Nx.template(Nx.shape(result), type) + end) + else + function + end + %Axon.Parameter{ name: name, - template: function, + template: template, initializer: initializer, kind: kind } end def parameter(name, shape_dsl, opts) when is_list(shape_dsl) do + opts = Keyword.validate!(opts, initializer: :glorot_uniform, type: {:f, 32}, kind: :parameter) + {type, opts} = Keyword.pop!(opts, :type) + if Enum.all?(shape_dsl, &is_integer/1) do - template = Nx.template(List.to_tuple(shape_dsl), {:f, 32}) + template = Nx.template(List.to_tuple(shape_dsl), type) parameter(name, template, opts) else - template_fn = compile_shape_dsl(shape_dsl, {:f, 32}) + template_fn = compile_shape_dsl(shape_dsl, type) parameter(name, template_fn, opts) end end