スープ餃子が冷める前にレストランに着ける?(そして機械学習にある他の問題)

Dev.to / 2026/5/5

💬 オピニオンIdeas & Deep Analysis

要点

  • この記事は、遅刻しがちな自分の体験を使って、現実世界が非決定的である状況で信頼できる予測を行う難しさを機械学習の問題として説明しています。
  • 学習(教師あり学習)では、人間のように「経路」を理解しているわけではなく、GPSのような手順ではなく、到達できたかどうかのスカラー損失にもとづいてパラメータを調整する仕組みだと述べています。
  • 重要な対比として、この記事の主眼は学習の後、つまり学習済みモデルが、モデルが完全には予測できない変動する環境の中で動作するときに何が起きるかだとしています。
  • 理想ケースを前提に計画すると現実とのズレが生まれ、MLモデルでも環境の変化によって期待される挙動が崩れるのと似た構図になる点を強調しています。
  • 「経路」という比喩については、そのまま当てはめるのは難しいとし、モデルは道を暗記しているのではなく、典型的な条件のもとで正しい目的地に到達するような写像を学んでいるのだと位置づけています。

私は慢性的に遅刻します。失礼にしたいわけではありません。毎回それは本当に嫌な気持ちになるのに、どれくらいで目的地に着くかの見積もりがあまりにも壊滅的に下手だからです。

ところが機械学習アルゴリズムにも、まったく同じ問題があります。

どうやって起きるかというと:夕食は19時です。レストランがどこにあるかは分かっています。頭の中には、完全に明確なルートがあります。オフィスのドア → 廊下 → エレベーター → 通り → 地下鉄 → 徒歩 → レストラン。とても定義がはっきりしています。ドアtoドアで14分としましょう。

問題は、気が散ることです。何かの課題に深く取り組んで、記事を読んだり、何かをデバッグしたりしています。ふと気づくと18時。次に18時30分。さらに18時45分。そして私は考えます。「まあ14分だし、6時46分に出れば大丈夫だな。」

でも実際のルートは非決定的です。

オフィスの掃除担当がいるから、遠回りしなきゃいけないの? エレベーターが混んでる? 地下鉄が遅れてる? それとも30秒差で乗り遅れた? 雨で、傘を大きくさした人たちが歩道にたくさんいるの?

私が自信を持って「14分の旅」と考えているものが、実際には25分かかるかもしれません。理想のケースで間に合うだけの余裕を見て出ます。だって頭の中ではそう計画しているから。おめでとう、今度は「自分ではどうにもできない事情」のせいで、誰かを10〜15分待たせてしまいました(そして小籠包が固まり始めています)。

あるいは現実的に言えば、まったく自分の管理下にあった事情です。もっと早く出ればよかったんです。ごめんなさい、みんな。

午前3時なら、信じられないくらい速い移動になるかもしれません。プエルトリコの独立記念日のパレード中なら、ずっと遅いでしょう。でも、その瞬間の流れがどうなっているかを正確に理解していない限り、そうしたことは分かりません。

どれだけ自分でも知らないルート

[本題に入る前の短い寄り道:機械学習にも「ルート」の別の問題があって、それを認めておく必要があります。なぜなら、この論文が解こうとしていることが何か(=何ではないか)を理解してほしいからです。]

訓練中、教師あり学習は、人間が「ルート」を知っているような意味では、どこかのルートを「知って」いません。ルート自体が何なのかすら、ほとんど理解していないのです。目標(「犬か猫か」)を与えると、最後にスカラーの損失が返ってきて、パラメータが調整されます。GPSはありません。分岐ごとの手順の指示もありません。ただ、到着したかどうかを示すスコアがあるだけです。

レストランの問題に似ていますが、アルゴリズムは並列で何十億もの異なるルートを試して、うまくいったものを覚えます。最良のルートがクイーンズを通って、そこからまた下に戻ってくるのかどうかなんて、分かりません。単に大規模な探索をしているだけです。

ここでのルートという比喩は少し無理があります。システムは通りを暗記しているわけではありません。通常は正しいレストランに着地させてくれる関数を学習しているのです。これは訓練の問題――解釈可能性の問題です。興味深い。未解決。けれども、この投稿が扱う話ではありません。

この投稿が扱うのは、訓練の「後」で何が起きるかです。訓練済みモデルが手に入ったあと――つまり「ルートを知っている」と思ったあとでも――別のまったく違う問題が残ります。同じ移動時間を2度と得られないという問題です。これを解決したのがThinking Machines Labです。では、いつものプログラミングに戻りましょう。

もっと大きな問題:ルートを知っていても

私のように――何度レストランに行っても、同じルートを何十億回も試そうとはしない人と違って――機械学習のアルゴリズムは、並列でいろいろなことを試せます。人間ができるよりはるかに多くの回数を試すことができます。だからモデルは一般にルートについて良い感覚を持っています。分かってしまっているのです。

