Teaching Small Language Models to Remember: Giving LLMs a Notebook with Differentiable Neural Computers

Dev.to / 4/25/2026

💬 OpinionDeveloper Stack & InfrastructureIdeas & Deep AnalysisModels & Research

Key Points

  • Large language models can reliably recall factual knowledge because their huge parameter sets act as compressed internal storage, whereas small language models often forget or hallucinate facts due to limited capacity.
  • The article proposes addressing this limitation by giving small models explicit external memory—similar to how humans use notebooks—so the model can retrieve information when needed.
  • A Differentiable Neural Computer (DNC) provides a learnable, differentiable memory matrix that a neural controller can read from and write to using soft attention-based heads.
  • The described system combines a small GPT-2 controller with DNC memory and relies on content-based addressing plus usage-based allocation for storing and retrieving new information.
  • Because the read/write operations are differentiable, the whole architecture can be trained end-to-end with backpropagation, enabling practical “notebook-like” factual recall for small models.

"Large models memorize the world in their weights. Small models need a notepad."

The Problem: Small Models Forget Facts

Large Language Models (LLMs) like GPT-4 are remarkably good at recalling facts — "Delhi is the capital of India," "Einstein developed the theory of relativity" — because they have billions of parameters acting as a massive, compressed knowledge store. The model bakes facts into weights during pre-training, and retrieval is implicit in the forward pass.

But what happens when you shrink the model?

Small Language Models (SLMs) — the kind you can actually run on a laptop or edge device — have far fewer parameters. There simply isn't enough capacity to reliably encode factual associations. They can handle grammar, style, and short-range reasoning reasonably well, but ask them a factual question and they hallucinate, hedge, or go blank.

The parametric memory paradigm breaks down at small scale.

The Insight

Humans don't store all their knowledge in their neurons alone. We use external memory — notebooks, calendars, books, sticky notes. We offload facts to the environment and look them up when needed. The neural machinery handles reasoning; the notepad handles retrieval.

What if we gave a small language model an explicit, learnable notepad?

That's precisely what a Differentiable Neural Computer (DNC) does.

What Is a DNC?

A Differentiable Neural Computer, introduced by DeepMind in 2016, augments a neural network controller with an external memory matrix — a structured, differentiable store that the network can read from and write to via learned attention mechanisms.

Think of it as RAM for a neural network.

Memory Matrix  M  ∈  ℝ^(N × W)
                   │
          N = number of memory slots (rows)
          W = width of each slot (columns)

The controller (in our case, a small GPT-2) interacts with this memory through soft, differentiable read and write heads — so the whole system is end-to-end trainable with backpropagation.

Unlike a hash map or database, the DNC doesn't look up memory by exact key. It uses content-based addressing — cosine similarity between a query key and stored vectors — blended with usage-based allocation to decide where to write new information.

Architecture: GPT-2 + DNC Memory


The full model layers two components:

                    ┌─────────────────────────────┐
                    │       GPT-2 Backbone        │
                    │  (Masked Self-Attention +   │
                    │   Feed-Forward Layers)      │
                    └──────────────┬──────────────┘
                                   │  hidden state h_t  (B, D)
                                   ▼
                    ┌─────────────────────────────┐
                    │        DNC Memory           │
                    │  ┌─────────────────────┐    │
                    │  │  M ∈ ℝ^(N × W)      │    │  ← external RAM
                    │  └─────────────────────┘    │
                    │   write → read → update     │
                    └──────────────┬──────────────┘
                                   │  read_vec  (B, R*W)
                                   ▼
                         read_proj → h_t + read_vec
                                   │
                                   ▼
                              LM Head → logits

At each time step t, the GPT-2 hidden state h_t is used to:

  1. Write new information into memory
  2. Read relevant information back out
  3. Fuse the read vector with h_t before projecting to vocabulary logits

The memory persists across time steps within a sequence, making it a form of working memory — information written at step 3 can be retrieved at step 47.

The Memory Module: Read & Write Mechanics

Memory State

The memory at any step is a matrix M ∈ ℝ^(B × N × W) — a batch of N slots, each a W-dimensional vector. A usage vector u ∈ ℝ^(B × N) tracks how much each slot has been written to.

Projections from Hidden State

Given the controller hidden state h_t ∈ ℝ^(B × D), the memory module computes:

Projection Shape Purpose
write_key (B, W) Where to write (content addressing)
write_vec (B, W) What to write
erase_vec (B, W) What to erase before writing (sigmoid-gated)
write_gate (B, 1) How much to write (0 = skip, 1 = full write)
read_keys (B, R, W) Where to read from (R read heads)

Write Weighting

The write address w_write ∈ ℝ^(B × N) is a soft attention distribution over slots:

w_content = softmax( cosine(write_key, M) × τ )
w_alloc   = softmax( (1 − u) × τ )

w_write   = 0.5 × w_content + 0.5 × w_alloc
  • Content addressing (w_content): write near slots whose content resembles the current write key — useful for updating existing facts.
  • Allocation (w_alloc): prefer less-used slots — useful for storing new facts without overwriting old ones.

