diff --git a/lib/axon/compiler.ex b/lib/axon/compiler.ex index 50b0d565..beb8d6b4 100644 --- a/lib/axon/compiler.ex +++ b/lib/axon/compiler.ex @@ -916,23 +916,11 @@ defmodule Axon.Compiler do else # Parameters are just accessed in the layer sub-map of the nested # parameter map, so we just need to extract them and then apply - # freezing and dtype policy + # freezing and dtype policy. Parameters may be SharedParameter + # structs for tied weights, which are resolved to their source. parameter_inputs = Enum.map(layer_params, fn %{name: v, frozen: frz} -> - param = params[name][v] - - cond do - param != nil -> - safe_policy_cast(maybe_freeze(param, frz), policy, :compute) - - true -> - raise ArgumentError, - "parameter #{inspect(v)} for layer: #{inspect(name)} in" <> - " was not present in the given parameter map, this can" <> - " happen if you are using parameters intended for another" <> - " model or did not initialize portions of your model with" <> - " Axon.init/3" - end + resolve_parameter!(params, name, v, frz, policy) end) # Reorder the inputs according to the original input ordering @@ -1188,5 +1176,42 @@ defmodule Axon.Compiler do defp propagating_none?(%Axon.None{__propagate__: true}), do: true defp propagating_none?(_), do: false + defp resolve_parameter!(params, layer_name, param_name, freeze?, policy) do + # Special case where this is a SharedParameter at the layer level, so we + # need to resolve that before forwarding. Otherwise this falls through and + # is handled at the next step + layer_params = + with %Axon.ModelState.SharedParameter{path: path} <- params[layer_name] do + get_in(params, path) + end + + parameter = + case layer_params[param_name] do + nil -> + raise ArgumentError, + "parameter #{inspect(param_name)} for layer: #{inspect(layer_name)}" <> + " was not present in the given parameter map, this can" <> + " happen if you are using parameters intended for another" <> + " model or did not initialize portions of your model with" <> + " Axon.init/3" + + %Axon.ModelState.SharedParameter{path: path, transform: transform} -> + tensor = + with nil <- get_in(params, path) do + raise ArgumentError, + "shared parameter for #{inspect(param_name)} in layer:" <> + " #{inspect(layer_name)}, references non-existent parameter" <> + " #{inspect(path)}" + end + + if transform, do: transform.(tensor), else: tensor + + parameter -> + parameter + end + + safe_policy_cast(maybe_freeze(parameter, freeze?), policy, :compute) + end + defp us_to_ms(time), do: Float.round(time / 1000, 1) end diff --git a/lib/axon/model_state.ex b/lib/axon/model_state.ex index b9088a1b..8304d3fb 100644 --- a/lib/axon/model_state.ex +++ b/lib/axon/model_state.ex @@ -191,6 +191,54 @@ defmodule Axon.ModelState do } end + @doc """ + Ties a parameter to another parameter, enabling weight sharing. + + The destination parameter will reference the source parameter's tensor, + optionally applying a transformation. Both `destination` and `source` + are access paths (lists of strings) into the model state data. + + ## Options + + * `:transform` - a function to transform the source tensor before + use at the destination. For example, `&Nx.transpose/1` for tying + an embedding layer to an output projection. + + ## Examples + + # Tie output projection to embedding weights (transposed) + model_state = Axon.ModelState.tie( + model_state, + ["output", "kernel"], + ["embed", "kernel"], + transform: &Nx.transpose/1 + ) + + """ + def tie(model_state, destination, source, opts \\ []) do + update_in(model_state, [Access.key!(:data)], fn data -> + shared = Axon.ModelState.SharedParameter.new(source, opts) + [key | rest] = Enum.reverse(destination) + + shared = + Enum.reduce(rest, %{key => shared}, fn next, acc -> + %{next => acc} + end) + + deep_merge(data, shared) + end) + end + + defp deep_merge(left, right) do + Map.merge(left, right, fn + _key, left_val, right_val when is_map(left_val) and is_map(right_val) -> + deep_merge(left_val, right_val) + + _key, _left_val, right_val -> + right_val + end) + end + defp transform_to_parameters(%Nx.Tensor{}), do: nil defp transform_to_parameters(map) when is_map(map) do @@ -249,6 +297,9 @@ defmodule Axon.ModelState do defp tree_get(data, access) when is_list(access) do Enum.reduce(access, %{}, fn key, acc -> case data do + %{^key => %Axon.ModelState.SharedParameter{}} -> + acc + %{^key => val} -> Map.put(acc, key, val) @@ -261,9 +312,13 @@ defmodule Axon.ModelState do defp tree_get(data, access) when is_map(access) do Enum.reduce(access, %{}, fn {key, value}, acc -> case data do + %{^key => %Axon.ModelState.SharedParameter{}} -> + # Skip shared parameters - they reference another parameter + acc + %{^key => val} -> tree = tree_get(val, value) - Map.put(acc, key, tree) + if map_size(tree) == 0, do: acc, else: Map.put(acc, key, tree) %{} -> acc diff --git a/lib/axon/model_state/shared_parameter.ex b/lib/axon/model_state/shared_parameter.ex new file mode 100644 index 00000000..bc22671c --- /dev/null +++ b/lib/axon/model_state/shared_parameter.ex @@ -0,0 +1,19 @@ +defmodule Axon.ModelState.SharedParameter do + # Represents a tied or shared parameter for layers whose + # weights are connected but don't necessarily perform the + # same operation. This implements the Nx.Container behavior + # and contains an access path to the parameter that holds the + # original weight. + + @moduledoc false + + @derive {Nx.Container, containers: [], keep: [:path, :transform]} + defstruct [:path, :transform] + + def new(path, opts \\ []) do + %__MODULE__{ + path: path, + transform: Keyword.get(opts, :transform) + } + end +end diff --git a/test/axon/compiler_test.exs b/test/axon/compiler_test.exs index ad4298e4..c152b114 100644 --- a/test/axon/compiler_test.exs +++ b/test/axon/compiler_test.exs @@ -5538,6 +5538,120 @@ defmodule CompilerTest do end end + describe "weight tying" do + test "tied parameter uses source parameter value" do + # Both dense layers have same input/output size so kernels are compatible + model = + Axon.input("input", shape: {nil, 4}) + |> Axon.dense(4, name: "dense_0", use_bias: false) + |> Axon.dense(4, name: "dense_1", use_bias: false) + + {init_fn, predict_fn} = Axon.build(model) + input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + + model_state = init_fn.(input, ModelState.empty()) + + # Set dense_0 kernel to identity matrix so we can trace the computation + identity = Nx.eye(4) + model_state = put_in(model_state.data["dense_0"]["kernel"], identity) + model_state = put_in(model_state.data["dense_1"]["kernel"], Nx.broadcast(0.0, {4, 4})) + + # Without tying: input -> identity -> zeros = zeros + output_untied = predict_fn.(model_state, input) + assert_equal(output_untied, Nx.tensor([[0.0, 0.0, 0.0, 0.0]])) + + # With tying: input -> identity -> identity = input + tied_state = + ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"]) + + output_tied = predict_fn.(tied_state, input) + assert_equal(output_tied, input) + end + + test "tied parameter with transform applies transformation" do + model = + Axon.input("input", shape: {nil, 2}) + |> Axon.dense(4, name: "dense_0", use_bias: false) + |> Axon.dense(2, name: "dense_1", use_bias: false) + + {init_fn, predict_fn} = Axon.build(model) + input = Nx.tensor([[1.0, 2.0]]) + + model_state = init_fn.(input, ModelState.empty()) + + # Set a known kernel value + kernel = Nx.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]]) + model_state = put_in(model_state.data["dense_0"]["kernel"], kernel) + + # Tie with transpose: dense_1 uses kernel^T which is {4, 2} + tied_state = + ModelState.tie( + model_state, + ["dense_1", "kernel"], + ["dense_0", "kernel"], + transform: &Nx.transpose/1 + ) + + # input {1,2} @ kernel {2,4} = {1,4}, then @ kernel^T {4,2} = {1,2} + # [[1,2]] @ [[1,0,0,0],[0,1,0,0]] = [[1,2,0,0]] + # [[1,2,0,0]] @ [[1,0],[0,1],[0,0],[0,0]] = [[1,2]] + output = predict_fn.(tied_state, input) + assert_equal(output, input) + end + + test "modifying source parameter affects tied layers" do + model = + Axon.input("input", shape: {nil, 2}) + |> Axon.dense(2, name: "dense_0", use_bias: false) + |> Axon.dense(2, name: "dense_1", use_bias: false) + + {init_fn, predict_fn} = Axon.build(model) + input = Nx.tensor([[1.0, 0.0]]) + + model_state = init_fn.(input, ModelState.empty()) + + tied_state = + ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"]) + + # Set source kernel to a specific value + kernel_v1 = Nx.tensor([[1.0, 0.0], [0.0, 1.0]]) + tied_state = put_in(tied_state.data["dense_0"]["kernel"], kernel_v1) + output_v1 = predict_fn.(tied_state, input) + + # Change source kernel - tied layer should see the change + kernel_v2 = Nx.tensor([[2.0, 0.0], [0.0, 2.0]]) + tied_state = put_in(tied_state.data["dense_0"]["kernel"], kernel_v2) + output_v2 = predict_fn.(tied_state, input) + + # Outputs should differ because the shared kernel changed + refute Nx.all(Nx.equal(output_v1, output_v2)) |> Nx.to_number() == 1 + + # Verify expected values: input @ kernel @ kernel + # v1: [1,0] @ I @ I = [1,0] + # v2: [1,0] @ 2I @ 2I = [4,0] + assert_equal(output_v1, Nx.tensor([[1.0, 0.0]])) + assert_equal(output_v2, Nx.tensor([[4.0, 0.0]])) + end + + test "raises on non-existent shared parameter source" do + model = + Axon.input("input", shape: {nil, 2}) + |> Axon.dense(4, name: "dense_0") + + {init_fn, predict_fn} = Axon.build(model) + input = Nx.tensor([[1.0, 2.0]]) + + model_state = init_fn.(input, ModelState.empty()) + + tied_state = + ModelState.tie(model_state, ["dense_0", "kernel"], ["nonexistent", "kernel"]) + + assert_raise ArgumentError, ~r/shared parameter.*references non-existent/, fn -> + predict_fn.(tied_state, input) + end + end + end + describe "instrumentation" do @describetag :capture_log diff --git a/test/axon/model_state_test.exs b/test/axon/model_state_test.exs new file mode 100644 index 00000000..8033f27e --- /dev/null +++ b/test/axon/model_state_test.exs @@ -0,0 +1,86 @@ +defmodule Axon.ModelStateTest do + use ExUnit.Case, async: true + + alias Axon.ModelState + + describe "tie/4" do + test "creates shared parameter at destination path" do + model = + Axon.input("input", shape: {nil, 2}) + |> Axon.dense(4, name: "dense_0") + |> Axon.dense(4, name: "dense_1") + + {init_fn, _} = Axon.build(model) + model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty()) + + tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"]) + + assert %Axon.ModelState.SharedParameter{path: ["dense_0", "kernel"], transform: nil} = + tied.data["dense_1"]["kernel"] + end + + test "stores transform function" do + model = + Axon.input("input", shape: {nil, 2}) + |> Axon.dense(4, name: "dense_0") + |> Axon.dense(4, name: "dense_1") + + {init_fn, _} = Axon.build(model) + model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty()) + + tied = + ModelState.tie( + model_state, + ["dense_1", "kernel"], + ["dense_0", "kernel"], + transform: &Nx.transpose/1 + ) + + assert %Axon.ModelState.SharedParameter{transform: transform} = + tied.data["dense_1"]["kernel"] + + assert is_function(transform, 1) + end + end + + describe "trainable_parameters/1 with tied weights" do + test "excludes tied parameters" do + model = + Axon.input("input", shape: {nil, 2}) + |> Axon.dense(4, name: "dense_0") + |> Axon.dense(4, name: "dense_1") + + {init_fn, _} = Axon.build(model) + model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty()) + + # Before tying, both layers have kernel in trainable params + trainable_before = ModelState.trainable_parameters(model_state) + assert Map.has_key?(trainable_before["dense_0"], "kernel") + assert Map.has_key?(trainable_before["dense_1"], "kernel") + + # After tying, dense_1 kernel should be excluded + tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"]) + trainable_after = ModelState.trainable_parameters(tied) + + assert Map.has_key?(trainable_after["dense_0"], "kernel") + refute Map.has_key?(trainable_after["dense_1"], "kernel") + assert Map.has_key?(trainable_after["dense_1"], "bias") + end + + test "excludes layer when all parameters are tied" do + model = + Axon.input("input", shape: {nil, 2}) + |> Axon.dense(4, name: "dense_0", use_bias: false) + |> Axon.dense(4, name: "dense_1", use_bias: false) + + {init_fn, _} = Axon.build(model) + model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty()) + + tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"]) + trainable = ModelState.trainable_parameters(tied) + + assert Map.has_key?(trainable, "dense_0") + refute Map.has_key?(trainable, "dense_1") + end + end +end