Skip to content

Commit 572112b

Browse files
committed
Update on "Tool to modify torchlib overload names via libcst"
[ghstack-poisoned]
2 parents c2577eb + 3521a1a commit 572112b

2 files changed

Lines changed: 63 additions & 30 deletions

File tree

onnxscript/function_libs/tools/torch_lib/modify_overload_names.py

Lines changed: 57 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import enum
44
import os
55
import pathlib
6+
import pprint
67
from typing import Dict, List, Set, Tuple
78

89
import libcst as cst
@@ -60,6 +61,16 @@ def visit_Call(self, node: cst.Call) -> None:
6061
if not matchers.matches(node.func, matchers.Name("torch_op")):
6162
return
6263

64+
# skip private ops
65+
if any(
66+
matchers.matches(
67+
arg,
68+
matchers.Arg(value=matchers.Name("True"), keyword=matchers.Name("private")),
69+
)
70+
for arg in node.args
71+
):
72+
return
73+
6374
function_name = self._stack[-1]
6475
overload_names = _cst_arg_to_overload_names(node.args[0])
6576
namespace_op_name = _overload_names_to_namespace_op(overload_names)
@@ -110,6 +121,16 @@ def leave_Call(self, original_node: cst.Call, updated_node: cst.Call) -> cst.Cal
110121
if not matchers.matches(original_node.func, matchers.Name("torch_op")):
111122
return original_node
112123

124+
# skip private ops
125+
if any(
126+
matchers.matches(
127+
arg,
128+
matchers.Arg(value=matchers.Name("True"), keyword=matchers.Name("private")),
129+
)
130+
for arg in original_node.args
131+
):
132+
return original_node
133+
113134
original_overload_names = _cst_arg_to_overload_names(original_node.args[0])
114135
namespace_op_name = _overload_names_to_namespace_op(original_overload_names)
115136
overload_names = self._overload_names[namespace_op_name][0][1]
@@ -144,37 +165,49 @@ def add_overload_names(
144165

145166

146167
def main():
147-
new_overload_names = {
148-
"aten::add.Tensor",
149-
"aten::clamp.Tensor",
150-
"aten::div.Tensor",
151-
"aten::eq.Scalar",
152-
"aten::eq.Tensor",
153-
"aten::fill.Tensor",
154-
"aten::ge.Scalar",
155-
"aten::ge.Tensor",
156-
"aten::gt.Scalar",
157-
"aten::le.Tensor",
158-
"aten::lt.Scalar",
159-
"aten::mul.Tensor",
160-
"aten::ne.Scalar",
161-
"aten::roll.default",
162-
"aten::rsub.Scalar",
163-
"aten::select.int",
164-
"aten::slice.Tensor",
165-
"aten::split.Tensor",
166-
"aten::sub.Tensor",
167-
"aten::transpose.int",
168-
"aten::unbind.int",
169-
"aten::where.self",
168+
new_overload_names_from_bench = {
169+
"aten.add.Tensor": 35510,
170+
"aten.bitwise_and.Tensor": 12,
171+
"aten.clamp.Tensor": 2690,
172+
"aten.div.Tensor": 10622,
173+
"aten.div.Tensor_mode": 2,
174+
"aten.empty.memory_format": 12486,
175+
"aten.eq.Scalar": 72,
176+
"aten.eq.Tensor": 112,
177+
"aten.fill.Tensor": 28,
178+
"aten.ge.Scalar": 4,
179+
"aten.ge.Tensor": 4,
180+
"aten.gt.Scalar": 46,
181+
"aten.le.Tensor": 32,
182+
"aten.lt.Scalar": 80,
183+
"aten.masked_fill.Scalar": 360,
184+
"aten.masked_fill.Tensor": 396,
185+
"aten.mul.Tensor": 24214,
186+
"aten.ne.Scalar": 630,
187+
"aten.pow.Tensor_Scalar": 528,
188+
"aten.pow.Tensor_Tensor": 1820,
189+
"aten.rsub.Scalar": 354,
190+
"aten.scatter_reduce.two": 18,
191+
"aten.select.int": 4669,
192+
"aten.slice.Tensor": 17717,
193+
"aten.split.Tensor": 3182,
194+
"aten.sub.Tensor": 7868,
195+
"aten.sum.dim_IntList": 6122,
196+
"aten.transpose.int": 13219,
197+
"aten.unbind.int": 1188,
198+
"aten.where.self": 732,
170199
}
200+
new_overload_names = set(
201+
{k.replace("aten.", "aten::") for k in new_overload_names_from_bench}
202+
)
171203
file_paths = [
172204
pathlib.Path(os.path.join(root, file))
173205
for root, dirs, files in os.walk("onnxscript/function_libs/torch_lib/ops")
174206
for file in files
175207
]
176208
for file_path in file_paths:
177-
print(add_overload_names(file_path, new_overload_names))
209+
print("Processing file:", file_path)
210+
pprint.pprint(add_overload_names(file_path, new_overload_names))
178211

179212

180213
if __name__ == "__main__":

onnxscript/function_libs/torch_lib/ops/core.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2299,7 +2299,7 @@ def aten_embedding_sparse_backward(
22992299
raise NotImplementedError()
23002300

23012301

2302-
@torch_op("aten::empty")
2302+
@torch_op(("aten::empty", "aten::empty.memory_format"))
23032303
def aten_empty(size: IntType, dtype: int = FLOAT.dtype) -> TTensor: # type: ignore[type-var]
23042304
# empty(SymInt[] size, *, ScalarType? dtype=None, Layout? layout=None, Device? device=None, bool? pin_memory=None, MemoryFormat? memory_format=None) -> Tensor
23052305

@@ -3957,7 +3957,7 @@ def aten_margin_ranking_loss(
39573957
raise NotImplementedError()
39583958

39593959

3960-
@torch_op("aten::masked_fill")
3960+
@torch_op(("aten::masked_fill", "aten::masked_fill.Scalar", "aten::masked_fill.Tensor"))
39613961
def aten_masked_fill(self: TTensor, mask: BOOL, value: TTensor) -> TTensor:
39623962
"""masked_fill.Tensor(Tensor self, Tensor mask, Tensor value) -> Tensor"""
39633963
# NOTE: Do not attempt to cast `mask` to BOOL because mask should not take any other types.
@@ -4883,7 +4883,7 @@ def aten_native_norm(self: TensorType, p: float = 2.0) -> TensorType:
48834883
raise NotImplementedError()
48844884

48854885

4886-
@torch_op(("aten::ne", "aten::ne.Scalar"))
4886+
@torch_op(("aten::ne", "aten::ne.Scalar", "aten::ne.Tensor"))
48874887
def aten_ne(self: TReal, other: TReal) -> BOOL:
48884888
"""ne.Tensor(Tensor self, Tensor other) -> Tensor"""
48894889

@@ -5223,7 +5223,7 @@ def aten_positive(self: TensorType) -> TensorType:
52235223
raise NotImplementedError()
52245224

52255225

5226-
@torch_op("aten::pow")
5226+
@torch_op(("aten::pow", "aten::pow.Tensor_Tensor", "aten::pow.Tensor_Scalar"))
52275227
def aten_pow(self: TReal, exponent: TTensor) -> TReal:
52285228
"""pow(Tensor self, Tensor exponent) -> Tensor"""
52295229

@@ -5785,7 +5785,7 @@ def aten_scatter_add(
57855785
return op.ScatterElements(self, index, src, axis=dim, reduction="add")
57865786

57875787

5788-
@torch_op("aten::scatter_reduce", trace_only=True)
5788+
@torch_op(("aten::scatter_reduce", "aten::scatter_reduce.two"), trace_only=True)
57895789
def aten_scatter_reduce(
57905790
self: TReal,
57915791
dim: int, # we have to use int here because ScatterElements() will use this attribute
@@ -6324,7 +6324,7 @@ def aten_subtract(self: TensorType, other: TensorType, alpha: float = 1.0) -> Te
63246324
raise NotImplementedError()
63256325

63266326

6327-
@torch_op("aten::sum", trace_only=True)
6327+
@torch_op(("aten::sum", "aten::sum.dim_IntList"), trace_only=True)
63286328
def aten_sum_dim_IntList(
63296329
self: TReal, dim: Optional[INT64] = None, keepdim: bool = False, dtype: int = -1
63306330
) -> TReal:

0 commit comments

Comments
 (0)