τ is a learned temperature parameter that sharpens or softens the distribution.

Write Operation

M_new = M × (1 − w_write ⊗ erase_vec) + w_write ⊗ write_vec

M_out = M + write_gate × (M_new − M)

The write_gate is the key knob:

write_gate ≈ 0  →  memory unchanged  (model relies on parametric knowledge)
write_gate ≈ 1  →  full write        (model externalizes knowledge)

This gate is learned entirely from data. The model discovers when it's worth writing.

Read Operation

w_read  = softmax( read_keys · M^T × τ )   ∈ ℝ^(B × R × N)
read_vec = w_read · M                       ∈ ℝ^(B × R × W)
         → reshape to (B, R*W)
         → projected back to (B, D) via read_proj

R read heads allow the model to simultaneously query R different "topics" from memory.

State Update

Usage is updated after each write so the allocator tracks which slots are "full":

usage_new = usage + (1 - usage) * w_write.detach()

The .detach() prevents gradients from flowing back through the usage signal — it's a bookkeeping variable, not a learned one.

The Write Gate: Knowing When to Remember

The write gate is the most interpretable component of the whole system. After training, you can run inspect_writes() and visualize per-token gate activations:

Token                  Gate   bar
────────────────────────────────────────────────
Albert                 0.821  ████████████████████████
Einstein               0.904  ███████████████████████████
was                    0.112  ███
born                   0.287  ████████
in                     0.094  ██
1879                   0.756  ██████████████████████
in                     0.071  ██
Ulm                    0.683  ████████████████████
He                     0.143  ████
developed              0.201  ██████
the                    0.058  █
theory                 0.388  ███████████
of                     0.062  █
relativity             0.712  █████████████████████

The model learns to write on content-bearing tokens (proper nouns, dates, key concepts) and skip function words. Nobody taught it this — it emerged from the loss functions.

Loss Functions

Training uses three losses summed together:

1. Language Modelling Loss (Cross-Entropy)

The standard next-token prediction loss:

L_lm = CrossEntropy(logits[:, :-1], input_ids[:, 1:])

This is the primary loss. The model must still predict the next token correctly.

2. Routing Loss

This loss asks: when the write gate is high, does memory actually change the prediction?

If the model writes to memory but the output distribution looks identical to the no-memory baseline, that write was pointless. The routing loss penalises this:

kl = KL( softmax(p_no_mem) || softmax(p_mem) ).detach()
L_routing = -(gate * kl).mean()

The KL divergence between the memory model and a frozen no-memory baseline is computed per token. Multiplied by the gate and negated, this loss:

  • Rewards high gates when memory changes the prediction (high KL)
  • Punishes high gates when memory doesn't matter (low KL → wasted write)

The .detach() on the KL ensures gradients only flow through the gate, not the no-memory logits.

3. Entropy Loss (Write Sparsity)

A diffuse write weighting — spreading activation uniformly across all N slots — is wasteful. It's like writing one word across every page of your notebook instead of a single page.

The entropy loss encourages sharp, decisive writes:

H = -(w_writes * (w_writes + 1e-8).log()).sum(-1).mean()
L_entropy = H   # minimized during training

Low entropy → sparse write attention → the model commits to specific slots.

Total Loss

L = L_lm + λ_r · L_routing + λ_e · L_entropy

# defaults: λ_r = 0.1,  λ_e = 0.05

The auxiliary losses are kept small relative to L_lm so language modelling remains the primary objective. The routing and entropy terms act as structural regularizers that shape how the memory is used, not just whether the model gets tokens right.

Code Walkthrough

DNCMemory Module

class DNCMemory(nn.Module):
    def __init__(self, mem_slots, mem_width, num_reads, controller_size):
        super().__init__()
        self.N = mem_slots   # number of memory rows
        self.W = mem_width   # width of each row
        self.R = num_reads   # number of read heads

        # All projections from controller hidden state
        self.write_key_proj  = nn.Linear(controller_size, mem_width)
        self.write_vec_proj  = nn.Linear(controller_size, mem_width)
        self.erase_vec_proj  = nn.Linear(controller_size, mem_width)
        self.write_gate_proj = nn.Linear(controller_size, 1)
        self.read_key_proj   = nn.Linear(controller_size, mem_width * num_reads)
        self.temp            = nn.Parameter(torch.ones(1) * 2.0)  # learned sharpness

DNCLLM Forward Pass

The key loop — stepping through time and interleaving memory reads/writes with transformer hidden states:

