Skip to content

Commit c74000c

Browse files
authored
Merge branch 'master' into loss
2 parents f95a8f0 + 4833ccd commit c74000c

296 files changed

Lines changed: 25639 additions & 5831 deletions

File tree

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

README.md

Lines changed: 13 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ torch.permute
298298
"dims": "perm"
299299
}
300300
},
301-
"unsupport_args": {},
301+
"unsupport_args": [],
302302
"paddle_default_kwargs": {}
303303
}
304304
```
@@ -341,7 +341,7 @@ paddle_default_kwargs :可选,当 paddle 参数更多 或者 参数默认值
341341
}
342342
```
343343

344-
如果不属于上述分类,则需要开发 **自定义的Matcher**,命名标准为:`API名+Matcher` 。例如 `torch.transpose` 可命名为`TransposeMatcher``torch.Tensor.transpose` 可命名为 `TensorTransposeMatcher`详见下面步骤3
344+
如果不属于上述分类,则需要开发 **自定义的Matcher**,命名标准为:`API名+Matcher` 。例如 `torch.transpose` 可命名为`TransposeMatcher``torch.Tensor.transpose` 可命名为 `TensorTransposeMatcher`详见下面步骤
345345

346346
## 步骤4:编写Matcher(转换规则)
347347

@@ -399,7 +399,7 @@ class Chain_MatmulMatcher(BaseMatcher):
399399
```
400400
class TransposeMatcher(BaseMatcher):
401401
def generate_code(self, kwargs):
402-
API_TEMPLACE = textwrap.dedent(
402+
API_TEMPLATE = textwrap.dedent(
403403
'''
404404
{} = list(range(len({}.shape)))
405405
{}[{}] = {}
@@ -408,7 +408,7 @@ class TransposeMatcher(BaseMatcher):
408408
'''
409409
)
410410
perm = unique_name('perm')
411-
code = API_TEMPLACE.format(perm, kwargs['input'],
411+
code = API_TEMPLATE.format(perm, kwargs['input'],
412412
perm, kwargs['dim0'], kwargs['dim1'],
413413
perm, kwargs['dim1'], kwargs['dim0'],
414414
kwargs['input'], perm)
@@ -564,6 +564,14 @@ x.reshape(2, 3)
564564

565565
3) API功能缺失。如果是整个API都缺失的,只需在API映射表中标注 **功能缺失** 即可,无需其他开发。如果是API局部功能缺失,则对功能缺失点,在代码中返回None表示不支持,同时在API映射表中说明此功能点 **Paddle暂无转写方式**,同时编写单测但可以注释掉不运行;对其他功能点正常开发即可。
566566

567+
4) 别名实现。如果一个API是别名API(alias API),例如 `torch.nn.modules.GroupNorm``torch.nn.GroupNorm` 的别名,那么就无需编写相关 Matcher,只需在 `paconvert/api_alias_mapping.json` 中增加该别名 API 的配置,同时也无需增加相应单测文件,只需在主API的单测文件中增加 `test_alias_case_1/test_alias_case_2...` 即可。
568+
569+
```bash
570+
{
571+
"torch.nn.modules.GroupNorm": "torch.nn.GroupNorm"
572+
}
573+
```
574+
567575
### 开发技巧
568576

569577
1)可以参考一些写的较为规范的Matcher:
@@ -618,7 +626,7 @@ if x:
618626
619627
**单测写法**
620628
621-
* **单测位置**:所有的单测文件均放在`tests`目录下,单测文件命名以`test_`为前缀,后面接测试的`API`名称(PyTorch API全称去掉模块名,保留大小写)。例如 `torch.add` 命名为 `test_add.py``torch.Tensor.add` 命名为 `test_Tensor_add.py`
629+
* **单测位置**:所有的单测文件均放在`tests`目录下,单测文件命名以`test_`为前缀,后面接测试的`API`名称(PyTorch API名称即可,保留大小写,无需Module前缀)。例如 `torch.nn.functional.relu` 命名为 `test_relu.py``torch.Tensor.add` 命名为 `test_Tensor_add.py`
622630
623631
* **默认检查逻辑**:采用`pytest`作为单测框架。一般情况下,用户只需要在单测文件中调用 `APIBase` 类的 `run()` 方法,传入 `pytorch_code` 和需要判断的 `Tensor` 变量名列表即可,参考 [torch.permute测试用例](https://github.com/PaddlePaddle/PaConvert/tree/master/tests/test_permute.py)。 `run()` 方法会调用`compare()`函数,该方法默认检查逻辑为:转换前后两个`Tensor``计算数值、数据类型、stop_gradient属性、形状` 是否一致。
624632

README_EN.md

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ First you need to add **Matcher** in paconvert/api_matcher.py one by one, and ov
220220
221221
class TransposeMatcher(BaseMatcher):
222222
def generate_code(self, kwargs):
223-
API_TEMPLACE = textwrap.dedent(
223+
API_TEMPLATE = textwrap.dedent(
224224
'''
225225
{} = list(range(len({}.shape)))
226226
{}[{}] = {}
@@ -229,7 +229,7 @@ class TransposeMatcher(BaseMatcher):
229229
'''
230230
)
231231
perm = unique_name('perm')
232-
code = API_TEMPLACE.format(perm, kwargs['input'],
232+
code = API_TEMPLATE.format(perm, kwargs['input'],
233233
perm, kwargs['dim0'], kwargs['dim1'],
234234
perm, kwargs['dim1'], kwargs['dim0'],
235235
kwargs['input'], perm)

paconvert/api_alias_mapping.json

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,13 @@
11
{
22
"torch.nn.modules.GroupNorm": "torch.nn.GroupNorm",
3-
"torch.nn.parameter.Parameter": "torch.nn.Parameter",
43
"torch.nn.modules.activation.ReLU": "torch.nn.ReLU",
54
"torch.nn.modules.conv.Conv2d": "torch.nn.Conv2d",
6-
"torch.utils.data.sampler.BatchSampler": "torch.utils.data.BatchSampler",
75
"torch.nn.modules.module.Module": "torch.nn.Module",
6+
"torch.nn.parameter.Parameter": "torch.nn.Parameter",
7+
"torch.utils.data._utils.collate.default_collate": "torch.utils.data.default_collate",
8+
"torch.utils.data.dataloader.default_collate": "torch.utils.data.default_collate",
9+
"torch.utils.data.sampler.BatchSampler": "torch.utils.data.BatchSampler",
810
"torch.utils.data.sampler.RandomSampler": "torch.utils.data.RandomSampler",
9-
"torch.utils.data.sampler.SequentialSampler": "torch.utils.data.SequentialSampler",
10-
"torch.utils.data.sampler.Sampler": "torch.utils.data.Sampler"
11+
"torch.utils.data.sampler.Sampler": "torch.utils.data.Sampler",
12+
"torch.utils.data.sampler.SequentialSampler": "torch.utils.data.SequentialSampler"
1113
}

0 commit comments

Comments
 (0)