Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
55 changes: 40 additions & 15 deletions lib/axon/compiler.ex
Original file line number Diff line number Diff line change
Expand Up @@ -916,23 +916,11 @@ defmodule Axon.Compiler do
else
# Parameters are just accessed in the layer sub-map of the nested
# parameter map, so we just need to extract them and then apply
# freezing and dtype policy
# freezing and dtype policy. Parameters may be SharedParameter
# structs for tied weights, which are resolved to their source.
parameter_inputs =
Enum.map(layer_params, fn %{name: v, frozen: frz} ->
param = params[name][v]

cond do
param != nil ->
safe_policy_cast(maybe_freeze(param, frz), policy, :compute)

true ->
raise ArgumentError,
"parameter #{inspect(v)} for layer: #{inspect(name)} in" <>
" was not present in the given parameter map, this can" <>
" happen if you are using parameters intended for another" <>
" model or did not initialize portions of your model with" <>
" Axon.init/3"
end
resolve_parameter!(params, name, v, frz, policy)
end)

# Reorder the inputs according to the original input ordering
Expand Down Expand Up @@ -1188,5 +1176,42 @@ defmodule Axon.Compiler do
defp propagating_none?(%Axon.None{__propagate__: true}), do: true
defp propagating_none?(_), do: false

defp resolve_parameter!(params, layer_name, param_name, freeze?, policy) do
# Special case where this is a SharedParameter at the layer level, so we
# need to resolve that before forwarding. Otherwise this falls through and
# is handled at the next step
layer_params =
with %Axon.ModelState.SharedParameter{path: path} <- params[layer_name] do
get_in(params, path)
end
Comment on lines +1183 to +1186
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
layer_params =
with %Axon.ModelState.SharedParameter{path: path} <- params[layer_name] do
get_in(params, path)
end
layer_params =
case params[layer_name] do
%Axon.ModelState.SharedParameter{path: path} ->
get_in(params, path)
nil ->
nil
end

I think this is slightly more readable


parameter =
case layer_params[param_name] do
nil ->
raise ArgumentError,
"parameter #{inspect(param_name)} for layer: #{inspect(layer_name)}" <>
" was not present in the given parameter map, this can" <>
" happen if you are using parameters intended for another" <>
" model or did not initialize portions of your model with" <>
" Axon.init/3"

%Axon.ModelState.SharedParameter{path: path, transform: transform} ->
tensor =
with nil <- get_in(params, path) do
raise ArgumentError,
"shared parameter for #{inspect(param_name)} in layer:" <>
" #{inspect(layer_name)}, references non-existent parameter" <>
" #{inspect(path)}"
end
Comment on lines +1200 to +1205
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
with nil <- get_in(params, path) do
raise ArgumentError,
"shared parameter for #{inspect(param_name)} in layer:" <>
" #{inspect(layer_name)}, references non-existent parameter" <>
" #{inspect(path)}"
end
if is_nil(get_in(params, path)) do
raise ArgumentError,
"shared parameter for #{inspect(param_name)} in layer:" <>
" #{inspect(layer_name)}, references non-existent parameter" <>
" #{inspect(path)}"
end


if transform, do: transform.(tensor), else: tensor

parameter ->
parameter
end

safe_policy_cast(maybe_freeze(parameter, freeze?), policy, :compute)
end

defp us_to_ms(time), do: Float.round(time / 1000, 1)
end
57 changes: 56 additions & 1 deletion lib/axon/model_state.ex
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,54 @@ defmodule Axon.ModelState do
}
end

@doc """
Ties a parameter to another parameter, enabling weight sharing.

The destination parameter will reference the source parameter's tensor,
optionally applying a transformation. Both `destination` and `source`
are access paths (lists of strings) into the model state data.

## Options

* `:transform` - a function to transform the source tensor before
use at the destination. For example, `&Nx.transpose/1` for tying
an embedding layer to an output projection.

## Examples

# Tie output projection to embedding weights (transposed)
model_state = Axon.ModelState.tie(
model_state,
["output", "kernel"],
["embed", "kernel"],
transform: &Nx.transpose/1
)

"""
def tie(model_state, destination, source, opts \\ []) do
update_in(model_state, [Access.key!(:data)], fn data ->
shared = Axon.ModelState.SharedParameter.new(source, opts)
[key | rest] = Enum.reverse(destination)

shared =
Enum.reduce(rest, %{key => shared}, fn next, acc ->
%{next => acc}
end)

deep_merge(data, shared)
end)
end

