1function test06 (A,B,fulltests,method_list)
2%TEST06 test GrB_mxm on all semirings
3%
4% Usage: test06(A)
5%        test06(A,B)
6%        test06(A,B,fulltests)
7%
8% with no input, a small 10-by-10 matrix is used.  If A is a scalar, it is a
9% matrix id number from the SuiteSparse collection otherwise A is the sparse
10% matrix to use in the test
11
12% SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
13% SPDX-License-Identifier: Apache-2.0
14
15fprintf ('test06: GrB_mxm on all semirings\n') ;
16
17[binops, ~, add_ops, types, ~, ~] = GB_spec_opsall ;
18% mult_ops = binops.positional ;
19mult_ops = binops.all ;
20types = types.all ;
21
22if (nargin < 3)
23    fprintf ('\n-------------- GrB_mxm on all semirings\n') ;
24    fulltests = 1 ;
25end
26
27if (nargin < 2)
28    B = [ ] ;
29end
30
31if (nargin < 4)
32    method_list = 0:3 ;
33end
34
35rng ('default') ;
36if (nargin < 1 || isempty (A))
37    % w = load ('../Demo/Matrix/west0067') ;
38    % A = sparse (w (:,1)+1, w (:,2)+1, w (:,3)) ;
39    w = load ('../Demo/Matrix/ibm32a') ;
40    nz = size (w,1) ;
41    n = max (max (w (:,1:2))) + 1 ;
42    A = sparse (w (:,1)+1, w (:,2)+1, 20 * rand (nz,1) - 10, n, n) ;
43    B = A (16:25, 16:25) .* rand (10) ;
44    A = A (1:10,1:10) ;
45elseif (isscalar (A))
46    Prob = ssget (A)
47    A = Prob.A ;
48    clear Prob
49    A (1,2) = 1 ;
50    if (isempty (B))
51        B = 2*A ;
52
53%       n = size (A,2) ;
54%       p = randperm (n) ;
55%       B = sparse (1:n, p, ones (n,1)) ;
56%       p = randperm (n) ;
57%       B = B + sparse (1:n, p, ones (n,1)) ;
58%       p = randperm (n) ;
59%       B = B + sparse (1:n, p, ones (n,1)) ;
60%       e = ones (n,1) ;
61%       % B = spdiags([e -2*e e], -1:1, n, n) ;
62%       % B = spdiags([-2*e e], 0:1, n, n) ;
63%       % B = speye(n);
64%       T = A ; A = B ; B = T ;
65
66    end
67end
68
69assert (issparse (A)) ;
70
71[m n] = size (A) ;
72Cin = sparse (m, n) ;
73assert (m == n) ;
74
75rng ('default') ;
76[i, j, ~] = find (A) ;
77nz = nnz (A) ;
78p = randperm (nz, floor (nz/2)) ;
79
80Mask = sparse (i (p), j (p), ones (length (p),1), m, n) + ...
81       spones (sprandn (m, n, 1/n)) ;
82
83
84tic
85C = A*B ;
86tm1 = toc ;
87
88% Mask = spones (A) ;
89% Mask = sparse (i (p), j (p), ones (length (p),1), m, n) ;
90
91tic
92C = A'*B ;
93tm2 = toc ;
94
95tic
96C = A*B' ;
97tm3 = toc ;
98
99tic
100C = A'*B' ;
101tm5 = toc ;
102
103if (n > 500)
104    fprintf ('MATLAB time: %g %g %g %g\n', tm1, tm2, tm3, tm5) ;
105    fprintf ('with mask:\n') ;
106end
107
108tic
109C = (A*B) .* Mask ;
110tmm1 = toc ;
111
112tic
113C = (A'*B) .* Mask ;
114tmm2 = toc ;
115
116tic
117C = (A*B') .* Mask ;
118tmm3 = toc ;
119
120tic
121C = (A'*B') .* Mask ;
122tmm5 = toc ;
123
124if (n > 500)
125    fprintf ('MATLAB time: %g %g %g %g\n', tmm1, tmm2, tmm3, tmm5) ;
126end
127
128dnn = struct ;
129dtn = struct ( 'inp0', 'tran' ) ;
130dnt = struct ( 'inp1', 'tran' ) ;
131dtt = struct ( 'inp0', 'tran', 'inp1', 'tran' ) ;
132
133n_semirings = 0 ;
134
135if (fulltests)
136    k1_list = 1:length (mult_ops) ;
137    k2_list = 1:length (add_ops) ;
138    k3_list = 1:length (types) ;
139else
140    % just use plus-times-double semiring
141    k1_list = 4 ;
142    k2_list = 3 ;
143    k3_list = 11 ;
144end
145
146n = size (A,1) ;
147
148for k1 = k1_list % 1:length(mult_ops)
149    mulop = mult_ops {k1} ;
150    if (n <= 500)
151        fprintf ('\n%s', mulop) ;
152    end
153
154    for k2 = k2_list % 1:length(add_ops)
155        addop = add_ops {k2} ;
156
157        for k3 = k3_list % 1:length (types)
158            semiring_type = types {k3} ;
159            if (n <= 500)
160               fprintf ('.') ;
161            end
162
163            semiring.multiply = mulop ;
164            semiring.add = addop ;
165            semiring.class = semiring_type ;
166
167            % create the semiring.  some are not valid because the or,and,xor,eq
168            % monoids can only be used when z is boolean for z=mult(x,y).
169            try
170                [mult_op add_op id] = GB_spec_semiring (semiring) ;
171                [mult_opname mult_optype ztype xtype ytype] = ...
172                    GB_spec_operator (mult_op) ;
173                [ add_opname  add_optype] = GB_spec_operator (add_op) ;
174                identity = GB_spec_identity (semiring.add, add_optype) ;
175            catch me
176                if (~isempty (strfind (me.message, 'gotcha')))
177                    semiring
178                end
179                continue
180            end
181
182            n_semirings = n_semirings + 1 ;
183
184            for method = method_list % 0:3
185
186                if (n > 500)
187                    fprintf ('%3d ', n_semirings) ;
188                    fprintf ('[%6s %6s %8s] : ', mulop, addop, semiring_type) ;
189                end
190
191                if (method == 1)
192                    algo = 'hash' ;
193                    if (n > 500)
194                        fprintf ('hash ') ;
195                    end
196                elseif (method == 2)
197                    algo = 'gustavson' ;
198                    if (n > 500)
199                        fprintf ('g/s  ') ;
200                    end
201                elseif (method == 3)
202                    algo = 'dot' ;
203                    if (n > 500)
204                        fprintf ('dot  ') ;
205                    end
206                else
207                    algo = 'default' ;
208                    if (n > 500)
209                        fprintf ('auto ') ;
210                    end
211                end
212                if (isequal (algo, 'dot'))
213                    ok = (n < 1000) ;
214                else
215                    ok = true ;
216                end
217
218                dnn.axb = algo ;
219                dnt.axb = algo ;
220                dtn.axb = algo ;
221                dtt.axb = algo ;
222
223                t1 = nan ;
224                t2 = nan ;
225                t3 = nan ;
226                t4 = nan ;
227
228                % C = A*B, no mask
229                % tic
230                if (ok)
231                C1 = GB_mex_mxm  (Cin, [ ], [ ], semiring, A, B, dnn) ;
232                t1 = grbresults ; % toc ;
233                if (n < 200)
234                C2 = GB_spec_mxm (Cin, [ ], [ ], semiring, A, B, dnn);
235                GB_spec_compare (C1, C2, id) ;
236                end
237                end
238
239                % C = A'*B, no mask
240                if (ok)
241                C1 = GB_mex_mxm  (Cin, [ ], [ ], semiring, A, B, dtn);
242                t2 = grbresults ; % toc ;
243                if (n < 200)
244                C2 = GB_spec_mxm (Cin, [ ], [ ], semiring, A, B, dtn);
245                GB_spec_compare (C1, C2, id) ;
246                end
247                end
248
249                % C = A*B', no mask
250                if (ok)
251                C1 = GB_mex_mxm  (Cin, [ ], [ ], semiring, A, B, dnt);
252                t3 = grbresults ; % toc ;
253                if (n < 200)
254                C2 = GB_spec_mxm (Cin, [ ], [ ], semiring, A, B, dnt);
255                GB_spec_compare (C1, C2, id) ;
256                end
257                end
258
259                % C = A'*B', no mask
260                if (ok)
261                C1 = GB_mex_mxm  (Cin, [ ], [ ], semiring, A, B, dtt);
262                t4 = grbresults ; % toc ;
263                if (n < 200)
264                C2 = GB_spec_mxm (Cin, [ ], [ ], semiring, A, B, dtt);
265                GB_spec_compare (C1, C2, id) ;
266                end
267                end
268
269                if (n > 500)
270                    fprintf (...
271                    'speedups %10.4f(%s) %10.4f(%s) %10.4f(%s) %10.4f(%s) ', ...
272                    tm1/t1, tm2/t2, tm3/t3, tm5/t4 ) ;
273                end
274
275                % C = A*B, with mask
276                % tic
277                C1 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dnn);
278                t1 = grbresults ; % toc ;
279                if (n < 200)
280                C2 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dnn);
281                GB_spec_compare (C1, C2, id) ;
282                end
283
284                % C = A'*B, with mask
285                % tic
286                C1 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dtn);
287                t2 = grbresults ; % toc ;
288                if (n < 200)
289                C2 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dtn);
290                GB_spec_compare (C1, C2, id) ;
291                end
292
293                % C = A*B', with mask
294                % tic
295                C1 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dnt);
296                t3 = grbresults ; % toc ;
297                if (n < 200)
298                C2 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dnt);
299                GB_spec_compare (C1, C2, id) ;
300                end
301
302                % C = A'*B', with mask
303                % tic
304                C1 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dtt);
305                t4 = grbresults ; % toc ;
306                if (n < 200)
307                C2 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dtt);
308                GB_spec_compare (C1, C2, id) ;
309                end
310
311                if (n > 500)
312                    fprintf (...
313                    'speedups %10.4f(%s) %10.4f(%s) %10.4f(%s) %10.4f(%s) ', ...
314                    tmm1/t1, tmm2/t2, tmm3/t3, tmm5/t4) ;
315                    fprintf ('\n') ;
316                end
317
318            end
319        end
320    end
321end
322
323% n_semirings
324
325if (fulltests)
326    fprintf ('\ntest06: all tests passed\n') ;
327end
328
329