1 #ifndef STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP
2 #define STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP
3 #ifdef STAN_OPENCL
4 
5 #include <stan/math/opencl/kernel_cl.hpp>
6 #include <string>
7 
8 namespace stan {
9 namespace math {
10 namespace opencl_kernels {
11 
12 // \cond
13 static const std::string categorical_logit_glm_kernel_code = STRINGIFY(
14     // \endcond
15     /** \ingroup opencl_kernels
16      * GPU implementation of Generalized Linear Model (GLM)
17      * with categorical distribution and logit (softmax) link function.
18      *
19      * Must be run with at least N threads and local size equal to LOCAL_SIZE_.
20      * @param[out] logp_global partially summed log probability (1 value per
21      * work group)
22      * @param[out] exp_lin_global exponentiation of sum of alpha and matrix
23      * product of x and beta
24      * @param[out] inv_sum_exp_lin_global inverse of rowwise sum of \c
25      * exp_lin_global
26      * @param[out] neg_softmax_lin_global negated softmax of sum of alpha and
27      * matrix product of x and beta
28      * @param[out] alpha_derivative derivative wrt alpha
29      * @param[in] y_global a scalar or vector of classes.
30      * @param[in] x_beta_global product of design matrix and weight matrix
31      * @param[in] alpha_global intercept (in log odds)
32      * @param N_instances number of instances
33      * @param N_attributes number of attributes
34      * @param N_classes number of classes
35      * @param is_y_vector 0 or 1 - whether y is a vector (alternatively it is a
36      * scalar)
37      * @param need_alpha_derivative interpreted as boolean - whether
38      * alpha_derivative needs to be computed
39      * @param need_neg_softmax_lin_global interpreted as boolean - whether
40      * neg_softmax_lin_global needs to be computed
41      */
42     __kernel void categorical_logit_glm(
43         __global double* logp_global, __global double* exp_lin_global,
44         __global double* inv_sum_exp_lin_global,
45         __global double* neg_softmax_lin_global,
46         __global double* alpha_derivative, const __global int* y_global,
47         const __global double* x_beta_global,
48         const __global double* alpha_global, const int N_instances,
49         const int N_attributes, const int N_classes, const int is_y_vector,
50         const int need_alpha_derivative,
51         const int need_neg_softmax_lin_global) {
52       const int gid = get_global_id(0);
53       const int lid = get_local_id(0);
54       const int lsize = get_local_size(0);
55       const int wg_id = get_group_id(0);
56       const int ngroups = get_num_groups(0);
57 
58       __local double local_storage[LOCAL_SIZE_];
59 
60       double logp = 0;
61       double inv_sum_exp_lin;
62       int class_idx = -1;
63       // Most calculations only happen for relevant data within next if.
64       // Exceptions are reductions between threads that need barriers.
65       if (gid < N_instances) {
66         double lin_max = -INFINITY;
67         for (int i = 0; i < N_classes; i++) {
68           double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
69           if (lin > lin_max) {
70             lin_max = lin;
71           }
72         }
73         double alpha = alpha_global[gid];
74         double sum_exp_lin = 0;
75         for (int i = 0; i < N_classes; i++) {
76           double lin = x_beta_global[i * N_instances + gid] + alpha_global[i];
77           double exp_lin = exp(lin - lin_max);
78           sum_exp_lin += exp_lin;
79           exp_lin_global[i * N_instances + gid] = exp_lin;
80         }
81         inv_sum_exp_lin = 1 / sum_exp_lin;
82         inv_sum_exp_lin_global[gid] = inv_sum_exp_lin;
83 
84         class_idx = y_global[gid * is_y_vector] - 1;
85         if (class_idx < 0 || class_idx > N_classes) {
86           logp = NAN;
87         } else {
88           logp = log(inv_sum_exp_lin) - lin_max
89                  + x_beta_global[class_idx * N_instances + gid]
90                  + alpha_global[class_idx];
91         }
92       }
93       barrier(CLK_GLOBAL_MEM_FENCE);
94       double neg_softmax_lin_sum = 0;
95       if (need_alpha_derivative || need_neg_softmax_lin_global) {
96         for (int i = 0; i < N_classes; i++) {
97           double neg_softmax_lin = 0;
98           if (gid < N_instances) {
99             int idx = i * N_instances + gid;
100             neg_softmax_lin = -exp_lin_global[idx] * inv_sum_exp_lin;
101             if (need_neg_softmax_lin_global) {
102               neg_softmax_lin_global[idx] = neg_softmax_lin;
103             }
104           }
105           if (need_alpha_derivative) {
106             local_storage[lid] = neg_softmax_lin + (class_idx == i);
107             barrier(CLK_LOCAL_MEM_FENCE);
108             for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
109                  step /= REDUCTION_STEP_SIZE) {
110               if (lid < step) {
111                 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
112                   local_storage[lid] += local_storage[lid + step * i];
113                 }
114               }
115               barrier(CLK_LOCAL_MEM_FENCE);
116             }
117             if (lid == 0) {
118               alpha_derivative[i + wg_id * N_classes] = local_storage[0];
119             }
120             barrier(CLK_LOCAL_MEM_FENCE);
121           }
122         }
123       }
124       // Sum logp, calculated by different threads.
125       // Since we can't sum between different work groups, we emit one number
126       // per work group. These must be summed on CPU for final result.
127       local_storage[lid] = logp;
128       barrier(CLK_LOCAL_MEM_FENCE);
129       for (int step = lsize / REDUCTION_STEP_SIZE; step > 0;
130            step /= REDUCTION_STEP_SIZE) {
131         if (lid < step) {
132           for (int i = 1; i < REDUCTION_STEP_SIZE; i++) {
133             local_storage[lid] += local_storage[lid + step * i];
134           }
135         }
136         barrier(CLK_LOCAL_MEM_FENCE);
137       }
138       if (lid == 0) {
139         logp_global[wg_id] = local_storage[0];
140       }
141     }
142     // \cond
143 );
144 // \endcond
145 
146 /** \ingroup opencl_kernels
147  * See the docs for \link kernels/categorical_logit_glm_lpmf.hpp
148  * categorical_logit_glm() \endlink
149  */
150 const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, out_buffer,
151                 in_buffer, in_buffer, in_buffer, int, int, int, int, int, int>
152     categorical_logit_glm("categorical_logit_glm",
153                           {categorical_logit_glm_kernel_code},
154                           {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}});
155 
156 // \cond
157 static const std::string categorical_logit_glm_beta_derivative_kernel_code
158     = STRINGIFY(
159         // \endcond
160         /** \ingroup opencl_kernels
161          * Calculates derivative wrt beta.
162          *
163          * Must be run with global size of local_size*N_attributes.
164          * @param[in,out] beta_derivative derivative wrt beta
165          * @param temp temporary workspace of size global_size*N_classes
166          * @param[in] y a scalar or vector of classes
167          * @param[in] x design matrix
168          * @param N_instances number of instances
169          * @param N_attributes number of attributes
170          * @param N_classes number of classes
171          * @param is_y_vector 0 or 1 - whether y is a vector (alternatively it
172          * is a scalar)
173          */
174         __kernel void categorical_logit_glm_beta_derivative(
175             __global double* beta_derivative, __global double* temp,
176             const __global int* y, const __global double* x,
177             const int N_instances, const int N_attributes, const int N_classes,
178             const int is_y_vector) {
179           const int gid = get_global_id(0);
180           const int lid = get_local_id(0);
181           const int lsize = get_local_size(0);
182           const int wg_id = get_group_id(0);
183 
184           for (int i = 0; i < N_classes; i++) {
185             temp[gid * N_classes + i] = 0;
186           }
187           for (int i = lid; i < N_instances; i += lsize) {
188             int pos = y[i * is_y_vector] - 1;
189             temp[gid * N_classes + pos] += x[wg_id * N_instances + i];
190           }
191           barrier(CLK_GLOBAL_MEM_FENCE);
192           for (int i = lid; i < N_classes; i += lsize) {
193             double res = 0;
194             for (int j = 0; j < lsize; j++) {
195               res += temp[(wg_id * lsize + j) * N_classes + i];
196             }
197             beta_derivative[i * N_attributes + wg_id] += res;
198           }
199         }
200         // \cond
201     );  // NOLINT
202 // \endcond
203 
204 /** \ingroup opencl_kernels
205  * See the docs for \link kernels/categorical_logit_glm_lpmf.hpp
206  * categorical_logit_glm_beta_derivative() \endlink
207  */
208 const kernel_cl<in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int,
209                 int, int>
210     categorical_logit_glm_beta_derivative(
211         "categorical_logit_glm_beta_derivative",
212         {categorical_logit_glm_beta_derivative_kernel_code});
213 
214 }  // namespace opencl_kernels
215 
216 }  // namespace math
217 }  // namespace stan
218 #endif
219 #endif
220