1 #ifndef _TBLIS_UTIL_TENSOR_HPP_
2 #define _TBLIS_UTIL_TENSOR_HPP_
3 
4 #include <initializer_list>
5 #include <string>
6 
7 #include "util/basic_types.h"
8 
9 #include "external/stl_ext/include/algorithm.hpp"
10 #include "external/stl_ext/include/type_traits.hpp"
11 #include "external/stl_ext/include/vector.hpp"
12 
13 namespace MArray
14 {
15     template <typename T, size_t N>
operator +(const short_vector<T,N> & lhs,const short_vector<T,N> & rhs)16     short_vector<T,N> operator+(const short_vector<T,N>& lhs,
17                                 const short_vector<T,N>& rhs)
18     {
19         short_vector<T,N> res;
20         res.reserve(lhs.size() + rhs.size());
21         res.insert(res.end(), lhs.begin(), lhs.end());
22         res.insert(res.end(), rhs.begin(), rhs.end());
23         return res;
24     }
25 }
26 
27 namespace tblis
28 {
29 
30 namespace detail
31 {
32 
free_idx(label_vector idx)33 inline label_type free_idx(label_vector idx)
34 {
35     if (idx.empty()) return 0;
36 
37     stl_ext::sort(idx);
38 
39     if (idx[0] > 0) return 0;
40 
41     for (unsigned i = 1;i < idx.size();i++)
42     {
43         if (idx[i] > idx[i-1]+1) return idx[i-1]+1;
44     }
45 
46     return idx.back()+1;
47 }
48 
free_idx(const label_vector & idx_A,const label_vector & idx_B)49 inline label_type free_idx(const label_vector& idx_A,
50                            const label_vector& idx_B)
51 {
52     return free_idx(stl_ext::union_of(idx_A, idx_B));
53 }
54 
free_idx(const label_vector & idx_A,const label_vector & idx_B,const label_vector & idx_C)55 inline label_type free_idx(const label_vector& idx_A,
56                            const label_vector& idx_B,
57                            const label_vector& idx_C)
58 {
59     return free_idx(stl_ext::union_of(idx_A, idx_B, idx_C));
60 }
61 
62 template <typename T>
relative_permutation(const T & a,const T & b)63 dim_vector relative_permutation(const T& a, const T& b)
64 {
65     dim_vector perm; perm.reserve(a.size());
66 
67     for (auto& e : b)
68     {
69         for (unsigned i = 0;i < a.size();i++)
70         {
71             if (a[i] == e) perm.push_back(i);
72         }
73     }
74 
75     return perm;
76 }
77 
78 struct sort_by_idx_helper
79 {
80     const label_type* idx;
81 
sort_by_idx_helpertblis::detail::sort_by_idx_helper82     sort_by_idx_helper(const label_type* idx_) : idx(idx_) {}
83 
operator ()tblis::detail::sort_by_idx_helper84     bool operator()(unsigned i, unsigned j) const
85     {
86         return idx[i] < idx[j];
87     }
88 };
89 
sort_by_idx(const label_type * idx)90 inline sort_by_idx_helper sort_by_idx(const label_type* idx)
91 {
92     return sort_by_idx_helper(idx);
93 }
94 
95 template <unsigned N>
96 struct sort_by_stride_helper
97 {
98     std::array<const stride_vector*, N> strides;
99 
sort_by_stride_helpertblis::detail::sort_by_stride_helper100     sort_by_stride_helper(std::initializer_list<const stride_vector*> ilist)
101     {
102         TBLIS_ASSERT(ilist.size() == N);
103         std::copy_n(ilist.begin(), N, strides.begin());
104     }
105 
operator ()tblis::detail::sort_by_stride_helper106     bool operator()(unsigned i, unsigned j) const
107     {
108         auto min_i = (*strides[0])[i];
109         auto min_j = (*strides[0])[j];
110 
111         for (size_t k = 1;k < N;k++)
112         {
113             min_i = std::min(min_i, (*strides[k])[i]);
114             min_j = std::min(min_j, (*strides[k])[j]);
115         }
116 
117         if (min_i < min_j) return true;
118         if (min_i > min_j) return false;
119 
120         for (size_t k = 0;k < N;k++)
121         {
122             auto s_i = (*strides[k])[i];
123             auto s_j = (*strides[k])[j];
124             if (s_i < s_j) return true;
125             if (s_i > s_j) return false;
126         }
127 
128         return false;
129     }
130 };
131 
check_sizes()132 inline size_t check_sizes() { return 0; }
133 
134 template <typename T, typename... Ts>
check_sizes(const T & arg,const Ts &...args)135 size_t check_sizes(const T& arg, const Ts&... args)
136 {
137     size_t sz = arg.size();
138     if (sizeof...(Ts)) TBLIS_ASSERT(sz == check_sizes(args...));
139     return sz;
140 }
141 
142 template <typename... Strides>
sort_by_stride(const Strides &...strides)143 dim_vector sort_by_stride(const Strides&... strides)
144 {
145     dim_vector idx = range(static_cast<unsigned>(check_sizes(strides...)));
146     std::sort(idx.begin(), idx.end(), sort_by_stride_helper<sizeof...(Strides)>{&strides...});
147     return idx;
148 }
149 
150 template <typename T>
are_congruent_along(const varray_view<const T> & A,const varray_view<const T> & B,unsigned dim)151 bool are_congruent_along(const varray_view<const T>& A,
152                          const varray_view<const T>& B, unsigned dim)
153 {
154     if (A.dimension() < B.dimension()) swap(A, B);
155 
156     unsigned ndim = A.dimension();
157     auto sA = A.strides().begin();
158     auto sB = B.strides().begin();
159     auto lA = A.lengths().begin();
160     auto lB = B.lengths().begin();
161 
162     if (B.dimension() == ndim)
163     {
164         if (!std::equal(sA, sA+ndim, sB)) return false;
165         if (!std::equal(lA, lA+dim, lB)) return false;
166         if (!std::equal(lA+dim+1, lA+ndim, lB+dim+1)) return false;
167     }
168     else if (B.dimension() == ndim-1)
169     {
170         if (!std::equal(sA, sA+dim, sB)) return false;
171         if (!std::equal(sA+dim+1, sA+ndim, sB+dim)) return false;
172         if (!std::equal(lA, lA+dim, lB)) return false;
173         if (!std::equal(lA+dim+1, lA+ndim, lB+dim)) return false;
174     }
175     else
176     {
177         return false;
178     }
179 
180     return true;
181 }
182 
are_compatible(const len_vector & len_A,const stride_vector & stride_A,const len_vector & len_B,const stride_vector & stride_B)183 inline bool are_compatible(const len_vector& len_A,
184                            const stride_vector& stride_A,
185                            const len_vector& len_B,
186                            const stride_vector& stride_B)
187 {
188     TBLIS_ASSERT(len_A.size() == stride_A.size());
189     auto dims_A = detail::sort_by_stride(stride_A);
190     auto len_Ar = stl_ext::permuted(len_A, dims_A);
191     auto stride_Ar = stl_ext::permuted(stride_A, dims_A);
192 
193     TBLIS_ASSERT(len_B.size() == stride_B.size());
194     auto dims_B = detail::sort_by_stride(stride_B);
195     auto len_Br = stl_ext::permuted(len_B, dims_B);
196     auto stride_Br = stl_ext::permuted(stride_B, dims_B);
197 
198     if (stl_ext::prod(len_Ar) != stl_ext::prod(len_Br))
199         return false;
200 
201     viterator<> it_A(len_Ar, stride_Ar);
202     viterator<> it_B(len_Br, stride_Br);
203 
204     stride_type off_A = 0, off_B = 0;
205     while (it_A.next(off_A) + it_B.next(off_B))
206         if (off_A != off_B) return false;
207 
208     return true;
209 }
210 
211 template <typename T>
are_compatible(const varray_view<const T> & A,const varray_view<const T> & B)212 bool are_compatible(const varray_view<const T>& A,
213                     const varray_view<const T>& B)
214 {
215     return A.data() == B.data() &&
216         are_compatible(A.lengths(), A.strides(),
217                        B.lengths(), B.strides());
218 }
219 
220 template <size_t I, size_t N, typename... Strides>
221 struct swap_strides_helper
222 {
swap_strides_helpertblis::detail::swap_strides_helper223     swap_strides_helper(std::tuple<Strides&...>& strides,
224                         std::tuple<Strides...>& oldstrides)
225     {
226         std::get<I>(strides).swap(std::get<I>(oldstrides));
227         swap_strides_helper<I+1, N, Strides...>(strides, oldstrides);
228     }
229 };
230 
231 template <size_t N, typename... Strides>
232 struct swap_strides_helper<N, N, Strides...>
233 {
swap_strides_helpertblis::detail::swap_strides_helper234     swap_strides_helper(std::tuple<Strides&...>&,
235                         std::tuple<Strides...>&) {}
236 };
237 
238 template <typename... Strides>
swap_strides(std::tuple<Strides &...> & strides,std::tuple<Strides...> & oldstrides)239 void swap_strides(std::tuple<Strides&...>& strides,
240                   std::tuple<Strides...>& oldstrides)
241 {
242     swap_strides_helper<0, sizeof...(Strides), Strides...>(strides, oldstrides);
243 }
244 
245 template <size_t I, size_t N, typename... Strides>
246 struct are_contiguous_helper
247 {
operator ()tblis::detail::are_contiguous_helper248     bool operator()(std::tuple<Strides...>& strides,
249                     const len_vector& lengths,
250                     unsigned i, unsigned im1)
251     {
252         return std::get<I>(strides)[i] == std::get<I>(strides)[im1]*lengths[im1] &&
253             are_contiguous_helper<I+1, N, Strides...>()(strides, lengths, i, im1);
254     }
255 };
256 
257 template <size_t N, typename... Strides>
258 struct are_contiguous_helper<N, N, Strides...>
259 {
operator ()tblis::detail::are_contiguous_helper260     bool operator()(std::tuple<Strides...>&,
261                     const len_vector&,
262                     unsigned, unsigned)
263     {
264         return true;
265     }
266 };
267 
268 template <typename... Strides>
are_contiguous(std::tuple<Strides...> & strides,const len_vector & lengths,unsigned i,unsigned im1)269 bool are_contiguous(std::tuple<Strides...>& strides,
270                     const len_vector& lengths,
271                     unsigned i, unsigned im1)
272 {
273     return are_contiguous_helper<0, sizeof...(Strides), Strides...>()(strides, lengths, i, im1);
274 }
275 
276 template <size_t I, size_t N, typename... Strides>
277 struct push_back_strides_helper
278 {
push_back_strides_helpertblis::detail::push_back_strides_helper279     push_back_strides_helper(std::tuple<Strides&...>& strides,
280                              std::tuple<Strides...>& oldstrides, unsigned i)
281     {
282         std::get<I>(strides).push_back(std::get<I>(oldstrides)[i]);
283         push_back_strides_helper<I+1, N, Strides...>(strides, oldstrides, i);
284     }
285 };
286 
287 template <size_t N, typename... Strides>
288 struct push_back_strides_helper<N, N, Strides...>
289 {
push_back_strides_helpertblis::detail::push_back_strides_helper290     push_back_strides_helper(std::tuple<Strides&...>&,
291                              std::tuple<Strides...>&, unsigned) {}
292 };
293 
294 template <typename... Strides>
push_back_strides(std::tuple<Strides &...> & strides,std::tuple<Strides...> & oldstrides,unsigned i)295 void push_back_strides(std::tuple<Strides&...>& strides,
296                        std::tuple<Strides...>& oldstrides, unsigned i)
297 {
298     push_back_strides_helper<0, sizeof...(Strides), Strides...>(strides, oldstrides, i);
299 }
300 
301 template <size_t I, size_t N, typename... Strides>
302 struct are_compatible_helper
303 {
operator ()tblis::detail::are_compatible_helper304     bool operator()(const len_vector& len_A,
305                     const std::tuple<Strides...>& stride_A,
306                     const len_vector& len_B,
307                     const std::tuple<Strides&...>& stride_B)
308     {
309         return are_compatible(len_A, std::get<I>(stride_A),
310                               len_B, std::get<I>(stride_B)) &&
311             are_compatible_helper<I+1, N, Strides...>()(len_A, stride_A,
312                                                         len_B, stride_B);
313     }
314 };
315 
316 template <size_t N, typename... Strides>
317 struct are_compatible_helper<N, N, Strides...>
318 {
operator ()tblis::detail::are_compatible_helper319     bool operator()(const len_vector&,
320                     const std::tuple<Strides...>&,
321                     const len_vector&,
322                     const std::tuple<Strides&...>&)
323     {
324         return true;
325     }
326 };
327 
328 template <typename... Strides>
are_compatible(const len_vector & len_A,const std::tuple<Strides...> & stride_A,const len_vector & len_B,const std::tuple<Strides &...> & stride_B)329 bool are_compatible(const len_vector& len_A,
330                     const std::tuple<Strides...>& stride_A,
331                     const len_vector& len_B,
332                     const std::tuple<Strides&...>& stride_B)
333 {
334     return are_compatible_helper<0, sizeof...(Strides), Strides...>()(
335         len_A, stride_A, len_B, stride_B);
336 }
337 
338 }
339 
340 template <typename... Strides>
fold(len_vector & lengths,label_vector & idx,Strides &..._strides)341 void fold(len_vector& lengths, label_vector& idx,
342           Strides&... _strides)
343 {
344     std::tuple<Strides&...> strides(_strides...);
345 
346     auto ndim = lengths.size();
347     auto inds = detail::sort_by_stride(std::get<0>(strides));
348 
349     label_vector oldidx;
350     len_vector oldlengths;
351     std::tuple<Strides...> oldstrides;
352 
353     oldidx.swap(idx);
354     oldlengths.swap(lengths);
355     detail::swap_strides(strides, oldstrides);
356 
357     for (unsigned i = 0;i < ndim;i++)
358     {
359         if (i != 0 && detail::are_contiguous(oldstrides, oldlengths, inds[i], inds[i-1]))
360         {
361             lengths.back() *= oldlengths[inds[i]];
362         }
363         else
364         {
365             idx.push_back(oldidx[inds[i]]);
366             lengths.push_back(oldlengths[inds[i]]);
367             detail::push_back_strides(strides, oldstrides, inds[i]);
368         }
369     }
370 
371     TBLIS_ASSERT(detail::are_compatible(oldlengths, oldstrides,
372                                         lengths, strides));
373 }
374 
diagonal(unsigned & ndim,const len_type * len_in,const stride_type * stride_in,const label_type * idx_in,len_vector & len_out,stride_vector & stride_out,label_vector & idx_out)375 inline void diagonal(unsigned& ndim,
376                      const len_type* len_in,
377                      const stride_type* stride_in,
378                      const label_type* idx_in,
379                      len_vector& len_out,
380                      stride_vector& stride_out,
381                      label_vector& idx_out)
382 {
383     len_out.reserve(ndim);
384     stride_out.reserve(ndim);
385     idx_out.reserve(ndim);
386 
387     dim_vector inds = range(ndim);
388     stl_ext::sort(inds, detail::sort_by_idx(idx_in));
389 
390     unsigned ndim_in = ndim;
391 
392     ndim = 0;
393     for (unsigned i = 0;i < ndim_in;i++)
394     {
395         if (i == 0 || idx_in[inds[i]] != idx_in[inds[i-1]])
396         {
397             if (len_in[inds[i]] != 1)
398             {
399                 len_out.push_back(len_in[inds[i]]);
400                 stride_out.push_back(stride_in[inds[i]]);
401                 idx_out.push_back(idx_in[inds[i]]);
402                 ndim++;
403             }
404         }
405         else if (len_in[inds[i]] != 1)
406         {
407             TBLIS_ASSERT(len_out[ndim-1] == len_in[inds[i]]);
408             if (len_in[inds[i]] != 1)
409                 stride_out[ndim-1] += stride_in[inds[i]];
410         }
411     }
412 }
413 
414 template <typename T>
matricize(varray_view<const T> A,matrix_view<const T> & AM,unsigned split)415 void matricize(varray_view<const T>  A,
416                matrix_view<const T>& AM, unsigned split)
417 {
418     unsigned ndim = A.dimension();
419     TBLIS_ASSERT(split <= ndim);
420     if (ndim > 0 && A.stride(0) < A.stride(ndim-1))
421     {
422         for (unsigned i = 1;i < split;i++)
423             TBLIS_ASSERT(A.stride(i) == A.stride(i-1)*A.length(i-1));
424         for (unsigned i = split+1;i < ndim;i++)
425             TBLIS_ASSERT(A.stride(i) == A.stride(i-1)*A.length(i-1));
426     }
427     else
428     {
429         for (unsigned i = 0;i+1 < split;i++)
430             TBLIS_ASSERT(A.stride(i) == A.stride(i+1)*A.length(i+1));
431         for (unsigned i = split;i+1 < ndim;i++)
432             TBLIS_ASSERT(A.stride(i) == A.stride(i+1)*A.length(i+1));
433     }
434 
435     len_type m = 1;
436     for (unsigned i = 0;i < split;i++)
437     {
438         m *= A.length(i);
439     }
440 
441     len_type n = 1;
442     for (unsigned i = split;i < ndim;i++)
443     {
444         n *= A.length(i);
445     }
446 
447     stride_type rs, cs;
448 
449     if (ndim == 0)
450     {
451         rs = cs = 1;
452     }
453     else if (m == 1)
454     {
455         rs = n;
456         cs = 1;
457     }
458     else if (n == 1)
459     {
460         rs = 1;
461         cs = m;
462     }
463     else if (A.stride(0) < A.stride(ndim-1))
464     {
465         rs = (split ==    0 ? 1 : A.stride(    0));
466         cs = (split == ndim ? m : A.stride(split));
467     }
468     else
469     {
470         rs = (split ==    0 ? n : A.stride(split-1));
471         cs = (split == ndim ? 1 : A.stride( ndim-1));
472     }
473 
474     AM.reset({m, n}, A.data(), {rs, cs});
475 }
476 
477 template <typename T>
matricize(varray_view<T> A,matrix_view<T> & AM,unsigned split)478 void matricize(varray_view<T>  A,
479                matrix_view<T>& AM, unsigned split)
480 {
481     matricize<T>(A, reinterpret_cast<matrix_view<const T>&>(AM), split);
482 }
483 
unit_dim(const stride_vector & stride,const dim_vector & reorder)484 inline unsigned unit_dim(const stride_vector& stride, const dim_vector& reorder)
485 {
486     for (unsigned i = 0;i < reorder.size();i++)
487         if (stride[reorder[i]] == 1)
488             return i;
489 
490     return reorder.size();
491 }
492 
493 }
494 
495 #endif
496