Skip to content

Training & autograd

Reverse-mode autograd over the graph, the Trainer loop (with the optimizer step on the engine), and layer-streamed training for deep stacks.

autograd

A tiny reverse-mode autograd over the aneforge graph: forward and backward both run on the ANE. Trainable parameters are graph inputs (fed each step, updated host-side, no recompile).

CEHandle

A softmax-cross-entropy training objective: carries the logits and the one-hot target (a graph input). The gradient at the logits is the analytic fused form (softmax(logits) - target)/N, which is fp16-stable (no log). The loss VALUE + accuracy are computed host-side in fp32 by the Trainer.

Source code in aneforge/autograd.py
class CEHandle:
    """A softmax-cross-entropy training objective: carries the logits and the
    one-hot target (a graph input). The gradient at the logits is the analytic
    fused form (softmax(logits) - target)/N, which is fp16-stable (no log). The
    loss VALUE + accuracy are computed host-side in fp32 by the Trainer."""
    def __init__(self, logits: Tensor, target: Tensor):
        if len(logits.shape) != 2 or logits.shape != target.shape:
            raise ValueError(f"softmax_cross_entropy expects 2-D logits/target [N,K]; "
                             f"got {logits.shape}, {target.shape}")
        self.logits, self.target, self.n = logits, target, int(logits.shape[0])

    def seed(self, loss_scale: float) -> Tensor:
        """dL/dlogits * loss_scale = (softmax(logits) - target) * (loss_scale / n)."""
        return (self.logits.softmax(-1) - self.target) * (float(loss_scale) / self.n)

seed

seed(loss_scale: float) -> Tensor

dL/dlogits * loss_scale = (softmax(logits) - target) * (loss_scale / n).

Source code in aneforge/autograd.py
def seed(self, loss_scale: float) -> Tensor:
    """dL/dlogits * loss_scale = (softmax(logits) - target) * (loss_scale / n)."""
    return (self.logits.softmax(-1) - self.target) * (float(loss_scale) / self.n)

SGD

Host fp32 SGD over the parameters' master values. Trainer applies loss scaling (grads come in scaled; divide before the step).

Source code in aneforge/autograd.py
class SGD:
    """Host fp32 SGD over the parameters' master values. `Trainer` applies loss
    scaling (grads come in scaled; divide before the step)."""
    def __init__(self, params, lr: float, loss_scale: float = 1.0):
        self.params, self.lr, self.scale = list(params), float(lr), float(loss_scale)
        self._nonfinite_skips = 0

    def step(self, grads):
        if not _check_finite_grads(self, grads):
            return
        for p, g in zip(self.params, grads):
            p.attrs["value"] = p.attrs["value"] - self.lr * (g.astype(np.float32) / self.scale)

Adam

Host fp32 Adam over the parameters' master values. Loss-scaled grads are divided by loss_scale before the moment update; fp16 written back into the fed param value (sibling to SGD).

Source code in aneforge/autograd.py
class Adam:
    """Host fp32 Adam over the parameters' master values. Loss-scaled grads are
    divided by `loss_scale` before the moment update; fp16 written back into the
    fed param value (sibling to SGD)."""
    def __init__(self, params, lr: float = 1e-3, betas=(0.9, 0.999),
                 eps: float = 1e-8, loss_scale: float = 1.0):
        self.params = list(params)
        self.lr, (self.b1, self.b2), self.eps, self.scale = float(lr), betas, float(eps), float(loss_scale)
        self.m = [np.zeros(p.attrs["value"].shape, np.float32) for p in self.params]
        self.v = [np.zeros_like(x) for x in self.m]
        self.t = 0
        self._nonfinite_skips = 0

    def step(self, grads):
        if not _check_finite_grads(self, grads):
            return
        self.t += 1
        bc1, bc2 = 1.0 - self.b1 ** self.t, 1.0 - self.b2 ** self.t
        for i, (p, g) in enumerate(zip(self.params, grads)):
            g = g.astype(np.float32) / self.scale
            self.m[i] = self.b1 * self.m[i] + (1 - self.b1) * g
            self.v[i] = self.b2 * self.v[i] + (1 - self.b2) * (g * g)
            mhat, vhat = self.m[i] / bc1, self.v[i] / bc2
            p.attrs["value"] = p.attrs["value"] - self.lr * mhat / (np.sqrt(vhat) + self.eps)

Trainer

