1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "primitive_desc.hpp"
18 #include "type_helpers.hpp"
19 #include "utils.hpp"
20
21 #include "dnnl_thread.hpp"
22 #include "engine.hpp"
23 #include "primitive_hashing.hpp"
24
25 namespace dnnl {
26 namespace impl {
27 namespace primitive_hashing {
28
key_t(const engine_t * engine,const op_desc_t * op_desc,const primitive_attr_t * attr,int pd_iterator_offset,const std::vector<memory_desc_t> & hint_mds)29 key_t::key_t(const engine_t *engine, const op_desc_t *op_desc,
30 const primitive_attr_t *attr, int pd_iterator_offset,
31 const std::vector<memory_desc_t> &hint_mds)
32 : primitive_kind_(get_pkind(op_desc->kind))
33 , op_desc_(op_desc)
34 , attr_(attr)
35 , pd_iterator_offset_(pd_iterator_offset)
36 , impl_nthr_(dnnl_get_max_threads())
37 , hint_mds_(hint_mds)
38 #ifdef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE
39 , engine_id_(engine->engine_id())
40 #else
41 , engine_kind_(engine->kind())
42 , runtime_kind_(engine->runtime_kind())
43 , device_id_(engine->device_id())
44 #endif
45 , thread_id_(std::this_thread::get_id()) {
46 }
47
key_t(const primitive_desc_t * pd,const engine_t * engine)48 key_t::key_t(const primitive_desc_t *pd, const engine_t *engine)
49 : key_t(engine, pd->op_desc(), pd->attr(), pd->pd_iterator_offset(),
50 pd->hint_mds(false /* is_hint */)) {}
51
get_pkind(primitive_kind_t pkind)52 primitive_kind_t key_t::get_pkind(primitive_kind_t pkind) {
53 switch (pkind) {
54 case primitive_kind::softmax:
55 case primitive_kind::logsoftmax: return primitive_kind::softmax;
56 default: return pkind;
57 }
58 }
59
operator ==(const key_t & rhs) const60 bool key_t::operator==(const key_t &rhs) const {
61 DNNL_SHORT_CIRCUIT_SELF_COMPARISON(rhs);
62 // clang-format off
63 bool ret = true
64 // Less expensive comparisons come first
65 && primitive_kind_ == rhs.primitive_kind_
66 #ifdef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE
67 && engine_id_ == rhs.engine_id_
68 #else
69 && engine_kind_ == rhs.engine_kind_
70 && runtime_kind_ == rhs.runtime_kind_
71 && device_id_ == rhs.device_id_
72 #endif
73 && hint_mds_.size() == rhs.hint_mds_.size()
74 && pd_iterator_offset_ == rhs.pd_iterator_offset_
75 && impl_nthr_ == rhs.impl_nthr_
76 && (*attr_) == (*rhs.attr_);
77
78 if (!ret) return false;
79
80 #define CASE(pkind) \
81 case primitive_kind::pkind: \
82 ret = cast_to_desc<pkind##_desc_t>(op_desc_) \
83 == cast_to_desc<pkind##_desc_t>(rhs.op_desc_); \
84 break;
85
86 switch ((int)primitive_kind_) {
87 CASE(batch_normalization)
88 CASE(binary)
89 CASE(concat)
90 CASE(convolution)
91 CASE(deconvolution)
92 CASE(eltwise)
93 CASE(gemm)
94 CASE(inner_product)
95 CASE(layer_normalization)
96 CASE(lrn)
97 CASE(matmul)
98 CASE(pooling)
99 CASE(pooling_v2)
100 CASE(prelu)
101 CASE(reduction)
102 CASE(reorder)
103 CASE(resampling)
104 CASE(rnn)
105 CASE(shuffle)
106 CASE(softmax)
107 CASE(sum)
108 CASE(zero_pad)
109 default: assert(!"unknown primitive kind");
110 }
111 #undef CASE
112 // clang-format on
113
114 if (!ret) return false;
115
116 for (size_t i = 0; i < hint_mds_.size(); ++i)
117 if (hint_mds_[i] != rhs.hint_mds_[i]) return false;
118
119 return true;
120 }
121
122 // Combine hash of each memory_desc_t data member
get_md_hash(const memory_desc_t & md)123 size_t get_md_hash(const memory_desc_t &md) {
124 size_t seed = 0;
125 seed = get_array_hash(seed, md.dims, md.ndims);
126 seed = hash_combine(seed, static_cast<size_t>(md.data_type));
127 seed = get_array_hash(seed, md.padded_dims, md.ndims);
128 seed = get_array_hash(seed, md.padded_offsets, md.ndims);
129 seed = hash_combine(seed, md.offset0);
130 seed = hash_combine(seed, static_cast<size_t>(md.format_kind));
131 // format desc
132 switch (md.format_kind) {
133 case format_kind::undef:
134 case format_kind::any: break;
135 case format_kind::blocked:
136 for (int i = 0; i < md.ndims; i++) {
137 if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue;
138 seed = hash_combine(seed, md.format_desc.blocking.strides[i]);
139 }
140 seed = hash_combine(seed, md.format_desc.blocking.inner_nblks);
141 seed = get_array_hash(seed, md.format_desc.blocking.inner_blks,
142 md.format_desc.blocking.inner_nblks);
143 seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs,
144 md.format_desc.blocking.inner_nblks);
145 break;
146 case format_kind::wino:
147 seed = hash_combine(seed,
148 static_cast<size_t>(md.format_desc.wino_desc.wino_format));
149 seed = hash_combine(seed, md.format_desc.wino_desc.r);
150 seed = hash_combine(seed, md.format_desc.wino_desc.alpha);
151 seed = hash_combine(seed, md.format_desc.wino_desc.ic);
152 seed = hash_combine(seed, md.format_desc.wino_desc.oc);
153 seed = hash_combine(seed, md.format_desc.wino_desc.ic_block);
154 seed = hash_combine(seed, md.format_desc.wino_desc.oc_block);
155 seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block);
156 seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block);
157 seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale);
158 seed = hash_combine(seed, md.format_desc.wino_desc.size);
159 break;
160 case format_kind::rnn_packed:
161 seed = hash_combine(seed,
162 static_cast<size_t>(md.format_desc.rnn_packed_desc.format));
163 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts);
164 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n);
165 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb);
166 {
167 int n_parts = md.format_desc.rnn_packed_desc.n_parts;
168 seed = get_array_hash(
169 seed, md.format_desc.rnn_packed_desc.parts, n_parts);
170 seed = get_array_hash(seed,
171 md.format_desc.rnn_packed_desc.part_pack_size, n_parts);
172 seed = get_array_hash(seed,
173 md.format_desc.rnn_packed_desc.pack_part, n_parts);
174 }
175 seed = hash_combine(
176 seed, md.format_desc.rnn_packed_desc.offset_compensation);
177 seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size);
178 break;
179 default: assert(!"unknown format_kind");
180 }
181
182 if (md.extra.flags != dnnl_memory_extra_flag_none) {
183 seed = hash_combine(seed, md.extra.flags);
184 if (md.extra.flags
185 & (dnnl_memory_extra_flag_compensation_conv_s8s8
186 | dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
187 seed = hash_combine(seed, md.extra.compensation_mask);
188 }
189
190 if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
191 seed = hash_combine(seed, md.extra.scale_adjust);
192 }
193
194 if (md.extra.flags
195 & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
196 seed = hash_combine(seed, md.extra.asymm_compensation_mask);
197 }
198 }
199 // Combined hash for a memory descriptor
200 return seed;
201 }
202
203 // Combine hash of each primitive_attr_t data member
get_attr_hash(const primitive_attr_t & attr)204 size_t get_attr_hash(const primitive_attr_t &attr) {
205 size_t seed = 0;
206 // scratchpad_mode
207 seed = hash_combine(seed, static_cast<size_t>(attr.scratchpad_mode_));
208
209 if (!attr.output_scales_.has_default_values()) {
210 // output_scales: mask
211 seed = hash_combine(seed, attr.output_scales_.mask_);
212 // output_scales: count
213 seed = hash_combine(seed, attr.output_scales_.count_);
214 // output_scales: scales[:]
215 seed = get_array_hash(
216 seed, attr.output_scales_.scales_, attr.output_scales_.count_);
217 } else if (!attr.scales_.has_default_values()) {
218 // go through scales for all arguments
219 for (const auto &p : attr.scales_.scales_) {
220 seed = hash_combine(seed, p.second.mask_);
221 seed = hash_combine(seed, p.second.count_);
222 seed = get_array_hash(seed, p.second.scales_, p.second.count_);
223 }
224 }
225 // zero_points
226 for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
227 if (!attr.zero_points_.has_default_values(arg)) {
228 dim_t count = 0;
229 int mask = 0;
230 const int *zero_points = nullptr;
231 attr.zero_points_.get(arg, &count, &mask, &zero_points);
232 // zero_points: count
233 seed = hash_combine(seed, count);
234 // zero_points: mask
235 seed = hash_combine(seed, mask);
236 // zero_points: zero_points[:]
237 seed = get_array_hash(seed, zero_points, count);
238 }
239 // post_ops: entry[:]
240 for (int i = 0; i < attr.post_ops_.len(); i++) {
241 const auto &entry = attr.post_ops_.entry_[i];
242 switch (entry.kind) {
243 case primitive_kind::eltwise:
244 seed = hash_combine(
245 seed, static_cast<size_t>(entry.eltwise.alg));
246 seed = hash_combine(seed, entry.eltwise.scale);
247 seed = hash_combine(seed, entry.eltwise.alpha);
248 seed = hash_combine(seed, entry.eltwise.beta);
249 break;
250 case primitive_kind::sum:
251 seed = hash_combine(seed, entry.sum.scale);
252 seed = hash_combine(seed, static_cast<size_t>(entry.sum.dt));
253 break;
254 case primitive_kind::convolution:
255 seed = hash_combine(
256 seed, static_cast<size_t>(entry.depthwise_conv.stride));
257 seed = hash_combine(
258 seed, static_cast<size_t>(entry.depthwise_conv.wei_dt));
259 seed = hash_combine(seed,
260 static_cast<size_t>(entry.depthwise_conv.bias_dt));
261 seed = hash_combine(
262 seed, static_cast<size_t>(entry.depthwise_conv.dst_dt));
263 if (entry.depthwise_conv.scales) {
264 seed = hash_combine(seed, entry.depthwise_conv.mask);
265 seed = hash_combine(seed, entry.depthwise_conv.count);
266 seed = get_array_hash(seed, entry.depthwise_conv.scales,
267 entry.depthwise_conv.count);
268 }
269 break;
270 case primitive_kind::binary:
271 seed = hash_combine(
272 seed, static_cast<size_t>(entry.binary.alg));
273 seed = hash_combine(seed, get_md_hash(entry.binary.src1_desc));
274 break;
275 default: assert(!"unknown post_op");
276 }
277 }
278 // rnn_data_qparams: scale, shift
279 seed = hash_combine(seed, attr.rnn_data_qparams_.scale_);
280 seed = hash_combine(seed, attr.rnn_data_qparams_.shift_);
281 if (!attr.rnn_weights_qparams_.has_default_values()) {
282 // rnn_weights_qparams: mask
283 seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_);
284 // rnn_weights_qparams: count
285 seed = hash_combine(seed, attr.rnn_weights_qparams_.count_);
286 // rnn_weights_qparams: scales[:]
287 seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_,
288 attr.rnn_weights_qparams_.count_);
289 }
290 // Combined hash for attributes
291 return seed;
292 }
293
294 // Functions that compute hash for different op_descs
get_desc_hash(const concat_desc_t & desc)295 size_t get_desc_hash(const concat_desc_t &desc) {
296 size_t seed = 0;
297 // Kinds
298 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
299 // Memory descriptors
300 seed = hash_combine(seed, get_md_hash(*desc.dst_md));
301 // N
302 seed = hash_combine(seed, desc.n);
303 // Concat dimension
304 seed = hash_combine(seed, desc.concat_dimension);
305 // Array of mds
306 seed = get_array_hash(seed, desc.src_mds, desc.n);
307 // Combined hash for concat desc
308 return seed;
309 }
310
get_desc_hash(const batch_normalization_desc_t & desc)311 size_t get_desc_hash(const batch_normalization_desc_t &desc) {
312 size_t seed = 0;
313 // Kinds
314 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
315 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
316 // Memory descriptors
317 seed = hash_combine(seed, get_md_hash(desc.data_desc));
318 seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
319 seed = hash_combine(seed, get_md_hash(desc.data_scaleshift_desc));
320 seed = hash_combine(seed, get_md_hash(desc.diff_data_scaleshift_desc));
321 seed = hash_combine(seed, get_md_hash(desc.stat_desc));
322 // Epsilon
323 seed = hash_combine(seed, desc.batch_norm_epsilon);
324 // Flags
325 seed = hash_combine(seed, desc.flags);
326 // Combined hash for batch normalization desc
327 return seed;
328 }
329
get_desc_hash(const binary_desc_t & desc)330 size_t get_desc_hash(const binary_desc_t &desc) {
331 size_t seed = 0;
332 // Kinds
333 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
334 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
335 // Memory descriptors
336 seed = hash_combine(seed, get_md_hash(desc.src_desc[0]));
337 seed = hash_combine(seed, get_md_hash(desc.src_desc[1]));
338 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
339 // Combined hash for binary op desc
340 return seed;
341 }
342
343 // (De-)Convolution
get_desc_hash(const convolution_desc_t & desc)344 size_t get_desc_hash(const convolution_desc_t &desc) {
345 size_t seed = 0;
346 // Kinds
347 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
348 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
349 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
350 // Memory descriptors
351 seed = hash_combine(seed, get_md_hash(desc.src_desc));
352 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
353 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
354 seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
355 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
356 seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
357 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
358 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
359 // Strides, dilates, padding
360 seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS);
361 seed = get_array_hash(seed, desc.dilates, DNNL_MAX_NDIMS);
362 seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS);
363 seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS);
364 // Accumulator type
365 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
366 // Combined hash for (de-)convolution desc
367 return seed;
368 }
369
370 // Eltwise
get_desc_hash(const eltwise_desc_t & desc)371 size_t get_desc_hash(const eltwise_desc_t &desc) {
372 size_t seed = 0;
373 // Kinds
374 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
375 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
376 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
377 // Memory descriptors
378 seed = hash_combine(seed, get_md_hash(desc.data_desc));
379 seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
380 // Alpha, beta
381 seed = hash_combine(seed, desc.alpha);
382 seed = hash_combine(seed, desc.beta);
383 // Combined hash for eltwise desc
384 return seed;
385 }
386
get_desc_hash(const gemm_desc_t & desc)387 size_t get_desc_hash(const gemm_desc_t &desc) {
388 size_t seed = 0;
389 // Kinds
390 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
391 seed = hash_combine(seed, get_md_hash(desc.a_desc));
392 seed = hash_combine(seed, get_md_hash(desc.b_desc));
393 seed = hash_combine(seed, get_md_hash(desc.c_desc));
394 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
395 // Accumulator type
396 seed = hash_combine(seed, static_cast<size_t>(desc.acc_type));
397 // Combined hash for gemm desc
398 return seed;
399 }
400
get_desc_hash(const inner_product_desc_t & desc)401 size_t get_desc_hash(const inner_product_desc_t &desc) {
402 size_t seed = 0;
403 // Kinds
404 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
405 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
406 // Memory descriptors
407 seed = hash_combine(seed, get_md_hash(desc.src_desc));
408 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
409 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
410 seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
411 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
412 seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
413 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
414 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
415 // Accumulator type
416 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
417 // Combined hash for inner_product desc
418 return seed;
419 }
420
421 // Layer normalization
get_desc_hash(const layer_normalization_desc_t & desc)422 size_t get_desc_hash(const layer_normalization_desc_t &desc) {
423 size_t seed = 0;
424 // Kinds
425 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
426 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
427 // Memory descriptors
428 seed = hash_combine(seed, get_md_hash(desc.data_desc));
429 seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
430 seed = hash_combine(seed, get_md_hash(desc.data_scaleshift_desc));
431 seed = hash_combine(seed, get_md_hash(desc.diff_data_scaleshift_desc));
432 seed = hash_combine(seed, get_md_hash(desc.stat_desc));
433 // Epsilon
434 seed = hash_combine(seed, desc.layer_norm_epsilon);
435 // Flags
436 seed = hash_combine(seed, desc.flags);
437 // Combined hash for layer_normalization desc
438 return seed;
439 }
440
get_desc_hash(const lrn_desc_t & desc)441 size_t get_desc_hash(const lrn_desc_t &desc) {
442 size_t seed = 0;
443 // Kinds
444 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
445 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
446 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
447 // Memory descriptors
448 seed = hash_combine(seed, get_md_hash(desc.data_desc));
449 seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
450 // Local size
451 seed = hash_combine(seed, desc.local_size);
452 // Alpha, beta
453 seed = hash_combine(seed, desc.lrn_alpha);
454 seed = hash_combine(seed, desc.lrn_beta);
455 // k
456 seed = hash_combine(seed, desc.lrn_k);
457 // Combined hash for lrn desc
458 return seed;
459 }
460
get_desc_hash(const matmul_desc_t & desc)461 size_t get_desc_hash(const matmul_desc_t &desc) {
462 size_t seed = 0;
463 // Kinds
464 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
465 // Memory descriptors
466 seed = hash_combine(seed, get_md_hash(desc.src_desc));
467 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
468 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
469 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
470 // Accumulator type
471 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
472 // Combined hash for matmul op desc
473 return seed;
474 }
475
get_desc_hash(const pooling_desc_t & desc)476 size_t get_desc_hash(const pooling_desc_t &desc) {
477 size_t seed = 0;
478 // Kinds
479 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
480 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
481 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
482 // Memory descriptors
483 seed = hash_combine(seed, get_md_hash(desc.src_desc));
484 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
485 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
486 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
487 // Strides, dilates, padding
488 seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS);
489 seed = get_array_hash(seed, desc.kernel, DNNL_MAX_NDIMS);
490 seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS);
491 seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS);
492 // Accumulator type
493 seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
494 // Combined hash for pooling desc
495 return seed;
496 }
497
get_desc_hash(const pooling_v2_desc_t & desc)498 size_t get_desc_hash(const pooling_v2_desc_t &desc) {
499 const auto &v1_desc = *reinterpret_cast<const pooling_desc_t *>(&desc);
500 size_t seed = get_desc_hash(v1_desc);
501 seed = get_array_hash(seed, desc.dilation, DNNL_MAX_NDIMS);
502 return seed;
503 }
504
get_desc_hash(const prelu_desc_t & desc)505 size_t get_desc_hash(const prelu_desc_t &desc) {
506 size_t seed = 0;
507 // Kinds
508 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
509 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
510 // Memory descriptors
511 seed = hash_combine(seed, get_md_hash(desc.data_desc));
512 seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
513 seed = hash_combine(seed, get_md_hash(desc.weights_desc));
514 seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
515 // Combined hash for pooling desc
516 return seed;
517 }
518
get_desc_hash(const reduction_desc_t & desc)519 size_t get_desc_hash(const reduction_desc_t &desc) {
520 size_t seed = 0;
521 // Kinds
522 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
523 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
524 // Memory descriptors
525 seed = hash_combine(seed, get_md_hash(desc.src_desc));
526 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
527 // P, eps
528 seed = hash_combine(seed, desc.p);
529 seed = hash_combine(seed, desc.eps);
530 // Combined hash for reduction desc
531 return seed;
532 }
533
get_desc_hash(const reorder_desc_t & desc)534 size_t get_desc_hash(const reorder_desc_t &desc) {
535 size_t seed = 0;
536 // Kinds
537 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
538 // Memory descriptors
539 seed = hash_combine(seed, get_md_hash(*desc.src_md));
540 seed = hash_combine(seed, get_md_hash(*desc.dst_md));
541 // Kinds of source and destination engines
542 seed = hash_combine(seed, static_cast<size_t>(desc.src_engine_kind));
543 seed = hash_combine(seed, static_cast<size_t>(desc.dst_engine_kind));
544 seed = hash_combine(seed, desc.is_cross_engine);
545 // Combined hash for reorder desc
546 return seed;
547 }
548
get_desc_hash(const resampling_desc_t & desc)549 size_t get_desc_hash(const resampling_desc_t &desc) {
550 size_t seed = 0;
551 // Kinds
552 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
553 seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
554 // Memory descriptors
555 seed = hash_combine(seed, get_md_hash(desc.src_desc));
556 seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
557 seed = hash_combine(seed, get_md_hash(desc.dst_desc));
558 seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
559 // Factors
560 seed = get_array_hash(seed, desc.factors, DNNL_MAX_NDIMS);
561 // Combined hash for resampling op desc
562 return seed;
563 }
564
get_desc_hash(const rnn_desc_t & desc)565 size_t get_desc_hash(const rnn_desc_t &desc) {
566 size_t seed = 0;
567 // Kinds
568 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
569 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
570 seed = hash_combine(seed, static_cast<size_t>(desc.cell_kind));
571 seed = hash_combine(seed, static_cast<size_t>(desc.direction));
572 // Memory descriptors
573 seed = hash_combine(seed, get_md_hash(desc.src_layer_desc));
574 seed = hash_combine(seed, get_md_hash(desc.src_iter_desc));
575 seed = hash_combine(seed, get_md_hash(desc.src_iter_c_desc));
576 seed = hash_combine(seed, get_md_hash(desc.weights_layer_desc));
577 seed = hash_combine(seed, get_md_hash(desc.weights_iter_desc));
578 seed = hash_combine(seed, get_md_hash(desc.bias_desc));
579 seed = hash_combine(seed, get_md_hash(desc.dst_layer_desc));
580 seed = hash_combine(seed, get_md_hash(desc.dst_iter_desc));
581 seed = hash_combine(seed, get_md_hash(desc.dst_iter_c_desc));
582 seed = hash_combine(seed, get_md_hash(desc.weights_peephole_desc));
583 seed = hash_combine(seed, get_md_hash(desc.weights_projection_desc));
584 seed = hash_combine(seed, get_md_hash(desc.diff_src_layer_desc));
585 seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_desc));
586 seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_c_desc));
587 seed = hash_combine(seed, get_md_hash(desc.diff_weights_layer_desc));
588 seed = hash_combine(seed, get_md_hash(desc.diff_weights_iter_desc));
589 seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
590 seed = hash_combine(seed, get_md_hash(desc.diff_dst_layer_desc));
591 seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_desc));
592 seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_c_desc));
593 seed = hash_combine(seed, get_md_hash(desc.diff_weights_peephole_desc));
594 seed = hash_combine(seed, get_md_hash(desc.diff_weights_projection_desc));
595 // Flags
596 seed = hash_combine(seed, desc.flags);
597 // Activation kind
598 seed = hash_combine(seed, static_cast<size_t>(desc.activation_kind));
599 // Alpha, beta
600 seed = hash_combine(seed, desc.alpha);
601 seed = hash_combine(seed, desc.beta);
602 // Combined hash for rnn desc
603 return seed;
604 }
605
606 // Shuffle
get_desc_hash(const shuffle_desc_t & desc)607 size_t get_desc_hash(const shuffle_desc_t &desc) {
608 size_t seed = 0;
609 // Kinds
610 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
611 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
612 // Memory descriptors
613 seed = hash_combine(seed, get_md_hash(desc.data_desc));
614 // Axis
615 seed = hash_combine(seed, desc.axis);
616 // Groupe size
617 seed = hash_combine(seed, desc.group_size);
618 // Combined hash for shuffle desc
619 return seed;
620 }
621
get_desc_hash(const softmax_desc_t & desc)622 size_t get_desc_hash(const softmax_desc_t &desc) {
623 size_t seed = 0;
624 // Kinds
625 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
626 seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
627 // Memory descriptors
628 seed = hash_combine(seed, get_md_hash(desc.data_desc));
629 seed = hash_combine(seed, get_md_hash(desc.diff_desc));
630 // Axis
631 seed = hash_combine(seed, desc.softmax_axis);
632 // Combined hash for softmax desc
633 return seed;
634 }
635
get_desc_hash(const sum_desc_t & desc)636 size_t get_desc_hash(const sum_desc_t &desc) {
637 size_t seed = 0;
638 // Kinds
639 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
640 // Memory descriptors
641 seed = hash_combine(seed, get_md_hash(*desc.dst_md));
642 // N
643 seed = hash_combine(seed, desc.n);
644 // Scales
645 if (desc.scales) { seed = get_array_hash(seed, desc.scales, desc.n); }
646 // Array of mds
647 seed = get_array_hash(seed, desc.src_mds, desc.n);
648 // Combined hash for sum desc
649 return seed;
650 }
651
get_desc_hash(const zero_pad_desc_t & desc)652 size_t get_desc_hash(const zero_pad_desc_t &desc) {
653 size_t seed = 0;
654 // Kinds
655 seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
656 return seed;
657 }
658
659 } // namespace primitive_hashing
660 } // namespace impl
661 } // namespace dnnl
662