しかし本当の非決定性の問題はこうです。あなたは「まったく同じルート」を取っていると思う。つまり同じ廊下、同じエレベーター、同じ地下鉄、同じ通り、同じ曲がり順の列。ただし、実際にはそうではありません。

一人で移動するときは、グループで移動するときとは異なるタイミングを別の地点で計測します。SEQUENCE(順序)は同じ。でもMEASUREMENT POINTS(計測地点)が違う。そしてストップウォッチがラップタイムを丸める仕方のせいで、計測地点が違えば最終的な時間も変わります。旅そのものが揺らぐというより、グループの人数によって、計測のしかたが変わるのです。

大きくはありません。せいぜい1分か2分程度。でも、まったく同一にはなりません。

同じルート。違う時間。毎回。

これがどれほどイライラするか想像できますよね。機械が好きなのは、質問に対して「ちょうど同じ答え」を返すことです。(人間もそうですが、わずかな変化にはより寛容です)。そして、物事を何十億回もやるなら、まさに同じ質問に対して、まさに同じ答えが必要になります。

みんなが最新の話題や「これが革命だ」と言っていることは分かっていますが、Thinking Machines Lab は1か月ほど前に、私が本当に大きな転換点だと感じていることを発表しました。彼らは「Defeating Nondeterminism in LLM Inference」を公開し、単に問題を説明しただけではありません。毎回、まったく同じルートがまったく同じ時間になるようにする方法を見つけたのです。

ビジネスモデルはありませんが、すべての推論エンジンがすぐにそれを取り入れると信じざるを得ません。

誰もが「起きている」と思っていたこと(でも間違っていた)

何年もの間、人々は非決定性を「ランダム」な並列性のせいにしてきました。つまり、結合則を満たさない浮動小数点演算と、GPU上での予測不能なスレッドスケジューリングです。たとえるなら「幽霊渋滞」のようなもの。そう聞こえます。並列で足し算すると順番が変わるので、丸め誤差も変わる。もっともらしいですよね。説明としてはこれで終わり。

どこでも繰り返されている受け入れられた説明は、次のとおりです:

"GPUにおける浮動小数点演算は結合則を満たさないため、有限の精度と丸め誤差のせいで $(a+b)+c eq a+(b+c)$ となる。GPUは多数のスレッドにまたがって演算を並列に実行するため、実行順序は予測できない。このランダムな順序が、毎回異なる丸めパターンを生み、非決定性につながる。"

この説明だと、あなたが毎回旅をするときに、見えない予測不能な交通がランダムな通りを遅くすることになります。数学そのものに、スレッドスケジューリングからくる「ランダムさ」が最初から組み込まれている。それは起こり得そうです。よし、これで閉廷。私たちにできることはない。

ところがThinking Machinesのチーム――Horace Heほか――は、うまく噛み合わない何かに気づきました。彼らはこの単純な実験を行いました:

A = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
B = torch.randn(2048, 2048, device='cuda', dtype=torch.bfloat16)
ref = torch.mm(A, B)
for _ in range(1000):
    assert (torch.mm(A, B) - ref).abs().max().item() == 0

GPU上での行列積。同じ行列。1,000回連続で実行した。結果は毎回、ビット単位で完全に同一でした。

浮動小数点の非結合性について、従来の見方が間違っていたわけではありません。これは本当です。ですが、その本質的な洞察を見落としていました。非決定性が“ランダム”ではないのです。予測不能なタイミングで突然現れる、幻の渋滞が起きているわけではありません。私たちが一緒に乗せている人数(バッチサイズ)によって、ストップウォッチの「計測の仕方」(内部での集計手順)が変わってしまっている、ということなのです。

単一のGPU演算を単独で実行すれば、それは完全に決定的です。問題が現れるのは、バッチ設定に応じて演算を合成(COMPOSE)する方法を変えるときだけです。異なるバッチサイズは推論エンジンに対して計算のグループ化の仕方を変え、その結果、演算の順序が変わります。すると丸め(rounding)のパターンも変わります。これはランダムではなく、体系的です。

では、ランダムではないなら、実際に何が起きているのでしょうか?

本当の原因

問題は不気味なランダム性ではなく、加算順(reduction order)です。レストランまでの道を思い出してください。同じ経路で進むとしても、各区間の所要時間を足し合わせる方法をいろいろ変えると、ある程度なら秒単位で誤差が出ます。

極小の数値差 → 温度0での別のargmax → 分岐するトークン。

