-
Notifications
You must be signed in to change notification settings - Fork 123
Add Axon.capture
#626
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Add Axon.capture
#626
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -3962,6 +3962,190 @@ defmodule Axon do | |||||
| end | ||||||
| end | ||||||
|
|
||||||
| @doc """ | ||||||
| Captures a model or subgraph, returning a reusable function. | ||||||
|
|
||||||
| This returns an arity-1 function that accepts new inputs and returns | ||||||
| the model with rewired inputs. For single-input models, pass an Axon | ||||||
| graph directly. For multi-input models, pass a map of input names to | ||||||
| Axon graphs. | ||||||
|
|
||||||
| This is useful for transfer learning where you want to extract a | ||||||
| pretrained model's feature extractor: | ||||||
|
|
||||||
| # Load a pretrained model | ||||||
| resnet = MyModels.resnet50() | ||||||
|
|
||||||
| # Capture at the pooling layer (returns a function) | ||||||
| backbone = Axon.capture(resnet, to: "avg_pool") | ||||||
|
|
||||||
| # Use with new inputs | ||||||
| new_input = Axon.input("my_features") | ||||||
| features = backbone.(new_input) | ||||||
|
|
||||||
| # Add your own head | ||||||
| my_model = features | ||||||
| |> Axon.dense(256, activation: :relu) | ||||||
| |> Axon.dense(num_classes) | ||||||
|
|
||||||
| For models with multiple inputs, pass a map: | ||||||
|
|
||||||
| model = Axon.capture(multi_input_model) | ||||||
| output = model.(%{"image" => image_input, "text" => text_input}) | ||||||
|
|
||||||
| You can also capture an entire model without truncation: | ||||||
|
|
||||||
| encoder = Axon.capture(encoder_model) | ||||||
| encoded = encoder.(my_input) | ||||||
|
|
||||||
| Layer names can be discovered using `Axon.properties/1`: | ||||||
|
|
||||||
| Axon.properties(model) |> Map.keys() | ||||||
|
|
||||||
| ## Options | ||||||
|
|
||||||
| * `:to` - the name of the layer to capture as the output. If not | ||||||
| provided, captures the entire model. | ||||||
|
Comment on lines
+4005
to
+4008
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is a bit confusing without a more didactic example, with explicit layers and how they interact. For instance, what happens if you have a multi-output model and you capture |
||||||
|
|
||||||
| """ | ||||||
| @doc type: :graph | ||||||
| def capture(%Axon{} = axon, opts \\ []) when is_list(opts) do | ||||||
| truncated = | ||||||
| case Keyword.fetch(opts, :to) do | ||||||
| {:ok, layer_name} when is_binary(layer_name) -> | ||||||
| name_to_id = build_name_to_id_map(axon) | ||||||
|
|
||||||
| target_id = | ||||||
| case Map.fetch(name_to_id, layer_name) do | ||||||
| {:ok, id} -> | ||||||
| id | ||||||
|
|
||||||
| :error -> | ||||||
| available = name_to_id |> Map.keys() |> Enum.sort() | ||||||
|
|
||||||
| raise ArgumentError, | ||||||
| "layer #{inspect(layer_name)} not found in model. " <> | ||||||
| "Available layers: #{inspect(available)}" | ||||||
| end | ||||||
|
|
||||||
| %Axon{axon | output: target_id} | ||||||
|
|
||||||
| {:ok, other} -> | ||||||
| raise ArgumentError, | ||||||
| "expected :to option to be a string layer name, got: #{inspect(other)}" | ||||||
|
|
||||||
| :error -> | ||||||
| axon | ||||||
| end | ||||||
|
|
||||||
| input_name_to_id = get_input_name_to_id_map(truncated) | ||||||
|
|
||||||
| fn new_inputs -> | ||||||
| rewire_inputs(truncated, input_name_to_id, new_inputs) | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
| defp build_name_to_id_map(%Axon{output: id, nodes: nodes}) do | ||||||
| {name_to_id, _, _} = do_build_name_to_id(id, nodes, {%{}, %{}, %{}}) | ||||||
| name_to_id | ||||||
| end | ||||||
|
|
||||||
| defp do_build_name_to_id(id, nodes, {_name_to_id, cache, _op_counts} = acc) do | ||||||
| case cache do | ||||||
| %{^id => _} -> | ||||||
| acc | ||||||
|
|
||||||
| %{} -> | ||||||
| %Axon.Node{parent: parents, name: name_fn, op_name: op_name} = nodes[id] | ||||||
|
|
||||||
| {name_to_id, cache, op_counts} = | ||||||
| Enum.reduce(parents, acc, fn parent_id, acc -> | ||||||
| do_build_name_to_id(parent_id, nodes, acc) | ||||||
| end) | ||||||
|
|
||||||
| name = name_fn.(op_name, op_counts) | ||||||
| op_counts = Map.update(op_counts, op_name, 1, fn x -> x + 1 end) | ||||||
| name_to_id = Map.put(name_to_id, name, id) | ||||||
|
|
||||||
| {name_to_id, Map.put(cache, id, name), op_counts} | ||||||
| end | ||||||
| end | ||||||
|
|
||||||
| defp get_input_name_to_id_map(%Axon{output: id, nodes: nodes}) do | ||||||
| {inorder_nodes, _} = traverse_nodes(id, nodes, [], MapSet.new()) | ||||||
|
|
||||||
| inorder_nodes | ||||||
| |> Enum.filter(fn %Axon.Node{op: op} -> op == :input end) | ||||||
| |> Map.new(fn %Axon.Node{id: id, name: name_fn} -> | ||||||
| {name_fn.(:input, %{}), id} | ||||||
| end) | ||||||
| end | ||||||
|
|
||||||
| defp rewire_inputs(%Axon{output: output_id, nodes: nodes}, input_name_to_id, new_inputs) do | ||||||
| new_inputs_map = | ||||||
| case new_inputs do | ||||||
| %Axon{} = single_input -> | ||||||
| if map_size(input_name_to_id) != 1 do | ||||||
| raise ArgumentError, | ||||||
| "model has #{map_size(input_name_to_id)} inputs, expected a map " <> | ||||||
| "with keys: #{inspect(Map.keys(input_name_to_id))}" | ||||||
| end | ||||||
|
|
||||||
| [{input_name, _}] = Map.to_list(input_name_to_id) | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. (style suggestion) |
||||||
| %{input_name => single_input} | ||||||
|
|
||||||
| %{} = inputs_map -> | ||||||
| inputs_map | ||||||
|
|
||||||
| other -> | ||||||
| raise ArgumentError, | ||||||
| "expected an Axon graph or a map of input names to Axon graphs, " <> | ||||||
| "got: #{inspect(other)}" | ||||||
| end | ||||||
|
|
||||||
| # Validate all expected inputs are provided | ||||||
| expected_names = Map.keys(input_name_to_id) |> MapSet.new() | ||||||
| provided_names = Map.keys(new_inputs_map) |> MapSet.new() | ||||||
|
|
||||||
| missing = MapSet.difference(expected_names, provided_names) | ||||||
| extra = MapSet.difference(provided_names, expected_names) | ||||||
|
Comment on lines
+4110
to
+4111
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think you can use
|
||||||
|
|
||||||
| if MapSet.size(missing) > 0 do | ||||||
| raise ArgumentError, "missing inputs: #{inspect(MapSet.to_list(missing))}" | ||||||
| end | ||||||
|
|
||||||
| if MapSet.size(extra) > 0 do | ||||||
| raise ArgumentError, "unexpected inputs: #{inspect(MapSet.to_list(extra))}" | ||||||
| end | ||||||
|
|
||||||
| id_mapping = | ||||||
| Map.new(input_name_to_id, fn {name, old_id} -> | ||||||
| {old_id, new_inputs_map[name]} | ||||||
| end) | ||||||
|
|
||||||
| merged_nodes = | ||||||
| Enum.reduce(Map.values(new_inputs_map), nodes, fn %Axon{nodes: new_nodes}, acc_nodes -> | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
slightly more efficient |
||||||
| Map.merge(new_nodes, acc_nodes) | ||||||
| end) | ||||||
|
|
||||||
| merged_nodes = Map.drop(merged_nodes, Map.values(input_name_to_id)) | ||||||
|
|
||||||
| updated_nodes = | ||||||
| Map.new(merged_nodes, fn {node_id, node} -> | ||||||
| updated_parents = | ||||||
| Enum.map(node.parent, fn parent_id -> | ||||||
| case Map.fetch(id_mapping, parent_id) do | ||||||
| {:ok, %Axon{output: new_output_id}} -> new_output_id | ||||||
| :error -> parent_id | ||||||
| end | ||||||
| end) | ||||||
|
|
||||||
| {node_id, %{node | parent: updated_parents}} | ||||||
| end) | ||||||
|
|
||||||
| %Axon{output: output_id, nodes: updated_nodes} | ||||||
| end | ||||||
|
|
||||||
| @doc """ | ||||||
| Builds the given model to `{init_fn, predict_fn}`. | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -893,4 +893,132 @@ defmodule AxonTest do | |
| assert %Axon.None{} = Axon.get_output_shape(model, %{"values" => Nx.template({1, 1}, :f32)}) | ||
| end | ||
| end | ||
|
|
||
| describe "capture" do | ||
| test "captures entire model with single input" do | ||
| model = | ||
| Axon.input("features", shape: {nil, 10}) | ||
| |> Axon.dense(32, name: "dense1") | ||
| |> Axon.dense(16, name: "dense2") | ||
|
|
||
| captured = Axon.capture(model) | ||
| assert is_function(captured, 1) | ||
|
|
||
| # Use with new input | ||
| new_input = Axon.input("my_input", shape: {nil, 10}) | ||
| new_model = captured.(new_input) | ||
|
|
||
| # Verify new model has the new input | ||
| assert %{"my_input" => _} = Axon.get_inputs(new_model) | ||
| refute Map.has_key?(Axon.get_inputs(new_model), "features") | ||
|
|
||
| # Verify layers are preserved | ||
| props = Axon.properties(new_model) | ||
| assert Map.has_key?(props, "dense1") | ||
| assert Map.has_key?(props, "dense2") | ||
| end | ||
|
|
||
| test "captures up to specific layer with :to option" do | ||
| model = | ||
| Axon.input("features", shape: {nil, 10}) | ||
| |> Axon.dense(32, name: "hidden1") | ||
| |> Axon.relu(name: "relu1") | ||
| |> Axon.dense(16, name: "hidden2") | ||
| |> Axon.dense(2, name: "output") | ||
|
|
||
|
Comment on lines
+921
to
+928
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this should be a doctest or at least an example in the docstring |
||
| captured = Axon.capture(model, to: "hidden2") | ||
|
|
||
| new_input = Axon.input("x", shape: {nil, 10}) | ||
| new_model = captured.(new_input) | ||
|
|
||
| props = Axon.properties(new_model) | ||
| assert Map.has_key?(props, "hidden1") | ||
| assert Map.has_key?(props, "relu1") | ||
| assert Map.has_key?(props, "hidden2") | ||
| refute Map.has_key?(props, "output") | ||
| end | ||
|
|
||
| test "works with multiple inputs using map" do | ||
| input1 = Axon.input("image", shape: {nil, 784}) | ||
| input2 = Axon.input("text", shape: {nil, 128}) | ||
|
|
||
| model = | ||
| Axon.concatenate(input1, input2) | ||
| |> Axon.dense(64, name: "combined") | ||
|
|
||
| captured = Axon.capture(model) | ||
|
|
||
| new_image = Axon.input("my_image", shape: {nil, 784}) | ||
| new_text = Axon.input("my_text", shape: {nil, 128}) | ||
|
|
||
| new_model = captured.(%{"image" => new_image, "text" => new_text}) | ||
|
|
||
| inputs = Axon.get_inputs(new_model) | ||
| assert Map.has_key?(inputs, "my_image") | ||
| assert Map.has_key?(inputs, "my_text") | ||
| refute Map.has_key?(inputs, "image") | ||
| refute Map.has_key?(inputs, "text") | ||
| end | ||
|
|
||
| test "raises on invalid layer name" do | ||
| model = | ||
| Axon.input("features", shape: {nil, 10}) | ||
| |> Axon.dense(32, name: "dense1") | ||
|
|
||
| assert_raise ArgumentError, ~r/layer "nonexistent" not found/, fn -> | ||
| Axon.capture(model, to: "nonexistent") | ||
| end | ||
| end | ||
|
|
||
| test "raises on missing inputs for multi-input model" do | ||
| input1 = Axon.input("a", shape: {nil, 10}) | ||
| input2 = Axon.input("b", shape: {nil, 10}) | ||
|
|
||
| model = Axon.add(input1, input2) |> Axon.dense(5) | ||
| captured = Axon.capture(model) | ||
|
|
||
| new_a = Axon.input("new_a", shape: {nil, 10}) | ||
|
|
||
| assert_raise ArgumentError, ~r/missing inputs/, fn -> | ||
| captured.(%{"a" => new_a}) | ||
| end | ||
| end | ||
|
|
||
| test "raises when single input passed to multi-input model" do | ||
| input1 = Axon.input("a", shape: {nil, 10}) | ||
| input2 = Axon.input("b", shape: {nil, 10}) | ||
|
|
||
| model = Axon.add(input1, input2) | ||
| captured = Axon.capture(model) | ||
|
|
||
| single_input = Axon.input("x", shape: {nil, 10}) | ||
|
|
||
| assert_raise ArgumentError, ~r/model has 2 inputs/, fn -> | ||
| captured.(single_input) | ||
| end | ||
| end | ||
|
|
||
| test "captured model can be executed" do | ||
| model = | ||
| Axon.input("features", shape: {nil, 2}) | ||
| |> Axon.dense(4, name: "dense1", kernel_initializer: :ones, bias_initializer: :zeros) | ||
| |> Axon.relu(name: "relu") | ||
| |> Axon.dense(2, name: "dense2", kernel_initializer: :ones, bias_initializer: :zeros) | ||
|
|
||
| # Capture up to relu | ||
| captured = Axon.capture(model, to: "relu") | ||
| new_input = Axon.input("x", shape: {nil, 2}) | ||
| new_model = captured.(new_input) | ||
|
|
||
| # Build and run | ||
| {init_fn, predict_fn} = Axon.build(new_model) | ||
| params = init_fn.(Nx.template({1, 2}, :f32), Axon.ModelState.empty()) | ||
|
|
||
| input = Nx.tensor([[1.0, 2.0]]) | ||
| result = predict_fn.(params, input) | ||
|
|
||
| # With ones kernel: [1,2] dot ones(2,4) = [3,3,3,3], relu keeps it positive | ||
| assert Nx.shape(result) == {1, 4} | ||
| end | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. there's a test missing for the multi-output case I mentioned in the doc coment |
||
| end | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.