1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file np_choice_op.cc
22  * \brief Operator for random subset sampling
23  */
24 
25 #include "./np_choice_op.h"
26 #include <algorithm>
27 
28 namespace mxnet {
29 namespace op {
30 
31 template <>
_sort(float * key,int64_t * data,index_t length)32 void _sort<cpu>(float* key, int64_t* data, index_t length) {
33   std::sort(data, data + length,
34             [key](int64_t const& i, int64_t const& j) -> bool {
35               return key[i] > key[j];
36             });
37 }
38 
39 DMLC_REGISTER_PARAMETER(NumpyChoiceParam);
40 
41 NNVM_REGISTER_OP(_npi_choice)
42 .describe("random choice")
43 .set_num_inputs(
__anonb7f66b0c0202(const nnvm::NodeAttrs& attrs) 44   [](const nnvm::NodeAttrs& attrs) {
45     int num_input = 0;
46     const NumpyChoiceParam& param = nnvm::get<NumpyChoiceParam>(attrs.parsed);
47     if (param.weighted) num_input += 1;
48     if (!param.a.has_value()) num_input += 1;
49     return num_input;
50 })
51 .set_num_outputs(1)
52 .set_attr<nnvm::FListInputNames>(
53     "FListInputNames",
__anonb7f66b0c0302(const NodeAttrs& attrs) 54     [](const NodeAttrs& attrs) {
55       int num_input = 0;
56       const NumpyChoiceParam& param =
57           nnvm::get<NumpyChoiceParam>(attrs.parsed);
58       if (param.weighted) num_input += 1;
59       if (!param.a.has_value()) num_input += 1;
60       if (num_input == 0) return std::vector<std::string>();
61       if (num_input == 1) return std::vector<std::string>{"input1"};
62       return std::vector<std::string>{"input1", "input2"};
63 })
64 .set_attr_parser(ParamParser<NumpyChoiceParam>)
65 .set_attr<mxnet::FInferShape>("FInferShape", NumpyChoiceOpShape)
66 .set_attr<nnvm::FInferType>("FInferType", NumpyChoiceOpType)
67 .set_attr<FResourceRequest>("FResourceRequest",
__anonb7f66b0c0402(const nnvm::NodeAttrs& attrs) 68   [](const nnvm::NodeAttrs& attrs) {
69     return std::vector<ResourceRequest>{
70         ResourceRequest::kRandom,
71         ResourceRequest::kTempSpace};
72   })
73 .set_attr<FCompute>("FCompute<cpu>", NumpyChoiceForward<cpu>)
74 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
75 .add_argument("input1", "NDArray-or-Symbol", "Source input")
76 .add_argument("input2", "NDArray-or-Symbol", "Source input")
77 .add_arguments(NumpyChoiceParam::__FIELDS__());
78 
79 }  // namespace op
80 }  // namespace mxnet
81