1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file allclose_op.cu
22  * \brief GPU Implementation of allclose op
23  * \author Andrei Ivanov
24  */
25 #include "./allclose_op-inl.h"
26 #include <cub/cub.cuh>
27 
28 namespace mxnet {
29 namespace op {
30 
31 template<typename T>
GetAdditionalMemory(mshadow::Stream<gpu> * s,const int num_items)32 size_t GetAdditionalMemory(mshadow::Stream<gpu> *s, const int num_items) {
33   T *d_in = nullptr;
34   T *d_out = nullptr;
35   size_t temp_storage_bytes = 0;
36   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
37   cub::DeviceReduce::Min(nullptr, temp_storage_bytes, d_in, d_out, num_items, stream);
38   return temp_storage_bytes;
39 }
40 
41 template<>
GetAdditionalMemoryLogical(mshadow::Stream<gpu> * s,const int num_items)42 size_t GetAdditionalMemoryLogical<gpu>(mshadow::Stream<gpu> *s, const int num_items) {
43   return GetAdditionalMemory<INTERM_DATA_TYPE>(s, num_items);
44 }
45 
46 template<>
GetResultLogical(mshadow::Stream<gpu> * s,INTERM_DATA_TYPE * workMem,size_t extraStorageBytes,int num_items,INTERM_DATA_TYPE * outPntr)47 void GetResultLogical<gpu>(mshadow::Stream<gpu> *s, INTERM_DATA_TYPE *workMem,
48                            size_t extraStorageBytes, int num_items, INTERM_DATA_TYPE *outPntr) {
49   cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
50   cub::DeviceReduce::Min(workMem + num_items, extraStorageBytes,
51                          workMem, outPntr, num_items, stream);
52 }
53 
54 NNVM_REGISTER_OP(_contrib_allclose)
55 .set_attr<FCompute>("FCompute<gpu>", AllClose<gpu>);
56 
57 }  // namespace op
58 }  // namespace mxnet
59