June 29, 2026
Building a 200M Parameter Time Series LLM from Scratch
Decoder-only patched transformer for zero-shot forecasting

By Fareed Khan
50 min read
Building a single forecasting model that works on a series it has never seen, instead of training a fresh model for every dataset, remains a hard problem. The field is converging on one answer, treat forecasting like language modeling. Decoder-only transformers, the architecture behind modern language models, have become one of the strongest approaches to forecasting, and Google's recently released decoder-based model reaches an accuracy that classical methods cannot match. So I set out to solve it from the ground up, building one from scratch, the model and the full training pipeline both, and scaling it to 200 million parameters.
The model we are building is a patched decoder, and it has six parts:
- Patch tokenizer: turns each 32-number window into one token.
- RevIN normalization: rescales each window so its units and size do not matter.
- Decoder stack: transformer layers where each token looks back at the earlier ones.
- Quantile head: outputs an uncertainty band, not a single number.
- All-positions training: every token predicts the stretch that follows it.
- Autoregressive forecast: rolls forward, feeding each prediction back in to extend the forecast as far as we need.
What I want to make is one model I can hand a series it has never seen and get back a forecast with a sensible band of uncertainty. This is where that lands. After I trained the 200M model for about 9 hours straight on a single GPU, the forecast (in green) tracks series it never saw, with the actual values shown in black.
In this blog we are going to build this decoder-only forecaster from scratch, the model and the full training pipeline, and scale it from a tiny 17M version up to the 200M one above. We do it step by step, hitting and fixing the problems we run into at every size along the way.
My trained model weights are available here:
FareedKhan/timesfm-from-scratch-70m · Hugging Face We're on a journey to advance and democratize artificial intelligence through open source and open science.
All the code is available in my GitHub repository:
GitHub - FareedKhan-dev/timesfm-from-scratch: From-scratch reimplementation of Google TimesFM… From-scratch reimplementation of Google TimesFM (time-series foundation model) plus the full pretraining pipeline…
Table of Contents
- The goal: forecasting a series the model has
- What we are building, and how we measure it
- A map of the codebase
- Setup and imports
- The synthetic data generator
- Patches as tokens, and RevIN normalization
- The transformer block, one mechanism at a time ∘ ResidualBlock and RMSNorm ∘ RoPE and per-dimension scale ∘ Multi-head attention: fused QKV, qk-norm, and a unit scale ∘ Sandwich norm and the two-matrix feed-forward
- Assembling the full model and the forward pass ∘ From training outputs to an autoregressive forecast
- The objective: predict every next 128 points,
- The training loop: optimizer and schedule
- Training the 17M model on synthetic data
- Scaling to 70M, and the lesson that data beats size
- The overfitting problem
- The fix: corpus diversity
- A bug: the model was weakest where it starts
- The evaluation harness
- Random-walk data, and a partial fix
- Calibrated uncertainty: conformal recalibration
- Scaling to the 200M model
- Limitations and what I would build next
- What building it taught me
The goal: forecasting a series the model has
Let me make the goal concrete. We want one model that takes the past values of a series, hourly electricity load, daily web traffic, or a currency exchange rate, and predicts the next stretch of it, having never seen that specific series before.
Zero-shot forecasting is hard because a model can fit one dataset perfectly and still be useless on a new one, having learned that dataset's quirks rather than the general grammar of time series. The goal is to learn structure that transfers, the seasonality, trend, and noise that show up across thousands of sources, without ever peeking at the series we test on.
Almost every design decision in this blog exists to push the model toward that transferable structure and away from memorization.
What we are building, and how we measure it
Before building anything, I need to know how to tell if it works. The metric is scaled MAE, the mean absolute error of our forecast divided by that of a naive baseline that just repeats the last value forward. Below 1.0 we beat the baseline, above 1.0 we are worse than doing nothing clever.
The benchmark is the held-out ETT dataset (electricity transformer temperature), which the model never trains on. For scale, a strong publicly released 200M forecasting model scores about 0.215 there, and that is the target for our final model.
The model comes in three sizes that we use throughout, tiny, small, and base, set by just the model dimension and the number of layers. The sanity script builds all three and prints their parameter counts.
#### OUTPUT ####
=== param counts ===
tiny 17.65M dim=512 layers=10 heads=16 head_dim=32 outputs=10
small 67.80M dim=1024 layers=10 heads=16 head_dim=64 outputs=10
base 203.44M dim=1280 layers=20 heads=16 head_dim=80 outputs=10#### OUTPUT ####
=== param counts ===
tiny 17.65M dim=512 layers=10 heads=16 head_dim=32 outputs=10
small 67.80M dim=1024 layers=10 heads=16 head_dim=64 outputs=10
base 203.44M dim=1280 layers=20 heads=16 head_dim=80 outputs=10These are the three rungs of our ladder, 17.65M to debug fast, 67.80M where we learn what drives forecasting quality, and 203.44M as the model we are building toward, which matches the backbone size of the released reference. The last column, outputs=10, is the mean plus nine quantiles, and we will see why when we reach the loss.
A map of the codebase
A project like this has a lot of moving parts, but the dependency flow is simple. Config defines the sizes, the layers and the normalization feed the model, the model and the loss feed the trainer, the data loaders feed it too, and the evaluation scripts produce every number and plot in this blog.
Here is the source layout, so you can follow along file by file.
tsfm-scratch/
src/tsfm/
config.py # model + training config, the three size presets
revin.py # reversible instance normalization (per-series scaling)
layers.py # ResidualBlock, RMSNorm, RoPE, attention, transformer block
model.py # PatchedDecoder: the full forward pass + autoregressive forecast
objective.py # the all-positions loss (MSE + pinball quantile loss)
trainer.py # optimizer, LR schedule, one training step, OOD validation
synthetic.py # the synthetic time-series generator
data.py # windowing + the random-prefix masking trick
sources.py # real-data loaders (electricity, M4, LOTSA)
scripts/
sanity.py # param counts + one forward/backward pass
overfit_one_batch.py # the correctness gate
train.py # the training entrypoint
eval_core.py # the shared evaluation protocol + metrics
eval_multidomain.py # the multi-domain zero-shot test gridtsfm-scratch/
src/tsfm/
config.py # model + training config, the three size presets
revin.py # reversible instance normalization (per-series scaling)
layers.py # ResidualBlock, RMSNorm, RoPE, attention, transformer block
model.py # PatchedDecoder: the full forward pass + autoregressive forecast
objective.py # the all-positions loss (MSE + pinball quantile loss)
trainer.py # optimizer, LR schedule, one training step, OOD validation
synthetic.py # the synthetic time-series generator
data.py # windowing + the random-prefix masking trick
sources.py # real-data loaders (electricity, M4, LOTSA)
scripts/
sanity.py # param counts + one forward/backward pass
overfit_one_batch.py # the correctness gate
train.py # the training entrypoint
eval_core.py # the shared evaluation protocol + metrics
eval_multidomain.py # the multi-domain zero-shot test gridWe are going to build this roughly top to bottom. First the data the model learns from, then the model itself, then the loss and the training loop, and finally the long climb up the three sizes.
Let us start with the environment.
Setup and imports
The dependency list is deliberately tiny. The core package needs only NumPy, and PyTorch is an optional extra we install for the actual training.
[project]
name = "tsfm"
version = "0.1.0"
description = "From-scratch decoder-only time-series foundation model + training pipeline"
requires-python = ">=3.10"
dependencies = ["numpy>=1.26"]
[project.optional-dependencies]
torch = ["torch>=2.2"]
dev = ["pytest>=8"][project]
name = "tsfm"
version = "0.1.0"
description = "From-scratch decoder-only time-series foundation model + training pipeline"
requires-python = ">=3.10"
dependencies = ["numpy>=1.26"]
[project.optional-dependencies]
torch = ["torch>=2.2"]
dev = ["pytest>=8"]I install it with uv, which keeps the environment reproducible, and I run everything on a local GPU machine with one NVIDIA H100 80GB.
uv venv
uv pip install -e ".[torch,dev]"uv venv
uv pip install -e ".[torch,dev]"The first piece of code is the configuration object. Every hyperparameter the model needs lives in one dataclass, and two computed properties save us from repeating arithmetic everywhere.
@dataclass
class ModelConfig:
model_dim: int = 1280
num_layers: int = 20
num_heads: int = 16
patch_len: int = 32
horizon_len: int = 128
quantiles: Tuple[float, ...] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
dropout: float = 0.0
rope_max_period: float = 10000.0
rms_eps: float = 1e-6
@property
def head_dim(self) -> int:
assert self.model_dim % self.num_heads == 0
return self.model_dim // self.num_heads
@property
def num_outputs(self) -> int:
return 1 + len(self.quantiles) # mean (point) + quantiles@dataclass
class ModelConfig:
model_dim: int = 1280
num_layers: int = 20
num_heads: int = 16
patch_len: int = 32
horizon_len: int = 128
quantiles: Tuple[float, ...] = (0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9)
dropout: float = 0.0
rope_max_period: float = 10000.0
rms_eps: float = 1e-6
@property
def head_dim(self) -> int:
assert self.model_dim % self.num_heads == 0
return self.model_dim // self.num_heads
@property
def num_outputs(self) -> int:
return 1 + len(self.quantiles) # mean (point) + quantilesA few of these are worth pausing on, because they define the shape of everything downstream. The patch_len of 32 is how many raw points make up one token.
The horizon_len of 128 is how far ahead the model predicts at once. The quantiles tuple has nine entries, and num_outputs is ten, one channel for the mean plus one for each of the nine quantiles.
That tenth-of-a-detail is the reason the sanity print said outputs=10, and it is what lets the model express uncertainty instead of a single guess.
The synthetic data generator
Here is the first problem. To train a model that generalizes across many kinds of series, we need a lot of varied series.
I do not have a giant labelled corpus lying around, so I grow one. The synthetic generator builds a series as a weighted sum of a few simple component families, trend, ARMA noise, and sine and cosine waves, with a chunk of pure random walk thrown in.
Let us build one component family at a time. The most interesting one is ARMA, an autoregressive moving-average process, because it is where I hit my first bug.
To generate an ARMA series we need coefficients that produce a stable process. If we pick them carelessly the process explodes to infinity and we get NaNs.
The condition for stability is that all the roots of the characteristic polynomial lie strictly inside the unit circle.
def _reject_sample_roots(rng, order, max_attempts=200):
"""Sample stationary AR (or invertible MA) coefficients.
For y_t = c1 y_{t-1} + ... + cp y_{t-p} + e_t the characteristic polynomial is
z^p - c1 z^{p-1} - ... - cp; the process is stationary iff ALL roots lie INSIDE the unit circle.
"""
for _ in range(max_attempts):
c = rng.normal(0, 0.4, size=order)
roots = np.roots(np.concatenate([[1.0], -c])) # char poly z^p - c1 z^{p-1} - ...
if np.all(np.abs(roots) < 0.99):
return c
return np.array([0.5] + [0.0] * (order - 1)) # AR(1), c=0.5 -> stationary fallbackdef _reject_sample_roots(rng, order, max_attempts=200):
"""Sample stationary AR (or invertible MA) coefficients.
For y_t = c1 y_{t-1} + ... + cp y_{t-p} + e_t the characteristic polynomial is
z^p - c1 z^{p-1} - ... - cp; the process is stationary iff ALL roots lie INSIDE the unit circle.
"""
for _ in range(max_attempts):
c = rng.normal(0, 0.4, size=order)
roots = np.roots(np.concatenate([[1.0], -c])) # char poly z^p - c1 z^{p-1} - ...
if np.all(np.abs(roots) < 0.99):
return c
return np.array([0.5] + [0.0] * (order - 1)) # AR(1), c=0.5 -> stationary fallbackHere is the function that samples stable coefficients by rejection. My first version had the test inverted, it kept the explosive coefficients and rejected the good ones, and every few batches a series would blow up to infinity and poison the loss.
The fix is the one line that checks the roots are inside the circle.
The ARMA recurrence itself is the standard one, the next value is a weighted sum of past values plus a weighted sum of past noise terms.
The wave components are simpler. We pick a random period, phase, and amplitude, and evaluate a sine or cosine over the length of the series.
def generate_wave(length, rng, fn):
period = float(rng.integers(4, max(5, length // 2 + 1)))
phase = float(rng.uniform(0, 2 * np.pi))
amp = float(np.clip(rng.normal(1.0, 0.5), 0.1, 3.0))
t = np.arange(length, dtype=np.float32)
return (amp * fn(2 * np.pi * t / period + phase)).astype(np.float32)def generate_wave(length, rng, fn):
period = float(rng.integers(4, max(5, length // 2 + 1)))
phase = float(rng.uniform(0, 2 * np.pi))
amp = float(np.clip(rng.normal(1.0, 0.5), 0.1, 3.0))
t = np.arange(length, dtype=np.float32)
return (amp * fn(2 * np.pi * t / period + phase)).astype(np.float32)The trend component is a piecewise-linear walk. We pick a handful of segments, give each a random slope, and join them so the line stays continuous at the breakpoints, which gives us trends that bend partway through instead of running in one straight line forever.
def generate_piecewise_trend(length, rng):
k = int(rng.integers(2, 9)) # 2..8 segments
bp = np.sort(rng.choice(np.arange(1, length), size=k - 1, replace=False)) if k > 1 else np.array([], int)
breaks = np.concatenate([[0], bp, [length]]).astype(int)
slopes = rng.normal(0, 0.5, size=k)
trend = np.zeros(length, np.float32)
intercept = 0.0
for i in range(k):
t0, t1 = breaks[i], breaks[i + 1]
t = np.arange(t0, t1, dtype=np.float32)
trend[t0:t1] = intercept + slopes[i] * (t - t0)
if t1 > t0:
intercept = float(trend[t1 - 1]) # keep the line continuous at each break
return trenddef generate_piecewise_trend(length, rng):
k = int(rng.integers(2, 9)) # 2..8 segments
bp = np.sort(rng.choice(np.arange(1, length), size=k - 1, replace=False)) if k > 1 else np.array([], int)
breaks = np.concatenate([[0], bp, [length]]).astype(int)
slopes = rng.normal(0, 0.5, size=k)
trend = np.zeros(length, np.float32)
intercept = 0.0
for i in range(k):
t0, t1 = breaks[i], breaks[i + 1]
t = np.arange(t0, t1, dtype=np.float32)
trend[t0:t1] = intercept + slopes[i] * (t - t0)
if t1 > t0:
intercept = float(trend[t1 - 1]) # keep the line continuous at each break
return trendThe ARMA component is the one that needs the stable coefficients we just sampled. It runs the recurrence forward over a burn-in period plus the requested length, then drops the burn-in so the process has settled into its natural behavior before we keep any of it.
def generate_arma(length, rng, max_attempts=100):
p = int(rng.integers(1, 9))
q = int(rng.integers(1, 9))
ar = _reject_sample_roots(rng, p, max_attempts) # stationary AR coefficients
ma = _reject_sample_roots(rng, q, max_attempts) # invertible MA coefficients
burn = max(p, q) * 10 # let the process forget its zero start
n = burn + length
eps = rng.normal(0, 1.0, size=n)
y = np.zeros(n, np.float32)
for t in range(n):
pp = min(p, t); qq = min(q, t)
ar_part = float(np.dot(ar[:pp], y[t - pp:t][::-1])) if pp else 0.0
ma_part = float(np.dot(ma[:qq], eps[t - qq:t][::-1])) if qq else 0.0
y[t] = ar_part + eps[t] + ma_part
return y[burn:].astype(np.float32)def generate_arma(length, rng, max_attempts=100):
p = int(rng.integers(1, 9))
q = int(rng.integers(1, 9))
ar = _reject_sample_roots(rng, p, max_attempts) # stationary AR coefficients
ma = _reject_sample_roots(rng, q, max_attempts) # invertible MA coefficients
burn = max(p, q) * 10 # let the process forget its zero start
n = burn + length
eps = rng.normal(0, 1.0, size=n)
y = np.zeros(n, np.float32)
for t in range(n):
pp = min(p, t); qq = min(q, t)
ar_part = float(np.dot(ar[:pp], y[t - pp:t][::-1])) if pp else 0.0
ma_part = float(np.dot(ma[:qq], eps[t - qq:t][::-1])) if qq else 0.0
y[t] = ar_part + eps[t] + ma_part
return y[burn:].astype(np.float32)We just built the two hardest component families. The trend generator produces a continuous bending line, and the ARMA generator produces correlated noise with the stability we guaranteed earlier.
The burn-in of ten times the model order is the detail that matters, it lets the process forget that it started from zero and reach its natural variance before we slice off the part we keep.
Now we assemble a full series. We randomly switch each component family on or off, mix the active ones with random weights that sum to one, apply the trend either additively or multiplicatively, and standardize the result.
def generate_series(length, rng, rw_prob=0.2):
if rng.random() < rw_prob: # ~20% structureless RW/white-noise
y = generate_random_walk(length, rng)
else:
on = {k: bool(rng.random() < 0.5) for k in ("trend", "arma", "sine", "cosine")}
if not any(on.values()):
on[("trend", "arma", "sine", "cosine")[int(rng.integers(0, 4))]] = True
comp = {}
if on["trend"]:
comp["trend"] = generate_piecewise_trend(length, rng)
if on["arma"]:
comp["arma"] = generate_arma(length, rng)
if on["sine"]:
comp["sine"] = generate_wave(length, rng, np.sin)
if on["cosine"]:
comp["cosine"] = generate_wave(length, rng, np.cos)
names = list(comp)
w = rng.uniform(0, 1, size=len(names))
w = w / w.sum()
y_add = np.zeros(length, np.float32)
for name, wi in zip(names, w):
if name != "trend":
y_add += wi * comp[name]
if "trend" in comp: # trend can be additive or multiplicative
tr = comp["trend"]
wt = float(w[names.index("trend")])
denom = 1.0 + float(np.std(np.abs(tr)))
if rng.random() < 0.5:
y = (1.0 + wt * (tr / denom)) * y_add
else:
y = wt * tr + y_add
else:
y = y_add
if rng.random() < 0.1:
y = y + rng.normal(0, 0.1, size=length)
y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0) # never emit nan/inf
y = np.clip(y, -1e4, 1e4)
s = float(y.std())
if s > 1e-6:
y = (y - float(y.mean())) / s # standardize to ~unit scale
return y.astype(np.float32)def generate_series(length, rng, rw_prob=0.2):
if rng.random() < rw_prob: # ~20% structureless RW/white-noise
y = generate_random_walk(length, rng)
else:
on = {k: bool(rng.random() < 0.5) for k in ("trend", "arma", "sine", "cosine")}
if not any(on.values()):
on[("trend", "arma", "sine", "cosine")[int(rng.integers(0, 4))]] = True
comp = {}
if on["trend"]:
comp["trend"] = generate_piecewise_trend(length, rng)
if on["arma"]:
comp["arma"] = generate_arma(length, rng)
if on["sine"]:
comp["sine"] = generate_wave(length, rng, np.sin)
if on["cosine"]:
comp["cosine"] = generate_wave(length, rng, np.cos)
names = list(comp)
w = rng.uniform(0, 1, size=len(names))
w = w / w.sum()
y_add = np.zeros(length, np.float32)
for name, wi in zip(names, w):
if name != "trend":
y_add += wi * comp[name]
if "trend" in comp: # trend can be additive or multiplicative
tr = comp["trend"]
wt = float(w[names.index("trend")])
denom = 1.0 + float(np.std(np.abs(tr)))
if rng.random() < 0.5:
y = (1.0 + wt * (tr / denom)) * y_add
else:
y = wt * tr + y_add
else:
y = y_add
if rng.random() < 0.1:
y = y + rng.normal(0, 0.1, size=length)
y = np.nan_to_num(y, nan=0.0, posinf=0.0, neginf=0.0) # never emit nan/inf
y = np.clip(y, -1e4, 1e4)
s = float(y.std())
if s > 1e-6:
y = (y - float(y.mean())) / s # standardize to ~unit scale
return y.astype(np.float32)That generate_series function is the core of the data pipeline, so let us recap what it does. It rolls a die for each component, mixes the active ones with normalized weights, occasionally multiplies by a trend instead of adding it, and always finishes by cleaning out any NaN or infinity and standardizing the series to roughly unit scale.
The rw_prob=0.2 branch at the top, the random walk, is a fix we add much later for a specific failure, so hold that thought.
Let us generate a small batch and look at the statistics.
from tsfm.synthetic import make_batch
batch = make_batch(4, 1024, seed=1)
print("shape:", batch.shape)
print("per-series mean:", batch.mean(axis=1).round(3))
print("per-series std :", batch.std(axis=1).round(3))
#### OUTPUT ####
shape: (4, 1024)
per-series mean: [ 0. -0. -0. 0. ]
per-series std : [1. 1. 1. 1.]from tsfm.synthetic import make_batch
batch = make_batch(4, 1024, seed=1)
print("shape:", batch.shape)
print("per-series mean:", batch.mean(axis=1).round(3))
print("per-series std :", batch.std(axis=1).round(3))
#### OUTPUT ####
shape: (4, 1024)
per-series mean: [ 0. -0. -0. 0. ]
per-series std : [1. 1. 1. 1.]Every series comes out with mean zero and standard deviation one, which is exactly what we asked for. This is a deliberate first layer of normalization, and the model has a second one (we build it next) that handles the scale of each individual context window.
Standardizing the raw series keeps the synthetic magnitudes sane so training does not diverge, and it puts every series on a comparable footing before the model ever sees it.
Patches as tokens, and RevIN normalization
We have series. Now we have to turn them into something a transformer can read.
Two things happen here. We cut the series into fixed-length patches, and we normalize each window so the model cares about shape, not units.
First the patching and the masking. The series is reshaped into N patches of length p, so a context of length 512 with a patch length of 32 becomes 16 patches.
Here is the function that windows a list of series into a batch and adds the masking. There is one clever trick in it, the random-prefix mask.
def batch_from_series(series_list, context_len, patch_len, rng, device="cpu"):
"""Window each series to context_len, left-pad short ones, add the random-prefix mask."""
B = len(series_list)
X = np.zeros((B, context_len), np.float32)
P = np.zeros((B, context_len), np.float32)
for i, s in enumerate(series_list):
s = np.asarray(s, np.float32)
if len(s) >= context_len:
st = int(rng.integers(0, len(s) - context_len + 1))
X[i] = s[st:st + context_len]
else:
X[i, context_len - len(s):] = s # left-pad short series
P[i, :context_len - len(s)] = 1.0 # mark the padding
r = int(rng.integers(0, patch_len)) # random-prefix mask
P[i, :r] = 1.0
return (torch.from_numpy(X).to(device), torch.from_numpy(P).to(device))def batch_from_series(series_list, context_len, patch_len, rng, device="cpu"):
"""Window each series to context_len, left-pad short ones, add the random-prefix mask."""
B = len(series_list)
X = np.zeros((B, context_len), np.float32)
P = np.zeros((B, context_len), np.float32)
for i, s in enumerate(series_list):
s = np.asarray(s, np.float32)
if len(s) >= context_len:
st = int(rng.integers(0, len(s) - context_len + 1))
X[i] = s[st:st + context_len]
else:
X[i, context_len - len(s):] = s # left-pad short series
P[i, :context_len - len(s)] = 1.0 # mark the padding
r = int(rng.integers(0, patch_len)) # random-prefix mask
P[i, :r] = 1.0
return (torch.from_numpy(X).to(device), torch.from_numpy(P).to(device))We just built the windower. For each series it either takes a random window if the series is long enough, or left-pads it and marks the padding.
Then it masks a random number of points at the very start, between zero and one full patch. Why mask a random prefix?
Because at inference the model sees contexts of every possible length, not just neat multiples of the patch size. By masking a random prefix during training, the model learns to handle every alignment, so it never gets surprised by a context that does not divide evenly.
Now the second normalization, and this is the one that makes the model scale-blind. It is called RevIN, reversible instance normalization.
The idea is to standardize each context by the mean and standard deviation of its first patch, run the model in that normalized space, and then reverse the normalization on the output. Electricity load in the thousands and a currency rate near one both become the same standardized shape, so the model only ever has to learn shapes.
There is a trap here that cost me a training run. If the first patch happens to be nearly flat, its standard deviation is almost zero, and dividing by it sends the normalized values to infinity.
The fix is to floor the standard deviation at a fraction of the whole-context standard deviation, and to clamp the normalized values to a safe range.
def first_patch_stats(xp, pp, min_valid: int = 3, eps: float = 1e-6, ctx_floor: float = 0.3):
"""xp, pp: [B, N, p]; pp 1=pad/missing. Returns mu, sigma: [B].
Centers on the first valid patch (>=min_valid points), but FLOORS sigma at ctx_floor * the
whole-context std so a near-flat first patch can't blow up the normalized scale (stability).
"""
B, N, p = xp.shape
valid = 1.0 - pp # 1 = usable point
counts = valid.sum(-1) # [B, N]
has = counts >= min_valid # [B, N]
idx = torch.argmax(has.to(torch.int64), dim=1) # first patch with enough real points
none = ~has.any(dim=1)
idx = torch.where(none, torch.full_like(idx, N - 1), idx)
rows = torch.arange(B, device=xp.device)
arr = xp[rows, idx] # [B, p]
m = valid[rows, idx]
n = m.sum(-1).clamp(min=1.0)
mu = (arr * m).sum(-1) / n # mean of the first valid patch
var = (((arr - mu[:, None]) * m) ** 2).sum(-1) / n
sigma = var.clamp(min=0.0).sqrt()
# whole-context std (over all valid points) as a stability floor
cnt = valid.sum(dim=(1, 2)).clamp(min=1.0)
gmu = (xp * valid).sum(dim=(1, 2)) / cnt
gvar = (((xp - gmu[:, None, None]) * valid) ** 2).sum(dim=(1, 2)) / cnt
ctx_std = gvar.clamp(min=0.0).sqrt()
sigma = torch.maximum(sigma, ctx_floor * ctx_std) # the floor: at least 30% of the context std
sigma = torch.where(sigma < eps, torch.ones_like(sigma), sigma)
return mu, sigmadef first_patch_stats(xp, pp, min_valid: int = 3, eps: float = 1e-6, ctx_floor: float = 0.3):
"""xp, pp: [B, N, p]; pp 1=pad/missing. Returns mu, sigma: [B].
Centers on the first valid patch (>=min_valid points), but FLOORS sigma at ctx_floor * the
whole-context std so a near-flat first patch can't blow up the normalized scale (stability).
"""
B, N, p = xp.shape
valid = 1.0 - pp # 1 = usable point
counts = valid.sum(-1) # [B, N]
has = counts >= min_valid # [B, N]
idx = torch.argmax(has.to(torch.int64), dim=1) # first patch with enough real points
none = ~has.any(dim=1)
idx = torch.where(none, torch.full_like(idx, N - 1), idx)
rows = torch.arange(B, device=xp.device)
arr = xp[rows, idx] # [B, p]
m = valid[rows, idx]
n = m.sum(-1).clamp(min=1.0)
mu = (arr * m).sum(-1) / n # mean of the first valid patch
var = (((arr - mu[:, None]) * m) ** 2).sum(-1) / n
sigma = var.clamp(min=0.0).sqrt()
# whole-context std (over all valid points) as a stability floor
cnt = valid.sum(dim=(1, 2)).clamp(min=1.0)
gmu = (xp * valid).sum(dim=(1, 2)) / cnt
gvar = (((xp - gmu[:, None, None]) * valid) ** 2).sum(dim=(1, 2)) / cnt
ctx_std = gvar.clamp(min=0.0).sqrt()
sigma = torch.maximum(sigma, ctx_floor * ctx_std) # the floor: at least 30% of the context std
sigma = torch.where(sigma < eps, torch.ones_like(sigma), sigma)
return mu, sigmaLet us read this carefully, because the floor is what matters most here.
- mu is the mean of the first patch that has at least three real points.
- sigma is the standard deviation of that same patch.
- The floor raises sigma to at least 30 percent of the whole-context standard deviation, so a flat first patch can never produce a tiny divisor.
- The clamp (applied in the model, to plus or minus 20) catches anything that still slips through.
When the model forecasts, it does the reverse, it multiplies by sigma and adds mu to put the prediction back on the original scale.
Let us see RevIN in action on a batch.
xp = x.view(B, -1, cfg.patch_len)
pp = pad.view(B, -1, cfg.patch_len)
mu, sigma = first_patch_stats(xp, pp)
xn = ((xp - mu[:, None, None]) / sigma[:, None, None]).clamp(-20.0, 20.0)
print("mu :", mu.numpy().round(3))
print("sigma:", sigma.numpy().round(3))
print("normalized range:", float(xn.min()), "to", float(xn.max()))
#### OUTPUT ####
mu : [ 0.214 -1.880 0.061 0.945]
sigma: [0.973 1.041 0.887 1.006]
normalized range: -4.118 to 5.272xp = x.view(B, -1, cfg.patch_len)
pp = pad.view(B, -1, cfg.patch_len)
mu, sigma = first_patch_stats(xp, pp)
xn = ((xp - mu[:, None, None]) / sigma[:, None, None]).clamp(-20.0, 20.0)
print("mu :", mu.numpy().round(3))
print("sigma:", sigma.numpy().round(3))
print("normalized range:", float(xn.min()), "to", float(xn.max()))
#### OUTPUT ####
mu : [ 0.214 -1.880 0.061 0.945]
sigma: [0.973 1.041 0.887 1.006]
normalized range: -4.118 to 5.272Each series gets its own mu and sigma, and after normalization the values land in a tidy range well within the plus-or-minus-20 clamp. The clamp never had to fire here, which is what we want, it is a guard rail for the pathological cases, not something the normal data should ever hit.
With patching and RevIN done, the model finally has tokens it can read.
The transformer block, one mechanism at a time
Now we build the transformer itself. I am going to assemble it from named pieces, one at a time, and explain why each piece is there.
Every mechanism here is copied from a design that is known to work, so we are not guessing, we are wiring up proven parts in the right order.
ResidualBlock and RMSNorm
The first building block is a small two-layer network with a skip connection, used both as the tokenizer at the input and as the prediction head at the output. It runs the input through a hidden layer with a SiLU activation, then an output layer, and adds a linear projection of the input on top.
class ResidualBlock(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.hidden_layer = nn.Linear(in_dim, hidden_dim, bias=bias)
self.output_layer = nn.Linear(hidden_dim, out_dim, bias=bias)
self.residual_layer = nn.Linear(in_dim, out_dim, bias=bias)
self.act = nn.SiLU()
def forward(self, x):
return self.output_layer(self.act(self.hidden_layer(x))) + self.residual_layer(x)class ResidualBlock(nn.Module):
def __init__(self, in_dim: int, hidden_dim: int, out_dim: int, bias: bool = False):
super().__init__()
self.hidden_layer = nn.Linear(in_dim, hidden_dim, bias=bias)
self.output_layer = nn.Linear(hidden_dim, out_dim, bias=bias)
self.residual_layer = nn.Linear(in_dim, out_dim, bias=bias)
self.act = nn.SiLU()
def forward(self, x):
return self.output_layer(self.act(self.hidden_layer(x))) + self.residual_layer(x)
The second building block is the normalization the transformer uses internally, RMSNorm. It rescales a vector by its root-mean-square and then applies a learned per-dimension gain.
The gain is initialized to zero, so the term (1 + scale) starts at one and the layer begins life as a near-identity, which keeps early training stable.
class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.zeros(dim)) # zero-init -> (1+scale) ~ identity at start
def forward(self, x):
dt = x.dtype
xf = x.float()
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
return (xf * (1.0 + self.scale.float())).to(dt)class RMSNorm(nn.Module):
def __init__(self, dim: int, eps: float = 1e-6):
super().__init__()
self.eps = eps
self.scale = nn.Parameter(torch.zeros(dim)) # zero-init -> (1+scale) ~ identity at start
def forward(self, x):
dt = x.dtype
xf = x.float()
xf = xf * torch.rsqrt(xf.pow(2).mean(-1, keepdim=True) + self.eps)
return (xf * (1.0 + self.scale.float())).to(dt)We just built the two workhorses. The ResidualBlock is a flexible learned transformation with a built-in shortcut, and RMSNorm is a cheap, stable normalization that starts as an identity.
We use them everywhere from here on.
RoPE and per-dimension scale
A transformer needs to know the order of its tokens. We encode position with rotary position embedding, RoPE, which rotates each query and key vector by an angle that depends on its position.
The angle grows with position and shrinks with the dimension index, so different dimensions rotate at different rates.
class RotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, max_period: float = 10000.0):
super().__init__()
self.head_dim = head_dim
self.max_period = max_period
def cos_sin(self, positions, dtype):
# positions: [N] -> cos/sin: [N, head_dim/2]
half = self.head_dim // 2
device = positions.device
freq = self.max_period ** (2.0 * torch.arange(half, device=device).float() / self.head_dim)
ang = positions.float()[:, None] / freq[None, :]
return torch.cos(ang).to(dtype), torch.sin(ang).to(dtype)
@staticmethod
def apply(x, cos, sin):
# x: [B, H, N, D]; cos/sin: [N, D/2]
first, second = x.chunk(2, dim=-1)
cos = cos[None, None]
sin = sin[None, None]
return torch.cat([first * cos - second * sin, second * cos + first * sin], dim=-1)class RotaryEmbedding(nn.Module):
def __init__(self, head_dim: int, max_period: float = 10000.0):
super().__init__()
self.head_dim = head_dim
self.max_period = max_period
def cos_sin(self, positions, dtype):
# positions: [N] -> cos/sin: [N, head_dim/2]
half = self.head_dim // 2
device = positions.device
freq = self.max_period ** (2.0 * torch.arange(half, device=device).float() / self.head_dim)
ang = positions.float()[:, None] / freq[None, :]
return torch.cos(ang).to(dtype), torch.sin(ang).to(dtype)
@staticmethod
def apply(x, cos, sin):
# x: [B, H, N, D]; cos/sin: [N, D/2]
first, second = x.chunk(2, dim=-1)
cos = cos[None, None]
sin = sin[None, None]
return torch.cat([first * cos - second * sin, second * cos + first * sin], dim=-1)
There is one more small piece that goes with attention, a learned per-dimension scale on the query. Instead of the usual fixed one-over-square-root-of-d temperature, we let the model learn a per-dimension gain through a softplus, which keeps it positive.
class PerDimScale(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.zeros(dim))
self.r = 1.442695041 / math.sqrt(dim) # 1/ln(2) / sqrt(d)
def forward(self, x):
return x * (self.r * F.softplus(self.scale))class PerDimScale(nn.Module):
def __init__(self, dim: int):
super().__init__()
self.scale = nn.Parameter(torch.zeros(dim))
self.r = 1.442695041 / math.sqrt(dim) # 1/ln(2) / sqrt(d)
def forward(self, x):
return x * (self.r * F.softplus(self.scale))Multi-head attention: fused QKV, qk-norm, and a unit scale
Now the attention itself. There are a few deliberate choices baked in here.
The query, key, and value projections are fused into one linear layer for speed. The queries and keys are each passed through their own RMSNorm before the rotation, which keeps their magnitudes controlled.
RoPE is applied to both, the per-dimension scale is applied to the query, and the actual attention is computed with a scale of 1.0, because the temperature is already handled by the per-dimension scale.
class MultiHeadAttention(nn.Module):
def __init__(self, cfg):
super().__init__()
d, self.h, self.hd = cfg.model_dim, cfg.num_heads, cfg.head_dim
self.qkv = nn.Linear(d, 3 * d, bias=False) # fused Q, K, V in one matmul
self.out = nn.Linear(d, d, bias=False)
self.q_norm = RMSNorm(self.hd, cfg.rms_eps) # qk-norm on the query
self.k_norm = RMSNorm(self.hd, cfg.rms_eps) # qk-norm on the key
self.rope = RotaryEmbedding(self.hd, cfg.rope_max_period)
self.per_dim = PerDimScale(self.hd)
def forward(self, x, attn_mask, cos, sin):
B, N, D = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1) # split the fused projection
q = self.q_norm(q.view(B, N, self.h, self.hd))
k = self.k_norm(k.view(B, N, self.h, self.hd))
v = v.view(B, N, self.h, self.hd)
q, k, v = (t.transpose(1, 2) for t in (q, k, v)) # [B, H, N, hd]
q = self.rope.apply(q, cos, sin) # rotate the query
k = self.rope.apply(k, cos, sin) # rotate the key
q = self.per_dim(q) # learned temperature on q
o = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0)
o = o.transpose(1, 2).reshape(B, N, D)
return self.out(o)class MultiHeadAttention(nn.Module):
def __init__(self, cfg):
super().__init__()
d, self.h, self.hd = cfg.model_dim, cfg.num_heads, cfg.head_dim
self.qkv = nn.Linear(d, 3 * d, bias=False) # fused Q, K, V in one matmul
self.out = nn.Linear(d, d, bias=False)
self.q_norm = RMSNorm(self.hd, cfg.rms_eps) # qk-norm on the query
self.k_norm = RMSNorm(self.hd, cfg.rms_eps) # qk-norm on the key
self.rope = RotaryEmbedding(self.hd, cfg.rope_max_period)
self.per_dim = PerDimScale(self.hd)
def forward(self, x, attn_mask, cos, sin):
B, N, D = x.shape
q, k, v = self.qkv(x).chunk(3, dim=-1) # split the fused projection
q = self.q_norm(q.view(B, N, self.h, self.hd))
k = self.k_norm(k.view(B, N, self.h, self.hd))
v = v.view(B, N, self.h, self.hd)
q, k, v = (t.transpose(1, 2) for t in (q, k, v)) # [B, H, N, hd]
q = self.rope.apply(q, cos, sin) # rotate the query
k = self.rope.apply(k, cos, sin) # rotate the key
q = self.per_dim(q) # learned temperature on q
o = F.scaled_dot_product_attention(q, k, v, attn_mask=attn_mask, scale=1.0)
o = o.transpose(1, 2).reshape(B, N, D)
return self.out(o)
We just built attention. The fused QKV is one matmul instead of three, the queries and keys are normalized before they are rotated, and the softmax temperature is learned rather than fixed.
The attn_mask argument is how we keep the model causal, and we build that mask in the model itself in a moment.
Sandwich norm and the two-matrix feed-forward
The last piece is the block that wraps attention and a feed-forward network. The norm placement here is a sandwich, there is a normalization both before and after each sublayer, not just before.
The feed-forward is two matrices with a SiLU in between, which is simpler than the gated variants and is exactly what the design we are following uses.
class TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
d = cfg.model_dim
self.pre_attn_ln = RMSNorm(d, cfg.rms_eps)
self.post_attn_ln = RMSNorm(d, cfg.rms_eps)
self.attn = MultiHeadAttention(cfg)
self.pre_ff_ln = RMSNorm(d, cfg.rms_eps)
self.post_ff_ln = RMSNorm(d, cfg.rms_eps)
self.ff0 = nn.Linear(d, d, bias=False)
self.ff1 = nn.Linear(d, d, bias=False)
self.act = nn.SiLU()
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x, attn_mask, cos, sin):
a = self.attn(self.pre_attn_ln(x), attn_mask, cos, sin)
x = x + self.post_attn_ln(self.drop(a)) # sandwich: norm before AND after
f = self.ff1(self.act(self.ff0(self.pre_ff_ln(x))))
x = x + self.post_ff_ln(self.drop(f))
return xclass TransformerBlock(nn.Module):
def __init__(self, cfg):
super().__init__()
d = cfg.model_dim
self.pre_attn_ln = RMSNorm(d, cfg.rms_eps)
self.post_attn_ln = RMSNorm(d, cfg.rms_eps)
self.attn = MultiHeadAttention(cfg)
self.pre_ff_ln = RMSNorm(d, cfg.rms_eps)
self.post_ff_ln = RMSNorm(d, cfg.rms_eps)
self.ff0 = nn.Linear(d, d, bias=False)
self.ff1 = nn.Linear(d, d, bias=False)
self.act = nn.SiLU()
self.drop = nn.Dropout(cfg.dropout)
def forward(self, x, attn_mask, cos, sin):
a = self.attn(self.pre_attn_ln(x), attn_mask, cos, sin)
x = x + self.post_attn_ln(self.drop(a)) # sandwich: norm before AND after
f = self.ff1(self.act(self.ff0(self.pre_ff_ln(x))))
x = x + self.post_ff_ln(self.drop(f))
return x
We just built one full transformer block. The input is normalized, attended, normalized again, and added back, then the same sandwich pattern runs around the feed-forward.
Stack a bunch of these and you have the body of the model.
Assembling the full model and the forward pass
We have all the parts. Now we wire them into the PatchedDecoder.
The tokenizer is a ResidualBlock that takes a patch and its mask and produces a model-dimension embedding. The body is a stack of transformer blocks.
The head is another ResidualBlock that maps the final embedding to the full forecast, every horizon step times every output channel.
class PatchedDecoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
p = cfg.patch_len
self.tokenizer = ResidualBlock(2 * p, cfg.model_dim, cfg.model_dim, bias=True)
self.layers = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.num_layers)])
self.rope = RotaryEmbedding(cfg.head_dim, cfg.rope_max_period)
self.head = ResidualBlock(cfg.model_dim, cfg.model_dim,
cfg.horizon_len * cfg.num_outputs, bias=False)class PatchedDecoder(nn.Module):
def __init__(self, cfg):
super().__init__()
self.cfg = cfg
p = cfg.patch_len
self.tokenizer = ResidualBlock(2 * p, cfg.model_dim, cfg.model_dim, bias=True)
self.layers = nn.ModuleList([TransformerBlock(cfg) for _ in range(cfg.num_layers)])
self.rope = RotaryEmbedding(cfg.head_dim, cfg.rope_max_period)
self.head = ResidualBlock(cfg.model_dim, cfg.model_dim,
cfg.horizon_len * cfg.num_outputs, bias=False)Notice the tokenizer takes 2 * p inputs, not p. That is because we feed it the normalized patch values concatenated with the patch mask, so the model always knows which points were real and which were padding.
The head produces horizon_len * num_outputs numbers, which we reshape into a horizon-by-channels grid.
Now the forward pass. This is where patching, RevIN, the tokenizer, RoPE, the mask, and the transformer stack all come together.
def encode(self, x, padding):
"""Returns normalized per-position outputs [B, N, horizon, Q] + (mu, sigma)."""
cfg = self.cfg
p = cfg.patch_len
B, L = x.shape
assert L % p == 0, "context length must be a multiple of patch_len"
N = L // p
xp = x.view(B, N, p)
pp = padding.view(B, N, p)
mu, sigma = first_patch_stats(xp, pp) # RevIN stats
xn = ((xp - mu[:, None, None]) / sigma[:, None, None]).clamp(-20.0, 20.0)
xn = xn * (1.0 - pp) # zero out masked points
tok_in = torch.cat([xn, pp], dim=-1) # [B, N, 2p]: values + mask
h = self.tokenizer(tok_in) # [B, N, D]
pos = torch.arange(N, device=x.device)
cos, sin = self.rope.cos_sin(pos, h.dtype)
patch_valid = pp.min(dim=-1).values < 0.5 # [B, N] True if any valid point
mask = self._attn_mask(patch_valid)
for blk in self.layers:
h = blk(h, mask, cos, sin)
out = self.head(h).view(B, N, cfg.horizon_len, cfg.num_outputs)
return out, (mu, sigma)def encode(self, x, padding):
"""Returns normalized per-position outputs [B, N, horizon, Q] + (mu, sigma)."""
cfg = self.cfg
p = cfg.patch_len
B, L = x.shape
assert L % p == 0, "context length must be a multiple of patch_len"
N = L // p
xp = x.view(B, N, p)
pp = padding.view(B, N, p)
mu, sigma = first_patch_stats(xp, pp) # RevIN stats
xn = ((xp - mu[:, None, None]) / sigma[:, None, None]).clamp(-20.0, 20.0)
xn = xn * (1.0 - pp) # zero out masked points
tok_in = torch.cat([xn, pp], dim=-1) # [B, N, 2p]: values + mask
h = self.tokenizer(tok_in) # [B, N, D]
pos = torch.arange(N, device=x.device)
cos, sin = self.rope.cos_sin(pos, h.dtype)
patch_valid = pp.min(dim=-1).values < 0.5 # [B, N] True if any valid point
mask = self._attn_mask(patch_valid)
for blk in self.layers:
h = blk(h, mask, cos, sin)
out = self.head(h).view(B, N, cfg.horizon_len, cfg.num_outputs)
return out, (mu, sigma)
It helps to trace the shapes through the whole forward pass once, end to end:
- The raw context arrives as
[B, L], a batch of B series each L points long. - We reshape it into
[B, N, p], which is N patches of p points each. - RevIN computes a mean and standard deviation per series and normalizes, so every series lands in the same tidy range.
- We concatenate the normalized values with the padding mask, so each token is
2pnumbers wide and carries its own "which points were real" flag. - The tokenizer maps each token to the model dimension, giving
[B, N, D]. - The transformer stack runs L blocks, keeping the shape
[B, N, D]but enriching every patch with information from the patches before it. - The head expands each patch to
[B, N, horizon, 10], a full 128-step forecast with ten channels per step.
Every step keeps the batch and patch dimensions intact, and that is exactly what lets a single forward pass produce a prediction at every position at once, which is the whole efficiency of the decoder-only objective.
The output shape says what the model produces. For every position in the sequence, the model predicts the next 128 steps, and for each of those steps it produces 10 channels, the mean plus nine quantiles.
def _attn_mask(self, patch_valid):
# patch_valid: [B, N] bool. Returns [B, 1, N, N] bool (True = attend), causal + key-padding,
# with the diagonal always allowed (prevents all-masked rows -> NaN).
B, N = patch_valid.shape
dev = patch_valid.device
causal = torch.tril(torch.ones(N, N, dtype=torch.bool, device=dev))
mask = causal[None, None] & patch_valid[:, None, None, :]
eye = torch.eye(N, dtype=torch.bool, device=dev)[None, None]
return mask | eyedef _attn_mask(self, patch_valid):
# patch_valid: [B, N] bool. Returns [B, 1, N, N] bool (True = attend), causal + key-padding,
# with the diagonal always allowed (prevents all-masked rows -> NaN).
B, N = patch_valid.shape
dev = patch_valid.device
causal = torch.tril(torch.ones(N, N, dtype=torch.bool, device=dev))
mask = causal[None, None] & patch_valid[:, None, None, :]
eye = torch.eye(N, dtype=torch.bool, device=dev)[None, None]
return mask | eyeThe attention mask deserves its own look. It has to be causal, a patch can only attend to itself and earlier patches, and it has to ignore padding patches.
There is one subtle guard, we always allow a position to attend to itself, even if it is padding, because a row with nothing to attend to produces a NaN in the softmax.
Let us prove the whole thing runs. The sanity script does one forward pass and one backward pass on the tiny model and checks the shapes and the gradient.
out, (mu, sigma) = m(x, pad)
print("forward out shape:", tuple(out.shape), " expect [4, 16, 128, 10]")
loss, mse = training_step(m, x, pad, cfg)
loss.backward()
gnorm = float(sum((p.grad.float() ** 2).sum() for p in m.parameters() if p.grad is not None) ** 0.5)
print(f"loss={loss.item():.4f} mse={mse.item():.4f} grad_norm={gnorm:.3f}")
#### OUTPUT ####
=== forward + backward (tiny, context 512) ===
forward out shape: (4, 16, 128, 10) expect [4, 16, 128, 10]
loss=44.7183 mse=27.6402 grad_norm=39.842
RESULT: OKout, (mu, sigma) = m(x, pad)
print("forward out shape:", tuple(out.shape), " expect [4, 16, 128, 10]")
loss, mse = training_step(m, x, pad, cfg)
loss.backward()
gnorm = float(sum((p.grad.float() ** 2).sum() for p in m.parameters() if p.grad is not None) ** 0.5)
print(f"loss={loss.item():.4f} mse={mse.item():.4f} grad_norm={gnorm:.3f}")
#### OUTPUT ####
=== forward + backward (tiny, context 512) ===
forward out shape: (4, 16, 128, 10) expect [4, 16, 128, 10]
loss=44.7183 mse=27.6402 grad_norm=39.842
RESULT: OKThe shape is exactly what we predicted, four series, sixteen patches, a 128-step horizon, and ten channels. The loss is finite and the gradient norm is non-zero, which means the whole graph is connected and learning can flow.
The loss starts high, around 44, because the untrained head produces large random values, and that is fine, it is about to come down fast.
From training outputs to an autoregressive forecast
During training the model predicts the next 128 steps from every position at once. At inference we want a forecast of arbitrary length, so we roll forward.
We encode the context, take the prediction at the last position, append the median of that prediction to the context, and repeat until we have enough steps.
@torch.no_grad()
def _forecast_once(self, context, horizon, point_channel=5):
cfg = self.cfg
p, h_len = cfg.patch_len, cfg.horizon_len
x = context
pad = torch.zeros_like(x)
points, quants = [], []
produced = 0
while produced < horizon:
out, (mu, sigma) = self.encode(x, pad) # [B, N, h_len, Q] normalized
last = out[:, -1] * sigma[:, None, None] + mu[:, None, None] # un-normalize the last patch
points.append(last[..., point_channel]) # feedback channel: 5 = q50 median
quants.append(last)
new = last[..., point_channel]
x = torch.cat([x, new], dim=1) # append the forecast to the context
pad = torch.cat([pad, torch.zeros_like(new)], dim=1)
produced += h_len
return torch.cat(points, dim=1)[:, :horizon], torch.cat(quants, dim=1)[:, :horizon]@torch.no_grad()
def _forecast_once(self, context, horizon, point_channel=5):
cfg = self.cfg
p, h_len = cfg.patch_len, cfg.horizon_len
x = context
pad = torch.zeros_like(x)
points, quants = [], []
produced = 0
while produced < horizon:
out, (mu, sigma) = self.encode(x, pad) # [B, N, h_len, Q] normalized
last = out[:, -1] * sigma[:, None, None] + mu[:, None, None] # un-normalize the last patch
points.append(last[..., point_channel]) # feedback channel: 5 = q50 median
quants.append(last)
new = last[..., point_channel]
x = torch.cat([x, new], dim=1) # append the forecast to the context
pad = torch.cat([pad, torch.zeros_like(new)], dim=1)
produced += h_len
return torch.cat(points, dim=1)[:, :horizon], torch.cat(quants, dim=1)[:, :horizon]
We feed back channel 5, the median, because the median is the most robust point estimate to roll forward. The channel is a parameter rather than a hardcoded choice, and that flexibility matters later, because which channel works best depends on how well the model is trained.
With the forward pass and the forecast loop done, the model is complete. Now it needs something to learn from.
The objective: predict every next 128 points,
Here is the decoder-only trick that makes training efficient. We do not just supervise the last position.
Every patch position is trained to predict the 128 steps that follow it. One forward pass produces N training signals, one per patch.
Positions whose target would run past the end of the context are masked out.
def build_targets(x, padding, patch_len, horizon, mu, sigma):
"""x, padding: [B, L]. Returns targets[B, N, horizon] (normalized), tmask[B, N, horizon] (1=valid)."""
B, L = x.shape
p = patch_len
N = L // p
xn = ((x - mu[:, None]) / sigma[:, None]).clamp(-20.0, 20.0) # normalized targets
targets = x.new_zeros(B, N, horizon)
tmask = x.new_zeros(B, N, horizon)
for j in range(N):
start = (j + 1) * p
if start >= L:
break # remaining positions have no target
end = min(start + horizon, L)
ln = end - start
targets[:, j, :ln] = xn[:, start:end]
tmask[:, j, :ln] = 1.0 - padding[:, start:end] # valid where source not padded
return targets, tmaskdef build_targets(x, padding, patch_len, horizon, mu, sigma):
"""x, padding: [B, L]. Returns targets[B, N, horizon] (normalized), tmask[B, N, horizon] (1=valid)."""
B, L = x.shape
p = patch_len
N = L // p
xn = ((x - mu[:, None]) / sigma[:, None]).clamp(-20.0, 20.0) # normalized targets
targets = x.new_zeros(B, N, horizon)
tmask = x.new_zeros(B, N, horizon)
for j in range(N):
start = (j + 1) * p
if start >= L:
break # remaining positions have no target
end = min(start + horizon, L)
ln = end - start
targets[:, j, :ln] = xn[:, start:end]
tmask[:, j, :ln] = 1.0 - padding[:, start:end] # valid where source not padded
return targets, tmaskKeep an eye on that if start >= L: break line. It quietly drops the target for the last patch when the context is exactly full, and that innocent-looking line is a bug we will pay for later.
For now, hold it in mind.
The loss has two parts. The mean channel is trained with plain squared error.
The nine quantile channels are trained with pinball loss, an asymmetric loss that pulls each quantile to the right place. For the 0.1 quantile it penalizes over-prediction more, for the 0.9 quantile it penalizes under-prediction more, and together they teach the model to express a calibrated spread.
def compute_loss(pred_norm, targets, tmask, quantiles):
"""pred_norm: [B, N, horizon, 1+Q] normalized. Returns (total_loss, mse_component)."""
denom = tmask.sum().clamp(min=1.0)
mean_pred = pred_norm[..., 0]
mse = (((mean_pred - targets) ** 2) * tmask).sum() / denom # the mean channel
pinball = pred_norm.new_zeros(())
for i, qv in enumerate(quantiles):
qp = pred_norm[..., i + 1]
dev = targets - qp
pl = torch.maximum(qv * dev, (qv - 1.0) * dev) * tmask # asymmetric per quantile
pinball = pinball + pl.sum() / denom
return mse + pinball, mse.detach()def compute_loss(pred_norm, targets, tmask, quantiles):
"""pred_norm: [B, N, horizon, 1+Q] normalized. Returns (total_loss, mse_component)."""
denom = tmask.sum().clamp(min=1.0)
mean_pred = pred_norm[..., 0]
mse = (((mean_pred - targets) ** 2) * tmask).sum() / denom # the mean channel
pinball = pred_norm.new_zeros(())
for i, qv in enumerate(quantiles):
qp = pred_norm[..., i + 1]
dev = targets - qp
pl = torch.maximum(qv * dev, (qv - 1.0) * dev) * tmask # asymmetric per quantile
pinball = pinball + pl.sum() / denom
return mse + pinball, mse.detach()We just built the objective. build_targets lines up the next 128 points for every patch and masks the parts that fall off the end.
compute_loss scores the mean channel with squared error and each quantile channel with its own asymmetric pinball penalty, then adds them up. The masking by tmask means a padded or out-of-range point contributes nothing, so the gradient only ever sees valid targets.
The training loop: optimizer and schedule
The loop is short, but the details around it are everything. We use AdamW, but we only apply weight decay to the two-dimensional weight matrices, never to the norms, biases, or learned scales.
Decaying a normalization gain toward zero would undo the reason for having it.
def make_param_groups(model, weight_decay):
"""Weight decay on 2-D Linear weights only; none on norms / per-dim-scale / biases / 1-D params."""
decay, no_decay = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if p.ndim >= 2 and "scale" not in name:
decay.append(p)
else:
no_decay.append(p)
return [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]def make_param_groups(model, weight_decay):
"""Weight decay on 2-D Linear weights only; none on norms / per-dim-scale / biases / 1-D params."""
decay, no_decay = [], []
for name, p in model.named_parameters():
if not p.requires_grad:
continue
if p.ndim >= 2 and "scale" not in name:
decay.append(p)
else:
no_decay.append(p)
return [
{"params": decay, "weight_decay": weight_decay},
{"params": no_decay, "weight_decay": 0.0},
]The learning rate warms up linearly and then follows a cosine down to nearly zero. The warmup keeps the first few hundred steps from blowing up while the norms and scales settle.
def cosine_lr(step, warmup, total, peak):
if step < warmup:
return peak * (step + 1) / max(1, warmup)
prog = (step - warmup) / max(1, total - warmup)
return 0.5 * peak * (1.0 + math.cos(math.pi * min(1.0, prog)))def cosine_lr(step, warmup, total, peak):
if step < warmup:
return peak * (step + 1) / max(1, warmup)
prog = (step - warmup) / max(1, total - warmup)
return 0.5 * peak * (1.0 + math.cos(math.pi * min(1.0, prog)))
There is one more scaling rule that matters when the batch size changes. The original recipe used a batch of 4096, and we train with a much smaller batch, so we scale the peak learning rate by the square root of the ratio.
This keeps the size of each update comparable.
def scaled_peak_lr(base_lr, batch, ref_batch=4096):
"""Square-root LR scaling: lr = base_lr * sqrt(batch / ref_batch). Keeps the per-step update
variance comparable to the paper's batch-4096 regime. (batch 256 -> 5e-4 becomes 1.25e-4.)"""
return base_lr * (batch / ref_batch) ** 0.5def scaled_peak_lr(base_lr, batch, ref_batch=4096):
"""Square-root LR scaling: lr = base_lr * sqrt(batch / ref_batch). Keeps the per-step update
variance comparable to the paper's batch-4096 regime. (batch 256 -> 5e-4 becomes 1.25e-4.)"""
return base_lr * (batch / ref_batch) ** 0.5One training step is just the forward pass, the loss, and a return.
def training_step(model, x, padding, mcfg):
"""Forward + all-positions loss. Returns (loss, mse)."""
pred_norm, (mu, sigma) = model(x, padding)
targets, tmask = build_targets(x, padding, mcfg.patch_len, mcfg.horizon_len, mu, sigma)
return compute_loss(pred_norm, targets, tmask, mcfg.quantiles)def training_step(model, x, padding, mcfg):
"""Forward + all-positions loss. Returns (loss, mse)."""
pred_norm, (mu, sigma) = model(x, padding)
targets, tmask = build_targets(x, padding, mcfg.patch_len, mcfg.horizon_len, mu, sigma)
return compute_loss(pred_norm, targets, tmask, mcfg.quantiles)And the loop that drives it clips the gradient, steps the optimizer, and logs the loss and throughput every so often.
for x, pad in dl:
x = x.to(args.device, non_blocking=True)
pad = pad.to(args.device, non_blocking=True)
for g in opt.param_groups:
g["lr"] = cosine_lr(step, args.warmup, args.steps, peak_lr)
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_amp):
loss, mse = training_step(model, x, pad, mcfg)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip to a norm of 1.0
opt.step()
if step % args.log_every == 0:
rate = (step + 1) / (time.time() - t0)
print(f"step {step:6d} loss {float(loss.detach()):.4f} avg {avg:.4f} "
f"mse {float(mse):.4f} lr {opt.param_groups[0]['lr']:.2e} {rate:.1f} it/s")
step += 1
if step >= args.steps:
breakfor x, pad in dl:
x = x.to(args.device, non_blocking=True)
pad = pad.to(args.device, non_blocking=True)
for g in opt.param_groups:
g["lr"] = cosine_lr(step, args.warmup, args.steps, peak_lr)
with torch.autocast("cuda", dtype=torch.bfloat16, enabled=use_amp):
loss, mse = training_step(model, x, pad, mcfg)
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) # clip to a norm of 1.0
opt.step()
if step % args.log_every == 0:
rate = (step + 1) / (time.time() - t0)
print(f"step {step:6d} loss {float(loss.detach()):.4f} avg {avg:.4f} "
f"mse {float(mse):.4f} lr {opt.param_groups[0]['lr']:.2e} {rate:.1f} it/s")
step += 1
if step >= args.steps:
break
It is time to prove it learns !!!!!
Training the 17M model on synthetic data
Before I spend hours training a large model, I want a fast, strict test that the whole pipeline is correct. The test is simple.
Take one fixed batch of eight synthetic series, turn off all randomness, and train on that single batch for 1500 steps. If the patching, the normalization, the target indexing, the masking, the loss, and the optimizer are all wired correctly, the model should memorize those eight series and the loss should collapse to nearly zero.
If it cannot overfit eight series, nothing else matters.
# small fast model for the gate (correctness, not scale)
cfg = ModelConfig(model_dim=256, num_layers=4, num_heads=8, horizon_len=128, dropout=0.0)
tcfg = TrainConfig(lr=2e-3, weight_decay=0.0, max_steps=1500)
m = build_model(cfg)
# ONE fixed batch of 8 synthetic series (no per-step randomness -> overfit)
series = make_batch(8, 1024, seed=2)
x, pad = batch_from_series(series, 512, cfg.patch_len, np.random.default_rng(3))
opt = make_optimizer(m, tcfg)
m.train()
for step in range(tcfg.max_steps):
for g in opt.param_groups:
g["lr"] = cosine_lr(step, 50, tcfg.max_steps, tcfg.lr)
loss, mse = training_step(m, x, pad, cfg)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
opt.step()
passed = final_mse < 0.05 and final_mse < 0.01 * loss_hist[0]
print("RESULT:", "PASS (overfit gate cleared, pipeline is correct)" if passed else "FAIL")
#### OUTPUT ####
gate model: 3.41M params dim=256 layers=4
step 0 loss 13.42851 mse 8.31204 lr 4.00e-05
step 100 loss 1.05823 mse 0.43210 lr 1.98e-03
step 500 loss 0.18402 mse 0.05133 lr 1.62e-03
step 1000 loss 0.04931 mse 0.01188 lr 6.91e-04
step 1499 loss 0.01797 mse 0.00382 lr 2.00e-06
start loss 13.4285 -> final loss 0.01797 final mse 0.00382
reduction: 3515x
RESULT: PASS (overfit gate cleared, pipeline is correct)# small fast model for the gate (correctness, not scale)
cfg = ModelConfig(model_dim=256, num_layers=4, num_heads=8, horizon_len=128, dropout=0.0)
tcfg = TrainConfig(lr=2e-3, weight_decay=0.0, max_steps=1500)
m = build_model(cfg)
# ONE fixed batch of 8 synthetic series (no per-step randomness -> overfit)
series = make_batch(8, 1024, seed=2)
x, pad = batch_from_series(series, 512, cfg.patch_len, np.random.default_rng(3))
opt = make_optimizer(m, tcfg)
m.train()
for step in range(tcfg.max_steps):
for g in opt.param_groups:
g["lr"] = cosine_lr(step, 50, tcfg.max_steps, tcfg.lr)
loss, mse = training_step(m, x, pad, cfg)
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(m.parameters(), 1.0)
opt.step()
passed = final_mse < 0.05 and final_mse < 0.01 * loss_hist[0]
print("RESULT:", "PASS (overfit gate cleared, pipeline is correct)" if passed else "FAIL")
#### OUTPUT ####
gate model: 3.41M params dim=256 layers=4
step 0 loss 13.42851 mse 8.31204 lr 4.00e-05
step 100 loss 1.05823 mse 0.43210 lr 1.98e-03
step 500 loss 0.18402 mse 0.05133 lr 1.62e-03
step 1000 loss 0.04931 mse 0.01188 lr 6.91e-04
step 1499 loss 0.01797 mse 0.00382 lr 2.00e-06
start loss 13.4285 -> final loss 0.01797 final mse 0.00382
reduction: 3515x
RESULT: PASS (overfit gate cleared, pipeline is correct)The loss falls from 13.4 to 0.018, and the final mean-squared error collapses to 0.0038, a reduction of more than three thousand times from the starting loss, far below our threshold. That single PASS tells us the entire pipeline is correct, every gradient flows to the right place and the model can fit data.
Now we can scale with confidence.
The first training run is the 17M tiny model on synthetic data only, 50,000 steps.
uv run python scripts/train.py --size tiny --corpus synth --steps 50000 --bf16 --out runs/tier1
#### OUTPUT ####
model tiny: 17.65M params on cuda
generating synthetic pool (50000 series) with 24 workers ...
pool (50000, 2048) ready in 12.7s
step 0 loss 47.0134 avg 47.0134 mse 28.5015 lr 5.00e-07 1.4 it/s
step 200 loss 10.1466 avg 18.4428 mse 4.3752 lr 1.01e-04 25.7 it/s
step 400 loss 7.1502 avg 8.3086 mse 2.7762 lr 2.01e-04 27.2 it/s
step 1000 loss 5.7897 avg 6.1124 mse 2.2031 lr 5.00e-04 28.1 it/s
step 10000 loss 3.8803 avg 3.8806 mse 1.4529 lr 4.60e-04 28.4 it/s
step 50000 loss 3.2832 avg 3.4224 mse 1.2242 lr 4.96e-05 28.5 it/s
done.uv run python scripts/train.py --size tiny --corpus synth --steps 50000 --bf16 --out runs/tier1
#### OUTPUT ####
model tiny: 17.65M params on cuda
generating synthetic pool (50000 series) with 24 workers ...
pool (50000, 2048) ready in 12.7s
step 0 loss 47.0134 avg 47.0134 mse 28.5015 lr 5.00e-07 1.4 it/s
step 200 loss 10.1466 avg 18.4428 mse 4.3752 lr 1.01e-04 25.7 it/s
step 400 loss 7.1502 avg 8.3086 mse 2.7762 lr 2.01e-04 27.2 it/s
step 1000 loss 5.7897 avg 6.1124 mse 2.2031 lr 5.00e-04 28.1 it/s
step 10000 loss 3.8803 avg 3.8806 mse 1.4529 lr 4.60e-04 28.4 it/s
step 50000 loss 3.2832 avg 3.4224 mse 1.2242 lr 4.96e-05 28.5 it/s
done.The loss drops from 47 to 10 in the first 200 steps, then settles into a long slow grind down to about 3.3 by the end. You can see the warmup in the lr column, it ramps up over the first thousand steps and then the cosine decay takes it back down.
On synthetic test series this 17M model beats the naive baseline by roughly two times, which is great. On the real ETT benchmark, though, it lands right around the naive baseline, a scaled MAE near 0.97.
The pipeline works, but synthetic data alone is not enough to win on real series. Synthetic series of smooth sines and trends happen to resemble periodic electricity data, which helps us a little, but it is no substitute for real data.
We need real data, and we need more capacity.
Scaling to 70M, and the lesson that data beats size
The obvious next move is more parameters and real data, so we jump to the 70M small model and mix in the M4 competition dataset, a large pile of business and finance series. I expected real data plus four times the parameters to be a clear win.
The numbers said otherwise.
A scaled MAE of 1.21 means the 70M model is 21 percent worse than just repeating the last value. A bigger model trained on the wrong real data lost to doing nothing.
The reason is domain mismatch. M4 is full of short, spiky business and finance series, and ETT is smooth periodic energy data.
The model spent its capacity learning the wrong shapes. This was the lesson that reset my thinking for the rest of the project, data composition matters more than model size.
So I swapped the corpus. Electricity load data from the UCI repository is long, smooth, and periodic, much closer to ETT.
Here is the loader, which reads the raw file, strips the leading inactive period of each client, and standardizes.
#### OUTPUT ####
=== zero-shot ETT (70M, M4 + synthetic) ===
ours ETT MAE 0.315 scaled 1.21
naive ETT MAE 0.261 scaled 1.00
def load_electricity(path, min_len=512, norm_mode="zscore", verbose=True):
"""UCI Electricity LD2011_2014.txt: ';'-separated, ','-decimal, 370 client columns
(15-min kWh). Each client = one long periodic series; strip leading zeros, standardize."""
import pandas as pd
df = pd.read_csv(path, sep=";", decimal=",", index_col=0, low_memory=False)
series = []
for col in df.columns:
v = pd.to_numeric(df[col], errors="coerce").to_numpy(dtype=np.float32)
v = np.nan_to_num(v, nan=0.0)
nz = np.nonzero(v)[0]
if len(nz) == 0:
continue
v = v[nz[0]:] # drop leading zeros (client not yet active)
if len(v) >= min_len:
series.append(_standardize(v, mode=norm_mode))
return series#### OUTPUT ####
=== zero-shot ETT (70M, M4 + synthetic) ===
ours ETT MAE 0.315 scaled 1.21
naive ETT MAE 0.261 scaled 1.00
def load_electricity(path, min_len=512, norm_mode="zscore", verbose=True):
"""UCI Electricity LD2011_2014.txt: ';'-separated, ','-decimal, 370 client columns
(15-min kWh). Each client = one long periodic series; strip leading zeros, standardize."""
import pandas as pd
df = pd.read_csv(path, sep=";", decimal=",", index_col=0, low_memory=False)
series = []
for col in df.columns:
v = pd.to_numeric(df[col], errors="coerce").to_numpy(dtype=np.float32)
v = np.nan_to_num(v, nan=0.0)
nz = np.nonzero(v)[0]
if len(nz) == 0:
continue
v = v[nz[0]:] # drop leading zeros (client not yet active)
if len(v) >= min_len:
series.append(_standardize(v, mode=norm_mode))
return seriesTo mix sources by weight, we use a small weighted sampler. Each source is a list of series or a pool, and a draw picks a source by its weight and then a random series from it.
class WeightedMixDataset(IterableDataset):
"""Sample from multiple sources by weight. Lets us weight ETT-like electricity high and M4 low."""
def __init__(self, sources, weights, context_len, patch_len, seed=0):
self.sources = sources
w = np.asarray(weights, dtype=np.float64)
self.cum = np.cumsum(w / w.sum()) # cumulative weights for a fast draw
self.context_len = context_len
self.patch_len = patch_len
self.seed = seed
def __iter__(self):
rng = np.random.default_rng(self.seed + 9973)
while True:
si = int(np.searchsorted(self.cum, rng.random())) # pick a source by weight
src = self.sources[si]
if isinstance(src, np.ndarray):
s = src[int(rng.integers(0, src.shape[0]))]
else:
s = src[int(rng.integers(0, len(src)))]
x, p = batch_from_series([s], self.context_len, self.patch_len, rng)
yield x[0], p[0]class WeightedMixDataset(IterableDataset):
"""Sample from multiple sources by weight. Lets us weight ETT-like electricity high and M4 low."""
def __init__(self, sources, weights, context_len, patch_len, seed=0):
self.sources = sources
w = np.asarray(weights, dtype=np.float64)
self.cum = np.cumsum(w / w.sum()) # cumulative weights for a fast draw
self.context_len = context_len
self.patch_len = patch_len
self.seed = seed
def __iter__(self):
rng = np.random.default_rng(self.seed + 9973)
while True:
si = int(np.searchsorted(self.cum, rng.random())) # pick a source by weight
src = self.sources[si]
if isinstance(src, np.ndarray):
s = src[int(rng.integers(0, src.shape[0]))]
else:
s = src[int(rng.integers(0, len(src)))]
x, p = batch_from_series([s], self.context_len, self.patch_len, rng)
yield x[0], p[0]
With electricity weighted heavily, the 70M model finally beats the baseline.
#### OUTPUT ####
=== zero-shot ETT vs step (70M, electricity + M4 + synthetic) ===
step 25000 ETT MAE 0.2418 scaled 0.93
naive 0.2610 scaled 1.00#### OUTPUT ####
=== zero-shot ETT vs step (70M, electricity + M4 + synthetic) ===
step 25000 ETT MAE 0.2418 scaled 0.93
naive 0.2610 scaled 1.00At step 25,000 the model hits 0.2418, a scaled MAE of 0.93, comfortably under the naive floor.
Here is how the model sizes compare on ETT so far, the naive floor, the 17M synthetic model, our 70M, and a strong released reference for scale.
The training loss curve looks healthy too, a steep early drop and a long plateau.
And here is what the forecasts look like on held-out ETT windows, our 70M median against the truth, the naive line, and the shaded interval.
This looked good. The model beats naive, the loss is falling, the forecasts look reasonable.
So naturally I let it keep training, expecting it to get even better. That is where the problem appeared.
The overfitting problem
I kept training past 25,000 steps and watched the ETT error. It did not keep falling.
It started climbing, and it kept climbing, even though the training loss kept dropping the whole time. The model was getting better at the data it trained on and worse at the world.
This is the generalization gap widening. The model is memorizing the narrow electricity corpus instead of learning transferable structure.
Here is the clearest evidence, the ETT error measured at every checkpoint of that run.
#### OUTPUT ####
step,ett_mae
25000,0.2418
50000,0.2501
75000,0.2647
100000,0.2692
125000,0.2819
150000,0.2846
175000,0.2829
final,0.3085
naive 0.2610#### OUTPUT ####
step,ett_mae
25000,0.2418
50000,0.2501
75000,0.2647
100000,0.2692
125000,0.2819
150000,0.2846
175000,0.2829
final,0.3085
naive 0.2610The best result, 0.2418, was at step 25,000. By the end the error had climbed to 0.3085, which is now worse than naive.
The model peaked early and then spent 175,000 more steps getting worse at the only thing we care about. Plotted, the curve is unmistakable.
Why does this happen at all? A model with seventy million parameters has enormous capacity, and a corpus of a few hundred electricity series simply does not contain enough distinct shapes to fill that capacity with general structure.
So the optimizer does the next best thing, it starts memorizing the exact series it trains on, quirks and noise included. Those memorized details do not transfer to ETT, which the model never sees, so the test error climbs even while the training loss keeps falling.
The widening gap between the two curves is the model learning things that are true of the training set and false of the world.
The naive reaction is to stop early at step 25,000 and call it done. But that is treating the symptom.
The actual cause is that a model with 70 million parameters trained on a narrow corpus of a few hundred electricity series has more than enough capacity to memorize it. The fix is not less training.
The fix is more variety.
The fix: corpus diversity
If a narrow corpus is why the model overfits, then the fix is breadth. I rebuilt the training corpus to span many domains and many frequencies, pulling from the LOTSA collection of real series, traffic, web, cloud, finance, and health, organized into frequency tiers.
The loader reads each configured dataset directory and caps how many series any single dataset can contribute, so one giant dataset cannot dominate.
LOTSA_BUNDLES = {
"hourly": ["LOOP_SEATTLE", "PEMS04", "PEMS08",
"azure_vm_traces_2017", "kdd_cup_2018_with_missing"],
"daily": ["m4_daily", "bitcoin_with_missing"],
"weekly": ["m4_weekly", "kaggle_web_traffic_weekly", "cdc_fluview_ilinet"],
"monthly": ["m4_monthly", "m1_monthly", "cif_2016_12", "car_parts_with_missing"],
"subhour": ["BEIJING_SUBWAY_30MIN", "australian_electricity_demand"],
}
def load_lotsa_bundle(root, tier, min_len=256, per_config_cap=20000, norm_mode="zscore", verbose=True):
"""Load all configs in a frequency tier -> one combined list of series."""
out = []
for cfg in LOTSA_BUNDLES[tier]:
out.extend(load_lotsa_config(root, cfg, min_len=min_len, max_series=per_config_cap,
norm_mode=norm_mode, verbose=verbose))
return outLOTSA_BUNDLES = {
"hourly": ["LOOP_SEATTLE", "PEMS04", "PEMS08",
"azure_vm_traces_2017", "kdd_cup_2018_with_missing"],
"daily": ["m4_daily", "bitcoin_with_missing"],
"weekly": ["m4_weekly", "kaggle_web_traffic_weekly", "cdc_fluview_ilinet"],
"monthly": ["m4_monthly", "m1_monthly", "cif_2016_12", "car_parts_with_missing"],
"subhour": ["BEIJING_SUBWAY_30MIN", "australian_electricity_demand"],
}
def load_lotsa_bundle(root, tier, min_len=256, per_config_cap=20000, norm_mode="zscore", verbose=True):
"""Load all configs in a frequency tier -> one combined list of series."""
out = []
for cfg in LOTSA_BUNDLES[tier]:
out.extend(load_lotsa_config(root, cfg, min_len=min_len, max_series=per_config_cap,
norm_mode=norm_mode, verbose=verbose))
return outEach tier leans on a per-config loader that has to survive the real world, where a download can be incomplete or a shard can be corrupt. It wraps the read in a try-and-skip so one bad dataset never takes down the whole run, and it drops any series that is too short or completely flat.
def load_lotsa_config(root, config, min_len=256, max_series=None, norm_mode="zscore", verbose=True):
"""Load one LOTSA config directory -> list of standardized 1D series."""
from datasets import load_from_disk
path = os.path.join(root, config)
if not os.path.isdir(path):
return []
out = []
try: # robust to missing/corrupt shards
ds = load_from_disk(path)
for ex in ds:
for s in _series_from_lotsa_row(ex["target"]):
s = np.nan_to_num(s, nan=0.0, posinf=0.0, neginf=0.0)
if s.shape[0] >= min_len and float(s.std()) > 1e-6: # long enough and not flat
out.append(_standardize(s, mode=norm_mode))
if max_series and len(out) >= max_series:
return out # respect the per-config cap
except Exception as e:
if verbose:
print(f" [warn] {config}: load error after {len(out)} series")
return outdef load_lotsa_config(root, config, min_len=256, max_series=None, norm_mode="zscore", verbose=True):
"""Load one LOTSA config directory -> list of standardized 1D series."""
from datasets import load_from_disk
path = os.path.join(root, config)
if not os.path.isdir(path):
return []
out = []
try: # robust to missing/corrupt shards
ds = load_from_disk(path)
for ex in ds:
for s in _series_from_lotsa_row(ex["target"]):
s = np.nan_to_num(s, nan=0.0, posinf=0.0, neginf=0.0)
if s.shape[0] >= min_len and float(s.std()) > 1e-6: # long enough and not flat
out.append(_standardize(s, mode=norm_mode))
if max_series and len(out) >= max_series:
return out # respect the per-config cap
except Exception as e:
if verbose:
print(f" [warn] {config}: load error after {len(out)} series")
return outThat try block matters more than it looks. On one run a single dataset failed to download and the loader simply skipped it instead of crashing the whole job, which would otherwise have wasted the entire training setup.
The price is that the corpus can quietly shrink if a source goes missing, so it pays to read the startup print and confirm the series counts are what you expect before you commit to a long run.
When we build the diverse corpus, the startup print shows just how much wider it is.
#### OUTPUT ####
loading DIVERSE corpus: electricity + LOTSA freq-tiers + synthetic ...
Electricity: 370 series (len>=1024)
LOTSA tier 'hourly': 22024 series
LOTSA tier 'daily': 3597 series
LOTSA tier 'weekly': 609 series
LOTSA tier 'monthly': 18329 series
LOTSA tier 'subhour': 557 series
corpus(diverse): elec 370 h 22024 d 3597 w 609 m 18329 sub 557 synth 80000 | w=[0.15, 0.18, 0.16, 0.12, 0.12, 0.07, 0.2]#### OUTPUT ####
loading DIVERSE corpus: electricity + LOTSA freq-tiers + synthetic ...
Electricity: 370 series (len>=1024)
LOTSA tier 'hourly': 22024 series
LOTSA tier 'daily': 3597 series
LOTSA tier 'weekly': 609 series
LOTSA tier 'monthly': 18329 series
LOTSA tier 'subhour': 557 series
corpus(diverse): elec 370 h 22024 d 3597 w 609 m 18329 sub 557 synth 80000 | w=[0.15, 0.18, 0.16, 0.12, 0.12, 0.07, 0.2]We went from about 370 energy-only series to 45,486 real series across nine domains, plus the synthetic pool. That diversity is the goal.
But now I have a new problem. If I am not allowed to peek at the ETT test set, how do I know when to stop training?
I need a validation signal from a domain that is neither in the training corpus nor in the test set.
The answer is out-of-domain validation. I hold out a couple of Monash datasets, daily series that are not energy and not in training, and I measure forecast error on them at every checkpoint.
That number tells me whether the model is still generalizing, without ever touching ETT.
@torch.no_grad()
def validate(model, val_ctxs, val_truths, device, H=96, batch=128, point_channel=5):
"""OUT-OF-DOMAIN validation MAE. point_channel selects the point/AR feedback channel.
MANDATORY model.eval()/train() toggle: forecast() is @no_grad but dropout is still ACTIVE
in train mode, which would inject noise into the early-stopping signal."""
was_training = model.training
model.eval()
errs = []
for i in range(0, len(val_ctxs), batch):
c = torch.from_numpy(np.asarray(val_ctxs[i:i + batch], np.float32)).to(device)
pt, _ = model.forecast(c, H, point_channel=point_channel)
errs.append(np.abs(pt.cpu().numpy() - np.asarray(val_truths[i:i + batch], np.float32)))
if was_training:
model.train()
return float(np.concatenate(errs, axis=0).mean())@torch.no_grad()
def validate(model, val_ctxs, val_truths, device, H=96, batch=128, point_channel=5):
"""OUT-OF-DOMAIN validation MAE. point_channel selects the point/AR feedback channel.
MANDATORY model.eval()/train() toggle: forecast() is @no_grad but dropout is still ACTIVE
in train mode, which would inject noise into the early-stopping signal."""
was_training = model.training
model.eval()
errs = []
for i in range(0, len(val_ctxs), batch):
c = torch.from_numpy(np.asarray(val_ctxs[i:i + batch], np.float32)).to(device)
pt, _ = model.forecast(c, H, point_channel=point_channel)
errs.append(np.abs(pt.cpu().numpy() - np.asarray(val_truths[i:i + batch], np.float32)))
if was_training:
model.train()
return float(np.concatenate(errs, axis=0).mean())The validation log shows the signal is noisy, which is itself useful information.
#### OUTPUT ####
building OOD val windows (Monash non-energy: nn5_daily + weather) ...
Monash[nn5_daily/train]: 60 series (len>=616)
Monash[weather/train]: 60 series (len>=616)
val windows: 48 (H=96)
[val] step 2500 OOD_val_MAE 0.6806 best inf@-1 <-- new best
[val] step 5000 OOD_val_MAE 0.7056 best 0.6806@2500
[val] step 10000 OOD_val_MAE 0.6908 best 0.6806@2500
[val] step 22500 OOD_val_MAE 0.7179 best 0.6806@2500
BEST OOD val: 0.6806 @ step 2500 -> ckpt_best.pt#### OUTPUT ####
building OOD val windows (Monash non-energy: nn5_daily + weather) ...
Monash[nn5_daily/train]: 60 series (len>=616)
Monash[weather/train]: 60 series (len>=616)
val windows: 48 (H=96)
[val] step 2500 OOD_val_MAE 0.6806 best inf@-1 <-- new best
[val] step 5000 OOD_val_MAE 0.7056 best 0.6806@2500
[val] step 10000 OOD_val_MAE 0.6908 best 0.6806@2500
[val] step 22500 OOD_val_MAE 0.7179 best 0.6806@2500
BEST OOD val: 0.6806 @ step 2500 -> ckpt_best.pt
The validation curve is bumpy and the best checkpoint is early, which tells us the signal is not super precise. But here is the result that matters.
When I let the diverse model train all the way out to 120,000 steps and measured ETT at every step, the error stayed flat, hovering around 0.25, instead of climbing the way the narrow model did.
This is the central result of the whole project. The narrow model rose to 0.31 by the same point.
The diverse model holds steady. Diversity, not less training, is what fixes the over-training. The model now has so many different shapes to fit that it cannot memorize its way out of generalizing.
A bug: the model was weakest where it starts
With the over-training handled, I went looking for accuracy, and I found a bug I had missed. Remember the line in build_targets that breaks out of the loop when the target would run past the context.
for j in range(N):
start = (j + 1) * p
if start >= L:
break # remaining positions have no targetfor j in range(N):
start = (j + 1) * p
if start >= L:
break # remaining positions have no targetWith a training window of exactly the context length, the last patch never gets a target. Its prediction would land past the end of the window, so the loop skips it.
That sounds harmless until you remember how autoregressive forecasting works. At inference, the first thing the model does is predict from the last patch.
So the model was least trained exactly at the position it predicts from first.
The fix is to train on a window that is the context plus one horizon, so the last context patch gets a full target. Because attention is causal, the extra horizon points only add supervision, they do not leak any future information into the context patches.
# OBJECTIVE FIX: train on context+horizon windows so the LAST context patch (the AR inference
# entry point) receives a FULL target. Causal attention => patch-15's representation is identical
# to the 512-context inference case; the extra horizon points only add supervision, not leakage.
train_win = args.context + mcfg.horizon_len
print(f"train window = {train_win} (context {args.context} + horizon {mcfg.horizon_len}) "
f"-> supervises the AR entry point")# OBJECTIVE FIX: train on context+horizon windows so the LAST context patch (the AR inference
# entry point) receives a FULL target. Causal attention => patch-15's representation is identical
# to the 512-context inference case; the extra horizon points only add supervision, not leakage.
train_win = args.context + mcfg.horizon_len
print(f"train window = {train_win} (context {args.context} + horizon {mcfg.horizon_len}) "
f"-> supervises the AR entry point")The improvement shows up at long horizons, where the autoregressive entry point matters most.
#### OUTPUT ####
=== long-horizon MAE (70M, even/K50) ===
before fix h=96 0.231 h=192 ~0.58 ratio ~2.5x
after fix h=96 0.231 h=192 0.280 ratio 1.21x#### OUTPUT ####
=== long-horizon MAE (70M, even/K50) ===
before fix h=96 0.231 h=192 ~0.58 ratio ~2.5x
after fix h=96 0.231 h=192 0.280 ratio 1.21xShort-horizon accuracy was already fine at 0.231. The long horizon was the problem.
Before the fix, the 192-step error was roughly two and a half times the 96-step error. After the fix it is only 1.21 times, because the model now has a trained starting point at the exact patch it has to extrapolate from.
One mislabeled training window, one quiet break, and a fifty percent improvement at long horizons once it was fixed.
The evaluation harness
Up to now I have been quoting single numbers, but to trust any of them we need a rigorous protocol. A good evaluation has to do four things, choose its test windows fairly, compare against a strong baseline and not just the trivial one, measure the uncertainty as well as the point forecast, and test whether a difference between two models is genuine or just noise.
First the windows. We slice each test series into evenly spaced windows across the whole test region, not just the tail, so we are not cherry-picking the easy recent part of the series.
This even-window protocol is stricter than testing only the recent tail, and it raises the naive floor on ETT from the 0.261 we were comparing against earlier to 0.298. Every number from here on is measured against that tougher baseline.
def make_windows(z, context, H, K, train_end=0, test_end=None, scheme="even"):
"""z: 1D standardized array. Returns (ctxs [W, context], truths [W, H]).
Windows live strictly inside the test region [train_end, test_end]."""
n = len(z)
test_end = n if test_end is None else min(test_end, n)
lo = (train_end or 0) + context
hi = test_end - H
if hi <= lo:
return np.empty((0, context), np.float32), np.empty((0, H), np.float32)
if scheme == "tail":
starts = [(test_end - i * H) - H for i in range(K)]
starts = [s for s in starts if s - context >= (train_end or 0) and s + H <= test_end]
else: # even
starts = np.unique(np.linspace(lo, hi, num=K, dtype=np.int64)).tolist()
ctxs, truths = [], []
for t in starts:
ctxs.append(z[t - context:t])
truths.append(z[t:t + H])
return np.asarray(ctxs, np.float32), np.asarray(truths, np.float32)def make_windows(z, context, H, K, train_end=0, test_end=None, scheme="even"):
"""z: 1D standardized array. Returns (ctxs [W, context], truths [W, H]).
Windows live strictly inside the test region [train_end, test_end]."""
n = len(z)
test_end = n if test_end is None else min(test_end, n)
lo = (train_end or 0) + context
hi = test_end - H
if hi <= lo:
return np.empty((0, context), np.float32), np.empty((0, H), np.float32)
if scheme == "tail":
starts = [(test_end - i * H) - H for i in range(K)]
starts = [s for s in starts if s - context >= (train_end or 0) and s + H <= test_end]
else: # even
starts = np.unique(np.linspace(lo, hi, num=K, dtype=np.int64)).tolist()
ctxs, truths = [], []
for t in starts:
ctxs.append(z[t - context:t])
truths.append(z[t:t + H])
return np.asarray(ctxs, np.float32), np.asarray(truths, np.float32)Next the baselines. The naive baseline repeats the last value.
The much stronger one is the seasonal-naive baseline, which repeats the last full season. For periodic data, beating seasonal-naive is a meaningful achievement, beating plain naive is the bare minimum.
def forecast_naive(ctxs, H):
return np.repeat(np.asarray(ctxs)[:, -1:], H, axis=1)
def seasonal_naive(ctxs, H, period):
"""Stronger baseline than last-value: repeat the last `period` context values cyclically."""
ctxs = np.asarray(ctxs)
W, L = ctxs.shape
out = np.empty((W, H), ctxs.dtype)
for t in range(H):
out[:, t] = ctxs[:, L - period + (t % period)]
return outdef forecast_naive(ctxs, H):
return np.repeat(np.asarray(ctxs)[:, -1:], H, axis=1)
def seasonal_naive(ctxs, H, period):
"""Stronger baseline than last-value: repeat the last `period` context values cyclically."""
ctxs = np.asarray(ctxs)
W, L = ctxs.shape
out = np.empty((W, H), ctxs.dtype)
for t in range(H):
out[:, t] = ctxs[:, L - period + (t % period)]
return outBecause we trained nine quantiles, we must also measure the quality of the distribution, not just the point. We use CRPS, a proper score for probabilistic forecasts that we can approximate from the quantiles, and coverage, the fraction of true values that actually fall inside the 80 percent interval.
def crps_from_quantiles(quants, truth, qlevels=QLEVELS):
"""Quantile-approx CRPS per window = 2 * mean_tau pinball_tau, averaged over the horizon."""
truth = np.asarray(truth)
tot = np.zeros(truth.shape, np.float64)
for i, tau in enumerate(qlevels):
dev = truth - quants[..., i + 1]
tot += np.maximum(tau * dev, (tau - 1.0) * dev)
return (2.0 * tot / len(qlevels)).mean(axis=1)
def coverage(quants, truth, lo_i=1, hi_i=9):
"""Fraction of truth inside [q10, q90] per window (should be ~0.8 if calibrated)."""
truth = np.asarray(truth)
lo = np.minimum(quants[..., lo_i], quants[..., hi_i])
hi = np.maximum(quants[..., lo_i], quants[..., hi_i])
return ((truth >= lo) & (truth <= hi)).mean(axis=1)def crps_from_quantiles(quants, truth, qlevels=QLEVELS):
"""Quantile-approx CRPS per window = 2 * mean_tau pinball_tau, averaged over the horizon."""
truth = np.asarray(truth)
tot = np.zeros(truth.shape, np.float64)
for i, tau in enumerate(qlevels):
dev = truth - quants[..., i + 1]
tot += np.maximum(tau * dev, (tau - 1.0) * dev)
return (2.0 * tot / len(qlevels)).mean(axis=1)
def coverage(quants, truth, lo_i=1, hi_i=9):
"""Fraction of truth inside [q10, q90] per window (should be ~0.8 if calibrated)."""
truth = np.asarray(truth)
lo = np.minimum(quants[..., lo_i], quants[..., hi_i])
hi = np.maximum(quants[..., lo_i], quants[..., hi_i])
return ((truth >= lo) & (truth <= hi)).mean(axis=1)
Finally, significance. Comparing two overlapping error bars is not a proper test.
The right test is a paired bootstrap on the per-window difference between our model and the baseline. If the whole confidence interval of that difference sits below zero, we genuinely beat the baseline.
def paired_bootstrap_ci(ae_a, ae_b, n_boot=2000, seed=0, alpha=0.05):
"""Bootstrap CI of the per-window difference (ae_a - ae_b). hi<0 => A significantly better than B."""
d = np.asarray(ae_a, np.float64) - np.asarray(ae_b, np.float64)
m = float(d.mean())
if len(d) < 2:
return m, m, m
rng = np.random.default_rng(seed)
n = len(d)
boot = np.array([d[rng.integers(0, n, n)].mean() for _ in range(n_boot)])
return m, float(np.quantile(boot, alpha / 2)), float(np.quantile(boot, 1 - alpha / 2))def paired_bootstrap_ci(ae_a, ae_b, n_boot=2000, seed=0, alpha=0.05):
"""Bootstrap CI of the per-window difference (ae_a - ae_b). hi<0 => A significantly better than B."""
d = np.asarray(ae_a, np.float64) - np.asarray(ae_b, np.float64)
m = float(d.mean())
if len(d) < 2:
return m, m, m
rng = np.random.default_rng(seed)
n = len(d)
boot = np.array([d[rng.integers(0, n, n)].mean() for _ in range(n_boot)])
return m, float(np.quantile(boot, alpha / 2)), float(np.quantile(boot, 1 - alpha / 2))Running the full grid on the 70M model gives the picture, the good and the bad together.
#### OUTPUT ####
=== Multi-domain zero-shot TEST (even, K=50, flip=OFF, point=median) ===
domain clean naive snaive OURS scaled CRPS cov80 ours-vs-naive d[95% CI]
ETT clean 0.298 0.272 0.246 0.826 0.241 0.682 -0.052[-0.071,-0.034] sig
Exchange clean 0.239 0.244 0.345 1.445 0.371 0.552 +0.106[+0.072,+0.141] ns
--- verdict (CLEAN domains only: ['ETT', 'Exchange']) ---
beats naive (paired-significant): 1/2 ['ETT']#### OUTPUT ####
=== Multi-domain zero-shot TEST (even, K=50, flip=OFF, point=median) ===
domain clean naive snaive OURS scaled CRPS cov80 ours-vs-naive d[95% CI]
ETT clean 0.298 0.272 0.246 0.826 0.241 0.682 -0.052[-0.071,-0.034] sig
Exchange clean 0.239 0.244 0.345 1.445 0.371 0.552 +0.106[+0.072,+0.141] ns
--- verdict (CLEAN domains only: ['ETT', 'Exchange']) ---
beats naive (paired-significant): 1/2 ['ETT']On ETT the 70M model scores 0.246, beating both naive (0.298) and the stronger seasonal-naive (0.272), and the paired interval [-0.071, -0.034] sits entirely below zero, so the result is genuine and not noise. On Exchange the same model scores 0.345, which is a scaled 1.445, meaning it is 44 percent worse than naive.
The model is genuinely good on structured data and genuinely bad on one whole category of data.
Two more things stand out in that table. The coverage column reads 0.682 and 0.552, both well below the 0.80 they should be, so our uncertainty bands are too narrow.
The CRPS column is worth a glance too. At 0.241 on ETT it is the single number that scores the whole predicted distribution rather than just the median, so it is the metric I watch when the uncertainty matters as much as the point forecast, and a lower value means the nine quantiles as a set sit closer to the truth.
And the Exchange failure needs explaining, because a scaled MAE above one means the model is actively hurting us there. Those are the next two problems.
Random-walk data, and a partial fix
The Exchange dataset is currency rates, and currency rates are close to a random walk. The defining property of a random walk is that the best possible forecast is simply the last value, because the next step is the last value plus unpredictable noise.
There is no trend to extend and no season to repeat.
Our model fails here because it learned its lesson too well. Trained on a corpus full of trends and seasons, it cannot help but see trend and season everywhere, so on a random walk it confidently extrapolates a pattern that does not exist and drifts off course.
It does not know how to give up and just persist.
The fix is to teach it that "no structure" is a valid answer. We add a slice of pure random-walk and white-noise series to the synthetic generator, the rw_prob=0.2 branch we saw earlier, so the model learns that for some series the right forecast is to hold the last value, or to revert to the mean.
def generate_random_walk(length, rng):
"""Structureless series so the model learns the RIGHT zero-structure forecast: PERSIST (random
walk -> optimal forecast = last value) or revert to the mean (white noise -> optimal = mean).
This is the targeted fix for the random-walk/FX failure (model was 1.45x WORSE than naive there)."""
if rng.random() < 0.25:
y = rng.normal(0.0, 1.0, size=length) # white noise (optimal = mean)
else:
drift = float(rng.normal(0.0, 0.02))
y = np.cumsum(rng.normal(0.0, 1.0, size=length) + drift) # random walk (optimal = last value)
return y.astype(np.float32)def generate_random_walk(length, rng):
"""Structureless series so the model learns the RIGHT zero-structure forecast: PERSIST (random
walk -> optimal forecast = last value) or revert to the mean (white noise -> optimal = mean).
This is the targeted fix for the random-walk/FX failure (model was 1.45x WORSE than naive there)."""
if rng.random() < 0.25:
y = rng.normal(0.0, 1.0, size=length) # white noise (optimal = mean)
else:
drift = float(rng.normal(0.0, 0.02))
y = np.cumsum(rng.normal(0.0, 1.0, size=length) + drift) # random walk (optimal = last value)
return y.astype(np.float32)
After retraining the 70M model with 20 percent random-walk data, the Exchange error improves, but only partway.
#### OUTPUT ####
=== Exchange (FX, random-walk) scaled MAE (70M) ===
no RW data 1.445 (44% worse than naive)
+20% RW data 1.313 (still 31% worse than naive)#### OUTPUT ####
=== Exchange (FX, random-walk) scaled MAE (70M) ===
no RW data 1.445 (44% worse than naive)
+20% RW data 1.313 (still 31% worse than naive)
The error dropped from 1.445 to 1.313, a genuine improvement, but it is still worse than just repeating the last value. This is a genuine limitation of the 70M model on this corpus.
Pushing the random-walk fraction higher would chase the Exchange number but cost us accuracy on the structured data we care about more. We note the limit plainly and move on to the calibration problem, which has a cleaner fix.
Calibrated uncertainty: conformal recalibration
The coverage numbers told us the model is overconfident, its 80 percent intervals only catch about 0.68 of the truth on ETT and 0.55 on Exchange, both well short of the 0.80 they should reach. We trained the quantiles with pinball loss, but on a finite corpus the model still ends up with bands that are too tight.
The fix does not require retraining. We use conformal recalibration, a post-hoc method that widens the interval by an amount measured on a held-out calibration split, with a finite-sample guarantee that the widened interval will hit the target coverage.
def conformal_widen(cal_lo, cal_hi, cal_truth, test_lo, test_hi, target=0.8):
"""Conformalized Quantile Regression: additively widen the [lo, hi] interval by the conformity
quantile measured on a CALIBRATION set so empirical coverage reaches `target`."""
cl = np.asarray(cal_lo).ravel(); ch = np.asarray(cal_hi).ravel(); cy = np.asarray(cal_truth).ravel()
E = np.maximum(cl - cy, cy - ch) # conformity score (negative when inside)
n = max(1, len(E))
qlev = min(1.0, np.ceil((n + 1) * target) / n) # finite-sample correction
Q = float(np.quantile(E, qlev))
return np.asarray(test_lo) - Q, np.asarray(test_hi) + Q, Qdef conformal_widen(cal_lo, cal_hi, cal_truth, test_lo, test_hi, target=0.8):
"""Conformalized Quantile Regression: additively widen the [lo, hi] interval by the conformity
quantile measured on a CALIBRATION set so empirical coverage reaches `target`."""
cl = np.asarray(cal_lo).ravel(); ch = np.asarray(cal_hi).ravel(); cy = np.asarray(cal_truth).ravel()
E = np.maximum(cl - cy, cy - ch) # conformity score (negative when inside)
n = max(1, len(E))
qlev = min(1.0, np.ceil((n + 1) * target) / n) # finite-sample correction
Q = float(np.quantile(E, qlev))
return np.asarray(test_lo) - Q, np.asarray(test_hi) + Q, QWe validate it on the 70M model without touching the point forecast.
#### OUTPUT ####
ETT 80% interval coverage: before 0.682 -> after conformal 0.807 (widen Q=0.31)
Exchange 80% interval coverage: before 0.552 -> after conformal 0.760 (widen Q=0.48)#### OUTPUT ####
ETT 80% interval coverage: before 0.682 -> after conformal 0.807 (widen Q=0.31)
Exchange 80% interval coverage: before 0.552 -> after conformal 0.760 (widen Q=0.48)
Coverage on ETT jumps from 0.682 to 0.807, right on the 0.80 target, and Exchange goes from 0.552 to 0.760. The point forecast does not change at all, only the width of the band, so we get trustworthy uncertainty without changing the model.
With the over-training handled, the entry point fixed, the random walk softened, and the intervals calibrated, the 70M model is as good as it is going to get on this corpus. It sits at about 0.246 on ETT.
To go further, we need scale.
Scaling to the 200M model
Everything so far has been the 70M model, and it has hit a ceiling. The ETT error is stuck around 0.246 no matter how I tweak the 70M corpus.
The hypothesis the whole project was built to test is that with the data composition fixed, the diversity in place, and every bug closed, scale should finally help. So we take all of it to the 200M base model.
The parameter count of a transformer grows roughly with the number of layers times the model dimension squared, which is why doubling the layers and widening the model takes us from 70M to 200M.
The config change is a single function. Same architecture, more layers and a wider model.
def base() -> ModelConfig: # ~200M backbone (Tier 3)
return ModelConfig(model_dim=1280, num_layers=20, num_heads=16)def base() -> ModelConfig: # ~200M backbone (Tier 3)
return ModelConfig(model_dim=1280, num_layers=20, num_heads=16)We launch the full run on the diverse, random-walk-augmented corpus, with the 640-window objective fix, for 200,000 steps.
uv run python scripts/train.py --size base --corpus diverse --steps 200000 --pool_size 80000 --bf16 \
--weights 0.15,0.18,0.16,0.12,0.12,0.07,0.20 --val_every 5000 --out runs/base200m
#### OUTPUT ####
model base: 203.44M params on cuda
generating synthetic pool (80000 series) with 24 workers ...
pool (80000, 2048) ready in 21.4s
train window = 640 (context 512 + horizon 128) -> supervises the AR entry point
corpus(diverse): elec 370 h 22024 d 3597 w 609 m 18329 sub 557 synth 80000 | w=[0.15, 0.18, 0.16, 0.12, 0.12, 0.07, 0.2]
step 0 loss 47.8312 avg 47.8312 mse 28.9011 lr 8.33e-08 1.0 it/s
step 200 loss 12.4419 avg 19.8830 mse 5.1142 lr 2.00e-05 6.6 it/s
step 1000 loss 6.2034 avg 7.3318 mse 2.4458 lr 1.00e-04 6.5 it/s
step 50000 loss 3.4129 avg 3.4567 mse 1.1043 lr 1.71e-04 6.5 it/s
step 120000 loss 3.0712 avg 3.0844 mse 0.9788 lr 5.84e-05 6.5 it/s
step 200000 loss 3.0188 avg 3.0201 mse 0.9512 lr 1.00e-06 6.5 it/s
[val] step 165000 OOD_val_MAE 0.5912 best 0.5947@150000 <-- new best
BEST OOD val: 0.5912 @ step 165000 -> ckpt_best.pt
done.uv run python scripts/train.py --size base --corpus diverse --steps 200000 --pool_size 80000 --bf16 \
--weights 0.15,0.18,0.16,0.12,0.12,0.07,0.20 --val_every 5000 --out runs/base200m
#### OUTPUT ####
model base: 203.44M params on cuda
generating synthetic pool (80000 series) with 24 workers ...
pool (80000, 2048) ready in 21.4s
train window = 640 (context 512 + horizon 128) -> supervises the AR entry point
corpus(diverse): elec 370 h 22024 d 3597 w 609 m 18329 sub 557 synth 80000 | w=[0.15, 0.18, 0.16, 0.12, 0.12, 0.07, 0.2]
step 0 loss 47.8312 avg 47.8312 mse 28.9011 lr 8.33e-08 1.0 it/s
step 200 loss 12.4419 avg 19.8830 mse 5.1142 lr 2.00e-05 6.6 it/s
step 1000 loss 6.2034 avg 7.3318 mse 2.4458 lr 1.00e-04 6.5 it/s
step 50000 loss 3.4129 avg 3.4567 mse 1.1043 lr 1.71e-04 6.5 it/s
step 120000 loss 3.0712 avg 3.0844 mse 0.9788 lr 5.84e-05 6.5 it/s
step 200000 loss 3.0188 avg 3.0201 mse 0.9512 lr 1.00e-06 6.5 it/s
[val] step 165000 OOD_val_MAE 0.5912 best 0.5947@150000 <-- new best
BEST OOD val: 0.5912 @ step 165000 -> ckpt_best.pt
done.The throughput is 6.5 iterations per second, less than half the 70M speed, which is the cost of three times the parameters. The training loss settles to about 3.02, with the mean-squared error component down at 0.95, both lower floors than the 70M run reached, and the out-of-domain validation lands at 0.5912, comfortably better than the 70M model's 0.68.
The bigger model is fitting the diverse corpus more deeply without losing its footing.
The first thing I checked was whether the bigger model overfits.
I measured ETT at every checkpoint out to 200,000 steps.
The curve is flat at about 0.219 all the way out, no rise at all. The diversity that protected the 70M model protects the 200M model too.
And 0.219 is a big step down from the 70M's 0.246, landing just above the strong released reference at 0.215. Now the full evaluation grid, this time with the released reference measured alongside us.
#### OUTPUT ####
=== Multi-domain zero-shot TEST (even, K=50, flip=OFF, point=median) ===
domain clean naive snaive OURS released scaled CRPS cov80 ours-vs-naive d[95% CI]
ETT clean 0.298 0.272 0.219 0.215 0.735 0.207 0.773 -0.079[-0.098,-0.061] sig
Exchange clean 0.239 0.244 0.263 0.241 1.101 0.286 0.718 +0.024[-0.003,+0.051] ns
--- verdict (CLEAN domains only: ['ETT', 'Exchange']) ---
beats naive (paired-significant): 1/2 ['ETT']#### OUTPUT ####
=== Multi-domain zero-shot TEST (even, K=50, flip=OFF, point=median) ===
domain clean naive snaive OURS released scaled CRPS cov80 ours-vs-naive d[95% CI]
ETT clean 0.298 0.272 0.219 0.215 0.735 0.207 0.773 -0.079[-0.098,-0.061] sig
Exchange clean 0.239 0.244 0.263 0.241 1.101 0.286 0.718 +0.024[-0.003,+0.051] ns
--- verdict (CLEAN domains only: ['ETT', 'Exchange']) ---
beats naive (paired-significant): 1/2 ['ETT']This is the result we were after. On ETT the 200M model scores 0.219, beating naive (0.298) and seasonal-naive (0.272) by a wide and paired-significant margin, and sitting just 0.004 above the strong released reference (0.215).
We have closed almost the entire gap to the best known model, from scratch. The scaling-law picture across all three of our sizes tells the story in one frame.
The random-walk story improves too. On Exchange the 200M model scores a scaled 1.101, down from the 70M's 1.313, and the paired interval [-0.003, +0.051] now straddles zero, which means it is no longer significantly worse than naive.
The bigger model, trained with the random-walk data, has finally learned to mostly persist on a random walk instead of inventing structure that is not there.
Calibration is in good shape out of the box, and conformal recalibration lands it exactly on target.
#### OUTPUT ####
ETT 80% interval coverage: raw 0.773 -> after conformal 0.81
Exchange 80% interval coverage: raw 0.718 -> after conformal 0.79#### OUTPUT ####
ETT 80% interval coverage: raw 0.773 -> after conformal 0.81
Exchange 80% interval coverage: raw 0.718 -> after conformal 0.79
And the long horizon, the place the entry-point fix mattered most, is the best it has ever been.
#### OUTPUT ####
=== long-horizon MAE (200M, even/K50) ===
h=96 0.205 h=192 0.232 ratio 1.13x#### OUTPUT ####
=== long-horizon MAE (200M, even/K50) ===
h=96 0.205 h=192 0.232 ratio 1.13xStep back and look at what just happened across the three sizes. The 17M model could barely match a naive baseline.
The 70M model beat it on structured data but stopped improving at 0.246 and could not be pushed further by any amount of corpus tuning. The 200M model, trained on exactly the same diverse data with exactly the same fixes, dropped to 0.219 and held there with no rise.
That is the scaling hypothesis confirmed, capacity helps only once the data and the objective are right, and at that point it helps consistently.
Finally, the forecasts themselves. Here is the 200M model forecasting four held-out ETT series it never trained on, with the median in green and the 80 percent band shaded.
The green median tracks the black truth through the daily cycles, the band widens as the horizon grows and uncertainty builds, and the naive baseline sits flat and useless next to it. This is a model that was built from an empty file, trained on data we assembled ourselves, and it forecasts a series it has never seen.
Limitations and what I would build next
A build-along owes you the rough edges as well as the wins, so here they are plainly. The model still scores a scaled 1.10 on the random-walk Exchange series, which means it is right at the naive baseline rather than clearly under it, so it is not yet a truly universal forecaster.
It has no key-value cache in the autoregressive loop, so every forecast step re-encodes the full context, which makes long forecasts slower than they need to be. It uses a post-hoc conformal wrapper for calibration rather than a native continuous-quantile head, which would calibrate the intervals without the extra step.
The corpus, while diverse, is still only about a dozen distinct datasets, and more genuinely different domains is the clearest remaining accuracy lever. And while the 200M now sits within 0.004 of the released reference, closing that last gap and pulling clearly ahead of it would take a larger and more varied corpus still.
None of these are deep architectural problems. They are the next four pieces of work, and each one has a clear path.
The model we built is sound, calibrated, and competitive, and it leaves obvious room to grow.
What building it taught me
The same lesson kept showing up at every rung of the ladder. At 17M, synthetic data proved the pipeline but could not win on real series.
At 70M, the wrong real data lost to a naive baseline, the right data overfit, and only diversity and a careful objective brought it back. The thing that made the difference was almost never raw size, it was data composition and careful measurement, right up until the data and the fixes were in place and the added scale finally helped at 200M.
That is the shape of building one of these from scratch. You do not win by reaching for more parameters first.
You win by getting the data right, by measuring without fooling yourself, by fixing the quiet bugs that hide at the exact spot the model is weakest, and then, once all of that is true, by scaling up and watching it land. We started with an empty file and a naive baseline we could not beat.
We finished with a 200 million parameter model that forecasts a series it has never seen, with calibrated uncertainty, from scratch.
If you enjoyed this blog, you can follow me on Medium. I only post here.