AI Navigate

モデルの不確実性空間が平坦か曲率を持つかを測定する方法

Dev.to / 2026/3/15

💬 オピニオンIdeas & Deep AnalysisTools & Practical UsageModels & Research

要点

  • 記事は、モデルの不確実性空間が平坦であるという仮定が曲率を持つ可能性を含んでおり、曲率がOOD検知、敵対的堅牢性、およびAI安全性に影響を与えると主張します。
  • ATICアーキテクチャの下で実装された独立した軸と対角距離メトリックを備えたAletheionLLM-v2の5次元エピステミック多様体のベースラインを紹介し、実行可能なPythonコードを添付します。
  • 曲率は単純な相関とは異なることを明確にし、軸間の関係を捉えるために対角メトリックから完全なマハラノビス距離計へ移行することを検討します。
  • リンク付きのGitHubリポジトリ経由で提供される曲率のテスト方法に関する実践的なガイダンスとコードを提示し、WikiText-103に対する較正メトリクスとOOD性能を、GPT-2 MediumやOPT-350Mと比較して競争力があることを報告します。

言語モデルにおけるリーマン的認識幾何学の実用ガイド(コード付き)。

ほとんどのキャリブレーション研究は不確実性をスカラーまたはベクトルとして扱います。信頼度スコアを計算し、それをグラウンドトゥルースと比較し、ECEを最小化します。その不確実性が存在する空間は平坦であると想定されます。

その仮定は間違っている可能性があります。そしてもしそれが間違っていれば、OOD検出、敵対的頑健性、AIの安全性に具体的な影響を及ぼします。

この投稿では、それをテストする方法を、私の現在の研究でのコードを用いて説明します。AletheionLLM-v2

ベースライン: 5次元の認識多様体における対角距離

AletheionLLM-v2 は、ATICと呼ばれる統合的な認識アーキテクチャを備えた、354Mパラメータのデコーダー専用LLMです。単一の信頼度スコアを出す代わりに、各軸が不確実性のそれぞれの成分を表す5次元の多様体をモデルが維持します。これは BayesianTau によって学習されます。

現在の距離指標(ブランチ main)は対角的です:

def distance_diagonal(x1, x2, tau_sq):
    diff = x1 - x2
    tau_sq_safe = np.maximum(tau_sq, 1e-8)
    return np.sqrt(np.sum(diff**2 / tau_sq_safe))

各軸にはそれぞれ学習された分散があります。軸は独立しています。空間は R5、再スケーリングされています。

これですでに良好に機能します。ECE 0.0176、Brier Score 0.1528、WikiText-103のOODにおける最上位で、GPT-2 MediumおよびOPT-350Mよりもエピステミックキャリブレーションで優れています。

しかし、対角線だけでは答えられない疑問があります。エピステミック空間には曲率があるのか?

曲率が相関とは異なる問いである理由

先に進む前に、1つの区別が重要です。

完全なマハラノビス距離では、G が定数の5x5行列として学習される場合、相関を捉えるが曲率は生じません。

G が定数の場合、Christoffel記号はすべてゼロです:

Gamma^k_ij = (1/2) g^kl (d_i g_jl + d_j g_il - d_l g_ij) = 0

ゼロの Christoffel 記号はゼロのリーマン曲率を意味します。空間は依然として平坦で、斜交座標のままです。測地線は依然として直線です。

真の曲率には、G は位置に応じて変化する必要があります。G(x) は定数行列ではなく、テンソル場でなければなりません。

ブランチ real_geodesic: 計量を場として作る

In the real_geodesic branch, a lightweight network (5 -> 32 -> 15, roughly 700 parameters) produces a position-dependent SPD tensor at every point in the manifold:

class MetricNet(nn.Module):
    def __init__(self, dim=5, hidden_dim=32):
        super().__init__()
        self.dim = dim
        self.n_chol = dim * (dim + 1) // 2  # 15 for dim=5
        self.net = nn.Sequential(
            nn.Linear(dim, hidden_dim),
            nn.Tanh(),  # Tanh, not ReLU -- G(x) must be smooth (C1)
            nn.Linear(hidden_dim, self.n_chol),
        )

        # Zero init on last layer -> G(x) ~ I at start
        nn.init.zeros_(self.net[-1].weight)
        nn.init.zeros_(self.net[-1].bias)

(このセクションの残りはコードブロックとして続く)

# 下三角形構築の事前計算インデックス tril_idx = torch.tril_indices(dim, dim) self.register_buffer("tril_row", tril_idx[0]) self.register_buffer("tril_col", tril_idx[1]) self.register_buffer("diag_idx", torch.arange(dim)) def forward(self, coords): """coords: [..., 5] -> G: [..., 5, 5] SPD""" raw = self.net(coords) # [..., 15] batch_shape = raw.shape[:-1] L = torch.zeros(*batch_shape, self.dim, self.dim, device=raw.device, dtype=raw.dtype) L[... , self.tril_row, self.tril_col] = raw # 正の対角は softplus + オフセット(exp ではなく、より安定) L[... , self.diag_idx, self.diag_idx] = ( F.softplus(L[... self.diag_idx, self.diag_idx]) + 1e-3 ) return torch.matmul(L, L.transpose(-1, -2)) # SPD が保証されます