def forward(self, input_ids, memory, usage):
    # Run all tokens through GPT-2 in parallel (causal masking handles ordering)
    hidden_states = self.transformer(input_ids).last_hidden_state  # (B, T, D)

    all_logits, all_gates, all_ww = [], [], []

    for t in range(input_ids.size(1)):
        h_t = hidden_states[:, t, :]                     # (B, D) — current hidden state

        # Memory interaction for this timestep
        read_vec, memory, usage, write_gate, w_write = self.memory(h_t, memory, usage)

        # Fuse read vector back into hidden state
        h_out = h_t + self.read_proj(read_vec)           # residual addition

        all_logits.append(self.lm_head(h_out))           # project to vocab
        all_gates.append(write_gate)
        all_ww.append(w_write)

    logits      = torch.stack(all_logits, dim=1)         # (B, T, V)
    write_gates = torch.stack(all_gates, dim=1)          # (B, T, 1)
    w_writes    = torch.stack(all_ww, dim=1)             # (B, T, N)
    return logits, memory, usage, write_gates, w_writes

Why the sequential loop? Memory has a causal dependency — memory[t] depends on what was written at steps 0..t-1. This can't be parallelized like self-attention. It's the main compute overhead of DNC over a pure transformer.

Memory Initialization

Memory is initialized to zeros at the start of each sequence:

def init_memory(self, batch_size, device):
    memory = torch.zeros(batch_size, self.cfg.mem_slots, self.cfg.mem_width, device=device)
    usage  = torch.zeros(batch_size, self.cfg.mem_slots, device=device)
    return memory, usage

This means memory is per-sequence, not persistent across batch items or between training steps. It acts as within-sequence working memory, not a cross-sequence knowledge base.

Metrics to Watch

During training, several metrics beyond loss reveal whether the memory system is working correctly:

Metric What It Tells You
avg_gate Mean write gate activation. Should settle between 0.2–0.7; too high = writing everything, too low = never writing
gate_std Gate polarization. High std means the model discriminates — writes on some tokens, skips others
write_rate Fraction of timesteps with gate > 0.7. Tracks how aggressively the model uses memory
write_sparsity How concentrated the write weighting is. High sparsity = sharp slot selection
mem_kl KL divergence between memory and no-memory predictions. Non-zero means memory is changing outputs

A healthy DNC should show high gate_std (selective writing) and high write_sparsity (concentrated writes), with non-trivial mem_kl (memory actually matters).

Practical Configuration

The config used in the experiments:

class Config:
    # GPT-2 backbone
    hidden_size = 768
    num_layers  = 6
    num_heads   = 8
    seq_len     = 128

    # DNC memory
    mem_slots   = 64     # N: number of memory slots
    mem_width   = 128    # W: width of each slot
    num_reads   = 4      # R: number of read heads

    # Loss weights
    lambda_routing = 0.1
    lambda_entropy = 0.05

    # Training
    batch_size  = 4
    lr          = 3e-4
    grad_clip   = 1.0

Memory footprint: The external memory adds N × W = 64 × 128 = 8,192 floats per batch item — negligible compared to the model weights themselves. The overhead is in the sequential forward loop, not storage.

Parameter count: DNC adds roughly 5 × (D × W) parameters from the five projection matrices. At D=768, W=128 that's ~490K parameters — about 0.5% overhead on a 6-layer GPT-2.

Limitations and What's Next

This architecture is a proof of concept. Several known limitations:

Sequential bottleneck: The time-step loop cannot be parallelized. For long sequences, this significantly slows training relative to the pure-transformer baseline.

No cross-sequence persistence: Memory resets between sequences. A truly useful factual memory would persist across the lifetime of the model — closer to a retrieval-augmented generation (RAG) system.

Gradient flow through time: Backpropagating through T sequential memory steps can cause vanishing/exploding gradients for long sequences. Gradient clipping (grad_clip = 1.0) helps but doesn't solve it.

Potential extensions:

  • Persistent memory: Keep a global memory matrix that accumulates knowledge across a training corpus and is frozen at inference time (like a learned knowledge base)
  • Sparse attention writes: Replace soft write weighting with a top-k hard selection to reduce memory write diffusion
  • Layer-wise memory: Attach a memory module to each transformer layer, not just the final hidden state
  • Memory-augmented RAG: Use DNC writes as an online summary buffer, and retrieve from it alongside a static vector DB

Summary

GPT-2 Baseline GPT-2 + DNC
Factual recall Parametric only Parametric + external memory
Memory type Weights (static) N×W matrix (dynamic, per-sequence)
Write mechanism None Content + allocation addressing
Selective writing No Yes (learned write gate)
Extra parameters ~490K (~0.5%)
Training overhead Sequential loop over T steps

The DNC doesn't replace the transformer's parametric knowledge — it supplements it. The model learns when to trust its weights and when to externalise a fact to the notepad. On a small model operating in a domain with many precise facts, that notepad can make all the difference.

The write gate is the centrepiece of the design. When it fires on "Einstein" and "1879" and stays quiet on "was" and "the", you know the model has learned something non-trivial: not all tokens are worth remembering.

Github Code: https://github.com/AsishKumarDalal/memoryllm

Implementation: PyTorch. Dataset: WikiText-2. Backbone: GPT-2 (6 layers, 768 hidden, 8 heads). DNC config: N=64, W=128, R=4 read heads. Loss: L_lm + 0.1·L_routing + 0.05·L_entropy.