1 #ifndef _TBLIS_INTERNAL_1T_DPD_UTIL_HPP_
2 #define _TBLIS_INTERNAL_1T_DPD_UTIL_HPP_
3 
4 #include "util/basic_types.h"
5 #include "util/tensor.hpp"
6 #include "internal/1t/dense/add.hpp"
7 #include "internal/3t/dpd/mult.hpp"
8 
9 namespace tblis
10 {
11 namespace internal
12 {
13 
14 class irrep_iterator
15 {
16     protected:
17         const unsigned irrep_;
18         const unsigned irrep_bits_;
19         const unsigned irrep_mask_;
20         viterator<0> it_;
21 
22     public:
irrep_iterator(unsigned irrep,unsigned nirrep,unsigned ndim)23         irrep_iterator(unsigned irrep, unsigned nirrep, unsigned ndim)
24         : irrep_(irrep), irrep_bits_(__builtin_popcount(nirrep-1)),
25           irrep_mask_ (nirrep-1), it_(irrep_vector(ndim ? ndim-1 : 0, nirrep)) {}
26 
next()27         bool next()
28         {
29             return it_.next();
30         }
31 
nblock() const32         unsigned nblock() const
33         {
34             return 1u << (irrep_bits_*it_.dimension());
35         }
36 
block(unsigned b)37         void block(unsigned b)
38         {
39             irrep_vector irreps(it_.dimension());
40 
41             for (unsigned i = 0;i < it_.dimension();i++)
42             {
43                 irreps[i] = b & irrep_mask_;
44                 b >>= irrep_bits_;
45             }
46 
47             it_.position(irreps);
48         }
49 
reset()50         void reset()
51         {
52             it_.reset();
53         }
54 
irrep(unsigned dim)55         unsigned irrep(unsigned dim)
56         {
57             TBLIS_ASSERT(dim <= it_.dimension());
58 
59             if (dim == 0)
60             {
61                 unsigned irr0 = irrep_;
62                 for (unsigned irr : it_.position()) irr0 ^= irr;
63                 return irr0;
64             }
65 
66             return it_.position()[dim-1];
67         }
68 };
69 
70 template <typename T, typename U>
block_to_full(const communicator & comm,const config & cfg,const dpd_varray_view<T> & A,varray<U> & A2)71 void block_to_full(const communicator& comm, const config& cfg,
72                    const dpd_varray_view<T>& A, varray<U>& A2)
73 {
74     unsigned nirrep = A.num_irreps();
75     unsigned ndim_A = A.dimension();
76 
77     len_vector len_A(ndim_A);
78     matrix<len_type> off_A({ndim_A, nirrep});
79     for (unsigned i = 0;i < ndim_A;i++)
80     {
81         for (unsigned irrep = 0;irrep < nirrep;irrep++)
82         {
83             off_A[i][irrep] = len_A[i];
84             len_A[i] += A.length(i, irrep);
85         }
86     }
87 
88     if (comm.master()) A2.reset(len_A);
89     comm.barrier();
90 
91     A.for_each_block(
92     [&](const varray_view<T>& local_A, const irrep_vector& irreps_A)
93     {
94         auto data_A2 = A2.data();
95         for (unsigned i = 0;i < ndim_A;i++)
96             data_A2 += off_A[i][irreps_A[i]]*A2.stride(i);
97 
98         add<U>(comm, cfg, {}, {}, local_A.lengths(),
99                1, false, local_A.data(), {}, local_A.strides(),
100                0, false,        data_A2, {},      A2.strides());
101     });
102 }
103 
104 template <typename T, typename U>
full_to_block(const communicator & comm,const config & cfg,const varray<U> & A2,const dpd_varray_view<T> & A)105 void full_to_block(const communicator& comm, const config& cfg,
106                    const varray<U>& A2, const dpd_varray_view<T>& A)
107 {
108     unsigned nirrep = A.num_irreps();
109     unsigned ndim_A = A.dimension();
110 
111     matrix<len_type> off_A({ndim_A, nirrep});
112     for (unsigned i = 0;i < ndim_A;i++)
113     {
114         len_type off = 0;
115         for (unsigned irrep = 0;irrep < nirrep;irrep++)
116         {
117             off_A[i][irrep] = off;
118             off += A.length(i, irrep);
119         }
120     }
121 
122     A.for_each_block(
123     [&](const varray_view<T>& local_A, const irrep_vector& irreps_A)
124     {
125         auto data_A2 = A2.data();
126         for (unsigned i = 0;i < ndim_A;i++)
127             data_A2 += off_A[i][irreps_A[i]]*A2.stride(i);
128 
129         add<U>(comm, cfg, {}, {}, local_A.lengths(),
130                1, false,        data_A2, {},      A2.strides(),
131                0, false, local_A.data(), {}, local_A.strides());
132     });
133 }
134 
135 template <unsigned I, size_t N>
dense_total_lengths_and_strides_helper(std::array<len_vector,N> &,std::array<stride_vector,N> &)136 void dense_total_lengths_and_strides_helper(std::array<len_vector,N>&,
137                                             std::array<stride_vector,N>&) {}
138 
139 template <unsigned I, size_t N, typename Array, typename... Args>
dense_total_lengths_and_strides_helper(std::array<len_vector,N> & len,std::array<stride_vector,N> & stride,const Array & A,const dim_vector &,const Args &...args)140 void dense_total_lengths_and_strides_helper(std::array<len_vector,N>& len,
141                                             std::array<stride_vector,N>& stride,
142                                             const Array& A,
143                                             const dim_vector&, const Args&... args)
144 {
145     unsigned ndim = A.permutation().size();
146     unsigned nirrep = A.num_irreps();
147 
148     len[I].resize(ndim);
149     stride[I].resize(ndim);
150 
151     for (unsigned j = 0;j < ndim;j++)
152     {
153         for (unsigned irrep = 0;irrep < nirrep;irrep++)
154             len[I][j] += A.length(j, irrep);
155     }
156 
157     auto iperm = detail::inverse_permutation(A.permutation());
158     stride[I][iperm[0]] = 1;
159     for (unsigned j = 1;j < ndim;j++)
160     {
161         stride[I][iperm[j]] = stride[I][iperm[j-1]] * len[I][iperm[j-1]];
162     }
163 
164     dense_total_lengths_and_strides_helper<I+1>(len, stride, args...);
165 }
166 
167 template <size_t N, typename... Args>
dense_total_lengths_and_strides(std::array<len_vector,N> & len,std::array<stride_vector,N> & stride,const Args &...args)168 void dense_total_lengths_and_strides(std::array<len_vector,N>& len,
169                                      std::array<stride_vector,N>& stride,
170                                      const Args&... args)
171 {
172     dense_total_lengths_and_strides_helper<0>(len, stride, args...);
173 }
174 
175 template <typename T>
is_block_empty(const dpd_varray_view<T> & A,const irrep_vector & irreps)176 bool is_block_empty(const dpd_varray_view<T>& A, const irrep_vector& irreps)
177 {
178     unsigned irrep = 0;
179 
180     for (unsigned i = 0;i < A.dimension();i++)
181     {
182         irrep ^= irreps[i];
183         if (!A.length(i, irreps[i])) return true;
184     }
185 
186     return irrep != A.irrep();
187 }
188 
assign_irrep(unsigned dim,unsigned irrep)189 inline unsigned assign_irrep(unsigned dim, unsigned irrep)
190 {
191     return irrep;
192 }
193 
194 template <typename... Args>
assign_irrep(unsigned dim,unsigned irrep,irrep_vector & irreps,const dim_vector & idx,Args &...args)195 unsigned assign_irrep(unsigned dim, unsigned irrep,
196                       irrep_vector& irreps,
197                       const dim_vector& idx,
198                       Args&... args)
199 {
200     irreps[idx[dim]] = irrep;
201     return assign_irrep(dim, irrep, args...);
202 }
203 
204 template <typename... Args>
assign_irreps(unsigned ndim,unsigned irrep,unsigned nirrep,stride_type block,Args &...args)205 void assign_irreps(unsigned ndim, unsigned irrep, unsigned nirrep,
206                    stride_type block, Args&... args)
207 {
208     unsigned mask = nirrep-1;
209     unsigned shift = (nirrep>1) + (nirrep>2) + (nirrep>4);
210 
211     unsigned irrep0 = irrep;
212     for (unsigned i = 1;i < ndim;i++)
213     {
214         irrep0 ^= assign_irrep(i, block & mask, args...);
215         block >>= shift;
216     }
217     if (ndim) assign_irrep(0, irrep0, args...);
218 }
219 
220 }
221 }
222 
223 #endif
224