設計の主なポイント:

  • tanh活性化 ReLU の代わり。G(x) は計量場であり、滑らかでなければなりません。ReLU は微分不可能な点を作り、Christoffel 記号を未定義にします。
  • softplus + 1e-3 on diagonal exp の代わり。訓練中に数値的に安定し、勾配の爆発を避けます。
  • 最後の層のゼロ初期化。初期化時、ネットワークはすべての入力に対してゼロを出力するため、G(x) は全体でおよそ 0.48 * I の状態から始まります。訓練は安定して開始します。

2つのエピステミック状態間の距離は Gauss-Legendre 求積法を用いた線積分として計算されます:

def line_integral_distance(self, p, q):
    """p: [B, T, 5], q: [5] -> 距離: [B, T, 1]"""
    if q.dim() == 1:
        q = q.unsqueeze(0).unsqueeze(0).expand_as(p)

    delta = q - p
    total = torch.zeros(p.shape[0], p.shape[1], 1,
                         device=p.device, dtype=p.dtype)
for i in range(self.n_quad): t = self.gl_points[i] w = self.gl_weights[i] x_t = + t * delta # point along straight line G_t = .forward(x_t) # G(x) at that point Gd = matmul(delta.unsqueeze(-2), G_t).squeeze(-2) integrand = (Gd * delta).sum(dim=-1, keepdim=True) total = total + w * sqrt(integrand.clamp(min=1e-8)) return total

One clarification worth being explicit about: this computes the length of the straight line between p and q under the varying metric, not the true geodesic (which would minimize path length and would be shorter). The true geodesic requires a shooting method or ODE solver. The straight-line approximation is differentiable, cheap (5 evaluations of MetricNet per distance), and sufficient to detect whether G(x) varies along the path -- which is the primary question.

When G depends on position, the Christoffel symbols are no longer zero. Geodesics are curves. The space has intrinsic curvature.

The experiment: three branches, one falsifiable question

Branch Metric Geometry
main G = diag(tau) Flat, orthogonal axes
full_mahalanobis G = constant 5x5 Flat, oblique axes
real_geodesic G(x) = learned field Potentially curved

The test uses three categories of input pairs:

