1function test53(fulltests)
2%TEST53 test GrB_Matrix_extract
3
4% SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
5% SPDX-License-Identifier: Apache-2.0
6
7if (nargin < 1)
8    fulltests = 0 ;
9end
10
11if (fulltests)
12    fprintf ('\n==== test53: exhaustive test for GrB_Matrix_extract:\n') ;
13else
14    fprintf ('\n==== test53: quick test for GrB_Matrix_extract:\n') ;
15end
16
17[binops, ~, ~, types, ~, ~] = GB_spec_opsall ;
18accum_ops = binops.all ;
19types = types.all ;
20
21problems = [
22    10,    1,    7,  -5, 100
23    10,    8,   40,  -5, 100
24    10,  20,  100,  -99, 200
25    100, 200, 1000, -99, 200
26     50,  50,  500,  -2, 3
27    ] ;
28
29% try several problems
30for k0 = 1:size (problems,1) ;
31
32    % create nnz triplets for a matrix of size nrows-by-nrows
33    nrows = problems (k0,1) ;
34    ncols = problems (k0,2) ;
35    nnz = problems (k0,3) ;
36    y1 = problems (k0,4) ;
37    y2 = problems (k0,5) ;
38
39    % create A
40    rng ('default') ;
41    I = irand (0, nrows-1, nnz, 1) ;
42    J = irand (0, ncols-1, nnz, 1) ;
43    Y = y2 * rand (nnz, 1) + y1 ;
44    clear A
45    A.matrix = sparse (double (I)+1, double (J)+1, Y, nrows, ncols) ;
46
47    % create Cin; note that it has the same dimensions as A, so if A
48    % gets transpose, Cin must also be transposed (so use Cin2 instead)
49    I = irand (0, nrows-1, nnz, 1) ;
50    J = irand (0, ncols-1, nnz, 1) ;
51    Y = y2 * rand (nnz, 1) + y1 ;
52    clear Cin
53    Cin.matrix = sparse (double (I)+1, double (J)+1, Y, nrows, ncols) ;
54    clear I J
55
56    clear Cin2
57    Cin2.matrix = Cin.matrix' ;
58
59    clear Cempty
60    Cempty.matrix = sparse (nrows, ncols) ;
61    Cempty2.matrix = Cempty.matrix' ;
62
63    % create a boolean Mask with roughly the same density as A and Cin
64    Mask = cast (sprandn (nrows, ncols, nnz/(nrows*ncols)), 'logical') ;
65
66    fprintf ('\nnrows: %d ncols %d nnz %d ymin %g ymax %g\n', ...
67        nrows, ncols, nnz, min (Y), max (Y)) ;
68
69    if (fulltests)
70        k1_list = [1:length(types)] ;
71    else
72        k1_list = 11 ;
73    end
74
75    % try every type for A
76    for k1 = k1_list % 1:length(types)
77        atype = types {k1} ;
78        A.class = atype ;
79        Cempty.class = atype ;
80        Cempty2.class = atype ;
81
82        % C = A (:,:)
83        C = GB_mex_Matrix_extract  (Cempty, [ ], [ ], A, [ ], [ ], [ ]) ;
84        assert (GB_spok (C.matrix*1) == 1) ;
85        S = GB_spec_Matrix_extract (Cempty, [ ], [ ], A, [ ], [ ], [ ]) ;
86        assert (isequal (C.class, A.class)) ;
87        assert (isequal (C.class, S.class)) ;
88        assert (isequal (full (double (C.matrix)), double (S.matrix))) ;
89        if (isequal (C.class, 'double'))
90            assert (isequal (C.matrix, A.matrix)) ;
91        end
92
93        % C = A (:,:)'
94        clear D
95        D = struct ('inp0', 'tran') ;
96        C = GB_mex_Matrix_extract  (Cempty2, [ ], [ ], A, [ ], [ ], D) ;
97        assert (GB_spok (C.matrix*1) == 1) ;
98        S = GB_spec_Matrix_extract (Cempty2, [ ], [ ], A, [ ], [ ], D) ;
99        assert (isequal (C.class, A.class)) ;
100        assert (isequal (C.class, S.class)) ;
101        assert (isequal (full (double (C.matrix)), double (S.matrix))) ;
102        if (isequal (C.class, 'double'))
103            assert (isequal (C.matrix, A.matrix')) ;
104        end
105
106        % C<Mask> = A (:,:)
107        C = GB_mex_Matrix_extract  (Cempty, Mask, [ ], A, [ ], [ ], [ ]) ;
108        assert (GB_spok (C.matrix*1) == 1) ;
109        S = GB_spec_Matrix_extract (Cempty, Mask, [ ], A, [ ], [ ], [ ]) ;
110        assert (isequal (C.class, A.class)) ;
111        assert (isequal (C.class, S.class)) ;
112        assert (isequal (full (double (C.matrix)), double (S.matrix))) ;
113        if (isequal (C.class, 'double'))
114            assert (isequal (C.matrix .* Mask, (A.matrix).*Mask)) ;
115        end
116
117        % C<Mask> = A (:,:)'
118        clear D
119        D = struct ('inp0', 'tran') ;
120        C = GB_mex_Matrix_extract  (Cempty2, Mask', [ ], A, [ ], [ ], D) ;
121        assert (GB_spok (C.matrix*1) == 1) ;
122        S = GB_spec_Matrix_extract (Cempty2, Mask', [ ], A, [ ], [ ], D) ;
123        assert (isequal (C.class, A.class)) ;
124        assert (isequal (C.class, S.class)) ;
125        assert (isequal (full (double (C.matrix)), double (S.matrix))) ;
126        if (isequal (C.class, 'double'))
127            assert (isequal (C.matrix .* Mask', (A.matrix') .* Mask')) ;
128        end
129
130        if (fulltests)
131            k2_list = [1:length(types)] ;
132        else
133            k2_list = unique ([11 irand(2,length(types),1,1)]) ;
134        end
135
136        % try every type for Cin
137        for k2 = k2_list
138            cintype = types {k2} ;
139            Cin2.class = cintype ;
140            Cin.class = cintype ;
141
142            fprintf ('%s', cintype) ;
143
144            if (fulltests)
145                k3_list = 1:length (accum_ops) ;
146            else
147                k3_list = unique ([1 5 irand(2,length(accum_ops),1,1)]) ;
148            end
149
150            % try every operator
151            for k3 = k3_list
152                op = accum_ops {k3} ;
153                fprintf ('.') ;
154
155                if (fulltests)
156                    k4_list = [1:length(types)] ;
157                else
158                    k4_list = unique ([11 irand(2,length(types),1,1)]) ;
159                end
160
161                % try every operator type
162                for k4 = k4_list
163                    optype = types {k4} ;
164
165                    clear accum
166                    accum.opname = op ;
167                    accum.optype = optype ;
168
169                    if (GB_spec_is_positional (accum))
170                        continue ;
171                    end
172
173                    try
174                        GB_spec_operator (accum) ;
175                    catch
176                        continue
177                    end
178
179                    z = GB_mex_cast (1, optype) ;
180                    opint = isinteger (z) || islogical (z) ;
181
182                    % try several I's
183                    for k5 = 1:4
184
185                        switch (k5)
186                            case 1
187                                I = [ ] ;
188                            case 2
189                                I = uint64 (1 + floor(nrows/2)) ;
190                                if (I+2 < nrows)
191                                    I = [I I+2] ;
192                                end
193                            case 3
194                                I = uint64 (randperm (nrows)) ;
195                            case 4
196                                I = uint64 (min (4, nrows-1)) ;
197                        end
198                        II = I ;
199                        if (isempty (II))
200                            II = 1:nrows ;
201                        end
202                        ni = length (II) ;
203
204                        if (size (A,2) == 1)
205                            k6_cases = 2 ;
206                        else
207                            k6_cases = 4 ;
208                        end
209
210                        % try several J's
211                        for k6 = 1:k6_cases
212
213                            switch (k6)
214                                case 1
215                                    J = [ ] ;
216                                case 2
217                                    J = uint64 (1 + floor(ncols/2)) ;
218                                    if (J+2 < ncols)
219                                        J = [J J+2] ;
220                                    end
221                                case 3
222                                    J = uint64 (randperm (ncols)) ;
223                                case 4
224                                    J = uint64 (1) ;
225                            end
226                            JJ = J ;
227                            if (isempty (JJ))
228                                JJ = 1:ncols ;
229                            end
230                            nj = length (JJ) ;
231
232                            clear Csub Csub2
233
234                            Csub.matrix   = Cin.matrix   (1:ni,1:nj) ;
235                            Csub.class    = Cin.class ;
236
237                            Csub2.matrix  = Cin2.matrix  (1:nj,1:ni) ;
238                            Csub2.class   = Cin2.class ;
239
240                            for A_is_hyper = 0:1
241                            for A_is_csc   = 0:1
242                            A.is_hyper = A_is_hyper ;
243                            A.is_csc   = A_is_csc   ;
244
245                            % C = op (Csub,A(I,J))
246                            C = GB_mex_Matrix_extract  (Csub, [ ], accum, ...
247                                A, I-1, J-1, [ ]) ;
248                            assert (GB_spok (C.matrix*1) == 1) ;
249                            S = GB_spec_Matrix_extract (Csub, [ ], accum,  ...
250                                A, I, J, [ ]) ;
251                            assert (isequal (C.class, cintype)) ;
252                            assert (isequal (C.class, S.class)) ;
253                            if (~(isequalwithequalnans (...
254                                full (double (C.matrix)), ...
255                                double (S.matrix))))
256                                assert (false)
257                            end
258
259                            A_is_vector = (size (A.matrix,2) == 1 && ...
260                                isequal (J, 1) && A.is_csc && ~A.is_hyper) ;
261
262                            if (A_is_vector)
263                                % A is a column vector; test Vector_extract
264                                % C = op (Csub,A(I,1))
265                                C = GB_mex_Vector_extract  (Csub, [ ], ...
266                                    accum, A, I-1, [ ]) ;
267                                assert (GB_spok (C.matrix*1) == 1) ;
268                                S = GB_spec_Vector_extract (Csub, [ ], ...
269                                    accum, A, I, [ ]) ;
270                                assert (isequal (C.class, cintype)) ;
271                                assert (isequal (C.class, S.class)) ;
272                                assert (isequalwithequalnans (...
273                                    full (double (C.matrix)), ...
274                                    double (S.matrix))) ;
275                            end
276
277                            if (length (J) == 1)
278                                % J is a scalar, test Col_extract
279                                % C = op (Csub,A(I,j))
280                                C = GB_mex_Col_extract  (Csub, [ ], ...
281                                    accum, A, I-1, J-1, [ ]) ;
282                                assert (GB_spok (C.matrix*1) == 1) ;
283                                S = GB_spec_Col_extract (Csub, [ ], ...
284                                    accum, A, I, J, [ ]) ;
285                                assert (isequal (C.class, cintype)) ;
286                                assert (isequal (C.class, S.class)) ;
287                                assert (isequalwithequalnans (...
288                                    full (double (C.matrix)), ...
289                                    double (S.matrix))) ;
290                            end
291
292                            % C = op (Csub,A(J,I)')
293                            clear D
294                            D = struct ('inp0', 'tran') ;
295
296                            C = GB_mex_Matrix_extract  (Csub2, [ ], accum,  ...
297                                A, J-1, I-1, D) ;
298                            assert (GB_spok (C.matrix*1) == 1) ;
299                            S = GB_spec_Matrix_extract (Csub2, [ ], accum,  ...
300                                A, J, I, D) ;
301                            assert (isequal (C.class, cintype)) ;
302                            assert (isequal (C.class, S.class)) ;
303                            assert (isequalwithequalnans (...
304                                full (double (C.matrix)), ...
305                                double (S.matrix))) ;
306
307                            if (length (I) == 1)
308                                % I is a scalar, test Col_extract
309                                % C = op (Csub,A(i,J)')
310                                C = GB_mex_Col_extract  (Csub2, [ ], ...
311                                    accum, A, J-1, I-1, D) ;
312                                assert (GB_spok (C.matrix*1) == 1) ;
313                                S = GB_spec_Col_extract (Csub2, [ ], ...
314                                    accum, A, J, I, D) ;
315                                assert (isequal (C.class, cintype)) ;
316                                assert (isequal (C.class, S.class)) ;
317                                assert (isequalwithequalnans (...
318                                    full (double (C.matrix)), ...
319                                    double (S.matrix))) ;
320                            end
321
322                            % try with a Mask (Mask must be sparse; logical and
323                            % double)
324
325                            for k7 = [1 11]
326                                mask_class = types {k7} ;
327                                M = cast (Mask, mask_class) ;
328                                Msub  = M (1:ni, 1:nj) ;
329
330                                % C = op (Csub2,A (I,J))
331                                C = GB_mex_Matrix_extract  (Csub, Msub,  ...
332                                    accum, A, I-1, J-1, [ ]) ;
333                                assert (GB_spok (C.matrix*1) == 1) ;
334                                S = GB_spec_Matrix_extract (Csub, Msub,  ...
335                                    accum, A, I, J, [ ]) ;
336                                assert (isequal (C.class, cintype)) ;
337                                assert (isequal (C.class, S.class)) ;
338                                assert (isequalwithequalnans (...
339                                    full (double (C.matrix)), ...
340                                    double (S.matrix))) ;
341
342                                if (A_is_vector)
343                                    % A is a column vector; test Vector_extract
344                                    % C = op (Csub,A(I,1))
345                                    C = GB_mex_Vector_extract  (Csub, Msub, ...
346                                        accum, A, I-1, [ ]) ;
347                                    assert (GB_spok (C.matrix*1) == 1) ;
348                                    S = GB_spec_Vector_extract (Csub, Msub, ...
349                                        accum, A, I, [ ]) ;
350                                    assert (isequal (C.class, cintype)) ;
351                                    assert (isequal (C.class, S.class)) ;
352                                    assert (isequalwithequalnans (...
353                                        full (double (C.matrix)), ...
354                                        double (S.matrix))) ;
355                                end
356
357                                if (length (J) == 1)
358                                    % J is a scalar, test Col_extract
359                                    % C = op (Csub,A(I,j))
360                                    C = GB_mex_Col_extract  (Csub, Msub, ...
361                                        accum, A, I-1, J-1, [ ]) ;
362                                    assert (GB_spok (C.matrix*1) == 1) ;
363                                    S = GB_spec_Col_extract (Csub, Msub, ...
364                                        accum, A, I, J, [ ]) ;
365                                    assert (isequal (C.class, cintype)) ;
366                                    assert (isequal (C.class, S.class)) ;
367                                    assert (isequalwithequalnans (...
368                                        full (double (C.matrix)), ...
369                                        double (S.matrix))) ;
370                                end
371
372                                % C = op (Csub,A(J,I)')
373                                clear D
374                                D = struct ('inp0', 'tran') ;
375                                C = GB_mex_Matrix_extract  (Csub2, Msub',  ...
376                                    accum, A, J-1, I-1, D) ;
377                                assert (GB_spok (C.matrix*1) == 1) ;
378                                S = GB_spec_Matrix_extract (Csub2, Msub',  ...
379                                    accum, A, J, I, D) ;
380                                assert (isequal (C.class, cintype)) ;
381                                assert (isequal (C.class, S.class)) ;
382                                assert (isequalwithequalnans (...
383                                    full (double (C.matrix)), ...
384                                    double (S.matrix))) ;
385
386                                if (length (I) == 1)
387                                    % I is a scalar, test Col_extract
388                                    % C = op (Csub,A(i,J)')
389                                    C = GB_mex_Col_extract  (Csub2, Msub', ...
390                                        accum, A, J-1, I-1, D) ;
391                                    assert (GB_spok (C.matrix*1) == 1) ;
392                                    S = GB_spec_Col_extract (Csub2, Msub', ...
393                                        accum, A, J, I, D) ;
394                                    assert (isequal (C.class, cintype)) ;
395                                    assert (isequal (C.class, S.class)) ;
396                                    assert (isequalwithequalnans (...
397                                        full (double (C.matrix)), ...
398                                        double (S.matrix))) ;
399                                end
400                            end
401
402                            end
403                            end
404
405                        end
406                    end
407                end
408            end
409        end
410    end
411end
412
413fprintf ('\ntest53: all tests passed\n') ;
414
415