Skip to content

InstanceNorm3d crashes on CoreML GPU path but works on CPU path #2667

@0xShug0

Description

@0xShug0

🐞Describing the bug

Found by opdiff.

For torch.nn.InstanceNorm3d(256, affine=False, track_running_stats=False), the CPU path works while the GPU-backed path crashes. This appears to hold regardless of precision.

Using FP32 as a concrete example:

  • ct.ComputeUnit.CPU_ONLY: works and matches PyTorch closely
  • ct.ComputeUnit.CPU_AND_GPU: crashes the process
  • ct.ComputeUnit.ALL: also crashes the process

This seems specific to the GPU/MPS-backed execution path for this 5D InstanceNorm3d case.

Stack Trace

The child process aborts with return code -6, and stderr shows:

/AppleInternal/Library/BuildRoots/4~CJ7zugCWBHgyY5jF33Pox4qXqx4MT4guvYcGoM4/Library/Caches/com.apple.xbs/TemporaryDirectory.oY3GWR/Sources/MetalPerformanceShaders/MPSNDArray/Kernels/MPSNDArrayStitchedReduction.mm:832: failed assertion `MPSNDArrayStitchedReductionRMSNorm 0x85ce05500 "(null)" Axis = 4. This class only supports axis = 0, 1, 2 or 3.
'

To Reproduce

Minimal self-contained repro script:

coreml_instancenorm3d_gpu_crash.py

import subprocess
import sys

import numpy as np
import torch
import torch.nn as nn
import coremltools as ct


class WrappedModule(nn.Module):
    def __init__(self, mod):
        super().__init__()
        self.mod = mod

    def forward(self, x):
        x = x.clone()
        return self.mod(x)


def run_child(mode: str):
    compute_units = {
        "cpu": ct.ComputeUnit.CPU_ONLY,
        "gpu": ct.ComputeUnit.CPU_AND_GPU,
        "any": ct.ComputeUnit.ALL,
    }[mode]

    torch.manual_seed(0)
    x = torch.randn(1, 256, 32, 28, 28, dtype=torch.float32)

    model = WrappedModule(
        nn.InstanceNorm3d(
            256,
            affine=False,
            track_running_stats=False,
        ).eval()
    ).eval()

    with torch.no_grad():
        y_pt = model(x).detach().cpu().numpy()

    exported = torch.export.export(model, (x,))
    exported = exported.run_decompositions()

    mlmodel = ct.convert(
        exported,
        inputs=[ct.TensorType(shape=x.shape)],
        convert_to="mlprogram",
        minimum_deployment_target=ct.target.iOS18,
        compute_units=compute_units,
        compute_precision=ct.precision.FLOAT32,
    )

    inp_name = list(mlmodel.input_description._fd_spec)[0].name
    out_name = list(mlmodel.output_description._fd_spec)[0].name
    y_coreml = mlmodel.predict({inp_name: x.detach().cpu().numpy()})[out_name]

    abs_diff = np.abs(y_pt - y_coreml)
    print("max_abs_diff =", float(abs_diff.max()))
    print("mean_abs_diff =", float(abs_diff.mean()))


def run_parent():
    for mode in ["cpu", "gpu", "any"]:
        print(f"=== {mode} ===")
        proc = subprocess.run(
            [sys.executable, __file__, "--child", mode],
            stdout=subprocess.PIPE,
            stderr=subprocess.PIPE,
            text=True,
        )
        print("returncode =", proc.returncode)
        print(proc.stdout)
        print(proc.stderr)


def main():
    if len(sys.argv) == 3 and sys.argv[1] == "--child":
        run_child(sys.argv[2])
    else:
        run_parent()


if __name__ == "__main__":
    main()

For numerical comparison, the script compares CoreML output against eager PyTorch:

with torch.no_grad():
    y_pt = model(x).detach().cpu().numpy()

y_coreml = mlmodel.predict({inp_name: x.detach().cpu().numpy()})[out_name]
abs_diff = np.abs(y_pt - y_coreml)
print("max_abs_diff =", float(abs_diff.max()))
print("mean_abs_diff =", float(abs_diff.mean()))

Observed behavior on my side for the FP32 example above:

  • CPU_ONLY: succeeds with max_abs_diff = 1.430511474609375e-06
  • CPU_AND_GPU: process aborts with return code -6
  • ALL: process aborts with return code -6

More generally, I’m also seeing the same CPU-vs-GPU split across precisions: CPU works, while GPU-backed execution fails.

System environment (please complete the following information):

  • coremltools version: 9.0
  • OS (e.g. MacOS version or Linux type): macOS on Apple Silicon
  • Any other relevant version information (e.g. PyTorch or TensorFlow version): PyTorch 2.10.0, Python 3.11

Additional context

The model converts successfully before the crash, and the CPU-only path runs normally, so this looks more like a GPU/MPS execution issue than a general conversion failure.

The MPS assertion mentions:

Axis = 4. This class only supports axis = 0, 1, 2 or 3.

which may be relevant since the repro uses a 5D InstanceNorm3d input of shape [1, 256, 32, 28, 28].

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugUnexpected behaviour that should be corrected (type)

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions