1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_elemwise_binary_op.h
22 * \brief Function definition of elemwise and broadcast operators
23 */
24 #ifndef MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
25 #define MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
26
27 #include <algorithm>
28 #include <vector>
29 #include <string>
30
31 #include "../tensor/elemwise_binary_broadcast_op.h"
32 #include "../tensor/elemwise_binary_scalar_op.h"
33
34 namespace mxnet {
35 namespace op {
36
PrintErrorMessage(const std::string & op_name,const int dtype1,const int dtype2)37 inline void PrintErrorMessage(const std::string& op_name, const int dtype1, const int dtype2) {
38 LOG(FATAL) << "Operator " << op_name << " does not support combination of "
39 << mshadow::dtype_string(dtype1) << " with " << mshadow::dtype_string(dtype2)
40 << " yet...";
41 }
42
43 template<typename xpu, typename OP>
MixedAllRealBinaryElemwiseCompute(const std::string & op_name,const OpContext & ctx,const TBlob & lhs,const TBlob & rhs,const TBlob & out,const OpReqType req)44 void MixedAllRealBinaryElemwiseCompute(const std::string& op_name,
45 const OpContext& ctx,
46 const TBlob& lhs,
47 const TBlob& rhs,
48 const TBlob& out,
49 const OpReqType req) {
50 using namespace mshadow;
51 using namespace mxnet_op;
52 CHECK_EQ(lhs.type_flag_, out.type_flag_);
53
54 Stream<xpu> *s = ctx.get_stream<xpu>();
55
56 MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, DType, {
57 const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size())
58 + DataType<DType>::kLanes - 1) / DataType<DType>::kLanes;
59 if (size == 0) return;
60
61 switch (lhs.type_flag_) {
62 case mshadow::kFloat32:
63 {
64 if (rhs.type_flag_ == mshadow::kFloat16) {
65 MXNET_ASSIGN_REQ_SWITCH(req, Req, {
66 Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
67 s, size, out.dptr<float>(), rhs.dptr<mshadow::half::half_t>(),
68 lhs.dptr<float>());
69 });
70 } else {
71 PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
72 }
73 break;
74 }
75 case mshadow::kFloat64:
76 {
77 if (rhs.type_flag_ == mshadow::kFloat16) {
78 MXNET_ASSIGN_REQ_SWITCH(req, Req, {
79 Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
80 s, size, out.dptr<double>(), rhs.dptr<mshadow::half::half_t>(),
81 lhs.dptr<double>());
82 });
83 } else if (rhs.type_flag_ == mshadow::kFloat32) {
84 MXNET_ASSIGN_REQ_SWITCH(req, Req, {
85 Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
86 s, size, out.dptr<double>(), rhs.dptr<float>(),
87 lhs.dptr<double>());
88 });
89 } else {
90 PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
91 }
92 break;
93 }
94 default:
95 {
96 PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
97 break;
98 }
99 }
100 });
101 }
102
103 template<typename xpu, typename OP>
MixedIntRealBinaryElemwiseCompute(const OpContext & ctx,const TBlob & lhs,const TBlob & rhs,const TBlob & out,const OpReqType req)104 void MixedIntRealBinaryElemwiseCompute(const OpContext& ctx,
105 const TBlob& lhs,
106 const TBlob& rhs,
107 const TBlob& out,
108 const OpReqType req) {
109 using namespace mshadow;
110 using namespace mxnet_op;
111 CHECK_EQ(lhs.type_flag_, out.type_flag_);
112
113 Stream<xpu> *s = ctx.get_stream<xpu>();
114
115 MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, FType, {
116 const size_t size = (ElemwiseBinaryOp::minthree(out.Size(), lhs.Size(), rhs.Size())
117 + DataType<FType>::kLanes - 1) / DataType<FType>::kLanes;
118 if (size == 0) return;
119
120 MXNET_INT_TYPE_SWITCH(rhs.type_flag_, IType, {
121 MXNET_ASSIGN_REQ_SWITCH(req, Req, {
122 Kernel<mxnet_op::op_with_req<OP, Req>, xpu>::Launch(
123 s, size, out.dptr<FType>(), rhs.dptr<IType>(),
124 lhs.dptr<FType>());
125 });
126 });
127 });
128 }
129
130 template<typename xpu, typename LOP, typename ROP>
MixedBinaryElemwiseCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)131 void MixedBinaryElemwiseCompute(const nnvm::NodeAttrs& attrs,
132 const OpContext& ctx,
133 const std::vector<TBlob>& inputs,
134 const std::vector<OpReqType>& req,
135 const std::vector<TBlob>& outputs) {
136 using namespace mshadow;
137 using namespace mxnet_op;
138 CHECK_EQ(inputs.size(), 2U);
139 CHECK_EQ(outputs.size(), 1U);
140
141 const TBlob& lhs = inputs[0];
142 const TBlob& rhs = inputs[1];
143 const TBlob& out = outputs[0];
144 if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
145 if (lhs.type_flag_ == out.type_flag_) {
146 MixedAllRealBinaryElemwiseCompute<xpu, ROP>(attrs.op->name, ctx, lhs, rhs, out, req[0]);
147 } else {
148 MixedAllRealBinaryElemwiseCompute<xpu, LOP>(attrs.op->name, ctx, rhs, lhs, out, req[0]);
149 }
150 } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
151 if (lhs.type_flag_ == out.type_flag_) {
152 MixedIntRealBinaryElemwiseCompute<xpu, ROP>(ctx, lhs, rhs, out, req[0]);
153 } else {
154 MixedIntRealBinaryElemwiseCompute<xpu, LOP>(ctx, rhs, lhs, out, req[0]);
155 }
156 } else {
157 PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
158 }
159 }
160
161 template<typename xpu, typename OP>
MixedAllRealBinaryBroadcastCompute(const std::string & op_name,const OpContext & ctx,const TBlob & lhs,const TBlob & rhs,const TBlob & out,const OpReqType req,const int ndim,const mxnet::TShape & new_oshape,const mxnet::TShape & new_lshape,const mxnet::TShape & new_rshape)162 void MixedAllRealBinaryBroadcastCompute(const std::string& op_name,
163 const OpContext& ctx,
164 const TBlob& lhs,
165 const TBlob& rhs,
166 const TBlob& out,
167 const OpReqType req,
168 const int ndim,
169 const mxnet::TShape& new_oshape,
170 const mxnet::TShape& new_lshape,
171 const mxnet::TShape& new_rshape) {
172 using namespace mshadow;
173 using namespace mxnet_op;
174 CHECK_EQ(lhs.type_flag_, out.type_flag_);
175
176 Stream<xpu> *s = ctx.get_stream<xpu>();
177
178 BROADCAST_NDIM_SWITCH(ndim, NDim, {
179 mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
180 mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
181 mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
182 switch (lhs.type_flag_) {
183 case mshadow::kFloat32:
184 {
185 if (rhs.type_flag_ == mshadow::kFloat16) {
186 mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
187 template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
188 rhs.dptr<mshadow::half::half_t>(), lhs.dptr<float>(), out.dptr<float>());
189 } else {
190 PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
191 }
192 break;
193 }
194 case mshadow::kFloat64:
195 {
196 if (rhs.type_flag_ == mshadow::kFloat16) {
197 mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
198 template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
199 rhs.dptr<mshadow::half::half_t>(), lhs.dptr<double>(), out.dptr<double>());
200 } else if (rhs.type_flag_ == mshadow::kFloat32) {
201 mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, OP>, xpu>::
202 template LaunchEx(s, new_oshape.Size(), req, rstride, lstride, oshape,
203 rhs.dptr<float>(), lhs.dptr<double>(), out.dptr<double>());
204 } else {
205 PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
206 }
207 break;
208 }
209 default:
210 {
211 PrintErrorMessage(op_name, lhs.type_flag_, rhs.type_flag_);
212 break;
213 }
214 }
215 });
216 }
217
218
219 template<typename xpu, typename OP, typename LOP, typename ROP>
MixedBinaryBroadcastCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)220 void MixedBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
221 const OpContext& ctx,
222 const std::vector<TBlob>& inputs,
223 const std::vector<OpReqType>& req,
224 const std::vector<TBlob>& outputs) {
225 using namespace mshadow;
226 using namespace mxnet_op;
227 CHECK_EQ(inputs.size(), 2U);
228 CHECK_EQ(outputs.size(), 1U);
229
230 const TBlob& lhs = inputs[0];
231 const TBlob& rhs = inputs[1];
232 const TBlob& out = outputs[0];
233
234 mxnet::TShape new_lshape, new_rshape, new_oshape;
235 int ndim = BinaryBroadcastShapeCompact(lhs.shape_, rhs.shape_, out.shape_,
236 &new_lshape, &new_rshape, &new_oshape);
237 if (!ndim) {
238 MixedBinaryElemwiseCompute<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
239 } else {
240 mshadow::Stream<xpu> *s = ctx.get_stream<xpu>();
241 if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
242 if (lhs.type_flag_ == out.type_flag_) {
243 MixedAllRealBinaryBroadcastCompute<xpu, ROP>(
244 attrs.op->name, ctx, lhs, rhs, out, req[0], ndim, new_oshape, new_lshape, new_rshape);
245 } else {
246 MixedAllRealBinaryBroadcastCompute<xpu, LOP>(
247 attrs.op->name, ctx, rhs, lhs, out, req[0], ndim, new_oshape, new_rshape, new_lshape);
248 }
249 } else if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
250 CHECK(lhs.type_flag_ == out.type_flag_ || rhs.type_flag_ == out.type_flag_)
251 << "One of the input type should be the same as the output";
252 BROADCAST_NDIM_SWITCH(ndim, NDim, {
253 mshadow::Shape<NDim> oshape = new_oshape.get<NDim>();
254 mshadow::Shape<NDim> lstride = mxnet_op::calc_stride(new_lshape.get<NDim>());
255 mshadow::Shape<NDim> rstride = mxnet_op::calc_stride(new_rshape.get<NDim>());
256 if (lhs.type_flag_ == out.type_flag_) {
257 MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, LType, {
258 MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
259 mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, ROP>, xpu>::
260 template LaunchEx(s, new_oshape.Size(), req[0], rstride, lstride, oshape,
261 rhs.dptr<RType>(), lhs.dptr<LType>(), out.dptr<LType>());
262 });
263 });
264 } else {
265 MSHADOW_REAL_TYPE_SWITCH(out.type_flag_, RType, {
266 MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
267 mxnet_op::Kernel<mxnet_op::binary_broadcast_kernel<NDim, LOP>, xpu>::
268 template LaunchEx(s, new_oshape.Size(), req[0], lstride, rstride, oshape,
269 lhs.dptr<LType>(), rhs.dptr<RType>(), out.dptr<RType>());
270 });
271 });
272 }
273 });
274 } else if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
275 TBlob temp_tblob;
276 if (lhs.type_flag_ == out.type_flag_) {
277 MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
278 Tensor<xpu, 1, LType> temp_tensor =
279 ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
280 temp_tblob = TBlob(temp_tensor);
281 });
282 CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
283 BinaryBroadcastCompute<xpu, OP>(
284 attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
285 } else {
286 MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
287 Tensor<xpu, 1, RType> temp_tensor =
288 ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
289 temp_tblob = TBlob(temp_tensor);
290 });
291 CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
292 BinaryBroadcastCompute<xpu, OP>(
293 attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
294 }
295 } else {
296 PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
297 }
298 }
299 }
300
301 template<typename xpu, typename OP, typename LOP, typename ROP>
NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)302 void NumpyBinaryBroadcastCompute(const nnvm::NodeAttrs& attrs,
303 const OpContext& ctx,
304 const std::vector<TBlob>& inputs,
305 const std::vector<OpReqType>& req,
306 const std::vector<TBlob>& outputs) {
307 using namespace mshadow;
308 using namespace mxnet_op;
309 CHECK_EQ(inputs.size(), 2U);
310 CHECK_EQ(outputs.size(), 1U);
311
312 const TBlob& lhs = inputs[0];
313 const TBlob& rhs = inputs[1];
314 const TBlob& out = outputs[0];
315
316 if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
317
318 if (lhs.type_flag_ == rhs.type_flag_) {
319 BinaryBroadcastCompute<xpu, OP>(attrs, ctx, inputs, req, outputs);
320 return;
321 }
322
323 MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
324 }
325
326 template<typename xpu, typename OP, typename LOP, typename ROP>
NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)327 void NumpyBinaryBroadcastComputeWithBool(const nnvm::NodeAttrs& attrs,
328 const OpContext& ctx,
329 const std::vector<TBlob>& inputs,
330 const std::vector<OpReqType>& req,
331 const std::vector<TBlob>& outputs) {
332 using namespace mshadow;
333 using namespace mxnet_op;
334 CHECK_EQ(inputs.size(), 2U);
335 CHECK_EQ(outputs.size(), 1U);
336
337 const TBlob& lhs = inputs[0];
338 const TBlob& rhs = inputs[1];
339 const TBlob& out = outputs[0];
340
341 if ((out.shape_.Size() == 0U) || (req[0] == kNullOp)) return;
342
343 if (lhs.type_flag_ == rhs.type_flag_) {
344 BinaryBroadcastComputeWithBool<xpu, OP>(attrs, ctx, inputs, req, outputs);
345 return;
346 }
347 if (!common::is_float(lhs.type_flag_) && !common::is_float(rhs.type_flag_)) {
348 Stream<xpu> *s = ctx.get_stream<xpu>();
349 TBlob temp_tblob;
350 if (lhs.type_flag_ == out.type_flag_) {
351 MXNET_INT_TYPE_SWITCH(lhs.type_flag_, LType, {
352 Tensor<xpu, 1, LType> temp_tensor =
353 ctx.requested[0].get_space_typed<xpu, 1, LType>(Shape1(rhs.Size()), s);
354 temp_tblob = TBlob(temp_tensor);
355 });
356 CastCompute<xpu>(attrs, ctx, {rhs}, {kWriteTo}, {temp_tblob});
357 BinaryBroadcastCompute<xpu, OP>(
358 attrs, ctx, {lhs, temp_tblob.reshape(rhs.shape_)}, req, outputs);
359 } else {
360 MXNET_INT_TYPE_SWITCH(rhs.type_flag_, RType, {
361 Tensor<xpu, 1, RType> temp_tensor =
362 ctx.requested[0].get_space_typed<xpu, 1, RType>(Shape1(lhs.Size()), s);
363 temp_tblob = TBlob(temp_tensor);
364 });
365 CastCompute<xpu>(attrs, ctx, {lhs}, {kWriteTo}, {temp_tblob});
366 BinaryBroadcastCompute<xpu, OP>(
367 attrs, ctx, {temp_tblob.reshape(lhs.shape_), rhs}, req, outputs);
368 }
369 return;
370 }
371 MixedBinaryBroadcastCompute<xpu, OP, LOP, ROP>(attrs, ctx, inputs, req, outputs);
372 }
373
374 template<typename xpu, typename LOP, typename ROP>
NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs & attrs,const OpContext & ctx,const std::vector<TBlob> & inputs,const std::vector<OpReqType> & req,const std::vector<TBlob> & outputs)375 void NumpyBinaryBackwardUseIn(const nnvm::NodeAttrs& attrs,
376 const OpContext& ctx,
377 const std::vector<TBlob>& inputs,
378 const std::vector<OpReqType>& req,
379 const std::vector<TBlob>& outputs) {
380 using namespace mshadow;
381 using namespace mxnet_op;
382 CHECK_EQ(inputs.size(), 3U);
383 CHECK_EQ(outputs.size(), 2U);
384
385 const TBlob& lhs = inputs[1];
386 const TBlob& rhs = inputs[2];
387 if (lhs.type_flag_ == rhs.type_flag_) {
388 BinaryBroadcastBackwardUseIn<xpu, LOP, ROP>(attrs, ctx, inputs, req, outputs);
389 return;
390 }
391
392 const TBlob& ograd = inputs[0];
393 const TBlob& lgrad = outputs[0];
394 const TBlob& rgrad = outputs[1];
395
396 if (common::is_float(lhs.type_flag_) || common::is_float(rhs.type_flag_)) {
397 // If any of the inputs is a float, it's the same type as the output
398 // So 2 of the 3 tensors have the same data type
399 Stream<xpu> *s = ctx.get_stream<xpu>();
400 mxnet::TShape new_lshape, new_rshape, new_oshape;
401 using namespace broadcast;
402 const bool need_bc = BinaryBroadcastShapeCompact(lgrad.shape_, rgrad.shape_, ograd.shape_,
403 &new_lshape, &new_rshape, &new_oshape) != 0;
404
405 // Prepare all the temporary memory
406 size_t workspace_size_l = 0, workspace_size_r = 0;
407 TBlob temp_tblob; // The TBlob for casted input data
408 TBlob temp_igrad; // The TBlob for casted grad results
409 size_t tensor_size = (lgrad.type_flag_ != ograd.type_flag_) ? lgrad.Size() : rgrad.Size();
410 Tensor<xpu, 1, char> workspace;
411
412 MSHADOW_TYPE_SWITCH(ograd.type_flag_, OType, {
413 if (need_bc) {
414 BROADCAST_NDIM_SWITCH(new_oshape.ndim(), ndim, {
415 workspace_size_l = ReduceWorkspaceSize<ndim, OType>(
416 s, new_lshape, req[0], new_oshape, new_lshape, new_rshape);
417 workspace_size_r = ReduceWorkspaceSize<ndim, OType>(
418 s, new_rshape, req[1], new_oshape, new_lshape, new_rshape);
419 });
420 }
421 size_t workspace_size = std::max(workspace_size_l, workspace_size_r);
422 size_t cast_tensor_size = tensor_size * sizeof(OType);
423 // Allocate the temporary memories now
424 Tensor<xpu, 1, char> temp_space =
425 ctx.requested[0].get_space_typed<xpu, 1, char>(
426 Shape1(workspace_size + cast_tensor_size * 2), s);
427 // Tensor for temp_tblob
428 Tensor<xpu, 1, OType> temp_tblob_tensor(
429 reinterpret_cast<OType*>(temp_space.dptr_),
430 Shape1(tensor_size), s);
431 // Tensor for temp_igrad
432 Tensor<xpu, 1, OType> temp_igrad_tensor(
433 reinterpret_cast<OType*>(temp_space.dptr_) + tensor_size,
434 Shape1(tensor_size), s);
435 temp_tblob =
436 TBlob(temp_tblob_tensor)
437 .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
438 temp_igrad =
439 TBlob(temp_igrad_tensor)
440 .reshape(((lgrad.type_flag_ != ograd.type_flag_) ? lhs.shape_ : rhs.shape_));
441 if (temp_igrad.Size() != 0) {
442 Kernel<set_zero, xpu>::Launch(s, temp_igrad.Size(), temp_igrad.dptr<OType>());
443 }
444 workspace =
445 Tensor<xpu, 1, char>(temp_space.dptr_ + 2 * cast_tensor_size, Shape1(workspace_size), s);
446 });
447
448 // Cast the input that does not have consistent type to temp_tblob
449 CastCompute<xpu>(
450 attrs, ctx, {((lgrad.type_flag_ != ograd.type_flag_) ? lhs : rhs)}, {kWriteTo}, {temp_tblob});
451
452 if (!need_bc) {
453 if (lhs.type_flag_ != ograd.type_flag_) {
454 ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
455 attrs, ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad});
456 } else {
457 ElemwiseBinaryOp::BackwardUseIn<xpu, LOP, ROP>(
458 attrs, ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad});
459 }
460 } else {
461 if (lhs.type_flag_ != ograd.type_flag_) {
462 MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
463 BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
464 BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
465 ctx, {ograd, temp_tblob, rhs}, {kWriteTo, req[1]}, {temp_igrad, rgrad},
466 workspace, new_lshape, new_rshape, new_oshape);
467 });
468 });
469 } else {
470 MSHADOW_TYPE_SWITCH(ograd.type_flag_, DType, {
471 BROADCAST_NDIM_SWITCH(new_oshape.ndim(), NDim, {
472 BinaryBroadcastBackwardUseInImplWithWorkspace<xpu, NDim, DType, LOP, ROP>(
473 ctx, {ograd, lhs, temp_tblob}, {req[0], kWriteTo}, {lgrad, temp_igrad},
474 workspace, new_lshape, new_rshape, new_oshape);
475 });
476 });
477 }
478 }
479
480 // If both inputs are floating numbers, cast the igrad to the input that has
481 // the different data type
482 if (common::is_float(lhs.type_flag_) && common::is_float(rhs.type_flag_)) {
483 if (lhs.type_flag_ != ograd.type_flag_) {
484 CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[0]}, {lgrad});
485 } else {
486 CastCompute<xpu>(attrs, ctx, {temp_igrad}, {req[1]}, {rgrad});
487 }
488 }
489 } else {
490 // Case where both inputs are integer types, should not even do
491 // backward computation for this case.
492 PrintErrorMessage(attrs.op->name, lhs.type_flag_, rhs.type_flag_);
493 }
494 }
495
496 } // namespace op
497 } // namespace mxnet
498 #endif // MXNET_OPERATOR_NUMPY_NP_ELEMWISE_BROADCAST_OP_H_
499