diff --git a/lib/axon.ex b/lib/axon.ex index 764dd56b..3d9e22a6 100644 --- a/lib/axon.ex +++ b/lib/axon.ex @@ -3962,6 +3962,190 @@ defmodule Axon do end end + @doc """ + Captures a model or subgraph, returning a reusable function. + + This returns an arity-1 function that accepts new inputs and returns + the model with rewired inputs. For single-input models, pass an Axon + graph directly. For multi-input models, pass a map of input names to + Axon graphs. + + This is useful for transfer learning where you want to extract a + pretrained model's feature extractor: + + # Load a pretrained model + resnet = MyModels.resnet50() + + # Capture at the pooling layer (returns a function) + backbone = Axon.capture(resnet, to: "avg_pool") + + # Use with new inputs + new_input = Axon.input("my_features") + features = backbone.(new_input) + + # Add your own head + my_model = features + |> Axon.dense(256, activation: :relu) + |> Axon.dense(num_classes) + + For models with multiple inputs, pass a map: + + model = Axon.capture(multi_input_model) + output = model.(%{"image" => image_input, "text" => text_input}) + + You can also capture an entire model without truncation: + + encoder = Axon.capture(encoder_model) + encoded = encoder.(my_input) + + Layer names can be discovered using `Axon.properties/1`: + + Axon.properties(model) |> Map.keys() + + ## Options + + * `:to` - the name of the layer to capture as the output. If not + provided, captures the entire model. + + """ + @doc type: :graph + def capture(%Axon{} = axon, opts \\ []) when is_list(opts) do + truncated = + case Keyword.fetch(opts, :to) do + {:ok, layer_name} when is_binary(layer_name) -> + name_to_id = build_name_to_id_map(axon) + + target_id = + case Map.fetch(name_to_id, layer_name) do + {:ok, id} -> + id + + :error -> + available = name_to_id |> Map.keys() |> Enum.sort() + + raise ArgumentError, + "layer #{inspect(layer_name)} not found in model. " <> + "Available layers: #{inspect(available)}" + end + + %Axon{axon | output: target_id} + + {:ok, other} -> + raise ArgumentError, + "expected :to option to be a string layer name, got: #{inspect(other)}" + + :error -> + axon + end + + input_name_to_id = get_input_name_to_id_map(truncated) + + fn new_inputs -> + rewire_inputs(truncated, input_name_to_id, new_inputs) + end + end + + defp build_name_to_id_map(%Axon{output: id, nodes: nodes}) do + {name_to_id, _, _} = do_build_name_to_id(id, nodes, {%{}, %{}, %{}}) + name_to_id + end + + defp do_build_name_to_id(id, nodes, {_name_to_id, cache, _op_counts} = acc) do + case cache do + %{^id => _} -> + acc + + %{} -> + %Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id] + + {name_to_id, cache, op_counts} = + Enum.reduce(parents, acc, fn parent_id, acc -> + do_build_name_to_id(parent_id, nodes, acc) + end) + + name = name_fn.(op_name, op_counts) + op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end) + name_to_id = Map.put(name_to_id, name, id) + + {name_to_id, Map.put(cache, id, name), op_counts} + end + end + + defp get_input_name_to_id_map(%Axon{output: id, nodes: nodes}) do + {inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new()) + + inorder_nodes + |> Enum.filter(fn %Axon.Node{op: op} -> op == :input end) + |> Map.new(fn %Axon.Node{id: id, name: name_fn} -> + {name_fn.(:input, %{}), id} + end) + end + + defp rewire_inputs(%Axon{output: output_id, nodes: nodes}, input_name_to_id, new_inputs) do + new_inputs_map = + case new_inputs do + %Axon{} = single_input -> + if map_size(input_name_to_id) != 1 do + raise ArgumentError, + "model has #{map_size(input_name_to_id)} inputs, expected a map " <> + "with keys: #{inspect(Map.keys(input_name_to_id))}" + end + + [{input_name, _}] = Map.to_list(input_name_to_id) + %{input_name => single_input} + + %{} = inputs_map -> + inputs_map + + other -> + raise ArgumentError, + "expected an Axon graph or a map of input names to Axon graphs, " <> + "got: #{inspect(other)}" + end + + # Validate all expected inputs are provided + expected_names = Map.keys(input_name_to_id) |> MapSet.new() + provided_names = Map.keys(new_inputs_map) |> MapSet.new() + + missing = MapSet.difference(expected_names, provided_names) + extra = MapSet.difference(provided_names, expected_names) + + if MapSet.size(missing) > 0 do + raise ArgumentError, "missing inputs: #{inspect(MapSet.to_list(missing))}" + end + + if MapSet.size(extra) > 0 do + raise ArgumentError, "unexpected inputs: #{inspect(MapSet.to_list(extra))}" + end + + id_mapping = + Map.new(input_name_to_id, fn {name, old_id} -> + {old_id, new_inputs_map[name]} + end) + + merged_nodes = + Enum.reduce(Map.values(new_inputs_map), nodes, fn %Axon{nodes: new_nodes}, acc_nodes -> + Map.merge(new_nodes, acc_nodes) + end) + + merged_nodes = Map.drop(merged_nodes, Map.values(input_name_to_id)) + + updated_nodes = + Map.new(merged_nodes, fn {node_id, node} -> + updated_parents = + Enum.map(node.parent, fn parent_id -> + case Map.fetch(id_mapping, parent_id) do + {:ok, %Axon{output: new_output_id}} -> new_output_id + :error -> parent_id + end + end) + + {node_id, %{node | parent: updated_parents}} + end) + + %Axon{output: output_id, nodes: updated_nodes} + end + @doc """ Builds the given model to `{init_fn, predict_fn}`. diff --git a/test/axon_test.exs b/test/axon_test.exs index 87c4878e..c3ebda2c 100644 --- a/test/axon_test.exs +++ b/test/axon_test.exs @@ -893,4 +893,132 @@ defmodule AxonTest do assert %Axon.None{} = Axon.get_output_shape(model, %{"values" => Nx.template({1, 1}, :f32)}) end end + + describe "capture" do + test "captures entire model with single input" do + model = + Axon.input("features", shape: {nil, 10}) + |> Axon.dense(32, name: "dense1") + |> Axon.dense(16, name: "dense2") + + captured = Axon.capture(model) + assert is_function(captured, 1) + + # Use with new input + new_input = Axon.input("my_input", shape: {nil, 10}) + new_model = captured.(new_input) + + # Verify new model has the new input + assert %{"my_input" => _} = Axon.get_inputs(new_model) + refute Map.has_key?(Axon.get_inputs(new_model), "features") + + # Verify layers are preserved + props = Axon.properties(new_model) + assert Map.has_key?(props, "dense1") + assert Map.has_key?(props, "dense2") + end + + test "captures up to specific layer with :to option" do + model = + Axon.input("features", shape: {nil, 10}) + |> Axon.dense(32, name: "hidden1") + |> Axon.relu(name: "relu1") + |> Axon.dense(16, name: "hidden2") + |> Axon.dense(2, name: "output") + + captured = Axon.capture(model, to: "hidden2") + + new_input = Axon.input("x", shape: {nil, 10}) + new_model = captured.(new_input) + + props = Axon.properties(new_model) + assert Map.has_key?(props, "hidden1") + assert Map.has_key?(props, "relu1") + assert Map.has_key?(props, "hidden2") + refute Map.has_key?(props, "output") + end + + test "works with multiple inputs using map" do + input1 = Axon.input("image", shape: {nil, 784}) + input2 = Axon.input("text", shape: {nil, 128}) + + model = + Axon.concatenate(input1, input2) + |> Axon.dense(64, name: "combined") + + captured = Axon.capture(model) + + new_image = Axon.input("my_image", shape: {nil, 784}) + new_text = Axon.input("my_text", shape: {nil, 128}) + + new_model = captured.(%{"image" => new_image, "text" => new_text}) + + inputs = Axon.get_inputs(new_model) + assert Map.has_key?(inputs, "my_image") + assert Map.has_key?(inputs, "my_text") + refute Map.has_key?(inputs, "image") + refute Map.has_key?(inputs, "text") + end + + test "raises on invalid layer name" do + model = + Axon.input("features", shape: {nil, 10}) + |> Axon.dense(32, name: "dense1") + + assert_raise ArgumentError, ~r/layer "nonexistent" not found/, fn -> + Axon.capture(model, to: "nonexistent") + end + end + + test "raises on missing inputs for multi-input model" do + input1 = Axon.input("a", shape: {nil, 10}) + input2 = Axon.input("b", shape: {nil, 10}) + + model = Axon.add(input1, input2) |> Axon.dense(5) + captured = Axon.capture(model) + + new_a = Axon.input("new_a", shape: {nil, 10}) + + assert_raise ArgumentError, ~r/missing inputs/, fn -> + captured.(%{"a" => new_a}) + end + end + + test "raises when single input passed to multi-input model" do + input1 = Axon.input("a", shape: {nil, 10}) + input2 = Axon.input("b", shape: {nil, 10}) + + model = Axon.add(input1, input2) + captured = Axon.capture(model) + + single_input = Axon.input("x", shape: {nil, 10}) + + assert_raise ArgumentError, ~r/model has 2 inputs/, fn -> + captured.(single_input) + end + end + + test "captured model can be executed" do + model = + Axon.input("features", shape: {nil, 2}) + |> Axon.dense(4, name: "dense1", kernel_initializer: :ones, bias_initializer: :zeros) + |> Axon.relu(name: "relu") + |> Axon.dense(2, name: "dense2", kernel_initializer: :ones, bias_initializer: :zeros) + + # Capture up to relu + captured = Axon.capture(model, to: "relu") + new_input = Axon.input("x", shape: {nil, 2}) + new_model = captured.(new_input) + + # Build and run + {init_fn, predict_fn} = Axon.build(new_model) + params = init_fn.(Nx.template({1, 2}, :f32), Axon.ModelState.empty()) + + input = Nx.tensor([[1.0, 2.0]]) + result = predict_fn.(params, input) + + # With ones kernel: [1,2] dot ones(2,4) = [3,3,3,3], relu keeps it positive + assert Nx.shape(result) == {1, 4} + end + end end