1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18# coding: utf-8
19"""Lists of functions whitelisted/blacklisted for automatic mixed precision in symbol API."""
20
21# Functions that should be cast to lower precision
22FP16_FUNCS = [
23    'Convolution',
24    'Deconvolution',
25    'FullyConnected',
26    'RNN',
27    ]
28
29# Functions that should not be casted, either because
30# they are irrelevant (not used in the network itself
31# like image transformations or optimizers) or they
32# are dtype neutral (can work in both fp16 and fp32)
33FP16_FP32_FUNCS = [
34    'BatchNorm',
35    'BatchNorm_v1',
36    'BilinearSampler',
37    'BlockGrad',
38    'Cast',
39    'cast',
40    'cast_storage',
41    'Crop',
42    'Dropout',
43    'Embedding',
44    '_sparse_Embedding',
45    '_sparse_FullyConnected',
46    'Flatten',
47    'GridGenerator',
48    'Pad',
49    'Pooling',
50    'Pooling_v1',
51    'ROIPooling',
52    'Reshape',
53    'SequenceLast',
54    'SequenceMask',
55    'SequenceReverse',
56    'SliceChannel',
57    'SpatialTransformer',
58    'SwapAxis',
59    'UpSampling',
60    '_CachedOp',
61    '_CrossDeviceCopy',
62    '_CustomFunction',
63    '_DivScalar',
64    '_EqualScalar',
65    '_GreaterScalar',
66    '_GreaterEqualScalar',
67    '_LesserScalar',
68    '_LesserEqualScalar',
69    '_LogicalAndScalar',
70    '_LogicalOrScalar',
71    '_LogicalXorScalar',
72    '_MaximumScalar',
73    '_MinimumScalar',
74    '_MinusScalar',
75    '_ModScalar',
76    '_MulScalar',
77    '_NoGradient',
78    '_NotEqualScalar',
79    '_PlusScalar',
80    '_RMinusScalar',
81    '_RModScalar',
82    '_adamw_update',
83    '_add',
84    '_arange',
85    '_broadcast_backward',
86    '_cond',
87    '_contrib_AdaptiveAvgPooling2D',
88    '_contrib_BilinearResize2D',
89    '_contrib_SparseEmbedding',
90    '_contrib_bipartite_matching',
91    '_contrib_dequantize',
92    '_contrib_div_sqrt_dim',
93    '_contrib_boolean_mask',
94    '_contrib_getnnz',
95    '_contrib_gradientmultiplier',
96    '_contrib_group_adagrad_update',
97    '_contrib_ifft',
98    '_contrib_index_array',
99    '_contrib_index_copy',
100    '_contrib_quadratic',
101    '_contrib_quantize',
102    '_contrib_quantize_v2',
103    '_contrib_quantized_concat',
104    '_contrib_quantized_conv',
105    '_contrib_quantized_flatten',
106    '_contrib_quantized_fully_connected',
107    '_contrib_quantized_pooling',
108    '_contrib_quantized_elemwise_add',
109    '_contrib_quantized_act',
110    '_image_crop',
111    '_linspace',
112    '_contrib_requantize',
113    '_copy',
114    '_copyto',
115    '_crop_assign',
116    '_crop_assign_scalar',
117    '_cvcopyMakeBorder',
118    '_cvimdecode',
119    '_cvimread',
120    '_cvimresize',
121    '_div_scalar',
122    '_equal_scalar',
123    '_eye',
124    '_foreach',
125    '_while_loop',
126    '_full',
127    '_grad_add',
128    '_greater_scalar',
129    '_greater_equal_scalar',
130    '_histogram',
131    '_identity_with_attr_like_rhs',
132    '_image_adjust_lighting',
133    '_image_flip_left_right',
134    '_image_flip_top_bottom',
135    '_image_normalize',
136    '_image_random_brightness',
137    '_image_random_color_jitter',
138    '_image_random_contrast',
139    '_image_random_flip_left_right',
140    '_image_random_flip_top_bottom',
141    '_image_random_hue',
142    '_image_random_lighting',
143    '_image_random_saturation',
144    '_image_resize',
145    '_image_to_tensor',
146    '_imdecode',
147    '_lesser_scalar',
148    '_lesser_equal_scalar',
149    '_logical_and_scalar',
150    '_logical_or_scalar',
151    '_logical_xor_scalar',
152    '_maximum_scalar',
153    '_minimum_scalar',
154    '_minus_scalar',
155    '_mod_scalar',
156    '_mp_adamw_update',
157    '_mul_scalar',
158    '_not_equal_scalar',
159    '_onehot_encode',
160    '_ones',
161    '_plus_scalar',
162    '_random_exponential',
163    '_random_exponential_like',
164    '_random_gamma',
165    '_random_gamma_like',
166    '_random_generalized_negative_binomial',
167    '_random_generalized_negative_binomial_like',
168    '_random_negative_binomial',
169    '_random_negative_binomial_like',
170    '_random_normal',
171    '_random_normal_like',
172    '_random_poisson',
173    '_random_poisson_like',
174    '_random_randint',
175    '_random_uniform',
176    '_random_uniform_like',
177    '_ravel_multi_index',
178    '_rminus_scalar',
179    '_rmod_scalar',
180    '_rnn_param_concat',
181    '_sample_exponential',
182    '_sample_gamma',
183    '_sample_generalized_negative_binomial',
184    '_sample_multinomial',
185    '_sample_negative_binomial',
186    '_sample_normal',
187    '_sample_poisson',
188    '_sample_uniform',
189    '_sample_unique_zipfian',
190    '_scatter_minus_scalar',
191    '_scatter_plus_scalar',
192    '_scatter_set_nd',
193    '_set_value',
194    '_shuffle',
195    '_slice_assign',
196    '_slice_assign_scalar',
197    '_sparse_abs',
198    '_sparse_adagrad_update',
199    '_sparse_adam_update',
200    '_sparse_arccosh',
201    '_sparse_arcsinh',
202    '_sparse_arctan',
203    '_sparse_cast_storage',
204    '_sparse_cbrt',
205    '_sparse_ceil',
206    '_sparse_clip',
207    '_sparse_concat',
208    '_sparse_cos',
209    '_sparse_degrees',
210    '_sparse_fix',
211    '_sparse_floor',
212    '_sparse_ftrl_update',
213    '_sparse_negative',
214    '_sparse_radians',
215    '_sparse_relu',
216    '_sparse_retain',
217    '_sparse_rint',
218    '_sparse_round',
219    '_sparse_sgd_mom_update',
220    '_sparse_sgd_update',
221    '_sparse_sigmoid',
222    '_sparse_sign',
223    '_sparse_sin',
224    '_sparse_sinh',
225    '_sparse_slice',
226    '_sparse_sqrt',
227    '_sparse_stop_gradient',
228    '_sparse_tanh',
229    '_sparse_trunc',
230    '_sparse_zeros_like',
231    '_split_v2',
232    '_split_v2_backward',
233    '_unravel_index',
234    '_zeros',
235    '_zeros_without_dtype',
236    'abs',
237    'adam_update',
238    'all_finite',
239    'amp_cast',
240    'amp_multicast',
241    'arccosh',
242    'arcsinh',
243    'arctan',
244    'argmax',
245    'argmax_channel',
246    'argmin',
247    'batch_take',
248    'broadcast_axes',
249    'broadcast_axis',
250    'broadcast_like',
251    'broadcast_to',
252    'cbrt',
253    'ceil',
254    'choose_element_0index',
255    'clip',
256    'cos',
257    'crop',
258    'degrees',
259    'depth_to_space',
260    'diag',
261    'erf',
262    'expand_dims',
263    'fill_element_0index',
264    'fix',
265    'flatten',
266    'flip',
267    'floor',
268    'ftml_update',
269    'ftrl_update',
270    'gather_nd',
271    'hard_sigmoid',
272    'identity',
273    'logical_not',
274    'max_axis',
275    'max',
276    'min',
277    'min_axis',
278    'mp_sgd_mom_update',
279    'mp_sgd_update',
280    'multi_all_finite',
281    'multi_mp_sgd_mom_update',
282    'multi_mp_sgd_update',
283    'multi_sgd_mom_update',
284    'multi_sgd_update',
285    'negative',
286    'normal',
287    'one_hot',
288    'ones_like',
289    'pad',
290    'pick',
291    'radians',
292    'random_exponential',
293    'random_gamma',
294    'random_generalized_negative_binomial',
295    'random_negative_binomial',
296    'random_normal',
297    'random_poisson',
298    'random_randint',
299    'random_uniform',
300    'ravel_multi_index',
301    'relu',
302    'repeat',
303    'reshape',
304    'reshape_like',
305    'reverse',
306    'rint',
307    'rmsprop_update',
308    'rmspropalex_update',
309    'round',
310    'sample_exponential',
311    'sample_gamma',
312    'sample_generalized_negative_binomial',
313    'sample_multinomial',
314    'sample_negative_binomial',
315    'sample_normal',
316    'sample_poisson',
317    'sample_uniform',
318    'scatter_nd',
319    'sgd_mom_update',
320    'sgd_update',
321    'shape_array',
322    'shuffle',
323    'sigmoid',
324    'sign',
325    'signsgd_update',
326    'signum_update',
327    'sin',
328    'size_array',
329    'slice',
330    'slice_axis',
331    'slice_like',
332    'softsign',
333    'sort',
334    'space_to_depth',
335    'split',
336    'sqrt',
337    'squeeze',
338    'stop_gradient',
339    'swapaxes',
340    'take',
341    'tanh',
342    'tile',
343    'transpose',
344    'trunc',
345    'uniform',
346    'unravel_index',
347    'zeros_like',
348    '_sg_mkldnn_conv',
349    '_sg_mkldnn_fully_connected',
350    'CuDNNBatchNorm',
351    '_TensorRT',
352    ]
353
354# Functions that have to be cast to FP32 due to possible
355# overflows
356FP32_FUNCS = [
357    'Convolution_v1',
358    'IdentityAttachKLSparseReg',
359    'arccos',
360    '_sparse_arccos',
361    'arcsin',
362    'cosh',
363    '_sparse_cosh',
364    'erfinv',
365    'sinh',
366    'tan',
367    '_sparse_tan',
368    'arctanh',
369    '_sparse_arcsin',
370    '_sparse_arctanh',
371    '_contrib_MultiBoxDetection',
372    '_contrib_MultiBoxPrior',
373    '_contrib_MultiBoxTarget',
374
375    # Exponents
376    'exp',
377    'expm1',
378    '_sparse_exp',
379    '_sparse_expm1',
380    'log',
381    'log10',
382    'log2',
383    'log1p',
384
385    # Powers
386    'broadcast_power',
387    'square',
388    '_sparse_square',
389    'reciprocal',
390    '_RDivScalar',
391    '_rdiv_scalar',
392    'rsqrt',
393    'rcbrt',
394    '_Power',
395    '_PowerScalar',
396    '_power',
397    '_power_scalar',
398    '_RPowerScalar',
399    '_rpower_scalar',
400    'linalg_sumlogdiag',
401    '_Hypot',
402    '_HypotScalar',
403    '_hypot',
404    '_hypot_scalar',
405    'broadcast_hypot',
406    '_square_sum',
407    '_contrib_hawkesll',
408
409    # Reductions
410    'sum',
411    'sum_axis',
412    'nansum',
413    'prod',
414    'nanprod',
415    'mean',
416    'norm',
417    'softmin',
418    'khatri_rao',
419    'moments',
420
421    # Misc
422    'gamma',
423    'gammaln',
424    '_linalg_gelqf',
425    '_linalg_gemm',
426    '_linalg_gemm2',
427    '_linalg_potrf',
428    '_linalg_potri',
429    '_linalg_sumlogdiag',
430    '_linalg_syevd',
431    '_linalg_syrk',
432    '_linalg_trmm',
433    '_linalg_trsm',
434    '_linalg_makediag',
435    '_linalg_extractdiag',
436    '_linalg_maketrian',
437    '_linalg_extracttrian',
438    '_linalg_inverse',
439    '_linalg_det',
440    '_linalg_slogdet',
441    'linalg_syrk',
442    'linalg_potrf',
443    'linalg_potri',
444    'linalg_gemm2',
445    'linalg_gemm',
446    'linalg_gelqf',
447    'linalg_trmm',
448    'linalg_trsm',
449    'linalg_makediag',
450    'linalg_extractdiag',
451    'linalg_maketrian',
452    'linalg_extracttrian',
453    'linalg_inverse',
454    'linalg_det',
455    'linalg_slogdet',
456    '_NDArray',
457    '_Native',
458    '_contrib_count_sketch',
459    '_contrib_SyncBatchNorm',
460    '_contrib_fft',
461    '_sparse_gamma',
462    '_sparse_gammaln',
463    '_sparse_log',
464    '_sparse_log10',
465    '_sparse_log1p',
466    '_sparse_log2',
467    '_sparse_make_loss',
468    '_sparse_mean',
469    '_sparse_norm',
470    '_sparse_rsqrt',
471    'argsort',
472    'topk',
473
474    # Neural network
475    'SoftmaxOutput',
476    'softmax',
477    'Softmax',
478    'log_softmax',
479    'InstanceNorm',
480    'LayerNorm',
481    'GroupNorm',
482    'L2Normalization',
483    'LRN',
484    'SoftmaxActivation',
485    'LinearRegressionOutput',
486    'LogisticRegressionOutput',
487    'MAERegressionOutput',
488    '_sparse_LinearRegressionOutput',
489    '_sparse_LogisticRegressionOutput',
490    '_sparse_MAERegressionOutput',
491    'SVMOutput',
492    'softmax_cross_entropy',
493    'smooth_l1',
494    'MakeLoss',
495    'make_loss',
496    'Custom',
497    'CTCLoss',
498    '_contrib_CTCLoss',
499    '_contrib_ctc_loss',
500    'ctc_loss',
501    '_contrib_DeformableConvolution',
502    '_contrib_DeformablePSROIPooling',
503    ]
504
505# Functions that have to be cast to FP32 only for
506# some values of their parameters
507CONDITIONAL_FP32_FUNCS = [
508    ('Activation', 'act_type', ['softrelu']),
509    ('LeakyReLU', 'act_type', ['elu', 'selu']),
510    ]
511
512# Functions with multiple inputs, that need the same
513# type of all their inputs
514WIDEST_TYPE_CASTS = [
515    '_Plus',
516    '_plus',
517    '_Minus',
518    '_sub',
519    '_Mul',
520    '_Div',
521    '_div',
522    '_scatter_elemwise_div',
523    '_Mod',
524    '_Not_Equal',
525    '_Equal',
526    '_equal',
527    '_Greater',
528    '_greater',
529    '_Greater_Equal',
530    '_greater_equal',
531    '_Lesser',
532    '_Lesser_Equal',
533    '_lesser',
534    '_lesser_equal',
535    '_Logical_And',
536    '_Logical_Or',
537    '_Logical_Xor',
538    '_logical_and',
539    '_logical_or',
540    '_logical_xor',
541    '_maximum',
542    '_minimum',
543    '_minus',
544    '_mod',
545    '_mul',
546    '_not_equal',
547    'Concat',
548    'concat',
549    'Correlation',
550    'ElementWiseSum',
551    '_sparse_ElementWiseSum',
552    'add_n',
553    '_sparse_add_n',
554    'batch_dot',
555    'broadcast_add',
556    'broadcast_plus',
557    'broadcast_div',
558    'broadcast_equal',
559    'broadcast_greater',
560    'broadcast_greater_equal',
561    'broadcast_lesser',
562    'broadcast_lesser_equal',
563    'broadcast_logical_and',
564    'broadcast_logical_or',
565    'broadcast_logical_xor',
566    'broadcast_maximum',
567    'broadcast_minimum',
568    'broadcast_minus',
569    'broadcast_mod',
570    'broadcast_mul',
571    'broadcast_not_equal',
572    'broadcast_sub',
573    'dot',
574    'elemwise_add',
575    'elemwise_div',
576    'elemwise_mul',
577    'elemwise_sub',
578    'stack',
579    '_Maximum',
580    '_Minimum',
581    '_contrib_MultiProposal',
582    '_contrib_PSROIPooling',
583    '_contrib_Proposal',
584    '_contrib_ROIAlign',
585    '_contrib_box_iou',
586    '_contrib_box_nms',
587    '_contrib_box_non_maximum_suppression',
588    '_contrib_dgl_adjacency',
589    '_contrib_dgl_csr_neighbor_non_uniform_sample',
590    '_contrib_dgl_csr_neighbor_uniform_sample',
591    '_contrib_dgl_graph_compact',
592    '_contrib_dgl_subgraph',
593    '_contrib_edge_id',
594    '_contrib_interleaved_matmul_encdec_qk',
595    '_contrib_interleaved_matmul_encdec_valatt',
596    '_contrib_interleaved_matmul_selfatt_qk',
597    '_contrib_interleaved_matmul_selfatt_valatt',
598    'where',
599    '_sparse_where',
600    '_sparse_broadcast_add',
601    '_sparse_broadcast_div',
602    '_sparse_broadcast_minus',
603    '_sparse_broadcast_mul',
604    '_sparse_broadcast_plus',
605    '_sparse_broadcast_sub',
606    '_sparse_dot',
607    '_sparse_elemwise_add',
608    '_sparse_elemwise_div',
609    '_sparse_elemwise_mul',
610    '_sparse_elemwise_sub',
611    '_sparse_sum',
612
613    'random_pdf_gamma',
614    'random_pdf_exponential',
615    'random_pdf_uniform',
616    'random_pdf_negative_binomial',
617    'random_pdf_generalized_negative_binomial',
618    'random_pdf_dirichlet',
619    'random_pdf_normal',
620    'random_pdf_poisson',
621    '_random_pdf_gamma',
622    '_random_pdf_exponential',
623    '_random_pdf_uniform',
624    '_random_pdf_negative_binomial',
625    '_random_pdf_generalized_negative_binomial',
626    '_random_pdf_dirichlet',
627    '_random_pdf_normal',
628    '_random_pdf_poisson',
629    ]
630
631LOSS_OUTPUT_FUNCTIONS = [
632    'SoftmaxOutput',
633    'LinearRegressionOutput',
634    'LogisticRegressionOutput',
635    'MAERegressionOutput',
636    ]
637