1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file np_elemwise_binary_op.h
22  * \brief Function definition of elemwise and broadcast operators
23  */
24 #ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
25 #define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
26 
27 #include <algorithm>
28 #include <vector>
29 #include <string>
30 
31 #include "../tensor/elemwise_binary_broadcast_op.h"
32 #include "../tensor/elemwise_binary_scalar_op.h"
33 
34 namespace mxnet {
35 namespace op {
36 
PrintErrorMessage(const std::string & op_name,const int dtype1,const int dtype2)37 inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
38   LOG(FATAL) << "Operator " << op_name << " does not support combination of "
39              << mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2)
40              << " yet...";
41 }
42 
43 template<typename xpu, typename OP>
MixedAllRealBinaryElemwiseCompute(const std::string & op_name,const OpContext & ctx,const TBlob & lhs,const TBlob & rhs,const TBlob & out,const OpReqType req)44 void MixedAllRealBinaryElemwiseCompute(const std::string& op_name,
45                                        const OpContext& ctx,
46                                        const TBlob& lhs,
47                                        const TBlob& rhs,
48                                        const TBlob& out,
49                                        const OpReqType req) {
50   using namespace mshadow;
51   using namespace mxnet_op;
52   CHECK_EQ(lhs.type_flag_, out.type_flag_);
53 
54   Stream<xpu> *s = ctx.get_stream<xpu>();
55 
56   MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
57     const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size())
58       + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
59     if (size == 0) return;
60 
61     switch (lhs.type_flag_) {
62       case mshadow::kFloat32:
63       {
64         if (rhs.type_flag_ == mshadow::kFloat16) {
65           MXNET_ASSIGN_REQ_SWITCH(req, Req, {
66             Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
67               s, size, out.dptr<float>(), rhs.dptr<mshadow::half::half_t>(),
68               lhs.dptr<float>());
69           });
70         } else {
71           PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
72         }
73         break;
74       }
75       case mshadow::kFloat64:
76       {
77         if (rhs.type_flag_ == mshadow::kFloat16) {
78           MXNET_ASSIGN_REQ_SWITCH(req, Req, {
79             Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
80               s, size, out.dptr<double>(), rhs.dptr<mshadow::half::half_t>(),
81               lhs.dptr<double>());
82           });
83         } else if (rhs.type_flag_ == mshadow::kFloat32) {
84           MXNET_ASSIGN_REQ_SWITCH(req, Req, {
85             Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
86               s, size, out.dptr<double>(), rhs.dptr<float>(),
87               lhs.dptr<double>());
88           });
89         } else {
90           PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
91         }
92         break;
93       }
94       default:
95       {
96         PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
97         break;
98       }
99     }
100   });
101 }
102 
103 template<typename xpu, typename OP>
MixedIntRealBinaryElemwiseCompute(const OpContext & ctx,const TBlob & lhs,const TBlob & rhs,const TBlob & out,const OpReqType req)104 void MixedIntRealBinaryElemwiseCompute(const OpContext& ctx,
105                                        const TBlob& lhs,
106                                        const TBlob& rhs,
107                                        const TBlob& out,
108                                        const OpReqType req) {
109   using namespace mshadow;
110   using namespace mxnet_op;
111   CHECK_EQ(lhs.type_flag_, out.type_flag_);
112 
113   Stream<xpu> *s = ctx.get_stream<xpu>();
114 
115   MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, FType, {
116     const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size())
117       + DataType<FType>::kLanes - 1) / DataType<FType>::kLanes;
118     if (size == 0) return;
119 
120     MXNET_INT_TYPE_SWITCH(rhs.type_flag_, IType, {
121       MXNET_ASSIGN_REQ_SWITCH(req, Req, {
122         Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
123           s, size, out.dptr<FType>(), rhs.dptr<IType>(),
124           lhs.dptr<FType>());
125       });
126     });
127   });
128 }
129 
130 template<typename xpu, typename LOP, typename ROP>
MixedBinaryElemwiseCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)131 void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
132                                 const OpContext& ctx,
133                                 const std::vector<TBlob>& inputs,
134                                 const std::vector<OpReqType>& req,
135                                 const std::vector<TBlob>& outputs) {
136   using namespace mshadow;
137   using namespace mxnet_op;
138   CHECK_EQ(inputs.size(), 2U);
139   CHECK_EQ(outputs.size(), 1U);
140 
141   const TBlob& lhs = inputs[0];
142   const TBlob& rhs = inputs[1];
143   const TBlob& out = outputs[0];
144   if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
145     if (lhs.type_flag_ == out.type_flag_) {
146       MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, rhs, out, req[0]);
147     } else {
148       MixedAllRealBinaryElemwiseCompute<xpu, LOP>(attrs.op->name, ctx, rhs, lhs, out, req[0]);
149     }
150   } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
151     if (lhs.type_flag_ == out.type_flag_) {
152       MixedIntRealBinaryElemwiseCompute<xpu, ROP>(ctx, lhs, rhs, out, req[0]);
153     } else {
154       MixedIntRealBinaryElemwiseCompute<xpu, LOP>(ctx, rhs, lhs, out, req[0]);
155     }
156   } else {
157     PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
158   }
159 }
160 
161 template<typename xpu, typename OP>
MixedAllRealBinaryBroadcastCompute(const std::string & op_name,const OpContext & ctx,const TBlob & lhs,const TBlob & rhs,const TBlob & out,const OpReqType req,const int ndim,const mxnet::TShape & new_oshape,const mxnet::TShape & new_lshape,const mxnet::TShape & new_rshape)162 void MixedAllRealBinaryBroadcastCompute(const std::string& op_name,
163                                         const OpContext& ctx,
164                                         const TBlob& lhs,
165                                         const TBlob& rhs,
166                                         const TBlob& out,
167                                         const OpReqType req,
168                                         const int ndim,
169                                         const mxnet::TShape& new_oshape,
170                                         const mxnet::TShape& new_lshape,
171                                         const mxnet::TShape& new_rshape) {
172   using namespace mshadow;
173   using namespace mxnet_op;
174   CHECK_EQ(lhs.type_flag_, out.type_flag_);
175 
176   Stream<xpu> *s = ctx.get_stream<xpu>();
177 
178   BROADCAST_NDIM_SWITCH(ndim, NDim, {
179     mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
180     mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
181     mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
182     switch (lhs.type_flag_) {
183       case mshadow::kFloat32:
184       {
185         if (rhs.type_flag_ == mshadow::kFloat16) {
186           mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
187           template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
188           rhs.dptr<mshadow::half::half_t>(), lhs.dptr<float>(), out.dptr<float>());
189         } else {
190           PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
191         }
192         break;
193       }
194       case mshadow::kFloat64:
195       {
196         if (rhs.type_flag_ == mshadow::kFloat16) {
197           mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
198           template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
199           rhs.dptr<mshadow::half::half_t>(), lhs.dptr<double>(), out.dptr<double>());
200         } else if (rhs.type_flag_ == mshadow::kFloat32) {
201           mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
202           template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
203           rhs.dptr<float>(), lhs.dptr<double>(), out.dptr<double>());
204         } else {
205           PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
206         }
207         break;
208       }
209       default:
210       {
211         PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
212         break;
213       }
214     }
215   });
216 }
217 
218 
219 template<typename xpu, typename OP, typename LOP, typename ROP>
MixedBinaryBroadcastCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)220 void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
221                                  const OpContext& ctx,
222                                  const std::vector<TBlob>& inputs,
223                                  const std::vector<OpReqType>& req,
224                                  const std::vector<TBlob>& outputs) {
225   using namespace mshadow;
226   using namespace mxnet_op;
227   CHECK_EQ(inputs.size(), 2U);
228   CHECK_EQ(outputs.size(), 1U);
229 
230   const TBlob& lhs = inputs[0];
231   const TBlob& rhs = inputs[1];
232   const TBlob& out = outputs[0];
233 
234   mxnet::TShape new_lshape, new_rshape, new_oshape;
235   int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
236                                          &new_lshape, &new_rshape, &new_oshape);
237   if (!ndim) {
238     MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
239   } else {
240     mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
241     if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
242       if (lhs.type_flag_ == out.type_flag_) {
243         MixedAllRealBinaryBroadcastCompute<xpu, ROP>(
244           attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape);
245       } else {
246         MixedAllRealBinaryBroadcastCompute<xpu, LOP>(
247           attrs.op->name, ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape);
248       }
249     } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
250       CHECK(lhs.type_flag_ == out.type_flag_ || rhs.type_flag_ == out.type_flag_)
251         << "One of the input type should be the same as the output";
252       BROADCAST_NDIM_SWITCH(ndim, NDim, {
253         mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
254         mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
255         mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
256         if (lhs.type_flag_ == out.type_flag_) {
257           MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, LType, {
258             MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
259               mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, ROP>, xpu>::
260               template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape,
261               rhs.dptr<RType>(), lhs.dptr<LType>(), out.dptr<LType>());
262             });
263           });
264         } else {
265           MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, RType, {
266             MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
267               mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, LOP>, xpu>::
268               template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
269               lhs.dptr<LType>(), rhs.dptr<RType>(), out.dptr<RType>());
270             });
271           });
272         }
273       });
274     } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
275       TBlob temp_tblob;
276       if (lhs.type_flag_ == out.type_flag_) {
277         MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
278           Tensor<xpu, 1, LType> temp_tensor =
279             ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
280           temp_tblob = TBlob(temp_tensor);
281         });
282         CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
283         BinaryBroadcastCompute<xpu, OP>(
284           attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
285       } else {
286         MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
287           Tensor<xpu, 1, RType> temp_tensor =
288             ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
289           temp_tblob = TBlob(temp_tensor);
290         });
291         CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
292         BinaryBroadcastCompute<xpu, OP>(
293           attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
294       }
295     } else {
296       PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
297     }
298   }
299 }
300 
301 template<typename xpu, typename OP, typename LOP, typename ROP>
NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)302 void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
303                                  const OpContext& ctx,
304                                  const std::vector<TBlob>& inputs,
305                                  const std::vector<OpReqType>& req,
306                                  const std::vector<TBlob>& outputs) {
307   using namespace mshadow;
308   using namespace mxnet_op;
309   CHECK_EQ(inputs.size(), 2U);
310   CHECK_EQ(outputs.size(), 1U);
311 
312   const TBlob& lhs = inputs[0];
313   const TBlob& rhs = inputs[1];
314   const TBlob& out = outputs[0];
315 
316   if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
317 
318   if (lhs.type_flag_ == rhs.type_flag_) {
319     BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
320     return;
321   }
322 
323   MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
324 }
325 
326 template<typename xpu, typename OP, typename LOP, typename ROP>
NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)327 void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
328                                          const OpContext& ctx,
329                                          const std::vector<TBlob>& inputs,
330                                          const std::vector<OpReqType>& req,
331                                          const std::vector<TBlob>& outputs) {
332   using namespace mshadow;
333   using namespace mxnet_op;
334   CHECK_EQ(inputs.size(), 2U);
335   CHECK_EQ(outputs.size(), 1U);
336 
337   const TBlob& lhs = inputs[0];
338   const TBlob& rhs = inputs[1];
339   const TBlob& out = outputs[0];
340 
341   if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
342 
343   if (lhs.type_flag_ == rhs.type_flag_) {
344     BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
345     return;
346   }
347   if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
348     Stream<xpu> *s = ctx.get_stream<xpu>();
349     TBlob temp_tblob;
350     if (lhs.type_flag_ == out.type_flag_) {
351       MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
352         Tensor<xpu, 1, LType> temp_tensor =
353           ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
354         temp_tblob = TBlob(temp_tensor);
355       });
356       CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
357       BinaryBroadcastCompute<xpu, OP>(
358         attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
359     } else {
360       MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
361         Tensor<xpu, 1, RType> temp_tensor =
362           ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
363         temp_tblob = TBlob(temp_tensor);
364       });
365       CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
366       BinaryBroadcastCompute<xpu, OP>(
367         attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
368     }
369     return;
370   }
371   MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
372 }
373 
374 template<typename xpu, typename LOP, typename ROP>
NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)375 void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
376                               const OpContext& ctx,
377                               const std::vector<TBlob>& inputs,
378                               const std::vector<OpReqType>& req,
379                               const std::vector<TBlob>& outputs) {
380   using namespace mshadow;
381   using namespace mxnet_op;
382   CHECK_EQ(inputs.size(), 3U);
383   CHECK_EQ(outputs.size(), 2U);
384 
385   const TBlob& lhs = inputs[1];
386   const TBlob& rhs = inputs[2];
387   if (lhs.type_flag_ == rhs.type_flag_) {
388     BinaryBroadcastBackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
389     return;
390   }
391 
392   const TBlob& ograd = inputs[0];
393   const TBlob& lgrad = outputs[0];
394   const TBlob& rgrad = outputs[1];
395 
396   if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
397     // If any of the inputs is a float, it's the same type as the output
398     // So 2 of the 3 tensors have the same data type
399     Stream<xpu> *s = ctx.get_stream<xpu>();
400     mxnet::TShape new_lshape, new_rshape, new_oshape;
401     using namespace broadcast;
402     const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_,
403                                                      &new_lshape, &new_rshape, &new_oshape) != 0;
404 
405     // Prepare all the temporary memory
406     size_t workspace_size_l = 0, workspace_size_r = 0;
407     TBlob temp_tblob;  // The TBlob for casted input data
408     TBlob temp_igrad;  // The TBlob for casted grad results
409     size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size();
410     Tensor<xpu, 1, char> workspace;
411 
412     MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
413       if (need_bc) {
414         BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, {
415           workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
416             s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
417           workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
418             s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
419         });
420       }
421       size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
422       size_t cast_tensor_size = tensor_size * sizeof(OType);
423       // Allocate the temporary memories now
424       Tensor<xpu, 1, char> temp_space =
425         ctx.requested[0].get_space_typed<xpu, 1, char>(
426           Shape1(workspace_size + cast_tensor_size * 2), s);
427       // Tensor for temp_tblob
428       Tensor<xpu, 1, OType> temp_tblob_tensor(
429                               reinterpret_cast<OType*>(temp_space.dptr_),
430                               Shape1(tensor_size), s);
431       // Tensor for temp_igrad
432       Tensor<xpu, 1, OType> temp_igrad_tensor(
433                               reinterpret_cast<OType*>(temp_space.dptr_) + tensor_size,
434                               Shape1(tensor_size), s);
435       temp_tblob =
436         TBlob(temp_tblob_tensor)
437           .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
438       temp_igrad =
439         TBlob(temp_igrad_tensor)
440           .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
441       if (temp_igrad.Size() != 0) {
442         Kernel<set_zero, xpu>::Launch(s, temp_igrad.Size(), temp_igrad.dptr<OType>());
443       }
444       workspace =
445         Tensor<xpu, 1, char>(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s);
446     });
447 
448     // Cast the input that does not have consistent type to temp_tblob
449     CastCompute<xpu>(
450       attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob});
451 
452     if (!need_bc) {
453       if (lhs.type_flag_ != ograd.type_flag_) {
454         ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
455           attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad});
456       } else {
457         ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
458           attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad});
459       }
460     } else {
461       if (lhs.type_flag_ != ograd.type_flag_) {
462         MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
463           BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
464             BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
465               ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad},
466               workspace, new_lshape, new_rshape, new_oshape);
467           });
468         });
469       } else {
470         MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
471           BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
472             BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
473               ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad},
474               workspace, new_lshape, new_rshape, new_oshape);
475           });
476         });
477       }
478     }
479 
480     // If both inputs are floating numbers, cast the igrad to the input that has
481     // the different data type
482     if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
483       if (lhs.type_flag_ != ograd.type_flag_) {
484         CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad});
485       } else {
486         CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad});
487       }
488     }
489   } else {
490     // Case where both inputs are integer types, should not even do
491     // backward computation for this case.
492     PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
493   }
494 }
495 
496 }  // namespace op
497 }  // namespace mxnet
498 #endif  // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
499