1 #ifndef STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP 2 #define STAN_MATH_OPENCL_KERNELS_CATEGORICAL_LOGIT_GLM_LPMF_HPP 3 #ifdef STAN_OPENCL 4 5 #include <stan/math/opencl/kernel_cl.hpp> 6 #include <string> 7 8 namespace stan { 9 namespace math { 10 namespace opencl_kernels { 11 12 // \cond 13 static const std::string categorical_logit_glm_kernel_code = STRINGIFY( 14 // \endcond 15 /** \ingroup opencl_kernels 16 * GPU implementation of Generalized Linear Model (GLM) 17 * with categorical distribution and logit (softmax) link function. 18 * 19 * Must be run with at least N threads and local size equal to LOCAL_SIZE_. 20 * @param[out] logp_global partially summed log probability (1 value per 21 * work group) 22 * @param[out] exp_lin_global exponentiation of sum of alpha and matrix 23 * product of x and beta 24 * @param[out] inv_sum_exp_lin_global inverse of rowwise sum of \c 25 * exp_lin_global 26 * @param[out] neg_softmax_lin_global negated softmax of sum of alpha and 27 * matrix product of x and beta 28 * @param[out] alpha_derivative derivative wrt alpha 29 * @param[in] y_global a scalar or vector of classes. 30 * @param[in] x_beta_global product of design matrix and weight matrix 31 * @param[in] alpha_global intercept (in log odds) 32 * @param N_instances number of instances 33 * @param N_attributes number of attributes 34 * @param N_classes number of classes 35 * @param is_y_vector 0 or 1 - whether y is a vector (alternatively it is a 36 * scalar) 37 * @param need_alpha_derivative interpreted as boolean - whether 38 * alpha_derivative needs to be computed 39 * @param need_neg_softmax_lin_global interpreted as boolean - whether 40 * neg_softmax_lin_global needs to be computed 41 */ 42 __kernel void categorical_logit_glm( 43 __global double* logp_global, __global double* exp_lin_global, 44 __global double* inv_sum_exp_lin_global, 45 __global double* neg_softmax_lin_global, 46 __global double* alpha_derivative, const __global int* y_global, 47 const __global double* x_beta_global, 48 const __global double* alpha_global, const int N_instances, 49 const int N_attributes, const int N_classes, const int is_y_vector, 50 const int need_alpha_derivative, 51 const int need_neg_softmax_lin_global) { 52 const int gid = get_global_id(0); 53 const int lid = get_local_id(0); 54 const int lsize = get_local_size(0); 55 const int wg_id = get_group_id(0); 56 const int ngroups = get_num_groups(0); 57 58 __local double local_storage[LOCAL_SIZE_]; 59 60 double logp = 0; 61 double inv_sum_exp_lin; 62 int class_idx = -1; 63 // Most calculations only happen for relevant data within next if. 64 // Exceptions are reductions between threads that need barriers. 65 if (gid < N_instances) { 66 double lin_max = -INFINITY; 67 for (int i = 0; i < N_classes; i++) { 68 double lin = x_beta_global[i * N_instances + gid] + alpha_global[i]; 69 if (lin > lin_max) { 70 lin_max = lin; 71 } 72 } 73 double alpha = alpha_global[gid]; 74 double sum_exp_lin = 0; 75 for (int i = 0; i < N_classes; i++) { 76 double lin = x_beta_global[i * N_instances + gid] + alpha_global[i]; 77 double exp_lin = exp(lin - lin_max); 78 sum_exp_lin += exp_lin; 79 exp_lin_global[i * N_instances + gid] = exp_lin; 80 } 81 inv_sum_exp_lin = 1 / sum_exp_lin; 82 inv_sum_exp_lin_global[gid] = inv_sum_exp_lin; 83 84 class_idx = y_global[gid * is_y_vector] - 1; 85 if (class_idx < 0 || class_idx > N_classes) { 86 logp = NAN; 87 } else { 88 logp = log(inv_sum_exp_lin) - lin_max 89 + x_beta_global[class_idx * N_instances + gid] 90 + alpha_global[class_idx]; 91 } 92 } 93 barrier(CLK_GLOBAL_MEM_FENCE); 94 double neg_softmax_lin_sum = 0; 95 if (need_alpha_derivative || need_neg_softmax_lin_global) { 96 for (int i = 0; i < N_classes; i++) { 97 double neg_softmax_lin = 0; 98 if (gid < N_instances) { 99 int idx = i * N_instances + gid; 100 neg_softmax_lin = -exp_lin_global[idx] * inv_sum_exp_lin; 101 if (need_neg_softmax_lin_global) { 102 neg_softmax_lin_global[idx] = neg_softmax_lin; 103 } 104 } 105 if (need_alpha_derivative) { 106 local_storage[lid] = neg_softmax_lin + (class_idx == i); 107 barrier(CLK_LOCAL_MEM_FENCE); 108 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0; 109 step /= REDUCTION_STEP_SIZE) { 110 if (lid < step) { 111 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) { 112 local_storage[lid] += local_storage[lid + step * i]; 113 } 114 } 115 barrier(CLK_LOCAL_MEM_FENCE); 116 } 117 if (lid == 0) { 118 alpha_derivative[i + wg_id * N_classes] = local_storage[0]; 119 } 120 barrier(CLK_LOCAL_MEM_FENCE); 121 } 122 } 123 } 124 // Sum logp, calculated by different threads. 125 // Since we can't sum between different work groups, we emit one number 126 // per work group. These must be summed on CPU for final result. 127 local_storage[lid] = logp; 128 barrier(CLK_LOCAL_MEM_FENCE); 129 for (int step = lsize / REDUCTION_STEP_SIZE; step > 0; 130 step /= REDUCTION_STEP_SIZE) { 131 if (lid < step) { 132 for (int i = 1; i < REDUCTION_STEP_SIZE; i++) { 133 local_storage[lid] += local_storage[lid + step * i]; 134 } 135 } 136 barrier(CLK_LOCAL_MEM_FENCE); 137 } 138 if (lid == 0) { 139 logp_global[wg_id] = local_storage[0]; 140 } 141 } 142 // \cond 143 ); 144 // \endcond 145 146 /** \ingroup opencl_kernels 147 * See the docs for \link kernels/categorical_logit_glm_lpmf.hpp 148 * categorical_logit_glm() \endlink 149 */ 150 const kernel_cl<out_buffer, out_buffer, out_buffer, out_buffer, out_buffer, 151 in_buffer, in_buffer, in_buffer, int, int, int, int, int, int> 152 categorical_logit_glm("categorical_logit_glm", 153 {categorical_logit_glm_kernel_code}, 154 {{"REDUCTION_STEP_SIZE", 4}, {"LOCAL_SIZE_", 64}}); 155 156 // \cond 157 static const std::string categorical_logit_glm_beta_derivative_kernel_code 158 = STRINGIFY( 159 // \endcond 160 /** \ingroup opencl_kernels 161 * Calculates derivative wrt beta. 162 * 163 * Must be run with global size of local_size*N_attributes. 164 * @param[in,out] beta_derivative derivative wrt beta 165 * @param temp temporary workspace of size global_size*N_classes 166 * @param[in] y a scalar or vector of classes 167 * @param[in] x design matrix 168 * @param N_instances number of instances 169 * @param N_attributes number of attributes 170 * @param N_classes number of classes 171 * @param is_y_vector 0 or 1 - whether y is a vector (alternatively it 172 * is a scalar) 173 */ 174 __kernel void categorical_logit_glm_beta_derivative( 175 __global double* beta_derivative, __global double* temp, 176 const __global int* y, const __global double* x, 177 const int N_instances, const int N_attributes, const int N_classes, 178 const int is_y_vector) { 179 const int gid = get_global_id(0); 180 const int lid = get_local_id(0); 181 const int lsize = get_local_size(0); 182 const int wg_id = get_group_id(0); 183 184 for (int i = 0; i < N_classes; i++) { 185 temp[gid * N_classes + i] = 0; 186 } 187 for (int i = lid; i < N_instances; i += lsize) { 188 int pos = y[i * is_y_vector] - 1; 189 temp[gid * N_classes + pos] += x[wg_id * N_instances + i]; 190 } 191 barrier(CLK_GLOBAL_MEM_FENCE); 192 for (int i = lid; i < N_classes; i += lsize) { 193 double res = 0; 194 for (int j = 0; j < lsize; j++) { 195 res += temp[(wg_id * lsize + j) * N_classes + i]; 196 } 197 beta_derivative[i * N_attributes + wg_id] += res; 198 } 199 } 200 // \cond 201 ); // NOLINT 202 // \endcond 203 204 /** \ingroup opencl_kernels 205 * See the docs for \link kernels/categorical_logit_glm_lpmf.hpp 206 * categorical_logit_glm_beta_derivative() \endlink 207 */ 208 const kernel_cl<in_out_buffer, in_out_buffer, in_buffer, in_buffer, int, int, 209 int, int> 210 categorical_logit_glm_beta_derivative( 211 "categorical_logit_glm_beta_derivative", 212 {categorical_logit_glm_beta_derivative_kernel_code}); 213 214 } // namespace opencl_kernels 215 216 } // namespace math 217 } // namespace stan 218 #endif 219 #endif 220