1 #ifndef _TBLIS_INTERNAL_1T_DPD_UTIL_HPP_
2 #define _TBLIS_INTERNAL_1T_DPD_UTIL_HPP_
3
4 #include "util/basic_types.h"
5 #include "util/tensor.hpp"
6 #include "internal/1t/dense/add.hpp"
7 #include "internal/3t/dpd/mult.hpp"
8
9 namespace tblis
10 {
11 namespace internal
12 {
13
14 class irrep_iterator
15 {
16 protected:
17 const unsigned irrep_;
18 const unsigned irrep_bits_;
19 const unsigned irrep_mask_;
20 viterator<0> it_;
21
22 public:
irrep_iterator(unsigned irrep,unsigned nirrep,unsigned ndim)23 irrep_iterator(unsigned irrep, unsigned nirrep, unsigned ndim)
24 : irrep_(irrep), irrep_bits_(__builtin_popcount(nirrep-1)),
25 irrep_mask_ (nirrep-1), it_(irrep_vector(ndim ? ndim-1 : 0, nirrep)) {}
26
next()27 bool next()
28 {
29 return it_.next();
30 }
31
nblock() const32 unsigned nblock() const
33 {
34 return 1u << (irrep_bits_*it_.dimension());
35 }
36
block(unsigned b)37 void block(unsigned b)
38 {
39 irrep_vector irreps(it_.dimension());
40
41 for (unsigned i = 0;i < it_.dimension();i++)
42 {
43 irreps[i] = b & irrep_mask_;
44 b >>= irrep_bits_;
45 }
46
47 it_.position(irreps);
48 }
49
reset()50 void reset()
51 {
52 it_.reset();
53 }
54
irrep(unsigned dim)55 unsigned irrep(unsigned dim)
56 {
57 TBLIS_ASSERT(dim <= it_.dimension());
58
59 if (dim == 0)
60 {
61 unsigned irr0 = irrep_;
62 for (unsigned irr : it_.position()) irr0 ^= irr;
63 return irr0;
64 }
65
66 return it_.position()[dim-1];
67 }
68 };
69
70 template <typename T, typename U>
block_to_full(const communicator & comm,const config & cfg,const dpd_varray_view<T> & A,varray<U> & A2)71 void block_to_full(const communicator& comm, const config& cfg,
72 const dpd_varray_view<T>& A, varray<U>& A2)
73 {
74 unsigned nirrep = A.num_irreps();
75 unsigned ndim_A = A.dimension();
76
77 len_vector len_A(ndim_A);
78 matrix<len_type> off_A({ndim_A, nirrep});
79 for (unsigned i = 0;i < ndim_A;i++)
80 {
81 for (unsigned irrep = 0;irrep < nirrep;irrep++)
82 {
83 off_A[i][irrep] = len_A[i];
84 len_A[i] += A.length(i, irrep);
85 }
86 }
87
88 if (comm.master()) A2.reset(len_A);
89 comm.barrier();
90
91 A.for_each_block(
92 [&](const varray_view<T>& local_A, const irrep_vector& irreps_A)
93 {
94 auto data_A2 = A2.data();
95 for (unsigned i = 0;i < ndim_A;i++)
96 data_A2 += off_A[i][irreps_A[i]]*A2.stride(i);
97
98 add<U>(comm, cfg, {}, {}, local_A.lengths(),
99 1, false, local_A.data(), {}, local_A.strides(),
100 0, false, data_A2, {}, A2.strides());
101 });
102 }
103
104 template <typename T, typename U>
full_to_block(const communicator & comm,const config & cfg,const varray<U> & A2,const dpd_varray_view<T> & A)105 void full_to_block(const communicator& comm, const config& cfg,
106 const varray<U>& A2, const dpd_varray_view<T>& A)
107 {
108 unsigned nirrep = A.num_irreps();
109 unsigned ndim_A = A.dimension();
110
111 matrix<len_type> off_A({ndim_A, nirrep});
112 for (unsigned i = 0;i < ndim_A;i++)
113 {
114 len_type off = 0;
115 for (unsigned irrep = 0;irrep < nirrep;irrep++)
116 {
117 off_A[i][irrep] = off;
118 off += A.length(i, irrep);
119 }
120 }
121
122 A.for_each_block(
123 [&](const varray_view<T>& local_A, const irrep_vector& irreps_A)
124 {
125 auto data_A2 = A2.data();
126 for (unsigned i = 0;i < ndim_A;i++)
127 data_A2 += off_A[i][irreps_A[i]]*A2.stride(i);
128
129 add<U>(comm, cfg, {}, {}, local_A.lengths(),
130 1, false, data_A2, {}, A2.strides(),
131 0, false, local_A.data(), {}, local_A.strides());
132 });
133 }
134
135 template <unsigned I, size_t N>
dense_total_lengths_and_strides_helper(std::array<len_vector,N> &,std::array<stride_vector,N> &)136 void dense_total_lengths_and_strides_helper(std::array<len_vector,N>&,
137 std::array<stride_vector,N>&) {}
138
139 template <unsigned I, size_t N, typename Array, typename... Args>
dense_total_lengths_and_strides_helper(std::array<len_vector,N> & len,std::array<stride_vector,N> & stride,const Array & A,const dim_vector &,const Args &...args)140 void dense_total_lengths_and_strides_helper(std::array<len_vector,N>& len,
141 std::array<stride_vector,N>& stride,
142 const Array& A,
143 const dim_vector&, const Args&... args)
144 {
145 unsigned ndim = A.permutation().size();
146 unsigned nirrep = A.num_irreps();
147
148 len[I].resize(ndim);
149 stride[I].resize(ndim);
150
151 for (unsigned j = 0;j < ndim;j++)
152 {
153 for (unsigned irrep = 0;irrep < nirrep;irrep++)
154 len[I][j] += A.length(j, irrep);
155 }
156
157 auto iperm = detail::inverse_permutation(A.permutation());
158 stride[I][iperm[0]] = 1;
159 for (unsigned j = 1;j < ndim;j++)
160 {
161 stride[I][iperm[j]] = stride[I][iperm[j-1]] * len[I][iperm[j-1]];
162 }
163
164 dense_total_lengths_and_strides_helper<I+1>(len, stride, args...);
165 }
166
167 template <size_t N, typename... Args>
dense_total_lengths_and_strides(std::array<len_vector,N> & len,std::array<stride_vector,N> & stride,const Args &...args)168 void dense_total_lengths_and_strides(std::array<len_vector,N>& len,
169 std::array<stride_vector,N>& stride,
170 const Args&... args)
171 {
172 dense_total_lengths_and_strides_helper<0>(len, stride, args...);
173 }
174
175 template <typename T>
is_block_empty(const dpd_varray_view<T> & A,const irrep_vector & irreps)176 bool is_block_empty(const dpd_varray_view<T>& A, const irrep_vector& irreps)
177 {
178 unsigned irrep = 0;
179
180 for (unsigned i = 0;i < A.dimension();i++)
181 {
182 irrep ^= irreps[i];
183 if (!A.length(i, irreps[i])) return true;
184 }
185
186 return irrep != A.irrep();
187 }
188
assign_irrep(unsigned dim,unsigned irrep)189 inline unsigned assign_irrep(unsigned dim, unsigned irrep)
190 {
191 return irrep;
192 }
193
194 template <typename... Args>
assign_irrep(unsigned dim,unsigned irrep,irrep_vector & irreps,const dim_vector & idx,Args &...args)195 unsigned assign_irrep(unsigned dim, unsigned irrep,
196 irrep_vector& irreps,
197 const dim_vector& idx,
198 Args&... args)
199 {
200 irreps[idx[dim]] = irrep;
201 return assign_irrep(dim, irrep, args...);
202 }
203
204 template <typename... Args>
assign_irreps(unsigned ndim,unsigned irrep,unsigned nirrep,stride_type block,Args &...args)205 void assign_irreps(unsigned ndim, unsigned irrep, unsigned nirrep,
206 stride_type block, Args&... args)
207 {
208 unsigned mask = nirrep-1;
209 unsigned shift = (nirrep>1) + (nirrep>2) + (nirrep>4);
210
211 unsigned irrep0 = irrep;
212 for (unsigned i = 1;i < ndim;i++)
213 {
214 irrep0 ^= assign_irrep(i, block & mask, args...);
215 block >>= shift;
216 }
217 if (ndim) assign_irrep(0, irrep0, args...);
218 }
219
220 }
221 }
222
223 #endif
224