-
Notifications
You must be signed in to change notification settings - Fork 123
Add docs for serialization with NX.serialize #630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
| @@ -0,0 +1,187 @@ | ||||||
| # Saving and Loading Models | ||||||
|
|
||||||
| ## Section | ||||||
|
|
||||||
| ```elixir | ||||||
| Mix.install([ | ||||||
| {:axon, ">= 0.8.0"} | ||||||
| ]) | ||||||
| ``` | ||||||
|
|
||||||
| ## Overview | ||||||
|
|
||||||
| Axon recommends a **parameters-only** approach to saving models: serialize only the trained parameters (weights) using `Nx.serialize/2` and `Nx.deserialize/2`, and keep the model definition in your code. This approach: | ||||||
|
|
||||||
| * Avoids serialization issues with anonymous functions and complex model structures | ||||||
| * Makes the model structure explicit and version-controlled in code | ||||||
| * Works reliably across processes and deployments | ||||||
|
|
||||||
| The model itself is just code—you define it once and reuse it. Only the learned parameters need to be persisted. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| ## Saving a Model After Training | ||||||
|
|
||||||
| When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them: | ||||||
|
|
||||||
| ```elixir | ||||||
| model = | ||||||
| Axon.input("data") | ||||||
| |> Axon.dense(8) | ||||||
| |> Axon.relu() | ||||||
| |> Axon.dense(4) | ||||||
| |> Axon.relu() | ||||||
| |> Axon.dense(1) | ||||||
|
|
||||||
| loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd) | ||||||
|
|
||||||
| train_data = | ||||||
| Stream.repeatedly(fn -> | ||||||
| {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1}) | ||||||
| {xs, Nx.sin(xs)} | ||||||
| end) | ||||||
|
|
||||||
| trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100) | ||||||
| ``` | ||||||
|
|
||||||
| The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`: | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| ```elixir | ||||||
| # Extract parameters - ModelState.data contains the nested map of weights | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| params = | ||||||
| case trained_model_state do | ||||||
| %Axon.ModelState{data: data} -> data | ||||||
| params when is_map(params) -> params | ||||||
| end | ||||||
|
Comment on lines
+49
to
+53
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Is this case really necessary? |
||||||
|
|
||||||
| # Transfer to binary backend for reliable serialization (avoids issues with Nx 0.11+ on EXLA/Torchx) | ||||||
| params = Nx.backend_transfer(params) | ||||||
|
Comment on lines
+55
to
+56
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Which issues would these be? We should not need to transfer it to serialize. If there are issues, that's a bug in Nx 0.11 |
||||||
|
|
||||||
| # Serialize and save | ||||||
| params_bytes = Nx.serialize(params) | ||||||
| File.write!("model_params.axon", params_bytes) | ||||||
| ``` | ||||||
|
|
||||||
| ## Loading a Model for Inference | ||||||
|
|
||||||
| To load and run inference, you need: | ||||||
|
|
||||||
| 1. The model definition (in code—the same structure you trained) | ||||||
| 2. The saved parameters | ||||||
|
|
||||||
| ```elixir | ||||||
| # 1. Define the same model structure (must match training) | ||||||
| model = | ||||||
| Axon.input("data") | ||||||
| |> Axon.dense(8) | ||||||
| |> Axon.relu() | ||||||
| |> Axon.dense(4) | ||||||
| |> Axon.relu() | ||||||
| |> Axon.dense(1) | ||||||
|
|
||||||
| # 2. Load parameters | ||||||
| params = File.read!("model_params.axon") |> Nx.deserialize() | ||||||
|
|
||||||
| # 3. Run inference | ||||||
| input = Nx.tensor([[1.0]]) # shape {1, 1}: 1 sample with 1 feature (matches model input) | ||||||
| Axon.predict(model, params, %{"data" => input}) | ||||||
| ``` | ||||||
|
|
||||||
| ## Checkpointing During Training | ||||||
|
|
||||||
| To save checkpoints during training (e.g., every epoch or when validation improves), use `Axon.Loop.checkpoint/2`. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later. | ||||||
|
|
||||||
| ```elixir | ||||||
| model = | ||||||
| Axon.input("data") | ||||||
| |> Axon.dense(8) | ||||||
| |> Axon.relu() | ||||||
| |> Axon.dense(1) | ||||||
|
|
||||||
| loop = | ||||||
| model | ||||||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd) | ||||||
| |> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed) | ||||||
|
|
||||||
| train_data = | ||||||
| Stream.repeatedly(fn -> | ||||||
| {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1}) | ||||||
| {xs, Nx.sin(xs)} | ||||||
| end) | ||||||
|
|
||||||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50) | ||||||
| ``` | ||||||
|
|
||||||
| Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`. | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| ## Resuming from a Checkpoint | ||||||
|
|
||||||
| To resume training from a saved checkpoint: | ||||||
|
|
||||||
| 1. Load the checkpoint with `Axon.Loop.deserialize_state/2` | ||||||
| 2. Attach it to your loop with `Axon.Loop.from_state/2` | ||||||
| 3. Run the loop as usual | ||||||
|
|
||||||
| ```elixir | ||||||
| # Load the checkpoint (use the path from your checkpoint files) | ||||||
| checkpoint_path = "checkpoints/checkpoint_2_50.ckpt" | ||||||
| serialized = File.read!(checkpoint_path) | ||||||
| state = Axon.Loop.deserialize_state(serialized) | ||||||
|
|
||||||
| # Resume training | ||||||
| model = | ||||||
| Axon.input("data") | ||||||
| |> Axon.dense(8) | ||||||
| |> Axon.relu() | ||||||
| |> Axon.dense(1) | ||||||
|
|
||||||
|
Comment on lines
+115
to
+135
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @seanmor5 I think there's a bit of a dissonance between not having Axon.serialize/deserialize, while checkpoints need their Axon functions. WDYT? |
||||||
| loop = | ||||||
| model | ||||||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd) | ||||||
| |> Axon.Loop.from_state(state) | ||||||
|
|
||||||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50) | ||||||
| ``` | ||||||
|
|
||||||
| ## Saving Only Parameters from a Checkpoint | ||||||
|
|
||||||
| If you have a checkpoint file and want to extract parameters for inference (without optimizer state): | ||||||
|
|
||||||
| ```elixir | ||||||
| checkpoint_path = "checkpoints/checkpoint_2_50.ckpt" | ||||||
| state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state() | ||||||
|
|
||||||
| # Extract model parameters from step_state | ||||||
| %{model_state: model_state} = state.step_state | ||||||
| params = model_state.data |> Nx.backend_transfer() | ||||||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment about needing backend_transfer |
||||||
|
|
||||||
| # Save for inference | ||||||
| File.write!("model_params.axon", Nx.serialize(params)) | ||||||
| ``` | ||||||
|
|
||||||
| ## Troubleshooting: ArgumentError with Nx.serialize | ||||||
|
|
||||||
| If you see `(ArgumentError) argument error` or `:erlang.++` errors when calling `Nx.serialize/2` on parameters (common with Nx 0.11+ and EXLA/Torchx backends), transfer tensors to the binary backend first: | ||||||
|
|
||||||
| ```elixir | ||||||
| params = Nx.backend_transfer(params) | ||||||
| params_bytes = Nx.serialize(params) | ||||||
| ``` | ||||||
|
|
||||||
| If serialization still fails, you can use `:erlang.term_to_binary/2` when parameters are on the binary backend (e.g. after `Nx.backend_transfer/1`): | ||||||
|
|
||||||
| ```elixir | ||||||
| params = Nx.backend_transfer(params) | ||||||
| params_bytes = :erlang.term_to_binary(params) | ||||||
| File.write!("model_params.axon", params_bytes) | ||||||
|
|
||||||
| # To load: | ||||||
| params = File.read!("model_params.axon") |> :erlang.binary_to_term([:safe]) | ||||||
| ``` | ||||||
|
Comment on lines
+160
to
+178
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We should fix this and release 0.11.1 so we can merge this PR without these bug-related caveats |
||||||
|
|
||||||
| ## Summary | ||||||
|
|
||||||
| | Use Case | Save | Load | | ||||||
| | ------------------------------ | --------------------------------------------------------- | ---------------------------------------------------------- | | ||||||
| | Inference only | `Nx.serialize(params)` → file | `Nx.deserialize(file)` + model in code | | ||||||
| | Checkpoint (resume training) | `Axon.Loop.checkpoint/2` or `Axon.Loop.serialize_state/2` | `Axon.Loop.deserialize_state/2` + `Axon.Loop.from_state/2` | | ||||||
| | Extract params from checkpoint | `state.step_state.model_state.data` → `Nx.serialize` | Use with model in code | | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,131 @@ | ||
| defmodule Axon.SerializationGuideTest do | ||
| @moduledoc """ | ||
| Tests that validate the examples in guides/serialization/saving_and_loading.livemd. | ||
| Run with: mix test test/axon/serialization_guide_test.exs | ||
| """ | ||
| use Axon.Case, async: false | ||
|
|
||
| @tmp_path Path.join(System.tmp_dir!(), "axon_serialization_guide_test_#{:erlang.unique_integer([:positive])}") | ||
|
|
||
| setup do | ||
| File.mkdir_p!(@tmp_path) | ||
| on_exit(fn -> File.rm_rf!(@tmp_path) end) | ||
| [tmp_path: @tmp_path] | ||
| end | ||
|
|
||
| describe "saving and loading guide examples" do | ||
| test "full flow: train → save params → load → predict", %{tmp_path: tmp_path} do | ||
| # Same model as the guide | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(8) | ||
| |> Axon.relu() | ||
| |> Axon.dense(4) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd, log: 0) | ||
|
|
||
| train_data = | ||
| Stream.repeatedly(fn -> | ||
| {xs, _} = Nx.Random.normal(Nx.Random.key(:erlang.phash2({self(), System.unique_integer([:monotonic])})), shape: {8, 1}) | ||
| {xs, Nx.sin(xs)} | ||
| end) | ||
|
|
||
| # Train | ||
| trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 50) | ||
|
|
||
| # Extract and save params (as in guide) | ||
| params = | ||
| case trained_model_state do | ||
| %Axon.ModelState{data: data} -> data | ||
| params when is_map(params) -> params | ||
| end | ||
|
|
||
| params_path = Path.join(tmp_path, "model_params.axon") | ||
| params = Nx.backend_transfer(params) | ||
| params_bytes = Nx.serialize(params) | ||
| File.write!(params_path, params_bytes) | ||
|
|
||
| # Load and predict (input shape must match training: {batch, 1} for 1 feature) | ||
| loaded_params = File.read!(params_path) |> Nx.deserialize() | ||
| input = Nx.tensor([[1.0]]) | ||
|
|
||
| prediction = Axon.predict(model, loaded_params, %{"data" => input}) | ||
|
|
||
| assert Nx.rank(prediction) == 2 | ||
| assert Nx.shape(prediction) == {1, 1} | ||
| end | ||
|
|
||
| test "checkpoint and resume flow", %{tmp_path: tmp_path} do | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(4) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
||
| checkpoint_path = Path.join(tmp_path, "checkpoints") | ||
|
|
||
| loop = | ||
| model | ||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) | ||
| |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) | ||
|
|
||
| train_data = [ | ||
| {Nx.tensor([[1.0, 2.0, 3.0, 4.0]]), Nx.tensor([[1.0]])}, | ||
| {Nx.tensor([[2.0, 3.0, 4.0, 5.0]]), Nx.tensor([[2.0]])} | ||
| ] | ||
|
|
||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2) | ||
|
|
||
| # Verify checkpoint was saved | ||
| ckpt_files = File.ls!(checkpoint_path) |> Enum.sort() | ||
| assert length(ckpt_files) == 2 | ||
| assert Enum.any?(ckpt_files, &String.contains?(&1, "checkpoint_")) | ||
|
|
||
| # Load checkpoint and extract params for inference | ||
| ckpt_file = Path.join(checkpoint_path, List.first(ckpt_files)) | ||
| state = File.read!(ckpt_file) |> Axon.Loop.deserialize_state() | ||
|
|
||
| %{model_state: model_state} = state.step_state | ||
| params = model_state.data | ||
|
|
||
| # Run inference with extracted params | ||
| input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) | ||
| prediction = Axon.predict(model, params, %{"data" => input}) | ||
|
|
||
| assert Nx.rank(prediction) == 2 | ||
| assert Nx.shape(prediction) == {1, 1} | ||
| end | ||
|
|
||
| test "resume from checkpoint with from_state", %{tmp_path: tmp_path} do | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(2) | ||
| |> Axon.dense(1) | ||
|
|
||
| checkpoint_path = Path.join(tmp_path, "checkpoints_resume") | ||
|
|
||
| loop = | ||
| model | ||
| |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) | ||
| |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) | ||
|
|
||
| train_data = [{Nx.tensor([[1.0, 2.0]]), Nx.tensor([[1.0]])}] | ||
|
|
||
| # Run for 1 epoch | ||
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 1) | ||
|
|
||
| # Load checkpoint and resume | ||
| [ckpt_file] = File.ls!(checkpoint_path) | ||
| state = File.read!(Path.join(checkpoint_path, ckpt_file)) |> Axon.Loop.deserialize_state() | ||
|
|
||
| resumed_loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) |> Axon.Loop.from_state(state) | ||
|
|
||
| # Resume - should complete without error | ||
| result = Axon.Loop.run(resumed_loop, train_data, Axon.ModelState.empty(), epochs: 2) | ||
|
|
||
| assert %Axon.ModelState{} = result | ||
| end | ||
| end | ||
| end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.