defp deep_merge(left, right) do
Map.merge(left, right, fn
_key, left_val, right_val when is_map(left_val) and is_map(right_val) ->
deep_merge(left_val, right_val)

_key, _left_val, right_val ->
right_val
end)
end

defp transform_to_parameters(%Nx.Tensor{}), do: nil

defp transform_to_parameters(map) when is_map(map) do
Expand Down Expand Up @@ -249,6 +297,9 @@ defmodule Axon.ModelState do
defp tree_get(data, access) when is_list(access) do
Enum.reduce(access, %{}, fn key, acc ->
case data do
%{^key => %Axon.ModelState.SharedParameter{}} ->
acc

%{^key => val} ->
Map.put(acc, key, val)

Expand All @@ -261,9 +312,13 @@ defmodule Axon.ModelState do
defp tree_get(data, access) when is_map(access) do
Enum.reduce(access, %{}, fn {key, value}, acc ->
case data do
%{^key => %Axon.ModelState.SharedParameter{}} ->
# Skip shared parameters - they reference another parameter
acc

%{^key => val} ->
tree = tree_get(val, value)
Map.put(acc, key, tree)
if map_size(tree) == 0, do: acc, else: Map.put(acc, key, tree)

%{} ->
acc
Expand Down
19 changes: 19 additions & 0 deletions lib/axon/model_state/shared_parameter.ex
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
defmodule Axon.ModelState.SharedParameter do
# Represents a tied or shared parameter for layers whose
# weights are connected but don't necessarily perform the
# same operation. This implements the Nx.Container behavior
# and contains an access path to the parameter that holds the
# original weight.

@moduledoc false

@derive {Nx.Container, containers: [], keep: [:path, :transform]}
defstruct [:path, :transform]

def new(path, opts \\ []) do
%__MODULE__{
path: path,
transform: Keyword.get(opts, :transform)
}
end
end
114 changes: 114 additions & 0 deletions test/axon/compiler_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -5538,6 +5538,120 @@ defmodule CompilerTest do
end
end

describe "weight tying" do
test "tied parameter uses source parameter value" do
# Both dense layers have same input/output size so kernels are compatible
model =
Axon.input("input", shape: {nil, 4})
|> Axon.dense(4, name: "dense_0", use_bias: false)
|> Axon.dense(4, name: "dense_1", use_bias: false)

{init_fn, predict_fn} = Axon.build(model)
input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]])

model_state = init_fn.(input, ModelState.empty())

# Set dense_0 kernel to identity matrix so we can trace the computation
identity = Nx.eye(4)
model_state = put_in(model_state.data["dense_0"]["kernel"], identity)
model_state = put_in(model_state.data["dense_1"]["kernel"], Nx.broadcast(0.0, {4, 4}))

# Without tying: input -> identity -> zeros = zeros
output_untied = predict_fn.(model_state, input)
assert_equal(output_untied, Nx.tensor([[0.0, 0.0, 0.0, 0.0]]))

# With tying: input -> identity -> identity = input
tied_state =
ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])

output_tied = predict_fn.(tied_state, input)
assert_equal(output_tied, input)
end

test "tied parameter with transform applies transformation" do
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(4, name: "dense_0", use_bias: false)
|> Axon.dense(2, name: "dense_1", use_bias: false)

{init_fn, predict_fn} = Axon.build(model)
input = Nx.tensor([[1.0, 2.0]])

model_state = init_fn.(input, ModelState.empty())

# Set a known kernel value
kernel = Nx.tensor([[1.0, 0.0, 0.0, 0.0], [0.0, 1.0, 0.0, 0.0]])
model_state = put_in(model_state.data["dense_0"]["kernel"], kernel)

# Tie with transpose: dense_1 uses kernel^T which is {4, 2}
tied_state =
ModelState.tie(
model_state,
["dense_1", "kernel"],
["dense_0", "kernel"],
transform: &Nx.transpose/1
)

# input {1,2} @ kernel {2,4} = {1,4}, then @ kernel^T {4,2} = {1,2}
# [[1,2]] @ [[1,0,0,0],[0,1,0,0]] = [[1,2,0,0]]
# [[1,2,0,0]] @ [[1,0],[0,1],[0,0],[0,0]] = [[1,2]]
output = predict_fn.(tied_state, input)
assert_equal(output, input)
end

