1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <assert.h>
18 #include "oneapi/dnnl/dnnl.h"
19 
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "utils.hpp"
23 
24 using namespace dnnl::impl;
25 using namespace dnnl::impl::utils;
26 using namespace dnnl::impl::status;
27 using namespace dnnl::impl::prop_kind;
28 using namespace dnnl::impl::alg_kind;
29 using namespace dnnl::impl::types;
30 
31 namespace {
lrn_desc_init(lrn_desc_t * lrn_desc,prop_kind_t prop_kind,alg_kind_t alg_kind,const memory_desc_t * data_desc,const memory_desc_t * diff_data_desc,dim_t local_size,float alpha,float beta,float k)32 status_t lrn_desc_init(lrn_desc_t *lrn_desc, prop_kind_t prop_kind,
33         alg_kind_t alg_kind, const memory_desc_t *data_desc,
34         const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
35         float beta, float k) {
36     bool args_ok = true && !any_null(lrn_desc, data_desc)
37             && one_of(alg_kind, lrn_within_channel, lrn_across_channels)
38             && one_of(prop_kind, forward_training, forward_inference,
39                     backward_data)
40             && IMPLICATION(
41                     prop_kind == backward_data, diff_data_desc != nullptr);
42     if (!args_ok) return invalid_arguments;
43 
44     auto ld = lrn_desc_t();
45     ld.primitive_kind = primitive_kind::lrn;
46     ld.prop_kind = prop_kind;
47     ld.alg_kind = alg_kind;
48 
49     const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
50 
51     bool runtime_dims_or_strides
52             = memory_desc_wrapper(data_desc).has_runtime_dims_or_strides();
53     if (!is_fwd)
54         runtime_dims_or_strides = runtime_dims_or_strides
55                 || memory_desc_wrapper(diff_data_desc)
56                            .has_runtime_dims_or_strides();
57     if (runtime_dims_or_strides) return unimplemented;
58 
59     ld.data_desc = *data_desc;
60     if (!is_fwd) ld.diff_data_desc = *diff_data_desc;
61 
62     ld.local_size = local_size;
63     ld.lrn_alpha = alpha;
64     ld.lrn_beta = beta;
65     ld.lrn_k = k;
66 
67     bool consistency = ld.data_desc.ndims >= 2;
68     if (consistency && ld.prop_kind == backward_data)
69         consistency = array_cmp(
70                 ld.diff_data_desc.dims, ld.data_desc.dims, ld.data_desc.ndims);
71     if (!consistency) return invalid_arguments;
72 
73     *lrn_desc = ld;
74     return success;
75 }
76 } // namespace
77 
dnnl_lrn_forward_desc_init(lrn_desc_t * lrn_desc,prop_kind_t prop_kind,alg_kind_t alg_kind,const memory_desc_t * data_desc,dim_t local_size,float alpha,float beta,float k)78 status_t dnnl_lrn_forward_desc_init(lrn_desc_t *lrn_desc, prop_kind_t prop_kind,
79         alg_kind_t alg_kind, const memory_desc_t *data_desc, dim_t local_size,
80         float alpha, float beta, float k) {
81     if (!one_of(prop_kind, forward_training, forward_inference))
82         return invalid_arguments;
83     return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
84             local_size, alpha, beta, k);
85 }
86 
dnnl_lrn_backward_desc_init(lrn_desc_t * lrn_desc,alg_kind_t alg_kind,const memory_desc_t * diff_data_desc,const memory_desc_t * data_desc,dim_t local_size,float alpha,float beta,float k)87 status_t dnnl_lrn_backward_desc_init(lrn_desc_t *lrn_desc, alg_kind_t alg_kind,
88         const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
89         dim_t local_size, float alpha, float beta, float k) {
90     return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
91             diff_data_desc, local_size, alpha, beta, k);
92 }
93 
94 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
95