1 /*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20 /*!
21 * \file sequence_mask.cu
22 * \brief
23 * \author Sebastian Bodenstein
24 */
25
26 #include "./sequence_mask-inl.h"
27
28 namespace mxnet {
29 namespace op {
30
31 // (seqlen, batch, rest) case
32 template <int req>
33 struct SequenceMask0GPUKernel {
34 template <typename DType, typename IType>
Mapmxnet::op::SequenceMask0GPUKernel35 MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
36 index_t max_s_len, index_t batch_size,
37 index_t restsize, DType value) {
38 index_t batch = i / restsize % batch_size;
39 const index_t seqpos = static_cast<int>(idx[batch]);
40 index_t seq = i / restsize / batch_size;
41 if (seq >= seqpos) {
42 KERNEL_ASSIGN(in[i], req, value);
43 }
44 }
45 };
46
47 // (batch, seqlen, rest) case
48 template <int req>
49 struct SequenceMask1GPUKernel {
50 template <typename DType, typename IType>
Mapmxnet::op::SequenceMask1GPUKernel51 MSHADOW_XINLINE static void Map(int i, DType *in, const IType *idx,
52 index_t max_s_len, index_t batch_size,
53 index_t restsize, DType value) {
54 index_t batch = i / restsize / max_s_len;
55 const index_t seqpos = static_cast<int>(idx[batch]);
56 index_t seq = i / restsize % max_s_len;
57 if (seq >= seqpos) {
58 KERNEL_ASSIGN(in[i], req, value);
59 }
60 }
61 };
62
63 template<typename DType, typename IType>
SequenceMaskExec(const mshadow::Tensor<gpu,3,DType> & data,const mshadow::Tensor<gpu,1,IType> & indices,const OpReqType req,mshadow::Stream<gpu> * const s,int axis,DType val)64 void SequenceMaskExec(
65 const mshadow::Tensor<gpu, 3, DType> &data,
66 const mshadow::Tensor<gpu, 1, IType> &indices,
67 const OpReqType req, mshadow::Stream<gpu> *const s,
68 int axis, DType val) {
69 using namespace mshadow;
70 using namespace mshadow::expr;
71 using namespace mxnet_op;
72
73 index_t batch = indices.size(0);
74 index_t max_seq_len = data.size(axis);
75 index_t restsize = data.size(2);
76
77 MXNET_ASSIGN_REQ_SWITCH(req, req_type, {
78 if (axis == 1) {
79 Kernel<SequenceMask1GPUKernel<req_type>, gpu>::Launch(
80 s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
81 val);
82 } else {
83 Kernel<SequenceMask0GPUKernel<req_type>, gpu>::Launch(
84 s, data.shape_.Size(), data.dptr_, indices.dptr_, max_seq_len, batch, restsize,
85 val);
86 }
87 });
88 }
89
CreateOp(SequenceMaskParam param,int dtype,int itype)90 template <> Operator *CreateOp<gpu>(SequenceMaskParam param, int dtype, int itype) {
91 Operator *op = nullptr;
92 MSHADOW_TYPE_SWITCH(dtype, DType, {
93 MSHADOW_TYPE_SWITCH(itype, IType, {
94 op = new SequenceMaskOp<gpu, DType, IType>(param);
95 });
96 });
97 return op;
98 }
99
100 } // namespace op
101 } // namespace mxnet
102