Add docs for serialization with NX.serialize#630
Add docs for serialization with NX.serialize#630aphillipo wants to merge 1 commit intoelixir-nx:mainfrom
Conversation
|
You are obviously welcome to rip this apart but it's a decent start, I might have a go at updating axon_onnx at some point. |
|
|
||
| ```elixir | ||
| Mix.install([ | ||
| {:axon, ">= 0.8.0"} |
There was a problem hiding this comment.
| {:axon, ">= 0.8.0"} | |
| {:axon, "~> 0.8"} |
| * Makes the model structure explicit and version-controlled in code | ||
| * Works reliably across processes and deployments | ||
|
|
||
| The model itself is just code—you define it once and reuse it. Only the learned parameters need to be persisted. |
There was a problem hiding this comment.
| The model itself is just code—you define it once and reuse it. Only the learned parameters need to be persisted. | |
| The model itself is just code — you define it once and reuse it. Only the learned parameters need to be persisted. |
| trained_model_state = Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 2, iterations: 100) | ||
| ``` | ||
|
|
||
| The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`: |
There was a problem hiding this comment.
| The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`: | |
| The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `data` field from `ModelState`: |
| The training loop returns `model_state` by default (from `Axon.Loop.trainer/3`). For inference, we need the parameters—extract the `.data` field from `ModelState`: | ||
|
|
||
| ```elixir | ||
| # Extract parameters - ModelState.data contains the nested map of weights |
There was a problem hiding this comment.
| # Extract parameters - ModelState.data contains the nested map of weights | |
| # Extract parameters - trained_model_state.data contains the nested map of weights |
| params = | ||
| case trained_model_state do | ||
| %Axon.ModelState{data: data} -> data | ||
| params when is_map(params) -> params | ||
| end |
There was a problem hiding this comment.
Is this case really necessary?
| # Transfer to binary backend for reliable serialization (avoids issues with Nx 0.11+ on EXLA/Torchx) | ||
| params = Nx.backend_transfer(params) |
There was a problem hiding this comment.
Which issues would these be? We should not need to transfer it to serialize. If there are issues, that's a bug in Nx 0.11
| Axon.Loop.run(loop, train_data, Axon.ModelState.empty(), epochs: 3, iterations: 50) | ||
| ``` | ||
|
|
||
| Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`. |
There was a problem hiding this comment.
| Checkpoints are saved to the `checkpoints/` directory. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`. | |
| Checkpoints are saved to the `checkpoints/` directory, as configured above. Each file contains the serialized loop state from `Axon.Loop.serialize_state/2`. |
| ## Resuming from a Checkpoint | ||
|
|
||
| To resume training from a saved checkpoint: | ||
|
|
||
| 1. Load the checkpoint with `Axon.Loop.deserialize_state/2` | ||
| 2. Attach it to your loop with `Axon.Loop.from_state/2` | ||
| 3. Run the loop as usual | ||
|
|
||
| ```elixir | ||
| # Load the checkpoint (use the path from your checkpoint files) | ||
| checkpoint_path = "checkpoints/checkpoint_2_50.ckpt" | ||
| serialized = File.read!(checkpoint_path) | ||
| state = Axon.Loop.deserialize_state(serialized) | ||
|
|
||
| # Resume training | ||
| model = | ||
| Axon.input("data") | ||
| |> Axon.dense(8) | ||
| |> Axon.relu() | ||
| |> Axon.dense(1) | ||
|
|
There was a problem hiding this comment.
@seanmor5 I think there's a bit of a dissonance between not having Axon.serialize/deserialize, while checkpoints need their Axon functions. WDYT?
|
|
||
| # Extract model parameters from step_state | ||
| %{model_state: model_state} = state.step_state | ||
| params = model_state.data |> Nx.backend_transfer() |
There was a problem hiding this comment.
Same comment about needing backend_transfer
| ## Troubleshooting: ArgumentError with Nx.serialize | ||
|
|
||
| If you see `(ArgumentError) argument error` or `:erlang.++` errors when calling `Nx.serialize/2` on parameters (common with Nx 0.11+ and EXLA/Torchx backends), transfer tensors to the binary backend first: | ||
|
|
||
| ```elixir | ||
| params = Nx.backend_transfer(params) | ||
| params_bytes = Nx.serialize(params) | ||
| ``` | ||
|
|
||
| If serialization still fails, you can use `:erlang.term_to_binary/2` when parameters are on the binary backend (e.g. after `Nx.backend_transfer/1`): | ||
|
|
||
| ```elixir | ||
| params = Nx.backend_transfer(params) | ||
| params_bytes = :erlang.term_to_binary(params) | ||
| File.write!("model_params.axon", params_bytes) | ||
|
|
||
| # To load: | ||
| params = File.read!("model_params.axon") |> :erlang.binary_to_term([:safe]) | ||
| ``` |
There was a problem hiding this comment.
We should fix this and release 0.11.1 so we can merge this PR without these bug-related caveats
Uh oh!
There was an error while loading. Please reload this page.