1function test107
2%TEST107 user-defined terminal monoid
3
4% SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
5% SPDX-License-Identifier: Apache-2.0
6
7fprintf ('test107: reduce with built-in and  user-defined terminal monoids\n') ;
8
9rng ('default') ;
10
11save = nthreads_get ;
12if (nargin < 1)
13    fulltest = false ;
14end
15nthreads_max = GB_mex_omp_max_threads ;
16if (fulltest)
17    nthreads_list = [1 2 4 8 16 20 40 64 160] ;
18else
19    nthreads_list = [1 nthreads_max 2*nthreads_max] ;
20end
21
22% create a matrix with entries [0..2]
23A = 2 * sparse (rand (4)) ;
24s = full (max (max (A))) ;
25
26c = GB_mex_reduce_terminal (A, 2) ;
27assert (c == s) ;
28
29% now add the terminal value somewhere
30A (1,2) = 2 ;
31s = full (max (max (A))) ;
32c = GB_mex_reduce_terminal (A, 2) ;
33assert (c == s) ;
34clear A
35
36ntrials = 10 ;
37
38%-------------------------------------------------------------------------------
39% big matrix ...
40fprintf ('\nbig matrix, no early exit\n') ;
41if (fulltest)
42    n = 6000 ;
43else
44    n = 1000 ;
45end
46A = sparse (rand (n)) ;
47
48tic
49for trial = 1:ntrials
50    s = full (max (max (A))) ;
51end
52tm = toc ;
53fprintf ('MATLAB max: %g\n', tm) ;
54for nthreads = nthreads_list
55    fprintf ('\n') ;
56    if (nthreads > 2*nthreads_max)
57        break ;
58    end
59    nthreads_set (nthreads) ;
60    tic
61    for trial = 1:ntrials
62        c1 = GB_mex_reduce_to_scalar (0, [ ], 'max', A) ;
63    end
64    tg = toc ;
65    assert (s == c1) ;
66    if (nthreads == 1)
67        t1 = tg ;
68    end
69
70    fprintf ('nthreads %3d built-in      %g  speedup %g\n', nthreads, tg, t1/tg) ;
71
72    tic
73    for trial = 1:ntrials
74        c2 = GB_mex_reduce_terminal (A, 1) ;    % user-defined
75    end
76    t2 = toc ;
77    fprintf ('nthreads %3d %g\n', nthreads, t2) ;
78    assert (s == c2) ;
79
80    tic
81    for trial = 1:ntrials
82        c3 = GB_mex_reduce_terminal (A, 2) ;    % user-defined
83    end
84    t3 = toc ;
85    fprintf ('nthreads %3d %g\n', nthreads, t3) ;
86    assert (s == c3) ;
87
88end
89
90%-------------------------------------------------------------------------------
91fprintf ('\nbig matrix, with early exit\n') ;
92
93A (n,1) = 1 ;
94
95tic
96for trial = 1:ntrials
97    s = full (max (max (A))) ;
98end
99tm = toc ;
100fprintf ('MATLAB max: %g\n', tm) ;
101for nthreads = nthreads_list
102    fprintf ('\n') ;
103    if (nthreads > nthreads_max)
104        break ;
105    end
106    nthreads_set (nthreads) ;
107    tic
108    for trial = 1:ntrials
109        c1 = GB_mex_reduce_to_scalar (0, [ ], 'max', A) ;
110    end
111    t1 = toc ;
112    fprintf ('nthreads %3d built-in      %g\n', nthreads, t1) ;
113    tic
114    for trial = 1:ntrials
115        c2 = GB_mex_reduce_terminal (A, 1) ;    % user-defined
116    end
117    t2 = toc ;
118    fprintf ('nthreads %3d %g\n', nthreads, t2) ;
119    assert (s == c1) ;
120    assert (s == c2) ;
121end
122
123%-------------------------------------------------------------------------------
124fprintf ('\nbig matrix, with inf \n') ;
125
126A (n,1) = inf ;
127
128tic
129for trial = 1:ntrials
130    s = full (max (max (A))) ;
131end
132tm = toc ;
133fprintf ('MATLAB max: %g\n', tm) ;
134for nthreads = nthreads_list
135    fprintf ('\n') ;
136    if (nthreads > nthreads_max)
137        break ;
138    end
139    nthreads_set (nthreads) ;
140    tic
141    for trial = 1:ntrials
142        c1 = GB_mex_reduce_to_scalar (0, [ ], 'max', A) ;
143    end
144    t1 = toc ;
145    fprintf ('nthreads %3d built-in      %g\n', nthreads, t1) ;
146    tic
147    for trial = 1:ntrials
148        c2 = GB_mex_reduce_terminal (A, inf) ;
149    end
150    t2 = toc ;
151    fprintf ('nthreads %3d %g\n', nthreads, t2) ;
152    assert (s == c1) ;
153    assert (s == c2) ;
154end
155
156%-------------------------------------------------------------------------------
157fprintf ('\nbig matrix, with 2 \n') ;
158
159A (n,1) = 2 ;
160
161tic
162for trial = 1:ntrials
163    s = full (max (max (A))) ;
164end
165tm = toc ;
166fprintf ('MATLAB max: %g\n', tm) ;
167for nthreads = nthreads_list
168    fprintf ('\n') ;
169    if (nthreads > nthreads_max)
170        break ;
171    end
172    nthreads_set (nthreads) ;
173    tic
174    for trial = 1:ntrials
175        c1 = GB_mex_reduce_to_scalar (0, [ ], 'max', A) ;
176    end
177    t1 = toc ;
178    fprintf ('nthreads %3d built-in      %g\n', nthreads, t1) ;
179    tic
180    for trial = 1:ntrials
181        c2 = GB_mex_reduce_terminal (A, 2) ;
182    end
183    t2 = toc ;
184    fprintf ('nthreads %3d %g\n', nthreads, t2) ;
185    assert (s == c1) ;
186    assert (s == c2) ;
187end
188
189%-------------------------------------------------------------------------------
190fprintf ('\nbig matrix, with nan\n') ;
191
192A (n,1) = nan ;
193
194tic
195for trial = 1:ntrials
196    s = full (max (max (A))) ;
197end
198tm = toc ;
199fprintf ('MATLAB max: %g\n', tm) ;
200for nthreads = nthreads_list
201    fprintf ('\n') ;
202    if (nthreads > nthreads_max)
203        break ;
204    end
205    nthreads_set (nthreads) ;
206    tic
207    for trial = 1:ntrials
208        c1 = GB_mex_reduce_to_scalar (0, [ ], 'max', A) ;
209    end
210    t1 = toc ;
211    fprintf ('nthreads %3d built-in      %g\n', nthreads, t1) ;
212    assert (s == c1) ;
213end
214
215assert (s == c1) ;
216
217%-------------------------------------------------------------------------------
218fprintf ('\nsum\n') ;
219
220A (n,1) = 1 ;
221
222tic
223for trial = 1:ntrials
224    ss = full (sum (sum (A))) ;
225end
226tm = toc ;
227fprintf ('MATLAB sum: %g\n', tm) ;
228for nthreads = nthreads_list
229    fprintf ('\n') ;
230    if (nthreads > nthreads_max)
231        break ;
232    end
233    nthreads_set (nthreads) ;
234    tic
235    for trial = 1:ntrials
236        cc = GB_mex_reduce_to_scalar (0, [ ], 'plus', A) ;
237    end
238    t1 = toc ;
239    fprintf ('nthreads %3d built-in      %g\n', nthreads, t1) ;
240    assert (s == c1) ;
241end
242
243err = abs (ss - cc) / ss
244assert (err < 1e-12) ;
245
246nthreads_set (save) ;
247fprintf ('test107: all tests passed\n') ;
248
249