1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import os
19import sys
20import mxnet as mx
21import numpy as np
22from random import randint
23import warnings
24import collections
25import ctypes
26import itertools
27import mxnet.contrib.amp as amp
28from nose.tools import assert_raises
29from mxnet.test_utils import set_default_context, download_model, same_symbol_structure, assert_almost_equal_with_err, rand_shape_nd
30from mxnet.gluon.model_zoo.vision import get_model
31from mxnet.gluon import SymbolBlock, nn, rnn
32from mxnet.contrib.amp import amp
33curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
34sys.path.insert(0, os.path.join(curr_path, '../unittest'))
35from common import with_seed
36import unittest
37
38bfloat16 = np.dtype([('bfloat16', np.uint16)])
39
40def check_operator_accuracy(sym_fp32, sym_bf16, data_shape, num_input_data=1, bf16_use_fp32_params=False, rtol=1e-1, atol=5e-1, etol=0):
41    """
42    check accuracy for bfloat16 operators
43
44    sym_fp32: Symbol
45        fp32 operator
46    sym_bf16: Symbol
47        bf16 operator
48    data_shape: tuple of int
49        input data shape for fp32/bf16 symbol
50    num_input_data: int
51        number of input data, default is 1, should set different values for those operators with multiple inputs, like concat, elemwise_add, etc.
52    bf16_use_fp32_params: bool
53        currently only bn use this param as True, since bf16 bn only accept bf16 data with fp32 mean/var/scale/shift
54    rtol: float
55        the relative threshold
56    atol: float
57        the absolute threshold
58    etol: float
59        The error rate threshold, allow a small amount of value not consistent between bf16 and fp32
60    """
61    if not isinstance(data_shape, tuple):
62        data_shape = tuple(data_shape)
63    data_range = (0.0, 10.0)
64    data_list_fp32 = list()
65    data_list_bf16 = list()
66    for i in range(num_input_data):
67        data_list_fp32.append(mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=data_shape))
68        data_list_bf16.append(mx.nd.amp_cast(data_list_fp32[i], dtype=bfloat16))
69
70    arg_shapes, _, aux_shapes = sym_fp32.infer_shape(data=data_shape)
71    arg_names = sym_fp32.list_arguments()
72    aux_names = sym_fp32.list_auxiliary_states()
73
74    exe_fp32 = sym_fp32.simple_bind(ctx=mx.cpu(), data=data_shape)
75
76    arg_params_fp32 = {}
77    aux_params_fp32 = {}
78    type_dict = {}
79    for i, arg_name in enumerate(arg_names):
80        if i < num_input_data:
81            exe_fp32.arg_dict[arg_name][:] = data_list_fp32[i]
82            continue
83        arg_params_fp32[arg_name] = mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=arg_shapes[i])
84        exe_fp32.arg_dict[arg_name][:] = arg_params_fp32[arg_name]
85        # specify the dtype of arguments
86        if not bf16_use_fp32_params:
87            type_dict.update({arg_name: bfloat16})
88
89    for i, aux_name in enumerate(aux_names):
90        aux_params_fp32[aux_name] = mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=aux_shapes[i])
91        exe_fp32.aux_dict[aux_name][:] = aux_params_fp32[aux_name]
92
93    output_fp32 = exe_fp32.forward()[0]
94
95    exe_bf16 = sym_bf16.simple_bind(ctx=mx.cpu(), data=data_shape, type_dict=type_dict)
96
97    arg_params_bf16 = {}
98    aux_params_bf16 = {}
99    for i, arg_name in enumerate(arg_names):
100        if i < num_input_data:
101            exe_bf16.arg_dict[arg_name][:] = data_list_bf16[i]
102            continue
103
104        if bf16_use_fp32_params:
105            exe_bf16.arg_dict[arg_name][:] = arg_params_fp32[arg_name]
106        else:
107            exe_bf16.arg_dict[arg_name][:] = mx.nd.amp_cast(arg_params_fp32[arg_name], dtype=bfloat16)
108
109    for aux_name in aux_names:
110        if bf16_use_fp32_params:
111            exe_bf16.aux_dict[aux_name][:] = aux_params_fp32[aux_name]
112        else:
113            exe_bf16.aux_dict[aux_name][:] = mx.nd.amp_cast(aux_params_fp32[aux_name], dtype=bfloat16)
114
115    output_bf16 = exe_bf16.forward()[0]
116    output_bf16.wait_to_read()
117    output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32")
118    assert_almost_equal_with_err(output_bf16_2_fp32, output_fp32, rtol=rtol, atol=atol, etol=etol)
119
120@with_seed()
121def test_bf16_bn():
122    data_sym_fp32 = mx.sym.Variable(name='data')
123    data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
124
125    bn_params = {"eps": 2e-05, "fix_gamma": False, "use_global_stats": True, "name": "bn"}
126    bn_fp32 = mx.sym.BatchNorm(data_sym_fp32, **bn_params)
127
128    bn_bf16 = mx.sym.BatchNorm(data_sym_bf16, **bn_params)
129    check_operator_accuracy(sym_fp32=bn_fp32, sym_bf16=bn_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=True, etol=1e-2)
130    check_operator_accuracy(sym_fp32=bn_fp32, sym_bf16=bn_bf16, data_shape=(32, 16, 64, 64), bf16_use_fp32_params=True, etol=1e-2)
131
132@with_seed()
133def test_bf16_bnrelu():
134    data_sym_fp32 = mx.sym.Variable(name='data')
135    data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
136
137    bnrelu_params = {"eps": 2e-05, "fix_gamma": False, "use_global_stats": True, "name": "bn"}
138    bnrelu_fp32 = mx.sym.contrib.BatchNormWithReLU(data_sym_fp32, **bnrelu_params)
139
140    bnrelu_bf16 = mx.sym.contrib.BatchNormWithReLU(data_sym_bf16, **bnrelu_params)
141    check_operator_accuracy(sym_fp32=bnrelu_fp32, sym_bf16=bnrelu_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=True, etol=1e-2)
142    check_operator_accuracy(sym_fp32=bnrelu_fp32, sym_bf16=bnrelu_bf16, data_shape=(32, 16, 64, 64), bf16_use_fp32_params=True, etol=1e-2)
143
144@with_seed()
145def test_bf16_conv():
146    data_sym_fp32 = mx.sym.Variable(name='data')
147    data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
148
149    conv_params = {"kernel": (3, 3), "num_filter": 128, "pad": (1, 1), "stride": (1, 1), "no_bias": True, "name": "conv"}
150    conv_fp32 = mx.sym.Convolution(data_sym_fp32, **conv_params)
151    conv_bf16 = mx.sym.Convolution(data_sym_bf16, **conv_params)
152    check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=False)
153    check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(128, 56, 14, 14), bf16_use_fp32_params=False)
154
155    conv_params = {"kernel": (1, 1), "num_filter": 32, "pad": (0, 0), "stride": (1, 1), "no_bias": False, "name": "conv"}
156    conv_fp32 = mx.sym.Convolution(data_sym_fp32, **conv_params)
157    conv_bf16 = mx.sym.Convolution(data_sym_bf16, **conv_params)
158    check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=False)
159    check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(128, 56, 14, 14), bf16_use_fp32_params=False)
160
161@with_seed()
162def test_bf16_fc():
163    data_sym_fp32 = mx.sym.Variable(name='data')
164    data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
165
166    fc_params = {"num_hidden": 10, "no_bias": True, "flatten": True, "name": "fc"}
167    fc_fp32 = mx.sym.FullyConnected(data_sym_fp32, **fc_params)
168    fc_bf16 = mx.sym.FullyConnected(data_sym_bf16, **fc_params)
169    check_operator_accuracy(fc_fp32, fc_bf16, data_shape=(3, 3, 16, 16), bf16_use_fp32_params=False)
170
171    fc_params = {"num_hidden": 10, "no_bias": False, "flatten": False, "name": "fc"}
172    fc_fp32 = mx.sym.FullyConnected(data_sym_fp32, **fc_params)
173    fc_bf16 = mx.sym.FullyConnected(data_sym_bf16, **fc_params)
174    check_operator_accuracy(fc_fp32, fc_bf16, data_shape=(3, 3, 16, 16), bf16_use_fp32_params=False)
175
176@with_seed()
177def test_bf16_pooling():
178    pool_params = {"kernel": (3, 3), "stride": (1, 1), "pad": (0, 0), "name": "pool"}
179    data_shapes = [(3, 16, 28, 28), (3, 32, 7, 7)]
180    pool_types = ["max", "avg"]
181    pool_conventions = ["full", "valid"]
182    for new_params in itertools.product(data_shapes, pool_types, pool_conventions):
183        pool_params.update({"pool_type": new_params[1], "pooling_convention": new_params[2]})
184
185        data_sym_fp32 = mx.sym.Variable(name='data')
186        data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
187        pool_fp32 = mx.sym.Pooling(data_sym_fp32, **pool_params)
188        pool_bf16 = mx.sym.Pooling(data_sym_bf16, **pool_params)
189        check_operator_accuracy(pool_fp32, pool_bf16, data_shape=new_params[0], bf16_use_fp32_params=False)
190
191@with_seed()
192def test_bf16_activation():
193    data_sym_fp32 = mx.sym.Variable(name='data')
194    data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
195
196    dshapes = [(3, 16), (3, 16, 16), (3, 3, 16, 16)]
197    act_types = ['relu', 'sigmoid', 'tanh']
198    for data_shape, act_type in itertools.product(dshapes, act_types):
199        act_fp32 = mx.sym.Activation(data_sym_fp32, act_type=act_type)
200        act_bf16 = mx.sym.Activation(data_sym_bf16, act_type=act_type)
201
202        check_operator_accuracy(act_fp32, act_bf16, data_shape, bf16_use_fp32_params=True)
203
204@with_seed()
205def test_bf16_elemwiseadd():
206    dshape = rand_shape_nd(4)
207
208    a_sym_fp32 = mx.sym.Variable("data")
209    b_sym_fp32 = mx.sym.Variable("data_1")
210    sym_fp32 = mx.sym.elemwise_add(a_sym_fp32, b_sym_fp32)
211
212    a_sym_bf16 = mx.sym.Variable("data", dtype=bfloat16)
213    b_sym_bf16 = mx.sym.Variable("data_1", dtype=bfloat16)
214    sym_bf16 = mx.sym.elemwise_add(a_sym_bf16, b_sym_bf16)
215
216    check_operator_accuracy(sym_fp32, sym_bf16, dshape, num_input_data=2, bf16_use_fp32_params=True)
217
218@unittest.skip("env dependent, need check further.")
219@with_seed()
220def test_bf16_concat():
221    dshape = rand_shape_nd(4)
222    a_shape = tuple(dshape)
223    b_shape = tuple(dshape)
224
225    a_sym_fp32 = mx.sym.Variable("data", shape=a_shape)
226    b_sym_fp32 = mx.sym.Variable("data_1", shape=b_shape)
227
228    a_sym_bf16 = mx.sym.Variable("data", dtype=bfloat16, shape=a_shape)
229    b_sym_bf16 = mx.sym.Variable("data_1", dtype=bfloat16, shape=b_shape)
230    for axis in range(0, 4):
231        print(axis, a_shape)
232        concat_sym_fp32 = mx.sym.concat(a_sym_fp32, b_sym_fp32, dim=axis)
233        concat_sym_bf16 = mx.sym.concat(a_sym_bf16, b_sym_bf16, dim=axis)
234
235        check_operator_accuracy(concat_sym_fp32, concat_sym_bf16, dshape, num_input_data=2, bf16_use_fp32_params=True)
236
237@with_seed()
238def test_bf16_abs():
239    dshapes = [(16,), (3, 16), (3, 16, 16), (3, 16, 16, 16)]
240    for data_shape in dshapes:
241        data_sym_fp32 = mx.sym.Variable(name='data')
242        data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
243        sym_fp32 = mx.sym.abs(data_sym_fp32)
244        sym_bf16 = mx.sym.abs(data_sym_bf16)
245
246        check_operator_accuracy(sym_fp32, sym_bf16, data_shape, bf16_use_fp32_params=True)
247
248@with_seed()
249def test_bf16_sqrt():
250    dshapes = [(16,), (3, 16), (3, 16, 16), (3, 16, 16, 16)]
251    for data_shape in dshapes:
252        data_sym_fp32 = mx.sym.Variable(name='data')
253        data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
254        sym_bf16 = mx.sym.sqrt(data_sym_bf16)
255        sym_fp32 = mx.sym.sqrt(data_sym_fp32)
256
257        check_operator_accuracy(sym_fp32, sym_bf16, data_shape, bf16_use_fp32_params=True)
258
259@with_seed()
260def test_bf16_square():
261    dshapes = [(16,), (3, 16), (3, 16, 16), (3, 16, 16, 16)]
262    for data_shape in dshapes:
263        data_sym_fp32 = mx.sym.Variable(name='data')
264        data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16)
265        sym_bf16 = mx.sym.square(data_sym_bf16)
266        sym_fp32 = mx.sym.square(data_sym_fp32)
267
268        check_operator_accuracy(sym_fp32, sym_bf16, data_shape, bf16_use_fp32_params=True)
269
270@with_seed()
271def test_bf16_flatten_slice_after_conv():
272    data_fp32 = mx.symbol.Variable('data')
273    data_bf16 = mx.symbol.Variable('data', dtype=bfloat16)
274
275    conv_fp32= mx.symbol.Convolution(data=data_fp32, name='conv', num_filter=64, kernel=(3,3), stride=(1,1))
276    flatten_fp32 = mx.symbol.flatten(data=conv_fp32)
277    slice_fp32 = mx.symbol.slice(data=flatten_fp32, begin=0, end=1)
278
279    conv_bf16= mx.symbol.Convolution(data=data_bf16, name='conv', num_filter=64, kernel=(3,3), stride=(1,1))
280    flatten_bf16 = mx.symbol.flatten(data=conv_bf16)
281    slice_bf16 = mx.symbol.slice(data=flatten_bf16, begin=0, end=1)
282
283    shape = (2, 16, 16, 16)
284    check_operator_accuracy(slice_fp32, slice_bf16, shape, bf16_use_fp32_params=False)
285
286def test_bf16_fallback():
287    data_sym_fp32 = mx.sym.Variable(name='data')
288    data_sym_bf16=mx.sym.Variable(name='data', dtype=bfloat16)
289
290    bn_params = {"eps": 2e-05, "fix_gamma": False, "use_global_stats": True, "name": "bn"}
291    bn_fp32 = mx.sym.BatchNorm(data_sym_fp32, **bn_params)
292    bn_bf16=mx.sym.BatchNorm(data_sym_bf16, **bn_params)
293    check_operator_accuracy(sym_fp32=bn_fp32, sym_bf16=bn_bf16, data_shape=(3, 32, 28, 28, 3), bf16_use_fp32_params=True, etol=1e-2)
294
295    conv_params = {"kernel": (3, 3, 3), "num_filter": 128, "pad": (1, 1, 1), "stride": (1, 1, 1), "no_bias": True, "name": "conv"}
296    conv_fp32 = mx.sym.Convolution(data_sym_fp32, **conv_params)
297    conv_bf16 = mx.sym.Convolution(data_sym_bf16, **conv_params)
298    check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(3, 32, 28, 28, 4), bf16_use_fp32_params=False)
299
300if __name__ == '__main__':
301    import nose
302    nose.runmodule()
303