Skip to content

Commit f477069

Browse files
committed
Add fused_batch_norm_add_relu and fix a bug.
1 parent 8923bcc commit f477069

2 files changed

Lines changed: 97 additions & 2 deletions

File tree

api/deploy/collect_api_info.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ def _import_api(test_module_name, basename):
7272
for filename in sorted(os.listdir(tests_path)):
7373
api_name = os.path.splitext(filename)[0]
7474
file_extension = os.path.splitext(filename)[1]
75-
if file_extension == '.py' and _is_special_module(api_name):
75+
if file_extension == '.py' and not _is_special_module(api_name):
7676
module = _import_api(test_module_name, api_name)
7777
if module:
7878
test_cases_dict[api_name] = module
@@ -138,7 +138,7 @@ def main(args):
138138
parser.add_argument(
139139
'--test_module_name',
140140
type=str,
141-
default="tests",
141+
default="tests_v2",
142142
help='The module_name under benchmark/api (tests|tests_v2|dynamic_tests_v2).'
143143
)
144144
parser.add_argument(
Lines changed: 95 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from common_import import *
16+
from batch_norm import BatchNormConfig
17+
18+
19+
class FusedBatchNormAddReluConfig(BatchNormConfig):
20+
def __init__(self):
21+
super(FusedBatchNormAddReluConfig,
22+
self).__init__("fused_batch_norm_add_relu")
23+
self.alias_name = "batch_norm"
24+
25+
26+
class PDFusedBatchNormAddRelu(PaddleAPIBenchmarkBase):
27+
def build_program(self, config):
28+
def _create_parameter(name, value, stop_gradient):
29+
param = paddle.create_parameter(
30+
name=name,
31+
shape=[config.num_channels],
32+
dtype=config.x_dtype,
33+
attr=paddle.ParamAttr(
34+
initializer=paddle.nn.initializer.Constant(value)))
35+
param.stop_gradient = stop_gradient
36+
return param
37+
38+
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
39+
y = self.variable(name='y', shape=config.x_shape, dtype=config.x_dtype)
40+
41+
running_mean = _create_parameter(
42+
name='running_mean', value=0.5, stop_gradient=True)
43+
running_var = _create_parameter(
44+
name='running_var', value=0.1, stop_gradient=True)
45+
46+
scale = _create_parameter(name='scale', value=0.5, stop_gradient=False)
47+
bias = _create_parameter(name='bias', value=0.1, stop_gradient=False)
48+
49+
bn_out = paddle.nn.functional.batch_norm(
50+
x=x,
51+
running_mean=running_mean,
52+
running_var=running_var,
53+
weight=scale,
54+
bias=bias,
55+
epsilon=config.epsilon,
56+
momentum=config.momentum,
57+
training=config.training,
58+
data_format=config.data_format)
59+
add_out = bn_out + y
60+
relu_out = paddle.nn.functional.relu(add_out)
61+
62+
self.feed_vars = [x, y]
63+
self.fetch_vars = [bn_out, add_out, relu_out]
64+
if config.backward:
65+
self.append_gradients(relu_out, [x, scale, bias, bn_out, add_out])
66+
67+
68+
class TFFusedBatchNormAddRelu(TensorflowAPIBenchmarkBase):
69+
def build_graph(self, config):
70+
x = self.variable(name='x', shape=config.x_shape, dtype=config.x_dtype)
71+
y = self.variable(name='y', shape=config.x_shape, dtype=config.x_dtype)
72+
bn = tf.keras.layers.BatchNormalization(
73+
axis=config.axis,
74+
momentum=config.momentum,
75+
epsilon=config.epsilon,
76+
beta_initializer=tf.constant_initializer(0.1),
77+
gamma_initializer=tf.constant_initializer(0.5),
78+
moving_mean_initializer=tf.constant_initializer(0.5),
79+
moving_variance_initializer=tf.constant_initializer(0.1))
80+
bn_out = bn(x, training=config.training)
81+
add_out = bn_out + y
82+
relu_out = tf.nn.relu(add_out)
83+
84+
self.feed_list = [x, y]
85+
self.fetch_list = [bn_out, add_out, relu_out]
86+
if config.backward:
87+
self.append_gradients(relu_out,
88+
[x, bn.gamma, bn.beta, bn_out, add_out])
89+
90+
91+
if __name__ == '__main__':
92+
test_main(
93+
PDFusedBatchNormAddRelu(),
94+
TFFusedBatchNormAddRelu(),
95+
config=FusedBatchNormAddReluConfig())

0 commit comments

Comments
 (0)