MedQA:AMD ROCmで臨床AIをファインチューニング—CUDA不要

Hugging Face Blog / 2026/5/8

💬 オピニオンDeveloper Stack & InfrastructureTools & Practical UsageModels & Research

要点

  • この記事は、AMD GPUとROCmソフトウェアスタックを使って、CUDAに頼らずにファインチューニングできる臨床AI「MedQA」を紹介しています。
  • CUDAを使わずに、開発者がAMDハードウェア上で臨床AIのワークフローを適応・運用できる点を強調しています。
  • 本文は、ROCm対応のセットアップでモデルを動かし調整するための実践的な手順に焦点を当てており、実行可能性を示しています。
  • このアプローチは、臨床AI開発や実験をGPUベンダーをまたいで広げる手段として位置付けられています。
  • 全体として、CUDAに偏りがちな手順でもROCmで成立させ、プラットフォームのロックインを減らすことが主旨です。

MedQA: AMD ROCm上で臨床AIをファインチューニング—CUDA不要

チーム 記事 2026年 5月8日公開

AMD MI300Xを使用し、lablab.aiのAMD Developer Hackathon向けに構築した、MedMCQAでのQwen3-1.7BのLoRAファインチューニングの完全な手順解説です。

アイデア

医療の質問応答は、リスクが本当に高いタスクの一つです。臨床のMCQで自信を持って誤った答えを選んでしまうモデルは、単に間違いなだけでなく、危険です。同時に、ほとんどのオープンソースの医療AIの取り組みでは、NVIDIA GPUがあることを前提にしています。CUDAがデフォルトで、それ以外は後回しにされています。

このプロジェクトは、その前提に挑戦します。

MedQAは、ROCmを使ってAMDのハードウェア上で完全に構築された、LoRAでファインチューニング済みの臨床向け質問応答モデルです。複数選択式の医療問題を入力として、正しい回答の文字 、推論のための臨床的な説明の両方を返します。データローディングからアダプタのエクスポートまで、学習の全パイプラインは、CUDAへの依存を一切含まず、AMD Instinct MI300X上で動作します。


なぜAMD ROCm?

AMD Instinct MI300Xは驚くべきハードウェアです。単一デバイスでHBM3メモリが192 GB搭載されています。LLMのファインチューニングでは、VRAMがしばしばボトルネックになります。VRAMは、バッチサイズ、シーケンス長、そしてそもそも量子化が必要かどうかを左右します。利用可能な192 GBがあったため、4-bitや8-bitの量子化の小手先は使わずに、Qwen3-1.7BをLoRAでフルfp16として学習しました。

さらに重要なのは、HuggingFaceのエコシステム(Transformers、PEFT、TRL、Accelerate)がROCm上でシームレスに動作することを示すことが目的だった点です。実際に動きます。CUDAで動かしている同じ学習コードが、次の3つの環境変数を設定するだけでROCmでも動きます:

os.environ["ROCR_VISIBLE_DEVICES"] = "0"
os.environ["HIP_VISIBLE_DEVICES"] = "0"
os.environ["HSA_OVERRIDE_GFX_VERSION"] = "9.4.2"

以上です。コード変更は不要です。カスタムカーネルも不要です。CUDA互換のための調整(シム)も不要です。

返却形式: {"translated": "翻訳されたHTML"}

データセット:MedMCQA

MedMCQA は、大規模なマルチクラス選択式問題データセットで、インドの医学系入学試験(AIIMS、USMLE形式)から派生しています。各例には次の内容が含まれます:

  • 臨床上の設問
  • 4つの解答候補(A〜D)
  • 正解のインデックス
  • 任意の自由形式の解説(exp フィールド)

本プロジェクトでは、2,000件の学習サンプルを使用しました。これは、意味のある微調整が素早く達成できることを示すための、あえて小さな切り出しです。MI300X上での学習には約5分かかりました。


モデル:Qwen3-1.7B

ベースモデルはQwen/Qwen3-1.7Bです。これは Alibaba の最新の小規模言語モデルです。17億パラメータなので、低コストで微調整できる一方、整った臨床的推論を生成できるだけの能力もあります。trust_remote_code=True をサポートしており、HuggingFace Transformers で問題なく読み込めます。


プロンプト形式

命令微調整において、プロンプト形式の一貫性は極めて重要です。すべての学習例と、すべての推論呼び出しは同じテンプレートを使用します:

### 質問:
{question}

### 選択肢:
A) {opa}
B) {opb}
C) {opc}
D) {opd}

### 答え:
{answer_letter}) {answer_text}

### 説明:
{explanation}