test "modifying source parameter affects tied layers" do
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(2, name: "dense_0", use_bias: false)
|> Axon.dense(2, name: "dense_1", use_bias: false)

{init_fn, predict_fn} = Axon.build(model)
input = Nx.tensor([[1.0, 0.0]])

model_state = init_fn.(input, ModelState.empty())

tied_state =
ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])

# Set source kernel to a specific value
kernel_v1 = Nx.tensor([[1.0, 0.0], [0.0, 1.0]])
tied_state = put_in(tied_state.data["dense_0"]["kernel"], kernel_v1)
output_v1 = predict_fn.(tied_state, input)

# Change source kernel - tied layer should see the change
kernel_v2 = Nx.tensor([[2.0, 0.0], [0.0, 2.0]])
tied_state = put_in(tied_state.data["dense_0"]["kernel"], kernel_v2)
output_v2 = predict_fn.(tied_state, input)

# Outputs should differ because the shared kernel changed
refute Nx.all(Nx.equal(output_v1, output_v2)) |> Nx.to_number() == 1

# Verify expected values: input @ kernel @ kernel
# v1: [1,0] @ I @ I = [1,0]
# v2: [1,0] @ 2I @ 2I = [4,0]
assert_equal(output_v1, Nx.tensor([[1.0, 0.0]]))
assert_equal(output_v2, Nx.tensor([[4.0, 0.0]]))
end

test "raises on non-existent shared parameter source" do
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(4, name: "dense_0")

{init_fn, predict_fn} = Axon.build(model)
input = Nx.tensor([[1.0, 2.0]])

model_state = init_fn.(input, ModelState.empty())

tied_state =
ModelState.tie(model_state, ["dense_0", "kernel"], ["nonexistent", "kernel"])

assert_raise ArgumentError, ~r/shared parameter.*references non-existent/, fn ->
predict_fn.(tied_state, input)
end
end
end

describe "instrumentation" do
@describetag :capture_log

Expand Down
86 changes: 86 additions & 0 deletions test/axon/model_state_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
defmodule Axon.ModelStateTest do
use ExUnit.Case, async: true

alias Axon.ModelState

describe "tie/4" do
test "creates shared parameter at destination path" do
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(4, name: "dense_0")
|> Axon.dense(4, name: "dense_1")

{init_fn, _} = Axon.build(model)
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())

tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])

assert %Axon.ModelState.SharedParameter{path: ["dense_0", "kernel"], transform: nil} =
tied.data["dense_1"]["kernel"]
end

test "stores transform function" do
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(4, name: "dense_0")
|> Axon.dense(4, name: "dense_1")

{init_fn, _} = Axon.build(model)
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())

tied =
ModelState.tie(
model_state,
["dense_1", "kernel"],
["dense_0", "kernel"],
transform: &Nx.transpose/1
)

assert %Axon.ModelState.SharedParameter{transform: transform} =
tied.data["dense_1"]["kernel"]

assert is_function(transform, 1)
end
end

describe "trainable_parameters/1 with tied weights" do
test "excludes tied parameters" do
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(4, name: "dense_0")
|> Axon.dense(4, name: "dense_1")

{init_fn, _} = Axon.build(model)
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())

# Before tying, both layers have kernel in trainable params
trainable_before = ModelState.trainable_parameters(model_state)
assert Map.has_key?(trainable_before["dense_0"], "kernel")
assert Map.has_key?(trainable_before["dense_1"], "kernel")

# After tying, dense_1 kernel should be excluded
tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])
trainable_after = ModelState.trainable_parameters(tied)

assert Map.has_key?(trainable_after["dense_0"], "kernel")
refute Map.has_key?(trainable_after["dense_1"], "kernel")
assert Map.has_key?(trainable_after["dense_1"], "bias")
end

test "excludes layer when all parameters are tied" do
model =
Axon.input("input", shape: {nil, 2})
|> Axon.dense(4, name: "dense_0", use_bias: false)
|> Axon.dense(4, name: "dense_1", use_bias: false)

{init_fn, _} = Axon.build(model)
model_state = init_fn.(Nx.template({1, 2}, :f32), ModelState.empty())

tied = ModelState.tie(model_state, ["dense_1", "kernel"], ["dense_0", "kernel"])
trainable = ModelState.trainable_parameters(tied)

assert Map.has_key?(trainable, "dense_0")
refute Map.has_key?(trainable, "dense_1")
end
end
end
Loading