1 #include "util.hpp"
2 #include "add.hpp"
3 #include "scale.hpp"
4 #include "set.hpp"
5 #include "internal/1t/dense/add.hpp"
6 
7 namespace tblis
8 {
9 namespace internal
10 {
11 
12 template <typename T>
add_full(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_B,const dim_vector & idx_B_AB)13 void add_full(const communicator& comm, const config& cfg,
14               T alpha, bool conj_A, const indexed_varray_view<const T>& A,
15               const dim_vector& idx_A_A,
16               const dim_vector& idx_A_AB,
17                                     const indexed_varray_view<      T>& B,
18               const dim_vector& idx_B_B,
19               const dim_vector& idx_B_AB)
20 {
21     varray<T> A2, B2;
22 
23     comm.broadcast(
24     [&](varray<T>& A2, varray<T>& B2)
25     {
26         block_to_full(comm, cfg, A, A2);
27         block_to_full(comm, cfg, B, B2);
28 
29         auto len_A = stl_ext::select_from(A2.lengths(), idx_A_A);
30         auto len_B = stl_ext::select_from(B2.lengths(), idx_B_B);
31         auto len_AB = stl_ext::select_from(A2.lengths(), idx_A_AB);
32         auto stride_A_A = stl_ext::select_from(A2.strides(), idx_A_A);
33         auto stride_B_B = stl_ext::select_from(B2.strides(), idx_B_B);
34         auto stride_A_AB = stl_ext::select_from(A2.strides(), idx_A_AB);
35         auto stride_B_AB = stl_ext::select_from(B2.strides(), idx_B_AB);
36 
37         add(comm, cfg, len_A, len_B, len_AB,
38             alpha, conj_A, A2.data(), stride_A_A, stride_A_AB,
39              T(0),  false, B2.data(), stride_B_B, stride_B_AB);
40 
41         full_to_block(comm, cfg, B2, B);
42     },
43     A2, B2);
44 }
45 
46 template <typename T>
trace_block(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_AB)47 void trace_block(const communicator& comm, const config& cfg,
48                  T alpha, bool conj_A, const indexed_varray_view<const T>& A,
49                  const dim_vector& idx_A_A,
50                  const dim_vector& idx_A_AB,
51                                        const indexed_varray_view<      T>& B,
52                  const dim_vector& idx_B_AB)
53 {
54     index_group<2> group_AB(A, idx_A_AB, B, idx_B_AB);
55     index_group<1> group_A(A, idx_A_A);
56 
57     group_indices<T, 2> indices_A(A, group_AB, 0, group_A, 0);
58     group_indices<T, 1> indices_B(B, group_AB, 1);
59     auto nidx_A = indices_A.size();
60     auto nidx_B = indices_B.size();
61 
62     stride_type idx = 0;
63     stride_type idx_A = 0;
64     stride_type idx_B = 0;
65 
66     comm.do_tasks_deferred(nidx_B, stl_ext::prod(group_AB.dense_len)*
67                                    stl_ext::prod(group_A.dense_len)*inout_ratio,
68     [&](communicator::deferred_task_set& tasks)
69     {
70         for_each_match<true, false>(idx_A, nidx_A, indices_A, 0,
71                                     idx_B, nidx_B, indices_B, 0,
72         [&](stride_type next_A)
73         {
74             if (indices_B[idx_B].factor == T(0)) return;
75 
76             tasks.visit(idx++,
77             [&,idx_A,idx_B,next_A](const communicator& subcomm)
78             {
79                 stride_type off_A_AB, off_B_AB;
80                 get_local_offset(indices_A[idx_A].idx[0], group_AB,
81                                  off_A_AB, 0, off_B_AB, 1);
82 
83                 auto data_B = B.data(0) + indices_B[idx_B].offset + off_B_AB;
84 
85                 for (auto local_idx_A = idx_A;local_idx_A < next_A;local_idx_A++)
86                 {
87                     auto factor = alpha*indices_A[local_idx_A].factor*
88                                         indices_B[idx_B].factor;
89                     if (factor == T(0)) continue;
90 
91                     auto data_A = A.data(0) + indices_A[local_idx_A].offset + off_A_AB;
92 
93                     add(subcomm, cfg, group_A.dense_len, {}, group_AB.dense_len,
94                         factor, conj_A, data_A, group_A.dense_stride[0],
95                                                 group_AB.dense_stride[0],
96                           T(1),  false, data_B, {}, group_AB.dense_stride[1]);
97                 }
98             });
99         });
100     });
101 }
102 
103 template <typename T>
replicate_block(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_B,const dim_vector & idx_B_AB)104 void replicate_block(const communicator& comm, const config& cfg,
105                      T alpha, bool conj_A, const indexed_varray_view<const T>& A,
106                      const dim_vector& idx_A_AB,
107                                            const indexed_varray_view<      T>& B,
108                      const dim_vector& idx_B_B,
109                      const dim_vector& idx_B_AB)
110 {
111     index_group<2> group_AB(A, idx_A_AB, B, idx_B_AB);
112     index_group<1> group_B(B, idx_B_B);
113 
114     group_indices<T, 1> indices_A(A, group_AB, 0);
115     group_indices<T, 2> indices_B(B, group_AB, 1, group_B, 0);
116     auto nidx_A = indices_A.size();
117     auto nidx_B = indices_B.size();
118 
119     stride_type idx = 0;
120     stride_type idx_A = 0;
121     stride_type idx_B = 0;
122 
123     len_vector dense_len_B = group_AB.dense_len + group_B.dense_len;
124     stride_vector dense_stride_B = group_AB.dense_stride[1] + group_B.dense_stride[0];
125 
126     comm.do_tasks_deferred(nidx_B, stl_ext::prod(group_AB.dense_len)*
127                                    stl_ext::prod(group_B.dense_len)*inout_ratio,
128     [&](communicator::deferred_task_set& tasks)
129     {
130         for_each_match<false, true>(idx_A, nidx_A, indices_A, 0,
131                                    idx_B, nidx_B, indices_B, 0,
132         [&](stride_type next_B)
133         {
134             for (auto local_idx_B = idx_B;local_idx_B < next_B;local_idx_B++)
135             {
136                 auto factor = alpha*indices_A[idx_A].factor*
137                                     indices_B[local_idx_B].factor;
138                 if (factor == T(0)) continue;
139 
140                 tasks.visit(idx++,
141                 [&,idx_A,local_idx_B,factor](const communicator& subcomm)
142                 {
143                     stride_type off_A_AB, off_B_AB;
144                     get_local_offset(indices_A[idx_A].idx[0], group_AB,
145                                      off_A_AB, 0, off_B_AB, 1);
146 
147                     auto data_A = A.data(0) + indices_A[idx_A].offset + off_A_AB;
148                     auto data_B = B.data(0) + indices_B[local_idx_B].offset + off_B_AB;
149                     add(subcomm, cfg, {}, group_B.dense_len, group_AB.dense_len,
150                         factor, conj_A, data_A, {}, group_AB.dense_stride[0],
151                           T(1),  false, data_B, group_B.dense_stride[0],
152                                                 group_AB.dense_stride[1]);
153                 });
154             }
155         });
156     });
157 }
158 
159 template <typename T>
transpose_block(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_AB)160 void transpose_block(const communicator& comm, const config& cfg,
161                      T alpha, bool conj_A, const indexed_varray_view<const T>& A,
162                      const dim_vector& idx_A_AB,
163                                            const indexed_varray_view<      T>& B,
164                      const dim_vector& idx_B_AB)
165 {
166     index_group<2> group_AB(A, idx_A_AB, B, idx_B_AB);
167 
168     group_indices<T, 1> indices_A(A, group_AB, 0);
169     group_indices<T, 1> indices_B(B, group_AB, 1);
170     auto nidx_A = indices_A.size();
171     auto nidx_B = indices_B.size();
172 
173     stride_type idx = 0;
174     stride_type idx_A = 0;
175     stride_type idx_B = 0;
176 
177     comm.do_tasks_deferred(nidx_B, stl_ext::prod(group_AB.dense_len)*inout_ratio,
178     [&](communicator::deferred_task_set& tasks)
179     {
180         for_each_match<false, false>(idx_A, nidx_A, indices_A, 0,
181                                     idx_B, nidx_B, indices_B, 0,
182         [&]
183         {
184             auto factor = alpha*indices_A[idx_A].factor*indices_B[idx_B].factor;
185             if (factor == T(0)) return;
186 
187             tasks.visit(idx++,
188             [&,idx_A,idx_B,factor](const communicator& subcomm)
189             {
190                 stride_type off_A_AB, off_B_AB;
191                 get_local_offset(indices_A[idx_A].idx[0], group_AB,
192                                  off_A_AB, 0, off_B_AB, 1);
193 
194                 auto data_A = A.data(0) + indices_A[idx_A].offset + off_A_AB;
195                 auto data_B = B.data(0) + indices_B[idx_B].offset + off_B_AB;
196 
197                 add(subcomm, cfg, {}, {}, group_AB.dense_len,
198                     factor, conj_A, data_A, {}, group_AB.dense_stride[0],
199                       T(1),  false, data_B, {}, group_AB.dense_stride[1]);
200             });
201         });
202     });
203 }
204 
205 template <typename T>
add(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_A,const dim_vector & idx_A_AB,T beta,bool conj_B,const indexed_varray_view<T> & B,const dim_vector & idx_B_B,const dim_vector & idx_B_AB)206 void add(const communicator& comm, const config& cfg,
207          T alpha, bool conj_A, const indexed_varray_view<const T>& A,
208          const dim_vector& idx_A_A,
209          const dim_vector& idx_A_AB,
210          T  beta, bool conj_B, const indexed_varray_view<      T>& B,
211          const dim_vector& idx_B_B,
212          const dim_vector& idx_B_AB)
213 {
214     if (beta == T(0))
215     {
216         set(comm, cfg, T(0), B, range(B.dimension()));
217     }
218     else if (beta != T(1) || (is_complex<T>::value && conj_B))
219     {
220         scale(comm, cfg, beta, conj_B, B, range(B.dimension()));
221     }
222 
223     if (dpd_impl == FULL)
224     {
225         add_full(comm, cfg,
226                  alpha, conj_A, A, idx_A_A, idx_A_AB,
227                              B, idx_B_B, idx_B_AB);
228     }
229     else if (!idx_A_A.empty())
230     {
231         trace_block(comm, cfg,
232                     alpha, conj_A, A, idx_A_A, idx_A_AB,
233                                    B, idx_B_AB);
234     }
235     else if (!idx_B_B.empty())
236     {
237         replicate_block(comm, cfg,
238                         alpha, conj_A, A, idx_A_AB,
239                                        B, idx_B_B, idx_B_AB);
240     }
241     else
242     {
243         transpose_block(comm, cfg,
244                         alpha, conj_A, A, idx_A_AB,
245                                        B, idx_B_AB);
246     }
247 }
248 
249 #define FOREACH_TYPE(T) \
250 template void add(const communicator& comm, const config& cfg, \
251                   T alpha, bool conj_A, const indexed_varray_view<const T>& A, \
252                   const dim_vector& idx_A, \
253                   const dim_vector& idx_A_AB, \
254                   T  beta, bool conj_B, const indexed_varray_view<      T>& B, \
255                   const dim_vector& idx_B, \
256                   const dim_vector& idx_B_AB);
257 #include "configs/foreach_type.h"
258 
259 }
260 }
261