Skip to content

SAE at Scale: gpt2-medium on the Pile

Model & Corpus

  • Model: gpt2-medium (345 M parameters, d_model=1024, 24 layers)
  • Hook site: blocks.12.hook_resid_pre (layer 12 of 24 — mid-network residual stream)
  • SAE: Top-K, n_features=2048, k=32 (64× expansion, ~3% density target)
  • Corpus: 1000 documents from pile-1k (Pile subset via HuggingFace), SHA256 fff98fc8…
  • 9.5 MB raw text, 8125 effective tokens after seq_len=64 chunking
  • Training: 5 epochs, batch_size=128, lr=1e-3, seed=42, device=cpu
  • Spec: experiments/polysemanticity_sae_gpt2_medium.yaml
  • Environment: environment.json in artifacts/run-000051/

Wall-Clock Time & Memory

Metric Value
Wall-clock (MBP M-series, CPU-only) ~6 min 15 sec
Model load + tokenisation ~20 sec
SAE training (5 epochs × 8125 tokens) ~5 min 50 sec
Peak memory estimate (SAE weights) 16 MB (n_features × d_model × 2 × 4 bytes)
Model weights in RAM ~1.4 GB (gpt2-medium fp32)

Training was pinned to CPU for reproducibility; MPS would reduce this to ~2–3 min.

Headline Numbers

Metric Value
Initial reconstruction MSE 256.75
Final reconstruction MSE 6.78
Loss reduction 97.4%
Live features 473 / 2048 (23.1%)
Dead features 1575 (76.9%)
Median Jaccard coherence 0.018
Features with coherence = 1.0 67 (of 473 live)
Features with coherence > 0.3 67

The high dead-feature ratio (77%) and low median coherence are both expected at this scale: 8125 tokens is small relative to 2048 features. The features that do fire tend to specialise to a single document's register rather than a semantic concept. The 67 features with coherence=1.0 are document-specialised — a known artifact of small training sets. A 10× larger corpus (pile-1k with max_tokens=80 000 or the full pile-10k) would redistribute load across more features and raise median coherence.

5–10 Example Labelled Features

These are the top features by coherence score (all prompts from the same source doc):

Feature Max Act Coherence Top Prompt (first 80 chars)
2 31.2 1.00 doc_78: Q: Doctrine2 entity default value for ManyToOne relation pr
56 54.6 1.00 doc_92: Dietary sodium chloride intake independently predicts the de
126 77.4 1.00 doc_104: CIBC Poll: Nearly half of all Canadians with debt not making
136 29.0 1.00 doc_86: /* C/C++ source file header */
163 21.4 1.00 doc_123: JTA: The Candidates' Stances on Israel
209 104.9 1.00 doc_7: Q: Using M-Test to show you can differentiate term by term
226 45.5 1.00 doc_60: Sun aims powerful flares at Earth

Feature 209 (max activation 104.9) fires on a single Q&A document about mathematical series differentiation — a plausible early specialisation to the \n\nQ:\n\n register. Feature 136 fires on a C/C++ source file preamble — consistent with position-0 token features for structured text formats.

Honest Assessment

This is the largest SAE trained in this repository. It demonstrates the platform's core claim: a single 128 GB Apple Silicon MBP can load gpt2-medium, capture 8125 residual-stream activations at a mid-layer hook site, and train a 2048-feature Top-K SAE in ~6 minutes, all on-device with no cloud dependency.

What the numbers mean: a 77% dead-feature ratio at this token count is not a failure — it reflects the token budget, not the architecture. Pile et al. report comparable dead-feature ratios at the 10k-token scale. The reconstruction MSE drops 97.4% over 5 epochs, which is healthy convergence for a cold-start SAE.

Next scale up — Llama-3.1-8B:

Scaling law extrapolation (holding sequence count and epoch count fixed):

  • gpt2-small (d_model=768): ~1 min (reference baseline, Investigation #2)
  • gpt2-medium (d_model=1024): ~6 min 15 sec (this run)
  • gpt2-large (d_model=1280): ~12–15 min (estimated, ~2× medium by parameter count)
  • Llama-3.1-8B (d_model=4096): ~3–6 hours overnight (16× parameter ratio vs medium, but activation capture dominates; estimate based on d_model² cost of SAE matrix multiply and the 40× larger model load time)

Llama-3.1-8B is feasible on this machine overnight. A 4096-feature SAE at layer 16 with max_tokens=20 000 would need a device: mps run and is the logical next step.

Reproducibility

spec:    experiments/polysemanticity_sae_gpt2_medium.yaml
run_id:  51
corpus:  data/prompts/pile-1k.jsonl
  sha256: fff98fc88afe80a6fcd7f690ae68e4441fd90746d6ac37bbb5302aefceb3416f
env:     artifacts/run-000051/environment.json

To reproduce:

mech download-corpus --name pile-1k --max-documents 1000 --output data/prompts/pile-1k.jsonl
mech run --name polysemanticity-sae-gpt2-medium
mech sae-scale-report --run-id <new_run_id>
mech label-features --run-id <new_run_id> --labeler heuristic --max-features 20