Spectral Compact Training (Reimplementation)
- Animated walkthrough
- 1. The problem SCT is built to solve
- 2. The core idea in one equation
- 3. Counting the savings
- 4. The forward pass: three small matmuls
- 5. The backward pass: gradients on the factors, not on W
- 6. The Stiefel retraction: keeping U and V orthonormal
- 7. Putting one training step together
- 8. What I observed when I reimplemented it
- 9. Honest limitations
- 10. Why I find this interesting
- Citation
Synopsis. SCT is a training method that stores every weight matrix of a neural network as a truncated SVD — three small factors
U,s,V— instead of one large dense matrixW. Gradients are computed directly with respect to those factors, so the dense matrix is never built, never differentiated, and never stored. After each Adam step, the orthonormal factorsUandVare pulled back onto the Stiefel manifold via a cheap QR decomposition with sign correction. At rank 32, this collapses the per-MLP-layer training footprint of a 70-billion-parameter architecture from ~3.7 GB to ~19 MB — a 199× reduction that is mathematically exact, and which lets a full forward+backward+optimizer+retraction step for a LLaMA-3-70B-shaped network fit inside 8 GB of memory on a Steam Deck. SCT was proposed by Björn Roman Kohlberger in 2026 (Irish Patent Application PTIE20260000000219). This write-up is my reimplementation study of the method.
Animated walkthrough
A ~85-second visual walkthrough of how the truncated-SVD parameterisation, the three small matmuls in the forward pass, and the Stiefel retraction fit together. The written deep-dive below covers the same material in more detail.
1. The problem SCT is built to solve
Training a large language model is bound less by compute than by memory bandwidth and capacity. For every weight matrix W ∈ ℝ^{m×n} you actually want to update, the Adam optimizer needs four buffers of the same shape:
| buffer | shape | what it stores |
|---|---|---|
W | m × n | the weights themselves |
∇W | m × n | the gradient |
m₁ (Adam) | m × n | first-moment running mean |
m₂ (Adam) | m × n | second-moment running var |
That is 4 · m · n floating-point numbers per layer. In FP32 that is 16 bytes per parameter. For a 70-billion-parameter dense model in FP32 with Adam, you are looking at roughly 1,245 GB just for weights + gradients + optimizer state. You cannot fit that on a single GPU. You cannot even fit it on a small cluster without sharding and offloading.
The standard responses to this are:
- Mixed-precision / FP8 / quantised Adam. Shrinks each number but keeps the same shapes.
- LoRA. Freezes
W, learns small adaptersA B. The fullWstill sits in memory. - GaLore. Computes the dense gradient, then projects it into a low-rank subspace before applying it.
None of those change the underlying object: there is still a dense m × n matrix at the heart of the layer.
SCT changes the underlying object. The weight matrix never exists in dense form. It is stored, used, and updated as three small pieces.
2. The core idea in one equation
Every weight matrix is represented as its rank-k truncated singular value decomposition:
W = U · diag(s) · Vᵀ
with shapes
U ∈ ℝ^{m × k} (orthonormal columns: UᵀU = I_k)
s ∈ ℝ^{k} (the kept singular values, all ≥ 0)
V ∈ ℝ^{n × k} (orthonormal columns: VᵀV = I_k)
and a rank budget k ≪ min(m, n). For a LLaMA-70B MLP layer with m=8192, n=28672, a typical choice is k=32. That is enough to capture the dominant directions of variation in the layer while collapsing the parameter count by orders of magnitude.
The three factors are the parameters. There is no shadow copy of W anywhere. If you want to apply the layer to a batch, you apply the factors. If you want to update the layer, you update the factors.
That single decision is the whole method. Everything else — the forward pass, the backward pass, the optimizer step, the manifold retraction — is the consequence of taking it seriously.
3. Counting the savings
The parameter count for a single MLP layer:
| Dense + Adam | SCT + Adam (rank k) | |
|---|---|---|
| weights | m·n | k(m+n+1) |
| gradient | m·n | k(m+n+1) |
| Adam m₁ | m·n | k(m+n+1) |
| Adam m₂ | m·n | k(m+n+1) |
| total | 4·m·n | 4·k(m+n+1) |
The +1 in k(m+n+1) is the diagonal vector s (length k).
Plug in m=8192, n=28672, k=32:
dense : 4 · 8192 · 28672 ≈ 939.5 × 10⁶ floats ≈ 3,758 MB
SCT : 4 · 32 · (8192 + 28672 + 1) ≈ 4.72 × 10⁶ floats ≈ 18.9 MB
ratio : 3,758 / 18.9 ≈ 199.1 ×
The ratio mn / k(m+n+1) grows with the layer dimensions, so the larger the model, the better SCT does:
| Model | Layer (m × n) | Dense + Adam | SCT (k=32) | Compression |
|---|---|---|---|---|
| SmolLM2-135M | 576 × 1536 | 14.2 MB | 1.1 MB | 13× |
| SmolLM2-1.7B | 2048 × 8192 | 268 MB | 5.2 MB | 51× |
| LLaMA-7B | 4096 × 11008 | 721 MB | 7.7 MB | 93× |
| Qwen-27B | 4096 × 17408 | 1,141 MB | 11.0 MB | 104× |
| LLaMA-70B | 8192 × 28672 | 3,758 MB | 18.9 MB | 199× |
These numbers are exact — they are not benchmarks, they are arithmetic.
4. The forward pass: three small matmuls
The dense forward pass is one big matmul:
y = x · Wᵀ cost: O(b · m · n)
where b is the batch (× sequence length).
Substituting the factored form W = U diag(s) Vᵀ:
y = x · (U diag(s) Vᵀ)ᵀ
= x · V · diag(s) · Uᵀ
which you compute as three sequential small matmuls:
h = x · U [b × k] project into the spectral basis
h_s = h ⊙ s [b × k] scale by singular values
y = h_s · Vᵀ [b × n] reconstruct in output space
Total cost: O(b · k · (m + n)). Compare with O(b · m · n) for the dense version. At m=n=8192, k=32, that is a 256× reduction in FLOPs per forward pass on top of the memory savings. The m × n weight matrix is never assembled — not even transiently — at any point in the computation.
5. The backward pass: gradients on the factors, not on W
This is the part that distinguishes SCT from the prior art most clearly, and it is where I had to think the hardest while reimplementing it.
When you train a dense layer, autograd produces a gradient ∂L/∂W of shape m × n. That gradient is the same size as the weight matrix, and Adam keeps two more buffers of that same size. Memory blows up.
In SCT, the layer is parameterised by U, s, and V. Autograd traces the three small matmuls above and produces:
∂L/∂U of shape m × k
∂L/∂s of shape k
∂L/∂V of shape n × k
No tensor of shape m × n is ever allocated, because no tensor of shape m × n ever exists in the computation graph. The cost of backprop is the same O(b · k · (m + n)) as the forward pass.
A subtle but important point: these gradients are exact with respect to the factored parameterisation. They are not approximations of the dense gradient. They are, however, not the gradients you would get from training a full-rank dense model — because the rank-k model lives on a different loss landscape. The rank cap is a capacity constraint, not a numerical one. If your task genuinely requires higher effective rank, no amount of clever optimisation will recover what the parameterisation cannot represent (this is the Eckart–Young theorem). In practice, MLP layers at LLM scale tolerate very aggressive rank caps because their singular-value spectra decay quickly.
This is the practical line between SCT and its neighbours:
- LoRA keeps the full dense
Win memory and adds a small low-rank adapter alongside it. SCT has noWto keep. - GaLore computes the dense gradient and then projects it into a low-rank subspace. SCT never produces a dense gradient.
- SVD-LLM and similar compression methods take an already-trained dense network and truncate its SVD post-hoc. SCT trains in the truncated space from the start.
6. The Stiefel retraction: keeping U and V orthonormal
There is one wrinkle. The SVD decomposition W = U diag(s) Vᵀ is only meaningful if U and V actually have orthonormal columns. The instant Adam takes a gradient step, that constraint breaks: U ← U − η · ∂L/∂U is no longer on the Stiefel manifold
St(m, k) = { X ∈ ℝ^{m × k} : Xᵀ X = I_k }
If you let U drift freely, the factorisation stops being a truncated SVD and the singular values in s stop meaning what you think they mean.
SCT fixes this with the cheapest possible move: a QR decomposition with sign correction after every optimizer step.
def retract(U_updated):
Q, R = torch.linalg.qr(U_updated)
# sign correction: avoids annihilating a column when a diagonal of R is zero
signs = torch.sign(torch.diag(R))
signs = torch.where(signs == 0, torch.ones_like(signs), signs)
return Q * signs
The sign correction matters: a vanilla QR can leave Q with arbitrary column signs, which propagates a discontinuity into the singular-value vector s. By forcing every diagonal of R to contribute +1 or −1 (never 0), no column of U or V is ever annihilated and orthonormality is preserved smoothly across steps.
Cost. QR on an m × k matrix is O(m · k²). For m = 8192 and k = 32, that is roughly 8.4 million FLOPs per layer — negligible at GPU scale, and a small but measurable fraction of step time on CPU/handheld targets.
Empirical orthonormality. Measured Frobenius norm ‖UᵀU − I‖_F after retraction across the 70B-architecture run: < 2 × 10⁻⁶. The constraint is maintained essentially to numerical precision throughout training.
This is the move that makes SCT a training method rather than a compression method. The manifold constraint is not a one-time projection applied at the end of training; it is maintained continuously, step by step, so that the spectral structure of the parameters is preserved across the entire optimisation trajectory.
7. Putting one training step together
A complete SCT training step looks like this:
# Forward
h = x @ U # [b, k]
h_s = h * s # [b, k]
y = h_s @ V.T # [b, n]
loss = criterion(y, target)
# Backward — autograd handles this through the three matmuls
loss.backward()
# now: U.grad is [m, k], s.grad is [k], V.grad is [n, k]
# Optimizer step (Adam on the factors)
optimizer.step()
optimizer.zero_grad()
# Retract U and V back to the Stiefel manifold
with torch.no_grad():
U.copy_(retract(U))
V.copy_(retract(V))
Every operation in this loop touches only small tensors. The peak memory of the step is dominated by activations (which depend on batch and sequence length, not on m · n) and by the factors themselves.
8. What I observed when I reimplemented it
I rebuilt SCT against the public description and ran three classes of experiments.
8.1 70B-architecture memory validation
A full 70B-class transformer (80 layers, d=8192, ffn=28672, SwiGLU activation, matching LLaMA-3-70B layer dimensions) was instantiated in spectral form at rank 32 and run through one complete training step. Attention was simplified (additive, no softmax/masking) to isolate the memory claim from sequence-length effects.
| Hardware | Peak Memory | Forward | Backward | Optimizer | QR | Total |
|---|---|---|---|---|---|---|
| Apple M4 Pro (48 GB) | 7,907 MB | 0.08 s | 0.09 s | 0.22 s | 3.02 s | 3.41 s |
| Steam Deck (16 GB) | 7,236 MB | 0.43 s | 0.92 s | 2.35 s | 2.58 s | 6.28 s |
The dense FP32 baseline would need ~1,245 GB. The measured 7.24 GB on a Steam Deck is a 172× empirical reduction at the full-model level, consistent with the analytical per-layer ratio. Orthonormality error after retraction was < 2 × 10⁻⁶ on both platforms.
8.2 Fine-tuning gradient integrity (SmolLM2-135M on Alpaca)
To confirm gradients flow correctly through the spectral factors, I converted pre-trained SmolLM2-135M to spectral form at 95% energy retention and fine-tuned on Alpaca for 400 steps, comparing to a dense baseline with the same seed, data, and learning rate.
| Method | Final Loss | Final PPL | Trainable Params |
|---|---|---|---|
| Dense + AdamW | 0.2356 | 1.3 | 134,515,008 |
| SCT (energy ≥ 0.95) | 0.6480 | 1.9 | 84,333,271 |
SCT recovered from an initial conversion loss spike (9.4 → 0.65) to 1.46× baseline perplexity. The math works; gradients flow.
Caveat: at this scale (d=576), the 95% energy threshold produces ranks of 412–466, which is close to the full dimension. This run validates correctness, not compression utility — for that you need a model where the spectral budget is genuinely smaller than the layer dimension.
8.3 Rank sweep at 1.7B (SmolLM2-1.7B on Alpaca, A100 40GB)
Dense baseline vs SCT at ranks 32, 64, 128, 256. MLP layers converted to spectral form; attention, embeddings, and norms kept dense. 2000 steps, batch 4, AdamW.
| Method | Params | MLP Compression | Loss | PPL | GPU Memory | Step Time |
|---|---|---|---|---|---|---|
| Dense | 1,711M | 1.0× | 1.29 | 3.6 | 35.5 GB | 1.17 s |
| SCT r=256 | 692M | 5.9× | 4.33 | 75.6 | 21.3 GB | 1.05 s |
| SCT r=128 | 598M | 11.7× | 4.18 | 65.6 | 20.0 GB | 0.74 s |
| SCT r=64 | 551M | 23.5× | 4.34 | 76.7 | 19.3 GB | 0.62 s |
| SCT r=32 | 527M | 46.9× | 4.47 | 86.9 | 19.0 GB | 0.56 s |
Two things stood out:
- Memory and throughput claims hold at scale. GPU usage dropped from 35.5 GB (dense) to 19.0 GB (rank 32) — a 46% reduction in real VRAM. Step time dropped 2.1×.
- All ranks converge to roughly the same loss floor. Rank 256 (5.9× compression) and rank 32 (46.9× compression) ended within 0.3 loss of each other. At 2000 steps, MLP rank was not the dominant bottleneck.
The ~3 loss gap to the dense baseline turned out to be partly a learning-rate issue. At rank 32, MLP spectral parameters are only 18M of 527M total; attention layers are 403M (77% of the model). All components were training at the same SCT-tuned LR (5e-4, which is 25× the dense baseline). A per-component LR schedule — dense LR for attention/embeddings, higher LR for SCT factors — closes a slice of the gap but not all of it, which suggests the residual is something else (rank cap, interaction with attention, or implementation differences in my reimplementation worth tracking down).
9. Honest limitations
- Rank is a real ceiling. A rank-
kfactorisation can only represent a rank-kweight matrix. The Eckart–Young theorem says you cannot recover information that the parameterisation cannot represent. SCT is therefore strongest for pre-training, where the network gets to grow into the spectral budget, and weakest for compressing already-trained dense networks that have learned to use their full spectrum. - Convergence gap on conversion fine-tuning. My 1.7B rank sweep shows a measurable loss gap to dense after 2000 steps. Part of this is LR; part of it I have not fully isolated. SCT is presented as a memory-efficient training method, not as a method that matches dense loss at arbitrary rank.
- QR retraction cost. O(m · k²) per layer per step. Negligible on A100 at 1.7B scale (0.56 s total step at rank 32). On the Steam Deck 70B run, retraction was ~40% of step time. Small
kkeeps this bounded; largekdoes not. - Small models benefit less. Below ~1.7B parameters (hidden dim < 2048), the rank you need to retain useful energy is close to the full dimension, and the per-layer reduction becomes modest. SCT compression scales with
mn / k(m+n)— it is fundamentally a big-model technique.
10. Why I find this interesting
Three things stand out to me as a reimplementer:
- It is a parameterisation choice, not an optimisation trick. Most memory-efficient training methods modify the optimiser (8-bit Adam, ZeRO, GaLore) or freeze parts of the model (LoRA, adapter tuning). SCT changes what a “weight” is. Once you make that change, everything else — small matmuls, small gradients, small optimiser state — falls out for free.
- The Stiefel constraint is not optional. Without QR retraction, the orthonormality of
UandVdecays, the meaning ofscorrupts, and the parameterisation stops being an SVD. With it, the spectral structure is preserved across optimisation. The constraint is the method. - The memory savings are arithmetic, not benchmarks. Most “X× speedup” claims in ML are empirical and platform-dependent.
4mn vs 4k(m+n+1)is just counting floats — the 199× number for 70B at rank 32 is true on any hardware, in any framework, in any precision.
The remaining open question — and the one I think is most interesting to chase next — is how aggressively you can shrink k while still recovering dense-equivalent quality on a real downstream task. That is a question about the intrinsic rank of language-model MLPs, and it is largely an empirical question. The math of SCT just gives you a clean place to ask it.
Citation
This method is not mine. Credit goes to the original author:
@misc{kohlberger2026sct,
title = {Spectral Compact Training: Memory-Efficient Neural Network Training
via Truncated SVD Factorization with Stiefel Manifold Retraction},
author = {Kohlberger, Bj{\"o}rn Roman},
year = {2026},
note = {Irish Patent Application PTIE20260000000219}
}
Related work I read while reimplementing
- Absil, Mahony, Sepulchre. Optimization Algorithms on Matrix Manifolds. Princeton, 2008. — the standard reference for Stiefel-manifold optimisation and retractions.
- Hu et al. LoRA: Low-Rank Adaptation of Large Language Models. 2021. arXiv:2106.09685.
- Zhao et al. GaLore: Memory-Efficient LLM Training. ICML 2024. arXiv:2403.03507.
- Wang et al. SVD-LLM: Truncation-aware SVD for LLM Compression. 2024. arXiv:2403.07378.
- Li et al. Efficient Riemannian Optimization via Cayley Transform. ICLR 2020.