では夕食の旅の例で考えましょう。14分の移動を、ナノ秒単位まで計算する超高精度なデジタルストップウォッチで計測するとします。たとえ表示は分と秒だけだったとしてもです。

  • シナリオ1:1件のリクエスト。 全行程を1つの区間として計測します。机の上で「start」を押し、レストランで「stop」を押します。ストップウォッチは1つの正確な所要時間を計算します:14.1123173271 分。
  • シナリオ2:3件のリクエストをバッチ化。 3つの区間を別々に計測することにします:(1) オフィスから地下鉄まで、(2) 地下鉄での移動、(3) レストランまでの徒歩。各段階で「lap」ボタンを押します。

ここが重要です。浮動小数点の計算が非結合であるため、つまり $(a+b)+c$ は $a+(b+c)$ と比べて、微視的に異なる数になり得ます。ストップウォッチ内部のチップがラップ時間を足し合わせる方法が違うので、結果も変わります。たとえば (lap1 + lap2) + lap3 のように計算され、最終的な所要時間が 14.1123173274 分になるかもしれません。

差は小数点以下の12桁あたりの丸め誤差です。私には完全に知覚できません。しかしそれは別の数です。

これは、vLLMのような推論サーバでもまさに起きます。GPUの利用率を最大化するため、リクエストをバッチで処理します。

  • 1つのシーケンスを処理する? GPUは計算の1セットをグループ化して実行します。
  • 10のシーケンスを処理する? より効率的になるように、計算を別の仕方でグループ化します。

それぞれのグループ化が、演算の順序を変えます。ランダムではなく、バッチサイズに基づいて体系的かつ決定的に変わるのです。異なる順序は、浮動小数点の丸めパターンを変えます。その結果、わずかに異なる数値結果が生まれます。するとモデルは「最も確からしい」トークンとして別のものを選ぶ可能性があり、そこから先はまったく別の出力へと連鎖します。

問題は幻の渋滞ではありませんでした。バッチに応じて記録するラップ回数が変わると、私たちのストップウォッチが返す測定値が変わってしまう——それが原因だったのです。

加算順は実際に重要 である理由

浮動小数点の演算は結合的ではありません。これはバグではなく、数学です。GPU固有の話でもありません。すべての浮動小数点の計算にこの性質があります。GPUが関係するのは、GPUが演算を並列に処理し、その並列性によって、順序が明示的に制御されない限り「演算の順序」が固定されないからです。

論文では完璧な例が示されています:

(0.1 + 1e20) - 1e20 = 0
0.1 + (1e20 - 1e20) = 0.1

コンピュータは無限に大きくないので、実数(無限精度の数)を有限精度で表現します。丸めは各演算で発生し、演算順序が異なると丸めパターンも異なります。

Kahanの加算が存在するのは、素朴な加算では精度が失われてしまうからです。数値解析という分野全体が存在するのも、こうした細部が重要だからです。

もし私が旅の区間をこのように足し合わせるなら:A + B + C、合計はある値になるかもしれません。しかしもし別のグループ化をして (A + B) + C のようにすると、頭の中の計算における丸めがわずかに別の結果を生む可能性があります。機械学習の文脈では、これを「reduction strategy(縮約戦略)」と呼びます。

論文では、潜在的な解決策を提示しています:

"バッチ不変性(batch invariance)に必要なのは、カーネルのバッチサイズに関係なく、各要素に対する縮約順序(reduction order)が固定されていなければならないということです。なお、これは常に同じ縮約戦略を使わなければならないという意味ではありません。たとえば、縮約対象の要素数を変えたとしても、縮約戦略が変わっていようともバッチ不変であり得ます。"

修正が必要な3つの操作

解決策は簡単に聞こえます。つまり、人数(グループ化する数)に関係なく、毎回同じ経路を通るようにする。実際には、3つの基本的なGPU演算がどのように動いているかを見直す必要があります。

操作1:RMSNorm(ストップウォッチが区間を足し合わせる方法)

"標準的な実装では、削減(reduction)を複数のワーカーに分散して並列化します……分割ポイントが異なれば、削減の順序も異なります。"

RMSNormとは? RMS(Root Mean Square)正規化は、LLaMAのようなモデルでベクトルに対するスケーリング係数を計算するために使われます。これは、何千もの値を1つの数にするために、削減(合計)を行う必要があります。

問題はこれです: 標準的な実装では、作業を複数のGPUワーカーに分割して並列化します。しかし、バッチサイズが異なるとワーカー数も異なり、その結果、加算のパターンも異なります。

ストップウォッチの比喩: あなたは旅の間に1,000の区間タイムを記録しました。ストップウォッチはそれらをどう合計するのでしょうか?

  • 一人で移動(小さいバッチ、GPUワーカー数が少ない): ストップウォッチは各区間タイムをペアで足します—((segment1+segment2) + (segment3+segment4)) + ((segment5+segment6) + (segment7+segment8))。このように、ペア同士を結び、次にペア同士のペアを結び…という形でツリーを構築し、最終的に上へと積み上げていきます。
  • 友達と一緒に移動(大きいバッチ、GPUワーカー数が多い): もっと速く行くために、4つずつのグループで足すかもしれません—(segment1+segment2+segment3+segment4) + (segment5+segment6+segment7+segment8)。こちらはまったく別のツリー構造になります。

