1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file np_choice_op.cu
22 * \brief Operator for random subset sampling
23 */
24
25 #include <thrust/execution_policy.h>
26 #include <thrust/sort.h>
27 #include <thrust/swap.h>
28 #include "./np_choice_op.h"
29
30 namespace mxnet {
31 namespace op {
32
33 template <>
_sort(float * key,int64_t * data,index_t length)34 void _sort<gpu>(float* key, int64_t* data, index_t length) {
35 thrust::device_ptr<float> dev_key(key);
36 thrust::device_ptr<int64_t> dev_data(data);
37 thrust::sort_by_key(dev_key, dev_key + length, dev_data,
38 thrust::greater<float>());
39 }
40
41 NNVM_REGISTER_OP(_npi_choice)
42 .set_attr<FCompute>("FCompute<gpu>", NumpyChoiceForward<gpu>);
43
44 } // namespace op
45 } // namespace mxnet
46