1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18package AI::MXNet::CudaModule;
19use strict;
20use warnings;
21use AI::MXNet::NS;
22use AI::MXNet::Base;
23use Mouse;
24use AI::MXNet::Function::Parameters;
25
26our %DTYPE_CPP_TO_STR = qw(
27    float    float32
28    double   float64
29    __half   float16
30    uint8_t  uint8
31    int      int32
32    int32_t  int32
33    int8_t   int8
34    char     int8
35    int64_t  int64
36);
37
38=head1 NAME
39
40    AI::MXNet::CudaModule - Interface to runtime cuda kernel compile module.
41=cut
42
43=head1 DESCRIPTION
44
45    Interface to runtime cuda kernel compile module.
46    Compile and run CUDA code from Perl.
47
48    In CUDA 7.5, you need to prepend your kernel definitions
49    with 'extern "C"' to avoid name mangling::
50
51        $source = '
52        extern "C" __global__ void axpy(const float *x, float *y, float alpha) {
53            int i = threadIdx.x + blockIdx.x * blockDim.x;
54            y[i] += alpha * x[i];
55        }
56        ';
57        $module = mx->rtc->CudaModule(source);
58        $func = $module->get_kernel("axpy", "const float *x, float *y, float alpha");
59        $x = mx->nd->ones([10]), ctx=>mx->gpu(0));
60        $y = mx->nd->zeros([10]), ctx=>mx->gpu(0));
61        $func->launch([$x, $y, 3.0], mx->gpu(0), [1, 1, 1], [10, 1, 1]);
62        print $y->aspdl;
63
64    Starting from CUDA 8.0, you can instead export functions by name.
65    This also allows you to use templates::
66
67        my $source = '
68        template<typename DType>
69        __global__ void axpy(const DType *x, DType *y, DType alpha) {
70            int i = threadIdx.x + blockIdx.x * blockDim.x;
71            y[i] += alpha * x[i];
72        }
73        ';
74        $module = mx->rtc->CudaModule($source, exports=>['axpy<float>', 'axpy<double>']);
75        $func32 = $module->get_kernel("axpy<float>", "const float *x, float *y, float alpha");
76        $x = mx->nd->ones([10], dtype=>'float32', ctx=>mx->gpu(0));
77        $y = mx->nd->zeros([10], dtype=>'float32', ctx=>mx->gpu(0));
78        $func32->launch([$x, $y, 3.0], mx->gpu(0), [1, 1, 1], [10, 1, 1]);
79        print $y->aspdl;
80
81        $func64 = $module->get_kernel("axpy<double>", "const double *x, double *y, double alpha");
82        $x = mx->nd->ones([10], dtype=>'float64', ctx=>mx->gpu(0));
83        $y = mx->nd->zeros([10], dtype=>'float64', ctx=>mx->gpu(0));
84        $func32->launch([$x, $y, 3.0], mx->gpu(0), [1, 1, 1], [10, 1, 1]);
85        print $y->aspdl;
86
87
88    Parameters
89    ----------
90    source : Str
91        Complete source code.
92    options : Str|ArrayRef[Str]
93        Compiler flags. For example, use "-I/usr/local/cuda/include" to
94        add cuda headers to include path.
95    exports : Str|ArrayRef[Str]
96        Export kernel names.
97=cut
98
99has 'source' => (is => 'rw', isa => 'Str', required => 1);
100has [qw/options exports/] => (is => 'rw', isa => 'Str|ArrayRef[Str]', default => sub { [] });
101has 'handle' => (is => 'rw', isa => 'CudaModuleHandle');
102around BUILDARGS => \&AI::MXNet::Base::process_arguments;
103method python_constructor_arguments() { ['source', 'options', 'exports'] }
104
105sub BUILD
106{
107    my $self = shift;
108    $self->options([$self->options]) unless ref $self->options;
109    $self->options([$self->exports]) unless ref $self->exports;
110    my $handle = check_call(
111                    AI::MXNetCAPI::RtcCudaModuleCreate(
112                        $self->source,
113                        scalar(@{ $self->options }),
114                        $self->options,
115                        scalar(@{ $self->exports }),
116                        $self->exports
117                    )
118    );
119    $self->handle($handle);
120}
121
122sub DEMOLISH
123{
124    check_call(AI::MXNetCAPI::RtcCudaModuleFree(shift->handle));
125}
126
127=head2 get_kernel
128
129        Get CUDA kernel from compiled module.
130
131        Parameters
132        ----------
133        $name : Str
134            String name of the kernel.
135        $signature : Str
136            Function signature for the kernel. For example, if a kernel is
137            declared as::
138
139                extern "C" __global__ void axpy(const float *x, double *y, int alpha)
140
141            Then its signature should be::
142
143                const float *x, double *y, int alpha
144
145            or::
146
147                const float *, double *, int
148
149            Note that `*` in signature marks an argument as array and
150            `const` marks an argument as constant (input) array.
151
152        Returns
153        -------
154        AI::MXNet::CudaKernel
155            CUDA kernels that can be launched on GPUs.
156=cut
157
158method get_kernel(Str $name, Str $signature)
159{
160    my @is_ndarray;
161    my @is_const;
162    my @dtypes;
163    my $pattern = qr/^\s*(const)?\s*([\w_]+)\s*(\*)?\s*([\w_]+)?\s*$/;
164    $signature =~ s/\s+/ /g;
165    my @args = split(/,/, $signature);
166    for my $arg (@args)
167    {
168        if(not $arg =~ $pattern or $2 eq 'const')
169        {
170            confess(
171                "Invalid function prototype \"$arg\". Must be in the ".
172                'form of "(const) type (*) (name)'
173            );
174        }
175        push @is_const, $1 ? 1 : 0;
176        my $dtype = $2;
177        push @is_ndarray, $3 ? 1 : 0;
178        if(not exists $DTYPE_CPP_TO_STR{$dtype})
179        {
180            my $types = join(',', sort keys %DTYPE_CPP_TO_STR);
181            confess("Unsupported kernel argument type $arg. Supported types are: $types.");
182        }
183        push @dtypes, DTYPE_STR_TO_MX->{$DTYPE_CPP_TO_STR{$dtype}};
184    }
185
186    my $handle = check_call(
187        AI::MXNetCAPI::RtcCudaKernelCreate(
188            $self->handle,
189            $name,
190            scalar(@dtypes),
191            \@is_ndarray,
192            \@is_const,
193            \@dtypes
194        )
195    );
196    return AI::MXNet::CudaKernel->new($handle, $name, \@is_ndarray, \@dtypes);
197}
198
199__PACKAGE__->AI::MXNet::NS::register('AI::MXNet');
200
201package AI::MXNet::CudaKernel;
202use Mouse;
203use AI::MXNet::Base;
204
205=head1 NAME
206
207    AI::MXNet::CudaKernel - Constructs CUDA kernel.
208=cut
209
210=head1 DESCRIPTION
211
212    Constructs CUDA kernel.
213    Intended to be created by calling AI::MXNet::CudaModule->get_kernel only.
214=cut
215
216has [qw/handle name is_ndarray dtypes/] => (is => 'rw');
217around BUILDARGS => sub {
218    my ($orig, $class, $handle, $name, $is_ndarray, $dtypes) = @_;
219    return $class->$orig(handle => $handle, name => $name, is_ndarray => $is_ndarray, dtypes => $dtypes);
220};
221
222sub BUILD
223{
224    my $self = shift;
225    $self->dtypes([map { DTYPE_MX_TO_STR->{$_} } @{ $self->dtypes }]);
226}
227
228sub DEMOLISH
229{
230    check_call(AI::MXNetCAPI::RtcCudaKernelFree(shift->handle));
231}
232
233=head2 launch
234
235        Launch cuda kernel.
236
237        Parameters
238        ----------
239        $args : ArrayRef[AI::MXNet::NDArray|Num]
240            List of arguments for kernel. NDArrays are expected for pointer
241            types (e.g. `float*`, `double*`) while numbers are expected for
242            non-pointer types (e.g. `int`, `float`).
243        $ctx : AI::MXNet::Context
244            The context to launch kernel on. Must be GPU context.
245        $grid_dims : array ref of 3 integers (CudaKernelShape)
246            Grid dimensions for CUDA kernel.
247        $block_dims : array ref of 3 integers (CudaKernelShape)
248            Block dimensions for CUDA kernel.
249        $shared_mem=0 : integer, optional
250            Size of dynamically allocated shared memory. Defaults to 0.
251=cut
252
253method launch(
254    ArrayRef[AI::MXNet::NDArray|Num] $args,
255    AI::MXNet::Context $ctx,
256    CudaKernelShape $grid_dims,
257    CudaKernelShape $block_dims,
258    Int $shared_mem=0
259)
260{
261    assert(($ctx->device_type eq 'gpu'), "Cuda kernel can only be launched on GPU");
262    confess("CudaKernel(${\ $self->name }) expects ".scalar(@{$self->dtypes}). "arguments but got ".scalar(@$args).".")
263        unless (@{ $args } == @{ $self->dtypes });
264    my @void_args;
265    enumerate(sub {
266        my ($i, $arg, $is_nd, $dtype) = @_;
267        if($is_nd)
268        {
269            confess("The $i-th argument is expected to be a NDArray but got [$arg]")
270                unless blessed $arg;
271            push @void_args, $arg->handle;
272        }
273        else
274        {
275            my $perl_pack_type = DTYPE_MX_TO_PERL->{$dtype};
276            my $packed_arg;
277            ## special handling for float16
278            if($perl_pack_type eq 'S')
279            {
280                $packed_arg = pack("S", AI::MXNetCAPI::_float_to_half($arg));
281            }
282            else
283            {
284                $packed_arg = pack($perl_pack_type, $arg);
285
286            }
287            push @void_args, $packed_arg;
288        }
289    }, $args, $self->is_ndarray, $self->dtypes);
290    check_call(
291        AI::MXNetCAPI::RtcCudaKernelCall(
292            $self->handle,
293            $ctx->device_id,
294            \@void_args,
295            @{ $grid_dims },
296            @{ $block_dims },
297            $shared_mem
298        )
299    );
300}
301
3021;
303