Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions src/s_tir/transform/lower_thread_allreduce.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,7 +510,9 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
//
// The former may cause dead lock as there is a divergent
// branch with a warp sync call inside.
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset);
bool cast_offset_to_uint = target_->kind->name == "webgpu";
Comment thread
ksgr5566 marked this conversation as resolved.
Outdated
PrimExpr other = WarpShuffle(builtin::tvm_warp_shuffle_down(), mask_buffer, val, offset,
cast_offset_to_uint);
Comment thread
ksgr5566 marked this conversation as resolved.
Outdated
Buffer local_buf = local_bufs[i];
Stmt s = BufferStore(local_buf, other, zero_indices);
seq->push_back(s);
Expand Down Expand Up @@ -699,7 +701,10 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {

// Emit warp shuffle calls.
PrimExpr WarpShuffle(const Op& op, ffi::Optional<Buffer> mask_buffer, PrimExpr val,
PrimExpr delta_or_lane) {
PrimExpr delta_or_lane, bool cast_delta_to_uint = false) {
if (cast_delta_to_uint) {
delta_or_lane = cast(DataType::UInt(32, delta_or_lane.dtype().lanes()), delta_or_lane);
}
ffi::Array<PrimExpr> indices = {0};
PrimExpr mask;
if (mask_buffer.defined()) {
Expand All @@ -719,11 +724,11 @@ class ThreadAllreduceBuilder final : public StmtExprMutator {
bool IsWarpReduction(const std::vector<DataType>& types, int group_extent, int reduce_extent,
int contiguous_reduce_extent) {
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
(target_->kind->name != "metal") && (target_->kind->name != "webgpu")) {
return false;
}

Comment thread
ksgr5566 marked this conversation as resolved.
need_warp_shuffle_mask_ = target_->kind->name != "metal";
need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu";
Comment on lines 744 to +749
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

To improve maintainability, consider using std::unordered_set for checking the target kind. This makes it easier to add or remove supported targets in the future.

Suggested change
if ((target_->kind->name != "cuda") && (target_->kind->name != "rocm") &&
(target_->kind->name != "metal")) {
(target_->kind->name != "metal") && (target_->kind->name != "webgpu")) {
return false;
}
need_warp_shuffle_mask_ = target_->kind->name != "metal";
need_warp_shuffle_mask_ = target_->kind->name != "metal" && target_->kind->name != "webgpu";
const std::unordered_set<std::string> supported_targets = {"cuda", "rocm", "metal", "webgpu"};
if (!supported_targets.count(target_->kind->name)) {
return false;
}
const std::unordered_set<std::string> no_mask_targets = {"metal", "webgpu"};
need_warp_shuffle_mask_ = !no_mask_targets.count(target_->kind->name);


// rocm only supports 32 bit operands for shuffling at the moment
if ((target_->kind->name == "rocm") &&
Expand Down
7 changes: 6 additions & 1 deletion src/target/source/codegen_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,9 @@ std::string CodeGenWebGPU::Finish() {
if (enable_fp16_) {
header_stream << "enable f16;\n\n";
}
if (enable_subgroups_) {
header_stream << "enable subgroups;\n\n";
}
return header_stream.str() + decl_stream.str() + this->fwd_decl_stream.str() + stream.str();
}

Expand All @@ -120,7 +123,9 @@ void CodeGenWebGPU::InitFuncState(const PrimFunc& f) {
}
}

CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {}
CodeGenWebGPU::CodeGenWebGPU(Target target) : target_(target) {
enable_subgroups_ = target_->GetAttr<Bool>("supports_subgroups").value_or(Bool(false));
Comment thread
ksgr5566 marked this conversation as resolved.
Comment thread
ksgr5566 marked this conversation as resolved.
}

runtime::FunctionInfo CodeGenWebGPU::AddFunction(const PrimFunc& f, bool skip_readonly_decl) {
// clear previous generated state.
Expand Down
2 changes: 2 additions & 0 deletions src/target/source/codegen_webgpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,8 @@ class CodeGenWebGPU final : public CodeGenC {

// whether enable fp16
bool enable_fp16_{false};
// whether enable subgroups
bool enable_subgroups_{false};

/*! \brief the header stream for function label and enable directive if any, goes before any other
* declaration */
Expand Down
56 changes: 56 additions & 0 deletions src/target/source/intrin_rule_webgpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,29 @@ namespace intrin {

using tir::FLowerIntrinsic;

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
struct WebGPUWarpIntrinsic {
const Op operator()(DataType t, const Op& orig_op) const {
if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
return Op::Get("tir.webgpu.subgroup_shuffle");
} else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
return Op::Get("tir.webgpu.subgroup_shuffle_up");
} else {
ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
return Op::Get("tir.webgpu.subgroup_shuffle_down");
}
}
};

template <typename T>
static PrimExpr DispatchWebGPUShuffle(const PrimExpr& e) {
const CallNode* call = e.as<CallNode>();
ICHECK(call != nullptr);
ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
ffi::Array<PrimExpr> webgpu_args{{call->args[1], call->args[2]}};
Comment thread
ksgr5566 marked this conversation as resolved.
Outdated
return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), webgpu_args);
Comment thread
ksgr5566 marked this conversation as resolved.
Outdated
}

// See full list of builtin: https://www.w3.org/TR/WGSL/#builtin-functions

struct ReturnAbs {
Expand Down Expand Up @@ -113,6 +136,39 @@ TVM_REGISTER_OP("tir.trunc")
// extra dispatch
TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchFastErf);

// warp-level primitives. Follows implementation in intrin_rule_metal.cc
TVM_REGISTER_OP("tir.tvm_warp_shuffle")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
.set_attr<FLowerIntrinsic>("webgpu.FLowerIntrinsic", DispatchWebGPUShuffle<WebGPUWarpIntrinsic>);

// Register low-level builtin ops.
TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("lane", "Expr", "The source thread id.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffle")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_up")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be added.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleUp")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));

