1function test160
2%TEST160 test GrB_mxm
3
4% SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
5% SPDX-License-Identifier: Apache-2.0
6
7rng ('default') ;
8
9n = 100 ;
10Mask.matrix = (rand (n) > 0.5) ;
11Mask.pattern = true (n) ;
12mtypes = { 'int8', 'int16', 'int32', 'int64', 'double complex' } ;
13
14Mask2 = GB_spec_random (n, n, 1, 0.01, 'logical') ;
15Mask2.matrix = logical (Mask2.matrix) ;
16Mask2.matrix (:,1) = false ;
17Mask2.pattern (:,1) = false ;
18
19semiring.add = 'plus' ;
20semiring.multiply = 'times' ;
21semiring.class = 'double' ;
22
23dnn = struct ;
24dnn_struct = struct ('mask', 'structural') ;
25dnn_notM_struct = struct ('mask', 'structural complement') ;
26dnn_notM = struct ('mask', 'complement') ;
27dnn_notM_hash = struct ('mask', 'complement', 'axb', 'hash') ;
28dnn_hash = struct ('axb', 'hash') ;
29
30d = 0.01 ;
31
32A = GB_spec_random (n, n, d, 1, 'double') ;
33G = A ;
34G.matrix (:,1:2) = 1 ;
35G.pattern (:,1:2) = true ;
36B = GB_spec_random (n, n, d, 1, 'double') ;
37B.matrix (1:2,1) = 1 ;
38B.pattern (1:2,1) = true ;
39b = GB_spec_random (n, 1, d, 1, 'double') ;
40Cin = sparse (n, n) ;
41cin = sparse (n, 1) ;
42mask.matrix = (rand (n,1) > 0.5) ;
43mask.pattern = true (n,1) ;
44
45H.matrix = sparse (ones (n,n)) ;
46H.matrix (1,1) = 0 ;
47H.pattern = sparse (true (n,n)) ;
48H.matrix (1,1) = false ;
49H.sparsity = 2 ;
50mask2.matrix = sparse (false (n,1)) ;
51mask2.matrix (1,1) = true ;
52mask2.pattern = sparse (false (n,1)) ;
53mask2.pattern (1,1) = true ;
54x = GB_spec_random (n, 1, 0.5, 1, 'double') ;
55x.sparsity = 2 ;
56y = GB_spec_random (n, 1, 0.02, 1, 'double') ;
57y.sparsity = 2 ;
58
59K = GB_spec_random (1000, 2, 0.1, 1, 'double') ;
60K.matrix (1:2, 1:2) = pi ;
61K.pattern (1:2, 1:2) = true ;
62K.sparsity = 2 ;
63z.matrix = rand (2,1) ;
64maskz.matrix = sparse (false (1000,1)) ;
65maskz.matrix (1,1) = true ;
66maskz.pattern = sparse (false (1000,1)) ;
67maskz.pattern (1,1) = true ;
68maskz.class = 'logical' ;
69cinz = sparse (1000, 1) ;
70
71
72for k = 1:length (mtypes)
73
74    fprintf ('%s ', mtypes {k}) ;
75    Mask.class = mtypes {k} ;
76    Mask2.class = mtypes {k} ;
77    mask.class = mtypes {k} ;
78
79    % C<M> = A*B
80    C1 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dnn) ;
81    C2 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dnn) ;
82    GB_spec_compare (C1, C2) ;
83
84    % C<M,struct> = A*B
85    C1 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dnn_struct) ;
86    C2 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dnn_struct) ;
87    GB_spec_compare (C1, C2) ;
88
89    % C<!M,struct> = A*B
90    C1 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dnn_notM_struct) ;
91    C2 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dnn_notM_struct) ;
92    GB_spec_compare (C1, C2) ;
93
94    % C<!M> = A*B
95    C1 = GB_spec_mxm (Cin, Mask, [ ], semiring, A, B, dnn_notM) ;
96    C2 = GB_mex_mxm  (Cin, Mask, [ ], semiring, A, B, dnn_notM) ;
97    GB_spec_compare (C1, C2) ;
98
99    % C<M> = G*b
100    C1 = GB_spec_mxm (cin, mask, [ ], semiring, G, b, dnn) ;
101    C2 = GB_mex_mxm  (cin, mask, [ ], semiring, G, b, dnn) ;
102    GB_spec_compare (C1, C2) ;
103
104    % C<!M> = G*b
105    C1 = GB_spec_mxm (cin, mask, [ ], semiring, G, b, dnn_notM) ;
106    C2 = GB_mex_mxm  (cin, mask, [ ], semiring, G, b, dnn_notM) ;
107    GB_spec_compare (C1, C2) ;
108
109    % C<!M,struct> = A*B
110    C1 = GB_spec_mxm (cin, mask, [ ], semiring, G, b, dnn_notM_struct) ;
111    C2 = GB_mex_mxm  (cin, mask, [ ], semiring, G, b, dnn_notM_struct) ;
112    GB_spec_compare (C1, C2) ;
113
114    % C<!M> = A*B
115    C1 = GB_spec_mxm (Cin, Mask2, [ ], semiring, A, B, dnn_notM) ;
116    C2 = GB_mex_mxm  (Cin, Mask2, [ ], semiring, A, B, dnn_notM) ;
117    GB_spec_compare (C1, C2) ;
118
119    % C<!Mask2> = G*B
120    C1 = GB_spec_mxm (Cin, Mask2, [ ], semiring, G, B, dnn_notM) ;
121    C2 = GB_mex_mxm  (Cin, Mask2, [ ], semiring, G, B, dnn_notM) ;
122    GB_spec_compare (C1, C2) ;
123
124    % C<!M> = H*x
125    C1 = GB_spec_mxm (cin, mask2, [ ], semiring, H, x, dnn_notM) ;
126    C2 = GB_mex_mxm  (cin, mask2, [ ], semiring, H, x, dnn_notM) ;
127    GB_spec_compare (C1, C2) ;
128
129    % C<!M> = G*x
130    C1 = GB_spec_mxm (cin, mask2, [ ], semiring, G, x, dnn_notM_hash) ;
131    C2 = GB_mex_mxm  (cin, mask2, [ ], semiring, G, x, dnn_notM_hash) ;
132    GB_spec_compare (C1, C2) ;
133
134    % C<!M> = K*z
135    z.sparsity = 4 ;
136    C1 = GB_spec_mxm (cinz, maskz, [ ], semiring, K, z, dnn_notM_hash) ;
137    C2 = GB_mex_mxm  (cinz, maskz, [ ], semiring, K, z, dnn_notM_hash) ;
138    GB_spec_compare (C1, C2) ;
139
140end
141
142% C = K*z
143z.sparsity = 2 ;
144C1 = GB_spec_mxm (cinz, [ ], [ ], semiring, K, z, dnn_hash) ;
145C2 = GB_mex_mxm  (cinz, [ ], [ ], semiring, K, z, dnn_hash) ;
146C3 = GB_mex_mxm_generic  (cinz, [ ], [ ], semiring, K, z, dnn_hash) ;
147GB_spec_compare (C1, C2) ;
148GB_spec_compare (C1, C3) ;
149
150% C<!M> = K*z
151z.sparsity = 2 ;
152C1 = GB_spec_mxm (cinz, maskz, [ ], semiring, K, z, dnn_notM_hash) ;
153C2 = GB_mex_mxm  (cinz, maskz, [ ], semiring, K, z, dnn_notM_hash) ;
154C3 = GB_mex_mxm_generic  (cinz, maskz, [ ], semiring, K, z, dnn_notM_hash) ;
155GB_spec_compare (C1, C2) ;
156GB_spec_compare (C1, C3) ;
157
158fprintf ('\ntest160: all tests passed\n') ;
159