浮動小数点演算は結合則を満たさないため、この2つの加算ツリーは、顕微鏡的に異なる合計値を生み出します。

同じ数でも、グループの仕方が違えば、最終的な答えも違う。

解決策: バッチサイズに関係なく、加算ツリーの構造を固定します。たとえ小さいバッチでGPUコアの一部がアイドルになるとしても、常に同じやり方でラップタイムを結合します。これは、潜在的な効率を諦めてでも、決定性(determinism)を保証する取引です。研究、デバッグ、そしてRLトレーニングでは?価値があります。スケールした本番の推論サービングでは?まだ難しいかもしれません。どれを選ぶかはあなた次第です。

Operation Two: Matrix Multiplication (How the Laps are Defined)

"最新のGPUカーネルは演算をタイル状に分割します……タイルサイズは通常、利用可能な並列性に依存します。バッチ内のシーケンスが増える?タイルは大きくなります……タイルが異なれば、蓄積(accumulation)のパターンも異なります。"

ストップウォッチの比喩: これは区間タイムの足し方そのものの話ではありません。そもそも、区間の区切り(boundary)をどこに置くかの話です。

  • 小さいバッチ: 限られた並列性を効率よく使うため、システムは「区間」を2ブロックごとにすると決めます。14ブロックの旅は、7つの同じ長さのラップとして計測されます。
  • 大きいバッチ: 並列性がより利用可能なので、効率のためにラップのサイズが4ブロックに変更されます。すると同じ14ブロックの旅は、3つの完全な区間と、1つの部分ラップになります。

あなたは依然として同じ14ブロックを移動しています。ですが、それらを異なる区間に分けるため、異なる中間和が生まれ、異なる丸めパターンが生まれ、異なる最終合計が生まれます。

解決策: バッチ設定に関係なく、タイルサイズを固定します。区間は常に正確に2ブロックです。

同じセグメンテーション、同じ中間値、同じ丸め、同じ結果。GPU効率の一部は犠牲になります—大きいタイルならそれらのコアはもっと働けたかもしれません—その代わり決定性を買い戻します。

重要なニュアンス: ほかのすべて(ハードウェア、ドライバ、ライブラリ、テンソルの形状、フラグ)を一定に保つと、単独のGPUによる行列乗算は、実行のたびにビット単位で同一になります。「ランダムなスレッドスケジューリングがランダムな結果を引き起こす」という話は、正確には少し違います。本当の問題は、バッチ設定に応じて部分結果の集約(aggregation)が異なるやり方になってしまうことです。偶然ではなく、体系的なものです。

Operation Three: Attention (The Actually Hard Problem)

論文にある最悪のシナリオはこうです:

"KVキャッシュに80トークンあって、そこに48個の新しいトークンを処理しているとします。ブロックサイズが32なら、キャッシュ済みの値に対して3ブロック(2つのフルと1つのマスク)必要で、新しい値に対して2ブロック(1つのフルと1つのマスク)必要—合計で128要素に対して5ブロックです。ですが、キャッシュ済みトークンがゼロで、128を一度に全部処理していたら?必要なのは4ブロックだけ。要素数は同じでも、削減(reduction)の組み立て方が完全に違います。"

Attention は、モデルが異なるトークンの重要度を重み付けできるようにします。FlashAttentionPagedAttentionのような現代的な手法は、とりわけ(KVキャッシュ内の)既存トークンと新規トークンの混在を扱う場合に非常に最適化されています。

ストップウォッチの比喩(ここで無理が出てきています): レストランに到着して、平均の移動時間を計算しようとしているところを想像してください。区間タイムのリストが2つあります。すでに到着している人(KVキャッシュ)からのリストと、今あなたが自分の旅で走っている分(新規トークン)のリストです。

システムはこれらのリストを別々に処理して、それぞれを合計してから合算するかもしれません。あるいは、まず1つのリストにマージしてから、すべてを合計するかもしれません。どちらの戦略を選ぶかは、それぞれのタイプがいくつあるかに依存します。グループ化の仕方が異なれば、加算の順序が異なり、丸めも異なり、結果も異なります。

ここは正直に言います: すべてを夕食の比喩に無理やり押し込むのは、ここでは役に立たなくなってきています。Attentionメカニズムはそれだけで十分に複雑で、比喩は良いより害のほうが大きいです。重要なのは次の点です:

  1. 最初にKVキャッシュを更新する—注意計算(attention calculations)が始まる前に、古いトークンと新しいトークンが1つの一貫したメモリ配置に存在するようにする。
  2. 固定の分割サイズを使い、分割回数を固定しない。 「何があってもこれを4等分する」のではなく、「各チャンクはちょうど32トークン」と言う。100トークンを処理しようが128トークンを処理しようが、チャンクの境界はデータに対する位置が同じままなので、一定の削減パターンが得られます。

