1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file sequence_mask.cu
22  * \brief
23  * \author Sebastian Bodenstein
24 */
25 
26 #include "./sequence_mask-inl.h"
27 
28 namespace mxnet {
29 namespace op {
30 
31 // (seqlen, batch, rest) case
32 template <int req>
33 struct SequenceMask0GPUKernel {
34   template <typename DType, typename IType>
Mapmxnet::op::SequenceMask0GPUKernel35   MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
36                                   index_t max_s_len, index_t batch_size,
37                                   index_t restsize, DType value) {
38     index_t batch = i / restsize % batch_size;
39     const index_t seqpos = static_cast<int>(idx[batch]);
40     index_t seq = i / restsize / batch_size;
41     if (seq >= seqpos) {
42       KERNEL_ASSIGN(in[i], req, value);
43     }
44   }
45 };
46 
47 // (batch, seqlen, rest) case
48 template <int req>
49 struct SequenceMask1GPUKernel {
50   template <typename DType, typename IType>
Mapmxnet::op::SequenceMask1GPUKernel51   MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
52                                   index_t max_s_len, index_t batch_size,
53                                   index_t restsize, DType value) {
54     index_t batch = i / restsize / max_s_len;
55     const index_t seqpos = static_cast<int>(idx[batch]);
56     index_t seq = i / restsize % max_s_len;
57     if (seq >= seqpos) {
58       KERNEL_ASSIGN(in[i], req, value);
59     }
60   }
61 };
62 
63 template<typename DType, typename IType>
SequenceMaskExec(const mshadow::Tensor<gpu,3,DType> & data,const mshadow::Tensor<gpu,1,IType> & indices,const OpReqType req,mshadow::Stream<gpu> * const s,int axis,DType val)64 void SequenceMaskExec(
65        const mshadow::Tensor<gpu, 3, DType> &data,
66        const mshadow::Tensor<gpu, 1, IType> &indices,
67        const OpReqType req, mshadow::Stream<gpu> *const s,
68        int axis, DType val) {
69   using namespace mshadow;
70   using namespace mshadow::expr;
71   using namespace mxnet_op;
72 
73   index_t batch = indices.size(0);
74   index_t max_seq_len = data.size(axis);
75   index_t restsize = data.size(2);
76 
77   MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
78     if (axis == 1) {
79       Kernel<SequenceMask1GPUKernel<req_type>, gpu>::Launch(
80         s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
81         val);
82     } else {
83       Kernel<SequenceMask0GPUKernel<req_type>, gpu>::Launch(
84         s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
85         val);
86     }
87   });
88 }
89 
CreateOp(SequenceMaskParam param,int dtype,int itype)90 template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype, int itype) {
91   Operator *op = nullptr;
92   MSHADOW_TYPE_SWITCH(dtype, DType, {
93       MSHADOW_TYPE_SWITCH(itype, IType, {
94           op = new SequenceMaskOp<gpu, DType, IType>(param);
95         });
96     });
97   return op;
98 }
99 
100 }  // namespace op
101 }  // namespace mxnet
102