有人在GPU上花10分鐘訓完一個能聊天的885K參數(shù)模型,用的是Google的編譯器技術(shù)——但調(diào)試時連print都用不了。
這是AI GDE TPU Sprint 2026的一個項目。我們把Andrej Karpathy的NanoChat從PyTorch移植到JAX,12400行代碼,核心目標是:同一套代碼,GPU和TPU都能跑,還要能系統(tǒng)性地做縮放定律實驗。
![]()
先搞清楚NanoChat是什么
Karpathy的NanoChat約8600行PyTorch,核心問題很直白:100美元預(yù)算能訓出最好的ChatGPT什么樣?
技術(shù)棧堆得很滿:Flash Attention 3(Hopper架構(gòu))、Muon優(yōu)化器、FP8混合精度、DistMuonAdamW分布式訓練、SFT+GRPO強化學習,外加一個網(wǎng)頁聊天界面。
架構(gòu)層面塞了一堆現(xiàn)代設(shè)計:分組查詢注意力(Grouped-Query Attention)、RoPE位置編碼、無參數(shù)RMSNorm、ReLU平方激活、Value Embeddings、Smear/Backout token混合、每層可學習標量、QK L2歸一化,以及l(fā)ogit softcap。
最后這五個——Value Embeddings、Smear/Backout、層標量、QK L2歸一化、logit softcap——是標準GPT/LLaMA配方里沒有的。移植到JAX時,每個都埋了坑。
Logit softcap:一個tanh的小把戲
NanoChat在注意力分數(shù)進softmax之前加了一道保險:
scores = softcap * tanh(scores / softcap)
這會把分數(shù)幅度鉗制在[-cap, cap]區(qū)間,防止深層網(wǎng)絡(luò)里注意力分布過銳導(dǎo)致的梯度餓死。原版用15.0,我們JAX版用了30.0——這是刻意為之的差異,后文的對比表里有注明。
PyTorch實現(xiàn)很直白:matmul算分數(shù),條件判斷進softcap,masked_fill處理掩碼,最后softmax。JAX版要過XLA編譯,同樣的邏輯寫成函數(shù)式風格,調(diào)試時看不到中間值。
為什么選JAX:兩個硬需求
移植動機很具體。第一,原版沒有縮放定律工具——我們需要系統(tǒng)掃描模型尺寸、數(shù)據(jù)量、算力預(yù)算,擬合Chinchilla風格的冪律。第二,不想維護兩套設(shè)備代碼。
JAX的XLA后端解決第二個問題。編譯一次消除Python開銷,之后純機器碼執(zhí)行。同一套代碼扔給GPU或TPU,XLA自動-target對應(yīng)后端。
代價也明確:沒有vLLM,沒有Flash Attention 3,JIT編譯后的函數(shù)里打斷點等于做夢。
10分鐘訓完的nano模型
我們在單GPU上用TinyStories數(shù)據(jù)集訓了一個885K參數(shù)的nano模型,耗時不到10分鐘。訓完直接掛進流式聊天UI,能跑。
這個體量當然聊不出GPT-4的效果,但驗證了端到端 pipeline:數(shù)據(jù)加載→XLA編譯→訓練→推理→服務(wù),全部走通。
對于想做縮放定律研究的人來說,這意味著可以快速迭代假設(shè)——改超參、改架構(gòu)、改數(shù)據(jù)混合比例,重新編譯再訓,循環(huán)成本從"天"壓到"小時"。
XLA的爽與痛
爽點很實在。Python層面的overhead被編譯抹掉,批量操作自動融合,內(nèi)存布局由編譯器優(yōu)化。我們寫的Flax NNX模塊,TPU上零修改直接跑。
痛點同樣實在。調(diào)試JIT編譯的函數(shù),print被優(yōu)化掉,斷點進不去,只能依賴jax.debug.print這類特殊API。性能問題定位靠猜:是算子融合沒觸發(fā)?還是內(nèi)存帶寬瓶頸?
更麻煩的是生態(tài)缺口。vLLM的PagedAttention、Flash Attention 3的FP8 kernel,這些PyTorch生態(tài)里的高性能組件,JAX側(cè)要么沒有,要么要自己用Pallas手寫。我們這次沒做,所以推理效率打不過原版。
什么時候該考慮這條路線
如果你需要同時占住GPU和TPU兩個池子——比如Google TPU Research Cloud的免費算力+自購的GPU集群——JAX的"寫一次跑兩處"能省大量工程債。
如果你的研究依賴系統(tǒng)性的縮放實驗,需要頻繁掃描超參空間,XLA的編譯開銷攤薄后,迭代速度會反超PyTorch的動態(tài)圖。
但如果你追求SOTA推理效率,或者團隊里沒人愿意啃XLA的調(diào)試黑箱,PyTorch生態(tài)仍是更務(wù)實的選擇。
我們開源了12400行代碼,對比表里有所有刻意保留的差異。拿去跑你的縮放實驗,或者單純看看JAX移植一個現(xiàn)代LLM架構(gòu)要踩多少坑——兩種用法都歡迎。
特別聲明:以上內(nèi)容(如有圖片或視頻亦包括在內(nèi))為自媒體平臺“網(wǎng)易號”用戶上傳并發(fā)布,本平臺僅提供信息存儲服務(wù)。
Notice: The content above (including the pictures and videos if any) is uploaded and posted by a user of NetEase Hao, which is a social media platform and only provides information storage services.