これには PyTorchのFlexAttentionへの貢献(変更) が必要でした—つまり、この修正がどれほど深いかを示しています。賢いアプリケーションレベルのコードをどれだけ工夫しても直せません。プリミティブを変える必要があります。

The Soup Dumplings Experiment

ここからが本番です:

"Qwen/Qwen3-235B-A22B-Instruct-2507を使い、温度0でプロンプト 'Tell me about Richard Feynman'(non-thinking mode)に対して1000件の完了(completions)をサンプリングし、それぞれ1000トークンを生成します。"

温度0は簡単なモードのはずです。常に次に来る最も可能性の高い単一のトークンを選ぶのだから。創造性も、ランダム性もありません。完全に決定的であるべきです。同じプロンプト。同じ温度。同じモデル。

非決定性はサンプリング手順ではありません—そこは問題なく動作しています。問題は、どのトークンが「最も可能性が高いか」を決めるログit(logits)を計算するFORWARD PASSにあります。バッチサイズが変わると、そのlogitsの計算方法が変わり、結果として勝つトークンが変わります。結果が変わるのです。

標準の vLLM では:

  • 1000回の実行から80種類の異なる出力
  • 最も多い出力が78回(全体の8%未満の実行!)
  • トークン103で初めて分岐
  • 992件の完了が同じことを言った
  • 8件の完了が別のことを言った

同じプロンプト。 同じ温度。 同じモデル。 なのに結果は違う。

しかし、バッチ不変(batch-invariant)のカーネルを有効にしたとき:

"...私たちの1000件の完了はすべて同一です。これは、サンプラーによって数学的にこのようなことが起きるはずだと我々が期待する内容です。しかし、バッチ不変カーネルがなければ、決定論的な結果を達成することはできません。"

同じルート。 同じ時間。 いつも。毎回。

1000回の実行。 1つの出力。 ビット単位で同一。

頭の中で予測したのは、毎回そこに着く“ちょうどその時間”でした。

パフォーマンス上のトレードオフ

ただし、これはタダでもらえるものではありません。最初の実装では、約2倍遅く動きます(55秒 対 26秒)。最適化すると1.6倍遅くなります(42秒)。

論文は率直です:

"スローダウンの大部分は、vLLMにおけるFlexAttentionの統合がまだ大きく最適化されていないことに起因します。それにもかかわらず、パフォーマンスが壊滅的でないことを確認しています。"

1.6倍遅いのは許容できるでしょうか? 条件次第です。

数十億件のクエリを本番で処理するために、ミリ秒が重要な場合? まだ難しいかもしれません。

再現性が何より重要な研究の場合? もちろんです。

厳密な繰り返し可能性が必要なモデル開発・テストの場合? 異論なしです。

人間からのフィードバックによる強化学習RLHF)で、ポリシードリフトが学習を壊す可能性がある場合? それは必要であって、選択肢ではないかもしれません。

真のオンポリシーRL:大きな解放

ここから先は、「いいインフラ改善」にとどまらず、「何が可能になるかを根本的に変える」領域です。論文はほとんど何気なくこの点を投げていますが、そのインパクトは深刻です:

"研究者が指摘しているように、学習時と推論時で数値が異なることが、暗黙的に私たちのオンポリシーRLをオフポリシーRLへと変えてしまいます。... [D]決定論的推論により、サンプリングと学習の間でビット単位で同一の結果を得るために学習スタック自体も修正できるため、その結果として真のオンポリシーRLが実現します。"

なぜこれがそんなに大きな話なのかを理解するには、用語を素早く整理する必要があります。

オンポリシーRLとオフポリシーRLの違いは何ですか?

強化学習では、ポリシーとはエージェントの戦略のことです。ここで言う戦略は、たとえばレストランへの特定の行き方(ルート)です。

オンポリシー: 今まさにあなたが取っている“そのルート”から学びます。6番電車に乗る、遅いから、その「この時刻の6番電車は遅い」ということを学ぶ。つまり、改善しているポリシーは、経験を集めるために使っているのと同じポリシーです。

オフポリシー: 今あなたが乗っているルートとは別のルートについて学びます。あなたは6番電車に乗っているけれど、スマホでF番電車の状況を確認する。実際にF番電車に乗らずに、F番電車の性能について学んでいるのです。

調整ミスの時計の問題

文章を生成する(サンプリング)ことと、それを学習すること(学習)の間で数値がドリフトしてしまうことで、オンポリシー手法が偶然オフポリシーのようなものになってしまいます。