TVM_REGISTER_OP("tir.webgpu.subgroup_shuffle_down")
.set_num_inputs(2)
.add_argument("var", "Expr", "The variable to sync.")
.add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
.set_attr<TGlobalSymbol>("TGlobalSymbol", "subgroupShuffleDown")
.set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque));


} // namespace intrin
} // namespace codegen
} // namespace tvm
19 changes: 19 additions & 0 deletions src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -427,8 +427,27 @@ TVM_REGISTER_TARGET_KIND("vulkan", kDLVulkan)
// Tags
.set_default_keys({"vulkan", "gpu"});

/*!
* \brief Update WebGPU target attributes based on subgroup support.
* When supports_subgroups is true, set thread_warp_size to 32 so that
* TIR lowering uses warp-level shuffle reductions instead of shared memory.
*/
TargetJSON UpdateWebGPUAttrs(TargetJSON target) {
if (target.count("supports_subgroups")) {
bool subgroups = Downcast<Bool>(target.at("supports_subgroups"));
if (subgroups) {
Comment thread
ksgr5566 marked this conversation as resolved.
Outdated
target.Set("thread_warp_size", refl::DefaultValue(32));
}
}
return target;
}

TVM_REGISTER_TARGET_KIND("webgpu", kDLWebGPU)
.add_attr_option<int64_t>("max_num_threads", refl::DefaultValue(256))
.add_attr_option<bool>("supports_subgroups", refl::DefaultValue(false))
// thread_warp_size=1: is_subwarp_reduction and is_multiwarp_reduction returns false, so no subgroup ops are emitted.
.add_attr_option<int64_t>("thread_warp_size", refl::DefaultValue(1))
.set_target_parser(UpdateWebGPUAttrs)
.set_default_keys({"webgpu", "gpu"});
Comment thread
ksgr5566 marked this conversation as resolved.

TVM_REGISTER_TARGET_KIND("hexagon", kDLHexagon)
Expand Down
3 changes: 3 additions & 0 deletions web/src/webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,9 @@ export async function detectGPUDevice(powerPreference: "low-power" | "high-perfo
if (adapter.features.has("shader-f16")) {
requiredFeatures.push("shader-f16");
}
if (adapter.features.has("subgroups")) {
requiredFeatures.push("subgroups");
}

// requestAdapterInfo() is deprecated, causing requestAdapterInfo to raise
// issue when building. However, it is still needed for older browsers, hence `as any`.
Expand Down