1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file pdf_op.h
22 * \brief Operators for computing the pdf of random distributions.
23 */
24 #ifndef MXNET_OPERATOR_RANDOM_PDF_OP_H_
25 #define MXNET_OPERATOR_RANDOM_PDF_OP_H_
26
27 #include <mxnet/operator_util.h>
28 #include <vector>
29 #include <algorithm>
30 #include "../mshadow_op.h"
31 #include "../mxnet_op.h"
32 #include "../operator_common.h"
33 #include "../elemwise_op_common.h"
34 #include "../special_functions-inl.h"
35 #include "../tensor/broadcast_reduce_op.h"
36
37 namespace mxnet {
38 namespace op {
39
40 template<typename DType>
ceph_psi(DType val)41 MSHADOW_XINLINE static DType ceph_psi(DType val) { return special_functions::cephes::psi(val); }
42 template<>
ceph_psi(mshadow::half::half_t val)43 MSHADOW_XINLINE mshadow::half::half_t ceph_psi(mshadow::half::half_t val) {
44 return special_functions::cephes::psi<float>(val);
45 }
46
47 template<bool logpdf>
48 struct PDF_Uniform {
49 template<typename DType, typename IType1, typename IType2>
MapPDF_Uniform50 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
51 DType *out, IType1 *sample, IType2 *lower, IType2 *upper) {
52 const index_t index(start / sample_size);
53 const DType l(lower[index]), h(upper[index]);
54 const index_t end = start + length;
55 for (index_t i = start; i < end; ++i) {
56 // No check whether sample is in the support.
57 out[i] = logpdf ? -DType(log(h - l)) : DType(1.0) / (h - l);
58 }
59 }
60 };
61
62 template<bool logpdf>
63 struct PDF_Uniform_Grad {
64 template<typename DType, typename IType1, typename IType2>
MapPDF_Uniform_Grad65 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
66 DType *out, IType1 *sample, IType2 *lower, IType2 *upper,
67 DType *grad_out, IType1 *grad_sample, IType2 *grad_lower, IType2 *grad_upper) {
68 const index_t index(start / sample_size);
69 const DType l(lower[index]), h(upper[index]);
70
71 const index_t end = start + length;
72 for (index_t i = start; i < end; ++i) {
73 const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
74 grad_lower[i] = scaling / (h - l);
75 grad_upper[i] = scaling / (l - h);
76 KERNEL_ASSIGN(grad_sample[i], req, 0);
77 }
78 }
79 };
80
81 template<bool logpdf>
82 struct PDF_Normal {
83 template<typename DType, typename IType1, typename IType2>
MapPDF_Normal84 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
85 DType *out, IType1 *sample, IType2 *loc, IType2 *scale) {
86 const index_t index(start / sample_size);
87 const DType u(loc[index]), s(scale[index]), sq(s * s);
88 const DType normalizer(sqrt(2.0 * mxnet_op::PI) * s);
89
90 const index_t end = start + length;
91 for (index_t i = start; i < end; ++i) {
92 const DType x(sample[i]);
93 const DType exponent((DType(-0.5) * (x - u) * (x - u)) / (sq));
94 out[i] = logpdf ? exponent - log(normalizer) : exp(exponent) / normalizer;
95 }
96 }
97 };
98
99 template<bool logpdf>
100 struct PDF_Normal_Grad {
101 template<typename DType, typename IType1, typename IType2>
MapPDF_Normal_Grad102 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
103 DType *out, IType1 *sample, IType2 *loc, IType2 *scale,
104 DType *grad_out, IType1 *grad_sample, IType2 *grad_loc, IType2 *grad_scale) {
105 const index_t index(start / sample_size);
106 const DType u(loc[index]), s(scale[index]), s_squared(s * s), s_cubed(s_squared * s);
107
108 const index_t end = start + length;
109 for (index_t i = start; i < end; ++i) {
110 const DType x(sample[i]);
111 const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
112 grad_loc[i] = scaling * (x - u) / s_squared;
113 grad_scale[i] = scaling * ((x - u) * (x - u) - s_squared) / s_cubed;
114 KERNEL_ASSIGN(grad_sample[i], req, scaling * (u - x) / s_squared);
115 }
116 }
117 };
118
119 template<bool logpdf>
120 struct PDF_Gamma {
121 template<typename DType, typename IType1, typename IType2>
MapPDF_Gamma122 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
123 DType *out, IType1 *sample, IType2 *alpha, IType2 *beta) {
124 const index_t index(start / sample_size);
125 const DType a(alpha[index]), b(beta[index]), lgamma_a(lgamma(a)), a_log_b(a * log(b));
126
127 const index_t end = start + length;
128 for (index_t i = start; i < end; ++i) {
129 const DType x(sample[i]);
130 const DType lpdf(a_log_b + (a - 1) * log(x) - b * x - lgamma_a);
131 out[i] = logpdf ? lpdf : DType(exp(lpdf));
132 }
133 }
134 };
135
136 template<bool logpdf>
137 struct PDF_Gamma_Grad {
138 template<typename DType, typename IType1, typename IType2>
MapPDF_Gamma_Grad139 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
140 DType *out, IType1 *sample, IType2 *alpha, IType2 *beta,
141 DType *grad_out, IType1 *grad_sample, IType2 *grad_alpha, IType2 *grad_beta) {
142 const index_t index(start / sample_size);
143 const DType a(alpha[index]), b(beta[index]), log_b(log(b)), ceph_psi_a(ceph_psi(a));
144
145 const index_t end = start + length;
146 for (index_t i = start; i < end; ++i) {
147 const DType x(sample[i]);
148 const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
149 grad_alpha[i] = scaling * (log_b + log(x) - ceph_psi_a);
150 grad_beta[i] = scaling * (a / b - x);
151 KERNEL_ASSIGN(grad_sample[i], req, scaling * ((a - 1) / x - b));
152 }
153 }
154 };
155
156 template<bool logpdf>
157 struct PDF_Exponential {
158 template<typename DType, typename IType1, typename IType2>
MapPDF_Exponential159 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
160 DType *out, IType1 *sample, IType2 *lambda) {
161 const index_t index(start / sample_size);
162 const DType l(lambda[index]), log_l(log(l));
163
164 const index_t end = start + length;
165 for (index_t i = start; i < end; ++i) {
166 const DType x(sample[i]);
167 out[i] = logpdf ? log_l - l * x : l * exp(-l * x);
168 }
169 }
170 };
171
172 template<bool logpdf>
173 struct PDF_Exponential_Grad {
174 template<typename DType, typename IType1, typename IType2>
MapPDF_Exponential_Grad175 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
176 DType *out, IType1 *sample, IType2 *lambda,
177 DType *grad_out, IType1 *grad_sample, IType2 *grad_lambda) {
178 const index_t index(start / sample_size);
179 const DType l(lambda[index]);
180
181 const index_t end = start + length;
182 for (index_t i = start; i < end; ++i) {
183 const DType x(sample[i]);
184 const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
185 grad_lambda[i] = scaling * (DType(1) / l - x);
186 KERNEL_ASSIGN(grad_sample[i], req, -scaling * l);
187 }
188 }
189 };
190
191 template<bool logpdf>
192 struct PDF_Poisson {
193 template<typename DType, typename IType1, typename IType2>
MapPDF_Poisson194 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
195 DType *out, IType1 *sample, IType2 *lambda) {
196 const index_t index(start / sample_size);
197 const DType l(lambda[index]), log_l(log(l));
198
199 const index_t end = start + length;
200 for (index_t i = start; i < end; ++i) {
201 const DType x(sample[i]);
202 const DType lpdf((x * log_l - lgamma(x + 1)) - l);
203 out[i] = logpdf ? lpdf : DType(exp(lpdf));
204 }
205 }
206 };
207
208 template<bool logpdf>
209 struct PDF_Poisson_Grad {
210 template<typename DType, typename IType1, typename IType2>
MapPDF_Poisson_Grad211 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
212 DType *out, IType1 *sample, IType2 *lambda,
213 DType *grad_out, IType1 *grad_sample, IType2 *grad_lambda) {
214 const index_t index(start / sample_size);
215 const DType l(lambda[index]);
216
217 const index_t end = start + length;
218 for (index_t i = start; i < end; ++i) {
219 const DType x(sample[i]);
220 const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
221 grad_lambda[i] = scaling * (x / l - DType(1));
222 KERNEL_ASSIGN(grad_sample[i], req, 0);
223 }
224 }
225 };
226
227
228 template<bool logpdf>
229 struct PDF_NegativeBinomial {
230 template<typename DType, typename IType1, typename IType2>
MapPDF_NegativeBinomial231 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
232 DType *out, IType1 *sample, IType2 *limit, IType2 *prob) {
233 const index_t index(start / sample_size);
234 const DType l(limit[index]), p(prob[index]), lgamma_l(lgamma(l));
235
236 const index_t end = start + length;
237 for (index_t i = start; i < end; ++i) {
238 const DType x(sample[i]);
239 const DType lpdf((lgamma(x + l) - lgamma(x + 1) - lgamma_l) + l * log(p) + x * log(1 - p));
240 out[i] = logpdf ? lpdf : DType(exp(lpdf));
241 }
242 }
243
244 template<typename DType>
LPDFPDF_NegativeBinomial245 MSHADOW_XINLINE static DType LPDF(DType l, DType p, DType x) {
246 // Note that "p" is the failure and not the success probability.
247 return (lgamma(x + l) - lgamma(x + 1) - lgamma(l)) + l * log(p) + x * log(1 - p);
248 }
249 };
250
251 template<bool logpdf>
252 struct PDF_NegativeBinomial_Grad {
253 template<typename DType, typename IType1, typename IType2>
MapPDF_NegativeBinomial_Grad254 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
255 DType *out, IType1 *sample, IType2 *limit, IType2 *prob,
256 DType *grad_out, IType1 *grad_sample, IType2 *grad_limit, IType2 *grad_prob) {
257 const index_t index(start / sample_size);
258 const index_t end = start + length;
259 for (index_t i = start; i < end; ++i) {
260 DType grad_l(0), grad_p(0);
261 LPDF_GRAD(DType(limit[index]), DType(prob[index]),
262 DType(sample[i]), out[i],
263 grad_out[i], &grad_l, &grad_p);
264 grad_limit[i] = grad_l;
265 grad_prob[i] = grad_p;
266 KERNEL_ASSIGN(grad_sample[i], req, 0);
267 }
268 }
269
270 template<typename DType>
LPDF_GRADPDF_NegativeBinomial_Grad271 MSHADOW_XINLINE static void LPDF_GRAD(DType l, DType p, DType x, DType o, DType grad_o,
272 DType* grad_l, DType* grad_p) {
273 const DType scaling(grad_o * (logpdf ? DType(1) : o));
274 *grad_l = scaling * ((ceph_psi(x + l) - ceph_psi(l)) + log(p));
275 *grad_p = scaling * (l / p - x / (1 - p));
276 }
277 };
278
279 template<bool logpdf>
280 struct PDF_GeneralizedNegativeBinomial {
281 template<typename DType, typename IType1, typename IType2>
MapPDF_GeneralizedNegativeBinomial282 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
283 DType *out, IType1 *sample, IType2 *mu, IType2 *alpha) {
284 const index_t index(start / sample_size);
285
286 // Reparameterize with limit = 1 / alpha, prob = 1 / (mu * alpha + 1)
287 const DType limit(1.0 / alpha[index]), prob(1.0 / (mu[index]*alpha[index]+1.0));
288
289 const index_t end = start + length;
290 for (index_t i = start; i < end; ++i) {
291 const DType lpdf(PDF_NegativeBinomial<logpdf>::LPDF(limit, prob, DType(sample[i])));
292 out[i] = logpdf ? lpdf : DType(exp(lpdf));
293 }
294 }
295 };
296
297 template<bool logpdf>
298 struct PDF_GeneralizedNegativeBinomial_Grad {
299 template<typename DType, typename IType1, typename IType2>
MapPDF_GeneralizedNegativeBinomial_Grad300 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, OpReqType req,
301 DType *out, IType1 *sample, IType2 *mu, IType2 *alpha,
302 DType *grad_out, IType1 *grad_sample, IType2 *grad_mu, IType2 *grad_alpha) {
303 const index_t index(start / sample_size);
304 const DType fmu(mu[index]), falpha(alpha[index]), den(fmu * falpha + 1.0);
305
306 // Reparameterize with limit = 1 / alpha, prob = 1 / (mu * alpha + 1)
307 const DType limit(1.0 / falpha), prob(1.0 / (fmu * falpha + 1.0));
308
309 const index_t end = start + length;
310 for (index_t i = start; i < end; ++i) {
311 // Grad returned as d_limit, d_prob
312 DType grad_l(0), grad_p(0);
313 PDF_NegativeBinomial_Grad<logpdf>::LPDF_GRAD(limit, prob,
314 DType(sample[i]), out[i],
315 grad_out[i], &grad_l, &grad_p);
316 grad_mu[i] = -grad_p * falpha / (den * den);
317 grad_alpha[i] = -grad_l / (falpha * falpha) - grad_p * fmu / (den * den);
318 KERNEL_ASSIGN(grad_sample[i], req, 0);
319 }
320 }
321 };
322
323 template<bool logpdf>
324 struct PDF_Dirichlet {
325 template<typename DType, typename IType1, typename IType2>
MapPDF_Dirichlet326 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size, index_t k,
327 DType *out, IType1 *sample, IType2 *alpha) {
328 const index_t index(start / sample_size);
329 const index_t end = start + length;
330 for (index_t i = start; i < end; ++i) {
331 const IType1 *cur_sample = sample + i * k;
332 const IType2 *cur_alpha = alpha + index * k;
333 DType sum_alpha(0), sum_lgamma(0), sum_sample(0);
334 for (index_t j = 0; j < k; ++j) {
335 sum_alpha += cur_alpha[j];
336 sum_lgamma += lgamma(cur_alpha[j]);
337 sum_sample += (cur_alpha[j]-1) * log(cur_sample[j]);
338 }
339 DType lpdf(sum_sample + (lgamma(sum_alpha) - sum_lgamma));
340 out[i] = logpdf ? lpdf : DType(exp(lpdf));
341 }
342 }
343 };
344
345
346 template<bool logpdf>
347 struct PDF_Dirichlet_Grad {
348 template<typename DType, typename IType1, typename IType2>
MapPDF_Dirichlet_Grad349 MSHADOW_XINLINE static void Map(index_t start, index_t length, index_t sample_size,
350 OpReqType req, index_t k,
351 DType *out, IType1 *sample, IType2 *alpha,
352 DType *grad_out, IType1 *grad_sample, IType2 *grad_alpha
353 ) {
354 const index_t index(start / sample_size);
355 const index_t end = start + length;
356
357 for (index_t i = start; i < end; ++i) {
358 // Digamma function
359 const IType1 *cur_sample = sample + i * k;
360 const IType2 *cur_alpha = alpha + index * k;
361
362 const DType scaling(grad_out[i]*(logpdf ? DType(1) : out[i]));
363 DType sum_alpha(0);
364 for (index_t j = 0; j < k; ++j) {
365 sum_alpha += cur_alpha[j];
366 }
367 const DType psi_sum(ceph_psi(sum_alpha));
368
369 for (index_t j = 0; j < k; ++j) {
370 size_t grad_alpha_index = i%sample_size + sample_size * (j + k * index);
371 size_t grad_sample_index = i * k + j;
372
373 // order grad_alpha differently to allow efficient reduction at the end.
374 grad_alpha[grad_alpha_index] =
375 scaling * (log(cur_sample[j]) + (psi_sum - ceph_psi(cur_alpha[j])));
376 KERNEL_ASSIGN(grad_sample[grad_sample_index],
377 req, scaling * (cur_alpha[j]-1) / cur_sample[j]);
378 }
379 }
380 }
381 };
382
383 struct PdfParam : public dmlc::Parameter<PdfParam> {
384 bool is_log;
DMLC_DECLARE_PARAMETERPdfParam385 DMLC_DECLARE_PARAMETER(PdfParam) {
386 DMLC_DECLARE_FIELD(is_log).set_default(false)
387 .describe("If set, compute the density of the log-probability instead of the probability.");
388 }
389 };
390
391 template<bool vparm = false>
PdfOpShape(const nnvm::NodeAttrs & attrs,std::vector<TShape> * in_attrs,std::vector<TShape> * out_attrs)392 inline bool PdfOpShape(const nnvm::NodeAttrs& attrs,
393 std::vector<TShape>* in_attrs,
394 std::vector<TShape>* out_attrs) {
395 CHECK_GT(in_attrs->size(), 1)
396 << "pdf operator takes at least 2 arguments (" << in_attrs->size() << " given)";
397 CHECK_EQ(out_attrs->size(), 1);
398 // All inputs must be defined in order to infer output shape.
399 if ( std::all_of((*in_attrs).begin(),
400 (*in_attrs).end(),
401 [](const TShape& s){ return s.ndim() > 0; }) ) {
402 // Tensors of distribution parameters must have same shape.
403 for (size_t i = 2; i < in_attrs->size(); ++i) {
404 SHAPE_ASSIGN_CHECK(*in_attrs, i, (*in_attrs)[i - 1]);
405 }
406 // Tensors of distribution parameters must match leftmost subshape of samples.
407 CHECK_LE((*in_attrs)[1].ndim(), (*in_attrs)[0].ndim())
408 << "dimension of input samples (" << (*in_attrs)[0].ndim()
409 << ") must be at least dimension of distribution parameters ("<< (*in_attrs)[1].ndim() << ")";
410 TShape tshape((*in_attrs)[0].begin(), (*in_attrs)[0].begin() + (*in_attrs)[1].ndim());
411 if (vparm) {
412 *(tshape.end() - 1) = *((*in_attrs)[0].end() - 1);
413 }
414 for (size_t i = 1; i < in_attrs->size(); ++i) {
415 SHAPE_ASSIGN_CHECK(*in_attrs, i, tshape);
416 }
417 // Output shape must equal input tensor of samples except for last dimension if we are
418 // dealing with samples that are itself vectors. Be aware of the special case where we
419 // are dealing with a single vector sample.
420 if (vparm && ((*in_attrs)[0].ndim() == 1)) {
421 // Special case where we are dealing with a single vector sample.
422 SHAPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::Shape1(1));
423 } else {
424 TShape oshape((*in_attrs)[0].begin(), (*in_attrs)[0].end() - (vparm ? 1 : 0));
425 SHAPE_ASSIGN_CHECK(*out_attrs, 0, oshape);
426 }
427 return true;
428 }
429 return false;
430 }
431
432 template<typename OP>
433 struct LaunchExWrapper {
434 template<typename ...Args>
MapLaunchExWrapper435 MSHADOW_XINLINE static void Map(const index_t start, const index_t length,
436 const index_t sample_size, Args... args) {
437 // Apply the operator to the sample in strides of sample_size, so that
438 // the operators can assume that their distribution parameters are constant.
439 index_t i = start;
440
441 // Get aligned
442 const index_t align_step = sample_size - (i % sample_size);
443 const index_t first_stride = length > align_step ? align_step : length;
444 OP::Map(i, first_stride, sample_size, args...);
445 i += first_stride;
446
447 const index_t end = start + length - sample_size;
448 for (; i < end; i += sample_size) {
449 OP::Map(i, sample_size, sample_size, args...);
450 }
451
452 // Last stride might not be aligned either
453 const index_t last_stride = start + length - i;
454 if (last_stride > 0) { // Don't overstep even if length <= sample_size
455 OP::Map(i, last_stride, sample_size, args...);
456 }
457 }
458 };
459
460 template<typename xpu, typename DType, typename pdf, int pnum, bool vparm = false>
461 struct PdfCaller;
462
463 template<typename xpu, typename DType, typename pdf>
464 struct PdfCaller<xpu, DType, pdf, 1, false> {
465 static void op(const std::vector<TBlob>& inputs,
466 const std::vector<TBlob>& outputs,
467 mshadow::Stream<xpu> *s) {
468 CHECK_EQ(inputs[0].Size()%inputs[1].Size(), 0);
469 CHECK_EQ(inputs[0].Size()%outputs[0].Size(), 0);
470 index_t num_samples(inputs[0].Size() / inputs[1].Size());
471 mxnet_op::Kernel<LaunchExWrapper<pdf>, xpu>::LaunchEx(s, outputs[0].Size(), num_samples,
472 outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
473 }
474 };
475
476 template<typename xpu, typename DType, typename pdf>
477 struct PdfCaller<xpu, DType, pdf, 1, true> {
478 static void op(const std::vector<TBlob>& inputs,
479 const std::vector<TBlob>& outputs,
480 mshadow::Stream<xpu> *s) {
481 CHECK_EQ(inputs[0].Size()%inputs[1].Size(), 0);
482 CHECK_EQ(inputs[0].Size()%outputs[0].Size(), 0);
483 index_t num_samples(inputs[0].Size() / inputs[1].Size());
484 index_t sample_size(inputs[0].Size() / outputs[0].Size());
485
486 // Covers distributions parametrized by a vector of parameters (Dirichlet distribution).
487 mxnet_op::Kernel<LaunchExWrapper<pdf>, xpu>::LaunchEx(s, outputs[0].Size(),
488 num_samples, sample_size,
489 outputs[0].dptr<DType>(), inputs[0].dptr<DType>(), inputs[1].dptr<DType>());
490 }
491 };
492
493 template<typename xpu, typename DType, typename pdf>
494 struct PdfCaller<xpu, DType, pdf, 2, false> {
495 static void op(const std::vector<TBlob>& inputs,
496 const std::vector<TBlob>& outputs,
497 mshadow::Stream<xpu> *s) {
498 CHECK_EQ(inputs[0].Size()%inputs[1].Size(), 0);
499 CHECK_EQ(inputs[0].Size(), outputs[0].Size());
500 index_t num_samples(inputs[0].Size() / inputs[1].Size());
501 mxnet_op::Kernel<LaunchExWrapper<pdf>, xpu>::LaunchEx(s, outputs[0].Size(), num_samples,
502 outputs[0].dptr<DType>(), inputs[0].dptr<DType>(),
503 inputs[1].dptr<DType>(), inputs[2].dptr<DType>());
504 }
505 };
506
507 template<typename xpu, template<bool> class pdf, int pnum, bool vparm>
508 void PdfOpForward(const nnvm::NodeAttrs& attrs,
509 const OpContext& ctx,
510 const std::vector<TBlob>& inputs,
511 const std::vector<OpReqType>& req,
512 const std::vector<TBlob>& outputs) {
513 CHECK_NE(req[0], kAddTo);
514 CHECK_EQ(inputs.size(), pnum + 1);
515 CHECK_EQ(outputs.size(), 1);
516 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
517 const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
518 MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
519 if ( param.is_log ) {
520 PdfCaller<xpu, DType, pdf<true>, pnum, vparm>::op(inputs, outputs, s);
521 } else {
522 PdfCaller<xpu, DType, pdf<false>, pnum, vparm>::op(inputs, outputs, s);
523 }
524 });
525 }
526
527
528 template<typename xpu, typename DType, typename pdfgrad, int pnum, int vparm = false>
529 struct PdfGradCaller;
530
531 template<typename xpu, typename DType, typename pdfgrad>
532 struct PdfGradCaller<xpu, DType, pdfgrad, 1, false> {
533 static void op(const std::vector<TBlob>& inputs,
534 const std::vector<OpReqType>& req,
535 const std::vector<TBlob>& grads,
536 mshadow::Stream<xpu> *s) {
537 index_t num_samples(inputs[1].Size() / inputs[2].Size());
538 mxnet_op::Kernel<LaunchExWrapper<pdfgrad>, xpu>::LaunchEx(s, inputs[0].Size(),
539 num_samples, req[0],
540 inputs[3].dptr<DType>(), inputs[1].dptr<DType>(), inputs[2].dptr<DType>(),
541 inputs[0].dptr<DType>(), grads[0].dptr<DType>(), grads[1].dptr<DType>());
542 }
543 };
544
545 template<typename xpu, typename DType, typename pdfgrad>
546 struct PdfGradCaller<xpu, DType, pdfgrad, 1, true> {
547 static void op(const std::vector<TBlob>& inputs,
548 const std::vector<OpReqType>& req,
549 const std::vector<TBlob>& grads,
550 mshadow::Stream<xpu> *s) {
551 index_t num_samples(inputs[1].Size() / inputs[2].Size());
552 index_t sample_size(inputs[1].Size() / inputs[0].Size());
553 mxnet_op::Kernel<LaunchExWrapper<pdfgrad>, xpu>::LaunchEx(s, inputs[0].Size(), num_samples,
554 req[0], sample_size,
555 inputs[3].dptr<DType>(), inputs[1].dptr<DType>(), inputs[2].dptr<DType>(),
556 inputs[0].dptr<DType>(), grads[0].dptr<DType>(), grads[1].dptr<DType>());
557 }
558 };
559
560 template<typename xpu, typename DType, typename pdfgrad>
561 struct PdfGradCaller<xpu, DType, pdfgrad, 2, false> {
562 static void op(const std::vector<TBlob>& inputs,
563 const std::vector<OpReqType>& req,
564 const std::vector<TBlob>& grads,
565 mshadow::Stream<xpu> *s) {
566 index_t num_samples(inputs[1].Size() / inputs[2].Size());
567 mxnet_op::Kernel<LaunchExWrapper<pdfgrad>, xpu>::LaunchEx(s, inputs[0].Size(),
568 num_samples, req[0],
569 inputs[4].dptr<DType>(), inputs[1].dptr<DType>(), inputs[2].dptr<DType>(),
570 inputs[3].dptr<DType>(), inputs[0].dptr<DType>(),
571 grads[0].dptr<DType>(), grads[1].dptr<DType>(), grads[2].dptr<DType>());
572 }
573 };
574
575 template<typename xpu, template<bool> class pdfgrad, int pnum, bool vparm>
576 void PdfOpBackward(const nnvm::NodeAttrs& attrs,
577 const OpContext& ctx,
578 const std::vector<TBlob>& inputs,
579 const std::vector<OpReqType>& req,
580 const std::vector<TBlob>& outputs) {
581 using namespace mshadow;
582 CHECK_EQ(inputs.size(), pnum + 3);
583 CHECK_EQ(outputs.size(), pnum + 1);
584 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
585 const PdfParam& param = nnvm::get<PdfParam>(attrs.parsed);
586 const size_t N(outputs[1].Size());
587 const TShape src_shape(Shape2(N, outputs[0].Size() / N)), dst_shape(Shape2(N, 1));
588 // Inputs to PdfOpBackward: grad, samples, parm1, parm2, pdf.
589 MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
590 const size_t red_work_size(broadcast::ReduceWorkspaceSize<2, DType>(
591 s, dst_shape, kAddTo, src_shape));
592 const size_t tmp_size(outputs[0].Size() * pnum * sizeof(DType) + red_work_size);
593 Tensor<xpu, 1, char> tmp_space =
594 ctx.requested[0].get_space_typed<xpu, 1, char>(Shape1(tmp_size), s);
595 std::vector<TBlob> grads = {outputs[0]};
596 grads.push_back(TBlob(tmp_space.dptr_, outputs[0].shape_,
597 outputs[1].dev_mask(), outputs[1].type_flag_, outputs[1].dev_id()));
598 if (pnum == 2) {
599 grads.push_back(TBlob(tmp_space.dptr_ + outputs[0].Size() * sizeof(DType), outputs[0].shape_,
600 outputs[2].dev_mask(), outputs[2].type_flag_, outputs[2].dev_id()));
601 }
602 if (param.is_log) {
603 PdfGradCaller<xpu, DType, pdfgrad<true>, pnum, vparm>::op(inputs, req, grads, s);
604 } else {
605 PdfGradCaller<xpu, DType, pdfgrad<false>, pnum, vparm>::op(inputs, req, grads, s);
606 }
607 Tensor<xpu, 1, char> red_work(
608 tmp_space.dptr_ + pnum * outputs[0].Size() * sizeof(DType), Shape1(red_work_size), s);
609 broadcast::Reduce<red::sum, 2, DType, op::mshadow_op::identity>(
610 s, outputs[1].reshape(dst_shape), req[1], red_work, grads[1].reshape(src_shape));
611 if (pnum == 2) {
612 broadcast::Reduce<red::sum, 2, DType, op::mshadow_op::identity>(
613 s, outputs[2].reshape(dst_shape), req[2], red_work, grads[2].reshape(src_shape));
614 }
615 });
616 }
617
618 } // namespace op
619 } // namespace mxnet
620
621 #endif // MXNET_OPERATOR_RANDOM_PDF_OP_H_
622