1 /*******************************************************************************
2 * Copyright 2019-2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "primitive_desc.hpp"
18 #include "type_helpers.hpp"
19 #include "utils.hpp"
20 
21 #include "dnnl_thread.hpp"
22 #include "engine.hpp"
23 #include "primitive_hashing.hpp"
24 
25 namespace dnnl {
26 namespace impl {
27 namespace primitive_hashing {
28 
key_t(const engine_t * engine,const op_desc_t * op_desc,const primitive_attr_t * attr,int pd_iterator_offset,const std::vector<memory_desc_t> & hint_mds)29 key_t::key_t(const engine_t *engine, const op_desc_t *op_desc,
30         const primitive_attr_t *attr, int pd_iterator_offset,
31         const std::vector<memory_desc_t> &hint_mds)
32     : primitive_kind_(get_pkind(op_desc->kind))
33     , op_desc_(op_desc)
34     , attr_(attr)
35     , pd_iterator_offset_(pd_iterator_offset)
36     , impl_nthr_(dnnl_get_max_threads())
37     , hint_mds_(hint_mds)
38 #ifdef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE
39     , engine_id_(engine->engine_id())
40 #else
41     , engine_kind_(engine->kind())
42     , runtime_kind_(engine->runtime_kind())
43     , device_id_(engine->device_id())
44 #endif
45     , thread_id_(std::this_thread::get_id()) {
46 }
47 
key_t(const primitive_desc_t * pd,const engine_t * engine)48 key_t::key_t(const primitive_desc_t *pd, const engine_t *engine)
49     : key_t(engine, pd->op_desc(), pd->attr(), pd->pd_iterator_offset(),
50             pd->hint_mds(false /* is_hint */)) {}
51 
get_pkind(primitive_kind_t pkind)52 primitive_kind_t key_t::get_pkind(primitive_kind_t pkind) {
53     switch (pkind) {
54         case primitive_kind::softmax:
55         case primitive_kind::logsoftmax: return primitive_kind::softmax;
56         default: return pkind;
57     }
58 }
59 
operator ==(const key_t & rhs) const60 bool key_t::operator==(const key_t &rhs) const {
61     DNNL_SHORT_CIRCUIT_SELF_COMPARISON(rhs);
62     // clang-format off
63     bool ret = true
64         // Less expensive comparisons come first
65         && primitive_kind_ == rhs.primitive_kind_
66 #ifdef DNNL_USE_RT_OBJECTS_IN_PRIMITIVE_CACHE
67         && engine_id_ == rhs.engine_id_
68 #else
69         && engine_kind_ == rhs.engine_kind_
70         && runtime_kind_ == rhs.runtime_kind_
71         && device_id_ == rhs.device_id_
72 #endif
73         && hint_mds_.size() == rhs.hint_mds_.size()
74         && pd_iterator_offset_ == rhs.pd_iterator_offset_
75         && impl_nthr_ == rhs.impl_nthr_
76         && (*attr_) == (*rhs.attr_);
77 
78     if (!ret) return false;
79 
80 #define CASE(pkind) \
81     case primitive_kind::pkind: \
82         ret = cast_to_desc<pkind##_desc_t>(op_desc_) \
83                 == cast_to_desc<pkind##_desc_t>(rhs.op_desc_); \
84         break;
85 
86         switch ((int)primitive_kind_) {
87             CASE(batch_normalization)
88             CASE(binary)
89             CASE(concat)
90             CASE(convolution)
91             CASE(deconvolution)
92             CASE(eltwise)
93             CASE(gemm)
94             CASE(inner_product)
95             CASE(layer_normalization)
96             CASE(lrn)
97             CASE(matmul)
98             CASE(pooling)
99             CASE(pooling_v2)
100             CASE(prelu)
101             CASE(reduction)
102             CASE(reorder)
103             CASE(resampling)
104             CASE(rnn)
105             CASE(shuffle)
106             CASE(softmax)
107             CASE(sum)
108             CASE(zero_pad)
109             default: assert(!"unknown primitive kind");
110         }
111 #undef CASE
112     // clang-format on
113 
114     if (!ret) return false;
115 
116     for (size_t i = 0; i < hint_mds_.size(); ++i)
117         if (hint_mds_[i] != rhs.hint_mds_[i]) return false;
118 
119     return true;
120 }
121 
122 // Combine hash of each memory_desc_t data member
get_md_hash(const memory_desc_t & md)123 size_t get_md_hash(const memory_desc_t &md) {
124     size_t seed = 0;
125     seed = get_array_hash(seed, md.dims, md.ndims);
126     seed = hash_combine(seed, static_cast<size_t>(md.data_type));
127     seed = get_array_hash(seed, md.padded_dims, md.ndims);
128     seed = get_array_hash(seed, md.padded_offsets, md.ndims);
129     seed = hash_combine(seed, md.offset0);
130     seed = hash_combine(seed, static_cast<size_t>(md.format_kind));
131     // format desc
132     switch (md.format_kind) {
133         case format_kind::undef:
134         case format_kind::any: break;
135         case format_kind::blocked:
136             for (int i = 0; i < md.ndims; i++) {
137                 if (md.dims[i] == 1 && md.padded_dims[i] == 1) continue;
138                 seed = hash_combine(seed, md.format_desc.blocking.strides[i]);
139             }
140             seed = hash_combine(seed, md.format_desc.blocking.inner_nblks);
141             seed = get_array_hash(seed, md.format_desc.blocking.inner_blks,
142                     md.format_desc.blocking.inner_nblks);
143             seed = get_array_hash(seed, md.format_desc.blocking.inner_idxs,
144                     md.format_desc.blocking.inner_nblks);
145             break;
146         case format_kind::wino:
147             seed = hash_combine(seed,
148                     static_cast<size_t>(md.format_desc.wino_desc.wino_format));
149             seed = hash_combine(seed, md.format_desc.wino_desc.r);
150             seed = hash_combine(seed, md.format_desc.wino_desc.alpha);
151             seed = hash_combine(seed, md.format_desc.wino_desc.ic);
152             seed = hash_combine(seed, md.format_desc.wino_desc.oc);
153             seed = hash_combine(seed, md.format_desc.wino_desc.ic_block);
154             seed = hash_combine(seed, md.format_desc.wino_desc.oc_block);
155             seed = hash_combine(seed, md.format_desc.wino_desc.ic2_block);
156             seed = hash_combine(seed, md.format_desc.wino_desc.oc2_block);
157             seed = hash_combine(seed, md.format_desc.wino_desc.adj_scale);
158             seed = hash_combine(seed, md.format_desc.wino_desc.size);
159             break;
160         case format_kind::rnn_packed:
161             seed = hash_combine(seed,
162                     static_cast<size_t>(md.format_desc.rnn_packed_desc.format));
163             seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n_parts);
164             seed = hash_combine(seed, md.format_desc.rnn_packed_desc.n);
165             seed = hash_combine(seed, md.format_desc.rnn_packed_desc.ldb);
166             {
167                 int n_parts = md.format_desc.rnn_packed_desc.n_parts;
168                 seed = get_array_hash(
169                         seed, md.format_desc.rnn_packed_desc.parts, n_parts);
170                 seed = get_array_hash(seed,
171                         md.format_desc.rnn_packed_desc.part_pack_size, n_parts);
172                 seed = get_array_hash(seed,
173                         md.format_desc.rnn_packed_desc.pack_part, n_parts);
174             }
175             seed = hash_combine(
176                     seed, md.format_desc.rnn_packed_desc.offset_compensation);
177             seed = hash_combine(seed, md.format_desc.rnn_packed_desc.size);
178             break;
179         default: assert(!"unknown format_kind");
180     }
181 
182     if (md.extra.flags != dnnl_memory_extra_flag_none) {
183         seed = hash_combine(seed, md.extra.flags);
184         if (md.extra.flags
185                 & (dnnl_memory_extra_flag_compensation_conv_s8s8
186                         | dnnl_memory_extra_flag_rnn_u8s8_compensation)) {
187             seed = hash_combine(seed, md.extra.compensation_mask);
188         }
189 
190         if (md.extra.flags & dnnl_memory_extra_flag_scale_adjust) {
191             seed = hash_combine(seed, md.extra.scale_adjust);
192         }
193 
194         if (md.extra.flags
195                 & dnnl_memory_extra_flag_compensation_conv_asymmetric_src) {
196             seed = hash_combine(seed, md.extra.asymm_compensation_mask);
197         }
198     }
199     // Combined hash for a memory descriptor
200     return seed;
201 }
202 
203 // Combine hash of each primitive_attr_t data member
get_attr_hash(const primitive_attr_t & attr)204 size_t get_attr_hash(const primitive_attr_t &attr) {
205     size_t seed = 0;
206     // scratchpad_mode
207     seed = hash_combine(seed, static_cast<size_t>(attr.scratchpad_mode_));
208 
209     if (!attr.output_scales_.has_default_values()) {
210         // output_scales: mask
211         seed = hash_combine(seed, attr.output_scales_.mask_);
212         // output_scales: count
213         seed = hash_combine(seed, attr.output_scales_.count_);
214         // output_scales: scales[:]
215         seed = get_array_hash(
216                 seed, attr.output_scales_.scales_, attr.output_scales_.count_);
217     } else if (!attr.scales_.has_default_values()) {
218         // go through scales for all arguments
219         for (const auto &p : attr.scales_.scales_) {
220             seed = hash_combine(seed, p.second.mask_);
221             seed = hash_combine(seed, p.second.count_);
222             seed = get_array_hash(seed, p.second.scales_, p.second.count_);
223         }
224     }
225     // zero_points
226     for (int arg : {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST})
227         if (!attr.zero_points_.has_default_values(arg)) {
228             dim_t count = 0;
229             int mask = 0;
230             const int *zero_points = nullptr;
231             attr.zero_points_.get(arg, &count, &mask, &zero_points);
232             // zero_points: count
233             seed = hash_combine(seed, count);
234             // zero_points: mask
235             seed = hash_combine(seed, mask);
236             // zero_points: zero_points[:]
237             seed = get_array_hash(seed, zero_points, count);
238         }
239     // post_ops: entry[:]
240     for (int i = 0; i < attr.post_ops_.len(); i++) {
241         const auto &entry = attr.post_ops_.entry_[i];
242         switch (entry.kind) {
243             case primitive_kind::eltwise:
244                 seed = hash_combine(
245                         seed, static_cast<size_t>(entry.eltwise.alg));
246                 seed = hash_combine(seed, entry.eltwise.scale);
247                 seed = hash_combine(seed, entry.eltwise.alpha);
248                 seed = hash_combine(seed, entry.eltwise.beta);
249                 break;
250             case primitive_kind::sum:
251                 seed = hash_combine(seed, entry.sum.scale);
252                 seed = hash_combine(seed, static_cast<size_t>(entry.sum.dt));
253                 break;
254             case primitive_kind::convolution:
255                 seed = hash_combine(
256                         seed, static_cast<size_t>(entry.depthwise_conv.stride));
257                 seed = hash_combine(
258                         seed, static_cast<size_t>(entry.depthwise_conv.wei_dt));
259                 seed = hash_combine(seed,
260                         static_cast<size_t>(entry.depthwise_conv.bias_dt));
261                 seed = hash_combine(
262                         seed, static_cast<size_t>(entry.depthwise_conv.dst_dt));
263                 if (entry.depthwise_conv.scales) {
264                     seed = hash_combine(seed, entry.depthwise_conv.mask);
265                     seed = hash_combine(seed, entry.depthwise_conv.count);
266                     seed = get_array_hash(seed, entry.depthwise_conv.scales,
267                             entry.depthwise_conv.count);
268                 }
269                 break;
270             case primitive_kind::binary:
271                 seed = hash_combine(
272                         seed, static_cast<size_t>(entry.binary.alg));
273                 seed = hash_combine(seed, get_md_hash(entry.binary.src1_desc));
274                 break;
275             default: assert(!"unknown post_op");
276         }
277     }
278     // rnn_data_qparams: scale, shift
279     seed = hash_combine(seed, attr.rnn_data_qparams_.scale_);
280     seed = hash_combine(seed, attr.rnn_data_qparams_.shift_);
281     if (!attr.rnn_weights_qparams_.has_default_values()) {
282         // rnn_weights_qparams: mask
283         seed = hash_combine(seed, attr.rnn_weights_qparams_.mask_);
284         // rnn_weights_qparams: count
285         seed = hash_combine(seed, attr.rnn_weights_qparams_.count_);
286         // rnn_weights_qparams: scales[:]
287         seed = get_array_hash(seed, attr.rnn_weights_qparams_.scales_,
288                 attr.rnn_weights_qparams_.count_);
289     }
290     // Combined hash for attributes
291     return seed;
292 }
293 
294 // Functions that compute hash for different op_descs
get_desc_hash(const concat_desc_t & desc)295 size_t get_desc_hash(const concat_desc_t &desc) {
296     size_t seed = 0;
297     // Kinds
298     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
299     // Memory descriptors
300     seed = hash_combine(seed, get_md_hash(*desc.dst_md));
301     // N
302     seed = hash_combine(seed, desc.n);
303     // Concat dimension
304     seed = hash_combine(seed, desc.concat_dimension);
305     // Array of mds
306     seed = get_array_hash(seed, desc.src_mds, desc.n);
307     // Combined hash for concat desc
308     return seed;
309 }
310 
get_desc_hash(const batch_normalization_desc_t & desc)311 size_t get_desc_hash(const batch_normalization_desc_t &desc) {
312     size_t seed = 0;
313     // Kinds
314     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
315     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
316     // Memory descriptors
317     seed = hash_combine(seed, get_md_hash(desc.data_desc));
318     seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
319     seed = hash_combine(seed, get_md_hash(desc.data_scaleshift_desc));
320     seed = hash_combine(seed, get_md_hash(desc.diff_data_scaleshift_desc));
321     seed = hash_combine(seed, get_md_hash(desc.stat_desc));
322     // Epsilon
323     seed = hash_combine(seed, desc.batch_norm_epsilon);
324     // Flags
325     seed = hash_combine(seed, desc.flags);
326     // Combined hash for batch normalization desc
327     return seed;
328 }
329 
get_desc_hash(const binary_desc_t & desc)330 size_t get_desc_hash(const binary_desc_t &desc) {
331     size_t seed = 0;
332     // Kinds
333     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
334     seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
335     // Memory descriptors
336     seed = hash_combine(seed, get_md_hash(desc.src_desc[0]));
337     seed = hash_combine(seed, get_md_hash(desc.src_desc[1]));
338     seed = hash_combine(seed, get_md_hash(desc.dst_desc));
339     // Combined hash for binary op desc
340     return seed;
341 }
342 
343 // (De-)Convolution
get_desc_hash(const convolution_desc_t & desc)344 size_t get_desc_hash(const convolution_desc_t &desc) {
345     size_t seed = 0;
346     // Kinds
347     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
348     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
349     seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
350     // Memory descriptors
351     seed = hash_combine(seed, get_md_hash(desc.src_desc));
352     seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
353     seed = hash_combine(seed, get_md_hash(desc.weights_desc));
354     seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
355     seed = hash_combine(seed, get_md_hash(desc.bias_desc));
356     seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
357     seed = hash_combine(seed, get_md_hash(desc.dst_desc));
358     seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
359     // Strides, dilates, padding
360     seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS);
361     seed = get_array_hash(seed, desc.dilates, DNNL_MAX_NDIMS);
362     seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS);
363     seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS);
364     // Accumulator type
365     seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
366     // Combined hash for (de-)convolution desc
367     return seed;
368 }
369 
370 // Eltwise
get_desc_hash(const eltwise_desc_t & desc)371 size_t get_desc_hash(const eltwise_desc_t &desc) {
372     size_t seed = 0;
373     // Kinds
374     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
375     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
376     seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
377     // Memory descriptors
378     seed = hash_combine(seed, get_md_hash(desc.data_desc));
379     seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
380     // Alpha, beta
381     seed = hash_combine(seed, desc.alpha);
382     seed = hash_combine(seed, desc.beta);
383     // Combined hash for eltwise desc
384     return seed;
385 }
386 
get_desc_hash(const gemm_desc_t & desc)387 size_t get_desc_hash(const gemm_desc_t &desc) {
388     size_t seed = 0;
389     // Kinds
390     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
391     seed = hash_combine(seed, get_md_hash(desc.a_desc));
392     seed = hash_combine(seed, get_md_hash(desc.b_desc));
393     seed = hash_combine(seed, get_md_hash(desc.c_desc));
394     seed = hash_combine(seed, get_md_hash(desc.bias_desc));
395     // Accumulator type
396     seed = hash_combine(seed, static_cast<size_t>(desc.acc_type));
397     // Combined hash for gemm desc
398     return seed;
399 }
400 
get_desc_hash(const inner_product_desc_t & desc)401 size_t get_desc_hash(const inner_product_desc_t &desc) {
402     size_t seed = 0;
403     // Kinds
404     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
405     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
406     // Memory descriptors
407     seed = hash_combine(seed, get_md_hash(desc.src_desc));
408     seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
409     seed = hash_combine(seed, get_md_hash(desc.weights_desc));
410     seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
411     seed = hash_combine(seed, get_md_hash(desc.bias_desc));
412     seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
413     seed = hash_combine(seed, get_md_hash(desc.dst_desc));
414     seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
415     // Accumulator type
416     seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
417     // Combined hash for inner_product desc
418     return seed;
419 }
420 
421 // Layer normalization
get_desc_hash(const layer_normalization_desc_t & desc)422 size_t get_desc_hash(const layer_normalization_desc_t &desc) {
423     size_t seed = 0;
424     // Kinds
425     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
426     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
427     // Memory descriptors
428     seed = hash_combine(seed, get_md_hash(desc.data_desc));
429     seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
430     seed = hash_combine(seed, get_md_hash(desc.data_scaleshift_desc));
431     seed = hash_combine(seed, get_md_hash(desc.diff_data_scaleshift_desc));
432     seed = hash_combine(seed, get_md_hash(desc.stat_desc));
433     // Epsilon
434     seed = hash_combine(seed, desc.layer_norm_epsilon);
435     // Flags
436     seed = hash_combine(seed, desc.flags);
437     // Combined hash for layer_normalization desc
438     return seed;
439 }
440 
get_desc_hash(const lrn_desc_t & desc)441 size_t get_desc_hash(const lrn_desc_t &desc) {
442     size_t seed = 0;
443     // Kinds
444     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
445     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
446     seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
447     // Memory descriptors
448     seed = hash_combine(seed, get_md_hash(desc.data_desc));
449     seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
450     // Local size
451     seed = hash_combine(seed, desc.local_size);
452     // Alpha, beta
453     seed = hash_combine(seed, desc.lrn_alpha);
454     seed = hash_combine(seed, desc.lrn_beta);
455     // k
456     seed = hash_combine(seed, desc.lrn_k);
457     // Combined hash for lrn desc
458     return seed;
459 }
460 
get_desc_hash(const matmul_desc_t & desc)461 size_t get_desc_hash(const matmul_desc_t &desc) {
462     size_t seed = 0;
463     // Kinds
464     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
465     // Memory descriptors
466     seed = hash_combine(seed, get_md_hash(desc.src_desc));
467     seed = hash_combine(seed, get_md_hash(desc.weights_desc));
468     seed = hash_combine(seed, get_md_hash(desc.bias_desc));
469     seed = hash_combine(seed, get_md_hash(desc.dst_desc));
470     // Accumulator type
471     seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
472     // Combined hash for matmul op desc
473     return seed;
474 }
475 
get_desc_hash(const pooling_desc_t & desc)476 size_t get_desc_hash(const pooling_desc_t &desc) {
477     size_t seed = 0;
478     // Kinds
479     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
480     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
481     seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
482     // Memory descriptors
483     seed = hash_combine(seed, get_md_hash(desc.src_desc));
484     seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
485     seed = hash_combine(seed, get_md_hash(desc.dst_desc));
486     seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
487     // Strides, dilates, padding
488     seed = get_array_hash(seed, desc.strides, DNNL_MAX_NDIMS);
489     seed = get_array_hash(seed, desc.kernel, DNNL_MAX_NDIMS);
490     seed = get_array_hash(seed, desc.padding[0], DNNL_MAX_NDIMS);
491     seed = get_array_hash(seed, desc.padding[1], DNNL_MAX_NDIMS);
492     // Accumulator type
493     seed = hash_combine(seed, static_cast<size_t>(desc.accum_data_type));
494     // Combined hash for pooling desc
495     return seed;
496 }
497 
get_desc_hash(const pooling_v2_desc_t & desc)498 size_t get_desc_hash(const pooling_v2_desc_t &desc) {
499     const auto &v1_desc = *reinterpret_cast<const pooling_desc_t *>(&desc);
500     size_t seed = get_desc_hash(v1_desc);
501     seed = get_array_hash(seed, desc.dilation, DNNL_MAX_NDIMS);
502     return seed;
503 }
504 
get_desc_hash(const prelu_desc_t & desc)505 size_t get_desc_hash(const prelu_desc_t &desc) {
506     size_t seed = 0;
507     // Kinds
508     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
509     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
510     // Memory descriptors
511     seed = hash_combine(seed, get_md_hash(desc.data_desc));
512     seed = hash_combine(seed, get_md_hash(desc.diff_data_desc));
513     seed = hash_combine(seed, get_md_hash(desc.weights_desc));
514     seed = hash_combine(seed, get_md_hash(desc.diff_weights_desc));
515     // Combined hash for pooling desc
516     return seed;
517 }
518 
get_desc_hash(const reduction_desc_t & desc)519 size_t get_desc_hash(const reduction_desc_t &desc) {
520     size_t seed = 0;
521     // Kinds
522     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
523     seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
524     // Memory descriptors
525     seed = hash_combine(seed, get_md_hash(desc.src_desc));
526     seed = hash_combine(seed, get_md_hash(desc.dst_desc));
527     // P, eps
528     seed = hash_combine(seed, desc.p);
529     seed = hash_combine(seed, desc.eps);
530     // Combined hash for reduction desc
531     return seed;
532 }
533 
get_desc_hash(const reorder_desc_t & desc)534 size_t get_desc_hash(const reorder_desc_t &desc) {
535     size_t seed = 0;
536     // Kinds
537     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
538     // Memory descriptors
539     seed = hash_combine(seed, get_md_hash(*desc.src_md));
540     seed = hash_combine(seed, get_md_hash(*desc.dst_md));
541     // Kinds of source and destination engines
542     seed = hash_combine(seed, static_cast<size_t>(desc.src_engine_kind));
543     seed = hash_combine(seed, static_cast<size_t>(desc.dst_engine_kind));
544     seed = hash_combine(seed, desc.is_cross_engine);
545     // Combined hash for reorder desc
546     return seed;
547 }
548 
get_desc_hash(const resampling_desc_t & desc)549 size_t get_desc_hash(const resampling_desc_t &desc) {
550     size_t seed = 0;
551     // Kinds
552     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
553     seed = hash_combine(seed, static_cast<size_t>(desc.alg_kind));
554     // Memory descriptors
555     seed = hash_combine(seed, get_md_hash(desc.src_desc));
556     seed = hash_combine(seed, get_md_hash(desc.diff_src_desc));
557     seed = hash_combine(seed, get_md_hash(desc.dst_desc));
558     seed = hash_combine(seed, get_md_hash(desc.diff_dst_desc));
559     // Factors
560     seed = get_array_hash(seed, desc.factors, DNNL_MAX_NDIMS);
561     // Combined hash for resampling op desc
562     return seed;
563 }
564 
get_desc_hash(const rnn_desc_t & desc)565 size_t get_desc_hash(const rnn_desc_t &desc) {
566     size_t seed = 0;
567     // Kinds
568     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
569     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
570     seed = hash_combine(seed, static_cast<size_t>(desc.cell_kind));
571     seed = hash_combine(seed, static_cast<size_t>(desc.direction));
572     // Memory descriptors
573     seed = hash_combine(seed, get_md_hash(desc.src_layer_desc));
574     seed = hash_combine(seed, get_md_hash(desc.src_iter_desc));
575     seed = hash_combine(seed, get_md_hash(desc.src_iter_c_desc));
576     seed = hash_combine(seed, get_md_hash(desc.weights_layer_desc));
577     seed = hash_combine(seed, get_md_hash(desc.weights_iter_desc));
578     seed = hash_combine(seed, get_md_hash(desc.bias_desc));
579     seed = hash_combine(seed, get_md_hash(desc.dst_layer_desc));
580     seed = hash_combine(seed, get_md_hash(desc.dst_iter_desc));
581     seed = hash_combine(seed, get_md_hash(desc.dst_iter_c_desc));
582     seed = hash_combine(seed, get_md_hash(desc.weights_peephole_desc));
583     seed = hash_combine(seed, get_md_hash(desc.weights_projection_desc));
584     seed = hash_combine(seed, get_md_hash(desc.diff_src_layer_desc));
585     seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_desc));
586     seed = hash_combine(seed, get_md_hash(desc.diff_src_iter_c_desc));
587     seed = hash_combine(seed, get_md_hash(desc.diff_weights_layer_desc));
588     seed = hash_combine(seed, get_md_hash(desc.diff_weights_iter_desc));
589     seed = hash_combine(seed, get_md_hash(desc.diff_bias_desc));
590     seed = hash_combine(seed, get_md_hash(desc.diff_dst_layer_desc));
591     seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_desc));
592     seed = hash_combine(seed, get_md_hash(desc.diff_dst_iter_c_desc));
593     seed = hash_combine(seed, get_md_hash(desc.diff_weights_peephole_desc));
594     seed = hash_combine(seed, get_md_hash(desc.diff_weights_projection_desc));
595     // Flags
596     seed = hash_combine(seed, desc.flags);
597     // Activation kind
598     seed = hash_combine(seed, static_cast<size_t>(desc.activation_kind));
599     // Alpha, beta
600     seed = hash_combine(seed, desc.alpha);
601     seed = hash_combine(seed, desc.beta);
602     // Combined hash for rnn desc
603     return seed;
604 }
605 
606 // Shuffle
get_desc_hash(const shuffle_desc_t & desc)607 size_t get_desc_hash(const shuffle_desc_t &desc) {
608     size_t seed = 0;
609     // Kinds
610     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
611     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
612     // Memory descriptors
613     seed = hash_combine(seed, get_md_hash(desc.data_desc));
614     // Axis
615     seed = hash_combine(seed, desc.axis);
616     // Groupe size
617     seed = hash_combine(seed, desc.group_size);
618     // Combined hash for shuffle desc
619     return seed;
620 }
621 
get_desc_hash(const softmax_desc_t & desc)622 size_t get_desc_hash(const softmax_desc_t &desc) {
623     size_t seed = 0;
624     // Kinds
625     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
626     seed = hash_combine(seed, static_cast<size_t>(desc.prop_kind));
627     // Memory descriptors
628     seed = hash_combine(seed, get_md_hash(desc.data_desc));
629     seed = hash_combine(seed, get_md_hash(desc.diff_desc));
630     // Axis
631     seed = hash_combine(seed, desc.softmax_axis);
632     // Combined hash for softmax desc
633     return seed;
634 }
635 
get_desc_hash(const sum_desc_t & desc)636 size_t get_desc_hash(const sum_desc_t &desc) {
637     size_t seed = 0;
638     // Kinds
639     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
640     // Memory descriptors
641     seed = hash_combine(seed, get_md_hash(*desc.dst_md));
642     // N
643     seed = hash_combine(seed, desc.n);
644     // Scales
645     if (desc.scales) { seed = get_array_hash(seed, desc.scales, desc.n); }
646     // Array of mds
647     seed = get_array_hash(seed, desc.src_mds, desc.n);
648     // Combined hash for sum desc
649     return seed;
650 }
651 
get_desc_hash(const zero_pad_desc_t & desc)652 size_t get_desc_hash(const zero_pad_desc_t &desc) {
653     size_t seed = 0;
654     // Kinds
655     seed = hash_combine(seed, static_cast<size_t>(desc.primitive_kind));
656     return seed;
657 }
658 
659 } // namespace primitive_hashing
660 } // namespace impl
661 } // namespace dnnl
662