TRELLIS.2をApple SiliconのMPSで動かすCUDAフリー移植
目次
MicrosoftのTRELLIS.2(4Bパラメータの image-to-3D モデル)をApple SiliconのPyTorch MPSで動かす移植実装が公開された。CUDA専用ライブラリへの依存を段階的にpure-PyTorch代替に置き換え、M4 Proで約3.5分の動作を確認している。
元のTRELLIS.2はNVIDIA GPU前提の実装で、flash_attn、nvdiffrast、スパース3D畳み込みなど複数のCUDA専用コンポーネントを使う。これらをどう置き換えたかが、今回の移植の核心だ。
TRELLIS.2とは
TRELLIS.2はMicrosoft Researchが開発した4Bパラメータのimage-to-3Dモデルで、単一画像から高品質な3Dアセットを生成できる。内部ではO-Voxel(Open Voxel)と呼ぶスパースボクセル表現を使い、通常のiso-surfaceフィールド(SDF、FlexiCubesなど)では扱いにくい開いた面・非多様体ジオメトリを処理できる。
生成する3DアセットはPBR(物理ベースレンダリング)マテリアル全属性(Base Color、Roughness、Metallic、Opacity)に対応しており、512³解像度で約3秒、1536³解像度でも60秒程度で生成できる(CUDA環境での公称値)。
モデル重みはHugging Faceに公開され、MITライセンスで商用利用も可能だ。
CUDAへの依存構造
元のコードベースには、純粋なPyTorchでは実装されていないCUDA専用コンポーネントが複数ある。
| コンポーネント | 役割 |
|---|---|
flash_attn | スパーストランスフォーマー向けアテンション計算 |
flex_gemm | スパース3D畳み込み(行列積カーネル) |
o_voxel._C | CUDAハッシュマップ(ボクセル→メッシュ変換) |
nvdiffrast | 微分可能ラスタライザー(テクスチャベイク) |
cumesh | メッシュ後処理(穴埋め・ポリゴン削減) |
加えて、コードベース全体に tensor.cuda() を直接呼び出す箇所が散在していた。これはデバイスを動的に切り替える前提を持たない実装で、MPS環境では全て落ちる。
置き換えの実装
移植では各コンポーネントを個別に対応している。
flash_attn → PyTorch SDPA
FlashAttentionは通常のアテンション計算より省メモリ・高速なCUDA専用の実装で、pip install flash-attn でインストールできるが内部がCUDAカーネルなのでMPSでは動かない。
代替として torch.nn.functional.scaled_dot_product_attention(SDPA)を使う。PyTorch 2.0から導入されたAPIで、CUDA環境ではFlashAttention v2をバックエンドとして使い、その他の環境ではPyTorchネイティブな実装にフォールバックする。MPS環境ではこのネイティブフォールバックが走る。
TRELLIS.2のスパースアテンションモジュール(full_attn.py)に対して、可変長シーケンスをパディングしてバッチ化し、SDPAを実行後にアンパディングする処理を追加している。スパーストランスフォーマーは各サンプルのトークン数がバラバラになるため、このパディング処理が必要になる。
# full_attn.py(パッチ後のSDPAバックエンド追加部分)
def _sdpa_backend(q, k, v, ...):
# 可変長シーケンスを最大長でパディングしてバッチ化
q_padded = pad_sequence(q, batch_first=True)
k_padded = pad_sequence(k, batch_first=True)
v_padded = pad_sequence(v, batch_first=True)
out = F.scaled_dot_product_attention(q_padded, k_padded, v_padded)
# パディング分を除いて元のシーケンス長に戻す
return unpad_sequence(out, lengths)
スパース3D畳み込み → gather-scatter
flex_gemm はスパースボクセルデータに対する行列積カーネルで、TRELLIS.2の特徴抽出部分を担う。スパース畳み込み(sparse convolution)の一種で、ボクセルグリッドの「中身がある点だけ」に対して計算を行う。
通常の密行列積と何が違うかというと、3D空間上で大半のボクセルは空(ゼロ)なので、ゼロ要素を全部計算するのは無駄だ。スパース畳み込みは活性ボクセルだけ拾って計算し、結果を書き戻す。
移植版(backends/conv_none.py)ではgather-scatter方式で実装する。
flowchart TD
A[活性ボクセル座標<br/>のハッシュマップ構築] --> B[カーネル位置ごとに<br/>近傍ボクセルをgather]
B --> C[収集した特徴量に<br/>重み行列を適用<br/>torch.mm]
C --> D[結果を元の<br/>ボクセル座標に<br/>scatter-add]
D --> E[近傍マップを<br/>テンソルごとにキャッシュ<br/>再計算回避]
CUDAカーネルのような並列度はないが、PyTorchの行列積(torch.mm)はMPSバックエンドで動く。公称値では純粋PyTorchスパース畳み込みはCUDAの flex_gemm に比べて約10倍遅いとされており、これが現状のボトルネックだ。
o_voxel._C ハッシュマップ → Pythonディクショナリ
o_voxel._C はCUDA上で実装されたハッシュマップで、O-Voxelのデュアルグリッドからメッシュを抽出するときに座標→インデックスの変換に使う。
移植版(backends/mesh_extract.py)では flexible_dual_grid_to_mesh をPythonのディクショナリで再実装する。各エッジの接続ボクセルを探して、法線アライメントヒューリスティックでクワッドを三角形化する処理をPythonループで置き換えている。GPUハッシュマップほどの速度は出ないが、動作の正確さは維持されている。
nvdiffrast・cumesh → スタブ化
nvdiffrast(微分可能ラスタライザー、テクスチャベイク用)と cumesh(穴埋め・ポリゴン削減)は現時点ではスタブとして扱い、呼び出し時にスキップする。
この結果、現在の移植版では以下の制限がある。
- テクスチャ出力なし(頂点カラーのみ)
- メッシュの穴埋めなし(小さな穴が残る場合がある)
テクスチャベイクはCUDA専用ラスタライザーに強く依存しており、MPS対応の代替実装は現時点では存在しない。
.cuda() 呼び出しのパッチ
コードベース全体に散在する .cuda() を動的デバイス参照に書き換えることも移植の一部だ。具体的には現在の推論デバイスを取得して .to(device) で渡す形に変更している。
# 変更前(CUDA固定)
tensor = tensor.cuda()
# 変更後(デバイス非依存)
tensor = tensor.to(device) # device = torch.device('mps') or 'cpu' or 'cuda'
パフォーマンス
M4 Pro(24GB統合メモリ)でのベンチマーク。パイプラインタイプ 512。
| ステージ | 時間 |
|---|---|
| モデルロード | 約45秒 |
| 画像前処理 | 約5秒 |
| スパース構造サンプリング | 約15秒 |
| Shape SLat サンプリング | 約90秒 |
| Texture SLat サンプリング | 約50秒 |
| メッシュデコード | 約30秒 |
| 合計 | 約3.5分 |
ピーク時のメモリ使用量は18GB前後。24GB統合メモリがあれば動作する。
出力は単一画像から400K以上の頂点を持つメッシュ(OBJ・GLB形式)で、PBRマテリアルに対応したファイルとして書き出される(テクスチャはなく頂点カラー)。
セットアップ
git clone https://github.com/shivampkumar/trellis-mac.git
cd trellis-mac
# HuggingFaceログイン(ゲートモデルのアクセスに必要)
hf auth login
# セットアップスクリプト(venv作成・依存インストール・TRELLIS.2パッチ適用)
bash setup.sh
source .venv/bin/activate
# 画像から3D生成
python generate.py photo.png
モデル重みは初回実行時にHugging Faceから自動ダウンロードされる(約15GB)。事前にDINOv3とRMBG-2.0のゲートモデルへのアクセス申請が必要(通常即時承認)。
パイプラインタイプは3種類あり、512(デフォルト)、1024、1024_cascade から選べる。
CUDAフリー移植という文脈
この移植が意味を持つのは、TRELLIS.2のような大型3Dモデルが「CUDA環境なしでは実質触れなかった」状態を変える点だ。
PyTorchのMPSバックエンドはここ2年で大きく進歩しており、ComfyUI上のQwen Image EditがMPS経路でBF16の制約に引っかかるような問題は残りつつも、主要な行列積・アテンション計算はMPSで動く段階になっている。M1〜M3でのBF16が遅い問題(ハードウェアのネイティブサポートがM4から)はMPS全般の制約だが、TRELLIS.2の推論はBF16依存ではないため影響は限定的だ。
WebAssembly+Metal のゼロコピー推論やFlash-MoEの397Bモデル移植で見てきたように、Apple Silicon向けの推論最適化はLLM領域で手法が固まってきた。
TRELLIS.2の移植はその延長で、3D生成という比較的新しい領域に同じ動きが広がってきた形だ。
スパース3D畳み込みをpure-PyTorchのgather-scatterで代替する手法は汎用的で、他のスパース畳み込みを使うモデル(3Dオブジェクト検出系など)のMPS移植にも同じアプローチが使える。CUDAカーネル比で10倍遅いという制約はあるが、「NVIDIAがなければ動かない」から「遅くても動く」への変化は大きい。
ライセンス面では、移植コード自体はMITだが、DINOv3(Meta custom license)とRMBG-2.0(CC BY-NC 4.0)のモデルが商用利用に制限を持つ。
商用に回す場合は、このモデル側のライセンスを別途確認しておきたい。
AIで3Dを作ってきた文脈
TRELLIS.2を触る前に、このブログでもAIで3Dアセットを作る試みを何度か扱ってきた。今回のMPS移植が噛み合う場所を見渡しておくと、どこで詰まっていたかも見えてくる。
AI 3D生成ツール比較 2026年版では、Hyper3D Rodin、Hitem3D、Tripo AI、Hunyuan 3D、TRELLIS、Meshyなど主要サービスを並べ、それぞれの入力画像の仕様(解像度・三面図の要否・背景処理)を整理した。
TRELLIS本家は当時のランキングで7位につけていて、1536³の高解像度出力と4Bパラメータの大きさが売りだったが、動かすにはNVIDIA GPUが前提だった。今回のMPS移植はそこを崩しに行くピースだ。
Blender MCPのセットアップ、画像から3Dモデル生成の実践、複数画像からの精度向上実験では、Claude経由でBlenderに指示を出し、Hyper3D Rodinで画像から3Dアセットを起こすワークフローを試した。
Rodinはクラウド側で推論が走るサービスなので手元のGPUは関係ないが、ポリゴン数が23,332で固定だったり、複数画像を送っても形状の精度だけが上がる(頂点数は増えない)といった制約がある。ローカルで回せるTRELLIS.2系の移植が進めば、この「サービスの枠に合わせる」作法から一歩外に出られる。
動画側の流れも見ておくと、Meta AIのActionMeshは動画からアニメーション付きの.glbメッシュを直接吐く方向で、静止画→3Dとは別ルートで3Dアセット化を攻めている。
TRELLIS.2が静止画→高精細メッシュ、ActionMeshが動画→アニメ付きメッシュ、という棲み分けで、どちらも「Apple Silicon単機でどこまで動くか」がこの先の焦点になる。
さらに広い意味での3Dとしては、NVIDIA Cosmosの世界モデルのように、3Dアセットを作るのではなく3D空間の挙動そのものを予測する方向もある。今回の話とは層が違うが、「3Dを生成する」という言葉の守備範囲が、アセット生成からシミュレーションまで広がっていることは押さえておきたい。
こうして並べると、TRELLIS.2のMPS移植は単体のネタではなく、「ローカルで3Dアセットを作る」という流れの中で、NVIDIA依存という最後のハードルを削りにいった一手だ。
元リポジトリ: shivampkumar/trellis-mac
参考: