Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
184 changes: 184 additions & 0 deletions lib/axon.ex
Original file line number Diff line number Diff line change
Expand Up @@ -3962,6 +3962,190 @@ defmodule Axon do
end
end

@doc """
Captures a model or subgraph, returning a reusable function.

This returns an arity-1 function that accepts new inputs and returns
the model with rewired inputs. For single-input models, pass an Axon
graph directly. For multi-input models, pass a map of input names to
Axon graphs.

This is useful for transfer learning where you want to extract a
pretrained model's feature extractor:

# Load a pretrained model
resnet = MyModels.resnet50()

# Capture at the pooling layer (returns a function)
backbone = Axon.capture(resnet, to: "avg_pool")

# Use with new inputs
new_input = Axon.input("my_features")
features = backbone.(new_input)

# Add your own head
my_model = features
|> Axon.dense(256, activation: :relu)
|> Axon.dense(num_classes)
Comment on lines +3986 to +3989
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Add your own head
my_model = features
|> Axon.dense(256, activation: :relu)
|> Axon.dense(num_classes)
# Add your own head
my_model =
features
|> Axon.dense(256, activation: :relu)
|> Axon.dense(num_classes)


For models with multiple inputs, pass a map:

model = Axon.capture(multi_input_model)
output = model.(%{"image" => image_input, "text" => text_input})

You can also capture an entire model without truncation:

encoder = Axon.capture(encoder_model)
encoded = encoder.(my_input)

Layer names can be discovered using `Axon.properties/1`:

Axon.properties(model) |> Map.keys()

## Options

* `:to` - the name of the layer to capture as the output. If not
provided, captures the entire model.
Comment on lines +4005 to +4008
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is a bit confusing without a more didactic example, with explicit layers and how they interact. For instance, what happens if you have a multi-output model and you capture to: "some-layer-after-the-head-split"


"""
@doc type: :graph
def capture(%Axon{} = axon, opts \\ []) when is_list(opts) do
truncated =
case Keyword.fetch(opts, :to) do
{:ok, layer_name} when is_binary(layer_name) ->
name_to_id = build_name_to_id_map(axon)

target_id =
case Map.fetch(name_to_id, layer_name) do
{:ok, id} ->
id

:error ->
available = name_to_id |> Map.keys() |> Enum.sort()

raise ArgumentError,
"layer #{inspect(layer_name)} not found in model. " <>
"Available layers: #{inspect(available)}"
end

%Axon{axon | output: target_id}

{:ok, other} ->
raise ArgumentError,
"expected :to option to be a string layer name, got: #{inspect(other)}"

:error ->
axon
end

input_name_to_id = get_input_name_to_id_map(truncated)

fn new_inputs ->
rewire_inputs(truncated, input_name_to_id, new_inputs)
end
end

defp build_name_to_id_map(%Axon{output: id, nodes: nodes}) do
{name_to_id, _, _} = do_build_name_to_id(id, nodes, {%{}, %{}, %{}})
name_to_id
end

defp do_build_name_to_id(id, nodes, {_name_to_id, cache, _op_counts} = acc) do
case cache do
%{^id => _} ->
acc

%{} ->
%Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id]

{name_to_id, cache, op_counts} =
Enum.reduce(parents, acc, fn parent_id, acc ->
do_build_name_to_id(parent_id, nodes, acc)
end)

name = name_fn.(op_name, op_counts)
op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end)
name_to_id = Map.put(name_to_id, name, id)

{name_to_id, Map.put(cache, id, name), op_counts}
end
end

defp get_input_name_to_id_map(%Axon{output: id, nodes: nodes}) do
{inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new())

inorder_nodes
|> Enum.filter(fn %Axon.Node{op: op} -> op == :input end)
|> Map.new(fn %Axon.Node{id: id, name: name_fn} ->
{name_fn.(:input, %{}), id}
end)
end

defp rewire_inputs(%Axon{output: output_id, nodes: nodes}, input_name_to_id, new_inputs) do
new_inputs_map =
case new_inputs do
%Axon{} = single_input ->
if map_size(input_name_to_id) != 1 do
raise ArgumentError,
"model has #{map_size(input_name_to_id)} inputs, expected a map " <>
"with keys: #{inspect(Map.keys(input_name_to_id))}"
end

[{input_name, _}] = Map.to_list(input_name_to_id)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
[{input_name, _}] = Map.to_list(input_name_to_id)
{input_name, _} = Enum.fetch!(input_name_to_id, 0)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(style suggestion)

%{input_name => single_input}

%{} = inputs_map ->
inputs_map

other ->
raise ArgumentError,
"expected an Axon graph or a map of input names to Axon graphs, " <>
"got: #{inspect(other)}"
end

# Validate all expected inputs are provided
expected_names = Map.keys(input_name_to_id) |> MapSet.new()
provided_names = Map.keys(new_inputs_map) |> MapSet.new()

missing = MapSet.difference(expected_names, provided_names)
extra = MapSet.difference(provided_names, expected_names)
Comment on lines +4110 to +4111
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you can use -- to avoid converting to and from MapSets all the time.

missing = expected_names -- provides
extra = provided_names -- expected_names


if MapSet.size(missing) > 0 do
raise ArgumentError, "missing inputs: #{inspect(MapSet.to_list(missing))}"
end

if MapSet.size(extra) > 0 do
raise ArgumentError, "unexpected inputs: #{inspect(MapSet.to_list(extra))}"
end

id_mapping =
Map.new(input_name_to_id, fn {name, old_id} ->
{old_id, new_inputs_map[name]}
end)

