1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18use strict;
19use warnings;
20use Scalar::Util qw(blessed);
21use Test::More 'no_plan';
22use AI::MXNet qw(mx);
23use AI::MXNet::TestUtils qw(zip assert enumerate same rand_shape_2d rand_shape_3d
24    rand_sparse_ndarray random_arrays almost_equal rand_ndarray randint allclose dies_ok);
25use AI::MXNet::Base qw(pones pzeros pdl product rand_sparse);
26$ENV{MXNET_STORAGE_FALLBACK_LOG_VERBOSE} = 0;
27$ENV{MXNET_SUBGRAPH_VERBOSE} = 0;
28
29
30sub sparse_nd_ones
31{
32    my ($shape, $stype) = @_;
33    return mx->nd->ones($shape)->tostype($stype);
34}
35
36sub test_sparse_nd_elemwise_add
37{
38    my $check_sparse_nd_elemwise_binary = sub {
39        my ($shapes, $stypes, $f, $g) = @_;
40        # generate inputs
41        my @nds;
42        enumerate(sub {
43            my ($i, $stype) = @_;
44            my $nd;
45            if($stype eq 'row_sparse')
46            {
47                ($nd) = rand_sparse_ndarray($shapes->[$i], $stype);
48            }
49            elsif($stype eq 'default')
50            {
51                $nd = mx->nd->array(random_arrays($shapes->[$i]), dtype => 'float32');
52            }
53            else
54            {
55                die;
56            }
57            push @nds, $nd;
58        }, $stypes);
59        # check result
60        my $test = $f->($nds[0], $nds[1]);
61        ok(almost_equal($test->aspdl, $g->($nds[0]->aspdl, $nds[1]->aspdl)));
62    };
63    my $num_repeats = 2;
64    my $g = sub { $_[0] + $_[1] };
65    my $op = sub { mx->nd->elemwise_add(@_) };
66    for my $i (0..$num_repeats)
67    {
68        my $shape = rand_shape_2d();
69        $shape = [$shape, $shape];
70        $check_sparse_nd_elemwise_binary->($shape, ['default', 'default'], $op, $g);
71        $check_sparse_nd_elemwise_binary->($shape, ['row_sparse', 'row_sparse'], $op, $g);
72    }
73}
74
75test_sparse_nd_elemwise_add();
76
77sub test_sparse_nd_copy
78{
79    my $check_sparse_nd_copy = sub { my ($from_stype, $to_stype, $shape) = @_;
80        my $from_nd = rand_ndarray($shape, $from_stype);
81        # copy to ctx
82        my $to_ctx = $from_nd->copyto(AI::MXNet::Context->current_ctx);
83        # copy to stype
84        my $to_nd = rand_ndarray($shape, $to_stype);
85        $from_nd->copyto($to_nd);
86        ok(($from_nd->aspdl != $to_ctx->aspdl)->abs->sum == 0);
87        ok(($from_nd->aspdl != $to_nd->aspdl)->abs->sum == 0);
88    };
89    my $shape = rand_shape_2d();
90    my $shape_3d = rand_shape_3d();
91    my @stypes = ('row_sparse', 'csr');
92    for my $stype (@stypes)
93    {
94        $check_sparse_nd_copy->($stype, 'default', $shape);
95        $check_sparse_nd_copy->('default', $stype, $shape);
96    }
97    $check_sparse_nd_copy->('row_sparse', 'row_sparse', $shape_3d);
98    $check_sparse_nd_copy->('row_sparse', 'default', $shape_3d);
99    $check_sparse_nd_copy->('default', 'row_sparse', $shape_3d);
100}
101
102test_sparse_nd_copy();
103
104sub test_sparse_nd_basic
105{
106    my $check_sparse_nd_basic_rsp = sub {
107        my $storage_type = 'row_sparse';
108        my $shape = rand_shape_2d();
109        my ($nd) = rand_sparse_ndarray($shape, $storage_type);
110        ok($nd->_num_aux == 1);
111        ok($nd->indices->dtype eq 'int64');
112        ok($nd->stype eq 'row_sparse');
113    };
114    $check_sparse_nd_basic_rsp->();
115}
116
117test_sparse_nd_basic();
118
119sub test_sparse_nd_setitem
120{
121    my $check_sparse_nd_setitem = sub { my ($stype, $shape, $dst) = @_;
122        my $x = mx->nd->zeros($shape, stype=>$stype);
123        $x .= $dst;
124        my $dst_nd = (blessed $dst and $dst->isa('PDL')) ? mx->nd->array($dst) : $dst;
125        ok(($x->aspdl == (ref $dst_nd ? $dst_nd->aspdl : $dst_nd))->all);
126    };
127
128    my $shape = rand_shape_2d();
129    for my $stype ('row_sparse', 'csr')
130    {
131        # ndarray assignment
132        $check_sparse_nd_setitem->($stype, $shape, rand_ndarray($shape, 'default'));
133        $check_sparse_nd_setitem->($stype, $shape, rand_ndarray($shape, $stype));
134        # numpy assignment
135        $check_sparse_nd_setitem->($stype, $shape, pones(reverse @{ $shape }));
136    }
137    # scalar assigned to row_sparse NDArray
138    $check_sparse_nd_setitem->('row_sparse', $shape, 2);
139}
140
141test_sparse_nd_setitem();
142
143sub test_sparse_nd_slice
144{
145    my $shape = [randint(2, 10), randint(2, 10)];
146    my $stype = 'csr';
147    my ($A) = rand_sparse_ndarray($shape, $stype);
148    my $A2 = $A->aspdl;
149    my $start = randint(0, $shape->[0] - 1);
150    my $end = randint($start + 1, $shape->[0]);
151    ok(same($A->slice([$start, $end])->aspdl, $A2->slice('X', [$start, $end])));
152    ok(same($A->slice([$start - $shape->[0], $end])->aspdl, $A2->slice('X', [$start, $end])));
153    ok(same($A->slice([$start, $shape->[0] - 1])->aspdl, $A2->slice('X', [$start, $shape->[0]-1])));
154    ok(same($A->slice([0, $end])->aspdl, $A2->slice('X', [0, $end])));
155
156    my $start_col = randint(0, $shape->[1] - 1);
157    my $end_col = randint($start_col + 1, $shape->[1]);
158    my $result = $A->slice(begin=>[$start, $start_col], end=>[$end, $end_col]);
159    my $result_dense = mx->nd->array($A2)->slice(begin=>[$start, $start_col], end=>[$end, $end_col]);
160    ok(same($result_dense->aspdl, $result->aspdl));
161
162    $A = mx->nd->sparse->zeros('csr', $shape);
163    $A2 = $A->aspdl;
164    ok(same($A->slice([$start, $end])->aspdl, $A2->slice('X', [$start, $end])));
165    $result = $A->slice(begin=>[$start, $start_col], end=>[$end, $end_col]);
166    $result_dense = mx->nd->array($A2)->slice(begin=>[$start, $start_col], end=>[$end, $end_col]);
167    ok(same($result_dense->aspdl, $result->aspdl));
168
169    my $check_slice_nd_csr_fallback = sub { my ($shape) = @_;
170        my $stype = 'csr';
171        my ($A) = rand_sparse_ndarray($shape, $stype);
172        my $A2 = $A->aspdl;
173        my $start = randint(0, $shape->[0] - 1);
174        my $end = randint($start + 1, $shape->[0]);
175
176        # non-trivial step should fallback to dense slice op
177        my $result = $A->slice(begin=>[$start], end=>[$end+1], step=>[2]);
178        my $result_dense = mx->nd->array($A2)->slice(begin=>[$start], end=>[$end + 1], step=>[2]);
179        ok(same($result_dense->aspdl, $result->aspdl));
180    };
181    $shape = [randint(2, 10), randint(1, 10)];
182    $check_slice_nd_csr_fallback->($shape);
183}
184
185test_sparse_nd_slice();
186
187sub test_sparse_nd_equal
188{
189    for my $stype ('row_sparse', 'csr')
190    {
191        my $shape = rand_shape_2d();
192        my $x = mx->nd->zeros($shape, stype=>$stype);
193        my $y = sparse_nd_ones($shape, $stype);
194        my $z = $x == $y;
195        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
196        $z = 0 == $x;
197        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
198    }
199}
200
201test_sparse_nd_equal();
202
203sub test_sparse_nd_not_equal
204{
205    for my $stype ('row_sparse', 'csr')
206    {
207        my $shape = rand_shape_2d();
208        my $x = mx->nd->zeros($shape, stype=>$stype);
209        my $y = sparse_nd_ones($shape, $stype);
210        my $z = $x != $y;
211        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
212        $z = 0 != $x;
213        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
214    }
215}
216
217test_sparse_nd_not_equal();
218
219sub test_sparse_nd_greater
220{
221    for my $stype ('row_sparse', 'csr')
222    {
223        my $shape = rand_shape_2d();
224        my $x = mx->nd->zeros($shape, stype=>$stype);
225        my $y = sparse_nd_ones($shape, $stype);
226        my $z = $x > $y;
227        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
228        $z = $y > 0;
229        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
230        $z = 0 > $y;
231        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
232    }
233}
234
235test_sparse_nd_greater();
236
237sub test_sparse_nd_greater_equal
238{
239    for my $stype ('row_sparse', 'csr')
240    {
241        my $shape = rand_shape_2d();
242        my $x = mx->nd->zeros($shape, stype=>$stype);
243        my $y = sparse_nd_ones($shape, $stype);
244        my $z = $x >= $y;
245        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
246        $z = $y >= 0;
247        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
248        $z = 0 >= $y;
249        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
250        $z = $y >= 1;
251        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
252    }
253}
254
255test_sparse_nd_greater_equal();
256
257sub test_sparse_nd_lesser
258{
259    for my $stype ('row_sparse', 'csr')
260    {
261        my $shape = rand_shape_2d();
262        my $x = mx->nd->zeros($shape, stype=>$stype);
263        my $y = sparse_nd_ones($shape, $stype);
264        my $z = $y < $x;
265        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
266        $z = 0 < $y;
267        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
268        $z = $y < 0;
269        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
270    }
271}
272
273test_sparse_nd_lesser();
274
275sub test_sparse_nd_lesser_equal
276{
277    for my $stype ('row_sparse', 'csr')
278    {
279        my $shape = rand_shape_2d();
280        my $x = mx->nd->zeros($shape, stype=>$stype);
281        my $y = sparse_nd_ones($shape, $stype);
282        my $z = $y <= $x;
283        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
284        $z = 0 <= $y;
285        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
286        $z = $y <= 0;
287        ok(($z->aspdl == pzeros(reverse @{ $shape }))->all);
288        $z = 1 <= $y;
289        ok(($z->aspdl == pones(reverse @{ $shape }))->all);
290    }
291}
292
293test_sparse_nd_lesser_equal();
294
295sub test_sparse_nd_binary
296{
297    my $N = 2;
298    my $check_binary = sub { my ($fn, $stype) = @_;
299        for (0 .. 2)
300        {
301            my $ndim = 2;
302            my $oshape = [map { randint(1, 6) } 1..$ndim];
303            my $bdim = 2;
304            my @lshape = @$oshape;
305            # one for broadcast op, another for elemwise op
306            my @rshape = @lshape[($ndim-$bdim)..@lshape-1];
307            for my $i (0..$bdim-1)
308            {
309                my $sep = mx->nd->random->uniform(0, 1)->asscalar;
310                if($sep < 0.33)
311                {
312                    $lshape[$ndim-$i-1] = 1;
313                }
314                elsif($sep < 0.66)
315                {
316                    $rshape[$bdim-$i-1] = 1;
317                }
318            }
319            my $lhs = mx->nd->random->uniform(0, 1, shape=>\@lshape)->aspdl;
320            my $rhs = mx->nd->random->uniform(0, 1, shape=>\@rshape)->aspdl;
321            my $lhs_nd = mx->nd->array($lhs)->tostype($stype);
322            my $rhs_nd = mx->nd->array($rhs)->tostype($stype);
323            ok(allclose($fn->($lhs, $rhs), $fn->($lhs_nd, $rhs_nd)->aspdl, 1e-4));
324        }
325    };
326    for my $stype ('row_sparse', 'csr')
327    {
328        $check_binary->(sub { $_[0] +  $_[1] }, $stype);
329        $check_binary->(sub { $_[0] -  $_[1] }, $stype);
330        $check_binary->(sub { $_[0] *  $_[1] }, $stype);
331        $check_binary->(sub { $_[0] /  $_[1] }, $stype);
332        $check_binary->(sub { $_[0] ** $_[1] }, $stype);
333        $check_binary->(sub { $_[0] >  $_[1] }, $stype);
334        $check_binary->(sub { $_[0] <  $_[1] }, $stype);
335        $check_binary->(sub { $_[0] >= $_[1] }, $stype);
336        $check_binary->(sub { $_[0] <= $_[1] }, $stype);
337        $check_binary->(sub { $_[0] == $_[1] }, $stype);
338    }
339}
340
341test_sparse_nd_binary();
342
343sub test_sparse_nd_binary_scalar_op
344{
345    my $N = 3;
346    my $check = sub { my ($fn, $stype) = @_;
347        for (1..$N)
348        {
349            my $ndim = 2;
350            my $shape = [map { randint(1, 6) } 1..$ndim];
351            my $npy = mx->nd->random->normal(0, 1, shape=>$shape)->aspdl;
352            my $nd = mx->nd->array($npy)->tostype($stype);
353            ok(allclose($fn->($npy), $fn->($nd)->aspdl, 1e-4));
354        }
355    };
356    for my $stype ('row_sparse', 'csr')
357    {
358        $check->(sub { 1 +    $_[0] }, $stype);
359        $check->(sub { 1 -    $_[0] }, $stype);
360        $check->(sub { 1 *    $_[0] }, $stype);
361        $check->(sub { 1 /    $_[0] }, $stype);
362        $check->(sub { 2 **   $_[0] }, $stype);
363        $check->(sub { 1 >    $_[0] }, $stype);
364        $check->(sub { 0.5 >  $_[0] }, $stype);
365        $check->(sub { 0.5 <  $_[0] }, $stype);
366        $check->(sub { 0.5 >= $_[0] }, $stype);
367        $check->(sub { 0.5 <= $_[0] }, $stype);
368        $check->(sub { 0.5 == $_[0] }, $stype);
369        $check->(sub { $_[0] / 2    }, $stype);
370    }
371}
372
373test_sparse_nd_binary_scalar_op();
374
375sub test_sparse_nd_binary_iop
376{
377    my $N = 3;
378    my $check_binary = sub { my ($fn, $stype) = @_;
379        for (1..$N)
380        {
381            my $ndim = 2;
382            my $oshape = [map { randint(1, 6) } 1..$ndim];
383            my $lhs = mx->nd->random->uniform(0, 1, shape => $oshape)->aspdl;
384            my $rhs = mx->nd->random->uniform(0, 1, shape => $oshape)->aspdl;
385            my $lhs_nd = mx->nd->array($lhs)->tostype($stype);
386            my $rhs_nd = mx->nd->array($rhs)->tostype($stype);
387            ok(
388                allclose(
389                    $fn->($lhs, $rhs),
390                    $fn->($lhs_nd, $rhs_nd)->aspdl,
391                    1e-4
392                )
393            );
394        }
395    };
396
397    my $inplace_add = sub { my ($x, $y) = @_;
398        $x += $y;
399        return $x
400    };
401    my $inplace_mul = sub { my ($x, $y) = @_;
402        $x *= $y;
403        return $x
404    };
405    my @stypes = ('csr', 'row_sparse');
406    my @fns = ($inplace_add, $inplace_mul);
407    for my $stype (@stypes)
408    {
409        for my $fn (@fns)
410        {
411            $check_binary->($fn, $stype);
412        }
413    }
414}
415
416test_sparse_nd_binary_iop();
417
418sub test_sparse_nd_negate
419{
420    my $check_sparse_nd_negate = sub { my ($shape, $stype) = @_;
421        my $npy = mx->nd->random->uniform(-10, 10, shape => rand_shape_2d())->aspdl;
422        my $arr = mx->nd->array($npy)->tostype($stype);
423        ok(almost_equal($npy, $arr->aspdl));
424        ok(almost_equal(-$npy, (-$arr)->aspdl));
425
426        # a final check to make sure the negation (-) is not implemented
427        # as inplace operation, so the contents of arr does not change after
428        # we compute (-arr)
429        ok(almost_equal($npy, $arr->aspdl));
430    };
431    my $shape = rand_shape_2d();
432    my @stypes = ('csr', 'row_sparse');
433    for my $stype (@stypes)
434    {
435        $check_sparse_nd_negate->($shape, $stype);
436    }
437}
438
439test_sparse_nd_negate();
440
441sub test_sparse_nd_broadcast
442{
443    my $sample_num = 10; # TODO 1000
444    my $test_broadcast_to = sub { my ($stype) = @_;
445        for (1..$sample_num)
446        {
447            my $ndim = 2;
448            my $target_shape = [map { randint(1, 11) } 1..$ndim];
449            my $shape = \@{ $target_shape };
450            my $axis_flags = [map { randint(0, 2) } 1..$ndim];
451            my $axes = [];
452            enumerate(sub {
453                my ($axis, $flag) = @_;
454                if($flag)
455                {
456                    $shape->[$axis] = 1;
457                }
458            }, $axis_flags);
459            my $dat = mx->nd->random->uniform(0, 1, shape => $shape)->aspdl - 0.5;
460            my $pdl_ret = $dat;
461            my $ndarray = mx->nd->array($dat)->tostype($stype);
462            my $ndarray_ret = $ndarray->broadcast_to($target_shape);
463            ok((pdl($ndarray_ret->shape) == pdl($target_shape))->all);
464            my $err = (($ndarray_ret->aspdl - $pdl_ret)**2)->avg;
465            ok($err < 1E-8);
466        }
467    };
468    my @stypes = ('csr', 'row_sparse');
469    for my $stype (@stypes)
470    {
471        $test_broadcast_to->($stype);
472    }
473}
474
475test_sparse_nd_broadcast();
476
477sub test_sparse_nd_transpose
478{
479    my $npy = mx->nd->random->uniform(-10, 10, shape => rand_shape_2d())->aspdl;
480    my @stypes = ('csr', 'row_sparse');
481    for my $stype (@stypes)
482    {
483        my $nd = mx->nd->array($npy)->tostype($stype);
484        ok(almost_equal($npy->transpose, ($nd->T)->aspdl));
485    }
486}
487
488test_sparse_nd_transpose();
489
490sub test_sparse_nd_storage_fallback
491{
492    my $check_output_fallback = sub { my ($shape) = @_;
493        my $ones = mx->nd->ones($shape);
494        my $out = mx->nd->zeros($shape, stype=>'csr');
495        mx->nd->broadcast_add($ones, $ones * 2, out=>$out);
496        ok(($out->aspdl - 3)->sum == 0);
497    };
498
499    my $check_input_fallback = sub { my ($shape) = @_;
500        my $ones = mx->nd->ones($shape);
501        my $out = mx->nd->broadcast_add($ones->tostype('csr'), $ones->tostype('row_sparse'));
502        ok(($out->aspdl - 2)->sum == 0);
503    };
504
505    my $check_fallback_with_temp_resource = sub { my ($shape) = @_;
506        my $ones = mx->nd->ones($shape);
507        my $out = mx->nd->sum($ones);
508        ok($out->asscalar == product(@{ $shape }));
509    };
510
511    my $shape = rand_shape_2d();
512    $check_output_fallback->($shape);
513    $check_input_fallback->($shape);
514    $check_fallback_with_temp_resource->($shape);
515}
516
517test_sparse_nd_storage_fallback();
518
519sub test_sparse_nd_astype
520{
521    my @stypes = ('row_sparse', 'csr');
522    for my $stype (@stypes)
523    {
524        my $x = mx->nd->zeros(rand_shape_2d(), stype => $stype, dtype => 'float32');
525        my $y = $x->astype('int32');
526        ok($y->dtype eq 'int32');
527    }
528}
529
530test_sparse_nd_astype();
531
532sub test_sparse_nd_storable
533{
534    my $repeat = 1;
535    my $dim0 = 40;
536    my $dim1 = 40;
537    my @stypes = ('row_sparse', 'csr');
538    my @densities = (0, 0.5);
539    my %stype = (row_sparse => 'AI::MXNet::NDArray::RowSparse', csr => 'AI::MXNet::NDArray::CSR');
540    for (1..$repeat)
541    {
542        my $shape = rand_shape_2d($dim0, $dim1);
543        for my $stype (@stypes)
544        {
545            for my $density (@densities)
546            {
547                my ($a) = rand_sparse_ndarray($shape, $stype, density => $density);
548                ok($a->isa($stype{$stype}));
549                my $data = Storable::freeze($a);
550                my $b = Storable::thaw($data);
551                ok($b->isa($stype{$stype}));
552                ok(same($a->aspdl, $b->aspdl));
553            }
554        }
555    }
556}
557
558test_sparse_nd_storable();
559
560sub test_sparse_nd_save_load
561{
562    my $repeat = 1;
563    my @stypes = ('default', 'row_sparse', 'csr');
564    my %stype = (default => 'AI::MXNet::NDArray', row_sparse => 'AI::MXNet::NDArray::RowSparse', csr => 'AI::MXNet::NDArray::CSR');
565    my $num_data = 20;
566    my @densities = (0, 0.5);
567    my $fname = 'tmp_list.bin';
568    for (1..$repeat)
569    {
570        my @data_list1;
571        for (1..$num_data)
572        {
573            my $stype = $stypes[randint(0, scalar(@stypes))];
574            my $shape = rand_shape_2d(40, 40);
575            my $density = $densities[randint(0, scalar(@densities))];
576            push @data_list1, rand_ndarray($shape, $stype, $density);
577            ok($data_list1[-1]->isa($stype{$stype}));
578        }
579        mx->nd->save($fname, \@data_list1);
580
581        my @data_list2 = @{ mx->nd->load($fname) };
582        ok(@data_list1 == @data_list2);
583        zip(sub {
584            my ($x, $y) = @_;
585            ok(same($x->aspdl, $y->aspdl));
586        }, \@data_list1, \@data_list2);
587
588        my %data_map1;
589        enumerate(sub {
590            my ($i, $x) = @_;
591            $data_map1{"ndarray xx $i"} = $x;
592        }, \@data_list1);
593        mx->nd->save($fname, \%data_map1);
594        my %data_map2 = %{ mx->nd->load($fname) };
595        ok(keys(%data_map1) == keys(%data_map2));
596        while(my ($k, $x) = each %data_map1)
597        {
598            my $y = $data_map2{$k};
599            ok(same($x->aspdl, $y->aspdl));
600        }
601    }
602    unlink $fname;
603}
604
605test_sparse_nd_save_load();
606
607sub test_create_csr
608{
609    my $check_create_csr_from_nd = sub { my ($shape, $density, $dtype) = @_;
610        my $matrix = rand_ndarray($shape, 'csr', $density);
611        # create data array with provided dtype and ctx
612        my $data = mx->nd->array($matrix->data->aspdl, dtype=>$dtype);
613        my $indptr = $matrix->indptr;
614        my $indices = $matrix->indices;
615        my $csr_created = mx->nd->sparse->csr_matrix([$data, $indices, $indptr], shape=>$shape);
616        ok($csr_created->stype eq 'csr');
617        ok(same($csr_created->data->aspdl, $data->aspdl));
618        ok(same($csr_created->indptr->aspdl, $indptr->aspdl));
619        ok(same($csr_created->indices->aspdl, $indices->aspdl));
620        # verify csr matrix dtype and ctx is consistent from the ones provided
621        ok($csr_created->dtype eq $dtype);
622        ok($csr_created->data->dtype eq $dtype);
623        ok($csr_created->context eq AI::MXNet::Context->current_ctx);
624        my $csr_copy = mx->nd->array($csr_created);
625        ok(same($csr_copy->aspdl, $csr_created->aspdl));
626    };
627
628    my $check_create_csr_from_coo = sub { my ($shape, $density, $dtype) = @_;
629        my $matrix = rand_ndarray($shape, 'csr', $density);
630        my $sp_csr = $matrix->aspdlccs;
631        my $sp_coo = $sp_csr->tocoo();
632        my $csr_created = mx->nd->sparse->csr_matrix([$sp_coo->data, [$sp_coo->row, $sp_coo->col]], shape=>$shape, dtype=>$dtype);
633        ok($csr_created->stype eq 'csr');
634        ok(same($csr_created->data->aspdl, $sp_csr->data));
635        ok(same($csr_created->indptr->aspdl, $sp_csr->indptr));
636        ok(same($csr_created->indices->aspdl, $sp_csr->indices));
637        my $csr_copy = mx->nd->array($csr_created);
638        ok(same($csr_copy->aspdl, $csr_created->aspdl));
639        # verify csr matrix dtype and ctx is consistent
640        ok($csr_created->dtype eq $dtype);
641        ok($csr_created->data->dtype eq $dtype);
642        ok($csr_created->context eq AI::MXNet::Context->current_ctx);
643    };
644
645    my $check_create_csr_from_pdlccs = sub { my ($shape, $density, $f) = @_;
646        my $assert_csr_almost_equal = sub { my ($nd, $sp) = @_;
647            ok(almost_equal($nd->data->aspdl, $sp->data));
648            ok(almost_equal($nd->indptr->aspdl, $sp->indptr));
649            ok(almost_equal($nd->indices->aspdl, $sp->indices));
650            my $sp_csr = $nd->aspdlccs;
651            ok(almost_equal($sp_csr->data, $sp->data));
652            ok(almost_equal($sp_csr->indptr, $sp->indptr));
653            ok(almost_equal($sp_csr->indices, $sp->indices));
654            ok($sp->dtype eq $sp_csr->dtype);
655        };
656
657            my $csr_sp = rand_sparse($shape->[0], $shape->[1], $density);
658            my $csr_nd = $f->($csr_sp);
659            ok(almost_equal($csr_nd->aspdl, $csr_sp->todense));
660            # non-canonical csr which contains duplicates and unsorted indices
661            my $indptr = pdl([0, 2, 3, 7]);
662            my $indices = pdl([0, 2, 2, 0, 1, 2, 1]);
663            my $data = pdl([1, 2, 3, 4, 5, 6, 1]);
664            my $non_canonical_csr = mx->nd->sparse->csr_matrix([$data, $indices, $indptr], shape=>[3, 3], dtype=>$csr_nd->dtype);
665            my $canonical_csr_nd = $f->($non_canonical_csr, dtype=>$csr_nd->dtype);
666            my $canonical_csr_sp = $non_canonical_csr->copy();
667            ok(almost_equal($canonical_csr_nd->aspdl, $canonical_csr_sp->aspdl));
668    };
669
670    my $dim0 = 20;
671    my $dim1 = 20;
672    my @densities = (0.5);
673    my $dtype = 'float64';
674    for my $density (@densities)
675    {
676        my $shape = [$dim0, $dim1];
677        $check_create_csr_from_nd->($shape, $density, $dtype);
678        $check_create_csr_from_coo->($shape, $density, $dtype);
679        $check_create_csr_from_pdlccs->($shape, $density, sub { mx->nd->sparse->array(@_) });
680        $check_create_csr_from_pdlccs->($shape, $density, sub { mx->nd->array(@_) });
681    }
682}
683
684test_create_csr();
685
686sub test_create_row_sparse
687{
688    my $dim0 = 50;
689    my $dim1 = 50;
690    my @densities = (0, 0.5, 1);
691    for my $density (@densities)
692    {
693        my $shape = rand_shape_2d($dim0, $dim1);
694        my $matrix = rand_ndarray($shape, 'row_sparse', $density);
695        my $data = $matrix->data;
696        my $indices = $matrix->indices;
697        my $rsp_created = mx->nd->sparse->row_sparse_array([$data, $indices], shape=>$shape);
698        ok($rsp_created->stype eq 'row_sparse');
699        ok(same($rsp_created->data->aspdl, $data->aspdl));
700        ok(same($rsp_created->indices->aspdl, $indices->aspdl));
701        my $rsp_copy = mx->nd->array($rsp_created);
702        ok(same($rsp_copy->aspdl, $rsp_created->aspdl));
703    }
704}
705
706test_create_row_sparse();
707
708sub test_create_sparse_nd_infer_shape
709{
710    my $check_create_csr_infer_shape = sub { my ($shape, $density, $dtype) = @_;
711        eval {
712            my $matrix = rand_ndarray($shape, 'csr', $density);
713            my $data = $matrix->data;
714            my $indptr = $matrix->indptr;
715            my $indices = $matrix->indices;
716            my $nd = mx->nd->sparse->csr_matrix([$data, $indices, $indptr], dtype=>$dtype);
717            my ($num_rows, $num_cols) = @{ $nd->shape };
718            ok($num_rows == @{ $indptr } - 1);
719            ok($indices->shape->[0] > 0);
720            ok(($num_cols <= $indices)->aspdl->sum == 0);
721            ok($nd->dtype eq $dtype);
722        };
723    };
724    my $check_create_rsp_infer_shape = sub { my ($shape, $density, $dtype) = @_;
725        eval {
726            my $array = rand_ndarray($shape, 'row_sparse', $density);
727            my $data = $array->data;
728            my $indices = $array->indices;
729            my $nd = mx->nd->sparse->row_sparse_array([$data, $indices], dtype=>$dtype);
730            my $inferred_shape = $nd->shape;
731            is_deeply([@{ $inferred_shape }[1..@{ $inferred_shape }-1]], [@{ $data->shape }[1..@{ $data->shape }-1]]);
732            ok($indices->ndim > 0);
733            ok($nd->dtype eq $dtype);
734            if($indices->shape->[0] > 0)
735            {
736                ok(($inferred_shape->[0] <= $indices)->aspdl->sum == 0);
737            }
738        };
739    };
740
741    my $dtype = 'int32';
742    my $shape = rand_shape_2d();
743    my $shape_3d = rand_shape_3d();
744    my @densities = (0, 0.5, 1);
745    for my $density (@densities)
746    {
747        $check_create_csr_infer_shape->($shape, $density, $dtype);
748        $check_create_rsp_infer_shape->($shape, $density, $dtype);
749        $check_create_rsp_infer_shape->($shape_3d, $density, $dtype);
750    }
751}
752
753test_create_sparse_nd_infer_shape();
754
755sub test_create_sparse_nd_from_dense
756{
757    my $check_create_from_dns = sub { my ($shape, $f, $dense_arr, $dtype, $default_dtype, $ctx) = @_;
758        my $arr = $f->($dense_arr, shape => $shape, dtype => $dtype, ctx => $ctx);
759        ok(same($arr->aspdl, pones(reverse @{ $shape })));
760        ok($arr->dtype eq $dtype);
761        ok($arr->context eq $ctx);
762        # verify the default dtype inferred from dense arr
763        my $arr2 = $f->($dense_arr);
764        ok($arr2->dtype eq $default_dtype);
765        ok($arr2->context eq AI::MXNet::Context->current_ctx);
766    };
767    my $shape = rand_shape_2d();
768    my $dtype = 'int32';
769    my $src_dtype = 'float64';
770    my $ctx = mx->cpu(1);
771    my @dense_arrs = (
772        mx->nd->ones($shape, dtype=>$src_dtype),
773        mx->nd->ones($shape, dtype=>$src_dtype)->aspdl
774    );
775    for my $f (sub { mx->nd->sparse->csr_matrix(@_) }, sub { mx->nd->sparse->row_sparse_array(@_) })
776    {
777        for my $dense_arr (@dense_arrs)
778        {
779            my $default_dtype = blessed($dense_arr) ? $dense_arr->dtype : 'float32';
780            $check_create_from_dns->($shape, $f, $dense_arr, $dtype, $default_dtype, $ctx);
781        }
782    }
783}
784
785test_create_sparse_nd_from_dense();
786
787sub test_create_sparse_nd_from_sparse
788{
789    my $check_create_from_sp = sub { my ($shape, $f, $sp_arr, $dtype, $src_dtype, $ctx) = @_;
790        my $arr = $f->($sp_arr, shape => $shape, dtype=>$dtype, ctx=>$ctx);
791        ok(same($arr->aspdl, pones(reverse @{ $shape })));
792        ok($arr->dtype eq $dtype);
793        ok($arr->context eq $ctx);
794        # verify the default dtype inferred from sparse arr
795        my $arr2 = $f->($sp_arr);
796        ok($arr2->dtype eq $src_dtype);
797        ok($arr2->context eq AI::MXNet::Context->current_ctx);
798    };
799
800    my $shape = rand_shape_2d();
801    my $src_dtype = 'float64';
802    my $dtype = 'int32';
803    my $ctx = mx->cpu(1);
804    my $ones = mx->nd->ones($shape, dtype=>$src_dtype);
805    my @csr_arrs = ($ones->tostype('csr'));
806    my @rsp_arrs = ($ones->tostype('row_sparse'));
807    push @csr_arrs, mx->nd->ones($shape, dtype=>$src_dtype)->aspdl->tocsr;
808    my $f_csr = sub { mx->nd->sparse->csr_matrix(@_) };
809    my $f_rsp = sub { mx->nd->sparse->row_sparse_array(@_) };
810    for my $sp_arr (@csr_arrs)
811    {
812        $check_create_from_sp->($shape, $f_csr, $sp_arr, $dtype, $src_dtype, $ctx);
813    }
814    for my $sp_arr (@rsp_arrs)
815    {
816        $check_create_from_sp->($shape, $f_rsp, $sp_arr, $dtype, $src_dtype, $ctx);
817    }
818}
819
820test_create_sparse_nd_from_sparse();
821
822sub test_create_sparse_nd_empty
823{
824    my $check_empty = sub { my ($shape, $stype) = @_;
825        my $arr = mx->nd->sparse->empty($stype, $shape);
826        ok($arr->stype eq $stype);
827        ok(same($arr->aspdl, pzeros(reverse(@{ $shape }))));
828    };
829
830    my $check_csr_empty = sub { my ($shape, $dtype, $ctx) = @_;
831        my $arr = mx->nd->sparse->csr_matrix(undef, shape => $shape, dtype => $dtype, ctx => $ctx);
832        ok($arr->stype eq 'csr');
833        ok($arr->dtype eq $dtype);
834        ok($arr->context eq $ctx);
835        ok(same($arr->aspdl, pzeros(reverse(@{ $shape }))));
836        # check the default value for dtype and ctx
837        $arr = mx->nd->sparse->csr_matrix(undef, shape => $shape);
838        ok($arr->dtype eq 'float32');
839        ok($arr->context eq AI::MXNet::Context->current_ctx);
840    };
841
842    my $check_rsp_empty = sub { my ($shape, $dtype, $ctx) = @_;
843        my $arr = mx->nd->sparse->row_sparse_array(undef, shape => $shape, dtype=>$dtype, ctx=>$ctx);
844        ok($arr->stype eq 'row_sparse');
845        ok($arr->dtype eq $dtype);
846        ok($arr->context eq $ctx);
847        ok(same($arr->aspdl, pzeros(reverse(@{ $shape }))));
848        # check the default value for dtype and ctx
849        $arr = mx->nd->sparse->row_sparse_array(undef, shape => $shape);
850        ok($arr->dtype eq 'float32');
851        ok($arr->context eq AI::MXNet::Context->current_ctx);
852    };
853
854    my @stypes = ('csr', 'row_sparse');
855    my $shape = rand_shape_2d();
856    my $shape_3d = rand_shape_3d();
857    my $dtype = 'int32';
858    my $ctx = mx->cpu(1);
859    for my $stype (@stypes)
860    {
861        $check_empty->($shape, $stype);
862    }
863    $check_csr_empty->($shape, $dtype, $ctx);
864    $check_rsp_empty->($shape, $dtype, $ctx);
865    $check_rsp_empty->($shape_3d, $dtype, $ctx);
866}
867
868test_create_sparse_nd_empty();
869
870sub test_synthetic_dataset_generator
871{
872    my $test_powerlaw_generator = sub { my ($csr_arr, $final_row) = @_;
873        my $indices = $csr_arr->indices->aspdl;
874        my $indptr = $csr_arr->indptr->aspdl;
875        for my $row (1..$final_row)
876        {
877            my $nextrow = $row + 1;
878            my $current_row_nnz = $indices->at($indptr->at($row) - 1) + 1;
879            my $next_row_nnz = $indices->at($indptr->at($nextrow) - 1) + 1;
880            ok($next_row_nnz == 2 * $current_row_nnz);
881        }
882    };
883
884    # Test if density is preserved
885    my ($csr_arr_cols) = rand_sparse_ndarray([32, 10000], "csr",
886                                          density=>0.01, distribution=>"powerlaw");
887
888    my ($csr_arr_small) = rand_sparse_ndarray([5, 5], "csr",
889                                           density=>0.5, distribution=>"powerlaw");
890
891    my ($csr_arr_big) = rand_sparse_ndarray([32, 1000000], "csr",
892                                         density=>0.4, distribution=>"powerlaw");
893
894    my ($csr_arr_square) = rand_sparse_ndarray([1600, 1600], "csr",
895                                            density=>0.5, distribution=>"powerlaw");
896    ok($csr_arr_cols->data->len == 3200);
897    $test_powerlaw_generator->($csr_arr_cols, 9);
898    $test_powerlaw_generator->($csr_arr_small, 1);
899    $test_powerlaw_generator->($csr_arr_big, 4);
900    $test_powerlaw_generator->($csr_arr_square, 6);
901}
902
903test_synthetic_dataset_generator();
904
905sub test_sparse_nd_fluent
906{
907    my $check_fluent_regular = sub { my ($stype, $func, $kwargs, $shape, $equal_nan) = @_;
908        $shape //= [5, 17];
909        my $data = mx->nd->random->uniform(shape=>$shape)->tostype($stype);
910        my $regular = AI::MXNet::NDArray::Base->$func($data, %$kwargs);
911        my $fluent  = $data->$func(%$kwargs);
912        ok(almost_equal($regular->aspdl, $fluent->aspdl));
913    };
914
915    my @common_func = ('zeros_like', 'square');
916    my @rsp_func = ('round', 'rint', 'fix', 'floor', 'ceil', 'trunc',
917                'abs', 'sign', 'sin', 'degrees', 'radians', 'expm1');
918    for my $func (@common_func)
919    {
920        $check_fluent_regular->('csr', $func, {});
921    }
922    for my $func (@common_func, @rsp_func)
923    {
924        $check_fluent_regular->('row_sparse', $func, {});
925    }
926
927    @rsp_func = ('arcsin', 'arctan', 'tan', 'sinh', 'tanh',
928                'arcsinh', 'arctanh', 'log1p', 'sqrt', 'relu');
929    for my $func (@rsp_func)
930    {
931        $check_fluent_regular->('row_sparse', $func, {});
932    }
933
934    $check_fluent_regular->('csr', 'slice', {begin => [2, 5], end => [4, 7]});
935    $check_fluent_regular->('row_sparse', 'clip', {a_min => -0.25, a_max => 0.75});
936
937    for my $func ('sum', 'mean')
938    {
939        $check_fluent_regular->('csr', $func, {axis => 0});
940    }
941}
942
943test_sparse_nd_fluent();
944
945sub test_sparse_nd_exception
946{
947    my $a = mx->nd->ones([2,2]);
948    dies_ok(sub { mx->nd->sparse->retain($a, invalid_arg=>"garbage_value") });
949    dies_ok(sub { mx->nd->sparse->csr_matrix($a, shape=>[3,2]) });
950    dies_ok(sub { mx->nd->sparse->csr_matrix(pdl([2,2]), shape=>[3,2]) });
951    dies_ok(sub { mx->nd->sparse->row_sparse_array(pdl([2,2]), shape=>[3,2]) });
952    dies_ok(sub { mx->nd->sparse->zeros("invalid_stype", [2,2]) });
953}
954
955test_sparse_nd_exception();
956
957sub test_sparse_nd_check_format
958{
959    my $shape = rand_shape_2d();
960    my @stypes = ("csr", "row_sparse");
961    for my $stype (@stypes)
962    {
963        my ($arr) = rand_sparse_ndarray($shape, $stype);
964        $arr->check_format();
965        $arr = mx->nd->sparse->zeros($stype, $shape);
966        $arr->check_format();
967    }
968    # CSR format index pointer array should be less than the number of rows
969    $shape = [3, 4];
970    my $data_list = [7, 8, 9];
971    my $indices_list = [0, 2, 1];
972    my $indptr_list = [0, 5, 2, 3];
973    my $a = mx->nd->sparse->csr_matrix([$data_list, $indices_list, $indptr_list], shape=>$shape);
974    dies_ok(sub { $a->check_format });
975    # CSR format indices should be in ascending order per row
976    $indices_list = [2, 1, 1];
977    $indptr_list = [0, 2, 2, 3];
978    $a = mx->nd->sparse->csr_matrix([$data_list, $indices_list, $indptr_list], shape=>$shape);
979    dies_ok(sub { $a->check_format });
980    # CSR format indptr should end with value equal with size of indices
981    $indices_list = [1, 2, 1];
982    $indptr_list = [0, 2, 2, 4];
983    $a = mx->nd->sparse->csr_matrix([$data_list, $indices_list, $indptr_list], shape=>$shape);
984    dies_ok(sub { $a->check_format });
985    # CSR format indices should not be negative
986    $indices_list = [0, 2, 1];
987    $indptr_list = [0, -2, 2, 3];
988    $a = mx->nd->sparse->csr_matrix([$data_list, $indices_list, $indptr_list], shape=>$shape);
989    dies_ok(sub { $a->check_format });
990    # Row Sparse format indices should be less than the number of rows
991    $shape = [3, 2];
992    $data_list = [[1, 2], [3, 4]];
993    $indices_list = [1, 4];
994    $a = mx->nd->sparse->row_sparse_array([$data_list, $indices_list], shape=>$shape);
995    dies_ok(sub { $a->check_format });
996    # Row Sparse format indices should be in ascending order
997    $indices_list = [1, 0];
998    $a = mx->nd->sparse->row_sparse_array([$data_list, $indices_list], shape=>$shape);
999    dies_ok(sub { $a->check_format });
1000    # Row Sparse format indices should not be negative
1001    $indices_list = [1, -2];
1002    $a = mx->nd->sparse->row_sparse_array([$data_list, $indices_list], shape=>$shape);
1003    dies_ok(sub { $a->check_format });
1004}
1005
1006test_sparse_nd_check_format();
1007