レストランへのルートを最適化しようとしているのに、時計が毎回ほんの少しずつズレているせいで、「変化が本当に役に立ったのか、それとも測定が単に揺れただけなのか」が確信できない、というのに似ています。

一般的な解決策は、重要度重み付けと呼ばれるパッチで、ドリフトを数学的に補正しようとします。しかし本質的には、そもそも存在するはずのない問題を“つぎはぎ”しているだけです。

時計を調整する

バッチ不変カーネルは、計算の「ルート(道筋)」が毎回ビット単位で同一になることを保証することで、この問題を解決します。これにより真のオンポリシーRLが成立します。論文の結果は目を引きます:

  • パッチなし: モデルの性能がすぐに崩れます。
  • パッチあり(重要度重み付け): 学習は動きますが不安定で、小さなドリフトの周りでぐらつきます。
  • 真のオンポリシー(バッチ不変): ドリフトはゼロのフラットラインになります。学習は完全に安定します。

論文が述べているように:

"...「真のオンポリシーRL」を実行しているとき、KLダイバージェンスが0のままフラットに保たれます。つまり学習ポリシーとサンプリングポリシーの間に発散が存在しないことを示しています。"

私の時計はついに調整されました。 同じルート、同じ時間。 これで本当に最適化できます。

これは単なる机上の話ではありません。LLM開発のポストトレーニング全体に直接影響し、RLHFDPO のような手法を、より安定かつ信頼性の高いものにします。

実際に今、何が変わるのか

具体的に言います。多くの人が「THIS CHANGES EVERYTHING(すべてが変わる)」と叫びますが、具体的なユースケースがなければ、それだけではあまり意味がありません。

研究の再現性が“現実”になる

現時点では、モデルAがモデルBより優れていることを検証するには、複数回の試行を行い統計的有意性を計算する必要があります。温度0ではモデルが本質的にランダムだからではありません。測定装置が一貫していないからです。

決定論的推論があると:

  • 厳密な範囲でのA/Bテスト
  • 確実に検証可能な主張になる
  • 再現実験が実際にまったく同じ結果になる
  • メタ分析で実装差を気にしなくてよくなる

AI研究における再現性危機は、一部には研究者が詳細を共有しないことが原因です。しかしそれだけではありません。実装が微妙に違ってしまうことも原因です。バッチ不変カーネルは、その変動要因を1つ取り除きます。

デバッグが桁違いに楽になる

モデルが悪い出力を生成したとき、それを正確に再現できるということは、次のことができるという意味です:

  • 中間アクティベーションの検査も含めて、実行をステップごとに追跡する
  • 計測用のインストゥルメンテーションを追加し、まったく同じ結果を再実行する
  • トークン列を二分探索して、どこで問題が起きたのかを突き止める

非決定論的なシステムでは、デバッグが確率的になります。「たいていXをするが、時々Yをする」といった状況は、開発者にとって最悪の悪夢です。

決定論になると、デバッグは体系的になります。毎回の実行が同一です。通常のデバッグツールをすべて使え、次回も同じものが得られると信頼できます。

キャッシュが“弾丸のように”堅牢になる

本番のLLM提供では、攻めたキャッシュが使われます。よくあるクエリ、プレフィックスキャッシュ、共通のKVキャッシュを用いたcontinuous batching――これらはいずれも「同一の入力なら同一の出力が得られる」と仮定しています。

しかし非決定性があると、この前提は漏れます。キャッシュヒット率は、本来あるべき値よりも低くなります。

決定論的推論なら:

  • 同一入力に対する完璧なキャッシュヒット率
  • 中間計算をキャッシュして再利用できる
  • キャッシュ無効化ロジックをよりシンプルにできる

何百万件ものクエリをさばく企業にとって、これはそのままインフラコストの削減につながります。

モデルテストが精密になる

LLMの品質保証は現在、統計的です。「このまったく同一の入力に対して、モデルはこのまったく同一の出力を出さねばならない」と言い切れるテストを書くことはできません。厳密な出力を保証できないからです。

決定論的推論があると:

  • モデル更新が特定の振る舞いを変えたことを検出する、精密な回帰テストを書く
  • フレークしない包括的なテストスイートを構築する
  • モデルバージョン間の差分テストを自信を持って行える

テストを信頼できるということは、より速い反復と、本番での驚きを減らすことにつながります。

研究の速度が加速する

おそらく最大の効果はこれです。研究者は統計に費やす時間が減り、その分、科学により多くの時間を使えるようになります。

現在、かなりのML研究時間が以下に費やされています:

  • 統計的パワーを得るのに十分な試行回数を回すこと
  • さまざまな分散(ばらつき)の要因を補正すること
  • その差が有意かどうかをめぐって議論すること

