技術 約9分で読めます

TRELLIS.2をApple SiliconのMPSで動かすCUDAフリー移植

いけさん目次

MicrosoftのTRELLIS.2(4Bパラメータの image-to-3D モデル)をApple SiliconのPyTorch MPSで動かす移植実装が公開された。CUDA専用ライブラリへの依存を段階的にpure-PyTorch代替に置き換え、M4 Proで約3.5分の動作を確認している。

元のTRELLIS.2はNVIDIA GPU前提の実装で、flash_attnnvdiffrast、スパース3D畳み込みなど複数のCUDA専用コンポーネントを使う。これらをどう置き換えたかが、今回の移植の核心だ。

TRELLIS.2とは

TRELLIS.2はMicrosoft Researchが開発した4Bパラメータのimage-to-3Dモデルで、単一画像から高品質な3Dアセットを生成できる。内部ではO-Voxel(Open Voxel)と呼ぶスパースボクセル表現を使い、通常のiso-surfaceフィールド(SDF、FlexiCubesなど)では扱いにくい開いた面・非多様体ジオメトリを処理できる。

生成する3DアセットはPBR(物理ベースレンダリング)マテリアル全属性(Base Color、Roughness、Metallic、Opacity)に対応しており、512³解像度で約3秒、1536³解像度でも60秒程度で生成できる(CUDA環境での公称値)。
モデル重みはHugging Faceに公開され、MITライセンスで商用利用も可能だ。

CUDAへの依存構造

元のコードベースには、純粋なPyTorchでは実装されていないCUDA専用コンポーネントが複数ある。

コンポーネント役割
flash_attnスパーストランスフォーマー向けアテンション計算
flex_gemmスパース3D畳み込み(行列積カーネル)
o_voxel._CCUDAハッシュマップ(ボクセル→メッシュ変換)
nvdiffrast微分可能ラスタライザー(テクスチャベイク)
cumeshメッシュ後処理(穴埋め・ポリゴン削減)

加えて、コードベース全体に tensor.cuda() を直接呼び出す箇所が散在していた。これはデバイスを動的に切り替える前提を持たない実装で、MPS環境では全て落ちる。

置き換えの実装

移植では各コンポーネントを個別に対応している。

flash_attn → PyTorch SDPA

FlashAttentionは通常のアテンション計算より省メモリ・高速なCUDA専用の実装で、pip install flash-attn でインストールできるが内部がCUDAカーネルなのでMPSでは動かない。

代替として torch.nn.functional.scaled_dot_product_attention(SDPA)を使う。PyTorch 2.0から導入されたAPIで、CUDA環境ではFlashAttention v2をバックエンドとして使い、その他の環境ではPyTorchネイティブな実装にフォールバックする。MPS環境ではこのネイティブフォールバックが走る。

TRELLIS.2のスパースアテンションモジュール(full_attn.py)に対して、可変長シーケンスをパディングしてバッチ化し、SDPAを実行後にアンパディングする処理を追加している。スパーストランスフォーマーは各サンプルのトークン数がバラバラになるため、このパディング処理が必要になる。

# full_attn.py(パッチ後のSDPAバックエンド追加部分)
def _sdpa_backend(q, k, v, ...):
    # 可変長シーケンスを最大長でパディングしてバッチ化
    q_padded = pad_sequence(q, batch_first=True)
    k_padded = pad_sequence(k, batch_first=True)
    v_padded = pad_sequence(v, batch_first=True)
    out = F.scaled_dot_product_attention(q_padded, k_padded, v_padded)
    # パディング分を除いて元のシーケンス長に戻す
    return unpad_sequence(out, lengths)

スパース3D畳み込み → gather-scatter

flex_gemm はスパースボクセルデータに対する行列積カーネルで、TRELLIS.2の特徴抽出部分を担う。スパース畳み込み(sparse convolution)の一種で、ボクセルグリッドの「中身がある点だけ」に対して計算を行う。

