Skip to content

Refine device by vibecoding#1790

Open
wenhuach21 wants to merge 7 commits into
mainfrom
refine_device
Open

Refine device by vibecoding#1790
wenhuach21 wants to merge 7 commits into
mainfrom
refine_device

Conversation

@wenhuach21
Copy link
Copy Markdown
Contributor

Description

Please briefly describe your main changes, the motivation.

Type of Change

Bug fix

Related Issues

Fixes or relates to #

Checklist Before Submitting

  • My code has been tested locally.
  • Documentation has been updated as needed.
  • New or updated tests are included where applicable.
  • The CUDA CI has passed. You can trigger it by commenting /azp run Unit-Test-CUDA-AutoRound.

Copilot AI review requested due to automatic review settings May 8, 2026 14:01
Copy link
Copy Markdown
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

This PR refactors device/accelerator handling across AutoRound into a pluggable backend registry (DeviceBackend) so device selection, memory accounting, OOM detection, and cache clearing can be extended to new accelerators without hard-coded cuda/xpu/hpu conditionals. It also adds a new fused Triton dequant+matmul implementation and a small Triton kernel dtype-safety fix.

Changes:

  • Introduces auto_round.utils.device_backend with a registry + built-in CPU/CUDA/XPU/HPU backends, and rewires device utilities to dispatch through it.
  • Updates multiple call sites (device detection, compile dispatch, memory clearing/monitoring, diffusion/MLLM helpers, delta-loss sync points) to use backend APIs instead of direct torch.cuda/xpu/hpu branching.
  • Adds a new fused Triton kernel path (fused_matmul.py) and fixes dtype alignment in an existing Triton kernel.

Reviewed changes

Copilot reviewed 16 out of 16 changed files in this pull request and generated 6 comments.

Show a summary per file
File Description
auto_round/utils/device.py Switches device utilities to backend-registry dispatch (device selection, compile, memory clear/monitor, OOM detection).
auto_round/utils/device_backend.py Adds the pluggable backend abstraction + registry and built-in CPU/CUDA/XPU/HPU backends.
auto_round/utils/__init__.py Re-exports backend APIs from device_backend for public access.
auto_round/inference/convert_model.py Sources available devices from backend registry (keeps legacy MPS probe).
auto_round/compressors/mllm/compressor.py Uses backend memory accounting / cache clearing when deciding whether to run on accelerator.
auto_round/compressors/diffusion/compressor.py Generalizes “GPU/XPU” checks/messages to “accelerator” via backend helpers.
auto_round/compressors/base.py Ensures mllm forces plain block_forward (no compile) in legacy compressor path.
auto_round/compressors_new/diffusion_mixin.py Generalizes diffusion offload checks/messages via backend helpers.
auto_round/compressors_new/base.py Extends torch.compile eligibility guard to include MLLM path.
auto_round/auto_scheme/delta_loss.py Synchronizes via the selected backend instead of CUDA-only sync.
auto_round/algorithms/transforms/rotation/inplace/hooks.py Resolves compute device via backend auto-selection instead of CUDA/XPU cascade.
auto_round/algorithms/transforms/rotation/inplace/apply.py Uses backend-selected cache clearing instead of CUDA-only empty_cache.
auto_round/algorithms/quantization/base.py Disables/comment-outs compiled block_forward resolution in base quantizer.
auto_round_extension/triton/triton_utils/kernels.py Ensures dequantized b matches a.dtype for tl.dot safety (bf16).
auto_round_extension/triton/triton_utils/fused_matmul.py Adds a new fused Triton dequant+matmul (+ split-K) implementation.
_smoke_device.py Adds a standalone smoke script to exercise backend refactor behavior.

Comment on lines 624 to +642
if isinstance(device_list, (str, torch.device)):
device_list = [device_list]

# -----------------------------------
# CUDA-specific clearing
# -----------------------------------
if torch.cuda.is_available():
# No device_list → clear all GPUs
if not device_list:
# Fix https://github.com/intel/auto-round/issues/1004
torch.cuda.synchronize()
torch.cuda.empty_cache()
else:
# Parse valid CUDA device IDs
devices = []
for dev in device_list:
dev = str(dev)
if not dev.startswith("cuda"):
continue
# cuda / cuda:0 / cuda:1
if ":" in dev:
devid = int(dev.split(":")[-1])
else:
devid = 0
devices.append(devid)

for d in devices:
torch.cuda.synchronize(d)
per_backend: dict[str, list] = {}
if device_list:
for dev in device_list:
dev_str = str(dev)
if dev_str.startswith("cpu"):
continue
# type[:idx]
if ":" in dev_str:
dtype, idx = dev_str.split(":", 1)
try:
idx_val = int(idx)
except ValueError:
idx_val = 0
else:
dtype, idx_val = dev_str, None
per_backend.setdefault(dtype, []).append(idx_val)
current_vram = 0.0
if current_vram <= 0:
continue
key = str(idx)
Comment on lines +654 to +660
mod = self.torch_module
if mod is None:
return 0
try:
return int(mod.memory_cached(index))
except Exception: # pragma: no cover
return 0
Comment on lines +60 to +62
# Built dynamically from the backend registry so a newly-registered device
# (e.g. NPU) automatically appears here without editing this file.
DEVICE_ENVIRON_VARIABLE_MAPPING = get_visible_devices_env_mapping()
this registry, so no other file needs to be edited when a new device
type is added.

See ``docs/adding_new_device.md`` for a concrete worked example.
Comment on lines +380 to +385
# elif self.compress_context.enable_torch_compile:
# compiled = self.__dict__.get("_compiled_block_forward")
# if compiled is None:
# compiled = compile_func(block_forward, self.compress_context.device)
# self._compiled_block_forward = compiled
# self._resolved_block_forward = compiled
wenhuach21 added 2 commits May 9, 2026 08:11
…nto refine_device

# Conflicts:
#	_smoke_device.py
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants