Skip to content

Add docs for serialization with NX.serialize#630

Open
aphillipo wants to merge 1 commit intoelixir-nx:mainfrom
aphillipo:ap-serialization-docs
Open

Add docs for serialization with NX.serialize#630
aphillipo wants to merge 1 commit intoelixir-nx:mainfrom
aphillipo:ap-serialization-docs

Conversation

@aphillipo
Copy link

@aphillipo aphillipo commented Mar 5, 2026

  • Added docs for Nx.serialize with examples.
  • Removed link to non-working onnx docs (axon_onnx needs updating)
  • Added tests
  • Updated CHANGELOG for 0.7 to explain what's happened (debatable if this is necessary)

@aphillipo
Copy link
Author

aphillipo commented Mar 5, 2026

You are obviously welcome to rip this apart but it's a decent start, I might have a go at updating axon_onnx at some point.


```elixir
Mix.install([
{:axon, ">= 0.8.0"}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
{:axon, ">= 0.8.0"}
{:axon, "~> 0.8"}

* Makes the model structure explicit and version-controlled in code
* Works reliably across processes and deployments

The model itself is just code—you define it once and reuse it. Only the learned parameters need to be persisted.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The model itself is just codeyou define it once and reuse it. Only the learned parameters need to be persisted.
The model itself is just codeyou define it once and reuse it. Only the learned parameters need to be persisted.

trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100)
```

The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`:
The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `data` field from `ModelState`:

The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`:

```elixir
# Extract parameters - ModelState.data contains the nested map of weights
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
# Extract parameters - ModelState.data contains the nested map of weights
# Extract parameters - trained_model_state.data contains the nested map of weights

Comment on lines +49 to +53
params =
case trained_model_state do
%Axon.ModelState{data: data} -> data
params when is_map(params) -> params
end
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this case really necessary?

Comment on lines +55 to +56
# Transfer to binary backend for reliable serialization (avoids issues with Nx 0.11+ on EXLA/Torchx)
params = Nx.backend_transfer(params)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Which issues would these be? We should not need to transfer it to serialize. If there are issues, that's a bug in Nx 0.11

Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50)
```

Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.
Checkpoints are saved to the `checkpoints/` directory, as configured above. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`.

Comment on lines +115 to +135
## Resuming from a Checkpoint

To resume training from a saved checkpoint:

1. Load the checkpoint with `Axon.Loop.deserialize_state/2`
2. Attach it to your loop with `Axon.Loop.from_state/2`
3. Run the loop as usual

```elixir
# Load the checkpoint (use the path from your checkpoint files)
checkpoint_path = "checkpoints/checkpoint_2_50.ckpt"
serialized = File.read!(checkpoint_path)
state = Axon.Loop.deserialize_state(serialized)

# Resume training
model =
Axon.input("data")
|> Axon.dense(8)
|> Axon.relu()
|> Axon.dense(1)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@seanmor5 I think there's a bit of a dissonance between not having Axon.serialize/deserialize, while checkpoints need their Axon functions. WDYT?


# Extract model parameters from step_state
%{model_state: model_state} = state.step_state
params = model_state.data |> Nx.backend_transfer()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment about needing backend_transfer

Comment on lines +160 to +178
## Troubleshooting: ArgumentError with Nx.serialize

If you see `(ArgumentError) argument error` or `:erlang.++` errors when calling `Nx.serialize/2` on parameters (common with Nx 0.11+ and EXLA/Torchx backends), transfer tensors to the binary backend first:

```elixir
params = Nx.backend_transfer(params)
params_bytes = Nx.serialize(params)
```

If serialization still fails, you can use `:erlang.term_to_binary/2` when parameters are on the binary backend (e.g. after `Nx.backend_transfer/1`):

```elixir
params = Nx.backend_transfer(params)
params_bytes = :erlang.term_to_binary(params)
File.write!("model_params.axon", params_bytes)

# To load:
params = File.read!("model_params.axon") |> :erlang.binary_to_term([:safe])
```
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We should fix this and release 0.11.1 so we can merge this PR without these bug-related caveats

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants