1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API.""" 20 21# Functions that should be cast to lower precision 22BF16_FUNCS = [ 23 'Convolution', 24 'FullyConnected', 25 ] 26 27# Functions that should not be casted, either because 28# they are irrelevant (not used in the network itself 29# like image transformations or optimizers) or they 30# are dtype neutral (can work in both bf16 and fp32) 31BF16_FP32_FUNCS = [ 32 'abs', 33 '_add', 34 'BatchNorm', 35 'BatchNormWithReLU', 36 'clip', 37 'Concat', 38 'concat', 39 'LRN', 40 'Pooling', 41 'relu', 42 'shuffle', 43 '_shuffle', 44 'sqrt', 45 'square', 46 'tanh', 47 ] 48 49# Functions that when running with Bfloat16, the params that still need float32. 50BF16_USE_FP32_PARAMS = { 51 'BatchNormWithReLU': ["", "gamma", "beta", "moving_mean", "moving_var"], 52 'BatchNorm': ["", "gamma", "beta", "moving_mean", "moving_var"], 53} 54 55# Functions that have to be cast to FP32 due to possible 56# overflows 57FP32_FUNCS = [ 58 'Deconvolution', 59 'RNN', 60 'BatchNorm_v1', 61 'BilinearSampler', 62 'BlockGrad', 63 'Cast', 64 'cast', 65 'cast_storage', 66 'Crop', 67 'Dropout', 68 'Embedding', 69 '_sparse_Embedding', 70 '_sparse_FullyConnected', 71 'Flatten', 72 'GridGenerator', 73 'Pad', 74 'Pooling_v1', 75 'ROIPooling', 76 'Reshape', 77 'SequenceLast', 78 'SequenceMask', 79 'SequenceReverse', 80 'SliceChannel', 81 'SpatialTransformer', 82 'SwapAxis', 83 'UpSampling', 84 '_CachedOp', 85 '_CrossDeviceCopy', 86 '_CustomFunction', 87 '_DivScalar', 88 '_EqualScalar', 89 '_GreaterScalar', 90 '_GreaterEqualScalar', 91 '_LesserScalar', 92 '_LesserEqualScalar', 93 '_LogicalAndScalar', 94 '_LogicalOrScalar', 95 '_LogicalXorScalar', 96 '_MaximumScalar', 97 '_MinimumScalar', 98 '_MinusScalar', 99 '_ModScalar', 100 '_MulScalar', 101 '_NoGradient', 102 '_NotEqualScalar', 103 '_PlusScalar', 104 '_RMinusScalar', 105 '_RModScalar', 106 '_adamw_update', 107 '_arange', 108 '_broadcast_backward', 109 '_cond', 110 '_contrib_AdaptiveAvgPooling2D', 111 '_contrib_BilinearResize2D', 112 '_contrib_SparseEmbedding', 113 '_contrib_bipartite_matching', 114 '_contrib_dequantize', 115 '_contrib_div_sqrt_dim', 116 '_contrib_boolean_mask', 117 '_contrib_getnnz', 118 '_contrib_gradientmultiplier', 119 '_contrib_group_adagrad_update', 120 '_contrib_ifft', 121 '_contrib_index_array', 122 '_contrib_index_copy', 123 '_contrib_quadratic', 124 '_contrib_quantize', 125 '_contrib_quantize_v2', 126 '_contrib_quantized_concat', 127 '_contrib_quantized_conv', 128 '_contrib_quantized_flatten', 129 '_contrib_quantized_fully_connected', 130 '_contrib_quantized_pooling', 131 '_contrib_quantized_elemwise_add', 132 '_contrib_quantized_act', 133 '_image_crop', 134 '_linspace', 135 '_contrib_requantize', 136 '_copy', 137 '_copyto', 138 '_crop_assign', 139 '_crop_assign_scalar', 140 '_cvcopyMakeBorder', 141 '_cvimdecode', 142 '_cvimread', 143 '_cvimresize', 144 '_div_scalar', 145 '_equal_scalar', 146 '_eye', 147 '_foreach', 148 '_while_loop', 149 '_full', 150 '_grad_add', 151 '_greater_scalar', 152 '_greater_equal_scalar', 153 '_histogram', 154 '_identity_with_attr_like_rhs', 155 '_image_adjust_lighting', 156 '_image_flip_left_right', 157 '_image_flip_top_bottom', 158 '_image_normalize', 159 '_image_random_brightness', 160 '_image_random_color_jitter', 161 '_image_random_contrast', 162 '_image_random_flip_left_right', 163 '_image_random_flip_top_bottom', 164 '_image_random_hue', 165 '_image_random_lighting', 166 '_image_random_saturation', 167 '_image_resize', 168 '_image_to_tensor', 169 '_imdecode', 170 '_lesser_scalar', 171 '_lesser_equal_scalar', 172 '_logical_and_scalar', 173 '_logical_or_scalar', 174 '_logical_xor_scalar', 175 '_maximum_scalar', 176 '_minimum_scalar', 177 '_minus_scalar', 178 '_mod_scalar', 179 '_mp_adamw_update', 180 '_mul_scalar', 181 '_not_equal_scalar', 182 '_onehot_encode', 183 '_ones', 184 '_plus_scalar', 185 '_random_exponential', 186 '_random_exponential_like', 187 '_random_gamma', 188 '_random_gamma_like', 189 '_random_generalized_negative_binomial', 190 '_random_generalized_negative_binomial_like', 191 '_random_negative_binomial', 192 '_random_negative_binomial_like', 193 '_random_normal', 194 '_random_normal_like', 195 '_random_poisson', 196 '_random_poisson_like', 197 '_random_randint', 198 '_random_uniform', 199 '_random_uniform_like', 200 '_ravel_multi_index', 201 '_rminus_scalar', 202 '_rmod_scalar', 203 '_rnn_param_concat', 204 '_sample_exponential', 205 '_sample_gamma', 206 '_sample_generalized_negative_binomial', 207 '_sample_multinomial', 208 '_sample_negative_binomial', 209 '_sample_normal', 210 '_sample_poisson', 211 '_sample_uniform', 212 '_sample_unique_zipfian', 213 '_scatter_minus_scalar', 214 '_scatter_plus_scalar', 215 '_scatter_set_nd', 216 '_set_value', 217 '_slice_assign', 218 '_slice_assign_scalar', 219 '_sparse_abs', 220 '_sparse_adagrad_update', 221 '_sparse_adam_update', 222 '_sparse_arccosh', 223 '_sparse_arcsinh', 224 '_sparse_arctan', 225 '_sparse_cast_storage', 226 '_sparse_cbrt', 227 '_sparse_ceil', 228 '_sparse_clip', 229 '_sparse_concat', 230 '_sparse_cos', 231 '_sparse_degrees', 232 '_sparse_fix', 233 '_sparse_floor', 234 '_sparse_ftrl_update', 235 '_sparse_negative', 236 '_sparse_radians', 237 '_sparse_relu', 238 '_sparse_retain', 239 '_sparse_rint', 240 '_sparse_round', 241 '_sparse_sgd_mom_update', 242 '_sparse_sgd_update', 243 '_sparse_sigmoid', 244 '_sparse_sign', 245 '_sparse_sin', 246 '_sparse_sinh', 247 '_sparse_slice', 248 '_sparse_sqrt', 249 '_sparse_stop_gradient', 250 '_sparse_tanh', 251 '_sparse_trunc', 252 '_sparse_zeros_like', 253 '_split_v2', 254 '_split_v2_backward', 255 '_unravel_index', 256 '_zeros', 257 '_zeros_without_dtype', 258 'adam_update', 259 'all_finite', 260 # 'amp_cast', 261 # 'amp_multicast', 262 'arccosh', 263 'arcsinh', 264 'arctan', 265 'argmax', 266 'argmax_channel', 267 'argmin', 268 'batch_take', 269 'broadcast_axes', 270 'broadcast_axis', 271 'broadcast_like', 272 'broadcast_to', 273 'cbrt', 274 'ceil', 275 'choose_element_0index', 276 'cos', 277 'crop', 278 'degrees', 279 'depth_to_space', 280 'diag', 281 'erf', 282 'expand_dims', 283 'fill_element_0index', 284 'fix', 285 'flatten', 286 'flip', 287 'floor', 288 'ftml_update', 289 'ftrl_update', 290 'gather_nd', 291 'hard_sigmoid', 292 'identity', 293 'logical_not', 294 'max_axis', 295 'max', 296 'min', 297 'min_axis', 298 'mp_sgd_mom_update', 299 'mp_sgd_update', 300 'multi_all_finite', 301 'multi_mp_sgd_mom_update', 302 'multi_mp_sgd_update', 303 'multi_sgd_mom_update', 304 'multi_sgd_update', 305 'negative', 306 'normal', 307 'one_hot', 308 'ones_like', 309 'pad', 310 'pick', 311 'radians', 312 'random_exponential', 313 'random_gamma', 314 'random_generalized_negative_binomial', 315 'random_negative_binomial', 316 'random_normal', 317 'random_poisson', 318 'random_randint', 319 'random_uniform', 320 'ravel_multi_index', 321 'repeat', 322 'reshape', 323 'reshape_like', 324 'reverse', 325 'rint', 326 'rmsprop_update', 327 'rmspropalex_update', 328 'round', 329 'sample_exponential', 330 'sample_gamma', 331 'sample_generalized_negative_binomial', 332 'sample_multinomial', 333 'sample_negative_binomial', 334 'sample_normal', 335 'sample_poisson', 336 'sample_uniform', 337 'scatter_nd', 338 'sgd_mom_update', 339 'sgd_update', 340 'shape_array', 341 'sigmoid', 342 'sign', 343 'signsgd_update', 344 'signum_update', 345 'sin', 346 'size_array', 347 'slice', 348 'slice_axis', 349 'slice_like', 350 'softsign', 351 'sort', 352 'space_to_depth', 353 'split', 354 'squeeze', 355 'stop_gradient', 356 'swapaxes', 357 'take', 358 'tile', 359 'transpose', 360 'trunc', 361 'uniform', 362 'unravel_index', 363 'zeros_like', 364 '_sg_mkldnn_conv', 365 '_sg_mkldnn_fully_connected', 366 'broadcast_mul', 367 'Convolution_v1', 368 'IdentityAttachKLSparseReg', 369 'arccos', 370 '_sparse_arccos', 371 'arcsin', 372 'cosh', 373 '_sparse_cosh', 374 'erfinv', 375 'sinh', 376 'tan', 377 '_sparse_tan', 378 'arctanh', 379 '_sparse_arcsin', 380 '_sparse_arctanh', 381 382 # Exponents 383 'exp', 384 'expm1', 385 '_sparse_exp', 386 '_sparse_expm1', 387 'log', 388 'log10', 389 'log2', 390 'log1p', 391 392 # Powers 393 'broadcast_power', 394 '_sparse_square', 395 'reciprocal', 396 '_RDivScalar', 397 '_rdiv_scalar', 398 'rsqrt', 399 'rcbrt', 400 '_Power', 401 '_PowerScalar', 402 '_power', 403 '_power_scalar', 404 '_RPowerScalar', 405 '_rpower_scalar', 406 'linalg_sumlogdiag', 407 '_Hypot', 408 '_HypotScalar', 409 '_hypot', 410 '_hypot_scalar', 411 'broadcast_hypot', 412 '_square_sum', 413 '_contrib_hawkesll', 414 415 # Reductions 416 'sum', 417 'sum_axis', 418 'nansum', 419 'prod', 420 'nanprod', 421 'mean', 422 'norm', 423 'softmin', 424 'khatri_rao', 425 'moments', 426 427 # Misc 428 'gamma', 429 'gammaln', 430 '_linalg_gelqf', 431 '_linalg_gemm', 432 '_linalg_gemm2', 433 '_linalg_potrf', 434 '_linalg_potri', 435 '_linalg_sumlogdiag', 436 '_linalg_syevd', 437 '_linalg_syrk', 438 '_linalg_trmm', 439 '_linalg_trsm', 440 '_linalg_makediag', 441 '_linalg_extractdiag', 442 '_linalg_maketrian', 443 '_linalg_extracttrian', 444 '_linalg_inverse', 445 '_linalg_det', 446 '_linalg_slogdet', 447 'linalg_syrk', 448 'linalg_potrf', 449 'linalg_potri', 450 'linalg_gemm2', 451 'linalg_gemm', 452 'linalg_gelqf', 453 'linalg_trmm', 454 'linalg_trsm', 455 'linalg_makediag', 456 'linalg_extractdiag', 457 'linalg_maketrian', 458 'linalg_extracttrian', 459 'linalg_inverse', 460 'linalg_det', 461 'linalg_slogdet', 462 '_NDArray', 463 '_Native', 464 '_contrib_count_sketch', 465 '_contrib_SyncBatchNorm', 466 '_contrib_fft', 467 '_sparse_gamma', 468 '_sparse_gammaln', 469 '_sparse_log', 470 '_sparse_log10', 471 '_sparse_log1p', 472 '_sparse_log2', 473 '_sparse_make_loss', 474 '_sparse_mean', 475 '_sparse_norm', 476 '_sparse_rsqrt', 477 'argsort', 478 'topk', 479 480 # Neural network 481 'SoftmaxOutput', 482 'softmax', 483 'Softmax', 484 'log_softmax', 485 'InstanceNorm', 486 'LayerNorm', 487 'GroupNorm', 488 'L2Normalization', 489 'SoftmaxActivation', 490 'LinearRegressionOutput', 491 'LogisticRegressionOutput', 492 'MAERegressionOutput', 493 '_sparse_LinearRegressionOutput', 494 '_sparse_LogisticRegressionOutput', 495 '_sparse_MAERegressionOutput', 496 'SVMOutput', 497 'softmax_cross_entropy', 498 'smooth_l1', 499 'MakeLoss', 500 'make_loss', 501 'Custom', 502 'CTCLoss', 503 '_contrib_CTCLoss', 504 '_contrib_ctc_loss', 505 'ctc_loss', 506 '_contrib_DeformableConvolution', 507 '_contrib_DeformablePSROIPooling', 508 ] 509 510# Functions that have to be cast to FP32 only for 511# some values of their parameters 512CONDITIONAL_FP32_FUNCS = [ 513 ('Activation', 'act_type', ['softrelu']), 514 ('LeakyReLU', 'act_type', ['elu', 'selu']), 515 ] 516 517# Functions with multiple inputs, that need the same 518# type of all their inputs 519WIDEST_TYPE_CASTS = [ 520 '_Plus', 521 '_plus', 522 '_Minus', 523 '_sub', 524 '_Mul', 525 '_Div', 526 '_div', 527 '_scatter_elemwise_div', 528 '_Mod', 529 '_Not_Equal', 530 '_Equal', 531 '_equal', 532 '_Greater', 533 '_greater', 534 '_Greater_Equal', 535 '_greater_equal', 536 '_Lesser', 537 '_Lesser_Equal', 538 '_lesser', 539 '_lesser_equal', 540 '_Logical_And', 541 '_Logical_Or', 542 '_Logical_Xor', 543 '_logical_and', 544 '_logical_or', 545 '_logical_xor', 546 '_maximum', 547 '_minimum', 548 '_minus', 549 '_mod', 550 '_mul', 551 '_not_equal', 552 'Correlation', 553 'ElementWiseSum', 554 '_sparse_ElementWiseSum', 555 'add_n', 556 '_sparse_add_n', 557 'batch_dot', 558 'broadcast_add', 559 'broadcast_plus', 560 'broadcast_div', 561 'broadcast_equal', 562 'broadcast_greater', 563 'broadcast_greater_equal', 564 'broadcast_lesser', 565 'broadcast_lesser_equal', 566 'broadcast_logical_and', 567 'broadcast_logical_or', 568 'broadcast_logical_xor', 569 'broadcast_maximum', 570 'broadcast_minimum', 571 'broadcast_minus', 572 'broadcast_mod', 573 'broadcast_not_equal', 574 'broadcast_sub', 575 'dot', 576 'elemwise_add', 577 'elemwise_div', 578 'elemwise_mul', 579 'elemwise_sub', 580 'stack', 581 '_Maximum', 582 '_Minimum', 583 '_contrib_MultiBoxDetection', 584 '_contrib_MultiBoxPrior', 585 '_contrib_MultiBoxTarget', 586 '_contrib_MultiProposal', 587 '_contrib_PSROIPooling', 588 '_contrib_Proposal', 589 '_contrib_ROIAlign', 590 '_contrib_box_iou', 591 '_contrib_box_nms', 592 '_contrib_box_non_maximum_suppression', 593 '_contrib_dgl_adjacency', 594 '_contrib_dgl_csr_neighbor_non_uniform_sample', 595 '_contrib_dgl_csr_neighbor_uniform_sample', 596 '_contrib_dgl_graph_compact', 597 '_contrib_dgl_subgraph', 598 '_contrib_edge_id', 599 'where', 600 '_sparse_where', 601 '_sparse_broadcast_add', 602 '_sparse_broadcast_div', 603 '_sparse_broadcast_minus', 604 '_sparse_broadcast_mul', 605 '_sparse_broadcast_plus', 606 '_sparse_broadcast_sub', 607 '_sparse_dot', 608 '_sparse_elemwise_add', 609 '_sparse_elemwise_div', 610 '_sparse_elemwise_mul', 611 '_sparse_elemwise_sub', 612 '_sparse_sum', 613 614 'random_pdf_gamma', 615 'random_pdf_exponential', 616 'random_pdf_uniform', 617 'random_pdf_negative_binomial', 618 'random_pdf_generalized_negative_binomial', 619 'random_pdf_dirichlet', 620 'random_pdf_normal', 621 'random_pdf_poisson', 622 '_random_pdf_gamma', 623 '_random_pdf_exponential', 624 '_random_pdf_uniform', 625 '_random_pdf_negative_binomial', 626 '_random_pdf_generalized_negative_binomial', 627 '_random_pdf_dirichlet', 628 '_random_pdf_normal', 629 '_random_pdf_poisson', 630 ] 631 632LOSS_OUTPUT_FUNCTIONS = [ 633 'SoftmaxOutput', 634 'LinearRegressionOutput', 635 'LogisticRegressionOutput', 636 'MAERegressionOutput', 637 ] 638