1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file pdf_op.h
22  * \brief Operators for computing the pdf of random distributions.
23  */
24 #ifndef MXNET_OPERATOR_RANDOM_PDF_OP_H_
25 #define MXNET_OPERATOR_RANDOM_PDF_OP_H_
26 
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include <algorithm>
30 #include "../mshadow_op.h"
31 #include "../mxnet_op.h"
32 #include "../operator_common.h"
33 #include "../elemwise_op_common.h"
34 #include "../special_functions-inl.h"
35 #include "../tensor/broadcast_reduce_op.h"
36 
37 namespace mxnet {
38 namespace op {
39 
40 template<typename DType>
ceph_psi(DType val)41 MSHADOW_XINLINE static DType ceph_psi(DType val) { return special_functions::cephes::psi(val); }
42 template<>
ceph_psi(mshadow::half::half_t val)43 MSHADOW_XINLINE mshadow::half::half_t ceph_psi(mshadow::half::half_t val) {
44     return special_functions::cephes::psi<float>(val);
45 }
46 
47 template<bool logpdf>
48 struct PDF_Uniform {
49   template<typename DType, typename IType1, typename IType2>
MapPDF_Uniform50   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
51                                   DType *out, IType1 *sample, IType2 *lower, IType2 *upper) {
52     const index_t index(start / sample_size);
53     const DType l(lower[index]), h(upper[index]);
54     const index_t end = start + length;
55     for (index_t i = start; i < end; ++i) {
56         // No check whether sample is in the support.
57         out[i] = logpdf ? -DType(log(h - l)) : DType(1.0) / (h - l);
58     }
59   }
60 };
61 
62 template<bool logpdf>
63 struct PDF_Uniform_Grad {
64   template<typename DType, typename IType1, typename IType2>
MapPDF_Uniform_Grad65   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
66                   DType *out, IType1 *sample, IType2 *lower, IType2 *upper,
67                   DType *grad_out, IType1 *grad_sample, IType2 *grad_lower, IType2 *grad_upper) {
68     const index_t index(start / sample_size);
69     const DType l(lower[index]), h(upper[index]);
70 
71     const index_t end = start + length;
72     for (index_t i = start; i < end; ++i) {
73         const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
74         grad_lower[i]  = scaling / (h - l);
75         grad_upper[i]  = scaling / (l - h);
76         KERNEL_ASSIGN(grad_sample[i], req, 0);
77     }
78   }
79 };
80 
81 template<bool logpdf>
82 struct PDF_Normal {
83   template<typename DType, typename IType1, typename IType2>
MapPDF_Normal84   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
85                                   DType *out, IType1 *sample, IType2 *loc, IType2 *scale) {
86     const index_t index(start / sample_size);
87     const DType u(loc[index]), s(scale[index]), sq(s * s);
88     const DType normalizer(sqrt(2.0 * mxnet_op::PI) * s);
89 
90     const index_t end = start + length;
91     for (index_t i = start; i < end; ++i) {
92         const DType x(sample[i]);
93         const DType exponent((DType(-0.5) * (x - u) * (x - u)) / (sq));
94         out[i] = logpdf ? exponent - log(normalizer) : exp(exponent) / normalizer;
95     }
96   }
97 };
98 
99 template<bool logpdf>
100 struct PDF_Normal_Grad {
101   template<typename DType, typename IType1, typename IType2>
MapPDF_Normal_Grad102   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
103                   DType *out, IType1 *sample, IType2 *loc, IType2 *scale,
104                   DType *grad_out, IType1 *grad_sample, IType2 *grad_loc, IType2 *grad_scale) {
105     const index_t index(start / sample_size);
106     const DType u(loc[index]), s(scale[index]), s_squared(s * s), s_cubed(s_squared * s);
107 
108     const index_t end = start + length;
109     for (index_t i = start; i < end; ++i) {
110         const DType x(sample[i]);
111         const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
112         grad_loc[i]    = scaling * (x - u) / s_squared;
113         grad_scale[i]  = scaling * ((x - u) * (x - u) - s_squared) / s_cubed;
114         KERNEL_ASSIGN(grad_sample[i], req, scaling * (u - x) / s_squared);
115     }
116   }
117 };
118 
119 template<bool logpdf>
120 struct PDF_Gamma {
121   template<typename DType, typename IType1, typename IType2>
MapPDF_Gamma122   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
123                                   DType *out, IType1 *sample, IType2 *alpha, IType2 *beta) {
124     const index_t index(start / sample_size);
125     const DType a(alpha[index]), b(beta[index]), lgamma_a(lgamma(a)), a_log_b(a * log(b));
126 
127     const index_t end = start + length;
128     for (index_t i = start; i < end; ++i) {
129         const DType x(sample[i]);
130         const DType lpdf(a_log_b + (a - 1) * log(x) - b * x - lgamma_a);
131         out[i] = logpdf ? lpdf : DType(exp(lpdf));
132     }
133   }
134 };
135 
136 template<bool logpdf>
137 struct PDF_Gamma_Grad {
138   template<typename DType, typename IType1, typename IType2>
MapPDF_Gamma_Grad139   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
140                   DType *out, IType1 *sample, IType2 *alpha, IType2 *beta,
141                   DType *grad_out, IType1 *grad_sample, IType2 *grad_alpha, IType2 *grad_beta) {
142     const index_t index(start / sample_size);
143     const DType a(alpha[index]), b(beta[index]), log_b(log(b)), ceph_psi_a(ceph_psi(a));
144 
145     const index_t end = start + length;
146     for (index_t i = start; i < end; ++i) {
147         const DType x(sample[i]);
148         const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
149         grad_alpha[i]  = scaling * (log_b + log(x) - ceph_psi_a);
150         grad_beta[i]   = scaling * (a / b - x);
151         KERNEL_ASSIGN(grad_sample[i], req, scaling * ((a - 1) / x - b));
152     }
153   }
154 };
155 
156 template<bool logpdf>
157 struct PDF_Exponential {
158   template<typename DType, typename IType1, typename IType2>
MapPDF_Exponential159   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
160                                   DType *out, IType1 *sample, IType2 *lambda) {
161     const index_t index(start / sample_size);
162     const DType l(lambda[index]), log_l(log(l));
163 
164     const index_t end = start + length;
165     for (index_t i = start; i < end; ++i) {
166         const DType x(sample[i]);
167         out[i] = logpdf ? log_l - l * x : l * exp(-l * x);
168     }
169   }
170 };
171 
172 template<bool logpdf>
173 struct PDF_Exponential_Grad {
174   template<typename DType, typename IType1, typename IType2>
MapPDF_Exponential_Grad175   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
176                   DType *out, IType1 *sample, IType2 *lambda,
177                   DType *grad_out, IType1 *grad_sample, IType2 *grad_lambda) {
178     const index_t index(start / sample_size);
179     const DType l(lambda[index]);
180 
181     const index_t end = start + length;
182     for (index_t i = start; i < end; ++i) {
183         const DType x(sample[i]);
184         const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
185         grad_lambda[i] = scaling * (DType(1) / l - x);
186         KERNEL_ASSIGN(grad_sample[i], req, -scaling * l);
187     }
188   }
189 };
190 
191 template<bool logpdf>
192 struct PDF_Poisson {
193   template<typename DType, typename IType1, typename IType2>
MapPDF_Poisson194   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
195                                   DType *out, IType1 *sample, IType2 *lambda) {
196     const index_t index(start / sample_size);
197     const DType l(lambda[index]), log_l(log(l));
198 
199     const index_t end = start + length;
200     for (index_t i = start; i < end; ++i) {
201         const DType x(sample[i]);
202         const DType lpdf((x * log_l - lgamma(x + 1)) - l);
203         out[i] = logpdf ? lpdf  : DType(exp(lpdf));
204     }
205   }
206 };
207 
208 template<bool logpdf>
209 struct PDF_Poisson_Grad {
210   template<typename DType, typename IType1, typename IType2>
MapPDF_Poisson_Grad211   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
212                   DType *out, IType1 *sample, IType2 *lambda,
213                   DType *grad_out, IType1 *grad_sample, IType2 *grad_lambda) {
214     const index_t index(start / sample_size);
215     const DType l(lambda[index]);
216 
217     const index_t end = start + length;
218     for (index_t i = start; i < end; ++i) {
219         const DType x(sample[i]);
220         const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
221         grad_lambda[i] = scaling * (x / l - DType(1));
222         KERNEL_ASSIGN(grad_sample[i], req, 0);
223     }
224   }
225 };
226 
227 
228 template<bool logpdf>
229 struct PDF_NegativeBinomial {
230   template<typename DType, typename IType1, typename IType2>
MapPDF_NegativeBinomial231   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
232                                   DType *out, IType1 *sample, IType2 *limit, IType2 *prob) {
233     const index_t index(start / sample_size);
234     const DType l(limit[index]), p(prob[index]), lgamma_l(lgamma(l));
235 
236     const index_t end = start + length;
237     for (index_t i = start; i < end; ++i) {
238         const DType x(sample[i]);
239         const DType lpdf((lgamma(x + l) - lgamma(x + 1) - lgamma_l) + l * log(p) + x * log(1 - p));
240         out[i] = logpdf ? lpdf : DType(exp(lpdf));
241     }
242   }
243 
244   template<typename DType>
LPDFPDF_NegativeBinomial245   MSHADOW_XINLINE static DType LPDF(DType l, DType p, DType x) {
246     // Note that "p" is the failure and not the success probability.
247     return (lgamma(x + l) - lgamma(x + 1) - lgamma(l)) + l * log(p) + x * log(1 - p);
248   }
249 };
250 
251 template<bool logpdf>
252 struct PDF_NegativeBinomial_Grad {
253   template<typename DType, typename IType1, typename IType2>
MapPDF_NegativeBinomial_Grad254   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
255                   DType *out, IType1 *sample, IType2 *limit, IType2 *prob,
256                   DType *grad_out, IType1 *grad_sample, IType2 *grad_limit, IType2 *grad_prob) {
257     const index_t index(start / sample_size);
258     const index_t end = start + length;
259     for (index_t i = start; i < end; ++i) {
260         DType grad_l(0), grad_p(0);
261         LPDF_GRAD(DType(limit[index]), DType(prob[index]),
262                   DType(sample[i]), out[i],
263                   grad_out[i], &grad_l, &grad_p);
264         grad_limit[i]  = grad_l;
265         grad_prob[i]   = grad_p;
266         KERNEL_ASSIGN(grad_sample[i], req, 0);
267     }
268   }
269 
270   template<typename DType>
LPDF_GRADPDF_NegativeBinomial_Grad271   MSHADOW_XINLINE static void LPDF_GRAD(DType l, DType p, DType x, DType o, DType grad_o,
272                                         DType* grad_l, DType* grad_p) {
273     const DType scaling(grad_o * (logpdf ? DType(1) : o));
274     *grad_l = scaling * ((ceph_psi(x + l) - ceph_psi(l)) + log(p));
275     *grad_p = scaling * (l / p - x / (1 - p));
276   }
277 };
278 
279 template<bool logpdf>
280 struct PDF_GeneralizedNegativeBinomial {
281   template<typename DType, typename IType1, typename IType2>
MapPDF_GeneralizedNegativeBinomial282   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
283                                   DType *out, IType1 *sample, IType2 *mu, IType2 *alpha) {
284     const index_t index(start / sample_size);
285 
286     // Reparameterize with limit = 1 / alpha, prob = 1 / (mu * alpha + 1)
287     const DType limit(1.0 / alpha[index]), prob(1.0 / (mu[index]*alpha[index]+1.0));
288 
289     const index_t end = start + length;
290     for (index_t i = start; i < end; ++i) {
291         const DType lpdf(PDF_NegativeBinomial<logpdf>::LPDF(limit, prob, DType(sample[i])));
292         out[i] = logpdf ? lpdf : DType(exp(lpdf));
293     }
294   }
295 };
296 
297 template<bool logpdf>
298 struct PDF_GeneralizedNegativeBinomial_Grad {
299   template<typename DType, typename IType1, typename IType2>
MapPDF_GeneralizedNegativeBinomial_Grad300   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
301                   DType *out, IType1 *sample, IType2 *mu, IType2 *alpha,
302                   DType *grad_out, IType1 *grad_sample, IType2 *grad_mu, IType2 *grad_alpha) {
303     const index_t index(start / sample_size);
304     const DType fmu(mu[index]), falpha(alpha[index]), den(fmu * falpha + 1.0);
305 
306     // Reparameterize with limit = 1 / alpha, prob = 1 / (mu * alpha + 1)
307     const DType limit(1.0 / falpha), prob(1.0 / (fmu * falpha + 1.0));
308 
309     const index_t end = start + length;
310     for (index_t i = start; i < end; ++i) {
311         // Grad returned as d_limit, d_prob
312         DType grad_l(0), grad_p(0);
313         PDF_NegativeBinomial_Grad<logpdf>::LPDF_GRAD(limit, prob,
314             DType(sample[i]), out[i],
315             grad_out[i], &grad_l, &grad_p);
316         grad_mu[i]     = -grad_p * falpha / (den * den);
317         grad_alpha[i]  = -grad_l / (falpha * falpha) - grad_p * fmu / (den * den);
318         KERNEL_ASSIGN(grad_sample[i], req, 0);
319     }
320   }
321 };
322 
323 template<bool logpdf>
324 struct PDF_Dirichlet {
325   template<typename DType, typename IType1, typename IType2>
MapPDF_Dirichlet326   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, index_t k,
327                                   DType *out, IType1 *sample, IType2 *alpha) {
328     const index_t index(start / sample_size);
329     const index_t end = start + length;
330     for (index_t i = start; i < end; ++i) {
331         const IType1 *cur_sample = sample + i * k;
332         const IType2 *cur_alpha  = alpha + index * k;
333         DType sum_alpha(0), sum_lgamma(0), sum_sample(0);
334         for (index_t j = 0; j < k; ++j) {
335           sum_alpha  += cur_alpha[j];
336           sum_lgamma += lgamma(cur_alpha[j]);
337           sum_sample += (cur_alpha[j]-1) * log(cur_sample[j]);
338         }
339         DType lpdf(sum_sample + (lgamma(sum_alpha) - sum_lgamma));
340         out[i] = logpdf ? lpdf : DType(exp(lpdf));
341     }
342   }
343 };
344 
345 
346 template<bool logpdf>
347 struct PDF_Dirichlet_Grad {
348   template<typename DType, typename IType1, typename IType2>
MapPDF_Dirichlet_Grad349   MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
350                   OpReqType req, index_t k,
351                   DType *out, IType1 *sample, IType2 *alpha,
352                   DType *grad_out, IType1 *grad_sample, IType2 *grad_alpha
353                   ) {
354     const index_t index(start / sample_size);
355     const index_t end = start + length;
356 
357     for (index_t i = start; i < end; ++i) {
358         // Digamma function
359         const IType1 *cur_sample = sample + i * k;
360         const IType2 *cur_alpha = alpha + index * k;
361 
362         const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
363         DType sum_alpha(0);
364         for (index_t j = 0; j < k; ++j) {
365           sum_alpha += cur_alpha[j];
366         }
367         const DType psi_sum(ceph_psi(sum_alpha));
368 
369         for (index_t j = 0; j < k; ++j) {
370           size_t grad_alpha_index = i%sample_size + sample_size * (j + k * index);
371           size_t grad_sample_index = i * k + j;
372 
373           // order grad_alpha differently to allow efficient reduction at the end.
374           grad_alpha[grad_alpha_index] =
375             scaling * (log(cur_sample[j]) + (psi_sum - ceph_psi(cur_alpha[j])));
376           KERNEL_ASSIGN(grad_sample[grad_sample_index],
377             req, scaling * (cur_alpha[j]-1) / cur_sample[j]);
378         }
379     }
380   }
381 };
382 
383 struct PdfParam : public dmlc::Parameter<PdfParam> {
384   bool is_log;
DMLC_DECLARE_PARAMETERPdfParam385   DMLC_DECLARE_PARAMETER(PdfParam) {
386     DMLC_DECLARE_FIELD(is_log).set_default(false)
387     .describe("If set, compute the density of the log-probability instead of the probability.");
388   }
389 };
390 
391 template<bool vparm = false>
PdfOpShape(const nnvm::NodeAttrs & attrs,std::vector<TShape> * in_attrs,std::vector<TShape> * out_attrs)392 inline bool PdfOpShape(const nnvm::NodeAttrs& attrs,
393                        std::vector<TShape>* in_attrs,
394                        std::vector<TShape>* out_attrs) {
395   CHECK_GT(in_attrs->size(), 1)
396     << "pdf operator takes at least 2 arguments (" << in_attrs->size() << " given)";
397   CHECK_EQ(out_attrs->size(), 1);
398   // All inputs must be defined in order to infer output shape.
399   if ( std::all_of((*in_attrs).begin(),
400                    (*in_attrs).end(),
401                    [](const TShape& s){ return s.ndim() > 0; }) ) {
402     // Tensors of distribution parameters must have same shape.
403     for (size_t i = 2; i < in_attrs->size(); ++i) {
404       SHAPE_ASSIGN_CHECK(*in_attrs, i, (*in_attrs)[i - 1]);
405     }
406     // Tensors of distribution parameters must match leftmost subshape of samples.
407     CHECK_LE((*in_attrs)[1].ndim(), (*in_attrs)[0].ndim())
408       << "dimension of input samples (" << (*in_attrs)[0].ndim()
409       << ") must be at least dimension of distribution parameters ("<< (*in_attrs)[1].ndim() << ")";
410     TShape tshape((*in_attrs)[0].begin(), (*in_attrs)[0].begin() + (*in_attrs)[1].ndim());
411     if (vparm) {
412       *(tshape.end() - 1) = *((*in_attrs)[0].end() - 1);
413     }
414     for (size_t i = 1; i < in_attrs->size(); ++i) {
415       SHAPE_ASSIGN_CHECK(*in_attrs, i, tshape);
416     }
417     // Output shape must equal input tensor of samples except for last dimension if we are
418     // dealing with samples that are itself vectors. Be aware of the special case where we
419     // are dealing with a single vector sample.
420     if (vparm && ((*in_attrs)[0].ndim() == 1)) {
421       // Special case where we are dealing with a single vector sample.
422       SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1));
423     } else {
424       TShape oshape((*in_attrs)[0].begin(), (*in_attrs)[0].end() - (vparm ? 1 : 0));
425       SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
426     }
427     return true;
428   }
429   return false;
430 }
431 
432 template<typename OP>
433 struct LaunchExWrapper {
434   template<typename ...Args>
MapLaunchExWrapper435   MSHADOW_XINLINE static void Map(const index_t start, const index_t length,
436                                   const index_t sample_size, Args... args) {
437     // Apply the operator to the sample in strides of sample_size, so that
438     // the operators can assume that their distribution parameters are constant.
439     index_t i = start;
440 
441     // Get aligned
442     const index_t align_step = sample_size - (i % sample_size);
443     const index_t first_stride = length > align_step ? align_step : length;
444     OP::Map(i, first_stride, sample_size, args...);
445     i += first_stride;
446 
447     const index_t end = start + length - sample_size;
448     for (; i < end; i += sample_size) {
449       OP::Map(i, sample_size, sample_size, args...);
450     }
451 
452     // Last stride might not be aligned either
453     const index_t last_stride = start + length - i;
454     if (last_stride > 0) {  // Don't overstep even if length <= sample_size
455       OP::Map(i, last_stride, sample_size, args...);
456     }
457   }
458 };
459 
460 template<typename xpu, typename DType, typename pdf, int pnum, bool vparm = false>
461 struct PdfCaller;
462 
463 template<typename xpu, typename DType, typename pdf>
464 struct PdfCaller<xpu, DType, pdf, 1, false> {
465   static void op(const std::vector<TBlob>& inputs,
466                  const std::vector<TBlob>& outputs,
467                  mshadow::Stream<xpu> *s) {
468     CHECK_EQ(inputs[0].Size()%inputs[1].Size(), 0);
469     CHECK_EQ(inputs[0].Size()%outputs[0].Size(), 0);
470     index_t num_samples(inputs[0].Size() / inputs[1].Size());
471     mxnet_op::Kernel<LaunchExWrapper<pdf>, xpu>::LaunchEx(s, outputs[0].Size(), num_samples,
472                 outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
473   }
474 };
475 
476 template<typename xpu, typename DType, typename pdf>
477 struct PdfCaller<xpu, DType, pdf, 1, true> {
478   static void op(const std::vector<TBlob>& inputs,
479                  const std::vector<TBlob>& outputs,
480                  mshadow::Stream<xpu> *s) {
481     CHECK_EQ(inputs[0].Size()%inputs[1].Size(), 0);
482     CHECK_EQ(inputs[0].Size()%outputs[0].Size(), 0);
483     index_t num_samples(inputs[0].Size() / inputs[1].Size());
484     index_t sample_size(inputs[0].Size() / outputs[0].Size());
485 
486     // Covers distributions parametrized by a vector of parameters (Dirichlet distribution).
487     mxnet_op::Kernel<LaunchExWrapper<pdf>, xpu>::LaunchEx(s, outputs[0].Size(),
488                 num_samples, sample_size,
489                 outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
490   }
491 };
492 
493 template<typename xpu, typename DType, typename pdf>
494 struct PdfCaller<xpu, DType, pdf, 2, false> {
495   static void op(const std::vector<TBlob>& inputs,
496                  const std::vector<TBlob>& outputs,
497                  mshadow::Stream<xpu> *s) {
498     CHECK_EQ(inputs[0].Size()%inputs[1].Size(), 0);
499     CHECK_EQ(inputs[0].Size(), outputs[0].Size());
500     index_t num_samples(inputs[0].Size() / inputs[1].Size());
501     mxnet_op::Kernel<LaunchExWrapper<pdf>, xpu>::LaunchEx(s, outputs[0].Size(), num_samples,
502                 outputs[0].dptr<DType>(), inputs[0].dptr<DType>(),
503                 inputs[1].dptr<DType>(), inputs[2].dptr<DType>());
504   }
505 };
506 
507 template<typename xpu, template<bool> class pdf, int pnum, bool vparm>
508 void PdfOpForward(const nnvm::NodeAttrs& attrs,
509                   const OpContext& ctx,
510                   const std::vector<TBlob>& inputs,
511                   const std::vector<OpReqType>& req,
512                   const std::vector<TBlob>& outputs) {
513   CHECK_NE(req[0], kAddTo);
514   CHECK_EQ(inputs.size(), pnum + 1);
515   CHECK_EQ(outputs.size(), 1);
516   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
517   const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
518   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
519     if ( param.is_log ) {
520       PdfCaller<xpu, DType, pdf<true>, pnum, vparm>::op(inputs, outputs, s);
521     } else {
522       PdfCaller<xpu, DType, pdf<false>, pnum, vparm>::op(inputs, outputs, s);
523     }
524   });
525 }
526 
527 
528 template<typename xpu, typename DType, typename pdfgrad, int pnum, int vparm = false>
529 struct PdfGradCaller;
530 
531 template<typename xpu, typename DType, typename pdfgrad>
532 struct PdfGradCaller<xpu, DType, pdfgrad, 1, false> {
533   static void op(const std::vector<TBlob>& inputs,
534                  const std::vector<OpReqType>& req,
535                  const std::vector<TBlob>& grads,
536                  mshadow::Stream<xpu> *s) {
537     index_t num_samples(inputs[1].Size() / inputs[2].Size());
538     mxnet_op::Kernel<LaunchExWrapper<pdfgrad>, xpu>::LaunchEx(s, inputs[0].Size(),
539                 num_samples, req[0],
540                 inputs[3].dptr<DType>(), inputs[1].dptr<DType>(), inputs[2].dptr<DType>(),
541                 inputs[0].dptr<DType>(), grads[0].dptr<DType>(), grads[1].dptr<DType>());
542   }
543 };
544 
545 template<typename xpu, typename DType, typename pdfgrad>
546 struct PdfGradCaller<xpu, DType, pdfgrad, 1, true> {
547   static void op(const std::vector<TBlob>& inputs,
548                  const std::vector<OpReqType>& req,
549                  const std::vector<TBlob>& grads,
550                  mshadow::Stream<xpu> *s) {
551     index_t num_samples(inputs[1].Size() / inputs[2].Size());
552     index_t sample_size(inputs[1].Size() / inputs[0].Size());
553     mxnet_op::Kernel<LaunchExWrapper<pdfgrad>, xpu>::LaunchEx(s, inputs[0].Size(), num_samples,
554                 req[0], sample_size,
555                 inputs[3].dptr<DType>(), inputs[1].dptr<DType>(), inputs[2].dptr<DType>(),
556                 inputs[0].dptr<DType>(), grads[0].dptr<DType>(), grads[1].dptr<DType>());
557   }
558 };
559 
560 template<typename xpu, typename DType, typename pdfgrad>
561 struct PdfGradCaller<xpu, DType, pdfgrad, 2, false> {
562   static void op(const std::vector<TBlob>& inputs,
563                  const std::vector<OpReqType>& req,
564                  const std::vector<TBlob>& grads,
565                  mshadow::Stream<xpu> *s) {
566     index_t num_samples(inputs[1].Size() / inputs[2].Size());
567     mxnet_op::Kernel<LaunchExWrapper<pdfgrad>, xpu>::LaunchEx(s, inputs[0].Size(),
568                 num_samples, req[0],
569                 inputs[4].dptr<DType>(), inputs[1].dptr<DType>(), inputs[2].dptr<DType>(),
570                 inputs[3].dptr<DType>(), inputs[0].dptr<DType>(),
571                 grads[0].dptr<DType>(), grads[1].dptr<DType>(), grads[2].dptr<DType>());
572   }
573 };
574 
575 template<typename xpu, template<bool> class pdfgrad, int pnum, bool vparm>
576 void PdfOpBackward(const nnvm::NodeAttrs& attrs,
577                    const OpContext& ctx,
578                    const std::vector<TBlob>& inputs,
579                    const std::vector<OpReqType>& req,
580                    const std::vector<TBlob>& outputs) {
581   using namespace mshadow;
582   CHECK_EQ(inputs.size(), pnum + 3);
583   CHECK_EQ(outputs.size(), pnum + 1);
584   mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
585   const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
586   const size_t N(outputs[1].Size());
587   const TShape src_shape(Shape2(N, outputs[0].Size() / N)), dst_shape(Shape2(N, 1));
588   // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf.
589   MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
590     const size_t red_work_size(broadcast::ReduceWorkspaceSize<2, DType>(
591             s, dst_shape, kAddTo, src_shape));
592     const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size);
593     Tensor<xpu, 1, char> tmp_space =
594             ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tmp_size), s);
595     std::vector<TBlob> grads = {outputs[0]};
596     grads.push_back(TBlob(tmp_space.dptr_, outputs[0].shape_,
597                           outputs[1].dev_mask(), outputs[1].type_flag_, outputs[1].dev_id()));
598     if (pnum == 2) {
599       grads.push_back(TBlob(tmp_space.dptr_ + outputs[0].Size() * sizeof(DType), outputs[0].shape_,
600                             outputs[2].dev_mask(), outputs[2].type_flag_, outputs[2].dev_id()));
601     }
602     if (param.is_log) {
603       PdfGradCaller<xpu, DType, pdfgrad<true>, pnum, vparm>::op(inputs, req, grads, s);
604     } else {
605       PdfGradCaller<xpu, DType, pdfgrad<false>, pnum, vparm>::op(inputs, req, grads, s);
606     }
607     Tensor<xpu, 1, char> red_work(
608             tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s);
609     broadcast::Reduce<red::sum, 2, DType, op::mshadow_op::identity>(
610        s, outputs[1].reshape(dst_shape), req[1], red_work, grads[1].reshape(src_shape));
611     if (pnum == 2) {
612       broadcast::Reduce<red::sum, 2, DType, op::mshadow_op::identity>(
613        s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape));
614     }
615   });
616 }
617 
618 }  // namespace op
619 }  // namespace mxnet
620 
621 #endif  // MXNET_OPERATOR_RANDOM_PDF_OP_H_
622