Skip to content

[BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch extension loader#585

Merged
tqchen merged 4 commits into
apache:mainfrom
zihaomu:zihao/tvmffi-001-rocm-torch-dlpack-ext
May 13, 2026
Merged

[BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch extension loader#585
tqchen merged 4 commits into
apache:mainfrom
zihaomu:zihao/tvmffi-001-rocm-torch-dlpack-ext

Conversation

@zihaomu
Copy link
Copy Markdown
Contributor

@zihaomu zihaomu commented May 12, 2026

Motivation:

  • If PyTorch already provides __dlpack_c_exchange_api__, tvm-ffi uses it directly on all backends.
  • The optional torch-c-dlpack-ext path is only used as a fallback for older PyTorch builds that do not provide the API.
  • The fallback extension library name is selected with backend-aware detection for CPU, CUDA, and ROCm.

Tests are added for backend detection, ROCm short-circuit behavior when PyTorch already provides the API, and GPU tensor metadata through DLPack.

Related PR: tile-ai/tilelang#2179, I have finished the A/B test locally.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request refactors the PyTorch backend detection logic into helper functions within the torch_c_dlpack_ext and tvm_ffi modules to better support ROCm alongside CUDA and CPU. It also adds an exception handler for OSError during extension loading to gracefully handle incompatible prebuilt libraries and introduces new tests for GPU tensor metadata. The review feedback highlights potential AttributeError risks when accessing version attributes directly on PyTorch builds missing specific backend support, recommending the use of getattr for safer detection.

Comment thread addons/torch_c_dlpack_ext/torch_c_dlpack_ext/core.py Outdated
Comment thread python/tvm_ffi/_optional_torch_c_dlpack.py
@zihaomu zihaomu marked this pull request as ready for review May 12, 2026 07:57
def _torch_extension_suffix() -> str:
"""Return the backend suffix used by the prebuilt extension library."""
if torch.cuda.is_available():
return "rocm" if torch.version.hip is not None else "cuda"
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

use the same code as

def _torch_extension_device(torch_module: Any) -> str:
    """Return the torch backend name used in the optional extension library name."""
    if torch_module.cuda.is_available():
        if torch_module.version.cuda is not None:
            return "cuda"
        if torch_module.version.hip is not None:
            return "rocm"
        raise ValueError("Cannot determine whether to build with CUDA or ROCm.")
    return "cpu" 

# Keep trying JIT
pass
except OSError:
# A prebuilt torch-c-dlpack-ext wheel can be present but linked
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

assuming we detect correctly, is this still necessary?

Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly want to know if it is something as a temp measure that we can delete after a few version cycles, if so, please add ad comment

# Keep trying JIT
pass
except OSError:
# A prebuilt torch-c-dlpack-ext wheel can be present but linked
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Mainly want to know if it is something as a temp measure that we can delete after a few version cycles, if so, please add ad comment

Comment thread python/tvm_ffi/_optional_torch_c_dlpack.py Outdated
@zihaomu zihaomu changed the title [BugFix] ROCm torch C DLPack extension loading [BugFix] Prefer upstream PyTorch DLPack API in torch extension loader May 12, 2026
@zihaomu
Copy link
Copy Markdown
Contributor Author

zihaomu commented May 12, 2026

Hi @tqchen. The scope of this PR has shifted from fixing ROCm-specific Torch C DLPack extension loading to making tvm-ffi avoid loading the optional out-of-tree torch-c-dlpack-ext package when upstream PyTorch already provides the C DLPack exchange API.

If newer PyTorch ROCm builds already expose a working __dlpack_c_exchange_api__, in that case, tvm-ffi should trust PyTorch and return early instead of forcing a local extension override. The local extension remains only as a compatibility fallback for older PyTorch versions

@zihaomu zihaomu changed the title [BugFix] Prefer upstream PyTorch DLPack API in torch extension loader [BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch extension loader May 12, 2026
Copy link
Copy Markdown
Member

@tqchen tqchen left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

thanks, looking good after fixing some minor nits

def _torch_extension_device(torch_module: Any) -> str:
"""Return the torch backend name used in the optional extension library name."""
if torch_module.cuda.is_available():
if getattr(torch_module.version, "hip", None) is not None:
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be good to use the following pattern, mainly also detect cuda version to be robust.

 if torch_module.cuda.is_available():
    if getattr(torch_module.version, "cuda", None) is not None:
       return "cuda"
    if getattr(torch_module.version, "hip", None) is not None:
       return "rocm"
  return "cuda"
return "cpu"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done



def _torch_extension_device(torch_module: Any) -> str:
"""Return the torch backend name used in the optional extension library name."""
Copy link
Copy Markdown
Member

@tqchen tqchen May 12, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

 if torch_module.cuda.is_available():
    if getattr(torch_module.version, "cuda", None) is not None:
       return "cuda"
    if getattr(torch_module.version, "hip", None) is not None:
       return "rocm"
  return "cuda"
return "cpu"

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done

@tqchen tqchen merged commit 4dbc260 into apache:main May 13, 2026
9 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants