1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file allclose_op.cu
22 * \brief GPU Implementation of allclose op
23 * \author Andrei Ivanov
24 */
25 #include "./allclose_op-inl.h"
26 #include <cub/cub.cuh>
27
28 namespace mxnet {
29 namespace op {
30
31 template<typename T>
GetAdditionalMemory(mshadow::Stream<gpu> * s,const int num_items)32 size_t GetAdditionalMemory(mshadow::Stream<gpu> *s, const int num_items) {
33 T *d_in = nullptr;
34 T *d_out = nullptr;
35 size_t temp_storage_bytes = 0;
36 cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
37 cub::DeviceReduce::Min(nullptr, temp_storage_bytes, d_in, d_out, num_items, stream);
38 return temp_storage_bytes;
39 }
40
41 template<>
GetAdditionalMemoryLogical(mshadow::Stream<gpu> * s,const int num_items)42 size_t GetAdditionalMemoryLogical<gpu>(mshadow::Stream<gpu> *s, const int num_items) {
43 return GetAdditionalMemory<INTERM_DATA_TYPE>(s, num_items);
44 }
45
46 template<>
GetResultLogical(mshadow::Stream<gpu> * s,INTERM_DATA_TYPE * workMem,size_t extraStorageBytes,int num_items,INTERM_DATA_TYPE * outPntr)47 void GetResultLogical<gpu>(mshadow::Stream<gpu> *s, INTERM_DATA_TYPE *workMem,
48 size_t extraStorageBytes, int num_items, INTERM_DATA_TYPE *outPntr) {
49 cudaStream_t stream = mshadow::Stream<gpu>::GetStream(s);
50 cub::DeviceReduce::Min(workMem + num_items, extraStorageBytes,
51 workMem, outPntr, num_items, stream);
52 }
53
54 NNVM_REGISTER_OP(_contrib_allclose)
55 .set_attr<FCompute>("FCompute<gpu>", AllClose<gpu>);
56
57 } // namespace op
58 } // namespace mxnet
59