From 38be77019d09ae220b9ade95bf0165400c11bcc1 Mon Sep 17 00:00:00 2001 From: Andrew Phillipo Date: Thu, 5 Mar 2026 12:56:30 +0000 Subject: [PATCH] Add docs for serialization with NX.serialize --- CHANGELOG.md | 4 + guides/guides.md | 2 +- .../serialization/saving_and_loading.livemd | 187 ++++++++++++++++++ lib/axon/loop.ex | 2 +- test/axon/serialization_guide_test.exs | 131 ++++++++++++ 5 files changed, 324 insertions(+), 2 deletions(-) create mode 100644 guides/serialization/saving_and_loading.livemd create mode 100644 test/axon/serialization_guide_test.exs diff --git a/CHANGELOG.md b/CHANGELOG.md index f57d4340e..74460526e 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -15,6 +15,10 @@ ## v0.7.0 (2024-10-08) +### Breaking Changes + +* **Removed `Axon.serialize/2` and `Axon.deserialize/2`** — Use `Nx.serialize/2` and `Nx.deserialize/2` for parameters instead. Axon recommends serializing only the trained parameters (weights) and keeping the model definition in code. See the [Saving and Loading](guides/serialization/saving_and_loading.livemd) guide. + ### Bug Fixes * Do not cast integers in in Axon.MixedPrecision.cast/2 diff --git a/guides/guides.md b/guides/guides.md index 4f3a32eac..8c0bbc7e0 100644 --- a/guides/guides.md +++ b/guides/guides.md @@ -28,5 +28,5 @@ Axon is a library for creating and training neural networks in Elixir. The Axon ## Serialization -* [Converting ONNX models to Axon](serialization/onnx_to_axon.livemd) +* [Saving and loading models](serialization/saving_and_loading.livemd) diff --git a/guides/serialization/saving_and_loading.livemd b/guides/serialization/saving_and_loading.livemd new file mode 100644 index 000000000..73fb1b186 --- /dev/null +++ b/guides/serialization/saving_and_loading.livemd @@ -0,0 +1,187 @@ +# Saving and Loading Models + +## Section + +```elixir +Mix.install([ + {:axon, ">= 0.8.0"} +]) +``` + +## Overview + +Axon recommends a **parameters-only** approach to saving models: serialize only the trained parameters (weights) using `Nx.serialize/2` and `Nx.deserialize/2`, and keep the model definition in your code. This approach: + +* Avoids serialization issues with anonymous functions and complex model structures +* Makes the model structure explicit and version-controlled in code +* Works reliably across processes and deployments + +The model itself is just code—you define it once and reuse it. Only the learned parameters need to be persisted. + +## Saving a Model After Training + +When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them: + +```elixir +model = + Axon.input("data") + |> Axon.dense(8) + |> Axon.relu() + |> Axon.dense(4) + |> Axon.relu() + |> Axon.dense(1) + +loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd) + +train_data = + Stream.repeatedly(fn -> + {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1}) + {xs, Nx.sin(xs)} + end) + +trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100) +``` + +The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`: + +```elixir +# Extract parameters - ModelState.data contains the nested map of weights +params = + case trained_model_state do + %Axon.ModelState{data: data} -> data + params when is_map(params) -> params + end + +# Transfer to binary backend for reliable serialization (avoids issues with Nx 0.11+ on EXLA/Torchx) +params = Nx.backend_transfer(params) + +# Serialize and save +params_bytes = Nx.serialize(params) +File.write!("model_params.axon", params_bytes) +``` + +## Loading a Model for Inference + +To load and run inference, you need: + +1. The model definition (in code—the same structure you trained) +2. The saved parameters + +```elixir +# 1. Define the same model structure (must match training) +model = + Axon.input("data") + |> Axon.dense(8) + |> Axon.relu() + |> Axon.dense(4) + |> Axon.relu() + |> Axon.dense(1) + +# 2. Load parameters +params = File.read!("model_params.axon") |> Nx.deserialize() + +# 3. Run inference +input = Nx.tensor([[1.0]]) # shape {1, 1}: 1 sample with 1 feature (matches model input) +Axon.predict(model, params, %{"data" => input}) +``` + +## Checkpointing During Training + +To save checkpoints during training (e.g., every epoch or when validation improves), use `Axon.Loop.checkpoint/2`. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later. + +```elixir +model = + Axon.input("data") + |> Axon.dense(8) + |> Axon.relu() + |> Axon.dense(1) + +loop = + model + |> Axon.Loop.trainer(:mean_squared_error, :sgd) + |> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed) + +train_data = + Stream.repeatedly(fn -> + {xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1}) + {xs, Nx.sin(xs)} + end) + +Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50) +``` + +Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`. + +## Resuming from a Checkpoint + +To resume training from a saved checkpoint: + +1. Load the checkpoint with `Axon.Loop.deserialize_state/2` +2. Attach it to your loop with `Axon.Loop.from_state/2` +3. Run the loop as usual + +```elixir +# Load the checkpoint (use the path from your checkpoint files) +checkpoint_path = "checkpoints/checkpoint_2_50.ckpt" +serialized = File.read!(checkpoint_path) +state = Axon.Loop.deserialize_state(serialized) + +# Resume training +model = + Axon.input("data") + |> Axon.dense(8) + |> Axon.relu() + |> Axon.dense(1) + +loop = + model + |> Axon.Loop.trainer(:mean_squared_error, :sgd) + |> Axon.Loop.from_state(state) + +Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50) +``` + +## Saving Only Parameters from a Checkpoint + +If you have a checkpoint file and want to extract parameters for inference (without optimizer state): + +```elixir +checkpoint_path = "checkpoints/checkpoint_2_50.ckpt" +state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state() + +# Extract model parameters from step_state +%{model_state: model_state} = state.step_state +params = model_state.data |> Nx.backend_transfer() + +# Save for inference +File.write!("model_params.axon", Nx.serialize(params)) +``` + +## Troubleshooting: ArgumentError with Nx.serialize + +If you see `(ArgumentError) argument error` or `:erlang.++` errors when calling `Nx.serialize/2` on parameters (common with Nx 0.11+ and EXLA/Torchx backends), transfer tensors to the binary backend first: + +```elixir +params = Nx.backend_transfer(params) +params_bytes = Nx.serialize(params) +``` + +If serialization still fails, you can use `:erlang.term_to_binary/2` when parameters are on the binary backend (e.g. after `Nx.backend_transfer/1`): + +```elixir +params = Nx.backend_transfer(params) +params_bytes = :erlang.term_to_binary(params) +File.write!("model_params.axon", params_bytes) + +# To load: +params = File.read!("model_params.axon") |> :erlang.binary_to_term([:safe]) +``` + +## Summary + +| Use Case | Save | Load | +| ------------------------------ | --------------------------------------------------------- | ---------------------------------------------------------- | +| Inference only | `Nx.serialize(params)` → file | `Nx.deserialize(file)` + model in code | +| Checkpoint (resume training) | `Axon.Loop.checkpoint/2` or `Axon.Loop.serialize_state/2` | `Axon.Loop.deserialize_state/2` + `Axon.Loop.from_state/2` | +| Extract params from checkpoint | `state.step_state.model_state.data` → `Nx.serialize` | Use with model in code | + diff --git a/lib/axon/loop.ex b/lib/axon/loop.ex index 1397c06c1..08456f351 100644 --- a/lib/axon/loop.ex +++ b/lib/axon/loop.ex @@ -1511,7 +1511,7 @@ defmodule Axon.Loop do It is the opposite of `Axon.Loop.serialize_state/2`. - By default, the step state is deserialized using `Nx.deserialize.2`; + By default, the step state is deserialized using `Nx.deserialize/2`; however, this behavior can be changed if step state is an application specific container. For example, if you introduce your own data structure into step_state and you customized the serialization logic, diff --git a/test/axon/serialization_guide_test.exs b/test/axon/serialization_guide_test.exs new file mode 100644 index 000000000..d4822fb7c --- /dev/null +++ b/test/axon/serialization_guide_test.exs @@ -0,0 +1,131 @@ +defmodule Axon.SerializationGuideTest do + @moduledoc """ + Tests that validate the examples in guides/serialization/saving_and_loading.livemd. + Run with: mix test test/axon/serialization_guide_test.exs + """ + use Axon.Case, async: false + + @tmp_path Path.join(System.tmp_dir!(), "axon_serialization_guide_test_#{:erlang.unique_integer([:positive])}") + + setup do + File.mkdir_p!(@tmp_path) + on_exit(fn -> File.rm_rf!(@tmp_path) end) + [tmp_path: @tmp_path] + end + + describe "saving and loading guide examples" do + test "full flow: train → save params → load → predict", %{tmp_path: tmp_path} do + # Same model as the guide + model = + Axon.input("data") + |> Axon.dense(8) + |> Axon.relu() + |> Axon.dense(4) + |> Axon.relu() + |> Axon.dense(1) + + loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd, log: 0) + + train_data = + Stream.repeatedly(fn -> + {xs, _} = Nx.Random.normal(Nx.Random.key(:erlang.phash2({self(), System.unique_integer([:monotonic])})), shape: {8, 1}) + {xs, Nx.sin(xs)} + end) + + # Train + trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 50) + + # Extract and save params (as in guide) + params = + case trained_model_state do + %Axon.ModelState{data: data} -> data + params when is_map(params) -> params + end + + params_path = Path.join(tmp_path, "model_params.axon") + params = Nx.backend_transfer(params) + params_bytes = Nx.serialize(params) + File.write!(params_path, params_bytes) + + # Load and predict (input shape must match training: {batch, 1} for 1 feature) + loaded_params = File.read!(params_path) |> Nx.deserialize() + input = Nx.tensor([[1.0]]) + + prediction = Axon.predict(model, loaded_params, %{"data" => input}) + + assert Nx.rank(prediction) == 2 + assert Nx.shape(prediction) == {1, 1} + end + + test "checkpoint and resume flow", %{tmp_path: tmp_path} do + model = + Axon.input("data") + |> Axon.dense(4) + |> Axon.relu() + |> Axon.dense(1) + + checkpoint_path = Path.join(tmp_path, "checkpoints") + + loop = + model + |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) + |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) + + train_data = [ + {Nx.tensor([[1.0, 2.0, 3.0, 4.0]]), Nx.tensor([[1.0]])}, + {Nx.tensor([[2.0, 3.0, 4.0, 5.0]]), Nx.tensor([[2.0]])} + ] + + Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2) + + # Verify checkpoint was saved + ckpt_files = File.ls!(checkpoint_path) |> Enum.sort() + assert length(ckpt_files) == 2 + assert Enum.any?(ckpt_files, &String.contains?(&1, "checkpoint_")) + + # Load checkpoint and extract params for inference + ckpt_file = Path.join(checkpoint_path, List.first(ckpt_files)) + state = File.read!(ckpt_file) |> Axon.Loop.deserialize_state() + + %{model_state: model_state} = state.step_state + params = model_state.data + + # Run inference with extracted params + input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]]) + prediction = Axon.predict(model, params, %{"data" => input}) + + assert Nx.rank(prediction) == 2 + assert Nx.shape(prediction) == {1, 1} + end + + test "resume from checkpoint with from_state", %{tmp_path: tmp_path} do + model = + Axon.input("data") + |> Axon.dense(2) + |> Axon.dense(1) + + checkpoint_path = Path.join(tmp_path, "checkpoints_resume") + + loop = + model + |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) + |> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed) + + train_data = [{Nx.tensor([[1.0, 2.0]]), Nx.tensor([[1.0]])}] + + # Run for 1 epoch + Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 1) + + # Load checkpoint and resume + [ckpt_file] = File.ls!(checkpoint_path) + state = File.read!(Path.join(checkpoint_path, ckpt_file)) |> Axon.Loop.deserialize_state() + + resumed_loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) |> Axon.Loop.from_state(state) + + # Resume - should complete without error + result = Axon.Loop.run(resumed_loop, train_data, Axon.ModelState.empty(), epochs: 2) + + assert %Axon.ModelState{} = result + end + end +end