決定論的推論(deterministic inference)では、主要な分散要因の1つを取り除けます。実験はよりシンプルになります。結果はより明確になります。計測手法に費やす時間が減り、実際の研究課題により多くの時間を回せるようになります。

それは、バージョン管理があなたのコードを直接よくするわけではないのに、摩擦や連携のためのオーバーヘッドを取り除くことでチームがより速く動けるようにするのと似ています。決定論的推論は、研究プロセスから摩擦を取り除きます。

これが重要なインフラである理由

私はこの投稿を、移動時間を予測できないせいで夕食に慢性的に遅れてしまう話から始めました。機械学習も同じ問題を抱えていますが、さらに悪い形です。たとえ経路がわかっていても、時間を予測できません。

Thinking Machines Labは、この仕組みを独自のものとして抱え込むこともできたはずです。「サービスとしての決定論的推論(Deterministic Inference as a Service)」を作り、高額なプレミアム価格を設定し、そして堀(もとい)を築けたでしょう。

しかし彼らは次のようにしました:

ここにはビジネスモデルがありません。そもそもあってはいけないからです。

論文は次のように結論づけています:

「現代のソフトウェアシステムには、多くの階層的な抽象化が含まれています。機械学習では、決定論性の欠如や微妙な数値差に行き当たったとき、ついそれらを“紙の上で”処理して済ませたくなることがあります。結局、私たちのシステムはすでに“確率的”なのだから、もう少し決定論性の欠如があっても何が問題でしょうか? 失敗したユニットテストでatol/rtolを少し上げればいいのでは? トレーナーとサンプラーの間のlogprobsの差は、おそらく本当のバグではないでしょう? 私たちはこの敗北主義を退けます。」

「私たちはこの敗北主義を退けます。」 いい言葉です。

誰もが決定論性の欠如は避けられないものとして受け入れました。回避策を作り、許容誤差を調整しました。問題を解決するのではなく、その問題を回避するようにエコシステム全体が適応していったのです。

Thinking Machines Labは問いかけました。「本当に解決したらどうなるだろう?」 「影響を最小化する方法は?」「それに対して統計的に補正する方法は?」ではなく、「根本原因は何で、それをどうやって取り除くのか?」でした。

これはシステム思考をインフラに適用したものです。問題はそもそも浮動小数点演算やGPUの並行性そのものではありません。バッチ構成が異なる状況で、仕事をどう編成するかです。解決策は浮動小数点の挙動と戦うことではなく、文脈に関わらず一貫した運用上の順序(operational ordering)が保証されるようにすることです。

より深いパターン

ほとんどの「インフラ改善」は、既存のものを速くしたり安くしたりすることです。学習を高速化する。配信コストを削減する。モデルを圧縮する。これらは重要です。FlashAttentionは重要です。量子化も重要です。効率的なアーキテクチャも重要です。

しかし、時々、問題そのものがあまりに根本的で、そもそも問題だと見えなくなるほどのものを誰かが修正します。人々は適応し、回避策を作り、問題は陣地の一部になっていきます。まるで、埋めるよりも避けると学んでしまう穴(ポットホール)のように。

  • Kubernetesはコンテナを速くしたわけではありません。あらゆる企業にとってコンテナオーケストレーションがカスタム地獄になることを防いだのです。
  • Gitはコードを良くしたわけではありません。コラボレーションが調整の地獄にならないようにしました。
  • Rustはシステムプログラミングを速くしたわけではありません。メモリ安全性を、ガベージコレクションなしで実現できるようにしました。

これらは基盤です。問題の“クラス”全体を取り除きます。回避するのではなく、そもそも考える必要がなくなるのです。

バッチ不変カーネルは、LLM推論の再現性のためにこれを実現します。 これは回避策ではありません。パッチでもありません。「ドリフトを補正するために重要度付けを追加しよう」ではありません。解決策です。問題は存在しなくなります。

回避策と解決策の違いはこれです。回避策は、まだその問題について考え続ける必要があることを意味します。解決策は、その問題があなたの頭の中のモデルから完全に消えることを意味します。最後に、あなたのテキストエディタがUnicodeを正しく扱えるかどうかを考えたのはいつでしょう? 考えないはずです。なぜなら、その問題がインフラの段階で解決されて、あなたは考えるのをやめてしまったからです。

決定論的推論がまさにそれを行います。「あれ、なぜ今回は挙動が違ったんだ?」という問いかけのための丸ごと1カテゴリーを取り除きます。デバッグセッションは短くなります。テストはより信頼できるようになります。研究はより再現可能になります。決定論性の欠如への回避が上手くなったからではありません。決定論性の欠如が、回避しなければならない“もの”ではなくなったからです。

