Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@

## v0.7.0 (2024-10-08)

### Breaking Changes

* **Removed `Axon.serialize/2` and `Axon.deserialize/2`** — Use `Nx.serialize/2` and `Nx.deserialize/2` for parameters instead. Axon recommends serializing only the trained parameters (weights) and keeping the model definition in code. See the [Saving and Loading](guides/serialization/saving_and_loading.livemd) guide.

### Bug Fixes

* Do not cast integers in in Axon.MixedPrecision.cast/2
Expand Down
2 changes: 1 addition & 1 deletion guides/guides.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,5 +28,5 @@ Axon is a library for creating and training neural networks in Elixir. The Axon

## Serialization

* [Converting ONNX models to Axon](serialization/onnx_to_axon.livemd)
* [Saving and loading models](serialization/saving_and_loading.livemd)

187 changes: 187 additions & 0 deletions guides/serialization/saving_and_loading.livemd
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
# Saving and Loading Models

## Section

```elixir
Mix.install([
{:axon, ">= 0.8.0"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{:axon, ">= 0.8.0"}
{:axon, "~> 0.8"}

])
```

## Overview

Axon recommends a **parameters-only** approach to saving models: serialize only the trained parameters (weights) using `Nx.serialize/2` and `Nx.deserialize/2`, and keep the model definition in your code. This approach:

* Avoids serialization issues with anonymous functions and complex model structures
* Makes the model structure explicit and version-controlled in code
* Works reliably across processes and deployments

The model itself is just code—you define it once and reuse it. Only the learned parameters need to be persisted.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The model itself is just codeyou define it once and reuse it. Only the learned parameters need to be persisted.
The model itself is just codeyou define it once and reuse it. Only the learned parameters need to be persisted.


## Saving a Model After Training

When you run a training loop, it returns the trained model state by default. Extract the parameters and serialize them:

```elixir
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd)

train_data =
Stream.repeatedly(fn ->
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
{xs, Nx.sin(xs)}
end)

trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100)
```

The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`:
The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `data` field from `ModelState`:


```elixir
# Extract parameters - ModelState.data contains the nested map of weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Extract parameters - ModelState.data contains the nested map of weights
# Extract parameters - trained_model_state.data contains the nested map of weights

params =
case trained_model_state do
%Axon.ModelState{data: data} -> data
params when is_map(params) -> params
end
Comment on lines +49 to +53
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case really necessary?


# Transfer to binary backend for reliable serialization (avoids issues with Nx 0.11+ on EXLA/Torchx)
params = Nx.backend_transfer(params)
Comment on lines +55 to +56
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which issues would these be? We should not need to transfer it to serialize. If there are issues, that's a bug in Nx 0.11


# Serialize and save
params_bytes = Nx.serialize(params)
File.write!("model_params.axon", params_bytes)
```

## Loading a Model for Inference

To load and run inference, you need:

1. The model definition (in code—the same structure you trained)
2. The saved parameters

```elixir
# 1. Define the same model structure (must match training)
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

# 2. Load parameters
params = File.read!("model_params.axon") |> Nx.deserialize()

# 3. Run inference
input = Nx.tensor([[1.0]]) # shape {1, 1}: 1 sample with 1 feature (matches model input)
Axon.predict(model, params, %{"data" => input})
```

## Checkpointing During Training

To save checkpoints during training (e.g., every epoch or when validation improves), use `Axon.Loop.checkpoint/2`. This serializes the full loop state—including model parameters and optimizer state—so you can resume training later.

```elixir
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(1)

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.checkpoint(path: "checkpoints", event: :epoch_completed)

train_data =
Stream.repeatedly(fn ->
{xs, _} = Nx.Random.key(System.unique_integer([:positive])) |> Nx.Random.normal(shape: {8, 1})
{xs, Nx.sin(xs)}
end)

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50)
```

Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.
Checkpoints are saved to the `checkpoints/` directory, as configured above. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.


## Resuming from a Checkpoint

To resume training from a saved checkpoint:

1. Load the checkpoint with `Axon.Loop.deserialize_state/2`
2. Attach it to your loop with `Axon.Loop.from_state/2`
3. Run the loop as usual

```elixir
# Load the checkpoint (use the path from your checkpoint files)
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
serialized = File.read!(checkpoint_path)
state = Axon.Loop.deserialize_state(serialized)

# Resume training
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(1)

Comment on lines +115 to +135
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmor5 I think there's a bit of a dissonance between not having Axon.serialize/deserialize, while checkpoints need their Axon functions. WDYT?

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd)
|> Axon.Loop.from_state(state)

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 5, iterations: 50)
```

## Saving Only Parameters from a Checkpoint

If you have a checkpoint file and want to extract parameters for inference (without optimizer state):

```elixir
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
state = File.read!(checkpoint_path) |> Axon.Loop.deserialize_state()

# Extract model parameters from step_state
%{model_state: model_state} = state.step_state
params = model_state.data |> Nx.backend_transfer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about needing backend_transfer


# Save for inference
File.write!("model_params.axon", Nx.serialize(params))
```

## Troubleshooting: ArgumentError with Nx.serialize

If you see `(ArgumentError) argument error` or `:erlang.++` errors when calling `Nx.serialize/2` on parameters (common with Nx 0.11+ and EXLA/Torchx backends), transfer tensors to the binary backend first:

```elixir
params = Nx.backend_transfer(params)
params_bytes = Nx.serialize(params)
```

If serialization still fails, you can use `:erlang.term_to_binary/2` when parameters are on the binary backend (e.g. after `Nx.backend_transfer/1`):

```elixir
params = Nx.backend_transfer(params)
params_bytes = :erlang.term_to_binary(params)
File.write!("model_params.axon", params_bytes)

