1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements.  See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership.  The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License.  You may obtain a copy of the License at
9 *
10 *   http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied.  See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19 
20 /*!
21 * \file image_random-inl.h
22 * \brief
23 * \author
24 */
25 #ifndef MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
26 #define MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
27 
28 
29 #include <algorithm>
30 #include <cmath>
31 #include <limits>
32 #include <tuple>
33 #include <utility>
34 #include <vector>
35 #include "mxnet/base.h"
36 #include "../mxnet_op.h"
37 #include "../operator_common.h"
38 #if MXNET_USE_OPENCV
39   #include <opencv2/opencv.hpp>
40 #endif  // MXNET_USE_OPENCV
41 
42 namespace mxnet {
43 namespace op {
44 namespace image {
45 
46 using namespace mshadow;
47 
48 #if MXNET_USE_CUDA
49 // NOTE: Kernel launch/map was extremely costly.
50 // Hence, we use separate CUDA kernels for these operators.
51 template<typename DType, typename T1, typename T2>
52 void ToTensorImplCUDA(mshadow::Stream<gpu> *s,
53                       const T1 input,
54                       const T2 output,
55                       const int req,
56                       const float normalize_factor);
57 
58 template<typename DType>
59 void NormalizeImplCUDA(mshadow::Stream<gpu> *s,
60                        const DType *input,
61                        DType *output,
62                        const int req,
63                        const int N,
64                        const int C,
65                        const int H,
66                        const int W,
67                        const float mean_d0,
68                        const float mean_d1,
69                        const float mean_d2,
70                        const float std_d0,
71                        const float std_d1,
72                        const float std_d2);
73 
74 template<typename DType>
75 void NormalizeBackwardImplCUDA(mshadow::Stream<gpu> *s,
76                                const DType *out_grad,
77                                DType *in_grad,
78                                const int req,
79                                const int N,
80                                const int C,
81                                const int H,
82                                const int W,
83                                const float std_d0,
84                                const float std_d1,
85                                const float std_d2);
86 #endif  // MXNET_USE_CUDA
87 
88 // Shape and Type inference for image to tensor operator
ToTensorShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)89 inline bool ToTensorShape(const nnvm::NodeAttrs& attrs,
90                           mxnet::ShapeVector *in_attrs,
91                           mxnet::ShapeVector *out_attrs) {
92   CHECK_EQ(in_attrs->size(), 1U);
93   CHECK_EQ(out_attrs->size(), 1U);
94 
95   mxnet::TShape &shp = (*in_attrs)[0];
96   if (!shape_is_known(shp)) return false;
97 
98   CHECK((shp.ndim() == 3) || (shp.ndim() == 4))
99       << "Input image must have shape (height, width, channels), or "
100       << "(N, height, width, channels) but got " << shp;
101   if (shp.ndim() == 3) {
102     SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({shp[2], shp[0], shp[1]}));
103   } else if (shp.ndim() == 4) {
104     SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({shp[0], shp[3], shp[1], shp[2]}));
105   }
106 
107   return true;
108 }
109 
ToTensorType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)110 inline bool ToTensorType(const nnvm::NodeAttrs& attrs,
111                          std::vector<int> *in_attrs,
112                          std::vector<int> *out_attrs) {
113   CHECK_EQ(in_attrs->size(), 1U);
114   CHECK_EQ(out_attrs->size(), 1U);
115   TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kFloat32);
116   return (*in_attrs)[0] != -1;
117 }
118 
119 // Operator Implementation
120 template<typename DType, int req>
ToTensor(float * out_data,const DType * in_data,const int length,const int channels,const float normalize_factor,const int step)121 inline void ToTensor(float* out_data, const DType* in_data,
122                      const int length,
123                      const int channels,
124                      const float normalize_factor,
125                      const int step) {
126   // Microsoft Visual C++ compiler does not support omp collapse
127   #ifdef _MSC_VER
128     #pragma omp parallel for
129   #else
130     #pragma omp parallel for collapse(2)
131   #endif  // _MSC_VER
132   for (int c = 0; c < channels; ++c) {
133       for (int i = 0; i < length; ++i) {
134         KERNEL_ASSIGN(out_data[step + c*length + i], req,
135                       (in_data[step + i*channels + c]) / normalize_factor);
136       }
137   }
138 }
139 
ToTensorImpl(const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs,const std::vector<OpReqType> & req,const int length,const int channel,const float normalize_factor,const int step)140 inline void ToTensorImpl(const std::vector<TBlob> &inputs,
141                          const std::vector<TBlob> &outputs,
142                          const std::vector<OpReqType> &req,
143                          const int length,
144                          const int channel,
145                          const float normalize_factor,
146                          const int step) {
147   MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
148     MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
149       float* output = outputs[0].dptr<float>();
150       DType* input = inputs[0].dptr<DType>();
151       ToTensor<DType, req_type>(output, input, length, channel,
152                                 normalize_factor, step);
153     });
154   });
155 }
156 
157 template<typename xpu>
ToTensorOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)158 void ToTensorOpForward(const nnvm::NodeAttrs &attrs,
159                        const OpContext &ctx,
160                        const std::vector<TBlob> &inputs,
161                        const std::vector<OpReqType> &req,
162                        const std::vector<TBlob> &outputs) {
163   CHECK_EQ(inputs.size(), 1U);
164   CHECK_EQ(outputs.size(), 1U);
165   CHECK_EQ(req.size(), 1U);
166 
167   // We do not use temp buffer when performance the operation.
168   // Hence, this check is necessary.
169   CHECK_EQ(req[0], kWriteTo)
170     << "`to_tensor` does not support inplace updates";
171 
172   const float normalize_factor = 255.0f;
173 
174   if (std::is_same<xpu, gpu>::value) {
175   #if MXNET_USE_CUDA
176       mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
177       MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
178         MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
179           if (inputs[0].ndim() == 3) {
180             Tensor<gpu, 3, DType> input = inputs[0].get<gpu, 3, DType>(s);
181             Tensor<gpu, 3, float> output = outputs[0].get<gpu, 3, float>(s);
182             ToTensorImplCUDA<DType, Tensor<gpu, 3, DType>, Tensor<gpu, 3, float>>
183             (s, input, output, req_type, normalize_factor);
184           } else {
185             Tensor<gpu, 4, DType> input = inputs[0].get<gpu, 4, DType>(s);
186             Tensor<gpu, 4, float> output = outputs[0].get<gpu, 4, float>(s);
187             ToTensorImplCUDA<DType, Tensor<gpu, 4, DType>, Tensor<gpu, 4, float>>
188             (s, input, output, req_type, normalize_factor);
189           }
190         });
191       });
192   #else
193     LOG(FATAL) << "Compile with USE_CUDA=1 to use ToTensor operator on GPU.";
194   #endif  // MXNET_USE_CUDA
195   } else if (inputs[0].ndim() == 3) {
196     // 3D Input - (h, w, c)
197     const int length = inputs[0].shape_[0] * inputs[0].shape_[1];
198     const int channel = static_cast<int>(inputs[0].shape_[2]);
199     const int step = 0;
200     ToTensorImpl(inputs, outputs, req, length,
201                  channel, normalize_factor, step);
202   } else if (inputs[0].ndim() == 4) {
203     // 4D input (n, h, w, c)
204     const int batch_size = inputs[0].shape_[0];
205     const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
206     const int channel = static_cast<int>(inputs[0].shape_[3]);
207     const int step = channel * length;
208 
209     #pragma omp parallel for
210     for (auto n = 0; n < batch_size; ++n) {
211       ToTensorImpl(inputs, outputs, req, length, channel,
212                    normalize_factor, n*step);
213     }
214   }
215 }
216 
217 struct NormalizeParam : public dmlc::Parameter<NormalizeParam> {
218   mxnet::Tuple<float> mean;
219   mxnet::Tuple<float> std;
220 
DMLC_DECLARE_PARAMETERNormalizeParam221   DMLC_DECLARE_PARAMETER(NormalizeParam) {
222     DMLC_DECLARE_FIELD(mean)
223     .set_default(mxnet::Tuple<float> {0.0f, 0.0f, 0.0f, 0.0f})
224     .describe("Sequence of means for each channel. "
225               "Default value is 0.");
226     DMLC_DECLARE_FIELD(std)
227     .set_default(mxnet::Tuple<float> {1.0f, 1.0f, 1.0f, 1.0f})
228     .describe("Sequence of standard deviations for each channel. "
229               "Default value is 1.");
230   }
231 };
232 
233 // Shape and Type inference for image Normalize operator
234 
235 // Shape inference
NormalizeOpShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)236 inline bool NormalizeOpShape(const nnvm::NodeAttrs& attrs,
237                           mxnet::ShapeVector *in_attrs,
238                           mxnet::ShapeVector *out_attrs) {
239   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
240 
241   const auto& dshape = (*in_attrs)[0];
242   if (!dshape.ndim()) return false;
243 
244   CHECK((dshape.ndim() == 3) || (dshape.ndim() == 4))
245       << "Input tensor must have shape (channels, height, width), or "
246       << "(N, channels, height, width), but got " << dshape;
247 
248   int nchannels = 0;
249   if (dshape.ndim() == 3) {
250     nchannels = dshape[0];
251     CHECK(nchannels == 3 || nchannels == 1)
252       << "The first dimension of input tensor must be the channel dimension with "
253       << "either 1 or 3 elements, but got input with shape " << dshape;
254   } else if (dshape.ndim() == 4) {
255     nchannels = dshape[1];
256     CHECK(nchannels == 3 || nchannels == 1)
257       << "The second dimension of input tensor must be the channel dimension with "
258       << "either 1 or 3 elements, but got input with shape " << dshape;
259   }
260 
261   CHECK((param.mean.ndim() == 1) || (param.mean.ndim() == nchannels))
262       << "Invalid mean for input with shape " << dshape
263       << ". mean must have either 1 or " << nchannels
264       << " elements, but got " << param.mean;
265   CHECK(param.std.ndim() == 1 || param.std.ndim() == nchannels)
266       << "Invalid std for input with shape " << dshape
267       << ". std must have either 1 or " << nchannels
268       << " elements, but got " << param.std;
269 
270   SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
271   return true;
272 }
273 
274 // Type Inference
NormalizeOpType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)275 inline bool NormalizeOpType(const nnvm::NodeAttrs& attrs,
276                           std::vector<int>* in_attrs,
277                           std::vector<int>* out_attrs) {
278   CHECK_EQ(in_attrs->size(), 1U);
279   CHECK_EQ(out_attrs->size(), 1U);
280 
281   TYPE_ASSIGN_CHECK(*out_attrs, 0, in_attrs->at(0));
282   TYPE_ASSIGN_CHECK(*in_attrs, 0, out_attrs->at(0));
283   return out_attrs->at(0) != -1;
284 }
285 
286 template<typename DType, int req>
Normalize(DType * out_data,const DType * in_data,const int length,const int channels,const int step,const std::vector<float> mean,const std::vector<float> std)287 inline void Normalize(DType* out_data,
288                       const DType* in_data,
289                       const int length,
290                       const int channels,
291                       const int step,
292                       const std::vector<float> mean,
293                       const std::vector<float> std) {
294   // Microsoft Visual C++ compiler does not support omp collapse
295   #ifdef _MSC_VER
296     #pragma omp parallel for
297   #else
298     #pragma omp parallel for collapse(2)
299   #endif  // _MSC_VER
300   for (int c = 0; c < channels; ++c) {
301     for (int i = 0; i < length; ++i) {
302       KERNEL_ASSIGN(out_data[step + c*length + i], req,
303                     (in_data[step + c*length + i] - mean[c]) / std[c]);
304     }
305   }
306 }
307 
NormalizeImpl(const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs,const std::vector<OpReqType> & req,const int length,const int channels,const int step,const std::vector<float> mean,const std::vector<float> std)308 inline void NormalizeImpl(const std::vector<TBlob> &inputs,
309                           const std::vector<TBlob> &outputs,
310                           const std::vector<OpReqType> &req,
311                           const int length,
312                           const int channels,
313                           const int step,
314                           const std::vector<float> mean,
315                           const std::vector<float> std) {
316   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
317     MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
318       DType* input = inputs[0].dptr<DType>();
319       DType* output = outputs[0].dptr<DType>();
320       Normalize<DType, req_type>(output, input, length, channels, step,
321                                  mean, std);
322     });
323   });
324 }
325 
326 template<typename xpu>
NormalizeOpForward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)327 void NormalizeOpForward(const nnvm::NodeAttrs &attrs,
328                         const OpContext &ctx,
329                         const std::vector<TBlob> &inputs,
330                         const std::vector<OpReqType> &req,
331                         const std::vector<TBlob> &outputs) {
332   CHECK_EQ(inputs.size(), 1U);
333   CHECK_EQ(outputs.size(), 1U);
334   CHECK_EQ(req.size(), 1U);
335 
336   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
337 
338   // Mean and Std can be 1 or 3D only.
339   std::vector<float> mean(3);
340   std::vector<float> std(3);
341   if (param.mean.ndim() == 1) {
342     mean[0] = mean[1] = mean[2] = param.mean[0];
343   } else {
344     mean[0] = param.mean[0];
345     mean[1] = param.mean[1];
346     mean[2] = param.mean[2];
347   }
348 
349   if (param.std.ndim() == 1) {
350     std[0] = std[1] = std[2] = param.std[0];
351   } else {
352     std[0] = param.std[0];
353     std[1] = param.std[1];
354     std[2] = param.std[2];
355   }
356 
357   if (std::is_same<xpu, gpu>::value) {
358     #if MXNET_USE_CUDA
359       mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
360       MSHADOW_TYPE_SWITCH(inputs[0].type_flag_, DType, {
361         MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
362           int N, C, H, W;
363           DType *input = nullptr;
364           DType *output = nullptr;
365           if (inputs[0].ndim() == 3) {
366             N = 1;
367             C = static_cast<int>(inputs[0].shape_[0]);
368             H = static_cast<int>(inputs[0].shape_[1]);
369             W = static_cast<int>(inputs[0].shape_[2]);
370             input = (inputs[0].get<gpu, 3, DType>(s)).dptr_;
371             output = (outputs[0].get<gpu, 3, DType>(s)).dptr_;
372           } else {
373             N = static_cast<int>(inputs[0].shape_[0]);
374             C = static_cast<int>(inputs[0].shape_[1]);
375             H = static_cast<int>(inputs[0].shape_[2]);
376             W = static_cast<int>(inputs[0].shape_[3]);
377             input = (inputs[0].get<gpu, 4, DType>(s)).dptr_;
378             output = (outputs[0].get<gpu, 4, DType>(s)).dptr_;
379           }
380           NormalizeImplCUDA<DType>(s, input, output, req_type,
381                                    N, C, H, W,
382                                    mean[0], mean[1], mean[2],
383                                    std[0], std[1], std[2]);
384         });
385       });
386     #else
387       LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize operator on GPU.";
388     #endif  // MXNET_USE_CUDA
389   } else if (inputs[0].ndim() == 3) {
390     // 3D input (c, h, w)
391     const int length = inputs[0].shape_[1] * inputs[0].shape_[2];
392     const int channel = static_cast<int>(inputs[0].shape_[0]);
393     const int step = 0;
394     NormalizeImpl(inputs, outputs, req, length, channel, step, mean, std);
395   } else if (inputs[0].ndim() == 4) {
396     // 4D input (n, c, h, w)
397     const int batch_size = inputs[0].shape_[0];
398     const int length = inputs[0].shape_[2] * inputs[0].shape_[3];
399     const int channel = static_cast<int>(inputs[0].shape_[1]);
400     const int step = channel * length;
401 
402     #pragma omp parallel for
403     for (auto n = 0; n < batch_size; ++n) {
404       NormalizeImpl(inputs, outputs, req, length, channel, n*step, mean, std);
405     }
406   }
407 }
408 
409 // Backward function
410 template<typename DType, int req>
NormalizeBackward(const DType * out_grad,DType * in_grad,const int length,const int channels,const int step,const std::vector<float> std)411 inline void NormalizeBackward(const DType* out_grad,
412                               DType* in_grad,
413                               const int length,
414                               const int channels,
415                               const int step,
416                               const std::vector<float> std) {
417   // Microsoft Visual C++ compiler does not support omp collapse
418   #ifdef _MSC_VER
419     #pragma omp parallel for
420   #else
421     #pragma omp parallel for collapse(2)
422   #endif  // _MSC_VER
423   for (int c = 0; c < channels; ++c) {
424     for (int i = 0; i < length; ++i) {
425       KERNEL_ASSIGN(in_grad[step + c*length + i], req,
426                     out_grad[step + c*length + i] * (1.0 / std[c]));
427     }
428   }
429 }
430 
NormalizeBackwardImpl(const std::vector<TBlob> & inputs,const std::vector<TBlob> & outputs,const std::vector<OpReqType> & req,const int length,const int channels,const int step,const std::vector<float> std)431 inline void NormalizeBackwardImpl(const std::vector<TBlob> &inputs,
432                                   const std::vector<TBlob> &outputs,
433                                   const std::vector<OpReqType> &req,
434                                   const int length,
435                                   const int channels,
436                                   const int step,
437                                   const std::vector<float> std
438                                   ) {
439     MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
440       MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
441         DType* out_grad = inputs[0].dptr<DType>();
442         DType* in_grad = outputs[0].dptr<DType>();
443         NormalizeBackward<DType, req_type>(out_grad, in_grad, length,
444                                            channels, step, std);
445       });
446     });
447 }
448 
449 template<typename xpu>
NormalizeOpBackward(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)450 void NormalizeOpBackward(const nnvm::NodeAttrs &attrs,
451                          const OpContext &ctx,
452                          const std::vector<TBlob> &inputs,
453                          const std::vector<OpReqType> &req,
454                          const std::vector<TBlob> &outputs) {
455   CHECK_EQ(inputs.size(), 2U);
456   CHECK_EQ(outputs.size(), 1U);
457   CHECK_EQ(req.size(), 1U);
458 
459   const NormalizeParam &param = nnvm::get<NormalizeParam>(attrs.parsed);
460   // Std can be 1 or 3D only.
461   std::vector<float> std(3);
462   if (param.std.ndim() == 1) {
463     std[0] = std[1] = std[2] = param.std[0];
464   } else {
465     std[0] = param.std[0];
466     std[1] = param.std[1];
467     std[2] = param.std[2];
468   }
469 
470   // Note: inputs[0] is out_grad
471   const TBlob& in_data = inputs[1];
472 
473   if (std::is_same<xpu, gpu>::value) {
474     #if MXNET_USE_CUDA
475       mshadow::Stream<gpu> *s = ctx.get_stream<gpu>();
476       MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
477         MXNET_ASSIGN_REQ_SWITCH(req[0], req_type, {
478           int N, C, H, W;
479           DType *in_grad = nullptr;
480           DType *out_grad = nullptr;
481           if (in_data.ndim() == 3) {
482             N = 1;
483             C = static_cast<int>(in_data.shape_[0]);
484             H = static_cast<int>(in_data.shape_[1]);
485             W = static_cast<int>(in_data.shape_[2]);
486             out_grad = (inputs[0].get<gpu, 3, DType>(s)).dptr_;
487             in_grad = (outputs[0].get<gpu, 3, DType>(s)).dptr_;
488           } else {
489             N = static_cast<int>(in_data.shape_[0]);
490             C = static_cast<int>(in_data.shape_[1]);
491             H = static_cast<int>(in_data.shape_[2]);
492             W = static_cast<int>(in_data.shape_[3]);
493             out_grad = (inputs[0].get<gpu, 4, DType>(s)).dptr_;
494             in_grad = (outputs[0].get<gpu, 4, DType>(s)).dptr_;
495           }
496           NormalizeBackwardImplCUDA<DType>(s, out_grad, in_grad, req_type,
497                                            N, C, H, W,
498                                            std[0], std[1], std[2]);
499         });
500       });
501     #else
502       LOG(FATAL) << "Compile with USE_CUDA=1 to use Normalize backward operator on GPU.";
503     #endif  // MXNET_USE_CUDA
504   } else if (in_data.ndim() == 3) {
505     // 3D input (c, h, w)
506     const int length = in_data.shape_[1] * in_data.shape_[2];
507     const int channel = static_cast<int>(in_data.shape_[0]);
508     const int step = 0;
509     NormalizeBackwardImpl(inputs, outputs, req, length, channel, step, std);
510   } else if (in_data.ndim() == 4) {
511     // 4D input (n, c, h, w)
512     const int batch_size = in_data.shape_[0];
513     const int length = in_data.shape_[2] * in_data.shape_[3];
514     const int channel = static_cast<int>(in_data.shape_[1]);
515     const int step = channel * length;
516 
517     #pragma omp parallel for
518     for (auto n = 0; n < batch_size; ++n) {
519       NormalizeBackwardImpl(inputs, outputs, req, length, channel, n*step, std);
520     }
521   }
522 }
523 
524 template<typename DType>
saturate_cast(const float & src)525 inline DType saturate_cast(const float& src) {
526   return static_cast<DType>(src);
527 }
528 
529 template<>
saturate_cast(const float & src)530 inline uint8_t saturate_cast(const float& src) {
531   return std::min(std::max(src, 0.f), 255.f);
532 }
533 
ImageShape(const nnvm::NodeAttrs & attrs,mxnet::ShapeVector * in_attrs,mxnet::ShapeVector * out_attrs)534 inline bool ImageShape(const nnvm::NodeAttrs& attrs,
535                        mxnet::ShapeVector *in_attrs,
536                        mxnet::ShapeVector *out_attrs) {
537   mxnet::TShape& dshape = (*in_attrs)[0];
538   CHECK_EQ(dshape.ndim(), 3)
539       << "Input image must have shape (height, width, channels), but got " << dshape;
540   auto nchannels = dshape[dshape.ndim()-1];
541   CHECK(nchannels == 3 || nchannels == 1)
542       << "The last dimension of input image must be the channel dimension with "
543       << "either 1 or 3 elements, but got input with shape " << dshape;
544   SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape);
545   return true;
546 }
547 
548 template<typename DType, int axis>
FlipImpl(const mxnet::TShape & shape,DType * src,DType * dst)549 void FlipImpl(const mxnet::TShape &shape, DType *src, DType *dst) {
550   int head = 1, mid = shape[axis], tail = 1;
551   for (int i = 0; i < axis; ++i) head *= shape[i];
552   for (int i = axis+1; i < shape.ndim(); ++i) tail *= shape[i];
553 
554   for (int i = 0; i < head; ++i) {
555     for (int j = 0; j < (mid >> 1); ++j) {
556       int idx1 = (i*mid + j) * tail;
557       int idx2 = idx1 + (mid-(j << 1)-1) * tail;
558       for (int k = 0; k < tail; ++k, ++idx1, ++idx2) {
559         DType tmp = src[idx1];
560         dst[idx1] = src[idx2];
561         dst[idx2] = tmp;
562       }
563     }
564   }
565 }
566 
FlipLeftRight(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)567 inline void FlipLeftRight(const nnvm::NodeAttrs &attrs,
568                    const OpContext &ctx,
569                    const std::vector<TBlob> &inputs,
570                    const std::vector<OpReqType> &req,
571                    const std::vector<TBlob> &outputs) {
572   using namespace mshadow;
573   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
574     FlipImpl<DType, 1>(inputs[0].shape_, inputs[0].dptr<DType>(),
575                        outputs[0].dptr<DType>());
576   });
577 }
578 
FlipTopBottom(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)579 inline void FlipTopBottom(const nnvm::NodeAttrs &attrs,
580                    const OpContext &ctx,
581                    const std::vector<TBlob> &inputs,
582                    const std::vector<OpReqType> &req,
583                    const std::vector<TBlob> &outputs) {
584   using namespace mshadow;
585   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
586     FlipImpl<DType, 0>(inputs[0].shape_, inputs[0].dptr<DType>(),
587                        outputs[0].dptr<DType>());
588   });
589 }
590 
RandomFlipLeftRight(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)591 inline void RandomFlipLeftRight(
592     const nnvm::NodeAttrs &attrs,
593     const OpContext &ctx,
594     const std::vector<TBlob> &inputs,
595     const std::vector<OpReqType> &req,
596     const std::vector<TBlob> &outputs) {
597   using namespace mshadow;
598   Stream<cpu> *s = ctx.get_stream<cpu>();
599   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
600   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
601     if (std::bernoulli_distribution()(prnd->GetRndEngine())) {
602       if (outputs[0].dptr_ != inputs[0].dptr_) {
603         std::memcpy(outputs[0].dptr_, inputs[0].dptr_, inputs[0].Size() * sizeof(DType));
604       }
605     } else {
606       FlipImpl<DType, 1>(inputs[0].shape_, inputs[0].dptr<DType>(),
607                          outputs[0].dptr<DType>());
608     }
609   });
610 }
611 
RandomFlipTopBottom(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)612 inline void RandomFlipTopBottom(
613     const nnvm::NodeAttrs &attrs,
614     const OpContext &ctx,
615     const std::vector<TBlob> &inputs,
616     const std::vector<OpReqType> &req,
617     const std::vector<TBlob> &outputs) {
618   using namespace mshadow;
619   Stream<cpu> *s = ctx.get_stream<cpu>();
620   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
621   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
622     if (std::bernoulli_distribution()(prnd->GetRndEngine())) {
623       if (outputs[0].dptr_ != inputs[0].dptr_) {
624         std::memcpy(outputs[0].dptr_, inputs[0].dptr_, inputs[0].Size() * sizeof(DType));
625       }
626     } else {
627       FlipImpl<DType, 0>(inputs[0].shape_, inputs[0].dptr<DType>(),
628                          outputs[0].dptr<DType>());
629     }
630   });
631 }
632 
633 struct RandomEnhanceParam : public dmlc::Parameter<RandomEnhanceParam> {
634   float min_factor;
635   float max_factor;
DMLC_DECLARE_PARAMETERRandomEnhanceParam636   DMLC_DECLARE_PARAMETER(RandomEnhanceParam) {
637     DMLC_DECLARE_FIELD(min_factor)
638     .set_lower_bound(0.0)
639     .describe("Minimum factor.");
640     DMLC_DECLARE_FIELD(max_factor)
641     .set_lower_bound(0.0)
642     .describe("Maximum factor.");
643   }
644 };
645 
AdjustBrightnessImpl(const float & alpha_b,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)646 inline void AdjustBrightnessImpl(const float& alpha_b,
647                                  const OpContext &ctx,
648                                  const std::vector<TBlob> &inputs,
649                                  const std::vector<OpReqType> &req,
650                                  const std::vector<TBlob> &outputs) {
651   using namespace mshadow;
652   int length = inputs[0].Size();
653 
654   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
655     DType* output = outputs[0].dptr<DType>();
656     DType* input = inputs[0].dptr<DType>();
657     for (int l = 0; l < length; ++l) {
658       float val = static_cast<float>(input[l]) * alpha_b;
659       output[l] = saturate_cast<DType>(val);
660     }
661   });
662 }
663 
RandomBrightness(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)664 inline void RandomBrightness(const nnvm::NodeAttrs &attrs,
665                       const OpContext &ctx,
666                       const std::vector<TBlob> &inputs,
667                       const std::vector<OpReqType> &req,
668                       const std::vector<TBlob> &outputs) {
669   using namespace mshadow;
670   const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
671 
672 
673   Stream<cpu> *s = ctx.get_stream<cpu>();
674   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
675   float alpha_b = std::uniform_real_distribution<float>(
676       param.min_factor, param.max_factor)(prnd->GetRndEngine());
677 
678   AdjustBrightnessImpl(alpha_b, ctx, inputs, req, outputs);
679 }
680 
AdjustContrastImpl(const float & alpha_c,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)681 inline void AdjustContrastImpl(const float& alpha_c,
682                                const OpContext &ctx,
683                                const std::vector<TBlob> &inputs,
684                                const std::vector<OpReqType> &req,
685                                const std::vector<TBlob> &outputs) {
686   using namespace mshadow;
687   static const float coef[] = { 0.299f, 0.587f, 0.114f };
688 
689   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
690   int nchannels = inputs[0].shape_[2];
691 
692   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
693     DType* output = outputs[0].dptr<DType>();
694     DType* input = inputs[0].dptr<DType>();
695 
696     float sum = 0.f;
697     if (nchannels > 1) {
698       for (int l = 0; l < length; ++l) {
699         for (int c = 0; c < 3; ++c) sum += input[l*3 + c] * coef[c];
700       }
701     } else {
702       for (int l = 0; l < length; ++l) sum += input[l];
703     }
704     float gray_mean = sum / static_cast<float>(length);
705     float beta = (1 - alpha_c) * gray_mean;
706 
707     for (int l = 0; l < length * nchannels; ++l) {
708       float val = input[l] * alpha_c + beta;
709       output[l] = saturate_cast<DType>(val);
710     }
711   });
712 }
713 
RandomContrast(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)714 inline void RandomContrast(const nnvm::NodeAttrs &attrs,
715                            const OpContext &ctx,
716                            const std::vector<TBlob> &inputs,
717                            const std::vector<OpReqType> &req,
718                            const std::vector<TBlob> &outputs) {
719   using namespace mshadow;
720   const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
721 
722 
723   Stream<cpu> *s = ctx.get_stream<cpu>();
724   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
725   float alpha_c = std::uniform_real_distribution<float>(
726       param.min_factor, param.max_factor)(prnd->GetRndEngine());
727 
728   AdjustContrastImpl(alpha_c, ctx, inputs, req, outputs);
729 }
730 
AdjustSaturationImpl(const float & alpha_s,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)731 inline void AdjustSaturationImpl(const float& alpha_s,
732                                  const OpContext &ctx,
733                                  const std::vector<TBlob> &inputs,
734                                  const std::vector<OpReqType> &req,
735                                  const std::vector<TBlob> &outputs) {
736   static const float coef[] = { 0.299f, 0.587f, 0.114f };
737 
738   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
739   int nchannels = inputs[0].shape_[2];
740 
741   float alpha_o = 1.f - alpha_s;
742 
743   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
744     DType* output = outputs[0].dptr<DType>();
745     DType* input = inputs[0].dptr<DType>();
746 
747     if (nchannels == 1) {
748       for (int l = 0; l < length; ++l) output[l] = input[l];
749       return;
750     }
751 
752     for (int l = 0; l < length; ++l) {
753       float gray = 0.f;
754       for (int c = 0; c < 3; ++c) {
755         gray = input[l*3 + c] * coef[c];
756       }
757       gray *= alpha_o;
758       for (int c = 0; c < 3; ++c) {
759         float val = gray + input[l*3 + c] * alpha_s;
760         output[l*3 + c] = saturate_cast<DType>(val);
761       }
762     }
763   });
764 }
765 
RandomSaturation(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)766 inline void RandomSaturation(const nnvm::NodeAttrs &attrs,
767                              const OpContext &ctx,
768                              const std::vector<TBlob> &inputs,
769                              const std::vector<OpReqType> &req,
770                              const std::vector<TBlob> &outputs) {
771   using namespace mshadow;
772   const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
773 
774   Stream<cpu> *s = ctx.get_stream<cpu>();
775   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
776   float alpha_s = std::uniform_real_distribution<float>(
777       param.min_factor, param.max_factor)(prnd->GetRndEngine());
778 
779   AdjustSaturationImpl(alpha_s, ctx, inputs, req, outputs);
780 }
781 
RGB2HLSConvert(const float & src_r,const float & src_g,const float & src_b,float * dst_h,float * dst_l,float * dst_s)782 inline void RGB2HLSConvert(const float& src_r,
783                     const float& src_g,
784                     const float& src_b,
785                     float *dst_h,
786                     float *dst_l,
787                     float *dst_s) {
788   float b = src_b / 255.f, g = src_g / 255.f, r = src_r / 255.f;
789   float h = 0.f, s = 0.f, l;
790   float vmin;
791   float vmax;
792   float diff;
793 
794   vmax = vmin = r;
795   vmax = std::fmax(vmax, g);
796   vmax = std::fmax(vmax, b);
797   vmin = std::fmin(vmin, g);
798   vmin = std::fmin(vmin, b);
799 
800   diff = vmax - vmin;
801   l = (vmax + vmin) * 0.5f;
802 
803   if (diff > std::numeric_limits<float>::epsilon()) {
804     s = (l < 0.5f) * diff / (vmax + vmin);
805     s += (l >= 0.5f) * diff / (2.0f - vmax - vmin);
806 
807     diff = 60.f / diff;
808 
809     h = (vmax == r) * (g - b) * diff;
810     h += (vmax != r && vmax == g) * ((b - r) * diff + 120.f);
811     h += (vmax != r && vmax != g) * ((r - g) * diff + 240.f);
812     h += (h < 0.f) * 360.f;
813   }
814 
815   *dst_h = h;
816   *dst_l = l;
817   *dst_s = s;
818 }
819 
HLS2RGBConvert(const float & src_h,const float & src_l,const float & src_s,float * dst_r,float * dst_g,float * dst_b)820 inline void HLS2RGBConvert(const float& src_h,
821                     const float& src_l,
822                     const float& src_s,
823                     float *dst_r,
824                     float *dst_g,
825                     float *dst_b) {
826   static const int c_HlsSectorData[6][3] = {
827     { 1, 3, 0 },
828     { 1, 0, 2 },
829     { 3, 0, 1 },
830     { 0, 2, 1 },
831     { 0, 1, 3 },
832     { 2, 1, 0 }
833   };
834 
835   float h = src_h, l = src_l, s = src_s;
836   float b = l, g = l, r = l;
837 
838   if (s != 0) {
839     float p2 = (l <= 0.5f) * l * (1 + s);
840     p2 += (l > 0.5f) * (l + s - l * s);
841     float p1 = 2 * l - p2;
842 
843     h *= 1.f / 60.f;
844 
845     if (h < 0) {
846       do { h += 6; } while (h < 0);
847     } else if (h >= 6) {
848       do { h -= 6; } while (h >= 6);
849     }
850 
851     int sector = static_cast<int>(h);
852 
853     h -= sector;
854 
855     float tab[4];
856     tab[0] = p2;
857     tab[1] = p1;
858     tab[2] = p1 + (p2 - p1) * (1 - h);
859     tab[3] = p1 + (p2 - p1) * h;
860 
861     b = tab[c_HlsSectorData[sector][0]];
862     g = tab[c_HlsSectorData[sector][1]];
863     r = tab[c_HlsSectorData[sector][2]];
864   }
865 
866   *dst_b = b * 255.f;
867   *dst_g = g * 255.f;
868   *dst_r = r * 255.f;
869 }
870 
AdjustHueImpl(float alpha,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)871 inline void AdjustHueImpl(float alpha,
872                    const OpContext &ctx,
873                    const std::vector<TBlob> &inputs,
874                    const std::vector<OpReqType> &req,
875                    const std::vector<TBlob> &outputs) {
876   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
877   if (inputs[0].shape_[2] == 1) return;
878 
879   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
880     DType* input = inputs[0].dptr<DType>();
881     DType* output = outputs[0].dptr<DType>();
882 
883     for (int i = 0; i < length; ++i) {
884       float h, l, s;
885       float r = static_cast<float>(*(input++));
886       float g = static_cast<float>(*(input++));
887       float b = static_cast<float>(*(input++));
888       RGB2HLSConvert(r, g, b, &h, &l, &s);
889       h += alpha * 360.f;
890       HLS2RGBConvert(h, l, s, &r, &g, &b);
891       *(output++) = saturate_cast<DType>(r);
892       *(output++) = saturate_cast<DType>(g);
893       *(output++) = saturate_cast<DType>(b);
894     }
895   });
896 }
897 
RandomHue(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)898 inline void RandomHue(const nnvm::NodeAttrs &attrs,
899                const OpContext &ctx,
900                const std::vector<TBlob> &inputs,
901                const std::vector<OpReqType> &req,
902                const std::vector<TBlob> &outputs) {
903   using namespace mshadow;
904   const RandomEnhanceParam &param = nnvm::get<RandomEnhanceParam>(attrs.parsed);
905 
906   Stream<cpu> *s = ctx.get_stream<cpu>();
907   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
908   float alpha = std::uniform_real_distribution<float>(
909       param.min_factor, param.max_factor)(prnd->GetRndEngine());
910 
911   AdjustHueImpl(alpha, ctx, inputs, req, outputs);
912 }
913 
914 struct RandomColorJitterParam : public dmlc::Parameter<RandomColorJitterParam> {
915   float brightness;
916   float contrast;
917   float saturation;
918   float hue;
DMLC_DECLARE_PARAMETERRandomColorJitterParam919   DMLC_DECLARE_PARAMETER(RandomColorJitterParam) {
920     DMLC_DECLARE_FIELD(brightness)
921     .describe("How much to jitter brightness.");
922     DMLC_DECLARE_FIELD(contrast)
923     .describe("How much to jitter contrast.");
924     DMLC_DECLARE_FIELD(saturation)
925     .describe("How much to jitter saturation.");
926     DMLC_DECLARE_FIELD(hue)
927     .describe("How much to jitter hue.");
928   }
929 };
930 
RandomColorJitter(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)931 inline void RandomColorJitter(const nnvm::NodeAttrs &attrs,
932                        const OpContext &ctx,
933                        const std::vector<TBlob> &inputs,
934                        const std::vector<OpReqType> &req,
935                        const std::vector<TBlob> &outputs) {
936   using namespace mshadow;
937   const RandomColorJitterParam &param = nnvm::get<RandomColorJitterParam>(attrs.parsed);
938   Stream<cpu> *s = ctx.get_stream<cpu>();
939   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, real_t>(s);
940 
941   int order[4] = {0, 1, 2, 3};
942   std::shuffle(order, order + 4, prnd->GetRndEngine());
943   bool flag = false;
944 
945   for (int i = 0; i < 4; ++i) {
946     switch (order[i]) {
947       case 0:
948         if (param.brightness > 0) {
949           float alpha_b = 1.0 + std::uniform_real_distribution<float>(
950               -param.brightness, param.brightness)(prnd->GetRndEngine());
951           AdjustBrightnessImpl(alpha_b, ctx, flag ? outputs : inputs, req, outputs);
952           flag = true;
953         }
954         break;
955       case 1:
956         if (param.contrast > 0) {
957           float alpha_c = 1.0 + std::uniform_real_distribution<float>(
958               -param.contrast, param.contrast)(prnd->GetRndEngine());
959           AdjustContrastImpl(alpha_c, ctx, flag ? outputs : inputs, req, outputs);
960           flag = true;
961         }
962         break;
963       case 2:
964         if (param.saturation > 0) {
965           float alpha_s = 1.f + std::uniform_real_distribution<float>(
966               -param.saturation, param.saturation)(prnd->GetRndEngine());
967           AdjustSaturationImpl(alpha_s, ctx, flag ? outputs : inputs, req, outputs);
968           flag = true;
969         }
970         break;
971       case 3:
972         if (param.hue > 0) {
973           float alpha_h = std::uniform_real_distribution<float>(
974               -param.hue, param.hue)(prnd->GetRndEngine());
975           AdjustHueImpl(alpha_h, ctx, flag ? outputs : inputs, req, outputs);
976           flag = true;
977         }
978         break;
979     }
980   }
981 }
982 
983 struct AdjustLightingParam : public dmlc::Parameter<AdjustLightingParam> {
984   mxnet::Tuple<float> alpha;
DMLC_DECLARE_PARAMETERAdjustLightingParam985   DMLC_DECLARE_PARAMETER(AdjustLightingParam) {
986     DMLC_DECLARE_FIELD(alpha)
987     .describe("The lighting alphas for the R, G, B channels.");
988   }
989 };
990 
991 struct RandomLightingParam : public dmlc::Parameter<RandomLightingParam> {
992   float alpha_std;
DMLC_DECLARE_PARAMETERRandomLightingParam993   DMLC_DECLARE_PARAMETER(RandomLightingParam) {
994     DMLC_DECLARE_FIELD(alpha_std)
995     .set_default(0.05)
996     .describe("Level of the lighting noise.");
997   }
998 };
999 
AdjustLightingImpl(const mxnet::Tuple<float> & alpha,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)1000 inline void AdjustLightingImpl(const mxnet::Tuple<float>& alpha,
1001                         const OpContext &ctx,
1002                         const std::vector<TBlob> &inputs,
1003                         const std::vector<OpReqType> &req,
1004                         const std::vector<TBlob> &outputs) {
1005   static const float eig[3][3] = {
1006       { 55.46 * -0.5675, 4.794 * 0.7192,  1.148 * 0.4009 },
1007       { 55.46 * -0.5808, 4.794 * -0.0045, 1.148 * -0.8140 },
1008       { 55.46 * -0.5836, 4.794 * -0.6948, 1.148 * 0.4203 }
1009     };
1010 
1011   int length = inputs[0].shape_[0] * inputs[0].shape_[1];
1012   int channels = inputs[0].shape_[2];
1013   if (channels == 1) return;
1014 
1015   float pca_r = eig[0][0] * alpha[0] + eig[0][1] * alpha[1] + eig[0][2] * alpha[2];
1016   float pca_g = eig[1][0] * alpha[0] + eig[1][1] * alpha[1] + eig[1][2] * alpha[2];
1017   float pca_b = eig[2][0] * alpha[0] + eig[2][1] * alpha[1] + eig[2][2] * alpha[2];
1018 
1019   MSHADOW_TYPE_SWITCH(outputs[0].type_flag_, DType, {
1020     DType* output = outputs[0].dptr<DType>();
1021     DType* input = inputs[0].dptr<DType>();
1022 
1023     for (int i = 0; i < length; i++) {
1024       int base_ind = 3 * i;
1025       float in_r = static_cast<float>(input[base_ind]);
1026       float in_g = static_cast<float>(input[base_ind + 1]);
1027       float in_b = static_cast<float>(input[base_ind + 2]);
1028       output[base_ind] = saturate_cast<DType>(in_r + pca_r);
1029       output[base_ind + 1] = saturate_cast<DType>(in_g + pca_g);
1030       output[base_ind + 2] = saturate_cast<DType>(in_b + pca_b);
1031     }
1032   });
1033 }
1034 
AdjustLighting(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)1035 inline void AdjustLighting(const nnvm::NodeAttrs &attrs,
1036                     const OpContext &ctx,
1037                     const std::vector<TBlob> &inputs,
1038                     const std::vector<OpReqType> &req,
1039                     const std::vector<TBlob> &outputs) {
1040   using namespace mshadow;
1041   const AdjustLightingParam &param = nnvm::get<AdjustLightingParam>(attrs.parsed);
1042   AdjustLightingImpl(param.alpha, ctx, inputs, req, outputs);
1043 }
1044 
RandomLighting(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)1045 inline void RandomLighting(const nnvm::NodeAttrs &attrs,
1046                     const OpContext &ctx,
1047                     const std::vector<TBlob> &inputs,
1048                     const std::vector<OpReqType> &req,
1049                     const std::vector<TBlob> &outputs) {
1050   using namespace mshadow;
1051   const RandomLightingParam &param = nnvm::get<RandomLightingParam>(attrs.parsed);
1052   Stream<cpu> *s = ctx.get_stream<cpu>();
1053   Random<cpu> *prnd = ctx.requested[0].get_random<cpu, float>(s);
1054   std::normal_distribution<float> dist(0, param.alpha_std);
1055   float alpha_r = dist(prnd->GetRndEngine());
1056   float alpha_g = dist(prnd->GetRndEngine());
1057   float alpha_b = dist(prnd->GetRndEngine());
1058   AdjustLightingImpl({alpha_r, alpha_g, alpha_b}, ctx, inputs, req, outputs);
1059 }
1060 
1061 
1062 #define MXNET_REGISTER_IMAGE_AUG_OP(name)                                   \
1063   NNVM_REGISTER_OP(name)                                                    \
1064   .set_num_inputs(1)                                                        \
1065   .set_num_outputs(1)                                                       \
1066   .set_attr<nnvm::FInplaceOption>("FInplaceOption",                         \
1067     [](const NodeAttrs& attrs){                                             \
1068       return std::vector<std::pair<int, int> >{{0, 0}};                     \
1069     })                                                                      \
1070   .set_attr<mxnet::FInferShape>("FInferShape", ImageShape)                   \
1071   .set_attr<nnvm::FInferType>("FInferType", ElemwiseType<1, 1>)             \
1072   .set_attr<nnvm::FGradient>("FGradient", ElemwiseGradUseNone{ "_copy" })   \
1073   .add_argument("data", "NDArray-or-Symbol", "The input.")
1074 
1075 
1076 #define MXNET_REGISTER_IMAGE_RND_AUG_OP(name)                               \
1077   MXNET_REGISTER_IMAGE_AUG_OP(name)                                         \
1078   .set_attr<FResourceRequest>("FResourceRequest",                           \
1079     [](const NodeAttrs& attrs) {                                            \
1080       return std::vector<ResourceRequest>{ResourceRequest::kRandom};        \
1081     })
1082 
1083 }  // namespace image
1084 }  // namespace op
1085 }  // namespace mxnet
1086 
1087 #endif  // MXNET_OPERATOR_IMAGE_IMAGE_RANDOM_INL_H_
1088