1 /*******************************************************************************
2 * Copyright 2016-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <assert.h>
18 #include "oneapi/dnnl/dnnl.h"
19
20 #include "c_types_map.hpp"
21 #include "type_helpers.hpp"
22 #include "utils.hpp"
23
24 using namespace dnnl::impl;
25 using namespace dnnl::impl::utils;
26 using namespace dnnl::impl::status;
27 using namespace dnnl::impl::prop_kind;
28 using namespace dnnl::impl::alg_kind;
29 using namespace dnnl::impl::types;
30
31 namespace {
lrn_desc_init(lrn_desc_t * lrn_desc,prop_kind_t prop_kind,alg_kind_t alg_kind,const memory_desc_t * data_desc,const memory_desc_t * diff_data_desc,dim_t local_size,float alpha,float beta,float k)32 status_t lrn_desc_init(lrn_desc_t *lrn_desc, prop_kind_t prop_kind,
33 alg_kind_t alg_kind, const memory_desc_t *data_desc,
34 const memory_desc_t *diff_data_desc, dim_t local_size, float alpha,
35 float beta, float k) {
36 bool args_ok = true && !any_null(lrn_desc, data_desc)
37 && one_of(alg_kind, lrn_within_channel, lrn_across_channels)
38 && one_of(prop_kind, forward_training, forward_inference,
39 backward_data)
40 && IMPLICATION(
41 prop_kind == backward_data, diff_data_desc != nullptr);
42 if (!args_ok) return invalid_arguments;
43
44 auto ld = lrn_desc_t();
45 ld.primitive_kind = primitive_kind::lrn;
46 ld.prop_kind = prop_kind;
47 ld.alg_kind = alg_kind;
48
49 const bool is_fwd = one_of(prop_kind, forward_training, forward_inference);
50
51 bool runtime_dims_or_strides
52 = memory_desc_wrapper(data_desc).has_runtime_dims_or_strides();
53 if (!is_fwd)
54 runtime_dims_or_strides = runtime_dims_or_strides
55 || memory_desc_wrapper(diff_data_desc)
56 .has_runtime_dims_or_strides();
57 if (runtime_dims_or_strides) return unimplemented;
58
59 ld.data_desc = *data_desc;
60 if (!is_fwd) ld.diff_data_desc = *diff_data_desc;
61
62 ld.local_size = local_size;
63 ld.lrn_alpha = alpha;
64 ld.lrn_beta = beta;
65 ld.lrn_k = k;
66
67 bool consistency = ld.data_desc.ndims >= 2;
68 if (consistency && ld.prop_kind == backward_data)
69 consistency = array_cmp(
70 ld.diff_data_desc.dims, ld.data_desc.dims, ld.data_desc.ndims);
71 if (!consistency) return invalid_arguments;
72
73 *lrn_desc = ld;
74 return success;
75 }
76 } // namespace
77
dnnl_lrn_forward_desc_init(lrn_desc_t * lrn_desc,prop_kind_t prop_kind,alg_kind_t alg_kind,const memory_desc_t * data_desc,dim_t local_size,float alpha,float beta,float k)78 status_t dnnl_lrn_forward_desc_init(lrn_desc_t *lrn_desc, prop_kind_t prop_kind,
79 alg_kind_t alg_kind, const memory_desc_t *data_desc, dim_t local_size,
80 float alpha, float beta, float k) {
81 if (!one_of(prop_kind, forward_training, forward_inference))
82 return invalid_arguments;
83 return lrn_desc_init(lrn_desc, prop_kind, alg_kind, data_desc, nullptr,
84 local_size, alpha, beta, k);
85 }
86
dnnl_lrn_backward_desc_init(lrn_desc_t * lrn_desc,alg_kind_t alg_kind,const memory_desc_t * diff_data_desc,const memory_desc_t * data_desc,dim_t local_size,float alpha,float beta,float k)87 status_t dnnl_lrn_backward_desc_init(lrn_desc_t *lrn_desc, alg_kind_t alg_kind,
88 const memory_desc_t *diff_data_desc, const memory_desc_t *data_desc,
89 dim_t local_size, float alpha, float beta, float k) {
90 return lrn_desc_init(lrn_desc, backward_data, alg_kind, data_desc,
91 diff_data_desc, local_size, alpha, beta, k);
92 }
93
94 // vim: et ts=4 sw=4 cindent cino+=l0,\:4,N-s
95