probes = {
    \"high_confidence\": [
        (\"The capital of France is\", \"Paris\"),
        (\"2 + 2 =\", \"4\"),
    ],
    \"low_confidence\": [
        (\"The exact number of neurons in the human brain is\", \"86\"),
    ],
    \"context_sensitive\": [
        (\"The bank was steep and\", \"muddy\"),    # bank = riverbank
        (\"The bank was closed and\", \"dark\"),    # bank = institution
        (\"He left the plant near\", \"water\"),    # plant = vegetation
        (\"He left the plant near\", \"the door\"), # plant = factory
    ]
}

The context-sensitive pairs are the key. Same surface token, different semantic region of the manifold. If G(x) learned real structure, the geodesic distance between \"bank=riverbank\" and \"bank=institution\" will be larger than the distance between two within-domain contexts, even though the diagonal distance would treat them similarly.

Detecting curvature directly: metric variation along a path

def measure_metric_variation(metric_net, x_start, x_end, n_samples=20):
    G_samples = []

"}for t in np.linspace(0, 1, n_samples):
        x_t = x_start + t * (x_end - x_start)
        x_tensor = torch.tensor(x_t, dtype=torch.float32).to(device)
        G_t = metric_net(x_tensor.unsqueeze(0).unsqueeze(0))
        G_samples.append(G_t[0, 0].cpu().numpy())

    G_stack = np.stack(G_samples)
    variation = np.std(G_stack, axis=0)

    print(f"Mean metric variation: {variation.mean():.6f}")
    print(f"Max element variation: {variation.max():.6f}")
    print(f"Verdict: {'CURVED' if variation.max() > 0.01 else 'FLAT'}")

    return variation

もし G が高信頼状態から低信頼状態へと経路に沿って変化するなら、多様体には非自明な局所幾何学が存在します。もし定数へ収束するなら、対角は近似ではなく、根本的な理由で正しかったのです。

各結果の意味

real_geodesic が G(x) を概ね定数として学習する場合:

354M の LLM の認識的多様体は本質的に平坦です。対角計量は怠惰な近似ではありませんでした。幾何学的に正しかったのです。ECE 0.0176 は真のキャリブレーションを反映しており、部分空間のアーティファクトではありません。

G(x) が構造的変動を学習する場合:

多様体には異なる幾何学を持つ領域が存在します。対角座標で等距離に見える2つの認識状態は、測地距離が非常に異なることがあります。これは直接的な影響をもたらします:

  • OOD検出は幾何学的信号を得ます。高曲率領域に着地する入力は、赤チーミングで似た入力が現れたかどうかに関係なく構造的に異常です。
  • キャリブレーションの閾値は局所的になり、グローバルにはなりません。平坦な領域は自信を持つべきです。高曲率の領域は慎重さを要し、幾何学が真の値を見ずにどちらかを示します。
  • 訓練コーパスは幾何学的署名を残します。有害な内容で訓練されたモデルは悪意を持つようにはなりません。むしろ、多様体が平坦で十分にサンプルされている場所では、有害な出力が幾何学的に安価になるシステムになります。これは、明示的な有害意図よりも構造的に異なる、より懸念される故障モードです。

トレーニングに関する考慮事項

MetricNet は 354M モデルに約 700 個のパラメータを追加します。これらのパラメータに到達する勾配信号は本質的に弱いです。これに対処する二つの対策があります:

1. 学習率を別々に設定する。 MetricNet はベースの LR の 10 倍を取得します(5e-4 対 5e-5)。これがないと、G(x) は空間が平坦だから収束したのではなく、構造を学習する信号が弱すぎたためアイデンティティに収束してしまう可能性があります。

2. 平滑性正則化。 入力座標の小さな摂動の下での G の変動に対するペナルティです:

def metric_smoothness_loss(metric_net, coords, eps=0.01):
    G = metric_net(coords)
    noise = torch.randn_like(coords) * eps
    G_perturbed = metric_net((coords + noise).clamp(0, 1))
    return (G - G_perturbed.detach()).pow(2).sum(dim=(-2, -1))

もし G が高信頼状態から低信頼状態へと経路に沿って変化するなら、多様体には非自明な局所幾何学が存在します。もし定数へ収束するなら、対角は近似ではなく、根本的な理由で正しかったのです。

各結果の意味

real_geodesic が G(x) を概ね定数として学習する場合:

354M の LLM の認識的多様体は本質的に平坦です。対角計量は怠惰な近似ではありませんでした。幾何学的に正しかったのです。ECE 0.0176 は真のキャリブレーションを反映しており、サブスペースのアーティファクトではありません。

G(x) が構造的変動を学習する場合:

多様体には異なる幾何学を持つ領域が存在します。対角座標で等距離に見える二つの認識状態は、測地距離が非常に異なることがあります。これには直接的な影響があります:

  • OOD検出は幾何学的信号を得ます。高曲率領域に落ちる入力は、レッドチーミングで似た入力が現れたかどうかに関係なく構造的に異常です。
  • キャリブレーション閾値は局所的になり、グローバルにはなりません。平坦な領域は信頼を示します。高曲率の領域は慎重さを示し、幾何学が真値を見ずにどちらかを示します。
  • 訓練コーパスは幾何学的署名を残します。有害な内容で訓練されたモデルが悪意を持つようにはなりません。むしろ、多様体が平坦で十分にサンプルされている場所では、有害な出力が幾何学的に安価になるシステムになります。これらは明示的な有害意図よりも構造的に異なる、より懸念される故障モードです。

この点がないと、G(x) は線積分を数値的に不安定にし、勾配をノイズ状にする不連続性を学習する可能性があります。

数値積分の安定性についてのノート

実装はデフォルトで 5 点の Gauss-Legendre 点を使用し、効率のために事前計算されたノードと重みを用います。Tanh 活性化は高周波数の変動を起こりにくくしますが、収束を検証できます:

def check_quadrature_convergence(metric_net, x1, x2,
                                  n_points_list=[5, 8, 16]):
    for n in n_points_list:
        t_nodes, weights = np.polynomial.legendre.leggauss(n)
        t_nodes = (t_nodes + 1) / 2
        weights = weights / 2

        dx = x2 - x1
        total = 0.0
        for t, w in zip(t_nodes, weights):
            x_t = x1 + t * dx
            x_tensor = torch.tensor(x_t, dtype=torch.float32).to(device)
            G_t = metric_net(x_tensor.unsqueeze(0).unsqueeze(0))
            G_np = G_t[0, 0].cpu().numpy()
            ds2 = dx @ G_np @ dx
            total += w * np.sqrt(max(ds2, 1e-12))

        print(f"  n={n:2d}: distance = {total:.6f}")

距離が 5 点と 16 点の間で安定しない場合、指標には高周波の局所的変動があります。Tanh を使えば、多くの多様体の幾何学に対して 5 点で十分であるはずです。

現状と再現性

公開リポジトリには 3 つのブランチすべてがライブ状態です。ベースライン(ブランチ main)は完全に再現可能です。トレーニングコード、評価スクリプト、完全な方法論を含む論文はすべて公開されています。

3 ブランチの比較の結果は、トレーニングが完了次第、ここおよび ResearchGate に公開されます。

キャリブレーション、OOD 検出、または言語モデルの不確実性に対する幾何学的アプローチに取り組んでいる方とはお話ししたいと考えています。リポジトリは公開されており、方法論は完全に文書化されています。

Felipe Maya Muniz は AletheionAGI の創設者であり、AI システムにおける認識論的自己認識のための幾何学的認知アーキテクチャである ATIC を開発する独立研究者です。