1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file np_choice_op.cu
22  * \brief Operator for random subset sampling
23  */
24 
25 #include <thrust/execution_policy.h>
26 #include <thrust/sort.h>
27 #include <thrust/swap.h>
28 #include "./np_choice_op.h"
29 
30 namespace mxnet {
31 namespace op {
32 
33 template <>
_sort(float * key,int64_t * data,index_t length)34 void _sort<gpu>(float* key, int64_t* data, index_t length) {
35   thrust::device_ptr<float> dev_key(key);
36   thrust::device_ptr<int64_t> dev_data(data);
37   thrust::sort_by_key(dev_key, dev_key + length, dev_data,
38                       thrust::greater<float>());
39 }
40 
41 NNVM_REGISTER_OP(_npi_choice)
42 .set_attr<FCompute>("FCompute<gpu>", NumpyChoiceForward<gpu>);
43 
44 }  // namespace op
45 }  // namespace mxnet
46