Running TRELLIS.2 on Apple Silicon MPS: a CUDA-free port
Contents
A port of Microsoft’s TRELLIS.2 (a 4B-parameter image-to-3D model) that runs on Apple Silicon’s PyTorch MPS backend has been published.
It replaces the CUDA-only dependencies with pure-PyTorch equivalents one by one, and confirms about 3.5 minutes of inference on an M4 Pro.
The original TRELLIS.2 assumes NVIDIA GPUs and uses several CUDA-specific components — flash_attn, nvdiffrast, sparse 3D convolutions, and more.
How each of those was swapped out is the heart of this port.
What TRELLIS.2 is
TRELLIS.2 is a 4B-parameter image-to-3D model from Microsoft Research that produces high-quality 3D assets from a single image.
Internally it uses a sparse voxel representation called O-Voxel (Open Voxel), which can handle open surfaces and non-manifold geometry that are awkward for typical iso-surface fields like SDFs or FlexiCubes.
Generated 3D assets cover the full PBR (Physically Based Rendering) material attribute set (Base Color, Roughness, Metallic, Opacity), and take around 3 seconds at 512³ resolution or about 60 seconds at 1536³ (nominal numbers on CUDA).
Model weights are on Hugging Face under an MIT license, so commercial use is fine on the code side.
The CUDA dependency surface
The original codebase pulls in multiple CUDA-only components that aren’t available in plain PyTorch.
| Component | Role |
|---|---|
flash_attn | Attention for the sparse transformer |
flex_gemm | Sparse 3D convolution (the matmul kernel) |
o_voxel._C | CUDA hash map (voxel → mesh conversion) |
nvdiffrast | Differentiable rasterizer (texture baking) |
cumesh | Mesh post-processing (hole filling, decimation) |
On top of that, direct tensor.cuda() calls are scattered throughout the codebase.
That’s an implementation that assumes you never swap devices, and everything blows up under MPS.
How each piece is replaced
The port handles each component individually.
flash_attn → PyTorch SDPA
FlashAttention is a CUDA-only implementation that’s lower-memory and faster than vanilla attention; you install it with pip install flash-attn, but the internals are CUDA kernels, so it doesn’t run on MPS.
The replacement is torch.nn.functional.scaled_dot_product_attention (SDPA).
It’s an API added in PyTorch 2.0 that uses FlashAttention v2 as a backend on CUDA and falls back to a native PyTorch implementation elsewhere — including MPS.
For TRELLIS.2’s sparse attention module (full_attn.py), the port adds a code path that pads variable-length sequences to batch them, runs SDPA, then un-pads the result.
Sparse transformers have per-sample token counts that vary, so this pad/unpad step is required.
# full_attn.py (added SDPA backend after patching)
def _sdpa_backend(q, k, v, ...):
# Batch variable-length sequences by padding to the max length
q_padded = pad_sequence(q, batch_first=True)
k_padded = pad_sequence(k, batch_first=True)
v_padded = pad_sequence(v, batch_first=True)
out = F.scaled_dot_product_attention(q_padded, k_padded, v_padded)
# Strip padding back to the original sequence lengths
return unpad_sequence(out, lengths)
Sparse 3D convolution → gather-scatter
flex_gemm is a matmul kernel over sparse voxel data and handles TRELLIS.2’s feature extraction.
It’s a sparse convolution: instead of operating on the full grid, it only computes on voxels that actually contain something.
The difference from dense matmul: most voxels in a 3D grid are empty (zero), so computing every zero element is wasted work.
Sparse convolution gathers only active voxels, computes, and writes back.
The port (backends/conv_none.py) implements this with a gather-scatter approach.
flowchart TD
A[Build a hash map<br/>of active voxel coords] --> B[For each kernel position,<br/>gather neighbor voxels]
B --> C[Apply weight matrix<br/>to gathered features<br/>torch.mm]
C --> D[scatter-add results<br/>back to original<br/>voxel coordinates]
D --> E[Cache neighbor maps<br/>per tensor<br/>to avoid recomputation]
You don’t get the parallelism of a CUDA kernel, but PyTorch’s matmul (torch.mm) runs on MPS.
The reported numbers put pure-PyTorch sparse convolution at roughly 10× slower than CUDA’s flex_gemm, and that’s the current bottleneck.
o_voxel._C hash map → Python dictionary
o_voxel._C is a CUDA-side hash map used for coordinate→index lookups when extracting a mesh from O-Voxel’s dual grid.
The port (backends/mesh_extract.py) reimplements flexible_dual_grid_to_mesh with a Python dictionary.
It replaces the “find the connected voxels for each edge and triangulate quads using a normal-alignment heuristic” logic with a Python loop.
It doesn’t match GPU hash map speed, but correctness is preserved.
nvdiffrast / cumesh → stubbed out
nvdiffrast (the differentiable rasterizer used for texture baking) and cumesh (hole filling and decimation) are stubbed out for now — the calls no-op.
The current port therefore has these limitations.
- No texture output (vertex colors only)
- No mesh hole filling (small holes may remain)
Texture baking is tightly coupled to the CUDA-only rasterizer, and no MPS-compatible substitute exists today.
Patching .cuda() calls
Part of the port is rewriting the .cuda() calls scattered through the codebase to use dynamic device references.
Concretely, it fetches the current inference device and hands it off via .to(device).
# Before (CUDA-pinned)
tensor = tensor.cuda()
# After (device-agnostic)
tensor = tensor.to(device) # device = torch.device('mps') or 'cpu' or 'cuda'
Performance
Benchmarks on an M4 Pro (24GB unified memory), pipeline type 512.
| Stage | Time |
|---|---|
| Model load | ~45s |
| Image preprocessing | ~5s |
| Sparse structure sampling | ~15s |
| Shape SLat sampling | ~90s |
| Texture SLat sampling | ~50s |
| Mesh decode | ~30s |
| Total | ~3.5min |
Peak memory sits around 18GB. 24GB of unified memory is enough to run it.
Output is a mesh with 400K+ vertices from a single image (OBJ and GLB formats), written out as a PBR-material-aware file (vertex colors, no textures).
Setup
git clone https://github.com/shivampkumar/trellis-mac.git
cd trellis-mac
# Hugging Face login (required for gated model access)
hf auth login
# Setup script (creates venv, installs deps, applies TRELLIS.2 patches)
bash setup.sh
source .venv/bin/activate
# Generate 3D from an image
python generate.py photo.png
Model weights download from Hugging Face automatically on first run (~15GB).
You’ll need to pre-approve access to the DINOv3 and RMBG-2.0 gated models (usually granted instantly).
Three pipeline types are available: 512 (default), 1024, and 1024_cascade.
Where a CUDA-free port fits
What makes this port interesting is that it changes the “you basically can’t touch TRELLIS.2 without a CUDA environment” situation for large 3D models.
PyTorch’s MPS backend has progressed a lot over the past two years. Issues like ComfyUI’s Qwen Image Edit hitting BF16 limitations on the MPS path still exist, but the major matmul and attention paths now run on MPS.
The slow BF16 situation on M1–M3 (native hardware support starts at M4) is a general MPS constraint, but TRELLIS.2’s inference doesn’t depend on BF16, so the impact is limited.
As seen in zero-copy inference with WebAssembly + Metal and the Flash-MoE port of a 397B-parameter model, the techniques for Apple Silicon inference optimization have settled down in the LLM space.
This TRELLIS.2 port is the extension of that into 3D generation, a comparatively newer area.
The gather-scatter approach for replacing sparse 3D convolution with pure PyTorch is general-purpose, and the same technique can port other sparse-convolution models (like some 3D object detectors) to MPS.
Even with the 10× slowdown against CUDA kernels, going from “doesn’t run without NVIDIA” to “runs, even if slowly” is a big shift.
On the license side, the port code itself is MIT, but the DINOv3 model (Meta custom license) and RMBG-2.0 (CC BY-NC 4.0) come with commercial-use restrictions.
If you’re putting this into commercial use, the model-side licensing is worth checking separately.
The AI-generated 3D context so far
This blog has touched on AI-generated 3D assets a few times before TRELLIS.2, and it’s worth looking at where this MPS port lines up with that history — it also makes the past sticking points visible.
In the 2026 comparison of AI 3D generation tools, I surveyed the major services — Hyper3D Rodin, Hitem3D, Tripo AI, Hunyuan 3D, TRELLIS, Meshy — and sorted out what each one expects for input images (resolution, whether three-view images are needed, background handling).
TRELLIS proper sat at 7th in that ranking, with 1536³ high-resolution output and 4B parameters as the main sell, but running it required NVIDIA GPUs.
This MPS port is the piece that chips away at that requirement.
In setting up Blender MCP, practical image-to-3D model generation, and an experiment on accuracy with multiple input images, I drove Blender via Claude and tried a workflow where Hyper3D Rodin turns images into 3D assets.
Rodin runs inference in the cloud, so the local GPU isn’t involved, but it has constraints — 23,332 polygons fixed, and even with multiple images only the shape precision improves (vertex count doesn’t go up).
If locally runnable TRELLIS.2-style ports mature, you can step out of that “fit whatever the service ships” pattern.
Looking at the video side, Meta AI’s ActionMesh goes from video directly to an animated .glb mesh, attacking 3D asset creation through a different route from the still-image → 3D path.
TRELLIS.2 is still-image → high-detail mesh, ActionMesh is video → animated mesh, and for both, “how far can a single Apple Silicon machine push this” becomes the next focus.
At a broader level of “3D,” NVIDIA Cosmos’s world model goes in the direction of predicting the behavior of 3D space itself, not generating assets.
That’s a different layer from the story here, but it’s worth noting that “generating 3D” now spans from asset creation to simulation.
Lined up this way, the TRELLIS.2 MPS port isn’t an isolated topic.
It’s a single move inside the broader trajectory of “making 3D assets locally” that shaves off the NVIDIA dependency — the last remaining hurdle.
Source repository: shivampkumar/trellis-mac
References: