Skip to content

[Feature] Support AMD HIP for cpp extension#460

Merged
MasterJH5574 merged 6 commits intoapache:mainfrom
DarkSharpness:feat_hip
Feb 18, 2026
Merged

[Feature] Support AMD HIP for cpp extension#460
MasterJH5574 merged 6 commits intoapache:mainfrom
DarkSharpness:feat_hip

Conversation

@DarkSharpness
Copy link
Contributor

Related issue #458 .

I'm not familiar with AMD at all, and most of the code is generated by claude code. I've only cleaned up a little and tried the following example on one AMD machine. Need some reviews from AMD experts.

import torch
from tvm_ffi import Module
import tvm_ffi.cpp

# define the cpp source code
cpp_source = '''
#include <hip/hip_runtime.h>

__global__ void add_one_kernel(const float* __restrict__ x,
                               float* __restrict__ y,
                               int64_t n) {
    int64_t i = (int64_t)blockIdx.x * blockDim.x + threadIdx.x;
    if (i < n) y[i] = x[i] + 1.0f;
}

void add_one_hip(tvm::ffi::TensorView x, tvm::ffi::TensorView y) {
    int64_t n = x.size(0);
    const float* x_ptr = static_cast<const float*>(x.data_ptr());
    float* y_ptr = static_cast<float*>(y.data_ptr());
    constexpr int threads = 256;
    int blocks = (int)((n + threads - 1) / threads);
    hipStream_t stream = 0;  // default stream; replace if your runtime provides a stream
    hipLaunchKernelGGL(add_one_kernel,
                       dim3(blocks), dim3(threads),
                       0, stream,
                       x_ptr, y_ptr, n);
}
'''

# compile the cpp source code and load the module
mod: Module = tvm_ffi.cpp.load_inline(
    name="hello",
    cuda_sources=cpp_source,
    functions="add_one_hip",
)

# use the function from the loaded module to perform
x = torch.tensor([1, 2, 3, 4, 5], dtype=torch.float32, device="cuda")
y = torch.empty_like(x, device="cuda")
mod.add_one_hip(x, y)
torch.testing.assert_close(x + 1, y)

@gemini-code-assist
Copy link
Contributor

Summary of Changes

Hello @DarkSharpness, I'm Gemini Code Assist1! I'm currently reviewing this pull request and will post my feedback shortly. In the meantime, here's a summary to help you and other reviewers quickly get up to speed!

This pull request introduces significant enhancements by integrating support for AMD HIP (ROCm) into the C++ extension compilation system. The primary goal is to enable users to compile and run GPU-accelerated code on AMD hardware, alongside existing NVIDIA CUDA support. The changes include robust backend detection, dynamic configuration of build tools and flags for both CUDA and HIP, and an extended API for explicit backend selection, making the system more versatile for different GPU environments.

Highlights

  • AMD HIP (ROCm) Support: Added comprehensive support for compiling C++/CUDA extensions with AMD HIP (ROCm), allowing users to leverage AMD GPUs for their computations.
  • Automatic GPU Backend Detection: Implemented an auto-detection mechanism that identifies whether CUDA or HIP is available on the system, defaulting to HIP if ROCm is found, otherwise CUDA. This can be overridden via the TVM_FFI_GPU_BACKEND environment variable.
  • ROCm Environment and Architecture Discovery: Introduced utility functions (_find_rocm_home, _get_rocm_target) to locate the ROCm installation path and automatically determine the target GPU architecture flags (e.g., --offload-arch=gfxXXXX) for compilation.
  • Unified Build System for GPU Backends: Refactored the internal build generation logic (_generate_ninja_build, _build_impl) to dynamically configure compilation flags, include paths, and linker options based on the selected GPU backend (CUDA or HIP), including using hipcc for HIP compilation.
  • API Extension for Backend Selection: Extended public API functions (build_inline, load_inline, build, load) with an optional backend parameter, allowing users to explicitly specify 'cuda' or 'hip' for compilation.

🧠 New Feature in Public Preview: You can now enable Memory to help Gemini Code Assist learn from your team's feedback. This makes future code reviews more consistent and personalized to your project's style. Click here to enable Memory in your admin console.

Changelog
  • python/tvm_ffi/cpp/extension.py
    • Added BACKEND_STR type alias and _detect_gpu_backend and _resolve_gpu_backend functions for automatic and explicit GPU backend selection.
    • Implemented _find_rocm_home to locate ROCm installations and _get_rocm_target to determine ROCm GPU architecture flags.
    • Modified _generate_ninja_build to accept a backend parameter, dynamically setting compilation flags and tools (hipcc or nvcc) based on the chosen GPU backend.
    • Updated _build_impl to incorporate the new backend parameter and use the resolved backend for build generation.
    • Extended build_inline, load_inline, build, and load functions with an optional backend parameter and updated their docstrings to reflect this new capability.
Activity
  • The author, DarkSharpness, initiated this feature to support AMD HIP for cpp extensions, addressing related issue [Feature Request] Support ROCm platform for tvm_ffi.cpp #458.
  • The initial implementation was largely generated by an AI tool (Claude code), and the author has performed some cleanup.
  • The author has tested the changes on an AMD machine with a provided example and is seeking reviews from AMD experts due to their limited familiarity with the AMD ecosystem.
Using Gemini Code Assist

The full guide for Gemini Code Assist can be found on our documentation page, here are some quick tips.

Invoking Gemini

You can request assistance from Gemini at any point by creating a comment using either /gemini <command> or @gemini-code-assist <command>. Below is a summary of the supported commands on the current page.

Feature Command Description
Code Review /gemini review Performs a code review for the current pull request in its current state.
Pull Request Summary /gemini summary Provides a summary of the current pull request in its current state.
Comment @gemini-code-assist Responds in comments when explicitly tagged, both in pull request comments and review comments.
Help /gemini help Displays a list of available commands.

Customization

To customize Gemini Code Assist for GitHub experience, repository maintainers can create a configuration file and/or provide a custom code review style guide (such as PEP-8 for Python) by creating and adding files to a .gemini/ folder in the base of the repository. Detailed instructions can be found here.

Limitations & Feedback

Gemini Code Assist may make mistakes. Please leave feedback on any instances where its feedback is incorrect or counter productive. You can react with 👍 and 👎 on @gemini-code-assist comments. If you're interested in giving your feedback about your experience with Gemini Code Assist for Github and other Google products, sign up here.

You can also get AI-powered code generation, chat, as well as code reviews directly in the IDE at no cost with the Gemini Code Assist IDE Extension.

Footnotes

  1. Review the Privacy Notices, Generative AI Prohibited Use Policy, Terms of Service, and learn how to configure Gemini Code Assist in GitHub here. Gemini can make mistakes, so double check it and use code with caution.

Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request adds support for AMD's HIP backend to the C++ extension, allowing users to compile and run HIP code. The changes include auto-detection of the GPU backend (CUDA or HIP), logic to find the ROCm installation and target architecture, and adjustments to the build process to use hipcc and appropriate flags.

My review focuses on improving the maintainability and clarity of the new code. I've pointed out some areas with code duplication that could be refactored. A significant point of feedback is the confusing use of cuda_* naming for parameters and variables that now also handle HIP code; I've suggested renaming them to more generic gpu_* names to improve API clarity. I also found a few incomplete docstrings that should be fixed.

Overall, this is a great feature addition. The feedback provided should help make the code more robust and easier to understand for future contributors.

Comment on lines +234 to +265
try:
agent_enum = str(Path(_find_rocm_home()) / "bin" / "rocm_agent_enumerator")
if not Path(agent_enum).exists():
agent_enum = "rocm_agent_enumerator"
status = subprocess.run(args=[agent_enum], capture_output=True, check=True, text=True)
archs = list(
dict.fromkeys(
line.strip()
for line in status.stdout.strip().split("\n")
if line.strip() and line.strip() != "gfx000"
)
)
if archs:
return [f"--offload-arch={arch}" for arch in archs]
except (subprocess.CalledProcessError, FileNotFoundError):
pass
# Try rocminfo
try:
status = subprocess.run(args=["rocminfo"], capture_output=True, check=True, text=True)
archs = list(
dict.fromkeys(
line.split(":")[-1].strip()
for line in status.stdout.split("\n")
if "Name:" in line
and "gfx" in line.lower()
and line.split(":")[-1].strip() != "gfx000"
)
)
if archs:
return [f"--offload-arch={arch}" for arch in archs]
except (subprocess.CalledProcessError, FileNotFoundError):
pass
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The logic for trying rocm_agent_enumerator and rocminfo is very similar. This code duplication could be reduced by extracting the common pattern into a helper function. This would improve maintainability.

For example, you could have a helper that takes the command and a parsing function as arguments:

def _try_get_arch_from_command(args, parse_func):
    try:
        status = subprocess.run(args=args, capture_output=True, check=True, text=True)
        # Using dict.fromkeys to get unique archs while preserving order
        archs = list(dict.fromkeys(parse_func(status.stdout)))
        if archs:
            return [f"--offload-arch={arch}" for arch in archs]
    except (subprocess.CalledProcessError, FileNotFoundError):
        pass
    return None

# In _get_rocm_target:
# ...
if archs := _try_get_arch_from_command(...):
    return archs
# ...

Comment on lines 612 to 620
cuda_path_list = [str(Path(p).resolve()) for p in _str_seq2list(cuda_files)]
with_cpp = bool(cpp_path_list)
with_cuda = bool(cuda_path_list)
assert with_cpp or with_cuda, "Either cpp_files or cuda_files must be provided."
with_backend = bool(cuda_path_list)
assert with_cpp or with_backend, "Either cpp_files or cuda_files must be provided."

resolved_backend = _resolve_gpu_backend(backend) if with_backend else None
extra_ldflags_list = list(extra_ldflags) if extra_ldflags is not None else []
extra_cflags_list = list(extra_cflags) if extra_cflags is not None else []
extra_cuda_cflags_list = list(extra_cuda_cflags) if extra_cuda_cflags is not None else []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

With the addition of HIP support, using names like cuda_path_list, extra_cuda_cflags_list, and their corresponding function parameters (cuda_files, extra_cuda_cflags) is confusing as they are now used for both CUDA and HIP backends.

For internal functions like _build_impl, these should be renamed to more generic names like gpu_path_list and extra_gpu_cflags_list.

For public-facing functions (build, build_inline, etc.), to maintain backward compatibility, you could introduce new generic parameters (e.g., gpu_files, gpu_sources, extra_gpu_cflags) and treat the existing cuda_* parameters as aliases. You might also consider issuing a warning if cuda_* parameters are used with backend='hip'. This would greatly improve the clarity of the API for users.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This aligns with PyTorch C++ extension

DarkSharpness and others added 2 commits February 18, 2026 16:51
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
@tqchen tqchen requested a review from MasterJH5574 February 18, 2026 14:53
Copy link
Contributor

@MasterJH5574 MasterJH5574 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Tested locally and there is no issue.

@MasterJH5574
Copy link
Contributor

There is a lint issue that needs some quick fix before getting this in.

@MasterJH5574 MasterJH5574 merged commit 65b5e90 into apache:main Feb 18, 2026
8 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants

Comments