1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file np_broadcast_reduce_op_index.cc
22  * \brief CPU Implementation of broadcast and reduce functions based on index.
23  */
24 #include "./np_broadcast_reduce_op.h"
25 
26 namespace mxnet {
27 namespace op {
28 
NumpyReduceAxisShape(const nnvm::NodeAttrs & attrs,std::vector<TShape> * in_attrs,std::vector<TShape> * out_attrs)29 bool NumpyReduceAxisShape(const nnvm::NodeAttrs& attrs,
30                           std::vector<TShape> *in_attrs,
31                           std::vector<TShape> *out_attrs) {
32   CHECK_EQ(in_attrs->size(), 1U);
33   CHECK_EQ(out_attrs->size(), 1U);
34   if (!shape_is_known(in_attrs->at(0))) {
35     return false;
36   }
37   const ReduceAxisParam& param = nnvm::get<ReduceAxisParam>(attrs.parsed);
38   dmlc::optional<mxnet::Tuple<int>> axes;
39   if (param.axis.has_value()) {
40     mxnet::Tuple<int> t({param.axis.value()});
41     axes = dmlc::optional<mxnet::Tuple<int>>(t);
42   }
43   SHAPE_ASSIGN_CHECK(*out_attrs, 0,
44                      NumpyReduceAxesShapeImpl((*in_attrs)[0], axes, param.keepdims));
45   return shape_is_known(out_attrs->at(0));
46 }
47 
ArgMinMaxType(const nnvm::NodeAttrs & attrs,std::vector<int> * in_attrs,std::vector<int> * out_attrs)48 bool ArgMinMaxType(const nnvm::NodeAttrs& attrs,
49                    std::vector<int> *in_attrs,
50                    std::vector<int> *out_attrs) {
51   CHECK_EQ(in_attrs->size(), 1U);
52   CHECK_EQ(out_attrs->size(), 1U);
53   CHECK_NE(in_attrs->at(0), -1);
54   TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kInt64);
55   return out_attrs->at(0) != -1;
56 }
57 
58 NNVM_REGISTER_OP(_npi_argmax)
59 .set_num_inputs(1)
60 .set_num_outputs(1)
61 .set_attr_parser(ParamParser<ReduceAxisParam>)
62 .set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
63 .set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
64 .add_argument("data", "NDArray-or-Symbol", "The input")
65 .set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::maximum>)
66 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
67 .add_arguments(ReduceAxisParam::__FIELDS__());
68 
69 NNVM_REGISTER_OP(_npi_argmin)
70 .set_num_inputs(1)
71 .set_num_outputs(1)
72 .set_attr_parser(ParamParser<ReduceAxisParam>)
73 .set_attr<mxnet::FInferShape>("FInferShape", NumpyReduceAxisShape)
74 .set_attr<nnvm::FInferType>("FInferType", ArgMinMaxType)
75 .add_argument("data", "NDArray-or-Symbol", "The input")
76 .set_attr<FCompute>("FCompute<cpu>", NumpySearchAxisCompute<cpu, mshadow::red::minimum>)
77 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
78 .add_arguments(ReduceAxisParam::__FIELDS__());
79 
80 }  // namespace op
81 }  // namespace mxnet
82