[BugFix][ROCm] Prefer upstream PyTorch DLPack API in torch extension loader#585
Conversation
There was a problem hiding this comment.
Code Review
This pull request refactors the PyTorch backend detection logic into helper functions within the torch_c_dlpack_ext and tvm_ffi modules to better support ROCm alongside CUDA and CPU. It also adds an exception handler for OSError during extension loading to gracefully handle incompatible prebuilt libraries and introduces new tests for GPU tensor metadata. The review feedback highlights potential AttributeError risks when accessing version attributes directly on PyTorch builds missing specific backend support, recommending the use of getattr for safer detection.
| def _torch_extension_suffix() -> str: | ||
| """Return the backend suffix used by the prebuilt extension library.""" | ||
| if torch.cuda.is_available(): | ||
| return "rocm" if torch.version.hip is not None else "cuda" |
There was a problem hiding this comment.
use the same code as
def _torch_extension_device(torch_module: Any) -> str:
"""Return the torch backend name used in the optional extension library name."""
if torch_module.cuda.is_available():
if torch_module.version.cuda is not None:
return "cuda"
if torch_module.version.hip is not None:
return "rocm"
raise ValueError("Cannot determine whether to build with CUDA or ROCm.")
return "cpu" | # Keep trying JIT | ||
| pass | ||
| except OSError: | ||
| # A prebuilt torch-c-dlpack-ext wheel can be present but linked |
There was a problem hiding this comment.
assuming we detect correctly, is this still necessary?
There was a problem hiding this comment.
Mainly want to know if it is something as a temp measure that we can delete after a few version cycles, if so, please add ad comment
| # Keep trying JIT | ||
| pass | ||
| except OSError: | ||
| # A prebuilt torch-c-dlpack-ext wheel can be present but linked |
There was a problem hiding this comment.
Mainly want to know if it is something as a temp measure that we can delete after a few version cycles, if so, please add ad comment
|
Hi @tqchen. The scope of this PR has shifted from fixing ROCm-specific Torch C DLPack extension loading to making tvm-ffi avoid loading the optional out-of-tree If newer PyTorch ROCm builds already expose a working |
tqchen
left a comment
There was a problem hiding this comment.
thanks, looking good after fixing some minor nits
| def _torch_extension_device(torch_module: Any) -> str: | ||
| """Return the torch backend name used in the optional extension library name.""" | ||
| if torch_module.cuda.is_available(): | ||
| if getattr(torch_module.version, "hip", None) is not None: |
There was a problem hiding this comment.
would be good to use the following pattern, mainly also detect cuda version to be robust.
if torch_module.cuda.is_available():
if getattr(torch_module.version, "cuda", None) is not None:
return "cuda"
if getattr(torch_module.version, "hip", None) is not None:
return "rocm"
return "cuda"
return "cpu"|
|
||
|
|
||
| def _torch_extension_device(torch_module: Any) -> str: | ||
| """Return the torch backend name used in the optional extension library name.""" |
There was a problem hiding this comment.
if torch_module.cuda.is_available():
if getattr(torch_module.version, "cuda", None) is not None:
return "cuda"
if getattr(torch_module.version, "hip", None) is not None:
return "rocm"
return "cuda"
return "cpu"
Motivation:
__dlpack_c_exchange_api__, tvm-ffi uses it directly on all backends.torch-c-dlpack-extpath is only used as a fallback for older PyTorch builds that do not provide the API.Tests are added for backend detection, ROCm short-circuit behavior when PyTorch already provides the API, and GPU tensor metadata through DLPack.
Related PR: tile-ai/tilelang#2179, I have finished the A/B test locally.