小型言語モデルに記憶を教える:LLMへ微分可能なノートブック(Differentiable Neural Computers)を与える

Dev.to / 2026/4/25

💬 オピニオンDeveloper Stack & InfrastructureIdeas & Deep AnalysisModels & Research

要点

  • 大規模言語モデルは内部の膨大なパラメータが圧縮された知識ストレージとして機能するため、事実を比較的確実に想起できますが、小型言語モデルは容量が限られるため事実を忘れたり幻覚を起こしたりしがちです。
  • その限界に対する解決として、人間がノートやメモを使って必要なときに参照するのと同様に、小型モデルへ明示的な外部メモリ(ノート)を与えて検索できるようにする提案が示されています。
  • Differentiable Neural Computer(DNC)は、学習可能で微分可能なメモリ行列を提供し、ニューラルコントローラがソフトな注意機構ベースのリード/ライトでそこから読み書きできるようにします。
  • 記事で説明されるシステムは、小型のGPT-2コントローラとDNCメモリを組み合わせ、情報の保存・検索にコンテンツベースのアドレッシングと使用状況に基づく割り当てを用います。
  • 読み書きが微分可能であるため、全体を誤差逆伝播でエンドツーエンド学習でき、小型モデルでも「ノートのような」事実想起を実現しやすくなります。

“大規模モデルはその重みの中に世界を記憶する。小規模モデルはメモ帳が必要だ。”

問題:小規模モデルは事実を忘れる

GPT-4のような大規模言語モデル(LLM)は、事実の想起が非常に得意です——たとえば「デリーはインドの首都である」「アインシュタインは相対性理論を発展させた」——なぜなら、それらは数十億のパラメータをもち、大規模で圧縮された知識ストアとして機能しているからです。モデルは事前学習の間に事実を重みに焼き込み、検索は順伝播(forward pass)の中で暗黙的に行われます。

しかし、モデルを縮小したらどうなるのでしょうか?

小規模言語モデル(SLM) —— 実際にノートパソコンやエッジデバイス上で動かせるようなタイプです —— は、パラメータ数がはるかに少ないです。事実と事実の関連付けを確実に符号化するのに十分な容量が単純にありません。文法、文体、短距離の推論は比較的うまく扱えますが、事実に関する質問をすると、幻覚を見せたり、言いよどんだり、何も答えなくなったりします。

パラメトリック・メモリ(parametric memory)のパラダイムは、小さなスケールでは破綻します。

着眼点

人間は、知識のすべてをニューロンだけに保存しているわけではありません。私たちは外部メモリ—— ノート、カレンダー、本、付箋など —— を使います。知識となる事実を環境に委ね、必要になったときに参照します。ニューラルな仕組みは推論を担当し、メモ帳が検索(retrieval)を担当します。

では、小規模言語モデルに明示的で学習可能なメモ帳を与えたらどうでしょうか?

それをまさに実現するのが微分可能ニューラル計算機(Differentiable Neural Computer: DNC)です。

DNCとは何か?

DeepMindが2016年に提案した微分可能ニューラル計算機は、ニューラルネットワークのコントローラに外部メモリ行列を追加します。これは、ネットワークが学習した注意機構(attention mechanisms)を通じて読み書きできる、構造化され微分可能な保存領域です。

ニューラルネットワークにとってのRAMだと思ってください。

メモリ行列  M  ∈  ℝ^(N × W)
                   │
          N = メモリスロット数(行)
          W = 各スロットの幅(列)

コントローラ(ここでは小規模なGPT-2)は、このメモリとソフトで微分可能な読み書きヘッドを介して相互作用します——したがって、バックプロパゲーションによりシステム全体をエンドツーエンドで学習可能です。

ハッシュマップやデータベースとは異なり、DNCは厳密なキーでメモリを引き当てません。コンテンツベースのアドレッシング—— クエリキーと保存されたベクトル同士の余弦類似度(cosine similarity)—— を使用し、さらに情報を書き込む場所を決めるために使用量に基づく割り当てとブレンドします。

アーキテクチャ:GPT-2 + DNCメモリ


このモデル全体は2つのコンポーネントを重ねて持ちます:

                    ┌─────────────────────────────┐
                    │       GPT-2 バックボーン    │
                    │  (マスク付き自己注意 +   │
                    │   フィードフォワード層)  │
                    └──────────────┬──────────────┘
                                   │  隠れ状態 h_t  (B, D)
                                   ▼
                    ┌─────────────────────────────┐
                    │        DNC メモリ           │
                    │  ┌─────────────────────┐    │
                    │  │  M ∈ ℝ^(N × W)      │    │  ← 外部RAM
                    │  └─────────────────────┘    │
                    │   write → read → update     │
                    └──────────────┬──────────────┘
                                   │  read_vec  (B, R*W)
                                   ▼
                         read_proj → h_t + read_vec
                                   │
                                   ▼
                              LM Head → logits

