Training on the ANE¶
ANEForge is not inference-only. aneforge/autograd.py is a small reverse-mode
autograd in which both the forward and the backward pass compile to, and run on,
the Apple Neural Engine through the e5rt path. A model's weights, its
gradients, and the optimizer update all live as ANE graph ops; the host computes
only a scalar learning rate, samples minibatches, and reads results at
checkpoints.
This page describes the model, the Trainer and UnrolledTrainer APIs, the
loss-scaling story, the measured results, and the honest limits of the current
gradient vocabulary. Everything here is reachable with the same e5rt dispatch
backend the rest of ANEForge uses - no CoreML, no entitlement.
The import alias throughout is import aneforge as af.
The model: tiny reverse-mode autograd¶
Trainable weights are graph inputs, not constants¶
A native ANE program bakes its weights into the compiled blob. That is fine for
inference but useless for training, because every weight update would force a
recompile. ANEForge sidesteps this with af.parameter:
af.parameter(init) creates a graph input tagged trainable. It carries an
fp32 master value in attrs["value"], is used in the graph like any other
input, and is fed its current value on each evaluation. Updating a weight is
therefore just writing a new value into the master array and feeding it on the
next step - no recompile per step. This is the general mutable-weight mechanism
the whole training stack rests on.
The VJP registry¶
backward walks the forward graph in reverse, and for each op it looks up a
vector-Jacobian-product rule in a registry:
@af.autograd.vjp("mul")
def _vjp_mul(t, g):
a, b = t.srcs
return [_unbroadcast(g * b, a.shape), _unbroadcast(g * a, b.shape)]
Each rule takes the node and its incoming cotangent g and returns one gradient
per source. The gradients are themselves ordinary ANEForge tensor ops, so the
backward pass is a graph that compiles and runs on the engine exactly like the
forward pass.
backward builds an on-ANE backward graph¶
af.backward(loss, params, loss_scale) returns, for each parameter, a gradient
Tensor expressed entirely in ANE ops. The seed dL/dloss = loss_scale is folded
into an additive constant rather than emitted as a multiply, which sidesteps the
mul(reduce_output, 0.0) compile wall described below. A companion
af.backward_from(grad_root, root, params) seeds from an explicit gradient at an
intermediate tensor (for example the logits) instead of from a scalar loss.
Both functions accept a stop stop-gradient frontier (it defaults to params).
For ordinary leaf weights this is a no-op, but it is what makes multi-step
unrolling possible: when one step's updated-weight tensors thread into the next
step's forward, each step's gradient must treat the current weights as leaves
(plain SGD/Adam) rather than differentiate through the previous update.
The Trainer¶
af.Trainer compiles a forward program once and one backward program per
parameter, then runs the training loop. It accepts either a scalar loss Tensor
(regression) or a CEHandle from af.softmax_cross_entropy(logits, target)
(classification). For classification the cross-entropy gradient at the logits is
the analytic, fp16-stable form (softmax(logits) - target) / N - no log
appears in the backward path, which would otherwise overflow fp16.
logits = forward(x) # an ANEForge graph
obj = af.softmax_cross_entropy(logits, target)
tr = af.Trainer(obj, params, lr=0.005, loss_scale=1024.0, optimizer="adam")
tr.set_dataset(x, X_full, target, Y_onehot)
for _ in range(steps):
tr.step()
print(tr.accuracy(X_test, y_test))
By default the host runs an fp32 SGD or Adam update over the parameters' master
values (af.SGD / af.Adam), which is the byte-for-byte baseline path.
device_optimizer=True: the update on the engine¶
With device_optimizer=True the optimizer arithmetic is compiled to ANE graph
ops as well. In addition to the per-parameter backward programs, a per-parameter
update program computes the new state on the engine, so no training tensor math
runs on the host. The host only computes the scalar lr_t, shuttles state and
gradients in and out, samples the minibatch, and prints.
One subtlety is load-bearing here. The learning rate lr_t fed to the on-engine
Adam update must not divide by loss_scale. Adam's step is the ratio
m / sqrt(v); with a scaled gradient g' = S*g the moments scale as m' = S*m
and v' = S^2*v, so the ratio is scale-invariant and the loss-scale cancels.
Dividing lr_t by S (as a naive SGD-style unscale would) double-unscales and
collapses the step. SGD is linear, so its lr_t = lr / loss_scale is correct;
Adam's is not. The on-engine Adam moments m and v are held as fp16 arrays,
fed each step and read back.
resident_state=True: optimizer state stays on the engine¶
resident_state=True assembles the whole step - forward, backward, and the
per-parameter update - as one fused multi-output program and keeps the optimizer
state resident on-device across steps. Each updated-state output is aliased back
onto its own input port with Program.share_buffer (via compile_multi /
MultiModel) and the shared buffers are seeded once. The host then feeds only
the minibatch and lr_t each step, and reads the weights off the device only at
checkpoints; nothing is shuttled in and out.
Full MNIST (a 784-256-10 GELU MLP, Adam) trains to 97.79% test accuracy with
all twelve state tensors (four parameters times weight/m/v) resident across
roughly 2,340 steps, in about 1.0 s - faster than the host round-trip path
(~4 s) because there is no per-step shuttle, and with no compile wall.
(The resident-state mechanism is verified bit-for-bit against the host-shuttle path in
tests/test_autograd.py.)
Loss scaling and fp16¶
Gradients are computed in fp16. A loss scale (commonly 1024) lifts small gradients out of fp16 underflow before the backward pass and is divided back out before the optimizer step (or, for Adam, cancels in the ratio as above). The optimizer state itself is fine in fp16: the Adam moments stay O(1-20) and do not overflow, so fp16 optimizer state is sufficient and paired-fp16 is not needed for the models trained here.
There is one chip-specific caveat. On the A13-class ANE (M1), the trainable
conv's width-offset im2col slices route through a fixed-point crop-DMA that
saturates any value past 4094 (= 65504/16) to infinity, so an extreme
loss_scale times a large backward activation could in principle corrupt a conv
weight-gradient. ANEForge emits a warning only in that case and never caps
loss_scale. The auto-cap was refuted end-to-end on M1: a real normalized CNN
trains identically at loss_scale 128, 1024, and 65536 - only a synthetic
0.5*sum(y^2) repro with random inputs ever reaches the threshold. The warning
fires solely on A13 targets that train a conv weight with kW > 1; M5 / A16 has
no such route.
Measured results¶
MNIST MLP¶
A 784-256-10 GELU MLP trains on full MNIST to 97.57% test accuracy with the forward, backward, and Adam update all on the ANE - about 0.5 point under the host-fp32-Adam baseline of 98.05% on the same model. With resident state the same model reaches 97.79% in ~1.0 s (above).
Cross-chip parity¶
The identical seeded MNIST CNN (examples/train_mnist_cnn.py, fixed
np.random.default_rng(0), deterministic data order, default loss_scale=1024)
trains to 0.9080 on M1 (h13/A13), 0.9080 on M2 Pro (h14g/A14, deterministic
x3), and 0.9070 on M5 (h17/A16-class) over 300 steps. Every run is internally
deterministic across repeats. A14 lands exactly on M1's number, so the one-sample gap
(1 in 1000, 0.1%) appears only at the A16 generation - the end-to-end signature of the
ANE's <=1-ULP-per-op cross-chip fp16 difference accumulating over 300 steps, with the
drift boundary at A15/A16 rather than per-chip noise: real, but far too small to
matter. Training is chip-portable across all three measured generations. See
cross-chip.md for the portability story.
Versus the GPU and CPU¶
For a small MLP, training on the ANE essentially ties the fastest Apple GPU path and both beat the CPU. On the same 784-256-10 GELU MLP (Adam, B=128, 5 epochs = 2,340 steps on full MNIST), all reaching ~97.5% accuracy:
| device / framework | train time | vs fastest |
|---|---|---|
| GPU Metal (MLX, fp32) | 0.87 s | 1.00x |
| ANE (ANEForge fp16, resident) | 0.93 s | 1.08x |
| CPU (PyTorch fp32) | 1.40 s | 1.62x |
| CPU (MLX fp32) | 1.77 s | 2.05x |
| GPU Metal (PyTorch MPS, fp32) | 1.87 s | 2.16x |
This is the dispatch-bound, tiny-model regime (~0.4 ms/step, ~200K parameters), so the ANE matching the GPU is about per-step overhead, not peak FLOPS; for larger compute-bound models the GPU's raw throughput is expected to pull ahead. ANE training is a capability and an efficiency story, not a peak-speed claim.
Host-free multi-step dispatch¶
A fixed number of training steps has no data-dependent control flow, so the whole
forward -> backward -> optimizer-update recurrence unrolls into a single graph,
the same trick the iterative solvers in aneforge.linalg use. af.UnrolledTrainer
builds it: each step() runs K steps in one dispatch on the engine, with the
host feeding K minibatches plus the per-step learning rates (an array move, no
per-step host round-trip inside the block). With resident=True (the default)
the optimizer state is share_buffer-aliased across dispatches, so the host
feeds only the K minibatches and per-step lr and reads weights at checkpoints.
xs = [af.input((B, DIN)) for _ in range(K)]
ts = [af.input((B, C)) for _ in range(K)]
tr = af.UnrolledTrainer(P, forward, "ce", xs, ts, (Xtr, onehot),
lr=0.01, loss_scale=1024.0)
for _ in range(epochs):
tr.step() # ONE dispatch = K on-engine steps
acc = (tr.predict(Xte).argmax(1) == yte).mean()
One execute_multi drives K on-engine steps with no Path B and no
entitlement. The honest finding is that this is performance-neutral: on the
resident-buffer path the per-step host dispatch was never the bottleneck, so
collapsing the per-step execute() calls into one buys nothing measurable. The
bounded multi-step path is reachable without an entitlement, and by this
measurement the entitlement-gated fully-autonomous loop would not run these
workloads any faster.
For a full-batch task the data and learning rate can also be seeded once, after
which the training loop is literally prog.execute() and the host feeds nothing
per dispatch; examples/train_transformer.py trains an attention block this way.
Worked examples¶
The committed demos train a 1000-sample MNIST subset (examples/data/mnist_subset.npz)
end to end on the engine:
examples/train_mnist_mlp.py- a GELU MLP, K steps unrolled per dispatch.examples/train_mnist_cnn.py- conv -> relu -> avg_pool -> fc, with a trainable conv. The native ANE conv requires a baked weight, soaf.conv_param/af.conv2dbuild the conv from primitives (static im2col plus a batched matmul), which gives a weight gradient that runs on the engine; theconv_shapeis carried across each in-graph optimizer update automatically.examples/train_transformer.py- a multi-head self-attention block trained fully host-free withaf.adam_stepand resident state.examples/train_transformer_prenorm.py- a pre-norm transformer block (layer_normbefore attention and before the MLP). The gradient flows through LayerNorm to the trainable projections and MLP; the block converges.examples/train_llama_block.py- a LLaMA-style block (rms_normpre-normalization + a SwiGLU feed-forward built onsilu). Exercises the RMSNorm and SiLU gradients end to end. The char-LM examples apply causal attention as an additive mask before a decomposed softmax; for a forward decoder block, causal attention is also available on the native fused-attention layer viaaf.sdpa(q, k, v, is_causal=True)(seeexamples/llama_block_causal.py).examples/train_charlm.py- training at scale: a 4-layer causal LLaMA-style character language model (token + positional embeddings, RMSNorm + SwiGLU blocks, output projection, next-token cross-entropy), forward + backward + optimizer all on the engine. Cross-entropy falls to near zero, next-char accuracy reaches 100%, and it reconstructs its training text from a seed. Token embedding is a one-hot matmul (onehot @ W_emb) rather than agather(which has no VJP), so the embedding gradient is an ordinary matmul gradient; causal masking is an additive upper-triangular bias before softmax. Stacking the fused forward + backward + optimizer program converges through eight layers; compile time, not correctness, is the practical ceiling on depth.examples/train_charlm_corpus.py- the data-scale companion: the same model trains on random windows from a structured corpus's training split and is evaluated on a disjoint validation split it never trains on. Held-out next-char accuracy reaches about 65% against a roughly 20% unigram baseline, so the model learned transferable structure rather than memorizing a single sequence. Each window is re-based to absolute positions 0..S-1 so the learned positional embeddings apply to every sampled window.examples/train_charlm_deep.py- a 16-layer char-LM trained with a layer-streamed compile (aneforge.streaming.CheckpointedStack). When the layers are identical, the per-layer forward and per-layer backward each depend only on one layer's shape, so each is compiled once and reused for every layer; the six programs (embed, layer, head, each with a forward and a backward) compile in about half a second regardless of depth, where the monolithic eight-layer compile alone took about 162 seconds. Cross-entropy falls from about 3.0 to 0.03. This removes compile size as the ceiling on training depth.
Depth without a compile-size wall¶
A monolithic compile fuses a model's whole forward, backward, and optimizer step into
one program, so compile time grows superlinearly with depth and caps how deep a model
can train. aneforge.streaming.CheckpointedStack removes that ceiling for a stack of
identical layers. The per-layer forward and the per-layer backward are each compiled
once and reused for every layer, fed that layer's parameters and its checkpointed
input activation, so compile work is independent of the layer count. The backward is
standard gradient checkpointing: each layer's input is stored, the layer's forward is
recomputed inside the reused backward program, and that program returns the parameter
gradients together with the gradient with respect to the input (the upstream gradient
for the layer below). The streamed gradients are bit-exact against a monolithic
backward. The caller drives the embedding and output stages (each compiled once) and a
host-side optimizer over the streamed gradients, with all forward and backward compute
on the engine.
The two normalization layers in those blocks demonstrate that a real
norm-bearing architecture trains on the engine, affine included. The norm methods
are polymorphic in their affine arguments: a numpy gamma / beta bakes a fixed
affine (the native lowering, gradient flows through the input only), while a
parameter Tensor for gamma / beta composes the unit-affine normalize op
with an explicit learnable scale and shift, so the affine trains alongside the
rest of the model (its gradient flows through the broadcast-aware mul / add
VJPs). Pass gamma / beta shaped to broadcast over the normalized axis
([1, D] for layer_norm / rms_norm, [1, C, 1, 1] for group_norm). Models
that use prelu are likewise trainable through the Trainer.
Honest limits¶
Gradient vocabulary¶
The registered VJPs are all numerically correct - every one matches its closed-form derivative (verified to cos ~1.0 against finite differences on M1). The differentiable vocabulary covers the structural and linear-algebra ops, the common activations, and the three normalization layers:
- structure / linalg:
matmul,bmm,conv,avg_pool,max_pool,reduce_sum,reduce_mean,transpose,reshape,flatten2d,slice_by_size,concat,add,sub,mul,muls,adds,square - activations:
relu,relu6,leaky_relu,elu,gelu,silu,sigmoid,sigmoid_hard,tanh,scaled_tanh,softmax,clip - math:
exp,log,sqrt,rsqrt,inverse,erf,cos,abs - normalization:
layer_norm,rms_norm,group_norm,l2_norm
This trains transformer and LLaMA-style blocks, diffusion-UNet and GroupNorm
CNNs, plain CNNs, and MLPs end to end on the engine. The normalization VJPs use
the exact closed-form input gradient; because ANEForge has no constant-tensor op,
the backward graph re-injects gamma as a fed value-input.
If a model uses a forward op with no registered VJP, it compiles and runs its forward
pass, then fails at backward with no vjp for op '<name>'. The norms and silu are
covered (they appear in essentially every modern architecture). The remaining gaps are
low-value and do not block mainstream training: the reduction-routing ops amax / amin
/ cumsum, and a few parametric activations (atan, softplus, clamped variants).
Two pre-existing test caveats are unrelated to gradient coverage and predate this
work: test_conv2d_trainable_grads fails with NaN (the conv weight-gradient
slice-saturation that is root-caused separately - it fails identically with the
op removed), and test_cnn_trains_on_subset can hang late in a full suite run
under many rapid compiles, which the compile backoff in aneforge/_circuit.py
paces as a defensive backstop.
The mul(reduce_output, 0.0) compile wall¶
Multiplying a reduce_sum / reduce_mean result by zero trips an ANECCompile
fusion wall (reduce * nonzero compiles, and a bare mul-by-zero compiles; only
the combination fails). The autograd builds constant tensors with
_const_like(t, c) = (t - t).adds(c) - an exact-zero subtract rather than a
multiply by zero - which sidesteps the wall for any finite tensor, and the
backward seed folds the loss scale into an additive constant for the same reason.
This is handled internally; it matters only if you author backward graphs by hand.
Throughput¶
End-to-end training throughput is dispatch- and round-trip-bound like the rest of the ANE surface at this model scale. On-engine training is a capability (train on-device, low power, chip-portable), not a peak-speed claim against the GPU on large compute-bound models.
See also¶
aneforge-api.md- the full frontend reference.capabilities.md- the hardware op surface and dtype limits.getting-started.md- building the e5rt dispatch dylib.