1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_broadcast_reduce_op_index.cc
22 * \brief CPU Implementation of broadcast and reduce functions based on index.
23 */
24 #include "./np_broadcast_reduce_op.h"
25
26 namespace mxnet {
27 namespace op {
28
NumpyReduceAxisShape(const nnvm::NodeAttrs & attrs,std::vector<TShape> * in_attrs,std::vector<TShape> * out_attrs)29 bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs,
30 std::vector<TShape> *in_attrs,
31 std::vector<TShape> *out_attrs) {
32 CHECK_EQ(in_attrs->size(), 1U);
33 CHECK_EQ(out_attrs->size(), 1U);
34 if (!shape_is_known(in_attrs->at(0))) {
35 return false;
36 }
37 const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
38 dmlc::optional<mxnet::Tuple<int>> axes;
39 if (param.axis.has_value()) {
40 mxnet::Tuple<int> t({param.axis.value()});
41 axes = dmlc::optional<mxnet::Tuple<int>>(t);
42 }
43 SHAPE_ASSIGN_CHECK(*out_attrs, 0,
44 NumpyReduceAxesShapeImpl((*in_attrs)[0], axes, param.keepdims));
45 return shape_is_known(out_attrs->at(0));
46 }
47
ArgMinMaxType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)48 bool ArgMinMaxType(const nnvm::NodeAttrs& attrs,
49 std::vector<int> *in_attrs,
50 std::vector<int> *out_attrs) {
51 CHECK_EQ(in_attrs->size(), 1U);
52 CHECK_EQ(out_attrs->size(), 1U);
53 CHECK_NE(in_attrs->at(0), -1);
54 TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
55 return out_attrs->at(0) != -1;
56 }
57
58 NNVM_REGISTER_OP(_npi_argmax)
59 .set_num_inputs(1)
60 .set_num_outputs(1)
61 .set_attr_parser(ParamParser<ReduceAxisParam>)
62 .set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
63 .set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
64 .add_argument("data", "NDArray-or-Symbol", "The input")
65 .set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::maximum>)
66 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
67 .add_arguments(ReduceAxisParam::__FIELDS__());
68
69 NNVM_REGISTER_OP(_npi_argmin)
70 .set_num_inputs(1)
71 .set_num_outputs(1)
72 .set_attr_parser(ParamParser<ReduceAxisParam>)
73 .set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
74 .set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
75 .add_argument("data", "NDArray-or-Symbol", "The input")
76 .set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::minimum>)
77 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
78 .add_arguments(ReduceAxisParam::__FIELDS__());
79
80 } // namespace op
81 } // namespace mxnet
82