@@ -298,7 +298,7 @@ torch.permute
298298 " dims" : " perm"
299299 }
300300 },
301- " unsupport_args" : {} ,
301+ " unsupport_args" : [] ,
302302 " paddle_default_kwargs" : {}
303303}
304304```
@@ -341,7 +341,7 @@ paddle_default_kwargs :可选,当 paddle 参数更多 或者 参数默认值
341341}
342342```
343343
344- 如果不属于上述分类,则需要开发 ** 自定义的Matcher** ,命名标准为:` API名+Matcher ` 。例如 ` torch.transpose ` 可命名为` TransposeMatcher ` ,` torch.Tensor.transpose ` 可命名为 ` TensorTransposeMatcher ` 。详见下面步骤3 。
344+ 如果不属于上述分类,则需要开发 ** 自定义的Matcher** ,命名标准为:` API名+Matcher ` 。例如 ` torch.transpose ` 可命名为` TransposeMatcher ` ,` torch.Tensor.transpose ` 可命名为 ` TensorTransposeMatcher ` 。详见下面步骤 。
345345
346346## 步骤4:编写Matcher(转换规则)
347347
@@ -399,7 +399,7 @@ class Chain_MatmulMatcher(BaseMatcher):
399399```
400400class TransposeMatcher(BaseMatcher):
401401 def generate_code(self, kwargs):
402- API_TEMPLACE = textwrap.dedent(
402+ API_TEMPLATE = textwrap.dedent(
403403 '''
404404 {} = list(range(len({}.shape)))
405405 {}[{}] = {}
@@ -408,7 +408,7 @@ class TransposeMatcher(BaseMatcher):
408408 '''
409409 )
410410 perm = unique_name('perm')
411- code = API_TEMPLACE .format(perm, kwargs['input'],
411+ code = API_TEMPLATE .format(perm, kwargs['input'],
412412 perm, kwargs['dim0'], kwargs['dim1'],
413413 perm, kwargs['dim1'], kwargs['dim0'],
414414 kwargs['input'], perm)
@@ -564,6 +564,14 @@ x.reshape(2, 3)
564564
5655653 ) API功能缺失。如果是整个API都缺失的,只需在API映射表中标注 ** 功能缺失** 即可,无需其他开发。如果是API局部功能缺失,则对功能缺失点,在代码中返回None表示不支持,同时在API映射表中说明此功能点 ** Paddle暂无转写方式** ,同时编写单测但可以注释掉不运行;对其他功能点正常开发即可。
566566
567+ 4 ) 别名实现。如果一个API是别名API(alias API),例如 ` torch.nn.modules.GroupNorm ` 是 ` torch.nn.GroupNorm ` 的别名,那么就无需编写相关 Matcher,只需在 ` paconvert/api_alias_mapping.json ` 中增加该别名 API 的配置,同时也无需增加相应单测文件,只需在主API的单测文件中增加 ` test_alias_case_1/test_alias_case_2... ` 即可。
568+
569+ ``` bash
570+ {
571+ " torch.nn.modules.GroupNorm" : " torch.nn.GroupNorm"
572+ }
573+ ```
574+
567575# ## 开发技巧
568576
5695771)可以参考一些写的较为规范的Matcher:
@@ -618,7 +626,7 @@ if x:
618626
619627** 单测写法** :
620628
621- * ** 单测位置** :所有的单测文件均放在` tests ` 目录下,单测文件命名以` test_ ` 为前缀,后面接测试的` API ` 名称(PyTorch API全称去掉模块名 ,保留大小写)。例如 ` torch.add ` 命名为 ` test_add .py` , ` torch.Tensor.add ` 命名为 ` test_Tensor_add.py ` 。
629+ * ** 单测位置** :所有的单测文件均放在` tests` 目录下,单测文件命名以` test_` 为前缀,后面接测试的` API` 名称(PyTorch API名称即可 ,保留大小写,无需Module前缀 )。例如 ` torch.nn.functional.relu ` 命名为 ` test_relu .py` , ` torch.Tensor.add` 命名为 ` test_Tensor_add.py` 。
622630
623631* ** 默认检查逻辑** :采用` pytest` 作为单测框架。一般情况下,用户只需要在单测文件中调用 ` APIBase` 类的 ` run()` 方法,传入 ` pytorch_code` 和需要判断的 ` Tensor` 变量名列表即可,参考 [torch.permute测试用例](https://github.com/PaddlePaddle/PaConvert/tree/master/tests/test_permute.py)。 ` run()` 方法会调用`compare ()` 函数,该方法默认检查逻辑为:转换前后两个` Tensor` 的 ` 计算数值、数据类型、stop_gradient属性、形状` 是否一致。
624632
0 commit comments