1from __future__ import absolute_import, print_function, division
2import numpy as np
3
4from theano import Op, Apply
5from six import StringIO
6
7try:
8    import pygpu
9    from pygpu import gpuarray
10except ImportError:
11    pass
12
13from .basic_ops import (as_gpuarray_variable, GpuKernelBase, Kernel, gpuarray_helper_inc_dir,
14                        infer_context_name)
15from .type import GpuArrayType
16from .fp16_help import work_dtype, load_w, write_w
17
18
19class GpuCrossentropySoftmaxArgmax1HotWithBias(GpuKernelBase, Op):
20    """
21    Implement CrossentropySoftmaxArgmax1HotWithBias on the gpu.
22
23    """
24    nin = 3
25    nout = 3
26    __props__ = ()
27    _f16_ok = True
28
29    def make_node(self, x, b, y_idx):
30        ctx_name = infer_context_name(x, b, y_idx)
31        x = as_gpuarray_variable(x, ctx_name)
32        b = as_gpuarray_variable(b, ctx_name)
33        y_idx = as_gpuarray_variable(y_idx, ctx_name)
34        nll = GpuArrayType(x.type.dtype,
35                           y_idx.type.broadcastable,
36                           context_name=ctx_name)()
37        sm = x.type()
38        am = y_idx.type()
39        return Apply(self, [x, b, y_idx], [nll, sm, am])
40
41    def c_headers(self):
42        return ['<numpy_compat.h>', '<gpuarray/types.h>', 'gpuarray_helper.h']
43
44    def c_header_dirs(self):
45        return [gpuarray_helper_inc_dir()]
46
47    def gpu_kernels(self, node, nodename):
48        dtype_x = node.inputs[0].dtype
49        dtype_b = node.inputs[1].dtype
50        dtype_y_idx = node.inputs[2].dtype
51        work_x = work_dtype(dtype_x)
52        work_b = work_dtype(dtype_b)
53        load_x = load_w(dtype_x)
54        load_b = load_w(dtype_b)
55        write_x = write_w(dtype_x)
56        write_b = write_w(dtype_b)
57        flags = Kernel.get_flags(dtype_x, dtype_b, dtype_y_idx)
58        type_x = gpuarray.dtype_to_ctype(dtype_x)
59        type_b = gpuarray.dtype_to_ctype(dtype_b)
60        work_x = gpuarray.dtype_to_ctype(work_x)
61        type_y_idx = gpuarray.dtype_to_ctype(dtype_y_idx)
62        kname = "k_xent_sm_1hot_bias"
63        k_var = "k_xent_sm_1hot_bias_" + nodename
64        if node.inputs[0].type.context.kind != b'cuda':
65            f = ''
66        else:
67            f = '' if dtype_x == 'float64' else 'f'
68        params = [
69            gpuarray.SIZE, gpuarray.SIZE,
70            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE,
71            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE,
72            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE,
73            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE,
74            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE,
75            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE
76        ]
77        sio = StringIO()
78        print("""#include "cluda.h"
79
80        KERNEL void %(kname)s(const ga_size M, const ga_size N,
81            GLOBAL_MEM const %(type_x)s* x_data, const ga_size offset_x, const ga_ssize xs0, const ga_ssize xs1,
82            GLOBAL_MEM const %(type_b)s* b, const ga_size offset_b, const ga_ssize bs0,
83            GLOBAL_MEM const %(type_y_idx)s* y_idx_data, const ga_size offset_y_idx, const ga_ssize y_idxs0,
84            GLOBAL_MEM %(type_x)s* nll_data, const ga_size offset_nll, const ga_ssize nlls0,
85            GLOBAL_MEM %(type_x)s* sm_data, const ga_size offset_sm, const ga_ssize sms0, const ga_ssize sms1,
86            GLOBAL_MEM %(type_y_idx)s* am_data, const ga_size offset_am, const ga_ssize ams0 GA_DECL_SHARED_PARAM(%(work_x)s, per_thread_values))
87        {
88          x_data = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x_data)+offset_x);
89          b = (GLOBAL_MEM const %(type_b)s *)(((GLOBAL_MEM char *)b)+offset_b);
90          y_idx_data = (GLOBAL_MEM const %(type_y_idx)s *)(((GLOBAL_MEM char *)y_idx_data)+offset_y_idx);
91          nll_data = (GLOBAL_MEM %(type_x)s *)(((GLOBAL_MEM char *)nll_data)+offset_nll);
92          sm_data = (GLOBAL_MEM %(type_x)s *)(((GLOBAL_MEM char *)sm_data)+offset_sm);
93          am_data = (GLOBAL_MEM %(type_y_idx)s *)(((GLOBAL_MEM char *)am_data)+offset_am);
94          for (ga_int row = GID_0; row < M; row += GDIM_0){
95            GLOBAL_MEM const %(type_x)s* x = x_data + xs0 * row;
96            GLOBAL_MEM %(type_x)s* sm = sm_data + sms0 * row;
97            GA_DECL_SHARED_BODY(%(work_x)s, per_thread_values);
98            LOCAL_MEM %(work_x)s row_max, sum, sum_inv;
99            LOCAL_MEM ga_int row_max_threadIdx;
100            %(work_x)s per_thread_row_max, per_thread_sum;
101            ga_int per_thread_row_max_j;
102            // COMPUTE ROW MAX AND ARGMAX
103            // compute separate per-thread maximums and argmaxes
104            per_thread_row_max = NAN;
105            per_thread_row_max_j = 0;
106            for (ga_int j = LID_0; j < N; j += LDIM_0)
107            {
108              %(work_x)s row_ij = %(load_x)s(x[j * xs1]) + %(load_b)s(b[j * bs0]);
109              per_thread_row_max_j = (row_ij > per_thread_row_max) ? j : per_thread_row_max_j;
110              per_thread_row_max = fmax%(f)s(row_ij, per_thread_row_max);
111            }
112            per_thread_values[LID_0] = per_thread_row_max;
113            local_barrier();
114            if (LID_0 == 0) {
115              row_max = NAN;
116              row_max_threadIdx = 0;
117              for (ga_int j = 0; j < LDIM_0; j++)
118              {
119                %(work_x)s per_thread_max = per_thread_values[j];
120                row_max_threadIdx = (per_thread_max > row_max) ? j : row_max_threadIdx;
121                row_max = fmax%(f)s(per_thread_max, row_max);
122              }
123            }
124            local_barrier();
125            // The thread with the highest max writes out which of its
126            // values was the winner.
127            if (LID_0 == row_max_threadIdx) am_data[row * ams0] = per_thread_row_max_j;
128            // COMPUTE SOFTMAX
129            per_thread_sum = 0.0;
130            for (ga_int j = LID_0; j < N; j += LDIM_0)
131            {
132              %(work_x)s row_ij = %(load_x)s(x[j * xs1]) + %(load_b)s(b[j * bs0]);
133              %(work_x)s sm_ij = exp%(f)s(row_ij - row_max);
134              per_thread_sum += sm_ij;
135              sm[j * sms1] = %(write_x)s(sm_ij);
136            }
137            per_thread_values[LID_0] = per_thread_sum;
138            local_barrier();
139            if (LID_0 == 0) {
140              sum = 0.0;
141              for (ga_int j = 0; j < LDIM_0; j++) {
142                sum += per_thread_values[j];
143              }
144              sum_inv = 1.0 / sum;
145            }
146            local_barrier();
147            for (ga_int j = LID_0; j < N; j += LDIM_0) {
148              sm[j * sms1] = %(write_x)s(%(load_x)s(sm[j * sms1]) * sum_inv);
149            }
150            if (LID_0 == 0) {
151              const %(type_y_idx)s y_idx = (ga_int)y_idx_data[row * y_idxs0];
152              if ((y_idx >= N || y_idx < 0)) {
153                // raise some suspicion.
154                nll_data[row * nlls0] = %(write_x)s(0.0);
155              } else {
156                nll_data[row * nlls0] = %(write_x)s(
157                   - %(load_x)s(x[y_idx * xs1])
158                   - %(load_b)s(b[y_idx * bs0])
159                   + row_max + log%(f)s(sum));
160              }
161            }
162          }
163        }
164        """ % locals(), file=sio)
165
166        return [Kernel(code=sio.getvalue(), name=kname, params=params,
167                       flags=flags, objvar=k_var)]
168
169    def c_code(self, node, nodename, inp, out, sub):
170        itemsize_x = np.dtype(node.inputs[0].dtype).itemsize
171        worksize_x = np.dtype(work_dtype(node.inputs[0].dtype)).itemsize
172        itemsize_b = np.dtype(node.inputs[1].dtype).itemsize
173        itemsize_y_idx = np.dtype(node.inputs[2].dtype).itemsize
174        itemsize_nll = np.dtype(node.outputs[0].dtype).itemsize
175        itemsize_sm = np.dtype(node.outputs[1].dtype).itemsize
176        itemsize_am = np.dtype(node.outputs[2].dtype).itemsize
177        x, b, y_idx = inp
178        nll, sm, am = out
179        fail = sub['fail']
180        ctx = sub['params']
181        k_var = "k_xent_sm_1hot_bias_%(nodename)s" % locals()
182        err_check = """
183            if (err != GA_NO_ERROR) {
184                PyErr_Format(PyExc_RuntimeError,
185                             "gpuarray error: %(k_var)s: %%s.",
186                             GpuKernel_error(&%(k_var)s, err));
187                %(fail)s;
188            }
189        """ % locals()
190        sio = StringIO()
191        print("""
192        if (PyGpuArray_DIMS(%(x)s)[0] !=
193            PyGpuArray_DIMS(%(y_idx)s)[0])
194        {
195            PyErr_SetString(PyExc_ValueError,
196                            "dimension mismatch in x,y_idx arguments");
197            %(fail)s;
198        }
199        if (PyGpuArray_DIMS(%(x)s)[1] != PyGpuArray_DIMS(%(b)s)[0])
200        {
201            PyErr_SetString(PyExc_ValueError,
202                            "dimension mismatch in x,b arguments");
203            %(fail)s;
204        }
205        if (theano_prep_output(&%(nll)s, 1, PyGpuArray_DIMS(%(y_idx)s), %(x)s->ga.typecode, GA_C_ORDER, %(ctx)s)) %(fail)s
206        if (theano_prep_output(&%(sm)s, 2, PyGpuArray_DIMS(%(x)s), %(x)s->ga.typecode, GA_C_ORDER, %(ctx)s)) %(fail)s
207        if (theano_prep_output(&%(am)s, 1, PyGpuArray_DIMS(%(y_idx)s), %(y_idx)s->ga.typecode, GA_C_ORDER, %(ctx)s)) %(fail)s
208        {
209            size_t n_blocks = std::min(PyGpuArray_DIM(%(x)s, 0), (size_t)4096);
210            size_t n_threads = std::min(PyGpuArray_DIM(%(x)s, 1), (size_t)256);
211            size_t n_shared = n_threads * %(worksize_x)s;
212     //TODO: launch more threads per row and do parallel sum and max reductions
213            int err = k_xent_sm_1hot_bias_call(
214                1, &n_blocks, &n_threads, n_shared,
215                PyGpuArray_DIMS(%(x)s)[0],
216                PyGpuArray_DIMS(%(x)s)[1],
217                %(x)s->ga.data, %(x)s->ga.offset,
218                PyGpuArray_STRIDE(%(x)s, 0) / %(itemsize_x)s,
219                PyGpuArray_STRIDE(%(x)s, 1) / %(itemsize_x)s,
220                %(b)s->ga.data, %(b)s->ga.offset,
221                PyGpuArray_STRIDE(%(b)s, 0) / %(itemsize_b)s,
222                %(y_idx)s->ga.data, %(y_idx)s->ga.offset,
223                PyGpuArray_STRIDE(%(y_idx)s, 0) / %(itemsize_y_idx)s,
224                %(nll)s->ga.data, %(nll)s->ga.offset,
225                PyGpuArray_STRIDE(%(nll)s, 0) / %(itemsize_nll)s,
226                %(sm)s->ga.data, %(sm)s->ga.offset,
227                PyGpuArray_STRIDE(%(sm)s, 0) / %(itemsize_sm)s,
228                PyGpuArray_STRIDE(%(sm)s, 1) / %(itemsize_sm)s,
229                %(am)s->ga.data, %(am)s->ga.offset,
230                PyGpuArray_STRIDE(%(am)s, 0) / %(itemsize_am)s);
231            %(err_check)s
232        }
233        """ % locals(), file=sio)
234        return sio.getvalue()
235
236    def c_code_cache_version(self):
237        return (14,)
238
239
240gpu_crossentropy_softmax_argmax_1hot_with_bias = GpuCrossentropySoftmaxArgmax1HotWithBias()
241
242
243class GpuCrossentropySoftmax1HotWithBiasDx(GpuKernelBase, Op):
244    """
245    Implement CrossentropySoftmax1HotWithBiasDx on the gpu.
246
247    Gradient wrt x of the CrossentropySoftmax1Hot Op.
248
249    """
250    nin = 3
251    nout = 1
252    __props__ = ()
253    _f16_ok = True
254
255    def make_node(self, dnll, sm, y_idx):
256        ctx_name = infer_context_name(dnll, sm, y_idx)
257        dnll = as_gpuarray_variable(dnll, ctx_name)
258        sm = as_gpuarray_variable(sm, ctx_name)
259        y_idx = as_gpuarray_variable(y_idx, ctx_name)
260        return Apply(self, [dnll, sm, y_idx], [sm.type()])
261
262    def c_code_cache_version(self):
263        return (14,)
264
265    def c_headers(self):
266        return ['<numpy_compat.h>', '<gpuarray/types.h>']
267
268    def c_code(self, node, nodename, inp, out, sub):
269        typecode_dx = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype)
270        itemsize_dnll = np.dtype(node.inputs[0].dtype).itemsize
271        itemsize_sm = np.dtype(node.inputs[1].dtype).itemsize
272        itemsize_y_idx = np.dtype(node.inputs[2].dtype).itemsize
273        itemsize_dx = np.dtype(node.outputs[0].dtype).itemsize
274        dtype_dnll = node.inputs[0].dtype
275        dtype_sm = node.inputs[1].dtype
276        dtype_y_idx = node.inputs[2].dtype
277        dtype_dx = node.outputs[0].dtype
278        type_intp = gpuarray.dtype_to_ctype(np.intp)
279        dnll, sm, y_idx = inp
280        dx, = out
281        fail = sub['fail']
282        ctx = sub['params']
283        k_var = "kCrossEntropySoftmax1HotWithBiasDx_" + nodename
284        err_check = """
285            if (err != GA_NO_ERROR) {
286                PyErr_Format(PyExc_RuntimeError,
287                             "gpuarray error: %(k_var)s: %%s.",
288                             GpuKernel_error(&%(k_var)s, err));
289                %(fail)s;
290            }
291        """ % locals()
292        return """
293        // Get `dnll.shape[0]` or set it to zero if `dnll` is a scalar.
294        const ssize_t %(dnll)s_dims0 = (PyGpuArray_NDIM(%(dnll)s) > 0 ?
295                                        PyGpuArray_DIMS(%(dnll)s)[0] :
296                                        (ssize_t) 0);
297        // Get `dnll.strides[0]` and set it to zero if `dnll` is a scalar
298        // or a vector with just one element.
299        const ssize_t %(dnll)s_strides0 = (%(dnll)s_dims0 > 1 ?
300                                           PyGpuArray_STRIDES(%(dnll)s)[0] :
301                                           (ssize_t) 0);
302        if ((PyGpuArray_NDIM(%(dnll)s) > 1)
303            || (PyGpuArray_NDIM(%(sm)s) != 2)
304            || (PyGpuArray_NDIM(%(y_idx)s) != 1))
305        {
306            PyErr_SetString(PyExc_ValueError, "rank error");
307            %(fail)s;
308        }
309        if (%(dnll)s_dims0 !=
310            PyGpuArray_DIMS(%(sm)s)[0] && %(dnll)s_dims0 > 1)
311        {
312            PyErr_Format(PyExc_ValueError,
313                         "dnll.shape[0] == %%i, but sm.shape[0] == %%i",
314                         %(dnll)s_dims0,
315                         PyGpuArray_DIMS(%(sm)s)[0]);
316            %(fail)s;
317        }
318        if (%(dnll)s_dims0 !=
319            PyGpuArray_DIMS(%(y_idx)s)[0] && %(dnll)s_dims0 > 1)
320        {
321            PyErr_SetString(PyExc_ValueError,
322                            "dnll.shape[0] != y_idx.shape[0]");
323            %(fail)s;
324        }
325        if (PyGpuArray_DIMS(%(sm)s)[0] !=
326            PyGpuArray_DIMS(%(y_idx)s)[0])
327        {
328            PyErr_SetString(PyExc_ValueError,
329                            "sm.shape[0] != y_idx.shape[0]");
330            %(fail)s;
331        }
332        if ((NULL == %(dx)s)
333            || (PyGpuArray_DIMS(%(dx)s)[0] !=
334                PyGpuArray_DIMS(%(sm)s)[0])
335            || (PyGpuArray_DIMS(%(dx)s)[1] !=
336                PyGpuArray_DIMS(%(sm)s)[1]))
337        {
338            Py_XDECREF(%(dx)s);
339            %(dx)s = pygpu_empty(2, PyGpuArray_DIMS(%(sm)s),
340                                 %(typecode_dx)s, GA_C_ORDER,
341                                 %(ctx)s, Py_None);
342            if (!%(dx)s) {
343                %(fail)s
344            }
345        }
346        {
347            size_t n_blocks[3] = {std::min(PyGpuArray_DIMS(%(dx)s)[0], (size_t)256), 1, 1};
348            size_t threads_per_block[3] = {std::min(PyGpuArray_DIMS(%(dx)s)[1], (size_t)256), 1, 1};
349            ssize_t stride_DNLL0 = %(dnll)s_strides0 / %(itemsize_dnll)s;
350            ssize_t stride_SM0 = PyGpuArray_STRIDES(%(sm)s)[0] / %(itemsize_sm)s;
351            ssize_t stride_SM1 = PyGpuArray_STRIDES(%(sm)s)[1] / %(itemsize_sm)s;
352            ssize_t stride_YIDX0 = PyGpuArray_STRIDES(%(y_idx)s)[0] / %(itemsize_y_idx)s;
353            ssize_t stride_DX0 = PyGpuArray_STRIDES(%(dx)s)[0] / %(itemsize_dx)s;
354            ssize_t stride_DX1 = PyGpuArray_STRIDES(%(dx)s)[1] / %(itemsize_dx)s;
355            void *kernel_params[] = {
356                (void *)&PyGpuArray_DIMS(%(dx)s)[0],
357                (void *)&PyGpuArray_DIMS(%(dx)s)[1],
358                (void *)%(dnll)s->ga.data, (void *)&%(dnll)s->ga.offset,
359                (void *)&stride_DNLL0,
360                (void *)%(sm)s->ga.data, (void *)&%(sm)s->ga.offset,
361                (void *)&stride_SM0, (void *)&stride_SM1,
362                (void *)%(y_idx)s->ga.data, (void *)&%(y_idx)s->ga.offset,
363                (void *)&stride_YIDX0,
364                (void *)%(dx)s->ga.data, (void *)&%(dx)s->ga.offset,
365                (void *)&stride_DX0, (void *)&stride_DX1};
366            int err = GpuKernel_call(&%(k_var)s, 3, n_blocks, threads_per_block, 0, kernel_params);
367            %(err_check)s
368        }
369        assert(%(dx)s);
370        """ % locals()
371
372    def gpu_kernels(self, node, nodename):
373        dtype_dnll = node.inputs[0].dtype
374        dtype_sm = node.inputs[1].dtype
375        dtype_y_idx = node.inputs[2].dtype
376        dtype_dx = node.outputs[0].dtype
377        work_dnll = work_dtype(dtype_dnll)
378        load_dnll = load_w(dtype_dnll)
379        load_sm = load_w(dtype_sm)
380        write_dx = write_w(dtype_dx)
381        flags = Kernel.get_flags(dtype_dnll, dtype_sm, dtype_y_idx, dtype_dx)
382        wtype_dnll = gpuarray.dtype_to_ctype(work_dnll)
383        type_dnll = gpuarray.dtype_to_ctype(dtype_dnll)
384        type_sm = gpuarray.dtype_to_ctype(dtype_sm)
385        type_y_idx = gpuarray.dtype_to_ctype(dtype_y_idx)
386        type_dx = gpuarray.dtype_to_ctype(dtype_dx)
387        kname = "kCrossEntropySoftmax1HotWithBiasDx"
388        k_var = "kCrossEntropySoftmax1HotWithBiasDx_" + nodename
389        params = [
390            gpuarray.SIZE, gpuarray.SIZE,
391            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE,
392            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE,
393            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE,
394            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE,
395        ]
396        sio = StringIO()
397        print("""#include "cluda.h"
398
399        KERNEL void %(kname)s(
400           const ga_size N, const ga_size K,
401           GLOBAL_MEM const %(type_dnll)s* dnll, const ga_size offset_dnll, const ga_ssize dnll_s0,
402           GLOBAL_MEM const %(type_sm)s* sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1,
403           GLOBAL_MEM const %(type_y_idx)s* y_idx, const ga_size offset_y_idx, const ga_ssize y_idx_s0,
404           GLOBAL_MEM %(type_dx)s* dx, const ga_size offset_dx, const ga_ssize dx_s0, const ga_ssize dx_s1)
405        {
406            dnll = (GLOBAL_MEM const %(type_dnll)s *)(((GLOBAL_MEM char *)dnll)+offset_dnll);
407            sm = (GLOBAL_MEM const %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
408            y_idx = (GLOBAL_MEM const %(type_y_idx)s *)(((GLOBAL_MEM char *)y_idx)+offset_y_idx);
409            dx = (GLOBAL_MEM %(type_dx)s *)(((GLOBAL_MEM char *)dx)+offset_dx);
410            for (ga_int i = GID_0; i < N; i += GDIM_0)
411            {
412                %(wtype_dnll)s dnll_i = %(load_dnll)s(dnll[i * dnll_s0]);
413                %(type_y_idx)s y_i = y_idx[i * y_idx_s0];
414                for (ga_int j = LID_0; j < K; j += LDIM_0)
415                {
416                    if (y_i == j)
417                    {
418                        dx[i * dx_s0 + j * dx_s1] =
419                            %(write_dx)s(dnll_i *
420                              (%(load_sm)s(sm[i * sm_s0 + j * sm_s1]) - 1.0));
421                    }
422                    else
423                    {
424                        dx[i * dx_s0 + j * dx_s1] =
425                            %(write_dx)s(dnll_i *
426                              %(load_sm)s(sm[i * sm_s0 + j * sm_s1]));
427                    }
428                }
429            }
430        }
431        """ % locals(), file=sio)
432        return [Kernel(code=sio.getvalue(), name=kname, params=params,
433                       flags=flags, objvar=k_var)]
434
435
436gpu_crossentropy_softmax_1hot_with_bias_dx = GpuCrossentropySoftmax1HotWithBiasDx()
437
438
439class GpuSoftmax(GpuKernelBase, Op):
440    """
441    Implement Softmax on the gpu.
442
443    """
444    __props__ = ()
445    _f16_ok = True
446
447    def make_node(self, x):
448        x = as_gpuarray_variable(x, infer_context_name(x))
449        return Apply(self, [x], [x.type()])
450
451    def infer_shape(self, node, shape):
452        return shape
453
454    def c_code_cache_version(self):
455        return (17,)
456
457    def c_headers(self):
458        return ['<numpy_compat.h>', '<gpuarray/types.h>']
459
460    def c_code(self, node, nodename, inp, out, sub):
461        dtype_x = node.inputs[0].dtype
462        work_x = work_dtype(dtype_x)
463        dtype_z = node.outputs[0].dtype
464        itemsize_x = np.dtype(dtype_x).itemsize
465        itemsize_z = np.dtype(dtype_z).itemsize
466        typecode = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype)
467        x, = inp
468        z, = out
469        fail = sub['fail']
470        ctx = sub['params']
471        err_check = """
472            if (err != GA_NO_ERROR) {
473                PyErr_Format(PyExc_RuntimeError, fmt_str, msg);
474                %(fail)s;
475            }
476        """ % locals()
477        return """
478        if (PyGpuArray_NDIM(%(x)s) != 2)
479        {
480            PyErr_SetString(PyExc_ValueError, "rank error");
481            %(fail)s;
482        }
483        if ((NULL == %(z)s) ||
484            (PyGpuArray_DIMS(%(z)s)[0] !=
485             PyGpuArray_DIMS(%(x)s)[0]) ||
486            (PyGpuArray_DIMS(%(z)s)[1] !=
487             PyGpuArray_DIMS(%(x)s)[1]))
488        {
489            Py_XDECREF(%(z)s);
490            %(z)s = pygpu_empty(2, PyGpuArray_DIMS(%(x)s),
491                                %(typecode)s, GA_C_ORDER,
492                                %(ctx)s, Py_None);
493            if (!%(z)s) {
494                %(fail)s
495            }
496        }
497        {
498            size_t n_blocks[3] = {std::min(PyGpuArray_DIMS(%(x)s)[0], (size_t)(32 * 1024)), 1, 1};
499//TODO, detect the maximum number of thread per block.
500            size_t threads_per_block[3] = {std::min(PyGpuArray_DIMS(%(x)s)[1], (size_t)256), 1, 1}; // TODO: Read GA_CTX_PROP_MAXLSIZE0
501            size_t shmem_sz = PyGpuArray_DIMS(%(x)s)[1] *
502                                     2 * sizeof(npy_%(work_x)s);
503            ssize_t stride_X0 = PyGpuArray_STRIDES(%(x)s)[0] / %(itemsize_x)s;
504            ssize_t stride_X1 = PyGpuArray_STRIDES(%(x)s)[1] / %(itemsize_x)s;
505            ssize_t stride_Z0 = PyGpuArray_STRIDES(%(z)s)[0] / %(itemsize_z)s;
506            ssize_t stride_Z1 = PyGpuArray_STRIDES(%(z)s)[1] / %(itemsize_z)s;
507            const char *fmt_str, *msg;
508            void *kernel_params[] = {
509                (void *)&PyGpuArray_DIMS(%(x)s)[0],
510                (void *)&PyGpuArray_DIMS(%(x)s)[1],
511                (void *)%(x)s->ga.data, (void *)&%(x)s->ga.offset,
512                (void *)&stride_X0, (void *)&stride_X1,
513                (void *)%(z)s->ga.data, (void *)&%(z)s->ga.offset,
514                (void *)&stride_Z0, (void *)&stride_Z1};
515            int err = GA_NO_ERROR;
516            if (PyGpuArray_DIMS(%(x)s)[0] > 0)
517            {
518              //Those numbers are based on not too recent GPU
519              //to make them compatible with more GPU.
520              //TODO: read the information from the card.
521              if(shmem_sz < (32 * 1024 - 500)){
522                err = GpuKernel_call(&kSoftmax_%(nodename)s, 3,
523                                     n_blocks, threads_per_block, shmem_sz,
524                                     kernel_params);
525                fmt_str = "gpuarray error: kSoftmax_%(nodename)s: %%s";
526                msg = GpuKernel_error(&kSoftmax_%(nodename)s, err);
527              }else{
528                err = GpuKernel_call(&kSoftmax_fixed_shared%(nodename)s, 3,
529                                     n_blocks, threads_per_block,
530                                     threads_per_block[0] * sizeof(npy_%(work_x)s),
531                                     kernel_params);
532                fmt_str = "gpuarray error: kSoftmax_fixed_shared%(nodename)s: %%s";
533                msg = GpuKernel_error(&kSoftmax_fixed_shared%(nodename)s, err);
534              }
535              %(err_check)s
536            }
537        }
538        assert(%(z)s);
539        """ % locals()
540
541    def gpu_kernels(self, node, nodename):
542        dtype_x = node.inputs[0].dtype
543        dtype_sm = node.outputs[0].dtype
544        load_x = load_w(dtype_x)
545        write_sm = write_w(node.outputs[0].dtype)
546        work_sm = work_dtype(dtype_sm)
547        flags = Kernel.get_flags(dtype_x, dtype_sm)
548        type_x = gpuarray.dtype_to_ctype(dtype_x)
549        type_sm = gpuarray.dtype_to_ctype(dtype_sm)
550        type_acc = gpuarray.dtype_to_ctype(work_sm)
551
552        ctype = gpuarray.dtype_to_ctype(work_sm)
553
554        params = [
555            gpuarray.SIZE, gpuarray.SIZE,
556            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE,
557            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE
558        ]
559        kernels = []
560        kname = "kSoftmax"
561        k_var = "kSoftmax_" + nodename
562        code = """#include "cluda.h"
563
564        KERNEL void %(kname)s (const ga_size M, const ga_size N,
565                               GLOBAL_MEM const %(type_x)s * x, const ga_size offset_x, const ga_ssize sx0, const ga_ssize sx1,
566                               GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
567        {
568            GA_DECL_SHARED_BODY(%(type_acc)s, buf);
569            LOCAL_MEM_ARG %(type_acc)s * buf2 = buf + N;
570            x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
571            sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
572            for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0) {
573                for (ga_int tx = LID_0; tx< N; tx += LDIM_0) {
574                    buf[tx] = %(load_x)s(x[blockIDX * sx0 + tx * sx1]);
575                    buf2[tx] = buf[tx];
576                }
577                local_barrier();
578                {
579                    // This function trashes buf[1..GA_WARP_SIZE],
580                    // leaving the reduction result in buf[0].
581                    if (LID_0 < GA_WARP_SIZE) {
582                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
583                        {
584                            buf[LID_0] = max(buf[LID_0], buf[i]);
585                        }
586                    }
587                    local_barrier();
588                    //reduce so that LID_0 0 has the reduction of everything
589                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
590                        if (LID_0 < _n && LID_0 + _n < N)
591                            buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
592                        local_barrier();
593                    }
594                }
595                %(ctype)s row_max = buf[0];
596                local_barrier();
597                for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){
598                    buf[__i] = exp(buf2[__i] - row_max);
599                    buf2[__i] = buf[__i];
600                }
601                local_barrier();
602                {
603                    // This function trashes buf[1..GA_WARP_SIZE],
604                    // leaving the reduction result in buf[0].
605                    if (LID_0 < GA_WARP_SIZE) {
606                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
607                        {
608                            buf[LID_0] = buf[LID_0] + buf[i];
609                        }
610                    }
611                    local_barrier();
612                    //reduce so that LID_0 0 has the reduction of everything
613                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
614                        if (LID_0 < _n && LID_0 + _n < N)
615                            buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
616                        local_barrier();
617                    }
618                }
619                %(ctype)s row_sum = buf[0];
620                local_barrier();
621                for(ga_int __i=LID_0; __i<N; __i+=LDIM_0) {
622                    buf[__i] = buf2[__i] / row_sum;
623                }
624                local_barrier();
625                for (ga_int tx = LID_0; tx< N; tx += LDIM_0) {
626                    sm[blockIDX * sm_s0 + tx * sm_s1] = %(write_sm)s(buf[tx]);
627                }
628                local_barrier();
629            }
630        }
631        """ % locals()
632        kernels.append(Kernel(code=code, name=kname, params=params,
633                              flags=flags, objvar=k_var))
634        kname = "kSoftmax_fixed_shared"
635        k_var = "kSoftmax_fixed_shared" + nodename
636        code = """#include "cluda.h"
637
638        KERNEL void %(kname)s (const ga_size M, const ga_size N,
639                               GLOBAL_MEM const %(type_x)s * x, const ga_size offset_x, const ga_ssize sx0, const ga_ssize sx1,
640                               GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
641        {
642            GA_DECL_SHARED_BODY(%(type_acc)s, buf);
643            x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
644            sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
645            for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
646                GLOBAL_MEM const %(type_x)s *x_ptr = &x[blockIDX * sx0];
647                GLOBAL_MEM %(type_sm)s *sm_ptr = &sm[blockIDX * sm_s0];
648                {
649                    // This function trashes buf[1..n_threads],
650                    // leaving the reduction result in buf[0].
651                    %(ctype)s red = %(load_x)s(x_ptr[LID_0 * sx1]);
652                    #pragma unroll 16
653                    for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
654                        red = max(red, %(load_x)s(x_ptr[i * sx1]));
655                    }
656                    buf[LID_0] = red;
657                    local_barrier();
658                    if (LID_0 < GA_WARP_SIZE) {
659                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
660                            buf[LID_0] = max(buf[LID_0], buf[i]);
661                        }
662                    }
663                    local_barrier();
664                    //reduce so that LID_0 0 has the reduction of everything
665                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
666                        if (LID_0 < _n && LID_0 + _n < N)
667                            buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
668                        local_barrier();
669                    }
670                }
671                %(ctype)s row_max = buf[0];
672                local_barrier();
673                {
674                    // This function trashes buf[1..n_threads],
675                    // leaving the reduction result in buf[0].
676                    %(ctype)s red = exp(%(load_x)s(x_ptr[LID_0 * sx1]) - row_max);
677                    #pragma unroll 16
678                    for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
679                        red = red + exp(%(load_x)s(x_ptr[i * sx1]) - row_max);
680                    }
681                    buf[LID_0] = red;
682                    local_barrier();
683                    if (LID_0 < GA_WARP_SIZE) {
684                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
685                            buf[LID_0] = buf[LID_0] + buf[i];
686                        }
687                    }
688                    local_barrier();
689                    //reduce so that LID_0 0 has the reduction of everything
690                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
691                        if (LID_0 < _n && LID_0 + _n < N)
692                            buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
693                        local_barrier();
694                    }
695                }
696                %(ctype)s row_sum = buf[0];
697                local_barrier();
698                for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
699                    sm_ptr[tx * sm_s1] = %(write_sm)s(exp(%(load_x)s(x_ptr[tx * sx1]) - row_max) / row_sum);
700                }
701                local_barrier();
702            }
703        }
704        """ % locals()
705        kernels.append(Kernel(code=code, name=kname, params=params,
706                              flags=flags, objvar=k_var))
707        return kernels
708
709
710gpu_softmax = GpuSoftmax()
711
712
713class GpuSoftmaxWithBias(GpuKernelBase, Op):
714    """
715    Implement SoftmaxWithBias on the gpu.
716
717    """
718    nin = 2
719    nout = 1
720    __props__ = ()
721    _f16_ok = True
722
723    def make_node(self, x, b):
724        ctx_name = infer_context_name(x, b)
725        x = as_gpuarray_variable(x, ctx_name)
726        b = as_gpuarray_variable(b, ctx_name)
727        return Apply(self, [x, b], [x.type()])
728
729    def infer_shape(self, node, shape):
730        return [shape[0]]
731
732    def c_code_cache_version(self):
733        return (16,)
734
735    def c_headers(self):
736        return ['<numpy_compat.h>', '<gpuarray/types.h>']
737
738    def c_code(self, node, nodename, inp, out, sub):
739        dtype_x = node.inputs[0].dtype
740        dtype_b = node.inputs[1].dtype
741        dtype_z = node.outputs[0].dtype
742        work_x = work_dtype(dtype_x)
743        itemsize_x = np.dtype(dtype_x).itemsize
744        itemsize_b = np.dtype(dtype_b).itemsize
745        itemsize_z = np.dtype(dtype_z).itemsize
746        typecode = pygpu.gpuarray.dtype_to_typecode(node.outputs[0].dtype)
747        x, b = inp
748        z, = out
749        fail = sub['fail']
750        ctx = sub['params']
751        err_check = """
752            if (err != GA_NO_ERROR) {
753                PyErr_Format(PyExc_RuntimeError, fmt_str, msg);
754                %(fail)s;
755            }
756        """ % locals()
757        return """
758        if (PyGpuArray_NDIM(%(x)s) != 2)
759        {
760            PyErr_SetString(PyExc_ValueError, "rank error input");
761            %(fail)s;
762        }
763        if (PyGpuArray_NDIM(%(b)s) != 1)
764        {
765            PyErr_SetString(PyExc_ValueError, "rank error for the bias");
766            %(fail)s;
767        }
768        if ((PyGpuArray_DIMS(%(x)s)[1] !=
769            PyGpuArray_DIMS(%(b)s)[0]))
770        {
771            PyErr_Format(PyExc_ValueError,
772                         "number of columns in x (%%ld)"
773                         " does not match length of b (%%ld)",
774                         (long int)PyGpuArray_DIMS(%(x)s)[1],
775                         (long int)PyGpuArray_DIMS(%(b)s)[0]);
776            %(fail)s;
777        }
778        if ((NULL == %(z)s)
779            || (PyGpuArray_DIMS(%(z)s)[0] !=
780                PyGpuArray_DIMS(%(x)s)[0])
781            || (PyGpuArray_DIMS(%(z)s)[1] !=
782                PyGpuArray_DIMS(%(x)s)[1]))
783        {
784            Py_XDECREF(%(z)s);
785            %(z)s = pygpu_empty(2, PyGpuArray_DIMS(%(x)s),
786                                %(typecode)s, GA_C_ORDER,
787                                %(ctx)s, Py_None);
788            if (!%(z)s) {
789                %(fail)s
790            }
791        }
792        {
793            size_t n_blocks[3] = {std::min(PyGpuArray_DIMS(%(x)s)[0], (size_t)(32*1024)), 1, 1};
794//TODO, detect the maximum number of thread per block.
795            size_t threads_per_block[3] = {std::min(PyGpuArray_DIMS(%(x)s)[1], (size_t)256), 1, 1}; // TODO: Read GA_CTX_PROP_MAXLSIZE0
796            size_t shmem_sz = PyGpuArray_DIMS(%(x)s)[1] *
797                                     2 * sizeof(npy_%(work_x)s);
798            ssize_t stride_X0 = PyGpuArray_STRIDES(%(x)s)[0] / %(itemsize_x)s;
799            ssize_t stride_X1 = PyGpuArray_STRIDES(%(x)s)[1] / %(itemsize_x)s;
800            ssize_t stride_B0 = PyGpuArray_STRIDES(%(b)s)[0] / %(itemsize_b)s;
801            ssize_t stride_Z0 = PyGpuArray_STRIDES(%(z)s)[0] / %(itemsize_z)s;
802            ssize_t stride_Z1 = PyGpuArray_STRIDES(%(z)s)[1] / %(itemsize_z)s;
803            const char *fmt_str, *msg;
804            void *kernel_params[] = {
805                (void *)&PyGpuArray_DIMS(%(x)s)[0],
806                (void *)&PyGpuArray_DIMS(%(x)s)[1],
807                (void *)%(x)s->ga.data, (void *)&%(x)s->ga.offset,
808                (void *)&stride_X0, (void *)&stride_X1,
809                (void *)%(b)s->ga.data, (void *)&%(b)s->ga.offset,
810                (void *)&stride_B0,
811                (void *)%(z)s->ga.data, (void *)&%(z)s->ga.offset,
812                (void *)&stride_Z0, (void *)&stride_Z1};
813            int err = GA_NO_ERROR;
814            if (PyGpuArray_DIMS(%(x)s)[0] > 0)
815            {
816              if(shmem_sz < (32 * 1024 - 500)){
817                err = GpuKernel_call(&kSoftmaxWithBias_%(nodename)s, 3,
818                                     n_blocks, threads_per_block, shmem_sz,
819                                     kernel_params);
820                fmt_str = "gpuarray error: kSoftmaxWithBias_%(nodename)s: %%s";
821                msg = GpuKernel_error(&kSoftmaxWithBias_%(nodename)s, err);
822              }else{
823                err = GpuKernel_call(&kSoftmaxWithBias_fixed_shared%(nodename)s,
824                                     3, n_blocks, threads_per_block,
825                                     threads_per_block[0] * sizeof(npy_%(work_x)s),
826                                     kernel_params);
827                fmt_str = "gpuarray error: kSoftmaxWithBias_fixed_shared%(nodename)s: %%s";
828                msg = GpuKernel_error(&kSoftmaxWithBias_fixed_shared%(nodename)s, err);
829              }
830              %(err_check)s
831            }
832        }
833        assert(%(z)s);
834        """ % locals()
835
836    def gpu_kernels(self, node, nodename):
837        dtype_x = node.inputs[0].dtype
838        dtype_b = node.inputs[1].dtype
839        dtype_sm = node.outputs[0].dtype
840        load_x = load_w(node.inputs[0].dtype)
841        load_b = load_w(node.inputs[1].dtype)
842        write_sm = write_w(node.outputs[0].dtype)
843        work_sm = work_dtype(node.outputs[0].dtype)
844        flags = Kernel.get_flags(dtype_x, dtype_b, dtype_sm)
845        type_x = gpuarray.dtype_to_ctype(dtype_x)
846        type_b = gpuarray.dtype_to_ctype(dtype_b)
847        type_sm = gpuarray.dtype_to_ctype(dtype_sm)
848        type_acc = gpuarray.dtype_to_ctype(work_sm)
849
850        ctype = gpuarray.dtype_to_ctype(work_sm)
851
852        params = [
853            gpuarray.SIZE, gpuarray.SIZE,
854            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE,
855            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE,
856            gpuarray.GpuArray, gpuarray.SIZE, gpuarray.SSIZE, gpuarray.SSIZE,
857        ]
858        kernels = []
859        kname = "kSoftmaxWithBias"
860        k_var = "kSoftmaxWithBias_" + nodename
861        code = """#include "cluda.h"
862
863        KERNEL void %(kname)s (const ga_size M, const ga_size N,
864                       GLOBAL_MEM const %(type_x)s * x, const ga_size offset_x, const ga_ssize sx0, const ga_ssize sx1,
865                       GLOBAL_MEM const %(type_b)s * b, const ga_size offset_b, const ga_ssize sb0,
866                       GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
867        {
868            GA_DECL_SHARED_BODY(%(type_acc)s, buf);
869            LOCAL_MEM_ARG %(type_acc)s * buf2 = buf + N;
870            x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
871            b = (GLOBAL_MEM const %(type_b)s *)(((GLOBAL_MEM char *)b)+offset_b);
872            sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
873            for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
874                for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
875                    buf[tx] = %(load_x)s(x[blockIDX * sx0 + tx * sx1]);
876                    buf[tx] += %(load_b)s(b[tx * sb0]);
877                    buf2[tx] = buf[tx];
878                }
879                local_barrier();
880                {
881                    // This function trashes buf[1..GA_WARP_SIZE],
882                    // leaving the reduction result in buf[0].
883                    if (LID_0 < GA_WARP_SIZE) {
884                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
885                        {
886                            buf[LID_0] = max(buf[LID_0], buf[i]);
887                        }
888                    }
889                    local_barrier();
890                    //reduce so that LID_0 0 has the reduction of everything
891                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
892                        if (LID_0 < _n && LID_0 + _n < N)
893                            buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
894                        local_barrier();
895                    }
896                }
897                %(ctype)s row_max = buf[0];
898                local_barrier();
899                for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){;
900                    buf[__i] = exp(buf2[__i] - row_max);
901                    buf2[__i] = buf[__i];
902                }
903                local_barrier();
904                {
905                    // This function trashes buf[1..GA_WARP_SIZE],
906                    // leaving the reduction result in buf[0].
907                    if (LID_0 < GA_WARP_SIZE) {
908                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < N; i += GA_WARP_SIZE)
909                        {
910                            buf[LID_0] = buf[LID_0] + buf[i];
911                        }
912                    }
913                    local_barrier();
914                    //reduce so that LID_0 0 has the reduction of everything
915                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
916                        if (LID_0 < _n && LID_0 + _n < N)
917                            buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
918                        local_barrier();
919                    }
920                }
921                %(ctype)s row_sum = buf[0];
922                local_barrier();
923                for(ga_int __i=LID_0; __i<N; __i+=LDIM_0){
924                    buf[__i] = buf2[__i] / row_sum;
925                }
926                local_barrier();
927                for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
928                    sm[blockIDX * sm_s0 + tx * sm_s1] = %(write_sm)s(buf[tx]);
929                }
930                local_barrier();
931            }
932        }
933        """ % locals()
934        kernels.append(Kernel(code=code, name=kname, params=params,
935                              flags=flags, objvar=k_var))
936        kname = "kSoftmaxWithBias_fixed_shared"
937        k_var = "kSoftmaxWithBias_fixed_shared" + nodename
938        code = """#include "cluda.h"
939
940        KERNEL void %(kname)s (const ga_size M, const ga_size N,
941                       GLOBAL_MEM const %(type_x)s * x, const ga_size offset_x, const ga_ssize sx0, const ga_ssize sx1,
942                       GLOBAL_MEM const %(type_b)s * b, const ga_size offset_b, const ga_ssize sb0,
943                       GLOBAL_MEM %(type_sm)s * sm, const ga_size offset_sm, const ga_ssize sm_s0, const ga_ssize sm_s1 GA_DECL_SHARED_PARAM(%(type_acc)s, buf))
944        {
945            GA_DECL_SHARED_BODY(%(type_acc)s, buf);
946            x = (GLOBAL_MEM const %(type_x)s *)(((GLOBAL_MEM char *)x)+offset_x);
947            b = (GLOBAL_MEM const %(type_b)s *)(((GLOBAL_MEM char *)b)+offset_b);
948            sm = (GLOBAL_MEM %(type_sm)s *)(((GLOBAL_MEM char *)sm)+offset_sm);
949            for (ga_int blockIDX = GID_0; blockIDX < M; blockIDX += GDIM_0){
950                GLOBAL_MEM const %(type_x)s *x_ptr = &x[blockIDX * sx0];
951                GLOBAL_MEM %(type_sm)s *sm_ptr = &sm[blockIDX * sm_s0];
952                {
953                    // This function trashes buf[1..n_threads],
954                    // leaving the reduction result in buf[0].
955                    %(ctype)s red = %(load_x)s(x_ptr[LID_0 * sx1]) + %(load_b)s(b[LID_0 * sb0]);
956                    #pragma unroll 16
957                    for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
958                        red = max(red, %(load_x)s(x_ptr[i * sx1]) + %(load_b)s(b[i * sb0]));
959                    }
960                    buf[LID_0] = red;
961                    local_barrier();
962                    if (LID_0 < GA_WARP_SIZE) {
963                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
964                            buf[LID_0] = max(buf[LID_0], buf[i]);
965                        }
966                    }
967                    local_barrier();
968                    //reduce so that LID_0 0 has the reduction of everything
969                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
970                        if (LID_0 < _n && LID_0 + _n < N)
971                            buf[LID_0] = max(buf[LID_0], buf[LID_0+_n]);
972                        local_barrier();
973                    }
974                }
975                %(ctype)s row_max = buf[0];
976                local_barrier();
977                {
978                    // This function trashes buf[1..n_threads],
979                    // leaving the reduction result in buf[0].
980                    %(ctype)s red = exp(%(load_x)s(x_ptr[LID_0 * sx1]) + %(load_b)s(b[LID_0 * sb0]) - row_max);
981                    #pragma unroll 16
982                    for (ga_int i = LID_0 + LDIM_0; i<N; i += LDIM_0) {
983                    red = red + exp(%(load_x)s(x_ptr[i * sx1]) + %(load_b)s(b[i * sb0]) - row_max);
984                    }
985                    buf[LID_0] = red;
986                    local_barrier();
987                    if (LID_0 < GA_WARP_SIZE) {
988                        for (ga_int i = LID_0 + GA_WARP_SIZE; i < LDIM_0; i += GA_WARP_SIZE) {
989                            buf[LID_0] = buf[LID_0] + buf[i];
990                        }
991                    }
992                    local_barrier();
993                    //reduce so that LID_0 0 has the reduction of everything
994                    for (ga_uint _n = GA_WARP_SIZE / 2; _n > 0; _n /= 2) {
995                        if (LID_0 < _n && LID_0 + _n < N)
996                            buf[LID_0] = buf[LID_0] + buf[LID_0+_n];
997                        local_barrier();
998                    }
999                }
1000                %(ctype)s row_sum = buf[0];
1001                local_barrier();
1002                for (ga_int tx = LID_0; tx< N; tx += LDIM_0){
1003                    sm_ptr[tx * sm_s1] = %(write_sm)s(exp(%(load_x)s(x_ptr[tx * sx1]) + %(load_b)s(b[tx * sb0]) - row_max) / row_sum);
1004                }
1005                local_barrier();
1006            }
1007        }
1008        """ % locals()
1009        kernels.append(Kernel(code=code, name=kname, params=params,
1010                              flags=flags, objvar=k_var))
1011        return kernels
1012
1013
1014gpu_softmax_with_bias = GpuSoftmaxWithBias()
1015