"Large models memorize the world in their weights. Small models need a notepad."
The Problem: Small Models Forget Facts
Large Language Models (LLMs) like GPT-4 are remarkably good at recalling facts — "Delhi is the capital of India," "Einstein developed the theory of relativity" — because they have billions of parameters acting as a massive, compressed knowledge store. The model bakes facts into weights during pre-training, and retrieval is implicit in the forward pass.
But what happens when you shrink the model?
Small Language Models (SLMs) — the kind you can actually run on a laptop or edge device — have far fewer parameters. There simply isn't enough capacity to reliably encode factual associations. They can handle grammar, style, and short-range reasoning reasonably well, but ask them a factual question and they hallucinate, hedge, or go blank.
The parametric memory paradigm breaks down at small scale.
The Insight
Humans don't store all their knowledge in their neurons alone. We use external memory — notebooks, calendars, books, sticky notes. We offload facts to the environment and look them up when needed. The neural machinery handles reasoning; the notepad handles retrieval.
What if we gave a small language model an explicit, learnable notepad?
That's precisely what a Differentiable Neural Computer (DNC) does.
What Is a DNC?
A Differentiable Neural Computer, introduced by DeepMind in 2016, augments a neural network controller with an external memory matrix — a structured, differentiable store that the network can read from and write to via learned attention mechanisms.
Think of it as RAM for a neural network.
Memory Matrix M ∈ ℝ^(N × W)
│
N = number of memory slots (rows)
W = width of each slot (columns)
The controller (in our case, a small GPT-2) interacts with this memory through soft, differentiable read and write heads — so the whole system is end-to-end trainable with backpropagation.
Unlike a hash map or database, the DNC doesn't look up memory by exact key. It uses content-based addressing — cosine similarity between a query key and stored vectors — blended with usage-based allocation to decide where to write new information.
Architecture: GPT-2 + DNC Memory