通常の密行列積と何が違うかというと、3D空間上で大半のボクセルは空(ゼロ)なので、ゼロ要素を全部計算するのは無駄だ。スパース畳み込みは活性ボクセルだけ拾って計算し、結果を書き戻す。

移植版(backends/conv_none.py)ではgather-scatter方式で実装する。

flowchart TD
    A[活性ボクセル座標<br/>のハッシュマップ構築] --> B[カーネル位置ごとに<br/>近傍ボクセルをgather]
    B --> C[収集した特徴量に<br/>重み行列を適用<br/>torch.mm]
    C --> D[結果を元の<br/>ボクセル座標に<br/>scatter-add]
    D --> E[近傍マップを<br/>テンソルごとにキャッシュ<br/>再計算回避]

CUDAカーネルのような並列度はないが、PyTorchの行列積(torch.mm)はMPSバックエンドで動く。公称値では純粋PyTorchスパース畳み込みはCUDAの flex_gemm に比べて約10倍遅いとされており、これが現状のボトルネックだ。

o_voxel._C ハッシュマップ → Pythonディクショナリ

o_voxel._C はCUDA上で実装されたハッシュマップで、O-Voxelのデュアルグリッドからメッシュを抽出するときに座標→インデックスの変換に使う。

移植版(backends/mesh_extract.py)では flexible_dual_grid_to_mesh をPythonのディクショナリで再実装する。各エッジの接続ボクセルを探して、法線アライメントヒューリスティックでクワッドを三角形化する処理をPythonループで置き換えている。GPUハッシュマップほどの速度は出ないが、動作の正確さは維持されている。

nvdiffrast・cumesh → スタブ化

nvdiffrast(微分可能ラスタライザー、テクスチャベイク用)と cumesh(穴埋め・ポリゴン削減)は現時点ではスタブとして扱い、呼び出し時にスキップする。

この結果、現在の移植版では以下の制限がある。

  • テクスチャ出力なし(頂点カラーのみ)
  • メッシュの穴埋めなし(小さな穴が残る場合がある)

テクスチャベイクはCUDA専用ラスタライザーに強く依存しており、MPS対応の代替実装は現時点では存在しない。

.cuda() 呼び出しのパッチ

コードベース全体に散在する .cuda() を動的デバイス参照に書き換えることも移植の一部だ。具体的には現在の推論デバイスを取得して .to(device) で渡す形に変更している。

# 変更前(CUDA固定)
tensor = tensor.cuda()

# 変更後(デバイス非依存)
tensor = tensor.to(device)  # device = torch.device('mps') or 'cpu' or 'cuda'

パフォーマンス

M4 Pro(24GB統合メモリ)でのベンチマーク。パイプラインタイプ 512

ステージ時間
モデルロード約45秒
画像前処理約5秒
スパース構造サンプリング約15秒
Shape SLat サンプリング約90秒
Texture SLat サンプリング約50秒
メッシュデコード約30秒
合計約3.5分

ピーク時のメモリ使用量は18GB前後。24GB統合メモリがあれば動作する。

出力は単一画像から400K以上の頂点を持つメッシュ(OBJ・GLB形式)で、PBRマテリアルに対応したファイルとして書き出される(テクスチャはなく頂点カラー)。

セットアップ

git clone https://github.com/shivampkumar/trellis-mac.git
cd trellis-mac

# HuggingFaceログイン(ゲートモデルのアクセスに必要)
hf auth login

# セットアップスクリプト(venv作成・依存インストール・TRELLIS.2パッチ適用)
bash setup.sh

source .venv/bin/activate

# 画像から3D生成
python generate.py photo.png

モデル重みは初回実行時にHugging Faceから自動ダウンロードされる(約15GB)。事前にDINOv3とRMBG-2.0のゲートモデルへのアクセス申請が必要(通常即時承認)。

パイプラインタイプは3種類あり、512(デフォルト)、10241024_cascade から選べる。

CUDAフリー移植という文脈

