1 #include "util.hpp"
2 #include "add.hpp"
3 #include "scale.hpp"
4 #include "set.hpp"
5 #include "internal/1t/dense/add.hpp"
6
7 namespace tblis
8 {
9 namespace internal
10 {
11
12 template <typename T>
add_full(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_B,const dim_vector & idx_B_AB)13 void add_full(const communicator& comm, const config& cfg,
14 T alpha, bool conj_A, const indexed_varray_view<const T>& A,
15 const dim_vector& idx_A_A,
16 const dim_vector& idx_A_AB,
17 const indexed_varray_view< T>& B,
18 const dim_vector& idx_B_B,
19 const dim_vector& idx_B_AB)
20 {
21 varray<T> A2, B2;
22
23 comm.broadcast(
24 [&](varray<T>& A2, varray<T>& B2)
25 {
26 block_to_full(comm, cfg, A, A2);
27 block_to_full(comm, cfg, B, B2);
28
29 auto len_A = stl_ext::select_from(A2.lengths(), idx_A_A);
30 auto len_B = stl_ext::select_from(B2.lengths(), idx_B_B);
31 auto len_AB = stl_ext::select_from(A2.lengths(), idx_A_AB);
32 auto stride_A_A = stl_ext::select_from(A2.strides(), idx_A_A);
33 auto stride_B_B = stl_ext::select_from(B2.strides(), idx_B_B);
34 auto stride_A_AB = stl_ext::select_from(A2.strides(), idx_A_AB);
35 auto stride_B_AB = stl_ext::select_from(B2.strides(), idx_B_AB);
36
37 add(comm, cfg, len_A, len_B, len_AB,
38 alpha, conj_A, A2.data(), stride_A_A, stride_A_AB,
39 T(0), false, B2.data(), stride_B_B, stride_B_AB);
40
41 full_to_block(comm, cfg, B2, B);
42 },
43 A2, B2);
44 }
45
46 template <typename T>
trace_block(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_AB)47 void trace_block(const communicator& comm, const config& cfg,
48 T alpha, bool conj_A, const indexed_varray_view<const T>& A,
49 const dim_vector& idx_A_A,
50 const dim_vector& idx_A_AB,
51 const indexed_varray_view< T>& B,
52 const dim_vector& idx_B_AB)
53 {
54 index_group<2> group_AB(A, idx_A_AB, B, idx_B_AB);
55 index_group<1> group_A(A, idx_A_A);
56
57 group_indices<T, 2> indices_A(A, group_AB, 0, group_A, 0);
58 group_indices<T, 1> indices_B(B, group_AB, 1);
59 auto nidx_A = indices_A.size();
60 auto nidx_B = indices_B.size();
61
62 stride_type idx = 0;
63 stride_type idx_A = 0;
64 stride_type idx_B = 0;
65
66 comm.do_tasks_deferred(nidx_B, stl_ext::prod(group_AB.dense_len)*
67 stl_ext::prod(group_A.dense_len)*inout_ratio,
68 [&](communicator::deferred_task_set& tasks)
69 {
70 for_each_match<true, false>(idx_A, nidx_A, indices_A, 0,
71 idx_B, nidx_B, indices_B, 0,
72 [&](stride_type next_A)
73 {
74 if (indices_B[idx_B].factor == T(0)) return;
75
76 tasks.visit(idx++,
77 [&,idx_A,idx_B,next_A](const communicator& subcomm)
78 {
79 stride_type off_A_AB, off_B_AB;
80 get_local_offset(indices_A[idx_A].idx[0], group_AB,
81 off_A_AB, 0, off_B_AB, 1);
82
83 auto data_B = B.data(0) + indices_B[idx_B].offset + off_B_AB;
84
85 for (auto local_idx_A = idx_A;local_idx_A < next_A;local_idx_A++)
86 {
87 auto factor = alpha*indices_A[local_idx_A].factor*
88 indices_B[idx_B].factor;
89 if (factor == T(0)) continue;
90
91 auto data_A = A.data(0) + indices_A[local_idx_A].offset + off_A_AB;
92
93 add(subcomm, cfg, group_A.dense_len, {}, group_AB.dense_len,
94 factor, conj_A, data_A, group_A.dense_stride[0],
95 group_AB.dense_stride[0],
96 T(1), false, data_B, {}, group_AB.dense_stride[1]);
97 }
98 });
99 });
100 });
101 }
102
103 template <typename T>
replicate_block(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_B,const dim_vector & idx_B_AB)104 void replicate_block(const communicator& comm, const config& cfg,
105 T alpha, bool conj_A, const indexed_varray_view<const T>& A,
106 const dim_vector& idx_A_AB,
107 const indexed_varray_view< T>& B,
108 const dim_vector& idx_B_B,
109 const dim_vector& idx_B_AB)
110 {
111 index_group<2> group_AB(A, idx_A_AB, B, idx_B_AB);
112 index_group<1> group_B(B, idx_B_B);
113
114 group_indices<T, 1> indices_A(A, group_AB, 0);
115 group_indices<T, 2> indices_B(B, group_AB, 1, group_B, 0);
116 auto nidx_A = indices_A.size();
117 auto nidx_B = indices_B.size();
118
119 stride_type idx = 0;
120 stride_type idx_A = 0;
121 stride_type idx_B = 0;
122
123 len_vector dense_len_B = group_AB.dense_len + group_B.dense_len;
124 stride_vector dense_stride_B = group_AB.dense_stride[1] + group_B.dense_stride[0];
125
126 comm.do_tasks_deferred(nidx_B, stl_ext::prod(group_AB.dense_len)*
127 stl_ext::prod(group_B.dense_len)*inout_ratio,
128 [&](communicator::deferred_task_set& tasks)
129 {
130 for_each_match<false, true>(idx_A, nidx_A, indices_A, 0,
131 idx_B, nidx_B, indices_B, 0,
132 [&](stride_type next_B)
133 {
134 for (auto local_idx_B = idx_B;local_idx_B < next_B;local_idx_B++)
135 {
136 auto factor = alpha*indices_A[idx_A].factor*
137 indices_B[local_idx_B].factor;
138 if (factor == T(0)) continue;
139
140 tasks.visit(idx++,
141 [&,idx_A,local_idx_B,factor](const communicator& subcomm)
142 {
143 stride_type off_A_AB, off_B_AB;
144 get_local_offset(indices_A[idx_A].idx[0], group_AB,
145 off_A_AB, 0, off_B_AB, 1);
146
147 auto data_A = A.data(0) + indices_A[idx_A].offset + off_A_AB;
148 auto data_B = B.data(0) + indices_B[local_idx_B].offset + off_B_AB;
149 add(subcomm, cfg, {}, group_B.dense_len, group_AB.dense_len,
150 factor, conj_A, data_A, {}, group_AB.dense_stride[0],
151 T(1), false, data_B, group_B.dense_stride[0],
152 group_AB.dense_stride[1]);
153 });
154 }
155 });
156 });
157 }
158
159 template <typename T>
transpose_block(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_AB,const indexed_varray_view<T> & B,const dim_vector & idx_B_AB)160 void transpose_block(const communicator& comm, const config& cfg,
161 T alpha, bool conj_A, const indexed_varray_view<const T>& A,
162 const dim_vector& idx_A_AB,
163 const indexed_varray_view< T>& B,
164 const dim_vector& idx_B_AB)
165 {
166 index_group<2> group_AB(A, idx_A_AB, B, idx_B_AB);
167
168 group_indices<T, 1> indices_A(A, group_AB, 0);
169 group_indices<T, 1> indices_B(B, group_AB, 1);
170 auto nidx_A = indices_A.size();
171 auto nidx_B = indices_B.size();
172
173 stride_type idx = 0;
174 stride_type idx_A = 0;
175 stride_type idx_B = 0;
176
177 comm.do_tasks_deferred(nidx_B, stl_ext::prod(group_AB.dense_len)*inout_ratio,
178 [&](communicator::deferred_task_set& tasks)
179 {
180 for_each_match<false, false>(idx_A, nidx_A, indices_A, 0,
181 idx_B, nidx_B, indices_B, 0,
182 [&]
183 {
184 auto factor = alpha*indices_A[idx_A].factor*indices_B[idx_B].factor;
185 if (factor == T(0)) return;
186
187 tasks.visit(idx++,
188 [&,idx_A,idx_B,factor](const communicator& subcomm)
189 {
190 stride_type off_A_AB, off_B_AB;
191 get_local_offset(indices_A[idx_A].idx[0], group_AB,
192 off_A_AB, 0, off_B_AB, 1);
193
194 auto data_A = A.data(0) + indices_A[idx_A].offset + off_A_AB;
195 auto data_B = B.data(0) + indices_B[idx_B].offset + off_B_AB;
196
197 add(subcomm, cfg, {}, {}, group_AB.dense_len,
198 factor, conj_A, data_A, {}, group_AB.dense_stride[0],
199 T(1), false, data_B, {}, group_AB.dense_stride[1]);
200 });
201 });
202 });
203 }
204
205 template <typename T>
add(const communicator & comm,const config & cfg,T alpha,bool conj_A,const indexed_varray_view<const T> & A,const dim_vector & idx_A_A,const dim_vector & idx_A_AB,T beta,bool conj_B,const indexed_varray_view<T> & B,const dim_vector & idx_B_B,const dim_vector & idx_B_AB)206 void add(const communicator& comm, const config& cfg,
207 T alpha, bool conj_A, const indexed_varray_view<const T>& A,
208 const dim_vector& idx_A_A,
209 const dim_vector& idx_A_AB,
210 T beta, bool conj_B, const indexed_varray_view< T>& B,
211 const dim_vector& idx_B_B,
212 const dim_vector& idx_B_AB)
213 {
214 if (beta == T(0))
215 {
216 set(comm, cfg, T(0), B, range(B.dimension()));
217 }
218 else if (beta != T(1) || (is_complex<T>::value && conj_B))
219 {
220 scale(comm, cfg, beta, conj_B, B, range(B.dimension()));
221 }
222
223 if (dpd_impl == FULL)
224 {
225 add_full(comm, cfg,
226 alpha, conj_A, A, idx_A_A, idx_A_AB,
227 B, idx_B_B, idx_B_AB);
228 }
229 else if (!idx_A_A.empty())
230 {
231 trace_block(comm, cfg,
232 alpha, conj_A, A, idx_A_A, idx_A_AB,
233 B, idx_B_AB);
234 }
235 else if (!idx_B_B.empty())
236 {
237 replicate_block(comm, cfg,
238 alpha, conj_A, A, idx_A_AB,
239 B, idx_B_B, idx_B_AB);
240 }
241 else
242 {
243 transpose_block(comm, cfg,
244 alpha, conj_A, A, idx_A_AB,
245 B, idx_B_AB);
246 }
247 }
248
249 #define FOREACH_TYPE(T) \
250 template void add(const communicator& comm, const config& cfg, \
251 T alpha, bool conj_A, const indexed_varray_view<const T>& A, \
252 const dim_vector& idx_A, \
253 const dim_vector& idx_A_AB, \
254 T beta, bool conj_B, const indexed_varray_view< T>& B, \
255 const dim_vector& idx_B, \
256 const dim_vector& idx_B_AB);
257 #include "configs/foreach_type.h"
258
259 }
260 }
261