1 #include "add.hpp"
2 #include "reduce.hpp"
3 #include "scale.hpp"
4 #include "shift.hpp"
5 
6 #include "internal/1m/add.hpp"
7 
8 #include "util/tensor.hpp"
9 
10 namespace tblis
11 {
12 namespace internal
13 {
14 
15 template <typename T>
add(const communicator & comm,const config & cfg,const len_vector & len_A_,const len_vector & len_B_,const len_vector & len_AB_,T alpha,bool conj_A,const T * A,const stride_vector & stride_A_,const stride_vector & stride_A_AB_,T beta,bool conj_B,T * B,const stride_vector & stride_B_,const stride_vector & stride_B_AB_)16 void add(const communicator& comm, const config& cfg,
17          const len_vector& len_A_,
18          const len_vector& len_B_,
19          const len_vector& len_AB_,
20          T alpha, bool conj_A, const T* A,
21          const stride_vector& stride_A_,
22          const stride_vector& stride_A_AB_,
23          T  beta, bool conj_B,       T* B,
24          const stride_vector& stride_B_,
25          const stride_vector& stride_B_AB_)
26 {
27     auto perm_A = detail::sort_by_stride(stride_A_);
28     auto perm_B = detail::sort_by_stride(stride_B_);
29     auto perm_AB = detail::sort_by_stride(stride_B_AB_, stride_A_AB_);
30 
31     auto len_A = stl_ext::permuted(len_A_, perm_A);
32     auto len_B = stl_ext::permuted(len_B_, perm_B);
33     auto len_AB = stl_ext::permuted(len_AB_, perm_AB);
34 
35     auto stride_A = stl_ext::permuted(stride_A_, perm_A);
36     auto stride_B = stl_ext::permuted(stride_B_, perm_B);
37     auto stride_A_AB = stl_ext::permuted(stride_A_AB_, perm_AB);
38     auto stride_B_AB = stl_ext::permuted(stride_B_AB_, perm_AB);
39 
40     len_type n_AB = stl_ext::prod(len_AB);
41     len_type n_A = stl_ext::prod(len_A);
42     len_type n_B = stl_ext::prod(len_B);
43 
44     if (n_AB == 0 || n_B == 0) return;
45 
46     if (n_A == 0)
47     {
48         scale(comm, cfg, len_B, beta, conj_B, B, stride_B);
49         return;
50     }
51 
52     //
53     // Scalar intermediate
54     //
55     if (n_AB == 1)
56     {
57         if (n_A > 1)
58         {
59             T sum;
60             len_type idx;
61             reduce(comm, cfg, REDUCE_SUM, len_A, A, stride_A, sum, idx);
62 
63             if (comm.master())
64             {
65                 if (beta == T(0))
66                 {
67                     *B = alpha*(conj_A ? conj(sum) : sum);
68                 }
69                 else
70                 {
71                     *B = alpha*(conj_A ? conj(sum) : sum) +
72                           beta*(conj_B ? conj( *B) :  *B);
73                 }
74             }
75         }
76         else if (n_B > 1)
77         {
78             shift(comm, cfg, len_B, alpha*(conj_A ? conj(*A) : *A),
79                   beta, conj_B, B, stride_B);
80         }
81         else if (comm.master())
82         {
83             if (beta == T(0))
84             {
85                 *B = alpha*(conj_A ? conj(*A) : *A);
86             }
87             else
88             {
89                 *B = alpha*(conj_A ? conj(*A) : *A) +
90                       beta*(conj_B ? conj(*B) : *B);
91             }
92         }
93 
94         comm.barrier();
95         return;
96     }
97 
98     if (n_A > 1)
99     {
100         //TODO sum (reduce?) ukr
101         //TODO fused ukr
102 
103         comm.distribute_over_threads(n_AB,
104         [&](len_type n_min, len_type n_max)
105         {
106             auto A1 = A;
107             auto B1 = B;
108 
109             viterator<1> iter_A(len_A, stride_A);
110             viterator<2> iter_AB(len_AB, stride_A_AB, stride_B_AB);
111             iter_AB.position(n_min, A1, B1);
112 
113             for (len_type i = n_min;i < n_max;i++)
114             {
115                 iter_AB.next(A1, B1);
116 
117                 T sum_A = T();
118                 while (iter_A.next(A1)) sum_A += *A1;
119                 sum_A = alpha*(conj_A ? conj(sum_A) : sum_A);
120 
121                 if (beta == T(0)) *B1 = sum_A;
122                 else              *B1 = sum_A + beta*(conj_B ? conj(*B1) : *B1);
123             }
124         });
125     }
126     else if (n_B > 1)
127     {
128         //TODO replicate ukr
129         //TODO fused ukr
130 
131         comm.distribute_over_threads(n_AB,
132         [&](len_type n_min, len_type n_max)
133         {
134             auto A1 = A;
135             auto B1 = B;
136 
137             viterator<1> iter_B(len_B, stride_B);
138             viterator<2> iter_AB(len_AB, stride_A_AB, stride_B_AB);
139             iter_AB.position(n_min, A1, B1);
140 
141             for (len_type i = n_min;i < n_max;i++)
142             {
143                 iter_AB.next(A1, B1);
144 
145                 T tmp_A = alpha*(conj_A ? conj(*A1) : *A1);
146 
147                 if (beta == T(0))
148                 {
149                     while (iter_B.next(B1)) *B1 = tmp_A;
150                 }
151                 else
152                 {
153                     TBLIS_SPECIAL_CASE(conj_B,
154                     while (iter_B.next(B1))
155                         *B1 = tmp_A + beta*(conj_B ? conj(*B1) : *B1);
156                     )
157                 }
158             }
159         });
160     }
161     else
162     {
163         unsigned unit_A_AB = 0;
164         unsigned unit_B_AB = 0;
165 
166         for (unsigned i = 1;i < len_AB.size();i++)
167         {
168             if (len_AB[i] == 1) continue;
169             if (stride_A_AB[i] == 1 && unit_A_AB == 0) unit_A_AB = i;
170             if (stride_B_AB[i] == 1 && unit_B_AB == 0) unit_B_AB = i;
171         }
172 
173         if (unit_A_AB == unit_B_AB)
174         {
175             len_type n0 = len_AB[unit_A_AB];
176             len_vector len1 = len_AB;
177             len1.erase(len1.begin()+unit_A_AB);
178             len_type n1 = stl_ext::prod(len1);
179 
180             stride_type stride_A0 = stride_A_AB[unit_A_AB];
181             stride_vector stride_A1 = stride_A_AB;
182             stride_A1.erase(stride_A1.begin()+unit_A_AB);
183 
184             stride_type stride_B0 = stride_B_AB[unit_A_AB];
185             stride_vector stride_B1 = stride_B_AB;
186             stride_B1.erase(stride_B1.begin()+unit_A_AB);
187 
188             comm.distribute_over_threads(n0, n1,
189             [&](len_type n0_min, len_type n0_max, len_type n1_min, len_type n1_max)
190             {
191                 auto A1 = A;
192                 auto B1 = B;
193 
194                 viterator<2> iter_AB(len1, stride_A1, stride_B1);
195                 iter_AB.position(n1_min, A1, B1);
196 
197                 A1 += n0_min*stride_A0;
198                 B1 += n0_min*stride_B0;
199 
200                 for (len_type i = n1_min;i < n1_max;i++)
201                 {
202                     iter_AB.next(A1, B1);
203                     cfg.add_ukr.call<T>(n0_max-n0_min,
204                                         alpha, conj_A, A1, stride_A0,
205                                          beta, conj_B, B1, stride_B0);
206                 }
207             });
208         }
209         else
210         {
211             // So that the two erase()'s work correctly
212             if (unit_A_AB < unit_B_AB)
213                 std::swap(unit_A_AB, unit_B_AB);
214 
215             len_type m0 = len_AB[unit_A_AB];
216             len_type n0 = len_AB[unit_B_AB];
217             len_vector len1 = len_AB;
218             len1.erase(len1.begin()+unit_A_AB);
219             len1.erase(len1.begin()+unit_B_AB);
220             len_type mn1 = stl_ext::prod(len1);
221 
222             stride_type stride_A_m = stride_A_AB[unit_A_AB];
223             stride_type stride_A_n = stride_A_AB[unit_B_AB];
224             stride_vector stride_A1 = stride_A_AB;
225             stride_A1.erase(stride_A1.begin()+unit_A_AB);
226             stride_A1.erase(stride_A1.begin()+unit_B_AB);
227 
228             stride_type stride_B_m = stride_B_AB[unit_A_AB];
229             stride_type stride_B_n = stride_B_AB[unit_B_AB];
230             stride_vector stride_B1 = stride_B_AB;
231             stride_B1.erase(stride_B1.begin()+unit_A_AB);
232             stride_B1.erase(stride_B1.begin()+unit_B_AB);
233 
234             unsigned nt_mn1, nt_mn;
235             std::tie(nt_mn1, nt_mn) = partition_2x2(comm.num_threads(), mn1, m0*n0);
236 
237             auto subcomm = comm.gang(TCI_EVENLY, nt_mn1);
238 
239             subcomm.distribute_over_gangs(mn1,
240             [&](len_type mn1_min, len_type mn1_max)
241             {
242                 auto A1 = A;
243                 auto B1 = B;
244 
245                 viterator<2> iter_AB(len1, stride_A1, stride_B1);
246                 iter_AB.position(mn1_min, A1, B1);
247 
248                 for (len_type i = mn1_min;i < mn1_max;i++)
249                 {
250                     iter_AB.next(A1, B1);
251 
252                     add(subcomm, cfg, m0, n0,
253                         alpha, conj_A, A1, stride_A_m, stride_A_n,
254                          beta, conj_B, B1, stride_B_m, stride_B_n);
255                 }
256             });
257         }
258     }
259 
260     comm.barrier();
261 }
262 
263 #define FOREACH_TYPE(T) \
264 template void add(const communicator& comm, const config& cfg, \
265                   const len_vector& len_A, \
266                   const len_vector& len_B, \
267                   const len_vector& len_AB, \
268                   T alpha, bool conj_A, const T* A, \
269                   const stride_vector& stride_A, \
270                   const stride_vector& stride_A_AB, \
271                   T  beta, bool conj_B,       T* B, \
272                   const stride_vector& stride_B, \
273                   const stride_vector& stride_B_AB);
274 #include "configs/foreach_type.h"
275 
276 }
277 }
278