Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 20 additions & 39 deletions backends/vulkan/custom_ops_lib.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
Expand All @@ -8,7 +8,6 @@

import executorch.backends.vulkan.patterns as vk_patterns
import torch.library

from torch._subclasses.fake_tensor import FakeTensor

namespace = "et_vk"
Expand Down Expand Up @@ -259,7 +258,7 @@
weights, [1, group_size], weight_scales, weight_zeros, torch.int8, -8, 7
)

out = torch.nn.functional.linear(x, weights)
out = torch.nn.functional.linear(x, weights, bias)
return out


Expand All @@ -273,26 +272,23 @@
group_size: int,
bias: Optional[torch.Tensor] = None,
):
return linear_q4gsw(x, weights, weight_scales, group_size)
return linear_q4gsw(x, weights, weight_scales, group_size, bias)


name = "linear_q4gsw"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor self,
Tensor weights,
Tensor weight_scales,
int group_size,
Tensor? bias = None) -> Tensor
"""
)
""")
lib.impl(name, linear_q4gsw, "CompositeExplicitAutograd")
linear_qc4w_op = getattr(getattr(torch.ops, namespace), name)

name = "linear_dq8ca_q4gsw"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor input,
Tensor input_scales,
Expand All @@ -302,8 +298,7 @@
Tensor weight_scales,
int group_size,
Tensor? bias = None) -> Tensor
"""
)
""")
lib.impl(name, linear_dq8ca_q4gsw, "CompositeExplicitAutograd")
linear_dq8ca_q4gsw_op = getattr(getattr(torch.ops, namespace), name)

Expand Down Expand Up @@ -341,8 +336,7 @@


name = "linear_q8ta_q8csw"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor x,
float input_scale,
Expand All @@ -351,8 +345,7 @@
Tensor weight_sums,
Tensor weight_scales,
Tensor? bias = None) -> Tensor
"""
)
""")
lib.impl(name, linear_q8ta_q8csw, "CompositeExplicitAutograd")
qa_q8csw_linear = getattr(getattr(torch.ops, namespace), name)

Expand Down Expand Up @@ -403,8 +396,7 @@


name = "q8ta_linear"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor x,
float input_scale,
Expand All @@ -416,8 +408,7 @@
int output_zero_point,
Tensor? bias = None,
str activation = "none") -> Tensor
"""
)
""")
lib.impl(name, q8ta_linear, "CompositeExplicitAutograd")
q8ta_linear_op = getattr(getattr(torch.ops, namespace), name)

Expand Down Expand Up @@ -468,8 +459,7 @@


name = "q8ta_linear_gemv"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor x,
float input_scale,
Expand All @@ -481,8 +471,7 @@
int output_zero_point,
Tensor? bias = None,
str activation = "none") -> Tensor
"""
)
""")
lib.impl(name, q8ta_linear_gemv, "CompositeExplicitAutograd")
q8ta_linear_gemv_op = getattr(getattr(torch.ops, namespace), name)

Expand Down Expand Up @@ -560,8 +549,7 @@


name = "q8ta_conv2d"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor x,
float input_scale,
Expand All @@ -578,15 +566,13 @@
SymInt[] dilation,
SymInt groups,
str activation) -> Tensor
"""
)
""")
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
q8ta_conv2d_op = getattr(getattr(torch.ops, namespace), name)


name = "q8ta_conv2d_pw"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor x,
float input_scale,
Expand All @@ -603,8 +589,7 @@
SymInt[] dilation,
SymInt groups,
str activation) -> Tensor
"""
)
""")
lib.impl(name, q8ta_conv2d, "CompositeExplicitAutograd")
q8ta_conv2d_pw_op = getattr(getattr(torch.ops, namespace), name)

Expand Down Expand Up @@ -662,8 +647,7 @@


name = "q8ta_conv2d_dw"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor x,
float input_scale,
Expand All @@ -680,8 +664,7 @@
SymInt[] dilation,
SymInt groups,
str activation) -> Tensor
"""
)
""")
lib.impl(name, q8ta_conv2d_dw, "CompositeExplicitAutograd")
conv2d_q8ta_q8csw_dw_op = getattr(getattr(torch.ops, namespace), name)

Expand Down Expand Up @@ -760,8 +743,7 @@


name = "q8ta_conv2d_transposed"
lib.define(
f"""
lib.define(f"""
{name}(
Tensor x,
float input_scale,
Expand All @@ -779,8 +761,7 @@
SymInt[] dilation,
SymInt groups,
str activation) -> Tensor
"""
)
""")
lib.impl(name, q8ta_conv2d_transposed, "CompositeExplicitAutograd")
q8ta_conv2d_transposed_op = getattr(getattr(torch.ops, namespace), name)

Expand Down
14 changes: 3 additions & 11 deletions backends/vulkan/patterns/quantized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,28 +5,22 @@
# LICENSE file in the root directory of this source tree.

import operator

from typing import Optional

