1 #include "add.hpp"
2 #include "reduce.hpp"
3 #include "scale.hpp"
4 #include "shift.hpp"
5
6 #include "internal/1m/add.hpp"
7
8 #include "util/tensor.hpp"
9
10 namespace tblis
11 {
12 namespace internal
13 {
14
15 template <typename T>
add(const communicator & comm,const config & cfg,const len_vector & len_A_,const len_vector & len_B_,const len_vector & len_AB_,T alpha,bool conj_A,const T * A,const stride_vector & stride_A_,const stride_vector & stride_A_AB_,T beta,bool conj_B,T * B,const stride_vector & stride_B_,const stride_vector & stride_B_AB_)16 void add(const communicator& comm, const config& cfg,
17 const len_vector& len_A_,
18 const len_vector& len_B_,
19 const len_vector& len_AB_,
20 T alpha, bool conj_A, const T* A,
21 const stride_vector& stride_A_,
22 const stride_vector& stride_A_AB_,
23 T beta, bool conj_B, T* B,
24 const stride_vector& stride_B_,
25 const stride_vector& stride_B_AB_)
26 {
27 auto perm_A = detail::sort_by_stride(stride_A_);
28 auto perm_B = detail::sort_by_stride(stride_B_);
29 auto perm_AB = detail::sort_by_stride(stride_B_AB_, stride_A_AB_);
30
31 auto len_A = stl_ext::permuted(len_A_, perm_A);
32 auto len_B = stl_ext::permuted(len_B_, perm_B);
33 auto len_AB = stl_ext::permuted(len_AB_, perm_AB);
34
35 auto stride_A = stl_ext::permuted(stride_A_, perm_A);
36 auto stride_B = stl_ext::permuted(stride_B_, perm_B);
37 auto stride_A_AB = stl_ext::permuted(stride_A_AB_, perm_AB);
38 auto stride_B_AB = stl_ext::permuted(stride_B_AB_, perm_AB);
39
40 len_type n_AB = stl_ext::prod(len_AB);
41 len_type n_A = stl_ext::prod(len_A);
42 len_type n_B = stl_ext::prod(len_B);
43
44 if (n_AB == 0 || n_B == 0) return;
45
46 if (n_A == 0)
47 {
48 scale(comm, cfg, len_B, beta, conj_B, B, stride_B);
49 return;
50 }
51
52 //
53 // Scalar intermediate
54 //
55 if (n_AB == 1)
56 {
57 if (n_A > 1)
58 {
59 T sum;
60 len_type idx;
61 reduce(comm, cfg, REDUCE_SUM, len_A, A, stride_A, sum, idx);
62
63 if (comm.master())
64 {
65 if (beta == T(0))
66 {
67 *B = alpha*(conj_A ? conj(sum) : sum);
68 }
69 else
70 {
71 *B = alpha*(conj_A ? conj(sum) : sum) +
72 beta*(conj_B ? conj( *B) : *B);
73 }
74 }
75 }
76 else if (n_B > 1)
77 {
78 shift(comm, cfg, len_B, alpha*(conj_A ? conj(*A) : *A),
79 beta, conj_B, B, stride_B);
80 }
81 else if (comm.master())
82 {
83 if (beta == T(0))
84 {
85 *B = alpha*(conj_A ? conj(*A) : *A);
86 }
87 else
88 {
89 *B = alpha*(conj_A ? conj(*A) : *A) +
90 beta*(conj_B ? conj(*B) : *B);
91 }
92 }
93
94 comm.barrier();
95 return;
96 }
97
98 if (n_A > 1)
99 {
100 //TODO sum (reduce?) ukr
101 //TODO fused ukr
102
103 comm.distribute_over_threads(n_AB,
104 [&](len_type n_min, len_type n_max)
105 {
106 auto A1 = A;
107 auto B1 = B;
108
109 viterator<1> iter_A(len_A, stride_A);
110 viterator<2> iter_AB(len_AB, stride_A_AB, stride_B_AB);
111 iter_AB.position(n_min, A1, B1);
112
113 for (len_type i = n_min;i < n_max;i++)
114 {
115 iter_AB.next(A1, B1);
116
117 T sum_A = T();
118 while (iter_A.next(A1)) sum_A += *A1;
119 sum_A = alpha*(conj_A ? conj(sum_A) : sum_A);
120
121 if (beta == T(0)) *B1 = sum_A;
122 else *B1 = sum_A + beta*(conj_B ? conj(*B1) : *B1);
123 }
124 });
125 }
126 else if (n_B > 1)
127 {
128 //TODO replicate ukr
129 //TODO fused ukr
130
131 comm.distribute_over_threads(n_AB,
132 [&](len_type n_min, len_type n_max)
133 {
134 auto A1 = A;
135 auto B1 = B;
136
137 viterator<1> iter_B(len_B, stride_B);
138 viterator<2> iter_AB(len_AB, stride_A_AB, stride_B_AB);
139 iter_AB.position(n_min, A1, B1);
140
141 for (len_type i = n_min;i < n_max;i++)
142 {
143 iter_AB.next(A1, B1);
144
145 T tmp_A = alpha*(conj_A ? conj(*A1) : *A1);
146
147 if (beta == T(0))
148 {
149 while (iter_B.next(B1)) *B1 = tmp_A;
150 }
151 else
152 {
153 TBLIS_SPECIAL_CASE(conj_B,
154 while (iter_B.next(B1))
155 *B1 = tmp_A + beta*(conj_B ? conj(*B1) : *B1);
156 )
157 }
158 }
159 });
160 }
161 else
162 {
163 unsigned unit_A_AB = 0;
164 unsigned unit_B_AB = 0;
165
166 for (unsigned i = 1;i < len_AB.size();i++)
167 {
168 if (len_AB[i] == 1) continue;
169 if (stride_A_AB[i] == 1 && unit_A_AB == 0) unit_A_AB = i;
170 if (stride_B_AB[i] == 1 && unit_B_AB == 0) unit_B_AB = i;
171 }
172
173 if (unit_A_AB == unit_B_AB)
174 {
175 len_type n0 = len_AB[unit_A_AB];
176 len_vector len1 = len_AB;
177 len1.erase(len1.begin()+unit_A_AB);
178 len_type n1 = stl_ext::prod(len1);
179
180 stride_type stride_A0 = stride_A_AB[unit_A_AB];
181 stride_vector stride_A1 = stride_A_AB;
182 stride_A1.erase(stride_A1.begin()+unit_A_AB);
183
184 stride_type stride_B0 = stride_B_AB[unit_A_AB];
185 stride_vector stride_B1 = stride_B_AB;
186 stride_B1.erase(stride_B1.begin()+unit_A_AB);
187
188 comm.distribute_over_threads(n0, n1,
189 [&](len_type n0_min, len_type n0_max, len_type n1_min, len_type n1_max)
190 {
191 auto A1 = A;
192 auto B1 = B;
193
194 viterator<2> iter_AB(len1, stride_A1, stride_B1);
195 iter_AB.position(n1_min, A1, B1);
196
197 A1 += n0_min*stride_A0;
198 B1 += n0_min*stride_B0;
199
200 for (len_type i = n1_min;i < n1_max;i++)
201 {
202 iter_AB.next(A1, B1);
203 cfg.add_ukr.call<T>(n0_max-n0_min,
204 alpha, conj_A, A1, stride_A0,
205 beta, conj_B, B1, stride_B0);
206 }
207 });
208 }
209 else
210 {
211 // So that the two erase()'s work correctly
212 if (unit_A_AB < unit_B_AB)
213 std::swap(unit_A_AB, unit_B_AB);
214
215 len_type m0 = len_AB[unit_A_AB];
216 len_type n0 = len_AB[unit_B_AB];
217 len_vector len1 = len_AB;
218 len1.erase(len1.begin()+unit_A_AB);
219 len1.erase(len1.begin()+unit_B_AB);
220 len_type mn1 = stl_ext::prod(len1);
221
222 stride_type stride_A_m = stride_A_AB[unit_A_AB];
223 stride_type stride_A_n = stride_A_AB[unit_B_AB];
224 stride_vector stride_A1 = stride_A_AB;
225 stride_A1.erase(stride_A1.begin()+unit_A_AB);
226 stride_A1.erase(stride_A1.begin()+unit_B_AB);
227
228 stride_type stride_B_m = stride_B_AB[unit_A_AB];
229 stride_type stride_B_n = stride_B_AB[unit_B_AB];
230 stride_vector stride_B1 = stride_B_AB;
231 stride_B1.erase(stride_B1.begin()+unit_A_AB);
232 stride_B1.erase(stride_B1.begin()+unit_B_AB);
233
234 unsigned nt_mn1, nt_mn;
235 std::tie(nt_mn1, nt_mn) = partition_2x2(comm.num_threads(), mn1, m0*n0);
236
237 auto subcomm = comm.gang(TCI_EVENLY, nt_mn1);
238
239 subcomm.distribute_over_gangs(mn1,
240 [&](len_type mn1_min, len_type mn1_max)
241 {
242 auto A1 = A;
243 auto B1 = B;
244
245 viterator<2> iter_AB(len1, stride_A1, stride_B1);
246 iter_AB.position(mn1_min, A1, B1);
247
248 for (len_type i = mn1_min;i < mn1_max;i++)
249 {
250 iter_AB.next(A1, B1);
251
252 add(subcomm, cfg, m0, n0,
253 alpha, conj_A, A1, stride_A_m, stride_A_n,
254 beta, conj_B, B1, stride_B_m, stride_B_n);
255 }
256 });
257 }
258 }
259
260 comm.barrier();
261 }
262
263 #define FOREACH_TYPE(T) \
264 template void add(const communicator& comm, const config& cfg, \
265 const len_vector& len_A, \
266 const len_vector& len_B, \
267 const len_vector& len_AB, \
268 T alpha, bool conj_A, const T* A, \
269 const stride_vector& stride_A, \
270 const stride_vector& stride_A_AB, \
271 T beta, bool conj_B, T* B, \
272 const stride_vector& stride_B, \
273 const stride_vector& stride_B_AB);
274 #include "configs/foreach_type.h"
275
276 }
277 }
278