1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include "cpu/x64/cpu_isa_traits.hpp"
18 #include "cpu/x64/jit_generator.hpp"
19 
20 #include "cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp"
21 
22 #ifdef _WIN32
23 static const bool is_windows = true;
24 #else
25 static const bool is_windows = false;
26 #endif
27 
28 namespace dnnl {
29 namespace impl {
30 namespace cpu {
31 namespace x64 {
32 
33 using namespace Xbyak;
34 
35 // Convert between vector register lengths.
make_xmm(const Xmm & v)36 static inline Xmm make_xmm(const Xmm &v) {
37     return Xmm(v.getIdx());
38 }
39 
40 // Load vector register data for x, y or A.
v_load(const Xbyak::Xmm & dst,const Xbyak::Address & src,int nelems)41 void jit_avx_gemv_t_f32_kern::v_load(
42         const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) {
43     switch (nelems) {
44         case 1: vmovss(make_xmm(dst), src); break;
45         case 2: vmovsd(make_xmm(dst), src); break;
46         case 4: vmovups(make_xmm(dst), src); break;
47         default:
48             assert(nelems >= 8);
49             vmovups(dst, src);
50             break;
51     }
52 }
53 
54 // Store vector register data for x, y or A.
v_store(const Xbyak::Address & dst,const Xbyak::Xmm & src,int nelems)55 void jit_avx_gemv_t_f32_kern::v_store(
56         const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems) {
57     switch (nelems) {
58         case 1: vmovss(dst, make_xmm(src)); break;
59         case 2: vmovsd(dst, make_xmm(src)); break;
60         case 4: vmovups(dst, make_xmm(src)); break;
61         default:
62             assert(nelems >= 8);
63             vmovups(dst, src);
64             break;
65     }
66 }
67 
68 // Perform Hadamard product of 2 vectors and accumulate.
69 // Use FMA instruction, otherwise emulate.
dot_product(const Xmm & dst,const Xmm & src1,const Xmm & src2)70 void jit_avx_gemv_t_f32_kern::dot_product(
71         const Xmm &dst, const Xmm &src1, const Xmm &src2) {
72     if (is_avx2_)
73         vfmadd231ps(dst, src1, src2);
74     else {
75         vmulps(scratch_, src1, src2);
76         vaddps(dst, dst, scratch_);
77     }
78 }
79 
80 // Inner loop.
innerloop(int unroll_m,int unroll_n)81 void jit_avx_gemv_t_f32_kern::innerloop(int unroll_m, int unroll_n) {
82     if ((unroll_m > M_UNROLL_) || (unroll_n > N_UNROLL_) || (unroll_m < 0)
83             || (unroll_n < 0))
84         return;
85 
86     int um_vecs = (unroll_m + 7) >> 3;
87 
88     // Load x.
89     for (int i = 0; i < um_vecs; i++) {
90         auto x_mem = ptr[XO_ + size_ * (8 * i - offset_x_)];
91         v_load(x_regs_[i], x_mem, unroll_m);
92     }
93     add(XO_, size_ * unroll_m);
94 
95     Reg64 LDA3 = rax;
96     lea(LDA3, ptr[LDA_ + LDA_ * 2]);
97 
98     // Load A
99     for (int j = 0; j < unroll_n; j++) {
100         for (int i = 0; i < um_vecs; i++) {
101             Ymm a = a_regs_[i][j];
102 
103             decltype(LDA_ * j) lda_mult = (j == 3) ? LDA3 : LDA_ * j;
104 
105             auto a_mem = ptr[AO_ + lda_mult + size_ * (8 * i - offset_a_)];
106             v_load(a, a_mem, unroll_m);
107         }
108     }
109 
110     lea(AO_, ptr[AO_ + size_ * unroll_m]);
111 
112     for (int j = 0; j < unroll_n; j++) {
113         Ymm acc = acc_[j];
114 
115         for (int i = 0; i < um_vecs; i++) {
116             dot_product(acc, x_regs_[i], a_regs_[i][j]);
117         }
118     }
119 }
120 
121 // Outer loop.
outerloop(int unroll_x,int unroll_y,Label * & cur_outerloop_label)122 void jit_avx_gemv_t_f32_kern::outerloop(
123         int unroll_x, int unroll_y, Label *&cur_outerloop_label) {
124     if ((unroll_x > M_UNROLL_) || (unroll_y > N_UNROLL_) || (unroll_y < 0)
125             || (unroll_x < 0))
126         return;
127 
128     Label label_m_loop, label_n_loop, label_m_remainder_loops[5];
129 
130     L(*cur_outerloop_label);
131     cur_outerloop_label++;
132     if (unroll_y >= N_UNROLL_) {
133         mov(I_, N_);
134         cmp(I_, unroll_y);
135         jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label.
136     } else {
137         test(I_, unroll_y);
138         jle(*cur_outerloop_label, T_NEAR);
139     }
140 
141     L_aligned(label_n_loop);
142     {
143 
144         mov(YO_, Y_);
145         lea(Y_, ptr[YO_ + INCY_ * unroll_y]);
146 
147         mov(AO_, A_);
148         lea(A_, ptr[AO_ + LDA_ * unroll_y]);
149 
150         mov(XO_, X_);
151 
152         for (int i = 0; i < unroll_y; i++) {
153             auto acc = acc_[i];
154             vxorps(acc, acc, acc);
155         }
156 
157         mov(J_, M_);
158         cmp(J_, unroll_x);
159         jl(label_m_remainder_loops[0], T_NEAR);
160 
161         L_aligned(label_m_loop);
162         {
163             innerloop(unroll_x, unroll_y);
164             sub(J_, unroll_x);
165             cmp(J_, unroll_x);
166             jge(label_m_loop, T_NEAR);
167         }
168 
169         align(16);
170 
171         // Update y.
172         for (int j = 0; j < unroll_y; j++) {
173             Ymm acc = acc_[j];
174 
175             vhaddps(acc, acc, acc);
176             vperm2f128(scratch_, acc, acc, 0x1);
177             vaddps(acc, acc, scratch_);
178             vhaddps(acc, acc, acc);
179         }
180         for (int j = 0; j < unroll_y; j++) {
181             // TODO Handle negative increments
182             Ymm y = y_regs_[j];
183             Ymm acc = acc_[j];
184 
185             imul(YO2_, INCY_, j);
186             lea(YO2_, ptr[YO_ + YO2_]);
187             auto y_mem = ptr[YO2_];
188 
189             v_load(y, y_mem, 1);
190 
191             if (is_avx2_) {
192                 vfmadd231ss(make_xmm(y), make_xmm(alpha_), make_xmm(acc));
193             } else {
194                 vmulps(make_xmm(scratch_), make_xmm(alpha_), make_xmm(acc));
195                 vaddps(make_xmm(y), make_xmm(y), make_xmm(scratch_));
196             }
197 
198             v_store(y_mem, y, 1);
199         }
200 
201         int label_idx = 0;
202         for (int ux = 8; ux > 0; ux >>= 1) {
203             L(label_m_remainder_loops[label_idx++]);
204             if (unroll_x > ux) {
205                 test(J_, ux);
206                 jle(label_m_remainder_loops[label_idx], T_NEAR);
207 
208                 for (int i = 0; i < unroll_y; i++) {
209                     auto acc = acc_[i];
210                     vxorps(acc, acc, acc);
211                 }
212 
213                 innerloop(ux, unroll_y);
214 
215                 align(16);
216 
217                 // Update y.
218                 for (int j = 0; j < unroll_y; j++) {
219                     Ymm acc = acc_[j];
220 
221                     vhaddps(acc, acc, acc);
222                     vperm2f128(scratch_, acc, acc, 0x1);
223                     vaddps(acc, acc, scratch_);
224                     vhaddps(acc, acc, acc);
225                 }
226                 for (int j = 0; j < unroll_y; j++) {
227                     // TODO Handle negative increments
228                     Ymm y = y_regs_[j];
229                     Ymm acc = acc_[j];
230 
231                     imul(YO2_, INCY_, j);
232                     lea(YO2_, ptr[YO_ + YO2_]);
233                     auto y_mem = ptr[YO2_];
234 
235                     v_load(y, y_mem, 1);
236 
237                     if (is_avx2_) {
238                         vfmadd231ss(
239                                 make_xmm(y), make_xmm(alpha_), make_xmm(acc));
240                     } else {
241                         vmulps(make_xmm(scratch_), make_xmm(alpha_),
242                                 make_xmm(acc));
243                         vaddps(make_xmm(y), make_xmm(y), make_xmm(scratch_));
244                     }
245 
246                     v_store(y_mem, y, 1);
247                 }
248             }
249         }
250         L(label_m_remainder_loops[label_idx]);
251 
252         if (unroll_y >= N_UNROLL_) {
253             sub(I_, unroll_y);
254             cmp(I_, unroll_y);
255             jge(label_n_loop);
256         }
257     }
258 
259     align(16);
260 }
261 
generate()262 void jit_avx_gemv_t_f32_kern::generate() {
263     // Prologue
264     preamble();
265 
266     movss(make_xmm(alpha_), qword[ALPHA_]);
267 
268     if (is_windows) {
269         mov(LDA_, arg_lda_);
270         mov(X_, arg_x_);
271     }
272 
273     mov(Y_, arg_y_);
274     mov(INCY_, arg_incy_);
275 
276     sub(A_, -offset_a_ * size_);
277     sub(X_, -offset_x_ * size_);
278 
279     mov(M_, qword[M_]);
280     mov(N_, qword[N_]);
281     mov(LDA_, qword[LDA_]);
282     mov(INCY_, qword[INCY_]);
283 
284     lea(LDA_, ptr[LDA_ * size_]);
285     lea(INCY_, ptr[INCY_ * size_]);
286 
287     Label outerloop_labels[4];
288     Label *cur_outerloop_label = &outerloop_labels[0];
289 
290     // Main n loop.
291     outerloop(M_UNROLL_, N_UNROLL_, cur_outerloop_label);
292 
293     // n remainder loops.
294     for (int un = 2; un > 0; un >>= 1)
295         if (N_UNROLL_ > un) outerloop(M_UNROLL_, un, cur_outerloop_label);
296 
297     L(*cur_outerloop_label);
298 
299     // Epilogue.
300     postamble();
301 }
302 
303 // Function signature: gemv(*m, *n, *alpha, *a, *lda, *x, *incx, *y, *incy)
jit_avx_gemv_t_f32_kern()304 jit_avx_gemv_t_f32_kern::jit_avx_gemv_t_f32_kern()
305     : jit_generator(nullptr, 100000)
306     , arg_lda_(0)
307     , arg_x_(0)
308     , arg_incx_(0)
309     , arg_y_(0)
310     , arg_incy_(0) {
311 
312     is_avx2_ = mayiuse(avx2);
313 
314     // Assign integer registers
315     M_ = abi_param1;
316     N_ = abi_param2;
317     ALPHA_ = abi_param3;
318     A_ = abi_param4;
319     LDA_ = is_windows ? rdi : r8;
320     X_ = is_windows ? rsi : r9;
321     INCY_ = r10;
322     Y_ = r11;
323 
324     J_ = r12;
325     I_ = r13;
326 
327     AO_ = r14;
328     XO_ = r15;
329 
330     YO_ = rbx;
331     YO2_ = rbp;
332 
333     // Assign vector registers
334     for (int i = 0; i < (N_UNROLL_); i++)
335         y_regs_[i] = Ymm(i);
336 
337     int rn = 0;
338     for (int i = 0; i < (M_UNROLL_ >> 3); i++)
339         for (int j = 0; j < N_UNROLL_; j++)
340             a_regs_[i][j] = Ymm(rn++);
341 
342     x_regs_[0] = ymm8;
343     x_regs_[1] = ymm9;
344 
345     alpha_ = ymm10;
346     scratch_ = ymm11;
347 
348     for (int i = 0; i < (N_UNROLL_); i++)
349         acc_[i] = Ymm(12 + i);
350 
351     // Assign stack variables.
352     auto args_offset = get_size_of_abi_save_regs() + 8 + (is_windows ? 48 : 0);
353 
354     arg_lda_ = ptr[rsp + (args_offset - 16)];
355     arg_x_ = ptr[rsp + (args_offset - 8)];
356     arg_incx_ = ptr[rsp + (args_offset + 0)]; // Assumed 1 for A transpose.
357     arg_y_ = ptr[rsp + (args_offset + 8)];
358     arg_incy_ = ptr[rsp + (args_offset + 16)]; // Assumed 1 for A non-transpose.
359 }
360 
361 } // namespace x64
362 } // namespace cpu
363 } // namespace impl
364 } // namespace dnnl
365