-
Notifications
You must be signed in to change notification settings - Fork 966
Expand file tree
/
Copy pathWhere.cpp
More file actions
91 lines (78 loc) · 2.46 KB
/
Where.cpp
File metadata and controls
91 lines (78 loc) · 2.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
// Where.cpp
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
* All rights reserved.
*
* This source code is licensed under the BSD-style license found in the
* LICENSE file in the root directory of this source tree.
*/
#include <executorch/backends/vulkan/runtime/graph/ops/OperatorRegistry.h>
#include <executorch/backends/vulkan/runtime/graph/ops/impl/Common.h>
#include <executorch/backends/vulkan/runtime/graph/ops/utils/ShaderNameUtils.h>
namespace vkcompute {
void resize_where_node(
ComputeGraph* graph,
const std::vector<ArgGroup>& args,
const std::vector<ValueRef>& extra_args) {
(void)extra_args;
const ValueRef out = args.at(0).refs.at(0);
std::vector<int64_t> out_sizes;
for (const ValueRef ref : args.at(1).refs) {
if (!graph->val_is_tensor(ref)) {
continue;
}
const std::vector<int64_t> s = graph->sizes_of(ref);
if (s.size() > out_sizes.size()) {
out_sizes.resize(s.size(), 1);
}
const size_t offset = out_sizes.size() - s.size();
for (size_t i = 0; i < s.size(); i++) {
out_sizes[offset + i] = std::max(out_sizes[offset + i], s[i]);
}
}
graph->virtual_resize(out, out_sizes);
}
void add_where_node(
ComputeGraph& graph,
const ValueRef cond,
const ValueRef self,
const ValueRef other,
const ValueRef out) {
std::string kernel_name = "where";
add_storage_type_suffix(kernel_name, graph.storage_type_of(out));
add_dtype_suffix(kernel_name, graph.dtype_of(out));
vkapi::ParamsBindList ubos = {
graph.meta_ubo(out),
graph.meta_ubo(cond),
graph.meta_ubo(self),
graph.meta_ubo(other)};
graph.execute_nodes().emplace_back(new DynamicDispatchNode(
graph,
VK_KERNEL_FROM_STR(kernel_name),
default_pick_global_wg_size,
default_pick_local_wg_size,
// Inputs and Outputs
{{out, vkapi::kWrite}, {{cond, self, other}, vkapi::kRead}},
// Parameter buffers
ubos,
// Push Constants
{},
// Specialization Constants
{},
// Resize Arguments
{},
// Resizing Logic
resize_where_node));
}
void where(ComputeGraph& graph, const std::vector<ValueRef>& args) {
int args_i = 0;
const ValueRef cond = args[args_i++];
const ValueRef self = args[args_i++];
const ValueRef other = args[args_i++];
const ValueRef out = args[args_i++];
add_where_node(graph, cond, self, other, out);
}
REGISTER_OPERATORS {
VK_REGISTER_OP(aten.where.self, where);
}
} // namespace vkcompute