Compiles a forward program ONCE plus one backward program PER PARAMETER (each emitting that param's gradient in its natural 2-D shape); step evals the backward programs on the ANE and applies the optimizer (params update host-side, fed back next eval - no recompile). One program per param avoids an ANECCompile wall hit by reshaping a large weight grad into a wide row and concatenating it with a differently-sized row (the math is unchanged).

Accepts either
  • a scalar loss Tensor (regression): forward program outputs the loss scalar; loss() reads it; backward seeds from a ones-seed at the loss.
  • a CEHandle (classification): forward program outputs the logits; backward seeds from the analytic on-ANE gradient (softmax(logits) - target) * (loss_scale / N) at the logits. Host-side loss() (fp32 cross-entropy) and accuracy(X, y_labels) (argmax) read the logits program.

optimizer="sgd"|"adam" selects the optimizer.

device_optimizer=True runs the OPTIMIZER STEP on the ANE: alongside the per-param backward programs (-> grads), a per-param UPDATE program computes the new state with ANE graph ops, so no training tensor-math runs on the host. The host only computes the scalar lr_t, shuttles state/grads in-out (the deferred host<->device round-trip), samples minibatch indices, and prints. The Adam moments m/v are held host-side as fp16 arrays, fed each step and read back. device_optimizer=False (default) keeps the host fp32 optimizer path byte-for-byte unchanged (the regression guard + the 98% baseline).

Source code in aneforge/autograd.py
 909
 910
 911
 912
 913
 914
 915
 916
 917
 918
 919
 920
 921
 922
 923
 924
 925
 926
 927
 928
 929
 930
 931
 932
 933
 934
 935
 936
 937
 938
 939
 940
 941
 942
 943
 944
 945
 946
 947
 948
 949
 950
 951
 952
 953
 954
 955
 956
 957
 958
 959
 960
 961
 962
 963
 964
 965
 966
 967
 968
 969
 970
 971
 972
 973
 974
 975
 976
 977
 978
 979
 980
 981
 982
 983
 984
 985
 986
 987
 988
 989
 990
 991
 992
 993
 994
 995
 996
 997
 998
 999
1000
1001
1002
1003
1004
1005
1006
1007
1008
1009
1010
1011
1012
1013
1014
1015
1016
1017
1018
1019
1020
1021
1022
1023
1024
1025
1026
1027
1028
1029
1030
1031
1032
1033
1034
1035
1036
1037
1038
1039
1040
1041
1042
1043
1044
1045
1046
1047
1048
1049
1050
1051
1052
1053
1054
1055
1056
1057
1058
1059
1060
1061
1062
1063
1064
1065
1066
1067
1068
1069
1070
1071
1072
1073
1074
1075
1076
1077
1078
1079
1080
1081
1082
1083
1084
1085
1086
1087
1088
1089
1090
1091
1092
1093
1094
1095
1096
1097
1098
1099
1100
1101
1102
1103
1104
1105
1106
1107
1108
1109
1110
1111
1112
1113
1114
1115
1116
1117
1118
1119
1120
1121
1122
1123
1124
1125
1126
1127
1128
1129
1130
1131
1132
1133
1134
1135
1136
1137
1138
1139
1140
1141
1142
1143
1144
1145
1146
1147
1148
1149
1150
1151
1152
1153
1154
1155
1156
1157
1158
1159
1160
1161
1162
1163
1164
1165
1166
1167
1168
1169
1170
1171
1172
1173
1174
1175
1176
1177
1178
1179
1180
1181
1182
1183
1184
1185
1186
1187
1188
1189
1190
1191
1192
1193
1194
1195
1196
1197
1198
1199
1200
1201
1202
1203
1204
1205
1206
1207
1208
1209
1210
1211
1212
1213
1214
1215
1216
1217
1218
1219
1220
1221
1222
1223
1224
1225
1226
1227
1228
1229
1230
1231
1232
1233
1234
1235
1236
1237
1238
1239
1240
1241
1242
1243
1244
1245
1246
1247
1248
1249
1250
1251
1252
1253
1254
class Trainer:
    """Compiles a forward program ONCE plus one backward program PER PARAMETER
    (each emitting that param's gradient in its natural 2-D shape); `step` evals
    the backward programs on the ANE and applies the optimizer (params update
    host-side, fed back next eval - no recompile). One program per param avoids an
    ANECCompile wall hit by reshaping a large weight grad into a wide row and
    concatenating it with a differently-sized row (the math is unchanged).

    Accepts either:
      * a scalar `loss` Tensor (regression): forward program outputs the loss
        scalar; `loss()` reads it; backward seeds from a ones-seed at the loss.
      * a `CEHandle` (classification): forward program outputs the logits;
        backward seeds from the analytic on-ANE gradient
        `(softmax(logits) - target) * (loss_scale / N)` at the logits.
        Host-side `loss()` (fp32 cross-entropy) and `accuracy(X, y_labels)`
        (argmax) read the logits program.

    `optimizer="sgd"|"adam"` selects the optimizer.

    `device_optimizer=True` runs the OPTIMIZER STEP on the ANE: alongside the
    per-param backward programs (-> grads), a per-param UPDATE program computes the
    new state with ANE graph ops, so no training tensor-math runs on the host. The
    host only computes the scalar `lr_t`, shuttles state/grads in-out (the deferred
    host<->device round-trip), samples minibatch indices, and prints. The Adam
    moments `m`/`v` are held host-side as fp16 arrays, fed each step and read back.
    `device_optimizer=False` (default) keeps the host fp32 optimizer path
    byte-for-byte unchanged (the regression guard + the 98% baseline)."""
    def __init__(self, objective, params, lr: float, loss_scale: float = 1.0,
                 data_inputs: dict | None = None, optimizer: str = "sgd",
                 betas=(0.9, 0.999), eps: float = 1e-8, device_optimizer: bool = False,
                 resident_state: bool = False):
        from . import _compile as _c
        self.params = list(params)
        # A13/M1 conv weight-grad saturation guard: warn if loss_scale could push the
        # width-offset im2col backward slices past the x16 crop-DMA threshold (no-op off A13
        # or for non-conv graphs). Done before scale is consumed by the optimizer/seed.
        loss_scale = _guard_a13_conv_loss_scale(self.params, objective, float(loss_scale))
        self.data = dict(data_inputs or {})       # {input Tensor: numpy value}
        self.scale = float(loss_scale)
        self.lr = float(lr)
        self.b1, self.b2 = float(betas[0]), float(betas[1])
        self.eps = float(eps)
        self.optimizer = optimizer
        # resident_state implies the optimizer runs on-device (it IS the on-device
        # update, with state held resident); device_optimizer is forced on.
        self.resident_state = bool(resident_state)
        self.device_optimizer = bool(device_optimizer) or self.resident_state
        self.opt = (Adam(self.params, lr, betas, eps=eps, loss_scale=loss_scale)
                    if optimizer == "adam" else SGD(self.params, lr, loss_scale))
        if isinstance(objective, CEHandle):
            self.ce = objective
            grads = backward_from(objective.seed(loss_scale), objective.logits, self.params)
            fwd_out = objective.logits                 # forward program -> logits
        else:
            self.ce = None
            grads = backward(objective, self.params, loss_scale=loss_scale)
            fwd_out = objective                        # forward program -> loss scalar
        # Backward: ONE e5rt program per parameter, each emitting that param's grad
        # in its NATURAL shape. Concatenating all grads into a single wide row trips
        # an ANECCompile wall when a large weight grad (e.g. 784x128) is reshaped to a
        # 100k-wide row and concatenated with a differently-sized row (verified: the
        # reshape-to-wide-row-into-concat is the trigger, not the matmul). Per-param
        # programs stay inside the verified 2-D matmul envelope, and the optimizer
        # consumes the grads in param shape anyway.
        # aneforge's own training kernels (forward loss, per-param backward, optimizer
        # update) contain structural subtracts - the loss pred-target, gradient axpys,
        # the w-lr*g update - that trip the generic cancel_sub precision heuristic. These
        # are vouched, accuracy-tested kernels (the MNIST baselines + the corpus), not
        # user-data modeling choices, so they skip the user-facing precision check.
        self._fwd = _c.compile(fwd_out, _check_precision=False)
        if self.resident_state:
            # The whole step (forward + backward + per-param update) is ONE fused
            # multi-output program; state tensors stay resident on-device across
            # steps (no host round-trip). _fwd is kept only for checkpoint accuracy.
            self._build_resident(_c, grads)
        else:
            self._bwd = [_c.compile(grads[p], _check_precision=False) for p in self.params]
            if self.device_optimizer:
                self._build_device_optimizer(_c)

    def _build_device_optimizer(self, _c):
        """Compile a per-param UPDATE program so the optimizer arithmetic runs on
        the ANE. SGD: `(w_p, g_p, lr_t) -> w_p'` (single output). Adam:
        `(w_p, m_p, v_p, g_p, lr_t) -> stack(w_p', m_p', v_p')` via `_stack3` (the
        single-output STACK path), split host-side. `g_p`/`m_p`/`v_p`/`lr_t` are
        plain (non-trainable) leaves fed each step; `w_p` is the param leaf itself
        (its master value is fed, the new value written back)."""
        self._upd, self._upd_g, self._upd_lr = [], [], []
        if self.optimizer == "adam":
            self._m = [np.zeros(p.shape, np.float16) for p in self.params]
            self._v = [np.zeros(p.shape, np.float16) for p in self.params]
            self._upd_m, self._upd_v = [], []
            self._t = 0
        for p in self.params:
            g_in = graph.input(p.shape)
            lr_in = graph.input((1, 1))
            self._upd_g.append(g_in); self._upd_lr.append(lr_in)
            if self.optimizer == "adam":
                m_in = graph.input(p.shape); v_in = graph.input(p.shape)
                self._upd_m.append(m_in); self._upd_v.append(v_in)
                w2, m2, v2 = _adam_update(p, m_in, v_in, g_in, lr_in, self.b1, self.b2, self.eps)
                out = _stack3(w2, m2, v2)
            else:
                out = _sgd_update(p, g_in, lr_in)
            self._upd.append(_c.compile(out, _check_precision=False))

    def _build_resident(self, _c, grads):
        """Assemble the whole training step as ONE fused multi-output program with
        optimizer state RESIDENT on-device. Each param `p` (and, for Adam, its
        moments `m_p`/`v_p`) is a graph input whose updated value is a program
        OUTPUT aliased back onto that input port via `share_buffer` - so state
        lives on the engine across steps and the host feeds only the minibatch
        (x, target) + the scalar `lr_t`, reading state back at checkpoints. The
        forward and the update share the same resident param buffer: within one
        execute the stream's FIFO ordering has the forward read the pre-step param
        and the update overwrite it last (the next step reads the advanced value)."""
        lr_in = graph.input((1, 1))
        self._res_lr = lr_in
        self._t = 0
        outs, alias, self._res_state = [], [], []
        for p in self.params:
            g = grads[p]
            entry = {"p": p}
            if self.optimizer == "adam":
                m_in, v_in = graph.input(p.shape), graph.input(p.shape)
                w2, m2, v2 = _adam_update(p, m_in, v_in, g, lr_in, self.b1, self.b2, self.eps)
                outs += [w2, m2, v2]
                alias += [(w2, p), (m2, m_in), (v2, v_in)]
                entry.update(m_in=m_in, v_in=v_in, w_out=w2)
            else:
                w2 = _sgd_update(p, g, lr_in)
                outs += [w2]
                alias += [(w2, p)]
                entry.update(w_out=w2)
            self._res_state.append(entry)

        mm = _c.compile_multi(outs)
        self._res = mm
        prog = mm.prog
        self._res_in_name = {t: n for t, n in mm.input_ports}
        self._res_out_name = {t: n for t, n in mm.output_ports}
        # alias each updated-state output onto its own input port (resident), then
        # seed the (now shared) buffers once - params to their masters, moments to 0.
        for out_t, in_t in alias:
            prog.share_buffer(0, self._res_out_name[out_t], 0, self._res_in_name[in_t])
        for entry in self._res_state:
            p = entry["p"]
            prog.set_input(self._res_in_name[p], p.attrs["value"].astype(np.float16))
            if self.optimizer == "adam":
                prog.set_input(self._res_in_name[entry["m_in"]], np.zeros(p.shape, np.float16))
                prog.set_input(self._res_in_name[entry["v_in"]], np.zeros(p.shape, np.float16))
        # the remaining inputs (neither state nor lr) are the data ports (x, target)
        state_ids = set()
        for entry in self._res_state:
            state_ids.add(id(entry["p"]))
            if self.optimizer == "adam":
                state_ids.update({id(entry["m_in"]), id(entry["v_in"])})
        self._res_data_inputs = [t for t, _ in mm.input_ports
                                 if id(t) not in state_ids and t is not lr_in]
        self._res_dirty = False

    def _resident_step(self) -> None:
        """One resident step: feed ONLY the minibatch + lr_t, execute. State ports
        are never written (they live on-device); the host does no tensor math and
        no state copy."""
        if getattr(self, "_ds", None) is not None:
            xin, X, tin, Y = self._ds
            idx = self._next_batch()
            self.data[xin] = X[idx]; self.data[tin] = Y[idx]
        prog = self._res.prog
        for t in self._res_data_inputs:
            prog.set_input(self._res_in_name[t], np.asarray(self.data[t], np.float16))
        if self.optimizer == "adam":
            self._t += 1
            lr_t = self.lr * math.sqrt(1.0 - self.b2 ** self._t) / (1.0 - self.b1 ** self._t)
        else:
            lr_t = self.lr / self.scale
        prog.set_input(self._res_in_name[self._res_lr], np.full((1, 1), lr_t, np.float16))
        prog.execute()
        self._res_dirty = True

    def _sync_params_from_device(self) -> None:
        """Checkpoint read: copy the resident params off the device into the host
        masters (for accuracy/loss via the forward program). Moments stay resident."""
        prog = self._res.prog
        for i, entry in enumerate(self._res_state):
            w = prog.read_output(self._res_out_name[entry["w_out"]]).astype(np.float32)
            if not np.isfinite(w).all():
                import warnings
                warnings.warn(
                    f"aneforge.Trainer: resident param {i} (shape {tuple(entry['p'].shape)}) read "
                    f"back non-finite values (inf/nan) at checkpoint - the on-device update was "
                    f"poisoned by an overflowed fp16 gradient; if you see inf/nan weight-grads, "
                    f"lower loss_scale.",
                    stacklevel=3)
            entry["p"].attrs["value"] = w.reshape(entry["p"].shape)
        self._res_dirty = False

    def _feed_update(self, net, p, extra):
        """Feed a per-param update program: the param leaf -> its master value;
        every other input -> the value mapped in `extra` (keyed by Tensor)."""
        vals = []
        for t in net._input_tensors:
            if t.attrs.get("trainable"):
                vals.append(t.attrs["value"].astype(np.float16))
            elif t in extra:
                vals.append(np.asarray(extra[t]).astype(np.float16))
            else:
                vals.append(np.asarray(t.attrs["value"]).astype(np.float16))   # baked value-input (e.g. norm gamma)
        return vals

    def _feed(self, model):
        # Map each compiled input Tensor to a value: trainable -> current master
        # value; provided data -> its data value; otherwise a baked value-input
        # carrying attrs['value'] (e.g. the gamma a norm VJP re-injects).
        # `model._input_tensors` is the ordered input list stored by compile
        # (matches the call signature).
        vals = []
        for t in model._input_tensors:
            if t.attrs.get("trainable"):
                vals.append(t.attrs["value"].astype(np.float16))
            elif t in self.data:
                vals.append(np.asarray(self.data[t]).astype(np.float16))
            else:
                vals.append(np.asarray(t.attrs["value"]).astype(np.float16))   # baked value-input (e.g. norm gamma)
        return vals

    def set_dataset(self, x_input, X_full, target_input, Y_onehot, seed: int = 0):
        """Provide the full dataset for mini-batch sampling. `x_input`/`target_input`
        are the batch-B graph input placeholders the objective was built from."""
        self._ds = (x_input, np.asarray(X_full, np.float32), target_input, np.asarray(Y_onehot, np.float32))
        self._ds_B = int(x_input.shape[0])
        self._ds_rng = np.random.default_rng(seed)
        self._ds_perm = self._ds_rng.permutation(len(self._ds[1]))
        self._ds_pos = 0

    def _next_batch(self):
        B = self._ds_B
        if self._ds_pos + B > len(self._ds_perm):       # reshuffle each epoch
            self._ds_perm = self._ds_rng.permutation(len(self._ds[1]))
            self._ds_pos = 0
        idx = self._ds_perm[self._ds_pos:self._ds_pos + B]
        self._ds_pos += B
        return idx

    def step(self) -> None:
        if self.resident_state:
            self._resident_step()
            return
        if getattr(self, "_ds", None) is not None:
            xin, X, tin, Y = self._ds
            idx = self._next_batch()
            self.data[xin] = X[idx]
            self.data[tin] = Y[idx]
        grads = [np.asarray(net(*self._feed(net))).reshape(p.shape)
                 for net, p in zip(self._bwd, self.params)]
        if not self.device_optimizer:
            self.opt.step(grads)
            return
        # On-ANE optimizer: the update arithmetic runs as graph ops. The host only
        # computes the scalar lr_t (folding loss-scale + Adam bias correction) and
        # shuttles state/grads in-out - no host tensor math.
        if self.optimizer == "adam":
            self._device_adam_step(grads)
        else:
            self._device_sgd_step(grads)

    def _device_sgd_step(self, grads):
        lr_t = self.lr / self.scale          # fold grad-unscale into lr_t
        lr_arr = np.full((1, 1), lr_t, np.float16)
        for i, p in enumerate(self.params):
            net, g_in, lr_in = self._upd[i], self._upd_g[i], self._upd_lr[i]
            extra = {g_in: grads[i], lr_in: lr_arr}
            w2 = np.asarray(net(*self._feed_update(net, p, extra))).reshape(p.shape)
            p.attrs["value"] = w2.astype(np.float32)

    def _device_adam_step(self, grads):
        self._t += 1
        # Host scalar lr_t = lr * sqrt(1-b2^t)/(1-b1^t) (bias correction folded in).
        # The loss-scale is not divided out here: Adam's update is the ratio
        # m/sqrt(v), and with the scaled grad g'=scale*g the moments scale as
        # m'=scale*m, v'=scale^2*v, so m'/sqrt(v') = m/sqrt(v) - the loss-scale
        # cancels in the ratio. Dividing lr_t by scale (a naive SGD-style unscale)
        # double-unscales and collapses the step. (eps is the only scale-sensitive
        # term and is negligible vs scale*sqrt(v).)
        lr_t = self.lr * math.sqrt(1.0 - self.b2 ** self._t) / (1.0 - self.b1 ** self._t)
        lr_arr = np.full((1, 1), lr_t, np.float16)
        for i, p in enumerate(self.params):
            net = self._upd[i]
            extra = {self._upd_g[i]: grads[i], self._upd_m[i]: self._m[i],
                     self._upd_v[i]: self._v[i], self._upd_lr[i]: lr_arr}
            out = np.asarray(net(*self._feed_update(net, p, extra)))
            w2, m2, v2 = _split3(out, p.shape)
            p.attrs["value"] = w2.astype(np.float32)
            self._m[i] = m2.astype(np.float16)       # fp16 optimizer state held host-side
            self._v[i] = v2.astype(np.float16)

    def accuracy(self, X, y_labels) -> float:
        """Argmax accuracy over X (any length) via the batch-B forward program,
        chunking X into B-row pieces (last padded then truncated)."""
        assert self.ce is not None, "accuracy() is for classification objectives"
        if self.resident_state and getattr(self, "_res_dirty", False):
            self._sync_params_from_device()
        X = np.asarray(X, np.float32); y = np.asarray(y_labels)
        # the feature input is the non-trainable, non-target input whose per-sample
        # shape matches X (handles 2-D [B,D] MLP and N-D [B,...] CNN inputs).
        feat_shape = X.shape[1:]
        xin = next(t for t in self._fwd._input_tensors
                   if not t.attrs.get("trainable") and t is not getattr(self.ce, "target", None)
                   and tuple(t.shape[1:]) == tuple(feat_shape))
        B = xin.shape[0]
        saved = self.data.get(xin)
        preds = []
        for s in range(0, X.shape[0], B):
            chunk = X[s:s + B]
            m = chunk.shape[0]
            if m < B:
                pad = np.zeros((B - m,) + tuple(feat_shape), np.float32)
                chunk = np.concatenate([chunk, pad], axis=0)
            self.data[xin] = chunk
            logits = np.asarray(self._fwd(*self._feed(self._fwd)))
            preds.append(logits[:m].argmax(1))
        if saved is not None:
            self.data[xin] = saved
        return float((np.concatenate(preds) == y).mean())

    def loss(self) -> float:
        if self.resident_state and getattr(self, "_res_dirty", False):
            self._sync_params_from_device()
        out = np.asarray(self._fwd(*self._feed(self._fwd)))
        if self.ce is None:
            return float(out.reshape(-1)[0])
        logits = out.reshape(self.ce.logits.shape)
        t = np.asarray(self.data[self.ce.target])
        z = logits - logits.max(1, keepdims=True)
        logsm = z - np.log(np.exp(z).sum(1, keepdims=True))
        return float(-(t * logsm).sum(1).mean())

    def release(self) -> None:
        self._fwd.release()
        if getattr(self, "_res", None) is not None:
            self._res.release()
        for net in getattr(self, "_bwd", []):
            net.release()
        for net in getattr(self, "_upd", []):
            net.release()

set_dataset

set_dataset(x_input, X_full, target_input, Y_onehot, seed: int = 0)

Provide the full dataset for mini-batch sampling. x_input/target_input are the batch-B graph input placeholders the objective was built from.

Source code in aneforge/autograd.py
def set_dataset(self, x_input, X_full, target_input, Y_onehot, seed: int = 0):
    """Provide the full dataset for mini-batch sampling. `x_input`/`target_input`
    are the batch-B graph input placeholders the objective was built from."""
    self._ds = (x_input, np.asarray(X_full, np.float32), target_input, np.asarray(Y_onehot, np.float32))
    self._ds_B = int(x_input.shape[0])
    self._ds_rng = np.random.default_rng(seed)
    self._ds_perm = self._ds_rng.permutation(len(self._ds[1]))
    self._ds_pos = 0

accuracy

accuracy(X, y_labels) -> float

Argmax accuracy over X (any length) via the batch-B forward program, chunking X into B-row pieces (last padded then truncated).

Source code in aneforge/autograd.py
def accuracy(self, X, y_labels) -> float:
    """Argmax accuracy over X (any length) via the batch-B forward program,
    chunking X into B-row pieces (last padded then truncated)."""
    assert self.ce is not None, "accuracy() is for classification objectives"
    if self.resident_state and getattr(self, "_res_dirty", False):
        self._sync_params_from_device()
    X = np.asarray(X, np.float32); y = np.asarray(y_labels)
    # the feature input is the non-trainable, non-target input whose per-sample
    # shape matches X (handles 2-D [B,D] MLP and N-D [B,...] CNN inputs).
    feat_shape = X.shape[1:]
    xin = next(t for t in self._fwd._input_tensors
               if not t.attrs.get("trainable") and t is not getattr(self.ce, "target", None)
               and tuple(t.shape[1:]) == tuple(feat_shape))
    B = xin.shape[0]
    saved = self.data.get(xin)
    preds = []
    for s in range(0, X.shape[0], B):
        chunk = X[s:s + B]
        m = chunk.shape[0]
        if m < B:
            pad = np.zeros((B - m,) + tuple(feat_shape), np.float32)
            chunk = np.concatenate([chunk, pad], axis=0)
        self.data[xin] = chunk
        logits = np.asarray(self._fwd(*self._feed(self._fwd)))
        preds.append(logits[:m].argmax(1))
    if saved is not None:
        self.data[xin] = saved
    return float((np.concatenate(preds) == y).mean())

UnrolledTrainer

Train with K Adam steps UNROLLED into ONE fused ANE program, so the whole forward -> backward -> optimizer-update recurrence runs on the engine with NO per-step host loop. Each step() runs K steps in a SINGLE dispatch: the host feeds K minibatches plus the K per-step learning rates and shuttles the (params, m, v) arrays in and out between K-step blocks. That shuttle is an array move, not tensor math, and there's no per-step host<->device round-trip inside the block.

The bounded-K, fully-on-engine analogue of Trainer (whose step() dispatches once per step). Enabled by the stop-gradient frontier in backward/backward_from: threading one step's updated-weight tensors into the next step's forward needs each step's gradient to treat the current weights as leaves (plain SGD/Adam), not differentiate through the previous update.

Parameters:

Name Type Description Default
params list

trainable leaves (af.parameter / af.conv_param).

required
forward callable

forward(P, x) -> output building the model from the current-step weight tensors P (a list aligned with params) and a data input tensor x. Used both per unrolled step and for predict.

required
kind str

"ce" (output = logits, softmax-cross-entropy) or "mse".

required
x_inputs list

list of K data input placeholders (af.input), one per step.

required
t_inputs list

list of K target input placeholders, one per step.

required
dataset tuple

(X, Y) numpy arrays; X already shaped like x_inputs[k] per sample (leading axis = sample), Y one-hot ([N, C]) for "ce".

required
resident bool

if True (default) the optimizer state (params, m, v) stays RESIDENT on-device across dispatches - each updated-state output is aliased (share_buffer) onto its own input port, seeded once. The host then feeds ONLY the K minibatches + per-step lr and reads weights at checkpoints; nothing is shuttled in/out between dispatches. The end-to-end fully-on- engine path (no per-step loop AND no state move).

True
Source code in aneforge/autograd.py
class UnrolledTrainer:
    """Train with `K` Adam steps UNROLLED into ONE fused ANE program, so the whole
    forward -> backward -> optimizer-update recurrence runs on the engine with NO
    per-step host loop. Each `step()` runs K steps in a SINGLE dispatch: the host
    feeds K minibatches plus the K per-step learning rates and shuttles the (params,
    m, v) arrays in and out between K-step blocks. That shuttle is an array move, not
    tensor math, and there's no per-step host<->device round-trip inside the block.

    The bounded-K, fully-on-engine analogue of `Trainer` (whose `step()` dispatches
    once per step). Enabled by the stop-gradient frontier in
    `backward`/`backward_from`: threading one step's updated-weight tensors into the
    next step's forward needs each step's gradient to treat the current weights as
    leaves (plain SGD/Adam), not differentiate through the previous update.

    Args:
      params (list): trainable leaves (`af.parameter` / `af.conv_param`).
      forward (callable): `forward(P, x) -> output` building the model from the
        current-step weight tensors `P` (a list aligned with `params`) and a data
        input tensor `x`. Used both per unrolled step and for `predict`.
      kind (str): `"ce"` (output = logits, softmax-cross-entropy) or `"mse"`.
      x_inputs (list): list of K data input placeholders (`af.input`), one per step.
      t_inputs (list): list of K target input placeholders, one per step.
      dataset (tuple): `(X, Y)` numpy arrays; `X` already shaped like `x_inputs[k]`
        per sample (leading axis = sample), `Y` one-hot ([N, C]) for `"ce"`.
      resident (bool): if True (default) the optimizer state (params, m, v) stays
        RESIDENT on-device across dispatches - each updated-state output is aliased
        (`share_buffer`) onto its own input port, seeded once. The host then feeds
        ONLY the K minibatches + per-step lr and reads weights at checkpoints;
        nothing is shuttled in/out between dispatches. The end-to-end fully-on-
        engine path (no per-step loop AND no state move).
    """
    def __init__(self, params, forward, kind, x_inputs, t_inputs, dataset, lr,
                 loss_scale: float = 1.0, betas=(0.9, 0.999), eps: float = 1e-8,
                 seed: int = 0, resident: bool = True):
        from . import _compile as _c
        if kind not in ("ce", "mse"):
            raise ValueError("kind must be 'ce' or 'mse'")
        self.params = list(params)
        self.forward = forward
        self.kind = kind
        self.K = len(x_inputs)
        self.lr = float(lr); self.scale = float(loss_scale)
        self.b1, self.b2 = float(betas[0]), float(betas[1]); self.eps = float(eps)
        self.X = np.asarray(dataset[0], np.float32)
        self.Y = np.asarray(dataset[1], np.float32)
        self.B = int(x_inputs[0].shape[0])
        self.m = [np.zeros(p.shape, np.float16) for p in self.params]
        self.v = [np.zeros(p.shape, np.float16) for p in self.params]
        self.t = 0
        self.rng = np.random.default_rng(seed)
        self._perm = self.rng.permutation(len(self.X)); self._pos = 0

        # build the unrolled graph: thread (P, m, v) through K Adam steps
        m_in = [graph.input(p.shape) for p in self.params]
        v_in = [graph.input(p.shape) for p in self.params]
        lr_ins = [graph.input((1, 1)) for _ in range(self.K)]
        P, M, V = list(self.params), list(m_in), list(v_in)
        for k in range(self.K):
            out = forward(P, x_inputs[k])
            if kind == "ce":
                g = backward_from(softmax_cross_entropy(out, t_inputs[k]).seed(self.scale), out, P)
            else:
                g = backward(mse(out, t_inputs[k]), P, loss_scale=self.scale)
            P, M, V = adam_step(P, M, V, g, lr_ins[k], (self.b1, self.b2), self.eps)
        self._net = _c.compile_multi([*P, *M, *V])
        self._oname = {t: n for t, n in self._net.output_ports}
        self._P_out, self._M_out, self._V_out = P, M, V
        self._m_in, self._v_in, self._lr_ins = m_in, v_in, lr_ins
        # map each data input tensor -> (step k, 'x'|'t') for feeding
        self._data_map = {}
        for k in range(self.K):
            self._data_map[id(x_inputs[k])] = (k, "x")
            self._data_map[id(t_inputs[k])] = (k, "t")

        # a separate single-batch forward program for checkpoint predict (on the ANE).
        # It uses its OWN weight-input leaves (fed the masters), NOT the trainable
        # params: compile mutates each Tensor's internal name, so sharing the param
        # objects between the two programs would clobber the training program's ports.
        ev_w = [graph.input(p.shape) for p in self.params]
        for ew, p in zip(ev_w, self.params):
            if "conv_shape" in p.attrs:
                ew.attrs["conv_shape"] = p.attrs["conv_shape"]
        xe = graph.input(x_inputs[0].shape)
        self._ev_w = ev_w
        self._eval = _c.compile(forward(ev_w, xe), _check_precision=False)

        self.resident = bool(resident)
        if self.resident:
            # Alias each final state OUTPUT onto its own initial INPUT port, so params/
            # m/v live on-device across dispatches; seed the shared buffers once.
            prog = self._net.prog
            inm = {id(t): n for t, n in self._net.input_ports}
            self._res_inm = inm
            self._res_lr_names = [inm[id(t)] for t in lr_ins]
            self._res_data = ([(inm[id(x_inputs[k])], k, "x") for k in range(self.K)] +
                              [(inm[id(t_inputs[k])], k, "t") for k in range(self.K)])
            pairs = (list(zip(self._P_out, self.params)) +
                     list(zip(self._M_out, m_in)) + list(zip(self._V_out, v_in)))
            for out_t, in_t in pairs:
                prog.share_buffer(0, self._oname[out_t], 0, inm[id(in_t)])
            for i, p in enumerate(self.params):
                prog.set_input(inm[id(p)], p.attrs["value"].astype(np.float16))
                prog.set_input(inm[id(m_in[i])], np.zeros(p.shape, np.float16))
                prog.set_input(inm[id(v_in[i])], np.zeros(p.shape, np.float16))
            self._res_dirty = False

    def _next(self):
        if self._pos + self.B > len(self._perm):
            self._perm = self.rng.permutation(len(self.X)); self._pos = 0
        idx = self._perm[self._pos:self._pos + self.B]; self._pos += self.B
        return idx

    def step(self) -> None:
        """Run K training steps on the ANE in ONE dispatch. Resident: feed only the K
        minibatches + per-step lr; state stays on-device. Else: shuttle params/m/v."""
        batches = [self._next() for _ in range(self.K)]
        if self.resident:
            prog = self._net.prog
            for name, k, which in self._res_data:
                idx = batches[k]
                prog.set_input(name, (self.X[idx] if which == "x" else self.Y[idx]).astype(np.float16))
            for k, name in enumerate(self._res_lr_names):
                gt = self.t + k + 1
                prog.set_input(name, np.full((1, 1), self.lr * math.sqrt(1.0 - self.b2 ** gt) /
                                             (1.0 - self.b1 ** gt), np.float16))
            prog.execute()
            self.t += self.K
            self._res_dirty = True
            return
        vals = []
        for t in self._net.input_tensors:
            if t.attrs.get("trainable"):
                vals.append(t.attrs["value"].astype(np.float16))
            elif t in self._m_in:
                vals.append(self.m[self._m_in.index(t)])
            elif t in self._v_in:
                vals.append(self.v[self._v_in.index(t)])
            elif t in self._lr_ins:
                gt = self.t + self._lr_ins.index(t) + 1          # global step for bias correction
                lr_t = self.lr * math.sqrt(1.0 - self.b2 ** gt) / (1.0 - self.b1 ** gt)
                vals.append(np.full((1, 1), lr_t, np.float16))
            else:
                k, which = self._data_map[id(t)]
                idx = batches[k]
                vals.append((self.X[idx] if which == "x" else self.Y[idx]).astype(np.float16))
        out = self._net(*vals)
        for i, p in enumerate(self.params):
            p.attrs["value"] = out[self._oname[self._P_out[i]]].reshape(p.shape)
            self.m[i] = out[self._oname[self._M_out[i]]].astype(np.float16).reshape(p.shape)
            self.v[i] = out[self._oname[self._V_out[i]]].astype(np.float16).reshape(p.shape)
        self.t += self.K

    def _sync_from_device(self) -> None:
        """Checkpoint read: copy the resident params off-device into the host masters."""
        prog = self._net.prog
        for i, p in enumerate(self.params):
            w = prog.read_output(self._oname[self._P_out[i]]).astype(np.float32)
            p.attrs["value"] = w.reshape(p.shape)
        self._res_dirty = False

    def predict(self, X) -> np.ndarray:
        """Run the trained weights forward on the ANE in B-sized chunks; returns the
        model output (logits for 'ce', prediction for 'mse')."""
        if self.resident and getattr(self, "_res_dirty", False):
            self._sync_from_device()
        X = np.asarray(X, np.float32)
        feeds_w = [p.attrs["value"].astype(np.float16) for p in self.params]
        outs = []
        for s in range(0, len(X), self.B):
            chunk = X[s:s + self.B]
            pad = self.B - len(chunk)
            if pad:
                chunk = np.concatenate([chunk, np.zeros((pad, *chunk.shape[1:]), np.float32)])
            # eval inputs are [weight leaves..., xe]; feed masters then the chunk
            args = []
            for t in self._eval._input_tensors:
                args.append(feeds_w[self._ev_w.index(t)] if t in self._ev_w
                            else chunk.astype(np.float16))
            o = np.asarray(self._eval(*args), np.float32)
            outs.append(o[:len(chunk)] if pad else o)
        return np.concatenate(outs)

    def release(self) -> None:
        self._net.release(); self._eval.release()

step

step() -> None

Run K training steps on the ANE in ONE dispatch. Resident: feed only the K minibatches + per-step lr; state stays on-device. Else: shuttle params/m/v.

Source code in aneforge/autograd.py
def step(self) -> None:
    """Run K training steps on the ANE in ONE dispatch. Resident: feed only the K
    minibatches + per-step lr; state stays on-device. Else: shuttle params/m/v."""
    batches = [self._next() for _ in range(self.K)]
    if self.resident:
        prog = self._net.prog
        for name, k, which in self._res_data:
            idx = batches[k]
            prog.set_input(name, (self.X[idx] if which == "x" else self.Y[idx]).astype(np.float16))
        for k, name in enumerate(self._res_lr_names):
            gt = self.t + k + 1
            prog.set_input(name, np.full((1, 1), self.lr * math.sqrt(1.0 - self.b2 ** gt) /
                                         (1.0 - self.b1 ** gt), np.float16))
        prog.execute()
        self.t += self.K
        self._res_dirty = True
        return
    vals = []
    for t in self._net.input_tensors:
        if t.attrs.get("trainable"):
            vals.append(t.attrs["value"].astype(np.float16))
        elif t in self._m_in:
            vals.append(self.m[self._m_in.index(t)])
        elif t in self._v_in:
            vals.append(self.v[self._v_in.index(t)])
        elif t in self._lr_ins:
            gt = self.t + self._lr_ins.index(t) + 1          # global step for bias correction
            lr_t = self.lr * math.sqrt(1.0 - self.b2 ** gt) / (1.0 - self.b1 ** gt)
            vals.append(np.full((1, 1), lr_t, np.float16))
        else:
            k, which = self._data_map[id(t)]
            idx = batches[k]
            vals.append((self.X[idx] if which == "x" else self.Y[idx]).astype(np.float16))
    out = self._net(*vals)
    for i, p in enumerate(self.params):
        p.attrs["value"] = out[self._oname[self._P_out[i]]].reshape(p.shape)
        self.m[i] = out[self._oname[self._M_out[i]]].astype(np.float16).reshape(p.shape)
        self.v[i] = out[self._oname[self._V_out[i]]].astype(np.float16).reshape(p.shape)
    self.t += self.K

predict

predict(X) -> np.ndarray

Run the trained weights forward on the ANE in B-sized chunks; returns the model output (logits for 'ce', prediction for 'mse').

Source code in aneforge/autograd.py
def predict(self, X) -> np.ndarray:
    """Run the trained weights forward on the ANE in B-sized chunks; returns the
    model output (logits for 'ce', prediction for 'mse')."""
    if self.resident and getattr(self, "_res_dirty", False):
        self._sync_from_device()
    X = np.asarray(X, np.float32)
    feeds_w = [p.attrs["value"].astype(np.float16) for p in self.params]
    outs = []
    for s in range(0, len(X), self.B):
        chunk = X[s:s + self.B]
        pad = self.B - len(chunk)
        if pad:
            chunk = np.concatenate([chunk, np.zeros((pad, *chunk.shape[1:]), np.float32)])
        # eval inputs are [weight leaves..., xe]; feed masters then the chunk
        args = []
        for t in self._eval._input_tensors:
            args.append(feeds_w[self._ev_w.index(t)] if t in self._ev_w
                        else chunk.astype(np.float16))
        o = np.asarray(self._eval(*args), np.float32)
        outs.append(o[:len(chunk)] if pad else o)
    return np.concatenate(outs)

vjp

vjp(*names: str)

Register a vjp rule fn(node, g) -> list[grad|None] (one per node.srcs).

Source code in aneforge/autograd.py
def vjp(*names: str):
    """Register a vjp rule `fn(node, g) -> list[grad|None]` (one per node.srcs)."""
    def reg(fn):
        for n in names:
            VJP[n] = fn
        return fn
    return reg

parameter

parameter(init) -> Tensor

A trainable leaf: a graph input tagged trainable, holding an fp32 master value in attrs['value']. Used like any input; fed its current value each eval and updated by the optimizer.

Source code in aneforge/autograd.py
def parameter(init) -> Tensor:
    """A trainable leaf: a graph input tagged trainable, holding an fp32 master
    value in attrs['value']. Used like any input; fed its current value each
    eval and updated by the optimizer."""
    init = np.asarray(init, dtype=np.float32)
    t = graph.input(init.shape)
    t.attrs["trainable"] = True
    t.attrs["value"] = init
    return t

backward

backward(loss: Tensor, params, loss_scale: float = 1.0, stop=None) -> dict

Reverse-mode grads of scalar loss wrt each Tensor in params. Returns {param: grad_Tensor}. The seed dL/dloss = loss_scale (folded into the seed's additive constant, avoiding a muls on the reduced loss output).

stop is the stop-gradient (detach) frontier: gradient reaches these tensors but does not propagate past them. Defaults to params - a no-op when the params are true graph leaves (the usual case), but it matters when an UNROLLED training step threads one step's updated-weight TENSORS into the next step's forward: there each step's gradient must treat the current weights as leaves (plain SGD), not differentiate through the previous update (second-order).

Source code in aneforge/autograd.py
def backward(loss: Tensor, params, loss_scale: float = 1.0, stop=None) -> dict:
    """Reverse-mode grads of scalar `loss` wrt each Tensor in `params`. Returns
    {param: grad_Tensor}. The seed dL/dloss = loss_scale (folded into the seed's
    additive constant, avoiding a muls on the reduced loss output).

    `stop` is the stop-gradient (detach) frontier: gradient reaches these tensors
    but does not propagate past them. Defaults to `params` - a no-op when the
    params are true graph leaves (the usual case), but it matters when an UNROLLED
    training step threads one step's updated-weight TENSORS into the next step's
    forward: there each step's gradient must treat the current weights as leaves
    (plain SGD), not differentiate through the previous update (second-order)."""
    stop_ids = {id(t) for t in (params if stop is None else stop)}
    order = _topo(loss, stop_ids)
    return _reverse(order, {id(loss): _const_like(loss, float(loss_scale))}, params, stop_ids)

backward_from

backward_from(grad_root, root, params, stop=None) -> dict

Reverse-mode from an explicit gradient grad_root at root (e.g. logits), rather than from a scalar loss + ones-seed. stop is the stop-gradient frontier (defaults to params); see backward - it matters for unrolled training.

Source code in aneforge/autograd.py
def backward_from(grad_root, root, params, stop=None) -> dict:
    """Reverse-mode from an explicit gradient `grad_root` at `root` (e.g. logits),
    rather than from a scalar loss + ones-seed. `stop` is the stop-gradient frontier
    (defaults to `params`); see `backward` - it matters for unrolled training."""
    stop_ids = {id(t) for t in (params if stop is None else stop)}
    return _reverse(_topo(root, stop_ids), {id(root): grad_root}, params, stop_ids)

conv_param

conv_param(weight_init) -> Tensor

A trainable conv weight parameter. weight_init is [Cout, Cin, kH, kW] (PyTorch conv layout); stored internally as the flat patch matrix [CinkHkW, Cout] that conv2d consumes. The patch (row) order is ci*(kH*kW) + (u*kW + v), matching the im2col built in conv2d.

Source code in aneforge/autograd.py
def conv_param(weight_init) -> Tensor:
    """A trainable conv weight parameter. `weight_init` is [Cout, Cin, kH, kW]
    (PyTorch conv layout); stored internally as the flat patch matrix
    [Cin*kH*kW, Cout] that `conv2d` consumes. The patch (row) order is
    `ci*(kH*kW) + (u*kW + v)`, matching the im2col built in `conv2d`."""
    W = np.asarray(weight_init, dtype=np.float32)
    Cout, Cin, kH, kW = W.shape
    flat = W.reshape(Cout, Cin * kH * kW).T.copy()        # [Cin*kH*kW, Cout]
    p = parameter(flat)
    p.attrs["conv_shape"] = (Cout, Cin, kH, kW)
    return p

conv2d

conv2d(x: Tensor, weight: Tensor, stride: int = 1, pad: int = 0) -> Tensor

A trainable stride-1 2-D conv built from primitives so weight is a real graph parameter (see conv_param). x is [N, Cin, H, W]; weight is a conv_param (flat [CinkHkW, Cout], carrying conv_shape). Returns [N, Cout, Hout, Wout]. stride must be 1 (strided slicing is unavailable); pad >= 0 zero-pads H and W IN-GRAPH (a zero-border concat before the im2col), so a 'same' conv stays inside one fused program and the padding differentiates through the existing concat VJP. With pad=0 the behaviour is byte-for-byte the previous implementation.

COMPILE SCALES WITH BATCH N: the im2col materialises [N, CinkHkW, HoutWout] tensors, so the compile* (tiling/partition) time grows with N - on M1/h13 a very large full batch (e.g. N~1000 over 28x28) can take minutes or hang the compiler (M5 compiles it fine). Train large datasets in MINI-BATCHES (a modest N, e.g. <=128, fed per step) rather than one full-batch graph.

Source code in aneforge/autograd.py
def conv2d(x: Tensor, weight: Tensor, stride: int = 1, pad: int = 0) -> Tensor:
    """A trainable stride-1 2-D conv built from primitives so `weight` is a real
    graph parameter (see `conv_param`). `x` is [N, Cin, H, W]; `weight` is a
    `conv_param` (flat [Cin*kH*kW, Cout], carrying `conv_shape`). Returns
    [N, Cout, Hout, Wout]. `stride` must be 1 (strided slicing is unavailable);
    `pad` >= 0 zero-pads H and W IN-GRAPH (a zero-border `concat` before the
    im2col), so a 'same' conv stays inside one fused program and the padding
    differentiates through the existing concat VJP. With `pad=0` the behaviour is
    byte-for-byte the previous implementation.

    COMPILE SCALES WITH BATCH N: the im2col materialises [N, Cin*kH*kW, Hout*Wout]
    tensors, so the *compile* (tiling/partition) time grows with N - on M1/h13 a
    very large full batch (e.g. N~1000 over 28x28) can take minutes or hang the
    compiler (M5 compiles it fine). Train large datasets in MINI-BATCHES (a modest
    N, e.g. <=128, fed per step) rather than one full-batch graph."""
    if stride != 1:
        raise NotImplementedError("conv2d (trainable) supports stride=1 only; "
                                  "downsample with avg_pool/max_pool.")
    if pad < 0:
        raise ValueError(f"conv2d: pad must be >= 0, got {pad}")
    if "conv_shape" not in weight.attrs:
        raise ValueError("conv2d weight must come from af.conv_param([Cout,Cin,kH,kW])")
    N, Cin, H, W = x.shape
    Cout, Cin_w, kH, kW = weight.attrs["conv_shape"]
    if Cin_w != Cin:
        raise ValueError(f"conv2d: weight Cin {Cin_w} != input Cin {Cin}")
    if pad:
        # In-graph zero padding: concat zero-constant borders onto H (axis 2) then W
        # (axis 3). The constant is a non-trainable input leaf; concat's VJP emits a
        # grad slice for it but it's discarded (not a parameter), adding no
        # trainable VJP.
        zh = graph.input((N, Cin, pad, W)); zh.attrs["value"] = np.zeros((N, Cin, pad, W), np.float32)
        x = graph.concat([zh, x, zh], axis=2)            # [N, Cin, H+2pad, W]
        H = H + 2 * pad
        zw = graph.input((N, Cin, H, pad)); zw.attrs["value"] = np.zeros((N, Cin, H, pad), np.float32)
        x = graph.concat([zw, x, zw], axis=3)            # [N, Cin, H+2pad, W+2pad]
        W = W + 2 * pad
    Hout, Wout = H - kH + 1, W - kW + 1
    L, K = Hout * Wout, Cin * kH * kW
    parts = []
    for u in range(kH):
        for v in range(kW):
            # patch index goes on axis 2 (NOT the last axis): the concat's backward is a
            # slice along that axis, and the h13 x16 crop-DMA saturation (>4094 -> +/-inf) fires
            # ONLY on a nonzero begin-offset of the LAST (width) axis. With the patches on a
            # non-last axis, the large loss-scaled input-gradient never transits the saturating
            # slice, so multi-layer conv training is numerically correct on M1 at any loss_scale.
            # (The forward x-slice still carries the small input; the transpose below moves K to
            # the last axis, and transpose's backward is a pure permute with no slice.)
            parts.append(x.slice_by_size([0, 0, u, v], [N, Cin, Hout, Wout]).reshape(N, Cin, 1, L))
    patches = graph.concat(parts, axis=2).transpose([0, 3, 1, 2]).reshape(N, L, K)   # [N,L,K]
    y = patches @ weight.reshape(1, K, Cout)                  # broadcast bmm -> [N,L,Cout]
    return y.transpose([0, 2, 1]).reshape(N, Cout, Hout, Wout)

mse

mse(y: Tensor, target: Tensor) -> Tensor

Mean squared error over all axes (a scalar loss).

Source code in aneforge/autograd.py
def mse(y: Tensor, target: Tensor) -> Tensor:
    """Mean squared error over all axes (a scalar loss)."""
    diff = y - target
    return diff.square().mean(tuple(range(len(y.shape))))

adam_step

adam_step(params, m, v, grads: dict, lr_t, betas=(0.9, 0.999), eps: float = 1e-08)

One Adam update as graph ops over lists params/m/v (grads keyed by param), returning the new (params, m, v) TENSOR lists. Used to UNROLL K training steps into one program: thread the returned lists into the next step's forward. Propagates a conv_param weight's conv_shape onto the updated tensor so an advanced conv weight still works as a conv2d weight in the next unrolled step.

Source code in aneforge/autograd.py
def adam_step(params, m, v, grads: dict, lr_t, betas=(0.9, 0.999), eps: float = 1e-8):
    """One Adam update as graph ops over lists `params`/`m`/`v` (grads keyed by
    param), returning the new (params, m, v) TENSOR lists. Used to UNROLL K training
    steps into one program: thread the returned lists into the next step's forward.
    Propagates a `conv_param` weight's `conv_shape` onto the updated tensor so an
    advanced conv weight still works as a `conv2d` weight in the next unrolled step."""
    b1, b2 = betas
    nP, nM, nV = [], [], []
    for p, mi, vi in zip(params, m, v):
        w2, m2, v2 = _adam_update(p, mi, vi, grads[p], lr_t, b1, b2, eps)
        if "conv_shape" in p.attrs:
            w2.attrs["conv_shape"] = p.attrs["conv_shape"]
        nP.append(w2); nM.append(m2); nV.append(v2)
    return nP, nM, nV

Layer-streamed training

streaming

Layer-streamed (gradient-checkpointed) training for deep stacks of identical layers.

A monolithic compile fuses a model's whole forward, backward, and optimizer step into ONE e5rt program, so compile time grows superlinearly with depth and caps how deep a model can train. When the layers are structurally identical (a transformer stack, a deep MLP), that cost is avoidable: the per-layer forward and backward each depend only on one layer's shape, not the depth, so they compile ONCE and reuse for every layer. CheckpointedStack does exactly that.

The backward is the standard gradient-checkpointing trick: store only each layer's INPUT activation, not every intermediate, and recompute the layer's forward inside its backward program. The reused backward program takes a layer's params, its checkpointed input, and the upstream gradient, and returns the param gradients plus the gradient with respect to the input (the upstream gradient for the layer below). The result is bit-identical to a monolithic backward (verified), with total compile work independent of layer count.

This module compiles the repeated stack; the surrounding embedding and output stages are ordinary compiled graphs the caller drives (each compiled once). The optimizer runs host-side over the streamed gradients, like the default autograd.Trainer path.

CheckpointedStack

A depth-independent compile for a stack of identical layers.

layer_fn(params, x) builds one layer: params is a list of graph Tensor parameters and x is the input activation Tensor; it returns the output activation Tensor (same shape as x). example_params is a list of numpy arrays giving one layer's parameter shapes, and io_shape is the activation shape that flows between layers.

Two programs are compiled: the per-layer forward and the per-layer backward (a multi-output program returning each param gradient and the input gradient). Both are reused for every layer, so compile cost does not grow with depth.

Source code in aneforge/streaming.py
class CheckpointedStack:
    """A depth-independent compile for a stack of identical layers.

    `layer_fn(params, x)` builds one layer: `params` is a list of graph `Tensor`
    parameters and `x` is the input activation `Tensor`; it returns the output
    activation `Tensor` (same shape as `x`). `example_params` is a list of numpy
    arrays giving one layer's parameter shapes, and `io_shape` is the activation shape
    that flows between layers.

    Two programs are compiled: the per-layer forward and the per-layer backward (a
    multi-output program returning each param gradient and the input gradient). Both are
    reused for every layer, so compile cost does not grow with depth.
    """

    def __init__(self, layer_fn, example_params, io_shape):
        self.io_shape = tuple(io_shape)
        self._nparam = len(example_params)

        # per-layer forward: y = layer_fn(params, x)
        self._x = _g.input(self.io_shape)
        self._p = [_ag.parameter(np.asarray(p, np.float32)) for p in example_params]
        y = layer_fn(self._p, self._x)
        if tuple(y.shape) != self.io_shape:
            raise ValueError(f"layer_fn output shape {y.shape} != io_shape {self.io_shape}")
        self._fwd = _compile(y)

        # per-layer backward: given the upstream gradient at the output, return the
        # param gradients and the input gradient (recompute-in-backward checkpointing).
        self._xb = _g.input(self.io_shape)
        self._pb = [_ag.parameter(np.asarray(p, np.float32)) for p in example_params]
        self._gout = _g.input(self.io_shape)
        yb = layer_fn(self._pb, self._xb)
        grads = _ag.backward_from(self._gout, yb, [*self._pb, self._xb])
        self._g_param = [grads[p] for p in self._pb]
        self._g_in = grads[self._xb]
        self._bwd = _compile_multi([*self._g_param, self._g_in])
        self._bwd_in = {id(t): n for t, n in self._bwd.input_ports}
        self._bwd_out = {t: n for t, n in self._bwd.output_ports}
        # baked-constant input ports (e.g. a causal mask) carried into the backward graph
        _fed = {id(self._xb), id(self._gout), *(id(p) for p in self._pb)}
        self._bwd_consts = [(t, n) for t, n in self._bwd.input_ports if id(t) not in _fed]

    def forward(self, layers_params, x0):
        """Run the stack. `layers_params` is a list (per layer) of lists (per-layer
        parameter numpy arrays). Returns `(output, checkpoints)` where `checkpoints[i]`
        is the input activation to layer `i` (needed by `backward`)."""
        x = np.asarray(x0, np.float32)
        checkpoints = []
        for lp in layers_params:
            checkpoints.append(x)
            feed = {id(self._x): x.astype(_F16)}
            for t, v in zip(self._p, lp):
                feed[id(t)] = np.asarray(v, _F16)
            # any other input port is a baked constant (e.g. a causal mask): feed its value
            vals = [feed[id(t)] if id(t) in feed else np.asarray(t.attrs["value"], _F16)
                    for t in self._fwd._input_tensors]
            x = np.asarray(self._fwd(*vals), np.float32)
        return x, checkpoints

    def backward(self, layers_params, checkpoints, g_out):
        """Backprop the stack. `g_out` is the gradient at the stack output. Returns
        `(param_grads, g_in)`: `param_grads[i]` is the list of gradients for layer
        `i`'s params, and `g_in` is the gradient at the stack input."""
        g = np.asarray(g_out, np.float32)
        param_grads = [None] * len(layers_params)
        for i in range(len(layers_params) - 1, -1, -1):
            self._bwd.prog.set_input(self._bwd_in[id(self._gout)], g.astype(_F16))
            self._bwd.prog.set_input(self._bwd_in[id(self._xb)], checkpoints[i].astype(_F16))
            for t, v in zip(self._pb, layers_params[i]):
                self._bwd.prog.set_input(self._bwd_in[id(t)], np.asarray(v, _F16))
            for t, n in self._bwd_consts:                   # baked constants (e.g. mask)
                self._bwd.prog.set_input(n, np.asarray(t.attrs["value"], _F16))
            self._bwd.prog.execute()
            param_grads[i] = [np.asarray(self._bwd.prog.read_output(self._bwd_out[gp]), np.float32)
                              for gp in self._g_param]
            g = np.asarray(self._bwd.prog.read_output(self._bwd_out[self._g_in]), np.float32)
        return param_grads, g

    def release(self):
        self._fwd.release()
        self._bwd.release()

forward

forward(layers_params, x0)

Run the stack. layers_params is a list (per layer) of lists (per-layer parameter numpy arrays). Returns (output, checkpoints) where checkpoints[i] is the input activation to layer i (needed by backward).

Source code in aneforge/streaming.py
def forward(self, layers_params, x0):
    """Run the stack. `layers_params` is a list (per layer) of lists (per-layer
    parameter numpy arrays). Returns `(output, checkpoints)` where `checkpoints[i]`
    is the input activation to layer `i` (needed by `backward`)."""
    x = np.asarray(x0, np.float32)
    checkpoints = []
    for lp in layers_params:
        checkpoints.append(x)
        feed = {id(self._x): x.astype(_F16)}
        for t, v in zip(self._p, lp):
            feed[id(t)] = np.asarray(v, _F16)
        # any other input port is a baked constant (e.g. a causal mask): feed its value
        vals = [feed[id(t)] if id(t) in feed else np.asarray(t.attrs["value"], _F16)
                for t in self._fwd._input_tensors]
        x = np.asarray(self._fwd(*vals), np.float32)
    return x, checkpoints

backward

backward(layers_params, checkpoints, g_out)

Backprop the stack. g_out is the gradient at the stack output. Returns (param_grads, g_in): param_grads[i] is the list of gradients for layer i's params, and g_in is the gradient at the stack input.

Source code in aneforge/streaming.py
def backward(self, layers_params, checkpoints, g_out):
    """Backprop the stack. `g_out` is the gradient at the stack output. Returns
    `(param_grads, g_in)`: `param_grads[i]` is the list of gradients for layer
    `i`'s params, and `g_in` is the gradient at the stack input."""
    g = np.asarray(g_out, np.float32)
    param_grads = [None] * len(layers_params)
    for i in range(len(layers_params) - 1, -1, -1):
        self._bwd.prog.set_input(self._bwd_in[id(self._gout)], g.astype(_F16))
        self._bwd.prog.set_input(self._bwd_in[id(self._xb)], checkpoints[i].astype(_F16))
        for t, v in zip(self._pb, layers_params[i]):
            self._bwd.prog.set_input(self._bwd_in[id(t)], np.asarray(v, _F16))
        for t, n in self._bwd_consts:                   # baked constants (e.g. mask)
            self._bwd.prog.set_input(n, np.asarray(t.attrs["value"], _F16))
        self._bwd.prog.execute()
        param_grads[i] = [np.asarray(self._bwd.prog.read_output(self._bwd_out[gp]), np.float32)
                          for gp in self._g_param]
        g = np.asarray(self._bwd.prog.read_output(self._bwd_out[self._g_in]), np.float32)
    return param_grads, g