merged_nodes =
Enum.reduce(Map.values(new_inputs_map), nodes, fn %Axon{nodes: new_nodes}, acc_nodes ->
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Enum.reduce(Map.values(new_inputs_map), nodes, fn %Axon{nodes: new_nodes}, acc_nodes ->
Enum.reduce(new_inputs_map, nodes, fn {_k, %Axon{nodes: new_nodes}}, acc_nodes ->

slightly more efficient

Map.merge(new_nodes, acc_nodes)
end)

merged_nodes = Map.drop(merged_nodes, Map.values(input_name_to_id))

updated_nodes =
Map.new(merged_nodes, fn {node_id, node} ->
updated_parents =
Enum.map(node.parent, fn parent_id ->
case Map.fetch(id_mapping, parent_id) do
{:ok, %Axon{output: new_output_id}} -> new_output_id
:error -> parent_id
end
end)

{node_id, %{node | parent: updated_parents}}
end)

%Axon{output: output_id, nodes: updated_nodes}
end

@doc """
Builds the given model to `{init_fn, predict_fn}`.

Expand Down
128 changes: 128 additions & 0 deletions test/axon_test.exs
Original file line number Diff line number Diff line change
Expand Up @@ -893,4 +893,132 @@ defmodule AxonTest do
assert %Axon.None{} = Axon.get_output_shape(model, %{"values" => Nx.template({1, 1}, :f32)})
end
end

describe "capture" do
test "captures entire model with single input" do
model =
Axon.input("features", shape: {nil, 10})
|> Axon.dense(32, name: "dense1")
|> Axon.dense(16, name: "dense2")

captured = Axon.capture(model)
assert is_function(captured, 1)

# Use with new input
new_input = Axon.input("my_input", shape: {nil, 10})
new_model = captured.(new_input)

# Verify new model has the new input
assert %{"my_input" => _} = Axon.get_inputs(new_model)
refute Map.has_key?(Axon.get_inputs(new_model), "features")

# Verify layers are preserved
props = Axon.properties(new_model)
assert Map.has_key?(props, "dense1")
assert Map.has_key?(props, "dense2")
end

test "captures up to specific layer with :to option" do
model =
Axon.input("features", shape: {nil, 10})
|> Axon.dense(32, name: "hidden1")
|> Axon.relu(name: "relu1")
|> Axon.dense(16, name: "hidden2")
|> Axon.dense(2, name: "output")

Comment on lines +921 to +928
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this should be a doctest or at least an example in the docstring

captured = Axon.capture(model, to: "hidden2")

new_input = Axon.input("x", shape: {nil, 10})
new_model = captured.(new_input)

props = Axon.properties(new_model)
assert Map.has_key?(props, "hidden1")
assert Map.has_key?(props, "relu1")
assert Map.has_key?(props, "hidden2")
refute Map.has_key?(props, "output")
end

test "works with multiple inputs using map" do
input1 = Axon.input("image", shape: {nil, 784})
input2 = Axon.input("text", shape: {nil, 128})

model =
Axon.concatenate(input1, input2)
|> Axon.dense(64, name: "combined")

captured = Axon.capture(model)

new_image = Axon.input("my_image", shape: {nil, 784})
new_text = Axon.input("my_text", shape: {nil, 128})

new_model = captured.(%{"image" => new_image, "text" => new_text})

inputs = Axon.get_inputs(new_model)
assert Map.has_key?(inputs, "my_image")
assert Map.has_key?(inputs, "my_text")
refute Map.has_key?(inputs, "image")
refute Map.has_key?(inputs, "text")
end

test "raises on invalid layer name" do
model =
Axon.input("features", shape: {nil, 10})
|> Axon.dense(32, name: "dense1")

assert_raise ArgumentError, ~r/layer "nonexistent" not found/, fn ->
Axon.capture(model, to: "nonexistent")
end
end

test "raises on missing inputs for multi-input model" do
input1 = Axon.input("a", shape: {nil, 10})
input2 = Axon.input("b", shape: {nil, 10})

model = Axon.add(input1, input2) |> Axon.dense(5)
captured = Axon.capture(model)

new_a = Axon.input("new_a", shape: {nil, 10})

assert_raise ArgumentError, ~r/missing inputs/, fn ->
captured.(%{"a" => new_a})
end
end

test "raises when single input passed to multi-input model" do
input1 = Axon.input("a", shape: {nil, 10})
input2 = Axon.input("b", shape: {nil, 10})

model = Axon.add(input1, input2)
captured = Axon.capture(model)

single_input = Axon.input("x", shape: {nil, 10})

assert_raise ArgumentError, ~r/model has 2 inputs/, fn ->
captured.(single_input)
end
end

test "captured model can be executed" do
model =
Axon.input("features", shape: {nil, 2})
|> Axon.dense(4, name: "dense1", kernel_initializer: :ones, bias_initializer: :zeros)
|> Axon.relu(name: "relu")
|> Axon.dense(2, name: "dense2", kernel_initializer: :ones, bias_initializer: :zeros)

# Capture up to relu
captured = Axon.capture(model, to: "relu")
new_input = Axon.input("x", shape: {nil, 2})
new_model = captured.(new_input)

# Build and run
{init_fn, predict_fn} = Axon.build(new_model)
params = init_fn.(Nx.template({1, 2}, :f32), Axon.ModelState.empty())

input = Nx.tensor([[1.0, 2.0]])
result = predict_fn.(params, input)

# With ones kernel: [1,2] dot ones(2,4) = [3,3,3,3], relu keeps it positive
assert Nx.shape(result) == {1, 4}
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

there's a test missing for the multi-output case I mentioned in the doc coment

end
end
Loading