各時刻ステップtで、GPT-2の隠れ状態h_tは次のために使われます:

  1. 新しい情報をメモリに書き込む
  2. 関連する情報を読み出す
  3. 読み出したベクトルをh_t融合し、語彙(vocabulary)のlogitsに射影する前に組み合わせる

メモリは、系列内の各時刻ステップにまたがって保持されます。これによりworking memoryの一種になります——ステップ3で書き込まれた情報は、ステップ47で取り出せます。

メモリモジュール:読み取り・書き込みの仕組み

メモリ状態

任意のステップにおけるメモリは行列M ∈ ℝ^(B × N × W)です——つまり、N個のスロットからなるバッチで、それぞれがW次元のベクトルです。使用量ベクトルu ∈ ℝ^(B × N)は、各スロットがどれくらい書き込まれてきたかを追跡します。

隠れ状態からの射影

コントローラの隠れ状態h_t ∈ ℝ^(B × D)が与えられると、メモリモジュールは次を計算します:

射影 形状 目的
write_key (B, W) どこに 書き込むか(コンテンツアドレッシング)
write_vec (B, W) 何を 書き込むか
erase_vec (B, W) 書き込む前に消すもの(sigmoidでゲート)
write_gate (B, 1) どれくらい 書き込むか(0 = スキップ、1 = 完全に書き込み)
read_keys (B, R, W) どこから読み出すか(R個の読み取りヘッド)

書き込みの重み付け

書き込みアドレスw_write ∈ ℝ^(B × N)は、スロットに対するソフトな注意分布です:

w_content = softmax( cosine(write_key, M) × τ )
w_alloc   = softmax( (1 − u) × τ )

w_write   = 0.5 × w_content + 0.5 × w_alloc
  • コンテンツアドレッシングw_content):現在の書き込みキーに内容が似ているスロットの近くに書き込む——既存の事実を更新するのに有用です。
  • 割り当てw_alloc):使用量が少ないスロットを優先する——古い情報を上書きせずに新しい事実を保存するのに有用です。

τは、分布を尖らせたり(sharp)なだらかにしたり(soft)する学習可能な温度パラメータです。

書き込み操作

M_new = M × (1 − w_write ⊗ erase_vec) + w_write ⊗ write_vec

M_out = M + write_gate × (M_new − M)

write_gateが主要な調整ノブです:

write_gate ≈ 0  →  メモリは不変  (モデルはパラメトリックな知識に依存)
write_gate ≈ 1  →  完全な書き込み  (モデルが知識を外部化する)

このゲートはデータのみから学習されます。モデルはいつ書き込む価値があるのかを見つけ出します。

読み取り操作

w_read  = softmax( read_keys · M^T × τ )   ∈ ℝ^(B × R × N)
read_vec = w_read · M                       ∈ ℝ^(B × R × W)
         → reshape して (B, R*W)
         → read_proj を通して (B, D) に射影し直す

R 読み取りヘッドにより、モデルはメモリから R の異なる「トピック」を同時に照会できます。

状態更新

書き込みのたびに使用量が更新されるため、アロケータはどのスロットが「満杯」かを追跡します:

usage_new = usage + (1 - usage) * w_write.detach()

.detach() は、使用量(usage)信号を通じて勾配が逆流するのを防ぎます。これは学習される変数ではなく、帳簿管理用の変数です。

書き込みゲート:いつ覚えるべきかを知る

書き込みゲートは、システム全体の中で最も解釈しやすいコンポーネントです。学習後は inspect_writes() を実行して、トークンごとのゲート活性を可視化できます:

トークン                 ゲート   bar
────────────────────────────────────────────────
Albert                 0.821  ████████████████████████
Einstein               0.904  ███████████████████████████
was                    0.112  ███
born                   0.287  ████████
in                     0.094  ██
1879                   0.756  ██████████████████████
in                     0.071  ██
Ulm                    0.683  ████████████████████
He                     0.143  ████
developed              0.201  ██████
the                    0.058  █
theory                 0.388  ███████████
of                     0.062  █
relativity             0.712  █████████████████████
aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa

モデルは、内容を持つトークン(固有名詞、日付、重要な概念)では書き込むことを学び、機能語はスキップします。誰もそのように教えたわけではありません。これは損失関数の中から自然に現れたものです。

損失関数

学習では、合計3つの損失を使用します:

1. 言語モデリング損失(交差エントロピー)

標準的な次トークン予測の損失です:

L_lm = CrossEntropy(logits[:, :-1], input_ids[:, 1:])

これは主要な損失です。モデルは次のトークンを正しく予測し続ける必要があります。

2. ルーティング損失

この損失はこう問いかけます:書き込みゲートが高いとき、メモリは実際に予測を変えるのか?