The full model layers two components:
┌─────────────────────────────┐
│ GPT-2 Backbone │
│ (Masked Self-Attention + │
│ Feed-Forward Layers) │
└──────────────┬──────────────┘
│ hidden state h_t (B, D)
▼
┌─────────────────────────────┐
│ DNC Memory │
│ ┌─────────────────────┐ │
│ │ M ∈ ℝ^(N × W) │ │ ← external RAM
│ └─────────────────────┘ │
│ write → read → update │
└──────────────┬──────────────┘
│ read_vec (B, R*W)
▼
read_proj → h_t + read_vec
│
▼
LM Head → logits
At each time step t, the GPT-2 hidden state h_t is used to:
- Write new information into memory
- Read relevant information back out
-
Fuse the read vector with
h_tbefore projecting to vocabulary logits
The memory persists across time steps within a sequence, making it a form of working memory — information written at step 3 can be retrieved at step 47.
The Memory Module: Read & Write Mechanics
Memory State
The memory at any step is a matrix M ∈ ℝ^(B × N × W) — a batch of N slots, each a W-dimensional vector. A usage vector u ∈ ℝ^(B × N) tracks how much each slot has been written to.
Projections from Hidden State
Given the controller hidden state h_t ∈ ℝ^(B × D), the memory module computes:
| Projection | Shape | Purpose |
|---|---|---|
write_key |
(B, W) |
Where to write (content addressing) |
write_vec |
(B, W) |
What to write |
erase_vec |
(B, W) |
What to erase before writing (sigmoid-gated) |
write_gate |
(B, 1) |
How much to write (0 = skip, 1 = full write) |
read_keys |
(B, R, W) |
Where to read from (R read heads) |
Write Weighting
The write address w_write ∈ ℝ^(B × N) is a soft attention distribution over slots:
w_content = softmax( cosine(write_key, M) × τ )
w_alloc = softmax( (1 − u) × τ )
w_write = 0.5 × w_content + 0.5 × w_alloc
-
Content addressing (
w_content): write near slots whose content resembles the current write key — useful for updating existing facts. -
Allocation (
w_alloc): prefer less-used slots — useful for storing new facts without overwriting old ones.
τ is a learned temperature parameter that sharpens or softens the distribution.
Write Operation
M_new = M × (1 − w_write ⊗ erase_vec) + w_write ⊗ write_vec
M_out = M + write_gate × (M_new − M)
The write_gate is the key knob:
write_gate ≈ 0 → memory unchanged (model relies on parametric knowledge)
write_gate ≈ 1 → full write (model externalizes knowledge)
This gate is learned entirely from data. The model discovers when it's worth writing.
Read Operation
w_read = softmax( read_keys · M^T × τ ) ∈ ℝ^(B × R × N)
read_vec = w_read · M ∈ ℝ^(B × R × W)
→ reshape to (B, R*W)
→ projected back to (B, D) via read_proj
R read heads allow the model to simultaneously query R different "topics" from memory.
State Update
Usage is updated after each write so the allocator tracks which slots are "full":
usage_new = usage + (1 - usage) * w_write.detach()
The .detach() prevents gradients from flowing back through the usage signal — it's a bookkeeping variable, not a learned one.
The Write Gate: Knowing When to Remember
The write gate is the most interpretable component of the whole system. After training, you can run inspect_writes() and visualize per-token gate activations:
Token Gate bar
────────────────────────────────────────────────
Albert 0.821 ████████████████████████
Einstein 0.904 ███████████████████████████
was 0.112 ███
born 0.287 ████████
in 0.094 ██
1879 0.756 ██████████████████████
in 0.071 ██
Ulm 0.683 ████████████████████
He 0.143 ████
developed 0.201 ██████
the 0.058 █
theory 0.388 ███████████
of 0.062 █
relativity 0.712 █████████████████████
The model learns to write on content-bearing tokens (proper nouns, dates, key concepts) and skip function words. Nobody taught it this — it emerged from the loss functions.
Loss Functions
Training uses three losses summed together:
1. Language Modelling Loss (Cross-Entropy)
The standard next-token prediction loss:
L_lm = CrossEntropy(logits[:, :-1], input_ids[:, 1:])
This is the primary loss. The model must still predict the next token correctly.
2. Routing Loss
This loss asks: when the write gate is high, does memory actually change the prediction?
If the model writes to memory but the output distribution looks identical to the no-memory baseline, that write was pointless. The routing loss penalises this:
kl = KL( softmax(p_no_mem) || softmax(p_mem) ).detach()
L_routing = -(gate * kl).mean()
The KL divergence between the memory model and a frozen no-memory baseline is computed per token. Multiplied by the gate and negated, this loss:
- Rewards high gates when memory changes the prediction (high KL)
- Punishes high gates when memory doesn't matter (low KL → wasted write)
The .detach() on the KL ensures gradients only flow through the gate, not the no-memory logits.
3. Entropy Loss (Write Sparsity)
A diffuse write weighting — spreading activation uniformly across all N slots — is wasteful. It's like writing one word across every page of your notebook instead of a single page.
The entropy loss encourages sharp, decisive writes:
H = -(w_writes * (w_writes + 1e-8).log()).sum(-1).mean()
L_entropy = H # minimized during training
Low entropy → sparse write attention → the model commits to specific slots.
Total Loss
L = L_lm + λ_r · L_routing + λ_e · L_entropy
# defaults: λ_r = 0.1, λ_e = 0.05
The auxiliary losses are kept small relative to L_lm so language modelling remains the primary objective. The routing and entropy terms act as structural regularizers that shape how the memory is used, not just whether the model gets tokens right.
Code Walkthrough
DNCMemory Module
class DNCMemory(nn.Module):
def __init__(self, mem_slots, mem_width, num_reads, controller_size):
super().__init__()
self.N = mem_slots # number of memory rows
self.W = mem_width # width of each row
self.R = num_reads # number of read heads
# All projections from controller hidden state
self.write_key_proj = nn.Linear(controller_size, mem_width)
self.write_vec_proj = nn.Linear(controller_size, mem_width)
self.erase_vec_proj = nn.Linear(controller_size, mem_width)
self.write_gate_proj = nn.Linear(controller_size, 1)
self.read_key_proj = nn.Linear(controller_size, mem_width * num_reads)
self.temp = nn.Parameter(torch.ones(1) * 2.0) # learned sharpness
DNCLLM Forward Pass
The key loop — stepping through time and interleaving memory reads/writes with transformer hidden states:
def forward(self, input_ids, memory, usage):
# Run all tokens through GPT-2 in parallel (causal masking handles ordering)
hidden_states = self.transformer(input_ids).last_hidden_state # (B, T, D)
all_logits, all_gates, all_ww = [], [], []
for t in range(input_ids.size(1)):
h_t = hidden_states[:, t, :] # (B, D) — current hidden state
# Memory interaction for this timestep
read_vec, memory, usage, write_gate, w_write = self.memory(h_t, memory, usage)
# Fuse read vector back into hidden state
h_out = h_t + self.read_proj(read_vec) # residual addition
all_logits.append(self.lm_head(h_out)) # project to vocab
all_gates.append(write_gate)
all_ww.append(w_write)
logits = torch.stack(all_logits, dim=1) # (B, T, V)
write_gates = torch.stack(all_gates, dim=1) # (B, T, 1)
w_writes = torch.stack(all_ww, dim=1) # (B, T, N)
return logits, memory, usage, write_gates, w_writes
Why the sequential loop? Memory has a causal dependency — memory[t] depends on what was written at steps 0..t-1. This can't be parallelized like self-attention. It's the main compute overhead of DNC over a pure transformer.
Memory Initialization
Memory is initialized to zeros at the start of each sequence:
def init_memory(self, batch_size, device):
memory = torch.zeros(batch_size, self.cfg.mem_slots, self.cfg.mem_width, device=device)
usage = torch.zeros(batch_size, self.cfg.mem_slots, device=device)
return memory, usage
This means memory is per-sequence, not persistent across batch items or between training steps. It acts as within-sequence working memory, not a cross-sequence knowledge base.
Metrics to Watch
During training, several metrics beyond loss reveal whether the memory system is working correctly:
| Metric | What It Tells You |
|---|---|
avg_gate |
Mean write gate activation. Should settle between 0.2–0.7; too high = writing everything, too low = never writing |
gate_std |
Gate polarization. High std means the model discriminates — writes on some tokens, skips others |
write_rate |
Fraction of timesteps with gate > 0.7. Tracks how aggressively the model uses memory |
write_sparsity |
How concentrated the write weighting is. High sparsity = sharp slot selection |
mem_kl |
KL divergence between memory and no-memory predictions. Non-zero means memory is changing outputs |
A healthy DNC should show high gate_std (selective writing) and high write_sparsity (concentrated writes), with non-trivial mem_kl (memory actually matters).
Practical Configuration
The config used in the experiments:
class Config:
# GPT-2 backbone
hidden_size = 768
num_layers = 6
num_heads = 8
seq_len = 128
# DNC memory
mem_slots = 64 # N: number of memory slots
mem_width = 128 # W: width of each slot
num_reads = 4 # R: number of read heads
# Loss weights
lambda_routing = 0.1
lambda_entropy = 0.05
# Training
batch_size = 4
lr = 3e-4
grad_clip = 1.0
Memory footprint: The external memory adds N × W = 64 × 128 = 8,192 floats per batch item — negligible compared to the model weights themselves. The overhead is in the sequential forward loop, not storage.
Parameter count: DNC adds roughly 5 × (D × W) parameters from the five projection matrices. At D=768, W=128 that's ~490K parameters — about 0.5% overhead on a 6-layer GPT-2.
Limitations and What's Next
This architecture is a proof of concept. Several known limitations:
Sequential bottleneck: The time-step loop cannot be parallelized. For long sequences, this significantly slows training relative to the pure-transformer baseline.
No cross-sequence persistence: Memory resets between sequences. A truly useful factual memory would persist across the lifetime of the model — closer to a retrieval-augmented generation (RAG) system.
Gradient flow through time: Backpropagating through T sequential memory steps can cause vanishing/exploding gradients for long sequences. Gradient clipping (grad_clip = 1.0) helps but doesn't solve it.
Potential extensions:
- Persistent memory: Keep a global memory matrix that accumulates knowledge across a training corpus and is frozen at inference time (like a learned knowledge base)
- Sparse attention writes: Replace soft write weighting with a top-k hard selection to reduce memory write diffusion
- Layer-wise memory: Attach a memory module to each transformer layer, not just the final hidden state
- Memory-augmented RAG: Use DNC writes as an online summary buffer, and retrieve from it alongside a static vector DB
Summary
| GPT-2 Baseline | GPT-2 + DNC | |
|---|---|---|
| Factual recall | Parametric only | Parametric + external memory |
| Memory type | Weights (static) | N×W matrix (dynamic, per-sequence) |
| Write mechanism | None | Content + allocation addressing |
| Selective writing | No | Yes (learned write gate) |
| Extra parameters | — | ~490K (~0.5%) |
| Training overhead | — | Sequential loop over T steps |
The DNC doesn't replace the transformer's parametric knowledge — it supplements it. The model learns when to trust its weights and when to externalise a fact to the notepad. On a small model operating in a domain with many precise facts, that notepad can make all the difference.
The write gate is the centrepiece of the design. When it fires on "Einstein" and "1879" and stays quiet on "was" and "the", you know the model has learned something non-trivial: not all tokens are worth remembering.
Github Code: https://github.com/AsishKumarDalal/memoryllm
Implementation: PyTorch. Dataset: WikiText-2. Backbone: GPT-2 (6 layers, 768 hidden, 8 heads). DNC config: N=64, W=128, R=4 read heads. Loss: L_lm + 0.1·L_routing + 0.05·L_entropy.



