1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include "cpu/x64/cpu_isa_traits.hpp"
18 #include "cpu/x64/jit_generator.hpp"
19
20 #include "cpu/x64/gemm/f32/jit_avx_gemv_t_f32_kern.hpp"
21
22 #ifdef _WIN32
23 static const bool is_windows = true;
24 #else
25 static const bool is_windows = false;
26 #endif
27
28 namespace dnnl {
29 namespace impl {
30 namespace cpu {
31 namespace x64 {
32
33 using namespace Xbyak;
34
35 // Convert between vector register lengths.
make_xmm(const Xmm & v)36 static inline Xmm make_xmm(const Xmm &v) {
37 return Xmm(v.getIdx());
38 }
39
40 // Load vector register data for x, y or A.
v_load(const Xbyak::Xmm & dst,const Xbyak::Address & src,int nelems)41 void jit_avx_gemv_t_f32_kern::v_load(
42 const Xbyak::Xmm &dst, const Xbyak::Address &src, int nelems) {
43 switch (nelems) {
44 case 1: vmovss(make_xmm(dst), src); break;
45 case 2: vmovsd(make_xmm(dst), src); break;
46 case 4: vmovups(make_xmm(dst), src); break;
47 default:
48 assert(nelems >= 8);
49 vmovups(dst, src);
50 break;
51 }
52 }
53
54 // Store vector register data for x, y or A.
v_store(const Xbyak::Address & dst,const Xbyak::Xmm & src,int nelems)55 void jit_avx_gemv_t_f32_kern::v_store(
56 const Xbyak::Address &dst, const Xbyak::Xmm &src, int nelems) {
57 switch (nelems) {
58 case 1: vmovss(dst, make_xmm(src)); break;
59 case 2: vmovsd(dst, make_xmm(src)); break;
60 case 4: vmovups(dst, make_xmm(src)); break;
61 default:
62 assert(nelems >= 8);
63 vmovups(dst, src);
64 break;
65 }
66 }
67
68 // Perform Hadamard product of 2 vectors and accumulate.
69 // Use FMA instruction, otherwise emulate.
dot_product(const Xmm & dst,const Xmm & src1,const Xmm & src2)70 void jit_avx_gemv_t_f32_kern::dot_product(
71 const Xmm &dst, const Xmm &src1, const Xmm &src2) {
72 if (is_avx2_)
73 vfmadd231ps(dst, src1, src2);
74 else {
75 vmulps(scratch_, src1, src2);
76 vaddps(dst, dst, scratch_);
77 }
78 }
79
80 // Inner loop.
innerloop(int unroll_m,int unroll_n)81 void jit_avx_gemv_t_f32_kern::innerloop(int unroll_m, int unroll_n) {
82 if ((unroll_m > M_UNROLL_) || (unroll_n > N_UNROLL_) || (unroll_m < 0)
83 || (unroll_n < 0))
84 return;
85
86 int um_vecs = (unroll_m + 7) >> 3;
87
88 // Load x.
89 for (int i = 0; i < um_vecs; i++) {
90 auto x_mem = ptr[XO_ + size_ * (8 * i - offset_x_)];
91 v_load(x_regs_[i], x_mem, unroll_m);
92 }
93 add(XO_, size_ * unroll_m);
94
95 Reg64 LDA3 = rax;
96 lea(LDA3, ptr[LDA_ + LDA_ * 2]);
97
98 // Load A
99 for (int j = 0; j < unroll_n; j++) {
100 for (int i = 0; i < um_vecs; i++) {
101 Ymm a = a_regs_[i][j];
102
103 decltype(LDA_ * j) lda_mult = (j == 3) ? LDA3 : LDA_ * j;
104
105 auto a_mem = ptr[AO_ + lda_mult + size_ * (8 * i - offset_a_)];
106 v_load(a, a_mem, unroll_m);
107 }
108 }
109
110 lea(AO_, ptr[AO_ + size_ * unroll_m]);
111
112 for (int j = 0; j < unroll_n; j++) {
113 Ymm acc = acc_[j];
114
115 for (int i = 0; i < um_vecs; i++) {
116 dot_product(acc, x_regs_[i], a_regs_[i][j]);
117 }
118 }
119 }
120
121 // Outer loop.
outerloop(int unroll_x,int unroll_y,Label * & cur_outerloop_label)122 void jit_avx_gemv_t_f32_kern::outerloop(
123 int unroll_x, int unroll_y, Label *&cur_outerloop_label) {
124 if ((unroll_x > M_UNROLL_) || (unroll_y > N_UNROLL_) || (unroll_y < 0)
125 || (unroll_x < 0))
126 return;
127
128 Label label_m_loop, label_n_loop, label_m_remainder_loops[5];
129
130 L(*cur_outerloop_label);
131 cur_outerloop_label++;
132 if (unroll_y >= N_UNROLL_) {
133 mov(I_, N_);
134 cmp(I_, unroll_y);
135 jl(*cur_outerloop_label, T_NEAR); // Jump to next outerloop label.
136 } else {
137 test(I_, unroll_y);
138 jle(*cur_outerloop_label, T_NEAR);
139 }
140
141 L_aligned(label_n_loop);
142 {
143
144 mov(YO_, Y_);
145 lea(Y_, ptr[YO_ + INCY_ * unroll_y]);
146
147 mov(AO_, A_);
148 lea(A_, ptr[AO_ + LDA_ * unroll_y]);
149
150 mov(XO_, X_);
151
152 for (int i = 0; i < unroll_y; i++) {
153 auto acc = acc_[i];
154 vxorps(acc, acc, acc);
155 }
156
157 mov(J_, M_);
158 cmp(J_, unroll_x);
159 jl(label_m_remainder_loops[0], T_NEAR);
160
161 L_aligned(label_m_loop);
162 {
163 innerloop(unroll_x, unroll_y);
164 sub(J_, unroll_x);
165 cmp(J_, unroll_x);
166 jge(label_m_loop, T_NEAR);
167 }
168
169 align(16);
170
171 // Update y.
172 for (int j = 0; j < unroll_y; j++) {
173 Ymm acc = acc_[j];
174
175 vhaddps(acc, acc, acc);
176 vperm2f128(scratch_, acc, acc, 0x1);
177 vaddps(acc, acc, scratch_);
178 vhaddps(acc, acc, acc);
179 }
180 for (int j = 0; j < unroll_y; j++) {
181 // TODO Handle negative increments
182 Ymm y = y_regs_[j];
183 Ymm acc = acc_[j];
184
185 imul(YO2_, INCY_, j);
186 lea(YO2_, ptr[YO_ + YO2_]);
187 auto y_mem = ptr[YO2_];
188
189 v_load(y, y_mem, 1);
190
191 if (is_avx2_) {
192 vfmadd231ss(make_xmm(y), make_xmm(alpha_), make_xmm(acc));
193 } else {
194 vmulps(make_xmm(scratch_), make_xmm(alpha_), make_xmm(acc));
195 vaddps(make_xmm(y), make_xmm(y), make_xmm(scratch_));
196 }
197
198 v_store(y_mem, y, 1);
199 }
200
201 int label_idx = 0;
202 for (int ux = 8; ux > 0; ux >>= 1) {
203 L(label_m_remainder_loops[label_idx++]);
204 if (unroll_x > ux) {
205 test(J_, ux);
206 jle(label_m_remainder_loops[label_idx], T_NEAR);
207
208 for (int i = 0; i < unroll_y; i++) {
209 auto acc = acc_[i];
210 vxorps(acc, acc, acc);
211 }
212
213 innerloop(ux, unroll_y);
214
215 align(16);
216
217 // Update y.
218 for (int j = 0; j < unroll_y; j++) {
219 Ymm acc = acc_[j];
220
221 vhaddps(acc, acc, acc);
222 vperm2f128(scratch_, acc, acc, 0x1);
223 vaddps(acc, acc, scratch_);
224 vhaddps(acc, acc, acc);
225 }
226 for (int j = 0; j < unroll_y; j++) {
227 // TODO Handle negative increments
228 Ymm y = y_regs_[j];
229 Ymm acc = acc_[j];
230
231 imul(YO2_, INCY_, j);
232 lea(YO2_, ptr[YO_ + YO2_]);
233 auto y_mem = ptr[YO2_];
234
235 v_load(y, y_mem, 1);
236
237 if (is_avx2_) {
238 vfmadd231ss(
239 make_xmm(y), make_xmm(alpha_), make_xmm(acc));
240 } else {
241 vmulps(make_xmm(scratch_), make_xmm(alpha_),
242 make_xmm(acc));
243 vaddps(make_xmm(y), make_xmm(y), make_xmm(scratch_));
244 }
245
246 v_store(y_mem, y, 1);
247 }
248 }
249 }
250 L(label_m_remainder_loops[label_idx]);
251
252 if (unroll_y >= N_UNROLL_) {
253 sub(I_, unroll_y);
254 cmp(I_, unroll_y);
255 jge(label_n_loop);
256 }
257 }
258
259 align(16);
260 }
261
generate()262 void jit_avx_gemv_t_f32_kern::generate() {
263 // Prologue
264 preamble();
265
266 movss(make_xmm(alpha_), qword[ALPHA_]);
267
268 if (is_windows) {
269 mov(LDA_, arg_lda_);
270 mov(X_, arg_x_);
271 }
272
273 mov(Y_, arg_y_);
274 mov(INCY_, arg_incy_);
275
276 sub(A_, -offset_a_ * size_);
277 sub(X_, -offset_x_ * size_);
278
279 mov(M_, qword[M_]);
280 mov(N_, qword[N_]);
281 mov(LDA_, qword[LDA_]);
282 mov(INCY_, qword[INCY_]);
283
284 lea(LDA_, ptr[LDA_ * size_]);
285 lea(INCY_, ptr[INCY_ * size_]);
286
287 Label outerloop_labels[4];
288 Label *cur_outerloop_label = &outerloop_labels[0];
289
290 // Main n loop.
291 outerloop(M_UNROLL_, N_UNROLL_, cur_outerloop_label);
292
293 // n remainder loops.
294 for (int un = 2; un > 0; un >>= 1)
295 if (N_UNROLL_ > un) outerloop(M_UNROLL_, un, cur_outerloop_label);
296
297 L(*cur_outerloop_label);
298
299 // Epilogue.
300 postamble();
301 }
302
303 // Function signature: gemv(*m, *n, *alpha, *a, *lda, *x, *incx, *y, *incy)
jit_avx_gemv_t_f32_kern()304 jit_avx_gemv_t_f32_kern::jit_avx_gemv_t_f32_kern()
305 : jit_generator(nullptr, 100000)
306 , arg_lda_(0)
307 , arg_x_(0)
308 , arg_incx_(0)
309 , arg_y_(0)
310 , arg_incy_(0) {
311
312 is_avx2_ = mayiuse(avx2);
313
314 // Assign integer registers
315 M_ = abi_param1;
316 N_ = abi_param2;
317 ALPHA_ = abi_param3;
318 A_ = abi_param4;
319 LDA_ = is_windows ? rdi : r8;
320 X_ = is_windows ? rsi : r9;
321 INCY_ = r10;
322 Y_ = r11;
323
324 J_ = r12;
325 I_ = r13;
326
327 AO_ = r14;
328 XO_ = r15;
329
330 YO_ = rbx;
331 YO2_ = rbp;
332
333 // Assign vector registers
334 for (int i = 0; i < (N_UNROLL_); i++)
335 y_regs_[i] = Ymm(i);
336
337 int rn = 0;
338 for (int i = 0; i < (M_UNROLL_ >> 3); i++)
339 for (int j = 0; j < N_UNROLL_; j++)
340 a_regs_[i][j] = Ymm(rn++);
341
342 x_regs_[0] = ymm8;
343 x_regs_[1] = ymm9;
344
345 alpha_ = ymm10;
346 scratch_ = ymm11;
347
348 for (int i = 0; i < (N_UNROLL_); i++)
349 acc_[i] = Ymm(12 + i);
350
351 // Assign stack variables.
352 auto args_offset = get_size_of_abi_save_regs() + 8 + (is_windows ? 48 : 0);
353
354 arg_lda_ = ptr[rsp + (args_offset - 16)];
355 arg_x_ = ptr[rsp + (args_offset - 8)];
356 arg_incx_ = ptr[rsp + (args_offset + 0)]; // Assumed 1 for A transpose.
357 arg_y_ = ptr[rsp + (args_offset + 8)];
358 arg_incy_ = ptr[rsp + (args_offset + 16)]; // Assumed 1 for A non-transpose.
359 }
360
361 } // namespace x64
362 } // namespace cpu
363 } // namespace impl
364 } // namespace dnnl
365