1# Licensed to the Apache Software Foundation (ASF) under one 2# or more contributor license agreements. See the NOTICE file 3# distributed with this work for additional information 4# regarding copyright ownership. The ASF licenses this file 5# to you under the Apache License, Version 2.0 (the 6# "License"); you may not use this file except in compliance 7# with the License. You may obtain a copy of the License at 8# 9# http://www.apache.org/licenses/LICENSE-2.0 10# 11# Unless required by applicable law or agreed to in writing, 12# software distributed under the License is distributed on an 13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 14# KIND, either express or implied. See the License for the 15# specific language governing permissions and limitations 16# under the License. 17 18# coding: utf-8 19"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API.""" 20 21# Functions that should be cast to lower precision 22FP16_FUNCS = [ 23 'Convolution', 24 'Deconvolution', 25 'FullyConnected', 26 'RNN', 27 ] 28 29# Functions that should not be casted, either because 30# they are irrelevant (not used in the network itself 31# like image transformations or optimizers) or they 32# are dtype neutral (can work in both fp16 and fp32) 33FP16_FP32_FUNCS = [ 34 'BatchNorm', 35 'BatchNorm_v1', 36 'BilinearSampler', 37 'BlockGrad', 38 'Cast', 39 'cast', 40 'cast_storage', 41 'Crop', 42 'Dropout', 43 'Embedding', 44 '_sparse_Embedding', 45 '_sparse_FullyConnected', 46 'Flatten', 47 'GridGenerator', 48 'Pad', 49 'Pooling', 50 'Pooling_v1', 51 'ROIPooling', 52 'Reshape', 53 'SequenceLast', 54 'SequenceMask', 55 'SequenceReverse', 56 'SliceChannel', 57 'SpatialTransformer', 58 'SwapAxis', 59 'UpSampling', 60 '_CachedOp', 61 '_CrossDeviceCopy', 62 '_CustomFunction', 63 '_DivScalar', 64 '_EqualScalar', 65 '_GreaterScalar', 66 '_GreaterEqualScalar', 67 '_LesserScalar', 68 '_LesserEqualScalar', 69 '_LogicalAndScalar', 70 '_LogicalOrScalar', 71 '_LogicalXorScalar', 72 '_MaximumScalar', 73 '_MinimumScalar', 74 '_MinusScalar', 75 '_ModScalar', 76 '_MulScalar', 77 '_NoGradient', 78 '_NotEqualScalar', 79 '_PlusScalar', 80 '_RMinusScalar', 81 '_RModScalar', 82 '_adamw_update', 83 '_add', 84 '_arange', 85 '_broadcast_backward', 86 '_cond', 87 '_contrib_AdaptiveAvgPooling2D', 88 '_contrib_BilinearResize2D', 89 '_contrib_SparseEmbedding', 90 '_contrib_bipartite_matching', 91 '_contrib_dequantize', 92 '_contrib_div_sqrt_dim', 93 '_contrib_boolean_mask', 94 '_contrib_getnnz', 95 '_contrib_gradientmultiplier', 96 '_contrib_group_adagrad_update', 97 '_contrib_ifft', 98 '_contrib_index_array', 99 '_contrib_index_copy', 100 '_contrib_quadratic', 101 '_contrib_quantize', 102 '_contrib_quantize_v2', 103 '_contrib_quantized_concat', 104 '_contrib_quantized_conv', 105 '_contrib_quantized_flatten', 106 '_contrib_quantized_fully_connected', 107 '_contrib_quantized_pooling', 108 '_contrib_quantized_elemwise_add', 109 '_contrib_quantized_act', 110 '_image_crop', 111 '_linspace', 112 '_contrib_requantize', 113 '_copy', 114 '_copyto', 115 '_crop_assign', 116 '_crop_assign_scalar', 117 '_cvcopyMakeBorder', 118 '_cvimdecode', 119 '_cvimread', 120 '_cvimresize', 121 '_div_scalar', 122 '_equal_scalar', 123 '_eye', 124 '_foreach', 125 '_while_loop', 126 '_full', 127 '_grad_add', 128 '_greater_scalar', 129 '_greater_equal_scalar', 130 '_histogram', 131 '_identity_with_attr_like_rhs', 132 '_image_adjust_lighting', 133 '_image_flip_left_right', 134 '_image_flip_top_bottom', 135 '_image_normalize', 136 '_image_random_brightness', 137 '_image_random_color_jitter', 138 '_image_random_contrast', 139 '_image_random_flip_left_right', 140 '_image_random_flip_top_bottom', 141 '_image_random_hue', 142 '_image_random_lighting', 143 '_image_random_saturation', 144 '_image_resize', 145 '_image_to_tensor', 146 '_imdecode', 147 '_lesser_scalar', 148 '_lesser_equal_scalar', 149 '_logical_and_scalar', 150 '_logical_or_scalar', 151 '_logical_xor_scalar', 152 '_maximum_scalar', 153 '_minimum_scalar', 154 '_minus_scalar', 155 '_mod_scalar', 156 '_mp_adamw_update', 157 '_mul_scalar', 158 '_not_equal_scalar', 159 '_onehot_encode', 160 '_ones', 161 '_plus_scalar', 162 '_random_exponential', 163 '_random_exponential_like', 164 '_random_gamma', 165 '_random_gamma_like', 166 '_random_generalized_negative_binomial', 167 '_random_generalized_negative_binomial_like', 168 '_random_negative_binomial', 169 '_random_negative_binomial_like', 170 '_random_normal', 171 '_random_normal_like', 172 '_random_poisson', 173 '_random_poisson_like', 174 '_random_randint', 175 '_random_uniform', 176 '_random_uniform_like', 177 '_ravel_multi_index', 178 '_rminus_scalar', 179 '_rmod_scalar', 180 '_rnn_param_concat', 181 '_sample_exponential', 182 '_sample_gamma', 183 '_sample_generalized_negative_binomial', 184 '_sample_multinomial', 185 '_sample_negative_binomial', 186 '_sample_normal', 187 '_sample_poisson', 188 '_sample_uniform', 189 '_sample_unique_zipfian', 190 '_scatter_minus_scalar', 191 '_scatter_plus_scalar', 192 '_scatter_set_nd', 193 '_set_value', 194 '_shuffle', 195 '_slice_assign', 196 '_slice_assign_scalar', 197 '_sparse_abs', 198 '_sparse_adagrad_update', 199 '_sparse_adam_update', 200 '_sparse_arccosh', 201 '_sparse_arcsinh', 202 '_sparse_arctan', 203 '_sparse_cast_storage', 204 '_sparse_cbrt', 205 '_sparse_ceil', 206 '_sparse_clip', 207 '_sparse_concat', 208 '_sparse_cos', 209 '_sparse_degrees', 210 '_sparse_fix', 211 '_sparse_floor', 212 '_sparse_ftrl_update', 213 '_sparse_negative', 214 '_sparse_radians', 215 '_sparse_relu', 216 '_sparse_retain', 217 '_sparse_rint', 218 '_sparse_round', 219 '_sparse_sgd_mom_update', 220 '_sparse_sgd_update', 221 '_sparse_sigmoid', 222 '_sparse_sign', 223 '_sparse_sin', 224 '_sparse_sinh', 225 '_sparse_slice', 226 '_sparse_sqrt', 227 '_sparse_stop_gradient', 228 '_sparse_tanh', 229 '_sparse_trunc', 230 '_sparse_zeros_like', 231 '_split_v2', 232 '_split_v2_backward', 233 '_unravel_index', 234 '_zeros', 235 '_zeros_without_dtype', 236 'abs', 237 'adam_update', 238 'all_finite', 239 'amp_cast', 240 'amp_multicast', 241 'arccosh', 242 'arcsinh', 243 'arctan', 244 'argmax', 245 'argmax_channel', 246 'argmin', 247 'batch_take', 248 'broadcast_axes', 249 'broadcast_axis', 250 'broadcast_like', 251 'broadcast_to', 252 'cbrt', 253 'ceil', 254 'choose_element_0index', 255 'clip', 256 'cos', 257 'crop', 258 'degrees', 259 'depth_to_space', 260 'diag', 261 'erf', 262 'expand_dims', 263 'fill_element_0index', 264 'fix', 265 'flatten', 266 'flip', 267 'floor', 268 'ftml_update', 269 'ftrl_update', 270 'gather_nd', 271 'hard_sigmoid', 272 'identity', 273 'logical_not', 274 'max_axis', 275 'max', 276 'min', 277 'min_axis', 278 'mp_sgd_mom_update', 279 'mp_sgd_update', 280 'multi_all_finite', 281 'multi_mp_sgd_mom_update', 282 'multi_mp_sgd_update', 283 'multi_sgd_mom_update', 284 'multi_sgd_update', 285 'negative', 286 'normal', 287 'one_hot', 288 'ones_like', 289 'pad', 290 'pick', 291 'radians', 292 'random_exponential', 293 'random_gamma', 294 'random_generalized_negative_binomial', 295 'random_negative_binomial', 296 'random_normal', 297 'random_poisson', 298 'random_randint', 299 'random_uniform', 300 'ravel_multi_index', 301 'relu', 302 'repeat', 303 'reshape', 304 'reshape_like', 305 'reverse', 306 'rint', 307 'rmsprop_update', 308 'rmspropalex_update', 309 'round', 310 'sample_exponential', 311 'sample_gamma', 312 'sample_generalized_negative_binomial', 313 'sample_multinomial', 314 'sample_negative_binomial', 315 'sample_normal', 316 'sample_poisson', 317 'sample_uniform', 318 'scatter_nd', 319 'sgd_mom_update', 320 'sgd_update', 321 'shape_array', 322 'shuffle', 323 'sigmoid', 324 'sign', 325 'signsgd_update', 326 'signum_update', 327 'sin', 328 'size_array', 329 'slice', 330 'slice_axis', 331 'slice_like', 332 'softsign', 333 'sort', 334 'space_to_depth', 335 'split', 336 'sqrt', 337 'squeeze', 338 'stop_gradient', 339 'swapaxes', 340 'take', 341 'tanh', 342 'tile', 343 'transpose', 344 'trunc', 345 'uniform', 346 'unravel_index', 347 'zeros_like', 348 '_sg_mkldnn_conv', 349 '_sg_mkldnn_fully_connected', 350 'CuDNNBatchNorm', 351 '_TensorRT', 352 ] 353 354# Functions that have to be cast to FP32 due to possible 355# overflows 356FP32_FUNCS = [ 357 'Convolution_v1', 358 'IdentityAttachKLSparseReg', 359 'arccos', 360 '_sparse_arccos', 361 'arcsin', 362 'cosh', 363 '_sparse_cosh', 364 'erfinv', 365 'sinh', 366 'tan', 367 '_sparse_tan', 368 'arctanh', 369 '_sparse_arcsin', 370 '_sparse_arctanh', 371 '_contrib_MultiBoxDetection', 372 '_contrib_MultiBoxPrior', 373 '_contrib_MultiBoxTarget', 374 375 # Exponents 376 'exp', 377 'expm1', 378 '_sparse_exp', 379 '_sparse_expm1', 380 'log', 381 'log10', 382 'log2', 383 'log1p', 384 385 # Powers 386 'broadcast_power', 387 'square', 388 '_sparse_square', 389 'reciprocal', 390 '_RDivScalar', 391 '_rdiv_scalar', 392 'rsqrt', 393 'rcbrt', 394 '_Power', 395 '_PowerScalar', 396 '_power', 397 '_power_scalar', 398 '_RPowerScalar', 399 '_rpower_scalar', 400 'linalg_sumlogdiag', 401 '_Hypot', 402 '_HypotScalar', 403 '_hypot', 404 '_hypot_scalar', 405 'broadcast_hypot', 406 '_square_sum', 407 '_contrib_hawkesll', 408 409 # Reductions 410 'sum', 411 'sum_axis', 412 'nansum', 413 'prod', 414 'nanprod', 415 'mean', 416 'norm', 417 'softmin', 418 'khatri_rao', 419 'moments', 420 421 # Misc 422 'gamma', 423 'gammaln', 424 '_linalg_gelqf', 425 '_linalg_gemm', 426 '_linalg_gemm2', 427 '_linalg_potrf', 428 '_linalg_potri', 429 '_linalg_sumlogdiag', 430 '_linalg_syevd', 431 '_linalg_syrk', 432 '_linalg_trmm', 433 '_linalg_trsm', 434 '_linalg_makediag', 435 '_linalg_extractdiag', 436 '_linalg_maketrian', 437 '_linalg_extracttrian', 438 '_linalg_inverse', 439 '_linalg_det', 440 '_linalg_slogdet', 441 'linalg_syrk', 442 'linalg_potrf', 443 'linalg_potri', 444 'linalg_gemm2', 445 'linalg_gemm', 446 'linalg_gelqf', 447 'linalg_trmm', 448 'linalg_trsm', 449 'linalg_makediag', 450 'linalg_extractdiag', 451 'linalg_maketrian', 452 'linalg_extracttrian', 453 'linalg_inverse', 454 'linalg_det', 455 'linalg_slogdet', 456 '_NDArray', 457 '_Native', 458 '_contrib_count_sketch', 459 '_contrib_SyncBatchNorm', 460 '_contrib_fft', 461 '_sparse_gamma', 462 '_sparse_gammaln', 463 '_sparse_log', 464 '_sparse_log10', 465 '_sparse_log1p', 466 '_sparse_log2', 467 '_sparse_make_loss', 468 '_sparse_mean', 469 '_sparse_norm', 470 '_sparse_rsqrt', 471 'argsort', 472 'topk', 473 474 # Neural network 475 'SoftmaxOutput', 476 'softmax', 477 'Softmax', 478 'log_softmax', 479 'InstanceNorm', 480 'LayerNorm', 481 'GroupNorm', 482 'L2Normalization', 483 'LRN', 484 'SoftmaxActivation', 485 'LinearRegressionOutput', 486 'LogisticRegressionOutput', 487 'MAERegressionOutput', 488 '_sparse_LinearRegressionOutput', 489 '_sparse_LogisticRegressionOutput', 490 '_sparse_MAERegressionOutput', 491 'SVMOutput', 492 'softmax_cross_entropy', 493 'smooth_l1', 494 'MakeLoss', 495 'make_loss', 496 'Custom', 497 'CTCLoss', 498 '_contrib_CTCLoss', 499 '_contrib_ctc_loss', 500 'ctc_loss', 501 '_contrib_DeformableConvolution', 502 '_contrib_DeformablePSROIPooling', 503 ] 504 505# Functions that have to be cast to FP32 only for 506# some values of their parameters 507CONDITIONAL_FP32_FUNCS = [ 508 ('Activation', 'act_type', ['softrelu']), 509 ('LeakyReLU', 'act_type', ['elu', 'selu']), 510 ] 511 512# Functions with multiple inputs, that need the same 513# type of all their inputs 514WIDEST_TYPE_CASTS = [ 515 '_Plus', 516 '_plus', 517 '_Minus', 518 '_sub', 519 '_Mul', 520 '_Div', 521 '_div', 522 '_scatter_elemwise_div', 523 '_Mod', 524 '_Not_Equal', 525 '_Equal', 526 '_equal', 527 '_Greater', 528 '_greater', 529 '_Greater_Equal', 530 '_greater_equal', 531 '_Lesser', 532 '_Lesser_Equal', 533 '_lesser', 534 '_lesser_equal', 535 '_Logical_And', 536 '_Logical_Or', 537 '_Logical_Xor', 538 '_logical_and', 539 '_logical_or', 540 '_logical_xor', 541 '_maximum', 542 '_minimum', 543 '_minus', 544 '_mod', 545 '_mul', 546 '_not_equal', 547 'Concat', 548 'concat', 549 'Correlation', 550 'ElementWiseSum', 551 '_sparse_ElementWiseSum', 552 'add_n', 553 '_sparse_add_n', 554 'batch_dot', 555 'broadcast_add', 556 'broadcast_plus', 557 'broadcast_div', 558 'broadcast_equal', 559 'broadcast_greater', 560 'broadcast_greater_equal', 561 'broadcast_lesser', 562 'broadcast_lesser_equal', 563 'broadcast_logical_and', 564 'broadcast_logical_or', 565 'broadcast_logical_xor', 566 'broadcast_maximum', 567 'broadcast_minimum', 568 'broadcast_minus', 569 'broadcast_mod', 570 'broadcast_mul', 571 'broadcast_not_equal', 572 'broadcast_sub', 573 'dot', 574 'elemwise_add', 575 'elemwise_div', 576 'elemwise_mul', 577 'elemwise_sub', 578 'stack', 579 '_Maximum', 580 '_Minimum', 581 '_contrib_MultiProposal', 582 '_contrib_PSROIPooling', 583 '_contrib_Proposal', 584 '_contrib_ROIAlign', 585 '_contrib_box_iou', 586 '_contrib_box_nms', 587 '_contrib_box_non_maximum_suppression', 588 '_contrib_dgl_adjacency', 589 '_contrib_dgl_csr_neighbor_non_uniform_sample', 590 '_contrib_dgl_csr_neighbor_uniform_sample', 591 '_contrib_dgl_graph_compact', 592 '_contrib_dgl_subgraph', 593 '_contrib_edge_id', 594 '_contrib_interleaved_matmul_encdec_qk', 595 '_contrib_interleaved_matmul_encdec_valatt', 596 '_contrib_interleaved_matmul_selfatt_qk', 597 '_contrib_interleaved_matmul_selfatt_valatt', 598 'where', 599 '_sparse_where', 600 '_sparse_broadcast_add', 601 '_sparse_broadcast_div', 602 '_sparse_broadcast_minus', 603 '_sparse_broadcast_mul', 604 '_sparse_broadcast_plus', 605 '_sparse_broadcast_sub', 606 '_sparse_dot', 607 '_sparse_elemwise_add', 608 '_sparse_elemwise_div', 609 '_sparse_elemwise_mul', 610 '_sparse_elemwise_sub', 611 '_sparse_sum', 612 613 'random_pdf_gamma', 614 'random_pdf_exponential', 615 'random_pdf_uniform', 616 'random_pdf_negative_binomial', 617 'random_pdf_generalized_negative_binomial', 618 'random_pdf_dirichlet', 619 'random_pdf_normal', 620 'random_pdf_poisson', 621 '_random_pdf_gamma', 622 '_random_pdf_exponential', 623 '_random_pdf_uniform', 624 '_random_pdf_negative_binomial', 625 '_random_pdf_generalized_negative_binomial', 626 '_random_pdf_dirichlet', 627 '_random_pdf_normal', 628 '_random_pdf_poisson', 629 ] 630 631LOSS_OUTPUT_FUNCTIONS = [ 632 'SoftmaxOutput', 633 'LinearRegressionOutput', 634 'LogisticRegressionOutput', 635 'MAERegressionOutput', 636 ] 637