1 #ifndef _TBLIS_UTIL_TENSOR_HPP_
2 #define _TBLIS_UTIL_TENSOR_HPP_
3
4 #include <initializer_list>
5 #include <string>
6
7 #include "util/basic_types.h"
8
9 #include "external/stl_ext/include/algorithm.hpp"
10 #include "external/stl_ext/include/type_traits.hpp"
11 #include "external/stl_ext/include/vector.hpp"
12
13 namespace MArray
14 {
15 template <typename T, size_t N>
operator +(const short_vector<T,N> & lhs,const short_vector<T,N> & rhs)16 short_vector<T,N> operator+(const short_vector<T,N>& lhs,
17 const short_vector<T,N>& rhs)
18 {
19 short_vector<T,N> res;
20 res.reserve(lhs.size() + rhs.size());
21 res.insert(res.end(), lhs.begin(), lhs.end());
22 res.insert(res.end(), rhs.begin(), rhs.end());
23 return res;
24 }
25 }
26
27 namespace tblis
28 {
29
30 namespace detail
31 {
32
free_idx(label_vector idx)33 inline label_type free_idx(label_vector idx)
34 {
35 if (idx.empty()) return 0;
36
37 stl_ext::sort(idx);
38
39 if (idx[0] > 0) return 0;
40
41 for (unsigned i = 1;i < idx.size();i++)
42 {
43 if (idx[i] > idx[i-1]+1) return idx[i-1]+1;
44 }
45
46 return idx.back()+1;
47 }
48
free_idx(const label_vector & idx_A,const label_vector & idx_B)49 inline label_type free_idx(const label_vector& idx_A,
50 const label_vector& idx_B)
51 {
52 return free_idx(stl_ext::union_of(idx_A, idx_B));
53 }
54
free_idx(const label_vector & idx_A,const label_vector & idx_B,const label_vector & idx_C)55 inline label_type free_idx(const label_vector& idx_A,
56 const label_vector& idx_B,
57 const label_vector& idx_C)
58 {
59 return free_idx(stl_ext::union_of(idx_A, idx_B, idx_C));
60 }
61
62 template <typename T>
relative_permutation(const T & a,const T & b)63 dim_vector relative_permutation(const T& a, const T& b)
64 {
65 dim_vector perm; perm.reserve(a.size());
66
67 for (auto& e : b)
68 {
69 for (unsigned i = 0;i < a.size();i++)
70 {
71 if (a[i] == e) perm.push_back(i);
72 }
73 }
74
75 return perm;
76 }
77
78 struct sort_by_idx_helper
79 {
80 const label_type* idx;
81
sort_by_idx_helpertblis::detail::sort_by_idx_helper82 sort_by_idx_helper(const label_type* idx_) : idx(idx_) {}
83
operator ()tblis::detail::sort_by_idx_helper84 bool operator()(unsigned i, unsigned j) const
85 {
86 return idx[i] < idx[j];
87 }
88 };
89
sort_by_idx(const label_type * idx)90 inline sort_by_idx_helper sort_by_idx(const label_type* idx)
91 {
92 return sort_by_idx_helper(idx);
93 }
94
95 template <unsigned N>
96 struct sort_by_stride_helper
97 {
98 std::array<const stride_vector*, N> strides;
99
sort_by_stride_helpertblis::detail::sort_by_stride_helper100 sort_by_stride_helper(std::initializer_list<const stride_vector*> ilist)
101 {
102 TBLIS_ASSERT(ilist.size() == N);
103 std::copy_n(ilist.begin(), N, strides.begin());
104 }
105
operator ()tblis::detail::sort_by_stride_helper106 bool operator()(unsigned i, unsigned j) const
107 {
108 auto min_i = (*strides[0])[i];
109 auto min_j = (*strides[0])[j];
110
111 for (size_t k = 1;k < N;k++)
112 {
113 min_i = std::min(min_i, (*strides[k])[i]);
114 min_j = std::min(min_j, (*strides[k])[j]);
115 }
116
117 if (min_i < min_j) return true;
118 if (min_i > min_j) return false;
119
120 for (size_t k = 0;k < N;k++)
121 {
122 auto s_i = (*strides[k])[i];
123 auto s_j = (*strides[k])[j];
124 if (s_i < s_j) return true;
125 if (s_i > s_j) return false;
126 }
127
128 return false;
129 }
130 };
131
check_sizes()132 inline size_t check_sizes() { return 0; }
133
134 template <typename T, typename... Ts>
check_sizes(const T & arg,const Ts &...args)135 size_t check_sizes(const T& arg, const Ts&... args)
136 {
137 size_t sz = arg.size();
138 if (sizeof...(Ts)) TBLIS_ASSERT(sz == check_sizes(args...));
139 return sz;
140 }
141
142 template <typename... Strides>
sort_by_stride(const Strides &...strides)143 dim_vector sort_by_stride(const Strides&... strides)
144 {
145 dim_vector idx = range(static_cast<unsigned>(check_sizes(strides...)));
146 std::sort(idx.begin(), idx.end(), sort_by_stride_helper<sizeof...(Strides)>{&strides...});
147 return idx;
148 }
149
150 template <typename T>
are_congruent_along(const varray_view<const T> & A,const varray_view<const T> & B,unsigned dim)151 bool are_congruent_along(const varray_view<const T>& A,
152 const varray_view<const T>& B, unsigned dim)
153 {
154 if (A.dimension() < B.dimension()) swap(A, B);
155
156 unsigned ndim = A.dimension();
157 auto sA = A.strides().begin();
158 auto sB = B.strides().begin();
159 auto lA = A.lengths().begin();
160 auto lB = B.lengths().begin();
161
162 if (B.dimension() == ndim)
163 {
164 if (!std::equal(sA, sA+ndim, sB)) return false;
165 if (!std::equal(lA, lA+dim, lB)) return false;
166 if (!std::equal(lA+dim+1, lA+ndim, lB+dim+1)) return false;
167 }
168 else if (B.dimension() == ndim-1)
169 {
170 if (!std::equal(sA, sA+dim, sB)) return false;
171 if (!std::equal(sA+dim+1, sA+ndim, sB+dim)) return false;
172 if (!std::equal(lA, lA+dim, lB)) return false;
173 if (!std::equal(lA+dim+1, lA+ndim, lB+dim)) return false;
174 }
175 else
176 {
177 return false;
178 }
179
180 return true;
181 }
182
are_compatible(const len_vector & len_A,const stride_vector & stride_A,const len_vector & len_B,const stride_vector & stride_B)183 inline bool are_compatible(const len_vector& len_A,
184 const stride_vector& stride_A,
185 const len_vector& len_B,
186 const stride_vector& stride_B)
187 {
188 TBLIS_ASSERT(len_A.size() == stride_A.size());
189 auto dims_A = detail::sort_by_stride(stride_A);
190 auto len_Ar = stl_ext::permuted(len_A, dims_A);
191 auto stride_Ar = stl_ext::permuted(stride_A, dims_A);
192
193 TBLIS_ASSERT(len_B.size() == stride_B.size());
194 auto dims_B = detail::sort_by_stride(stride_B);
195 auto len_Br = stl_ext::permuted(len_B, dims_B);
196 auto stride_Br = stl_ext::permuted(stride_B, dims_B);
197
198 if (stl_ext::prod(len_Ar) != stl_ext::prod(len_Br))
199 return false;
200
201 viterator<> it_A(len_Ar, stride_Ar);
202 viterator<> it_B(len_Br, stride_Br);
203
204 stride_type off_A = 0, off_B = 0;
205 while (it_A.next(off_A) + it_B.next(off_B))
206 if (off_A != off_B) return false;
207
208 return true;
209 }
210
211 template <typename T>
are_compatible(const varray_view<const T> & A,const varray_view<const T> & B)212 bool are_compatible(const varray_view<const T>& A,
213 const varray_view<const T>& B)
214 {
215 return A.data() == B.data() &&
216 are_compatible(A.lengths(), A.strides(),
217 B.lengths(), B.strides());
218 }
219
220 template <size_t I, size_t N, typename... Strides>
221 struct swap_strides_helper
222 {
swap_strides_helpertblis::detail::swap_strides_helper223 swap_strides_helper(std::tuple<Strides&...>& strides,
224 std::tuple<Strides...>& oldstrides)
225 {
226 std::get<I>(strides).swap(std::get<I>(oldstrides));
227 swap_strides_helper<I+1, N, Strides...>(strides, oldstrides);
228 }
229 };
230
231 template <size_t N, typename... Strides>
232 struct swap_strides_helper<N, N, Strides...>
233 {
swap_strides_helpertblis::detail::swap_strides_helper234 swap_strides_helper(std::tuple<Strides&...>&,
235 std::tuple<Strides...>&) {}
236 };
237
238 template <typename... Strides>
swap_strides(std::tuple<Strides &...> & strides,std::tuple<Strides...> & oldstrides)239 void swap_strides(std::tuple<Strides&...>& strides,
240 std::tuple<Strides...>& oldstrides)
241 {
242 swap_strides_helper<0, sizeof...(Strides), Strides...>(strides, oldstrides);
243 }
244
245 template <size_t I, size_t N, typename... Strides>
246 struct are_contiguous_helper
247 {
operator ()tblis::detail::are_contiguous_helper248 bool operator()(std::tuple<Strides...>& strides,
249 const len_vector& lengths,
250 unsigned i, unsigned im1)
251 {
252 return std::get<I>(strides)[i] == std::get<I>(strides)[im1]*lengths[im1] &&
253 are_contiguous_helper<I+1, N, Strides...>()(strides, lengths, i, im1);
254 }
255 };
256
257 template <size_t N, typename... Strides>
258 struct are_contiguous_helper<N, N, Strides...>
259 {
operator ()tblis::detail::are_contiguous_helper260 bool operator()(std::tuple<Strides...>&,
261 const len_vector&,
262 unsigned, unsigned)
263 {
264 return true;
265 }
266 };
267
268 template <typename... Strides>
are_contiguous(std::tuple<Strides...> & strides,const len_vector & lengths,unsigned i,unsigned im1)269 bool are_contiguous(std::tuple<Strides...>& strides,
270 const len_vector& lengths,
271 unsigned i, unsigned im1)
272 {
273 return are_contiguous_helper<0, sizeof...(Strides), Strides...>()(strides, lengths, i, im1);
274 }
275
276 template <size_t I, size_t N, typename... Strides>
277 struct push_back_strides_helper
278 {
push_back_strides_helpertblis::detail::push_back_strides_helper279 push_back_strides_helper(std::tuple<Strides&...>& strides,
280 std::tuple<Strides...>& oldstrides, unsigned i)
281 {
282 std::get<I>(strides).push_back(std::get<I>(oldstrides)[i]);
283 push_back_strides_helper<I+1, N, Strides...>(strides, oldstrides, i);
284 }
285 };
286
287 template <size_t N, typename... Strides>
288 struct push_back_strides_helper<N, N, Strides...>
289 {
push_back_strides_helpertblis::detail::push_back_strides_helper290 push_back_strides_helper(std::tuple<Strides&...>&,
291 std::tuple<Strides...>&, unsigned) {}
292 };
293
294 template <typename... Strides>
push_back_strides(std::tuple<Strides &...> & strides,std::tuple<Strides...> & oldstrides,unsigned i)295 void push_back_strides(std::tuple<Strides&...>& strides,
296 std::tuple<Strides...>& oldstrides, unsigned i)
297 {
298 push_back_strides_helper<0, sizeof...(Strides), Strides...>(strides, oldstrides, i);
299 }
300
301 template <size_t I, size_t N, typename... Strides>
302 struct are_compatible_helper
303 {
operator ()tblis::detail::are_compatible_helper304 bool operator()(const len_vector& len_A,
305 const std::tuple<Strides...>& stride_A,
306 const len_vector& len_B,
307 const std::tuple<Strides&...>& stride_B)
308 {
309 return are_compatible(len_A, std::get<I>(stride_A),
310 len_B, std::get<I>(stride_B)) &&
311 are_compatible_helper<I+1, N, Strides...>()(len_A, stride_A,
312 len_B, stride_B);
313 }
314 };
315
316 template <size_t N, typename... Strides>
317 struct are_compatible_helper<N, N, Strides...>
318 {
operator ()tblis::detail::are_compatible_helper319 bool operator()(const len_vector&,
320 const std::tuple<Strides...>&,
321 const len_vector&,
322 const std::tuple<Strides&...>&)
323 {
324 return true;
325 }
326 };
327
328 template <typename... Strides>
are_compatible(const len_vector & len_A,const std::tuple<Strides...> & stride_A,const len_vector & len_B,const std::tuple<Strides &...> & stride_B)329 bool are_compatible(const len_vector& len_A,
330 const std::tuple<Strides...>& stride_A,
331 const len_vector& len_B,
332 const std::tuple<Strides&...>& stride_B)
333 {
334 return are_compatible_helper<0, sizeof...(Strides), Strides...>()(
335 len_A, stride_A, len_B, stride_B);
336 }
337
338 }
339
340 template <typename... Strides>
fold(len_vector & lengths,label_vector & idx,Strides &..._strides)341 void fold(len_vector& lengths, label_vector& idx,
342 Strides&... _strides)
343 {
344 std::tuple<Strides&...> strides(_strides...);
345
346 auto ndim = lengths.size();
347 auto inds = detail::sort_by_stride(std::get<0>(strides));
348
349 label_vector oldidx;
350 len_vector oldlengths;
351 std::tuple<Strides...> oldstrides;
352
353 oldidx.swap(idx);
354 oldlengths.swap(lengths);
355 detail::swap_strides(strides, oldstrides);
356
357 for (unsigned i = 0;i < ndim;i++)
358 {
359 if (i != 0 && detail::are_contiguous(oldstrides, oldlengths, inds[i], inds[i-1]))
360 {
361 lengths.back() *= oldlengths[inds[i]];
362 }
363 else
364 {
365 idx.push_back(oldidx[inds[i]]);
366 lengths.push_back(oldlengths[inds[i]]);
367 detail::push_back_strides(strides, oldstrides, inds[i]);
368 }
369 }
370
371 TBLIS_ASSERT(detail::are_compatible(oldlengths, oldstrides,
372 lengths, strides));
373 }
374
diagonal(unsigned & ndim,const len_type * len_in,const stride_type * stride_in,const label_type * idx_in,len_vector & len_out,stride_vector & stride_out,label_vector & idx_out)375 inline void diagonal(unsigned& ndim,
376 const len_type* len_in,
377 const stride_type* stride_in,
378 const label_type* idx_in,
379 len_vector& len_out,
380 stride_vector& stride_out,
381 label_vector& idx_out)
382 {
383 len_out.reserve(ndim);
384 stride_out.reserve(ndim);
385 idx_out.reserve(ndim);
386
387 dim_vector inds = range(ndim);
388 stl_ext::sort(inds, detail::sort_by_idx(idx_in));
389
390 unsigned ndim_in = ndim;
391
392 ndim = 0;
393 for (unsigned i = 0;i < ndim_in;i++)
394 {
395 if (i == 0 || idx_in[inds[i]] != idx_in[inds[i-1]])
396 {
397 if (len_in[inds[i]] != 1)
398 {
399 len_out.push_back(len_in[inds[i]]);
400 stride_out.push_back(stride_in[inds[i]]);
401 idx_out.push_back(idx_in[inds[i]]);
402 ndim++;
403 }
404 }
405 else if (len_in[inds[i]] != 1)
406 {
407 TBLIS_ASSERT(len_out[ndim-1] == len_in[inds[i]]);
408 if (len_in[inds[i]] != 1)
409 stride_out[ndim-1] += stride_in[inds[i]];
410 }
411 }
412 }
413
414 template <typename T>
matricize(varray_view<const T> A,matrix_view<const T> & AM,unsigned split)415 void matricize(varray_view<const T> A,
416 matrix_view<const T>& AM, unsigned split)
417 {
418 unsigned ndim = A.dimension();
419 TBLIS_ASSERT(split <= ndim);
420 if (ndim > 0 && A.stride(0) < A.stride(ndim-1))
421 {
422 for (unsigned i = 1;i < split;i++)
423 TBLIS_ASSERT(A.stride(i) == A.stride(i-1)*A.length(i-1));
424 for (unsigned i = split+1;i < ndim;i++)
425 TBLIS_ASSERT(A.stride(i) == A.stride(i-1)*A.length(i-1));
426 }
427 else
428 {
429 for (unsigned i = 0;i+1 < split;i++)
430 TBLIS_ASSERT(A.stride(i) == A.stride(i+1)*A.length(i+1));
431 for (unsigned i = split;i+1 < ndim;i++)
432 TBLIS_ASSERT(A.stride(i) == A.stride(i+1)*A.length(i+1));
433 }
434
435 len_type m = 1;
436 for (unsigned i = 0;i < split;i++)
437 {
438 m *= A.length(i);
439 }
440
441 len_type n = 1;
442 for (unsigned i = split;i < ndim;i++)
443 {
444 n *= A.length(i);
445 }
446
447 stride_type rs, cs;
448
449 if (ndim == 0)
450 {
451 rs = cs = 1;
452 }
453 else if (m == 1)
454 {
455 rs = n;
456 cs = 1;
457 }
458 else if (n == 1)
459 {
460 rs = 1;
461 cs = m;
462 }
463 else if (A.stride(0) < A.stride(ndim-1))
464 {
465 rs = (split == 0 ? 1 : A.stride( 0));
466 cs = (split == ndim ? m : A.stride(split));
467 }
468 else
469 {
470 rs = (split == 0 ? n : A.stride(split-1));
471 cs = (split == ndim ? 1 : A.stride( ndim-1));
472 }
473
474 AM.reset({m, n}, A.data(), {rs, cs});
475 }
476
477 template <typename T>
matricize(varray_view<T> A,matrix_view<T> & AM,unsigned split)478 void matricize(varray_view<T> A,
479 matrix_view<T>& AM, unsigned split)
480 {
481 matricize<T>(A, reinterpret_cast<matrix_view<const T>&>(AM), split);
482 }
483
unit_dim(const stride_vector & stride,const dim_vector & reorder)484 inline unsigned unit_dim(const stride_vector& stride, const dim_vector& reorder)
485 {
486 for (unsigned i = 0;i < reorder.size();i++)
487 if (stride[reorder[i]] == 1)
488 return i;
489
490 return reorder.size();
491 }
492
493 }
494
495 #endif
496