Thinking Machines Labはそれを理解していました。ここでビジネスを作り得ることも分かっていました。しかし、この革新は誰にでも当たり前に使われるようになってこそ、より価値が高いことも理解していました。自由に利用可能にすることで、あらゆる推論エンジンがこれらの技術を採用できます。あらゆる研究ラボが結果を再現できます。あらゆる本番環境のデプロイが一貫した挙動を得られます。

基盤が盤石で共有されていると、分野はより速く進みます。

次に何が起きるか

バッチ不変カーネルは、次のようなところに登場してくると思います:

  • vLLMにまずはオプトインのフラグとして入り、その後はデフォルトになる可能性
  • TensorRT-LLMに数か月以内に入る
  • Text Generation Inference(HuggingFace)
  • llama.cppでローカル推論
  • 主要クラウド事業者の配信(サービング)インフラ

これは長くは競争上の差別化要因にはならないでしょう。推論がそういうふうに動く“標準”になっていくだけです。HTTPSがかつては任意だったのが、いまでは当然の期待になっているのと同じです。Unicode対応がかつては機能だったのが、いまでは当然だと思われているのと同じです。

それがまさに正しい方向です。

私たちは途方もないエネルギーを、ブレークスルーとなるモデルのアーキテクチャについて語ることに使っています。Mixture of ExpertsState Space ModelsLong context。これらは重要です。アーキテクチャは重要です。

しかし、こうしたインフラ――派手ではなく、技術的に深く、しかも自由に共有されるもの――こそが、体系的な進歩を可能にします。基盤が盤石なら、その上に作られるものすべてがより信頼できるようになります。測定が一貫していれば、最適化が可能になります。実験が再現可能なら、科学は機能します。

2つの問題、1つの解決策

これを現実に引き寄せて話しましょう。私たちは、2つの異なる非決定性(nondeterminism)の問題から始めました:

問題 #1: ルートが分からない。 学習中、機械学習は「何が重要か」を教えてくれません。最後に損失スコアは得られますが、その途中でのGPS追跡はありません。これは解釈可能性の問題であり、学習の問題であり、「モデルが学習したことを理解する」という問題です。ですが、まだ未解決です。人々は取り組んでいます――機械的解釈可能性、アトリビューション手法、どの特徴が効いているのかの理解。次のフロンティアです。

問題 #2: ルートは分かっていても、毎回違う時間になる。 学習済みモデルがあります。重みも分かっています。やっているのは推論だけ――同じ順伝播、同じ計算です。ですが、バッチサイズが操作のまとまり方を変え、そのまとまり方が縮約(reduction)の順序を変え、それが丸め(rounding)を変え、ひいては結果を変えてしまいます。温度0の非決定性です。そこではランダム性がゼロであるべきなのに。結果を再現できませんでした。信頼できる形でデバッグできませんでした。本当の意味でのオンポリシーRL(on-policy RL)ができませんでした。フレークしないテストを書くこともできませんでした。

Thinking Machines Labは問題#2を解決しました。単なる回避策ではなく、解決策で。根本原因(バッチ設定が操作のグループ化を変えていたこと)を突き止め、カーネルレベルで修正しました。これで同じルートを取れば、同じ時間が返ってきます。ビット単位で完全に同一、100%の確率で、1000回連続でも同じです。

私たちはまだ問題#1を解決できていません。ですが今では、正確に測定できます。そして測定は科学の土台です。安定して一貫して測れないものは、最適化できません。

私の腕時計がようやく動きました。同じルート、同じ時間、毎回。

さて、失礼。予約があるので行かないといけません。さらに、決定論的推論があるなら、私は絶対的な確実さをもって、いつ到着するかを正確にお伝えできます。

(それでも遅れます。ですが、少なくとも遅れるという事実は再現可能になります。)

前へ。

完全な論文「Defeating Nondeterminism in LLM Inference」には、広範な技術的詳細、ベンチマーク、アブレーション(比較削除)研究が含まれています。バッチ不変(batch-invariant)のカーネル実装は以下で利用できます。 github.com/thinking-machines-lab/batch-invariant-ops。関連する取り組みとして、 FlexAttentionの改善 はPyTorchにアップストリームされています。

これはかなり熱い話です。ぜひ読んでください!

MLシステムにおいてなぜ決定論的計算が重要なのかについて、さらに詳しくは Reproducible Machine Learning Numerical Reproducibility in HPC、そしてAIにおけるより広い レプリケーション危機をご覧ください。

知的なデータパイプラインが、AIコストをどのように削減できるのか学びたいですか? Expansoをチェック。それとも別にいいです。私があなたに何をすべきかを指示する立場ではありません。*

注記: 現在、機械学習のためのデータ準備における現実的な課題(主に運用、コンプライアンス、コスト)について、私が見てきたことをもとに本を書いています。 ぜひあなたの考えも聞かせてください

もともとは Will I Make It To The Restaurant Before The Soup Dumplings Get Cold? (And Other Problems In Machine Learning) にて公開されました。