|
3 | 3 | import enum |
4 | 4 | import os |
5 | 5 | import pathlib |
| 6 | +import pprint |
6 | 7 | from typing import Dict, List, Set, Tuple |
7 | 8 |
|
8 | 9 | import libcst as cst |
@@ -60,6 +61,16 @@ def visit_Call(self, node: cst.Call) -> None: |
60 | 61 | if not matchers.matches(node.func, matchers.Name("torch_op")): |
61 | 62 | return |
62 | 63 |
|
| 64 | + # skip private ops |
| 65 | + if any( |
| 66 | + matchers.matches( |
| 67 | + arg, |
| 68 | + matchers.Arg(value=matchers.Name("True"), keyword=matchers.Name("private")), |
| 69 | + ) |
| 70 | + for arg in node.args |
| 71 | + ): |
| 72 | + return |
| 73 | + |
63 | 74 | function_name = self._stack[-1] |
64 | 75 | overload_names = _cst_arg_to_overload_names(node.args[0]) |
65 | 76 | namespace_op_name = _overload_names_to_namespace_op(overload_names) |
@@ -110,6 +121,16 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal |
110 | 121 | if not matchers.matches(original_node.func, matchers.Name("torch_op")): |
111 | 122 | return original_node |
112 | 123 |
|
| 124 | + # skip private ops |
| 125 | + if any( |
| 126 | + matchers.matches( |
| 127 | + arg, |
| 128 | + matchers.Arg(value=matchers.Name("True"), keyword=matchers.Name("private")), |
| 129 | + ) |
| 130 | + for arg in original_node.args |
| 131 | + ): |
| 132 | + return original_node |
| 133 | + |
113 | 134 | original_overload_names = _cst_arg_to_overload_names(original_node.args[0]) |
114 | 135 | namespace_op_name = _overload_names_to_namespace_op(original_overload_names) |
115 | 136 | overload_names = self._overload_names[namespace_op_name][0][1] |
@@ -144,37 +165,49 @@ def add_overload_names( |
144 | 165 |
|
145 | 166 |
|
146 | 167 | def main(): |
147 | | - new_overload_names = { |
148 | | - "aten::add.Tensor", |
149 | | - "aten::clamp.Tensor", |
150 | | - "aten::div.Tensor", |
151 | | - "aten::eq.Scalar", |
152 | | - "aten::eq.Tensor", |
153 | | - "aten::fill.Tensor", |
154 | | - "aten::ge.Scalar", |
155 | | - "aten::ge.Tensor", |
156 | | - "aten::gt.Scalar", |
157 | | - "aten::le.Tensor", |
158 | | - "aten::lt.Scalar", |
159 | | - "aten::mul.Tensor", |
160 | | - "aten::ne.Scalar", |
161 | | - "aten::roll.default", |
162 | | - "aten::rsub.Scalar", |
163 | | - "aten::select.int", |
164 | | - "aten::slice.Tensor", |
165 | | - "aten::split.Tensor", |
166 | | - "aten::sub.Tensor", |
167 | | - "aten::transpose.int", |
168 | | - "aten::unbind.int", |
169 | | - "aten::where.self", |
| 168 | + new_overload_names_from_bench = { |
| 169 | + "aten.add.Tensor": 35510, |
| 170 | + "aten.bitwise_and.Tensor": 12, |
| 171 | + "aten.clamp.Tensor": 2690, |
| 172 | + "aten.div.Tensor": 10622, |
| 173 | + "aten.div.Tensor_mode": 2, |
| 174 | + "aten.empty.memory_format": 12486, |
| 175 | + "aten.eq.Scalar": 72, |
| 176 | + "aten.eq.Tensor": 112, |
| 177 | + "aten.fill.Tensor": 28, |
| 178 | + "aten.ge.Scalar": 4, |
| 179 | + "aten.ge.Tensor": 4, |
| 180 | + "aten.gt.Scalar": 46, |
| 181 | + "aten.le.Tensor": 32, |
| 182 | + "aten.lt.Scalar": 80, |
| 183 | + "aten.masked_fill.Scalar": 360, |
| 184 | + "aten.masked_fill.Tensor": 396, |
| 185 | + "aten.mul.Tensor": 24214, |
| 186 | + "aten.ne.Scalar": 630, |
| 187 | + "aten.pow.Tensor_Scalar": 528, |
| 188 | + "aten.pow.Tensor_Tensor": 1820, |
| 189 | + "aten.rsub.Scalar": 354, |
| 190 | + "aten.scatter_reduce.two": 18, |
| 191 | + "aten.select.int": 4669, |
| 192 | + "aten.slice.Tensor": 17717, |
| 193 | + "aten.split.Tensor": 3182, |
| 194 | + "aten.sub.Tensor": 7868, |
| 195 | + "aten.sum.dim_IntList": 6122, |
| 196 | + "aten.transpose.int": 13219, |
| 197 | + "aten.unbind.int": 1188, |
| 198 | + "aten.where.self": 732, |
170 | 199 | } |
| 200 | + new_overload_names = set( |
| 201 | + {k.replace("aten.", "aten::") for k in new_overload_names_from_bench} |
| 202 | + ) |
171 | 203 | file_paths = [ |
172 | 204 | pathlib.Path(os.path.join(root, file)) |
173 | 205 | for root, dirs, files in os.walk("onnxscript/function_libs/torch_lib/ops") |
174 | 206 | for file in files |
175 | 207 | ] |
176 | 208 | for file_path in file_paths: |
177 | | - print(add_overload_names(file_path, new_overload_names)) |
| 209 | + print("Processing file:", file_path) |
| 210 | + pprint.pprint(add_overload_names(file_path, new_overload_names)) |
178 | 211 |
|
179 | 212 |
|
180 | 213 | if __name__ == "__main__": |
|
0 commit comments