学習中、モデルは答えと解説を含む完全なシーケンスを目にします。推論時には### Answer: までのすべてを提示し、そこからモデルに続きを生成させます。


LoRA を用いた学習

1.5B(15億)パラメータすべてを微調整するのではなく、PEFTライブラリを通じてLoRA(Low-Rank Adaptation)を使用します。LoRA は注意層に小さな学習可能なランク分解行列を注入し、ベースの重みは凍結したままにします。

LoRA の設定

from peft import LoraConfig, get_peft_model, TaskType

lora_config = LoraConfig(
    task_type=TaskType.CAUSAL_LM,
    r=8,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=["q_proj", "v_proj"],
    bias="none",
)

返却形式: {"translated": "翻訳されたHTML"}model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
# trainable params: 2,228,224 || all params: 1,543,901,184 || trainable%: 0.1443

モデルの15億パラメータのうち、学習するのは約220万のみです。これによりメモリ使用量が少なく、学習が高速になります。

学習引数

from transformers import TrainingArguments

args = TrainingArguments(
    output_dir="./outputs",
    num_train_epochs=2,
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,     # effective batch size = 16
    learning_rate=2e-4,
    fp16=True,
    bf16=False,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    gradient_checkpointing=True,
    optim="adamw_torch",
    warmup_ratio=0.05,
    lr_scheduler_type="cosine",
    report_to="none",
)

いくつか注目すべき点があります:

  • fp16=True, bf16=False — 標準のfp16を使用します。bfloat16での初期実験ではNaNの損失が発生しましたが、fp16に切り替えることで完全に解決しました。
  • gradient_checkpointing=True — 計算量と引き換えにメモリを節約します。MI300Xでは192 GBのVRAMがあるため厳密には不要ですが、より小さなGPUでの再現性を高めるためには良い実践です。
  • gradient_accumulation_steps=4 — 実効バッチサイズは4の物理バッチから16になります。
  • ウォームアップ付きのCosine LRスケジュール — 学習が短い実行では、フラットなスケジュールよりも収束が滑らかになります。

完全な学習ループ

from transformers import DataCollatorForSeq2Seq, Trainer

collator = DataCollatorForSeq2Seq(
    tokenizer,
    model=model,
    padding=True,
    pad_to_multiple_of=8,
)

trainer = Trainer(
    model=model,
    args=args,
    train_dataset=train_ds,
    eval_dataset=val_ds,
    data_collator=collator,
)

trainer.train()

# adapter + トークナイザを保存
model.save_pretrained("./outputs")
tokenizer.save_pretrained("./outputs")

学習後、./outputs にはLoRAアダプタの重みが含まれます。これは、数GBのモデル全体のチェックポイントの代わりに、数MB程度のファイルです。


推論(Inference)

推論時には、基となるモデルを読み込み、LoRAアダプタを追加し、必要に応じて重みをマージします:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained("./outputs", trust_remote_code=True)
tokenizer.pad_token = tokenizer.eos_token

base_model = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.float16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base_model, "./outputs")
model.eval()

生成は貪欲デコーディング(do_sample=False)を使用し、モデルがループしないよう反復ペナルティを適用します:

def generate(prompt, model, tokenizer):
    inputs = tokenizer(prompt, return_tensors="pt").to(model.device)

返却形式: {"translated": "翻訳されたHTML"}with torch.no_grad():
        output = model.generate(
            **inputs,
            max_new_tokens=200,
            do_sample=False,
            temperature=1.0,
            repetition_penalty=1.1,
            eos_token_id=tokenizer.eos_token_id,
            pad_token_id=tokenizer.eos_token_id,
        )

    new_tokens = output[0][inputs["input_ids"].shape[-1]:]
    return tokenizer.decode(new_tokens, skip_special_tokens=True)

サンプル出力

質問: 高血圧性緊急症(hypertensive emergency)の第一選択の治療として正しいのはどれですか?

A) 経口アムロジピン
B) 静注ラベタロールまたは静注ニトロプルシド
C) 舌下ニフェジピン
D) 筋注ヒドララジン

モデル出力:

B) 静注ラベタロールまたは静注ニトロプルシド

説明:
静注ラベタロール(ベータ遮断薬)またはニトロプルシドは、緊急の場で血圧を迅速に下げます。経口薬は、高血圧性緊急症で必要なように、臓器障害を防ぐための即時の血圧コントロールを行うには作用が遅すぎます。

このモデルは文字(選択肢)を出すだけではありません。なぜそうなるかも説明します。これが臨床的に役立つ理由です。


HuggingFace Hub から読み込み

微調整済みアダプタは公開されています。リポジトリをクローンせずに、そのまま直接読み込めます:

from transformers import AutoTokenizer, AutoModelForCausalLM
from peft import PeftModel
import torch

tokenizer = AutoTokenizer.from_pretrained(
    "Qwen/Qwen3-1.7B", trust_remote_code=True
)

base = AutoModelForCausalLM.from_pretrained(
    "Qwen/Qwen3-1.7B",
    torch_dtype=torch.bfloat16,
    device_map="auto",
    trust_remote_code=True,
)

model = PeftModel.from_pretrained(base, "HK2184/medqa-qwen3-lora")
model = model.merge_and_unload()
model.eval()

課題と修正

AMD ROCm のプロジェクトは、戦いの逸話(war story)の章がないと完成しません。そこで、私たちが遭遇したことはこちらです:

課題 根本原因 修正
NaN loss 混合精度による不安定さ bfloat16 → fp16 に切り替え
GPU が検出されない ROCm の環境変数が不足 ROCR_VISIBLE_DEVICESHIP_VISIBLE_DEVICESHSA_OVERRIDE_GFX_VERSION を設定
bitsandbytes がサポートされていない bitsandbytes の ROCm ビルドがない 量子化を完全に削除 — MI300X には十分な VRAM がある
推論出力がゴミになる トークナイザのパディング設定が誤っている pad_token = eos_token を設定し、padding_side を修正
Trainer の評価エラー Transformers のバージョン不一致 transformers>=4.40.0 に固定

bitsandbytes の問題には一言添える価値があります。NVIDIA のハードウェアでは、4-bit 量子化はモデルをメモリに収めるために 必要 であることが多いです。一方、192GB の HBM3 を搭載する MI300X では、単に不要です。これは本物のハードウェア上の利点です。よりクリーンな学習で、量子化アーティファクトもありません。

返却形式: {"translated": "翻訳されたHTML"}

結果

指標
学習可能パラメータ 約2.2M(全体の0.15%)
MI300Xでの学習時間 約5分
使用したデータセットサイズ 2,000サンプル
ベースライン MedMCQA の精度 約45%
フレームワーク PyTorch + ROCm 6.1

自分で試してみる

GPUがない?問題ありません。 ライブのGradioデモはHuggingFace Spaces(CPU推論)で動作します:

HuggingFace Spacesでのライブデモ

AMDのハードウェアをお持ちですか? リポジトリをクローンしてネイティブに実行してください:

git clone https://github.com/HK2184/MedQA-Medical-AI-on-AMD-ROCm.git
cd MedQA-Medical-AI-on-AMD-ROCm
pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/rocm6.1
pip install transformers datasets peft accelerate trl gradio
python train.py   # 約5分
python infer.py   # サンプルの質問を実行
python app.py     # Gradio UIを起動

次に何をする?

このプロジェクトは、パイプラインが機能することを証明しています。次のステップは、スケールさせ、さらに堅牢にすることです:

  • より大きなデータセット — MedMCQAの全コーパス(約18万問)で学習し、PubMedQAも追加する
  • 信頼度スコアリング — 回答と併せて較正済みの信頼度推定を追加する
  • RAG統合 — リアルタイムの医療文献検索で回答を根拠づける
  • 評価用ハーネス — 学習分割の外側での適切なホールドアウト精度ベンチマークを行う

結論

MedQAは、オープンソースのAMDハードウェア上で、能力があり、説明可能な医療AIを構築することは不可能ではなく、むしろ簡単だということを示しています。HuggingFaceエコシステムのROCm対応は、実際にかなり良好です。MI300Xのメモリの余裕は、エンジニアリング上の一群の問題を丸ごと取り除きます。そしてLoRAにより、1.7Bモデルの微調整が5分の作業になります。

もしAMDのROCm上で構築していて壁にぶつかっているなら、上記の修正で数時間を節約できるはずです。また、医療AIを構築するなら、単なる生の精度よりも説明性を重視することは真剣に受け止める価値があります。


lablab.ai で開催されたAMD Developer Hackathon向けに構築 · AMD ROCm + HuggingFaceエコシステムによって支えられています

*— Harikrishna Sivanand Iyer と Srijan Sivaram A

2026-05-07 14-26-07のスクリーンショット

この記事で言及されているモデル 2

この記事で言及されているデータセット 1

コミュニティ

編集プレビュー
テキスト入力欄にドラッグして、貼り付けるか、ここをクリックして画像、音声、動画をアップロードしてください。
ここをタップまたは貼り付けて、画像をアップロード
コメント

· 登録する または ログインしてコメントする

この記事で言及されているモデル 2

この記事で言及されているデータセット 1