1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_choice_op.cc
22 * \brief Operator for random subset sampling
23 */
24
25 #include "./np_choice_op.h"
26 #include <algorithm>
27
28 namespace mxnet {
29 namespace op {
30
31 template <>
_sort(float * key,int64_t * data,index_t length)32 void _sort<cpu>(float* key, int64_t* data, index_t length) {
33 std::sort(data, data + length,
34 [key](int64_t const& i, int64_t const& j) -> bool {
35 return key[i] > key[j];
36 });
37 }
38
39 DMLC_REGISTER_PARAMETER(NumpyChoiceParam);
40
41 NNVM_REGISTER_OP(_npi_choice)
42 .describe("random choice")
43 .set_num_inputs(
__anonb7f66b0c0202(const nnvm::NodeAttrs& attrs) 44 [](const nnvm::NodeAttrs& attrs) {
45 int num_input = 0;
46 const NumpyChoiceParam& param = nnvm::get<NumpyChoiceParam>(attrs.parsed);
47 if (param.weighted) num_input += 1;
48 if (!param.a.has_value()) num_input += 1;
49 return num_input;
50 })
51 .set_num_outputs(1)
52 .set_attr<nnvm::FListInputNames>(
53 "FListInputNames",
__anonb7f66b0c0302(const NodeAttrs& attrs) 54 [](const NodeAttrs& attrs) {
55 int num_input = 0;
56 const NumpyChoiceParam& param =
57 nnvm::get<NumpyChoiceParam>(attrs.parsed);
58 if (param.weighted) num_input += 1;
59 if (!param.a.has_value()) num_input += 1;
60 if (num_input == 0) return std::vector<std::string>();
61 if (num_input == 1) return std::vector<std::string>{"input1"};
62 return std::vector<std::string>{"input1", "input2"};
63 })
64 .set_attr_parser(ParamParser<NumpyChoiceParam>)
65 .set_attr<mxnet::FInferShape>("FInferShape", NumpyChoiceOpShape)
66 .set_attr<nnvm::FInferType>("FInferType", NumpyChoiceOpType)
67 .set_attr<FResourceRequest>("FResourceRequest",
__anonb7f66b0c0402(const nnvm::NodeAttrs& attrs) 68 [](const nnvm::NodeAttrs& attrs) {
69 return std::vector<ResourceRequest>{
70 ResourceRequest::kRandom,
71 ResourceRequest::kTempSpace};
72 })
73 .set_attr<FCompute>("FCompute<cpu>", NumpyChoiceForward<cpu>)
74 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
75 .add_argument("input1", "NDArray-or-Symbol", "Source input")
76 .add_argument("input2", "NDArray-or-Symbol", "Source input")
77 .add_arguments(NumpyChoiceParam::__FIELDS__());
78
79 } // namespace op
80 } // namespace mxnet
81