1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""
20
21# Functions that should be cast to lower precision
22BF16_FUNCS = [
23    'Convolution',
24    'FullyConnected',
25    ]
26
27# Functions that should not be casted, either because
28# they are irrelevant (not used in the network itself
29# like image transformations or optimizers) or they
30# are dtype neutral (can work in both bf16 and fp32)
31BF16_FP32_FUNCS = [
32    'abs',
33    '_add',
34    'BatchNorm',
35    'BatchNormWithReLU',
36    'clip',
37    'Concat',
38    'concat',
39    'LRN',
40    'Pooling',
41    'relu',
42    'shuffle',
43    '_shuffle',
44    'sqrt',
45    'square',
46    'tanh',
47    ]
48
49# Functions that when running with Bfloat16, the params that still need float32.
50BF16_USE_FP32_PARAMS = {
51    'BatchNormWithReLU': ["", "gamma", "beta", "moving_mean", "moving_var"],
52    'BatchNorm': ["", "gamma", "beta", "moving_mean", "moving_var"],
53}
54
55# Functions that have to be cast to FP32 due to possible
56# overflows
57FP32_FUNCS = [
58    'Deconvolution',
59    'RNN',
60    'BatchNorm_v1',
61    'BilinearSampler',
62    'BlockGrad',
63    'Cast',
64    'cast',
65    'cast_storage',
66    'Crop',
67    'Dropout',
68    'Embedding',
69    '_sparse_Embedding',
70    '_sparse_FullyConnected',
71    'Flatten',
72    'GridGenerator',
73    'Pad',
74    'Pooling_v1',
75    'ROIPooling',
76    'Reshape',
77    'SequenceLast',
78    'SequenceMask',
79    'SequenceReverse',
80    'SliceChannel',
81    'SpatialTransformer',
82    'SwapAxis',
83    'UpSampling',
84    '_CachedOp',
85    '_CrossDeviceCopy',
86    '_CustomFunction',
87    '_DivScalar',
88    '_EqualScalar',
89    '_GreaterScalar',
90    '_GreaterEqualScalar',
91    '_LesserScalar',
92    '_LesserEqualScalar',
93    '_LogicalAndScalar',
94    '_LogicalOrScalar',
95    '_LogicalXorScalar',
96    '_MaximumScalar',
97    '_MinimumScalar',
98    '_MinusScalar',
99    '_ModScalar',
100    '_MulScalar',
101    '_NoGradient',
102    '_NotEqualScalar',
103    '_PlusScalar',
104    '_RMinusScalar',
105    '_RModScalar',
106    '_adamw_update',
107    '_arange',
108    '_broadcast_backward',
109    '_cond',
110    '_contrib_AdaptiveAvgPooling2D',
111    '_contrib_BilinearResize2D',
112    '_contrib_SparseEmbedding',
113    '_contrib_bipartite_matching',
114    '_contrib_dequantize',
115    '_contrib_div_sqrt_dim',
116    '_contrib_boolean_mask',
117    '_contrib_getnnz',
118    '_contrib_gradientmultiplier',
119    '_contrib_group_adagrad_update',
120    '_contrib_ifft',
121    '_contrib_index_array',
122    '_contrib_index_copy',
123    '_contrib_quadratic',
124    '_contrib_quantize',
125    '_contrib_quantize_v2',
126    '_contrib_quantized_concat',
127    '_contrib_quantized_conv',
128    '_contrib_quantized_flatten',
129    '_contrib_quantized_fully_connected',
130    '_contrib_quantized_pooling',
131    '_contrib_quantized_elemwise_add',
132    '_contrib_quantized_act',
133    '_image_crop',
134    '_linspace',
135    '_contrib_requantize',
136    '_copy',
137    '_copyto',
138    '_crop_assign',
139    '_crop_assign_scalar',
140    '_cvcopyMakeBorder',
141    '_cvimdecode',
142    '_cvimread',
143    '_cvimresize',
144    '_div_scalar',
145    '_equal_scalar',
146    '_eye',
147    '_foreach',
148    '_while_loop',
149    '_full',
150    '_grad_add',
151    '_greater_scalar',
152    '_greater_equal_scalar',
153    '_histogram',
154    '_identity_with_attr_like_rhs',
155    '_image_adjust_lighting',
156    '_image_flip_left_right',
157    '_image_flip_top_bottom',
158    '_image_normalize',
159    '_image_random_brightness',
160    '_image_random_color_jitter',
161    '_image_random_contrast',
162    '_image_random_flip_left_right',
163    '_image_random_flip_top_bottom',
164    '_image_random_hue',
165    '_image_random_lighting',
166    '_image_random_saturation',
167    '_image_resize',
168    '_image_to_tensor',
169    '_imdecode',
170    '_lesser_scalar',
171    '_lesser_equal_scalar',
172    '_logical_and_scalar',
173    '_logical_or_scalar',
174    '_logical_xor_scalar',
175    '_maximum_scalar',
176    '_minimum_scalar',
177    '_minus_scalar',
178    '_mod_scalar',
179    '_mp_adamw_update',
180    '_mul_scalar',
181    '_not_equal_scalar',
182    '_onehot_encode',
183    '_ones',
184    '_plus_scalar',
185    '_random_exponential',
186    '_random_exponential_like',
187    '_random_gamma',
188    '_random_gamma_like',
189    '_random_generalized_negative_binomial',
190    '_random_generalized_negative_binomial_like',
191    '_random_negative_binomial',
192    '_random_negative_binomial_like',
193    '_random_normal',
194    '_random_normal_like',
195    '_random_poisson',
196    '_random_poisson_like',
197    '_random_randint',
198    '_random_uniform',
199    '_random_uniform_like',
200    '_ravel_multi_index',
201    '_rminus_scalar',
202    '_rmod_scalar',
203    '_rnn_param_concat',
204    '_sample_exponential',
205    '_sample_gamma',
206    '_sample_generalized_negative_binomial',
207    '_sample_multinomial',
208    '_sample_negative_binomial',
209    '_sample_normal',
210    '_sample_poisson',
211    '_sample_uniform',
212    '_sample_unique_zipfian',
213    '_scatter_minus_scalar',
214    '_scatter_plus_scalar',
215    '_scatter_set_nd',
216    '_set_value',
217    '_slice_assign',
218    '_slice_assign_scalar',
219    '_sparse_abs',
220    '_sparse_adagrad_update',
221    '_sparse_adam_update',
222    '_sparse_arccosh',
223    '_sparse_arcsinh',
224    '_sparse_arctan',
225    '_sparse_cast_storage',
226    '_sparse_cbrt',
227    '_sparse_ceil',
228    '_sparse_clip',
229    '_sparse_concat',
230    '_sparse_cos',
231    '_sparse_degrees',
232    '_sparse_fix',
233    '_sparse_floor',
234    '_sparse_ftrl_update',
235    '_sparse_negative',
236    '_sparse_radians',
237    '_sparse_relu',
238    '_sparse_retain',
239    '_sparse_rint',
240    '_sparse_round',
241    '_sparse_sgd_mom_update',
242    '_sparse_sgd_update',
243    '_sparse_sigmoid',
244    '_sparse_sign',
245    '_sparse_sin',
246    '_sparse_sinh',
247    '_sparse_slice',
248    '_sparse_sqrt',
249    '_sparse_stop_gradient',
250    '_sparse_tanh',
251    '_sparse_trunc',
252    '_sparse_zeros_like',
253    '_split_v2',
254    '_split_v2_backward',
255    '_unravel_index',
256    '_zeros',
257    '_zeros_without_dtype',
258    'adam_update',
259    'all_finite',
260    # 'amp_cast',
261    # 'amp_multicast',
262    'arccosh',
263    'arcsinh',
264    'arctan',
265    'argmax',
266    'argmax_channel',
267    'argmin',
268    'batch_take',
269    'broadcast_axes',
270    'broadcast_axis',
271    'broadcast_like',
272    'broadcast_to',
273    'cbrt',
274    'ceil',
275    'choose_element_0index',
276    'cos',
277    'crop',
278    'degrees',
279    'depth_to_space',
280    'diag',
281    'erf',
282    'expand_dims',
283    'fill_element_0index',
284    'fix',
285    'flatten',
286    'flip',
287    'floor',
288    'ftml_update',
289    'ftrl_update',
290    'gather_nd',
291    'hard_sigmoid',
292    'identity',
293    'logical_not',
294    'max_axis',
295    'max',
296    'min',
297    'min_axis',
298    'mp_sgd_mom_update',
299    'mp_sgd_update',
300    'multi_all_finite',
301    'multi_mp_sgd_mom_update',
302    'multi_mp_sgd_update',
303    'multi_sgd_mom_update',
304    'multi_sgd_update',
305    'negative',
306    'normal',
307    'one_hot',
308    'ones_like',
309    'pad',
310    'pick',
311    'radians',
312    'random_exponential',
313    'random_gamma',
314    'random_generalized_negative_binomial',
315    'random_negative_binomial',
316    'random_normal',
317    'random_poisson',
318    'random_randint',
319    'random_uniform',
320    'ravel_multi_index',
321    'repeat',
322    'reshape',
323    'reshape_like',
324    'reverse',
325    'rint',
326    'rmsprop_update',
327    'rmspropalex_update',
328    'round',
329    'sample_exponential',
330    'sample_gamma',
331    'sample_generalized_negative_binomial',
332    'sample_multinomial',
333    'sample_negative_binomial',
334    'sample_normal',
335    'sample_poisson',
336    'sample_uniform',
337    'scatter_nd',
338    'sgd_mom_update',
339    'sgd_update',
340    'shape_array',
341    'sigmoid',
342    'sign',
343    'signsgd_update',
344    'signum_update',
345    'sin',
346    'size_array',
347    'slice',
348    'slice_axis',
349    'slice_like',
350    'softsign',
351    'sort',
352    'space_to_depth',
353    'split',
354    'squeeze',
355    'stop_gradient',
356    'swapaxes',
357    'take',
358    'tile',
359    'transpose',
360    'trunc',
361    'uniform',
362    'unravel_index',
363    'zeros_like',
364    '_sg_mkldnn_conv',
365    '_sg_mkldnn_fully_connected',
366    'broadcast_mul',
367    'Convolution_v1',
368    'IdentityAttachKLSparseReg',
369    'arccos',
370    '_sparse_arccos',
371    'arcsin',
372    'cosh',
373    '_sparse_cosh',
374    'erfinv',
375    'sinh',
376    'tan',
377    '_sparse_tan',
378    'arctanh',
379    '_sparse_arcsin',
380    '_sparse_arctanh',
381
382    # Exponents
383    'exp',
384    'expm1',
385    '_sparse_exp',
386    '_sparse_expm1',
387    'log',
388    'log10',
389    'log2',
390    'log1p',
391
392    # Powers
393    'broadcast_power',
394    '_sparse_square',
395    'reciprocal',
396    '_RDivScalar',
397    '_rdiv_scalar',
398    'rsqrt',
399    'rcbrt',
400    '_Power',
401    '_PowerScalar',
402    '_power',
403    '_power_scalar',
404    '_RPowerScalar',
405    '_rpower_scalar',
406    'linalg_sumlogdiag',
407    '_Hypot',
408    '_HypotScalar',
409    '_hypot',
410    '_hypot_scalar',
411    'broadcast_hypot',
412    '_square_sum',
413    '_contrib_hawkesll',
414
415    # Reductions
416    'sum',
417    'sum_axis',
418    'nansum',
419    'prod',
420    'nanprod',
421    'mean',
422    'norm',
423    'softmin',
424    'khatri_rao',
425    'moments',
426
427    # Misc
428    'gamma',
429    'gammaln',
430    '_linalg_gelqf',
431    '_linalg_gemm',
432    '_linalg_gemm2',
433    '_linalg_potrf',
434    '_linalg_potri',
435    '_linalg_sumlogdiag',
436    '_linalg_syevd',
437    '_linalg_syrk',
438    '_linalg_trmm',
439    '_linalg_trsm',
440    '_linalg_makediag',
441    '_linalg_extractdiag',
442    '_linalg_maketrian',
443    '_linalg_extracttrian',
444    '_linalg_inverse',
445    '_linalg_det',
446    '_linalg_slogdet',
447    'linalg_syrk',
448    'linalg_potrf',
449    'linalg_potri',
450    'linalg_gemm2',
451    'linalg_gemm',
452    'linalg_gelqf',
453    'linalg_trmm',
454    'linalg_trsm',
455    'linalg_makediag',
456    'linalg_extractdiag',
457    'linalg_maketrian',
458    'linalg_extracttrian',
459    'linalg_inverse',
460    'linalg_det',
461    'linalg_slogdet',
462    '_NDArray',
463    '_Native',
464    '_contrib_count_sketch',
465    '_contrib_SyncBatchNorm',
466    '_contrib_fft',
467    '_sparse_gamma',
468    '_sparse_gammaln',
469    '_sparse_log',
470    '_sparse_log10',
471    '_sparse_log1p',
472    '_sparse_log2',
473    '_sparse_make_loss',
474    '_sparse_mean',
475    '_sparse_norm',
476    '_sparse_rsqrt',
477    'argsort',
478    'topk',
479
480    # Neural network
481    'SoftmaxOutput',
482    'softmax',
483    'Softmax',
484    'log_softmax',
485    'InstanceNorm',
486    'LayerNorm',
487    'GroupNorm',
488    'L2Normalization',
489    'SoftmaxActivation',
490    'LinearRegressionOutput',
491    'LogisticRegressionOutput',
492    'MAERegressionOutput',
493    '_sparse_LinearRegressionOutput',
494    '_sparse_LogisticRegressionOutput',
495    '_sparse_MAERegressionOutput',
496    'SVMOutput',
497    'softmax_cross_entropy',
498    'smooth_l1',
499    'MakeLoss',
500    'make_loss',
501    'Custom',
502    'CTCLoss',
503    '_contrib_CTCLoss',
504    '_contrib_ctc_loss',
505    'ctc_loss',
506    '_contrib_DeformableConvolution',
507    '_contrib_DeformablePSROIPooling',
508    ]
509
510# Functions that have to be cast to FP32 only for
511# some values of their parameters
512CONDITIONAL_FP32_FUNCS = [
513    ('Activation', 'act_type', ['softrelu']),
514    ('LeakyReLU', 'act_type', ['elu', 'selu']),
515    ]
516
517# Functions with multiple inputs, that need the same
518# type of all their inputs
519WIDEST_TYPE_CASTS = [
520    '_Plus',
521    '_plus',
522    '_Minus',
523    '_sub',
524    '_Mul',
525    '_Div',
526    '_div',
527    '_scatter_elemwise_div',
528    '_Mod',
529    '_Not_Equal',
530    '_Equal',
531    '_equal',
532    '_Greater',
533    '_greater',
534    '_Greater_Equal',
535    '_greater_equal',
536    '_Lesser',
537    '_Lesser_Equal',
538    '_lesser',
539    '_lesser_equal',
540    '_Logical_And',
541    '_Logical_Or',
542    '_Logical_Xor',
543    '_logical_and',
544    '_logical_or',
545    '_logical_xor',
546    '_maximum',
547    '_minimum',
548    '_minus',
549    '_mod',
550    '_mul',
551    '_not_equal',
552    'Correlation',
553    'ElementWiseSum',
554    '_sparse_ElementWiseSum',
555    'add_n',
556    '_sparse_add_n',
557    'batch_dot',
558    'broadcast_add',
559    'broadcast_plus',
560    'broadcast_div',
561    'broadcast_equal',
562    'broadcast_greater',
563    'broadcast_greater_equal',
564    'broadcast_lesser',
565    'broadcast_lesser_equal',
566    'broadcast_logical_and',
567    'broadcast_logical_or',
568    'broadcast_logical_xor',
569    'broadcast_maximum',
570    'broadcast_minimum',
571    'broadcast_minus',
572    'broadcast_mod',
573    'broadcast_not_equal',
574    'broadcast_sub',
575    'dot',
576    'elemwise_add',
577    'elemwise_div',
578    'elemwise_mul',
579    'elemwise_sub',
580    'stack',
581    '_Maximum',
582    '_Minimum',
583    '_contrib_MultiBoxDetection',
584    '_contrib_MultiBoxPrior',
585    '_contrib_MultiBoxTarget',
586    '_contrib_MultiProposal',
587    '_contrib_PSROIPooling',
588    '_contrib_Proposal',
589    '_contrib_ROIAlign',
590    '_contrib_box_iou',
591    '_contrib_box_nms',
592    '_contrib_box_non_maximum_suppression',
593    '_contrib_dgl_adjacency',
594    '_contrib_dgl_csr_neighbor_non_uniform_sample',
595    '_contrib_dgl_csr_neighbor_uniform_sample',
596    '_contrib_dgl_graph_compact',
597    '_contrib_dgl_subgraph',
598    '_contrib_edge_id',
599    'where',
600    '_sparse_where',
601    '_sparse_broadcast_add',
602    '_sparse_broadcast_div',
603    '_sparse_broadcast_minus',
604    '_sparse_broadcast_mul',
605    '_sparse_broadcast_plus',
606    '_sparse_broadcast_sub',
607    '_sparse_dot',
608    '_sparse_elemwise_add',
609    '_sparse_elemwise_div',
610    '_sparse_elemwise_mul',
611    '_sparse_elemwise_sub',
612    '_sparse_sum',
613
614    'random_pdf_gamma',
615    'random_pdf_exponential',
616    'random_pdf_uniform',
617    'random_pdf_negative_binomial',
618    'random_pdf_generalized_negative_binomial',
619    'random_pdf_dirichlet',
620    'random_pdf_normal',
621    'random_pdf_poisson',
622    '_random_pdf_gamma',
623    '_random_pdf_exponential',
624    '_random_pdf_uniform',
625    '_random_pdf_negative_binomial',
626    '_random_pdf_generalized_negative_binomial',
627    '_random_pdf_dirichlet',
628    '_random_pdf_normal',
629    '_random_pdf_poisson',
630    ]
631
632LOSS_OUTPUT_FUNCTIONS = [
633    'SoftmaxOutput',
634    'LinearRegressionOutput',
635    'LogisticRegressionOutput',
636    'MAERegressionOutput',
637    ]
638