1 /*******************************************************************************
2 * Copyright 2020-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "oneapi/dnnl/dnnl.h"
18 
19 #include "c_types_map.hpp"
20 #include "utils.hpp"
21 
dnnl_reduction_desc_init(dnnl_reduction_desc_t * desc,dnnl_alg_kind_t alg_kind,const dnnl_memory_desc_t * src_desc,const dnnl_memory_desc_t * dst_desc,float p,float eps)22 dnnl_status_t dnnl_reduction_desc_init(dnnl_reduction_desc_t *desc,
23         dnnl_alg_kind_t alg_kind, const dnnl_memory_desc_t *src_desc,
24         const dnnl_memory_desc_t *dst_desc, float p, float eps) {
25     using namespace dnnl::impl;
26     using namespace dnnl::impl::status;
27     using namespace dnnl::impl::utils;
28     using namespace dnnl::impl::alg_kind;
29 
30     bool args_ok = !any_null(desc, src_desc, dst_desc)
31             && src_desc->format_kind != format_kind::any
32             && one_of(alg_kind, reduction_max, reduction_min, reduction_sum,
33                     reduction_mul, reduction_mean, reduction_norm_lp_max,
34                     reduction_norm_lp_sum, reduction_norm_lp_power_p_max,
35                     reduction_norm_lp_power_p_sum)
36             && IMPLICATION(one_of(alg_kind, reduction_norm_lp_max,
37                                    reduction_norm_lp_sum,
38                                    reduction_norm_lp_power_p_max,
39                                    reduction_norm_lp_power_p_sum),
40                     p >= 1.0f)
41             && IMPLICATION(one_of(alg_kind, reduction_norm_lp_max,
42                                    reduction_norm_lp_sum,
43                                    reduction_norm_lp_power_p_max,
44                                    reduction_norm_lp_power_p_sum),
45                     one_of(src_desc->data_type, data_type::f32, data_type::bf16,
46                             data_type::f16));
47     if (!args_ok) return invalid_arguments;
48 
49     if (src_desc->ndims != dst_desc->ndims) return invalid_arguments;
50 
51     for (auto d = 0; d < src_desc->ndims; ++d) {
52         const auto dst_dim_d = dst_desc->dims[d];
53         if (!one_of(dst_dim_d, 1, src_desc->dims[d])) return invalid_arguments;
54     }
55 
56     // reduction primitive doesn't support identity operation
57     if (array_cmp(src_desc->dims, dst_desc->dims, src_desc->ndims))
58         return invalid_arguments;
59 
60     if (src_desc->format_kind != format_kind::blocked
61             || !one_of(dst_desc->format_kind, format_kind::blocked,
62                     format_kind::any))
63         return invalid_arguments;
64 
65     if (src_desc->extra.flags != 0
66             || !IMPLICATION(dst_desc->format_kind == format_kind::blocked,
67                     dst_desc->extra.flags == 0))
68         return invalid_arguments;
69 
70     auto rd = reduction_desc_t();
71     rd.primitive_kind = primitive_kind::reduction;
72     rd.alg_kind = alg_kind;
73 
74     rd.src_desc = *src_desc;
75     rd.dst_desc = *dst_desc;
76 
77     rd.p = p;
78     rd.eps = eps;
79 
80     *desc = rd;
81     return success;
82 }
83