この移植が意味を持つのは、TRELLIS.2のような大型3Dモデルが「CUDA環境なしでは実質触れなかった」状態を変える点だ。

PyTorchのMPSバックエンドはここ2年で大きく進歩しており、ComfyUI上のQwen Image EditがMPS経路でBF16の制約に引っかかるような問題は残りつつも、主要な行列積・アテンション計算はMPSで動く段階になっている。M1〜M3でのBF16が遅い問題(ハードウェアのネイティブサポートがM4から)はMPS全般の制約だが、TRELLIS.2の推論はBF16依存ではないため影響は限定的だ。

WebAssembly+Metal のゼロコピー推論Flash-MoEの397Bモデル移植で見てきたように、Apple Silicon向けの推論最適化はLLM領域で手法が固まってきた。
TRELLIS.2の移植はその延長で、3D生成という比較的新しい領域に同じ動きが広がってきた形だ。

スパース3D畳み込みをpure-PyTorchのgather-scatterで代替する手法は汎用的で、他のスパース畳み込みを使うモデル(3Dオブジェクト検出系など)のMPS移植にも同じアプローチが使える。CUDAカーネル比で10倍遅いという制約はあるが、「NVIDIAがなければ動かない」から「遅くても動く」への変化は大きい。

ライセンス面では、移植コード自体はMITだが、DINOv3(Meta custom license)とRMBG-2.0(CC BY-NC 4.0)のモデルが商用利用に制限を持つ。
商用に回す場合は、このモデル側のライセンスを別途確認しておきたい。

AIで3Dを作ってきた文脈

TRELLIS.2を触る前に、このブログでもAIで3Dアセットを作る試みを何度か扱ってきた。今回のMPS移植が噛み合う場所を見渡しておくと、どこで詰まっていたかも見えてくる。

AI 3D生成ツール比較 2026年版では、Hyper3D Rodin、Hitem3D、Tripo AI、Hunyuan 3D、TRELLIS、Meshyなど主要サービスを並べ、それぞれの入力画像の仕様(解像度・三面図の要否・背景処理)を整理した。
TRELLIS本家は当時のランキングで7位につけていて、1536³の高解像度出力と4Bパラメータの大きさが売りだったが、動かすにはNVIDIA GPUが前提だった。今回のMPS移植はそこを崩しに行くピースだ。

Blender MCPのセットアップ画像から3Dモデル生成の実践複数画像からの精度向上実験では、Claude経由でBlenderに指示を出し、Hyper3D Rodinで画像から3Dアセットを起こすワークフローを試した。
Rodinはクラウド側で推論が走るサービスなので手元のGPUは関係ないが、ポリゴン数が23,332で固定だったり、複数画像を送っても形状の精度だけが上がる(頂点数は増えない)といった制約がある。ローカルで回せるTRELLIS.2系の移植が進めば、この「サービスの枠に合わせる」作法から一歩外に出られる。

動画側の流れも見ておくと、Meta AIのActionMeshは動画からアニメーション付きの.glbメッシュを直接吐く方向で、静止画→3Dとは別ルートで3Dアセット化を攻めている。
TRELLIS.2が静止画→高精細メッシュ、ActionMeshが動画→アニメ付きメッシュ、という棲み分けで、どちらも「Apple Silicon単機でどこまで動くか」がこの先の焦点になる。

さらに広い意味での3Dとしては、NVIDIA Cosmosの世界モデルのように、3Dアセットを作るのではなく3D空間の挙動そのものを予測する方向もある。今回の話とは層が違うが、「3Dを生成する」という言葉の守備範囲が、アセット生成からシミュレーションまで広がっていることは押さえておきたい。

こうして並べると、TRELLIS.2のMPS移植は単体のネタではなく、「ローカルで3Dアセットを作る」という流れの中で、NVIDIA依存という最後のハードルを削りにいった一手だ。


元リポジトリ: shivampkumar/trellis-mac

参考: