Conversation
paconvert/api_matcher.py
Outdated
|
|
||
|
|
||
| class MseLossMatcher(BaseMatcher): | ||
| def generate_code(self, kwargs): |
There was a problem hiding this comment.
感觉这一块的重复度很高,是否可以统一成一个Matcher
这个文档的链接不对,不是loss的 |
|
PR描述里,docs直接放PR链接,不要写成软连接形式 |
paconvert/api_mapping.json
Outdated
| "reduction" | ||
| ] | ||
| }, | ||
| "torch.nn.Unfold": { |
There was a problem hiding this comment.
这个可以用genericmatcher吧,改成那个吧
There was a problem hiding this comment.
kernel_size 参数 pytorch支持tuple,paddle不支持,改为genericmatcher遇到tuple会报错
paconvert/api_mapping.json
Outdated
| "stride": "strides" | ||
| } | ||
| }, | ||
| "torch.nn.functional.unfold": { |
There was a problem hiding this comment.
这个可以用genericmatcher吧,改成那个吧
paconvert/api_matcher.py
Outdated
| for key in list(kwargs_change.keys()): | ||
| if key in kwargs: | ||
| if "input" not in key: | ||
| if "(" in kwargs[key]: |
There was a problem hiding this comment.
直接判断 if isinstance(kwargs[key] , ast.Tuple): 吧
There was a problem hiding this comment.
这个在generate_code直接这样判断似乎不起作用
paconvert/api_matcher.py
Outdated
| kwargs[kwargs_change[key]] = kwargs[key] | ||
| kwargs.pop(key) | ||
|
|
||
| if "paddings" not in kwargs: |
paconvert/api_matcher.py
Outdated
| kwargs_change = self.api_mapping["kwargs_change"] | ||
| for key in list(kwargs_change.keys()): | ||
| if key in kwargs: | ||
| if "input" not in key: |
paconvert/api_matcher.py
Outdated
| return GenericMatcher.generate_code(self, kwargs) | ||
|
|
||
|
|
||
| class UnfoldMatcher(BaseMatcher): |
There was a problem hiding this comment.
这个可以起一个通用的名字,这个主要功能是把tuple转成list:
可以叫Tuple2ListMatcher
|
|
||
|
|
||
| class UnfoldMatcher(BaseMatcher): | ||
| def generate_code(self, kwargs): |
There was a problem hiding this comment.
逻辑可以写成对每个kwargs遍历,判断是否kwargs,每个分支里再判断是否list,一共4个分支。用new_kwargs来接收kwargs,不然参数顺序会改变,导致代码风格不太好
for k in list(kwargs.keys()):
if kwargs_change:
if tuple:
else:
else:
if tuple:
else:
paconvert/api_mapping.json
Outdated
| "device", | ||
| "requires_grad", | ||
| "memory_format" | ||
| "index" |
paconvert/api_matcher.py
Outdated
| for k in list(kwargs.keys()): | ||
| if k in kwargs_change: | ||
| if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple): | ||
| new_kwargs[kwargs_change[k]] = list(ast.literal_eval(kwargs[k])) |
There was a problem hiding this comment.
判断前面这个就可以,直接 'list({})'.format(kwargs[k]) 吧,尽量不要用eval执行这些逻辑,容易存在隐患
paconvert/api_matcher.py
Outdated
| new_kwargs[kwargs_change[k]] = kwargs[k] | ||
| else: | ||
| if "(" in kwargs[k] and isinstance(ast.literal_eval(kwargs[k]), tuple): | ||
| new_kwargs[k] = list(ast.literal_eval(kwargs[k])) |
| obj.run(pytorch_code, ["result"]) | ||
|
|
||
|
|
||
| def _test_case_7(): |
| obj.run(pytorch_code, ["result"]) | ||
|
|
||
|
|
||
| def _test_case_7(): |
PR Docs
PaddlePaddle/docs#5928
PR APIs