もしモデルがメモリに書き込んでも、出力分布がメモリなしのベースラインと同じに見えるなら、その書き込みは無意味です。ルーティング損失はこれを罰します:

kl = KL(softmax(p_no_mem) || softmax(p_mem) ).detach()
L_routing = -(gate * kl).mean()

メモリモデルと固定されたメモリなしベースラインの間の KL ダイバージェンスをトークンごとに計算します。ゲートで掛け、符号を反転させることで、この損失は:

  • 高いゲート値に報酬:メモリが予測を変える(KLが高い)とき
  • 高いゲート値を罰則:メモリが重要でない(KLが低い → 無駄な書き込み)とき

KL に対する .detach() により、勾配は no-memory ロジットではなく、ゲートのみを通って流れます。

3. エントロピー損失(書き込みの疎性)

広く拡散した書き込み重み—すべての N スロットに活性を一様に広げる—は無駄です。これは、ノートのすべてのページに1語を書き込むのに等しく、1ページに書くだけではありません。

エントロピー損失は、鋭く決断的な書き込みを促します:

H = -(w_writes * (w_writes + 1e-8).log()).sum(-1).mean()
L_entropy = H   # 学習中に最小化される

エントロピーが低いほど—疎な書き込み注意—モデルは特定のスロットにコミットします。

合計損失

L = L_lm + λ_r · L_routing + λ_e · L_entropy

# デフォルト:λ_r = 0.1、 λ_e = 0.05

補助的な損失は L_lm に比べて小さく保たれるため、言語モデリングが主要な目的として維持されます。ルーティング項とエントロピー項は、モデルがトークンを当てるかどうかだけでなく、メモリをどのように使うかを形作る構造的正則化として働きます。

コード解説

DNCMemory モジュール

class DNCMemory(nn.Module):
    def __init__(self, mem_slots, mem_width, num_reads, controller_size):
        super().__init__()
        self.N = mem_slots   # メモリ行数
        self.W = mem_width   # 各行の幅
        self.R = num_reads   # 読み取りヘッド数

        # コントローラの隠れ状態からの全ての射影
        self.write_key_proj  = nn.Linear(controller_size, mem_width)
        self.write_vec_proj  = nn.Linear(controller_size, mem_width)
        self.erase_vec_proj  = nn.Linear(controller_size, mem_width)
        self.write_gate_proj = nn.Linear(controller_size, 1)
        self.read_key_proj   = nn.Linear(controller_size, mem_width * num_reads)
        self.temp            = nn.Parameter(torch.ones(1) * 2.0)  # 学習されたシャープネス

DNCLLM フォワードパス

キーとなるループ——時間方向に進めながら、メモリの読み書きをトランスフォーマーの隠れ状態と交互に行う:

def forward(self, input_ids, memory, usage):
    # すべてのトークンを GPT-2 に並列で通す(因果マスキングが順序を処理する)
    hidden_states = self.transformer(input_ids).last_hidden_state  # (B, T, D)
    all_logits, all_gates, all_ww = [], [], []

    for t in range(input_ids.size(1)):
        h_t = hidden_states[:, t, :]                     # (B, D) — 現在の隠れ状態
        # このタイムステップにおけるメモリの相互作用
        read_vec, memory, usage, write_gate, w_write = self.memory(h_t, memory, usage)

        # 読み出しベクトルを隠れ状態へ合成する
        h_out = h_t + self.read_proj(read_vec)           # 残差加算
        all_logits.append(self.lm_head(h_out))           # 語彙へ投影
        all_gates.append(write_gate)
        all_ww.append(w_write)

    logits      = torch.stack(all_logits, dim=1)         # (B, T, V)
    write_gates = torch.stack(all_gates, dim=1)          # (B, T, 1)
    w_writes    = torch.stack(all_ww, dim=1)             # (B, T, N)
    return logits, memory, usage, write_gates, w_writes

なぜ逐次ループなのか? メモリには因果的な依存関係があります——memory[t] はステップ 0..t-1 で書き込まれた内容に依存します。自己注意のようには並列化できません。純粋なトランスフォーマーに対する DNC の主要な計算オーバーヘッドです。

メモリの初期化

メモリは各シーケンスの開始時にゼロで初期化されます:

def init_memory(self, batch_size, device):
    memory = torch.zeros(batch_size, self.cfg.mem_slots, self.cfg.mem_width, device=device)
    usage  = torch.zeros(batch_size, self.cfg.mem_slots, device=device)
    return memory, usage

これは、メモリがシーケンスごとであることを意味します。バッチ内の各アイテム間、または学習ステップ間で永続化されるわけではありません。これは、シーケンス内の作業用メモリであって、シーケンスをまたいだ知識ベースではありません。

Metrics to Watch

学習中、損失以外にもいくつかの指標によって、メモリシステムが正しく機能しているかどうかを判断できます:

