1 /*******************************************************************************
2 * Copyright 2021 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "utils/dims.hpp"
18 
get_broadcast_mask(int i_input) const19 int prb_vdims_t::get_broadcast_mask(int i_input) const {
20     int broadcast_mask = 0;
21     for (int d = 0; d < ndims; ++d)
22         broadcast_mask += dst_dims[d] == vdims[i_input][d] ? (1 << d) : 0;
23     return broadcast_mask;
24 }
25 
26 // returns dims with current @p off values using actual values from @p dims
off2dims_idx(const dims_t & dims,int64_t off)27 dims_t off2dims_idx(const dims_t &dims, int64_t off) {
28     dims_t dims_idx;
29     dims_idx.reserve(dims.size());
30 
31     for (int i = (int)dims.size() - 1; i >= 0; --i) {
32         dims_idx.insert(dims_idx.begin(), off % dims[i]);
33         off /= dims[i];
34     }
35     assert(off == 0);
36     return dims_idx;
37 }
38 
dims2str(const dims_t & dims)39 std::string dims2str(const dims_t &dims) {
40     std::string s;
41     if (dims.empty()) return s;
42 
43     s += std::to_string(dims[0]);
44     for (auto d = dims.begin() + 1; d != dims.end(); d++)
45         s += "x" + std::to_string(*d);
46 
47     return s;
48 }
49 
vdims2str(const vdims_t & vdims)50 std::string vdims2str(const vdims_t &vdims) {
51     std::string s;
52     if (vdims.empty()) return s;
53 
54     s += dims2str(vdims[0]);
55     for (auto it = vdims.begin() + 1; it != vdims.end(); it++) {
56         const auto &dims = *it;
57         s += ":" + dims2str(dims);
58     }
59     return s;
60 }
61 
operator <<(std::ostream & s,const prb_dims_t & prb_dims)62 std::ostream &operator<<(std::ostream &s, const prb_dims_t &prb_dims) {
63     s << dims2str(prb_dims.dims);
64     if (!prb_dims.name.empty()) s << "_n" << prb_dims.name;
65     return s;
66 }
67 
operator <<(std::ostream & s,const prb_vdims_t & prb_vdims)68 std::ostream &operator<<(std::ostream &s, const prb_vdims_t &prb_vdims) {
69     s << vdims2str(prb_vdims.vdims);
70     if (!prb_vdims.name.empty()) s << "_n" << prb_vdims.name;
71     return s;
72 }
73