# To load:
params = File.read!("model_params.axon") |> :erlang.binary_to_term([:safe])
```
Comment on lines +160 to +178
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fix this and release 0.11.1 so we can merge this PR without these bug-related caveats


## Summary

| Use Case | Save | Load |
| ------------------------------ | --------------------------------------------------------- | ---------------------------------------------------------- |
| Inference only | `Nx.serialize(params)` → file | `Nx.deserialize(file)` + model in code |
| Checkpoint (resume training) | `Axon.Loop.checkpoint/2` or `Axon.Loop.serialize_state/2` | `Axon.Loop.deserialize_state/2` + `Axon.Loop.from_state/2` |
| Extract params from checkpoint | `state.step_state.model_state.data` → `Nx.serialize` | Use with model in code |

2 changes: 1 addition & 1 deletion lib/axon/loop.ex
Original file line number Diff line number Diff line change
Expand Up @@ -1511,7 +1511,7 @@ defmodule Axon.Loop do

It is the opposite of `Axon.Loop.serialize_state/2`.

By default, the step state is deserialized using `Nx.deserialize.2`;
By default, the step state is deserialized using `Nx.deserialize/2`;
however, this behavior can be changed if step state is an application
specific container. For example, if you introduce your own data
structure into step_state and you customized the serialization logic,
Expand Down
131 changes: 131 additions & 0 deletions test/axon/serialization_guide_test.exs
Original file line number Diff line number Diff line change
@@ -0,0 +1,131 @@
defmodule Axon.SerializationGuideTest do
@moduledoc """
Tests that validate the examples in guides/serialization/saving_and_loading.livemd.
Run with: mix test test/axon/serialization_guide_test.exs
"""
use Axon.Case, async: false

@tmp_path Path.join(System.tmp_dir!(), "axon_serialization_guide_test_#{:erlang.unique_integer([:positive])}")

setup do
File.mkdir_p!(@tmp_path)
on_exit(fn -> File.rm_rf!(@tmp_path) end)
[tmp_path: @tmp_path]
end

describe "saving and loading guide examples" do
test "full flow: train → save params → load → predict", %{tmp_path: tmp_path} do
# Same model as the guide
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

loop = Axon.Loop.trainer(model, :mean_squared_error, :sgd, log: 0)

train_data =
Stream.repeatedly(fn ->
{xs, _} = Nx.Random.normal(Nx.Random.key(:erlang.phash2({self(), System.unique_integer([:monotonic])})), shape: {8, 1})
{xs, Nx.sin(xs)}
end)

# Train
trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 50)

# Extract and save params (as in guide)
params =
case trained_model_state do
%Axon.ModelState{data: data} -> data
params when is_map(params) -> params
end

params_path = Path.join(tmp_path, "model_params.axon")
params = Nx.backend_transfer(params)
params_bytes = Nx.serialize(params)
File.write!(params_path, params_bytes)

# Load and predict (input shape must match training: {batch, 1} for 1 feature)
loaded_params = File.read!(params_path) |> Nx.deserialize()
input = Nx.tensor([[1.0]])

prediction = Axon.predict(model, loaded_params, %{"data" => input})

assert Nx.rank(prediction) == 2
assert Nx.shape(prediction) == {1, 1}
end

test "checkpoint and resume flow", %{tmp_path: tmp_path} do
model =
Axon.input("data")
|> Axon.dense(4)
|> Axon.relu()
|> Axon.dense(1)

checkpoint_path = Path.join(tmp_path, "checkpoints")

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0)
|> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed)

train_data = [
{Nx.tensor([[1.0, 2.0, 3.0, 4.0]]), Nx.tensor([[1.0]])},
{Nx.tensor([[2.0, 3.0, 4.0, 5.0]]), Nx.tensor([[2.0]])}
]

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2)

# Verify checkpoint was saved
ckpt_files = File.ls!(checkpoint_path) |> Enum.sort()
assert length(ckpt_files) == 2
assert Enum.any?(ckpt_files, &String.contains?(&1, "checkpoint_"))

# Load checkpoint and extract params for inference
ckpt_file = Path.join(checkpoint_path, List.first(ckpt_files))
state = File.read!(ckpt_file) |> Axon.Loop.deserialize_state()

%{model_state: model_state} = state.step_state
params = model_state.data

# Run inference with extracted params
input = Nx.tensor([[1.0, 2.0, 3.0, 4.0]])
prediction = Axon.predict(model, params, %{"data" => input})

assert Nx.rank(prediction) == 2
assert Nx.shape(prediction) == {1, 1}
end

test "resume from checkpoint with from_state", %{tmp_path: tmp_path} do
model =
Axon.input("data")
|> Axon.dense(2)
|> Axon.dense(1)

checkpoint_path = Path.join(tmp_path, "checkpoints_resume")

loop =
model
|> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0)
|> Axon.Loop.checkpoint(path: checkpoint_path, event: :epoch_completed)

train_data = [{Nx.tensor([[1.0, 2.0]]), Nx.tensor([[1.0]])}]

# Run for 1 epoch
Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 1)

# Load checkpoint and resume
[ckpt_file] = File.ls!(checkpoint_path)
state = File.read!(Path.join(checkpoint_path, ckpt_file)) |> Axon.Loop.deserialize_state()

resumed_loop = model |> Axon.Loop.trainer(:mean_squared_error, :sgd, log: 0) |> Axon.Loop.from_state(state)

# Resume - should complete without error
result = Axon.Loop.run(resumed_loop, train_data, Axon.ModelState.empty(), epochs: 2)

assert %Axon.ModelState{} = result
end
end
end