Metric What It Tells You
avg_gate 平均の書き込みゲート活性。0.2〜0.7の間に落ち着くはずです。高すぎる = 全てを書き込んでいる、低すぎる = 一度も書き込まない
gate_std ゲートの分極。標準偏差が大きいほど、モデルは識別できていることを意味します — あるトークンでは書き込み、他はスキップします
write_rate ゲートが > 0.7 となるタイムステップの割合。モデルがメモリをどれだけ積極的に使っているかを追跡します
write_sparsity 書き込み重みがどれだけ集中しているか。スパース性が高い = スロット選択が鋭い
mem_kl メモリありの予測と、メモリなしの予測の間のKLダイバージェンス。0でない場合、メモリが出力を変えていることを意味します

健全なDNCは高いgate_std(選択的な書き込み)と高いwrite_sparsity(書き込みが集中している)を示し、さらに非自明なmem_kl(実際にメモリが効いている)を持つはずです。

Practical Configuration

実験で使用したconfig:

class Config:
    # GPT-2 のバックボーン
    hidden_size = 768
    num_layers  = 6
    num_heads   = 8
    seq_len     = 128

    # DNC のメモリ
    mem_slots   = 64     # N: メモリスロット数
    mem_width   = 128    # W: 各スロットの幅
    num_reads   = 4      # R: 読み取りヘッド数

    # 損失の重み
    lambda_routing = 0.1
    lambda_entropy = 0.05

    # 学習
    batch_size  = 4
    lr          = 3e-4
    grad_clip   = 1.0

メモリのフットプリント: 外部メモリは、バッチアイテムあたり N × W = 64 × 128 = 8,192 個の浮動小数点数を追加します — モデルの重みそのものと比べればごくわずかです。オーバーヘッドはストレージではなく、逐次フォワードループにあります。

パラメータ数: DNCは5つの射影行列からおおよそ 5 × (D × W) 個のパラメータを追加します。D=768, W=128 のとき、これは約490Kパラメータで、6層のGPT-2に対して約0.5%のオーバーヘッドです。

Limitations and What's Next

このアーキテクチャは概念実証です。いくつかの既知の制限があります:

逐次ボトルネック: タイムステップのループは並列化できません。長いシーケンスでは、純粋なトランスフォーマのベースラインに比べて学習が大幅に遅くなります。

シーケンス間の永続性なし: メモリはシーケンス間でリセットされます。真に有用な事実のメモリは、モデルのライフタイムを通して永続化されるべきで、検索拡張生成(RAG)システムに近い形になります。

時間を通じた勾配伝播: T 個の逐次メモリステップを通して誤差逆伝播(backprop)すると、長いシーケンスでは勾配消失/爆発が起き得ます。勾配クリッピング(grad_clip = 1.0)は役立ちますが、問題を完全には解決しません。

潜在的な拡張:

  • 永続メモリ: 学習コーパス全体で知識を蓄積し、推論時に凍結するグローバルなメモリ行列を保持する(学習済みの知識ベースのように)
  • スパースな注意による書き込み: ソフトな書き込み重み付けを top-k のハード選択に置き換えて、メモリ書き込みの拡散を減らす
  • 層ごとのメモリ: メモリモジュールを最終の隠れ状態だけでなく、各トランスフォーマ層に取り付ける
  • メモリ拡張 RAG: DNCの書き込みをオンラインの要約バッファとして使い、静的なベクトルDBと並行してそこから検索する

Summary

GPT-2 ベースライン GPT-2 + DNC
事実の想起 パラメトリックのみ パラメトリック + 外部メモリ
メモリの種類 重み(静的) N×W 行列(動的、シーケンスごと)
書き込み機構 なし コンテンツ + 配分(アロケーション)アドレッシング
選択的な書き込み いいえ はい(学習された書き込みゲート)
追加パラメータ 約490K(約0.5%)
学習オーバーヘッド T ステップにわたる逐次ループ

DNCはトランスフォーマのパラメトリックな知識を置き換えるのではなく、それを補完するものです。このモデルは、重みをどのタイミングで信頼し、事実をメモ帳(notepad)に外部化するべきかを学習します。多くの正確な事実がある領域で動作する小さなモデルでは、このメモ帳がすべての違いを生み得ます。

書き込みゲートは設計の中心です。「Einstein」や「1879」ではゲートが発火し、「was」や「the」では沈黙しているなら、モデルが非自明な何かを学習したことが分かります: すべてのトークンが記憶する価値を持つわけではない

Github コード: https://github.com/AsishKumarDalal/memoryllm

実装: PyTorch。データセット: WikiText-2。バックボーン: GPT-2(6層、768 hidden、8ヘッド)。DNC設定: N=64、W=128、R=4つの読み取りヘッド。損失: L_lm + 0.1·L_routing + 0.05·L_entropy。