import executorch.backends.vulkan.utils as utils

import torch
import torch.nn.functional as F

from executorch.backends.transforms.utils import (
create_constant_placeholder,
get_param_tensor,
)

from executorch.backends.vulkan.patterns.pattern_registry import (
PatternMatch,
register_pattern_detector,
register_pattern_replacement,
)

from executorch.exir import ExportedProgram
from executorch.exir.dialects._ops import ops as exir_ops

from torch.export.graph_signature import InputKind


Expand Down Expand Up @@ -407,6 +401,7 @@ def make_linear_q4gsw_op(
match.weight_node,
match.weight_scales_node,
group_size,
match.bias_node,
),
)

Expand Down Expand Up @@ -474,6 +469,7 @@ def make_linear_dq8ca_q4gsw_op(
weight_sums_node,
match.weight_scales_node,
group_size,
match.bias_node,
),
)

Expand Down Expand Up @@ -538,6 +534,7 @@ def make_linear_q8ta_q8csw_custom_op(
match.weight_node,
weight_sums_node,
match.weight_scales_node,
match.bias_node,
),
)

Expand Down Expand Up @@ -637,7 +634,6 @@ def replace_quantized_linear_patterns(
assert weight_zeros_tensor is not None

# Route to appropriate custom op.
# q8ta_linear supports bias, so check it first before the bias guard.
if (
match.is_input_static_per_tensor_quantized()
and match.is_weight_perchannel_quantized()
Expand All @@ -646,10 +642,6 @@ def replace_quantized_linear_patterns(
make_q8ta_linear_custom_op(ep, graph_module, match, weight_tensor)
return

# Remaining ops do not support bias
if match.bias_node is not None:
return

if (
match.is_weight_only_quantized()
and match.is_weight_pergroup_quantized()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -144,5 +144,11 @@ void main() {
group_size);
}

if (apply_bias > 0) {
FPPerOutChannelParams bias_tile;
load_bias_tile(bias_tile, n4);
add_bias_to_out_tile(out_tile, bias_tile);
}

write_output_tile_with_checks(out_tile, n4, m, N4, M);
}
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,16 @@ void apply_weight_scales_and_biases(
}
}

void add_bias_to_out_tile(
inout FPOutTile tile,
const FPPerOutChannelParams bias) {
[[unroll]] for (int m = 0; m < TILE_M; ++m) {
[[unroll]] for (int n4 = 0; n4 < TILE_N4; ++n4) {
tile.data[m][n4] = tile.data[m][n4] + bias.data[n4];
}
}
}

void accumulate_out_tile_with_out_tile(
inout FPOutTile accum,
const FPOutTile other) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,11 @@ void main() {
// Only the first thread will write out result
if (lid == 0) {
out_tile = partial_sums[0];
if (apply_bias > 0) {
FPPerOutChannelParams bias_tile;
load_bias_tile(bias_tile, n4);
add_bias_to_out_tile(out_tile, bias_tile);
}
write_output_tile_with_checks(out_tile, n4, 0, N4, 1);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -110,5 +110,11 @@ void main() {
}
}

if (apply_bias > 0) {
FPPerOutChannelParams bias_tile;
load_bias_tile(bias_tile, n4);
add_bias_to_out_tile(out_tile, bias_tile);
}

write_output_tile_with_checks(out_tile, n4, m, N4, M);
}
16 changes: 5 additions & 11 deletions backends/vulkan/test/custom_ops/q4gsw_linear.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ TestCase create_test_case_from_config(
input_dtype,
storage_type,
utils::kWidthPacked,
DataGenType::ZEROS);
config.has_bias ? DataGenType::RANDOM : DataGenType::ZEROS);
bias.set_constant(true);
if (!config.has_bias) {
bias.set_none(true);
Expand Down Expand Up @@ -237,9 +237,10 @@ std::vector<TestCase> generate_quantized_linear_test_cases() {
{32, 64, 32, 16},
{32, 128, 64, 32},
{32, 256, 128, 64},
// No bias tests
{32, 128, 64, 32, false},
{32, 256, 128, 64, false},
// With bias
{4, 64, 32, 16, true},
{4, 128, 64, 32, true},
{32, 128, 64, 32, true},
// Performance test cases
{1, 2048, 2048, 128},
{128, 2048, 2048, 128},
Expand Down Expand Up @@ -499,13 +500,6 @@ void reference_impl(TestCase& test_case) {
}

int64_t quantized_linear_flop_calculator(const TestCase& test_case) {
int input_idx = 0;
int weight_idx = 1;
if (test_case.operator_name().find("dq8ca") != std::string::npos) {
input_idx = 0;
weight_idx = 3; // Weight comes after input, input_scale, input_zero_point
}

// Get input and weight dimensions
const auto& input_sizes = test_case.inputs()[0].get_tensor_sizes();
const auto& output_sizes = test_case.outputs()[0].get_tensor_sizes();
Expand Down
Loading