1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18import os 19import sys 20import mxnet as mx 21import numpy as np 22from random import randint 23import warnings 24import collections 25import ctypes 26import itertools 27import mxnet.contrib.amp as amp 28from nose.tools import assert_raises 29from mxnet.test_utils import set_default_context, download_model, same_symbol_structure, assert_almost_equal_with_err, rand_shape_nd 30from mxnet.gluon.model_zoo.vision import get_model 31from mxnet.gluon import SymbolBlock, nn, rnn 32from mxnet.contrib.amp import amp 33curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) 34sys.path.insert(0, os.path.join(curr_path, '../unittest')) 35from common import with_seed 36import unittest 37 38bfloat16 = np.dtype([('bfloat16', np.uint16)]) 39 40def check_operator_accuracy(sym_fp32, sym_bf16, data_shape, num_input_data=1, bf16_use_fp32_params=False, rtol=1e-1, atol=5e-1, etol=0): 41 """ 42 check accuracy for bfloat16 operators 43 44 sym_fp32: Symbol 45 fp32 operator 46 sym_bf16: Symbol 47 bf16 operator 48 data_shape: tuple of int 49 input data shape for fp32/bf16 symbol 50 num_input_data: int 51 number of input data, default is 1, should set different values for those operators with multiple inputs, like concat, elemwise_add, etc. 52 bf16_use_fp32_params: bool 53 currently only bn use this param as True, since bf16 bn only accept bf16 data with fp32 mean/var/scale/shift 54 rtol: float 55 the relative threshold 56 atol: float 57 the absolute threshold 58 etol: float 59 The error rate threshold, allow a small amount of value not consistent between bf16 and fp32 60 """ 61 if not isinstance(data_shape, tuple): 62 data_shape = tuple(data_shape) 63 data_range = (0.0, 10.0) 64 data_list_fp32 = list() 65 data_list_bf16 = list() 66 for i in range(num_input_data): 67 data_list_fp32.append(mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=data_shape)) 68 data_list_bf16.append(mx.nd.amp_cast(data_list_fp32[i], dtype=bfloat16)) 69 70 arg_shapes, _, aux_shapes = sym_fp32.infer_shape(data=data_shape) 71 arg_names = sym_fp32.list_arguments() 72 aux_names = sym_fp32.list_auxiliary_states() 73 74 exe_fp32 = sym_fp32.simple_bind(ctx=mx.cpu(), data=data_shape) 75 76 arg_params_fp32 = {} 77 aux_params_fp32 = {} 78 type_dict = {} 79 for i, arg_name in enumerate(arg_names): 80 if i < num_input_data: 81 exe_fp32.arg_dict[arg_name][:] = data_list_fp32[i] 82 continue 83 arg_params_fp32[arg_name] = mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=arg_shapes[i]) 84 exe_fp32.arg_dict[arg_name][:] = arg_params_fp32[arg_name] 85 # specify the dtype of arguments 86 if not bf16_use_fp32_params: 87 type_dict.update({arg_name: bfloat16}) 88 89 for i, aux_name in enumerate(aux_names): 90 aux_params_fp32[aux_name] = mx.nd.random.uniform(low=data_range[0], high=data_range[1], shape=aux_shapes[i]) 91 exe_fp32.aux_dict[aux_name][:] = aux_params_fp32[aux_name] 92 93 output_fp32 = exe_fp32.forward()[0] 94 95 exe_bf16 = sym_bf16.simple_bind(ctx=mx.cpu(), data=data_shape, type_dict=type_dict) 96 97 arg_params_bf16 = {} 98 aux_params_bf16 = {} 99 for i, arg_name in enumerate(arg_names): 100 if i < num_input_data: 101 exe_bf16.arg_dict[arg_name][:] = data_list_bf16[i] 102 continue 103 104 if bf16_use_fp32_params: 105 exe_bf16.arg_dict[arg_name][:] = arg_params_fp32[arg_name] 106 else: 107 exe_bf16.arg_dict[arg_name][:] = mx.nd.amp_cast(arg_params_fp32[arg_name], dtype=bfloat16) 108 109 for aux_name in aux_names: 110 if bf16_use_fp32_params: 111 exe_bf16.aux_dict[aux_name][:] = aux_params_fp32[aux_name] 112 else: 113 exe_bf16.aux_dict[aux_name][:] = mx.nd.amp_cast(aux_params_fp32[aux_name], dtype=bfloat16) 114 115 output_bf16 = exe_bf16.forward()[0] 116 output_bf16.wait_to_read() 117 output_bf16_2_fp32 = mx.nd.amp_cast(output_bf16, dtype="float32") 118 assert_almost_equal_with_err(output_bf16_2_fp32, output_fp32, rtol=rtol, atol=atol, etol=etol) 119 120@with_seed() 121def test_bf16_bn(): 122 data_sym_fp32 = mx.sym.Variable(name='data') 123 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 124 125 bn_params = {"eps": 2e-05, "fix_gamma": False, "use_global_stats": True, "name": "bn"} 126 bn_fp32 = mx.sym.BatchNorm(data_sym_fp32, **bn_params) 127 128 bn_bf16 = mx.sym.BatchNorm(data_sym_bf16, **bn_params) 129 check_operator_accuracy(sym_fp32=bn_fp32, sym_bf16=bn_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=True, etol=1e-2) 130 check_operator_accuracy(sym_fp32=bn_fp32, sym_bf16=bn_bf16, data_shape=(32, 16, 64, 64), bf16_use_fp32_params=True, etol=1e-2) 131 132@with_seed() 133def test_bf16_bnrelu(): 134 data_sym_fp32 = mx.sym.Variable(name='data') 135 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 136 137 bnrelu_params = {"eps": 2e-05, "fix_gamma": False, "use_global_stats": True, "name": "bn"} 138 bnrelu_fp32 = mx.sym.contrib.BatchNormWithReLU(data_sym_fp32, **bnrelu_params) 139 140 bnrelu_bf16 = mx.sym.contrib.BatchNormWithReLU(data_sym_bf16, **bnrelu_params) 141 check_operator_accuracy(sym_fp32=bnrelu_fp32, sym_bf16=bnrelu_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=True, etol=1e-2) 142 check_operator_accuracy(sym_fp32=bnrelu_fp32, sym_bf16=bnrelu_bf16, data_shape=(32, 16, 64, 64), bf16_use_fp32_params=True, etol=1e-2) 143 144@with_seed() 145def test_bf16_conv(): 146 data_sym_fp32 = mx.sym.Variable(name='data') 147 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 148 149 conv_params = {"kernel": (3, 3), "num_filter": 128, "pad": (1, 1), "stride": (1, 1), "no_bias": True, "name": "conv"} 150 conv_fp32 = mx.sym.Convolution(data_sym_fp32, **conv_params) 151 conv_bf16 = mx.sym.Convolution(data_sym_bf16, **conv_params) 152 check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=False) 153 check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(128, 56, 14, 14), bf16_use_fp32_params=False) 154 155 conv_params = {"kernel": (1, 1), "num_filter": 32, "pad": (0, 0), "stride": (1, 1), "no_bias": False, "name": "conv"} 156 conv_fp32 = mx.sym.Convolution(data_sym_fp32, **conv_params) 157 conv_bf16 = mx.sym.Convolution(data_sym_bf16, **conv_params) 158 check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(3, 32, 28, 28), bf16_use_fp32_params=False) 159 check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(128, 56, 14, 14), bf16_use_fp32_params=False) 160 161@with_seed() 162def test_bf16_fc(): 163 data_sym_fp32 = mx.sym.Variable(name='data') 164 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 165 166 fc_params = {"num_hidden": 10, "no_bias": True, "flatten": True, "name": "fc"} 167 fc_fp32 = mx.sym.FullyConnected(data_sym_fp32, **fc_params) 168 fc_bf16 = mx.sym.FullyConnected(data_sym_bf16, **fc_params) 169 check_operator_accuracy(fc_fp32, fc_bf16, data_shape=(3, 3, 16, 16), bf16_use_fp32_params=False) 170 171 fc_params = {"num_hidden": 10, "no_bias": False, "flatten": False, "name": "fc"} 172 fc_fp32 = mx.sym.FullyConnected(data_sym_fp32, **fc_params) 173 fc_bf16 = mx.sym.FullyConnected(data_sym_bf16, **fc_params) 174 check_operator_accuracy(fc_fp32, fc_bf16, data_shape=(3, 3, 16, 16), bf16_use_fp32_params=False) 175 176@with_seed() 177def test_bf16_pooling(): 178 pool_params = {"kernel": (3, 3), "stride": (1, 1), "pad": (0, 0), "name": "pool"} 179 data_shapes = [(3, 16, 28, 28), (3, 32, 7, 7)] 180 pool_types = ["max", "avg"] 181 pool_conventions = ["full", "valid"] 182 for new_params in itertools.product(data_shapes, pool_types, pool_conventions): 183 pool_params.update({"pool_type": new_params[1], "pooling_convention": new_params[2]}) 184 185 data_sym_fp32 = mx.sym.Variable(name='data') 186 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 187 pool_fp32 = mx.sym.Pooling(data_sym_fp32, **pool_params) 188 pool_bf16 = mx.sym.Pooling(data_sym_bf16, **pool_params) 189 check_operator_accuracy(pool_fp32, pool_bf16, data_shape=new_params[0], bf16_use_fp32_params=False) 190 191@with_seed() 192def test_bf16_activation(): 193 data_sym_fp32 = mx.sym.Variable(name='data') 194 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 195 196 dshapes = [(3, 16), (3, 16, 16), (3, 3, 16, 16)] 197 act_types = ['relu', 'sigmoid', 'tanh'] 198 for data_shape, act_type in itertools.product(dshapes, act_types): 199 act_fp32 = mx.sym.Activation(data_sym_fp32, act_type=act_type) 200 act_bf16 = mx.sym.Activation(data_sym_bf16, act_type=act_type) 201 202 check_operator_accuracy(act_fp32, act_bf16, data_shape, bf16_use_fp32_params=True) 203 204@with_seed() 205def test_bf16_elemwiseadd(): 206 dshape = rand_shape_nd(4) 207 208 a_sym_fp32 = mx.sym.Variable("data") 209 b_sym_fp32 = mx.sym.Variable("data_1") 210 sym_fp32 = mx.sym.elemwise_add(a_sym_fp32, b_sym_fp32) 211 212 a_sym_bf16 = mx.sym.Variable("data", dtype=bfloat16) 213 b_sym_bf16 = mx.sym.Variable("data_1", dtype=bfloat16) 214 sym_bf16 = mx.sym.elemwise_add(a_sym_bf16, b_sym_bf16) 215 216 check_operator_accuracy(sym_fp32, sym_bf16, dshape, num_input_data=2, bf16_use_fp32_params=True) 217 218@unittest.skip("env dependent, need check further.") 219@with_seed() 220def test_bf16_concat(): 221 dshape = rand_shape_nd(4) 222 a_shape = tuple(dshape) 223 b_shape = tuple(dshape) 224 225 a_sym_fp32 = mx.sym.Variable("data", shape=a_shape) 226 b_sym_fp32 = mx.sym.Variable("data_1", shape=b_shape) 227 228 a_sym_bf16 = mx.sym.Variable("data", dtype=bfloat16, shape=a_shape) 229 b_sym_bf16 = mx.sym.Variable("data_1", dtype=bfloat16, shape=b_shape) 230 for axis in range(0, 4): 231 print(axis, a_shape) 232 concat_sym_fp32 = mx.sym.concat(a_sym_fp32, b_sym_fp32, dim=axis) 233 concat_sym_bf16 = mx.sym.concat(a_sym_bf16, b_sym_bf16, dim=axis) 234 235 check_operator_accuracy(concat_sym_fp32, concat_sym_bf16, dshape, num_input_data=2, bf16_use_fp32_params=True) 236 237@with_seed() 238def test_bf16_abs(): 239 dshapes = [(16,), (3, 16), (3, 16, 16), (3, 16, 16, 16)] 240 for data_shape in dshapes: 241 data_sym_fp32 = mx.sym.Variable(name='data') 242 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 243 sym_fp32 = mx.sym.abs(data_sym_fp32) 244 sym_bf16 = mx.sym.abs(data_sym_bf16) 245 246 check_operator_accuracy(sym_fp32, sym_bf16, data_shape, bf16_use_fp32_params=True) 247 248@with_seed() 249def test_bf16_sqrt(): 250 dshapes = [(16,), (3, 16), (3, 16, 16), (3, 16, 16, 16)] 251 for data_shape in dshapes: 252 data_sym_fp32 = mx.sym.Variable(name='data') 253 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 254 sym_bf16 = mx.sym.sqrt(data_sym_bf16) 255 sym_fp32 = mx.sym.sqrt(data_sym_fp32) 256 257 check_operator_accuracy(sym_fp32, sym_bf16, data_shape, bf16_use_fp32_params=True) 258 259@with_seed() 260def test_bf16_square(): 261 dshapes = [(16,), (3, 16), (3, 16, 16), (3, 16, 16, 16)] 262 for data_shape in dshapes: 263 data_sym_fp32 = mx.sym.Variable(name='data') 264 data_sym_bf16 = mx.sym.Variable(name='data', dtype=bfloat16) 265 sym_bf16 = mx.sym.square(data_sym_bf16) 266 sym_fp32 = mx.sym.square(data_sym_fp32) 267 268 check_operator_accuracy(sym_fp32, sym_bf16, data_shape, bf16_use_fp32_params=True) 269 270@with_seed() 271def test_bf16_flatten_slice_after_conv(): 272 data_fp32 = mx.symbol.Variable('data') 273 data_bf16 = mx.symbol.Variable('data', dtype=bfloat16) 274 275 conv_fp32= mx.symbol.Convolution(data=data_fp32, name='conv', num_filter=64, kernel=(3,3), stride=(1,1)) 276 flatten_fp32 = mx.symbol.flatten(data=conv_fp32) 277 slice_fp32 = mx.symbol.slice(data=flatten_fp32, begin=0, end=1) 278 279 conv_bf16= mx.symbol.Convolution(data=data_bf16, name='conv', num_filter=64, kernel=(3,3), stride=(1,1)) 280 flatten_bf16 = mx.symbol.flatten(data=conv_bf16) 281 slice_bf16 = mx.symbol.slice(data=flatten_bf16, begin=0, end=1) 282 283 shape = (2, 16, 16, 16) 284 check_operator_accuracy(slice_fp32, slice_bf16, shape, bf16_use_fp32_params=False) 285 286def test_bf16_fallback(): 287 data_sym_fp32 = mx.sym.Variable(name='data') 288 data_sym_bf16=mx.sym.Variable(name='data', dtype=bfloat16) 289 290 bn_params = {"eps": 2e-05, "fix_gamma": False, "use_global_stats": True, "name": "bn"} 291 bn_fp32 = mx.sym.BatchNorm(data_sym_fp32, **bn_params) 292 bn_bf16=mx.sym.BatchNorm(data_sym_bf16, **bn_params) 293 check_operator_accuracy(sym_fp32=bn_fp32, sym_bf16=bn_bf16, data_shape=(3, 32, 28, 28, 3), bf16_use_fp32_params=True, etol=1e-2) 294 295 conv_params = {"kernel": (3, 3, 3), "num_filter": 128, "pad": (1, 1, 1), "stride": (1, 1, 1), "no_bias": True, "name": "conv"} 296 conv_fp32 = mx.sym.Convolution(data_sym_fp32, **conv_params) 297 conv_bf16 = mx.sym.Convolution(data_sym_bf16, **conv_params) 298 check_operator_accuracy(sym_fp32=conv_fp32, sym_bf16=conv_bf16, data_shape=(3, 32, 28, 28, 4), bf16_use_fp32_params=False) 299 300if __name__ == '__main__': 301 import nose 302 nose.runmodule() 303