1function test10
2%TEST10 test GrB_apply
3
4% SuiteSparse:GraphBLAS, Timothy A. Davis, (c) 2017-2021, All Rights Reserved.
5% SPDX-License-Identifier: Apache-2.0
6
7fprintf ('\ntest10: GrB_apply tests\n') ;
8
9[~, unary_ops, ~, types, ~, ~] = GB_spec_opsall ;
10types = types.all ;
11unary_ops = unary_ops.all ;
12
13rng ('default') ;
14
15m = 8 ;
16n = 4 ;
17dt = struct ('inp0', 'tran') ;
18dr = struct ('outp', 'replace') ;
19
20for k1 = 1:length(types)
21    atype = types {k1} ;
22    fprintf ('\n%s: ', atype) ;
23
24    Mask = GB_random_mask (m, n, 0.5, true, false) ;
25    Cin = GB_spec_random (m, n, 0.3, 100, atype) ;
26    Cmask = spones (GB_mex_cast (full (Cin.matrix), Cin.class)) ;
27
28    % for most operators
29    A = GB_spec_random (m, n, 0.3, 100, atype) ;
30    B = GB_spec_random (n, m, 0.3, 100, atype) ;
31
32    A_matrix = A.matrix ;
33    B_matrix = B.matrix ;
34
35    % for pow, sqrt, log, log10, log2, gammaln (domain is [0,inf])
36    A_pos_matrix = abs (A.matrix) ;
37    B_pos_matrix = abs (B.matrix) ;
38
39    % for asin, acos, atanh (domain is [-1,1])
40    A_1_matrix = A_matrix ;
41    B_1_matrix = B_matrix ;
42    A_1_matrix (abs (A_matrix) > 1) = 1 ;
43    B_1_matrix (abs (B_matrix) > 1) = 1 ;
44
45    % for acosh, asech (domain is [1, inf])
46    A_1inf_matrix = A_matrix ;
47    B_1inf_matrix = B_matrix ;
48    A_1inf_matrix (A_matrix < 1 & A_matrix ~= 0) = 1 ;
49    B_1inf_matrix (B_matrix < 1 & B_matrix ~= 0) = 1 ;
50
51    % for log1p (domain is [-1, inf])
52    A_n1inf_matrix = A_matrix ;
53    B_n1inf_matrix = B_matrix ;
54    A_n1inf_matrix (A_matrix < -1) = 1 ;
55    B_n1inf_matrix (B_matrix < -1) = 1 ;
56
57    % for tanh: domain is [-inf,inf], but rounding to
58    % integers fails when x is outside this range
59    A_5_matrix = A_matrix ;
60    B_5_matrix = B_matrix ;
61    A_5_matrix (abs (A_matrix) > 5) = 5 ;
62    B_5_matrix (abs (B_matrix) > 5) = 5 ;
63
64    % for gamma: domain is [-inf,inf], but not defined for negative
65    % integers, and rounding to integers fails when x is outside this range
66    A_pos5_matrix = A_matrix ;
67    B_pos5_matrix = B_matrix ;
68    A_pos5_matrix (A_matrix <= 0.1 & A_matrix ~= 0) = 0.1 ;
69    B_pos5_matrix (B_matrix <= 0.1 & B_matrix ~= 0) = 0.1 ;
70    A_pos5_matrix (A_matrix > 5) = 5 ;
71    B_pos5_matrix (B_matrix > 5) = 5 ;
72
73    % do longer tests for a few types
74    longer_tests = isequal (atype, 'double') || isequal (atype, 'int64') ;
75    if (longer_tests)
76        hrange = [0 1] ;
77        crange = [0 1] ;
78    else
79        hrange = 0 ;
80        crange = 1 ;
81    end
82
83    for k2 = 1:length(unary_ops)
84        op.opname = unary_ops {k2} ;
85        if (longer_tests)
86            fprintf ('\n') ;
87        end
88        fprintf (' %s', op.opname) ;
89
90        for k3 = 1:length(types)
91            op.optype = types {k3} ;
92
93            if (ispc && contains (op.opname, 'asin') && contains (op.optype, 'complex'))
94                % casin and casinf are broken on Windows
95                fprintf (' (skipped)') ;
96                continue ;
97            end
98
99            try
100                [opname optype ztype xtype ytype] = GB_spec_operator (op) ;
101            catch
102                continue
103            end
104            fprintf ('.') ;
105
106            A.matrix = A_matrix ;
107            B.matrix = B_matrix ;
108
109            switch (opname)
110                % domain is ok, but limit it to avoid integer typecast
111                % failures from O(eps) errors, or overflow to inf
112                case { 'tanh', 'exp', 'sin', 'cos', 'tan', ...
113                    'sinh', 'cosh', 'asin', 'acos', 'acosh', 'asinh', ...
114                    'atanh', 'exp2', 'expm1', 'carg', 'atan' }
115                    A.matrix = A_5_matrix ;
116                    B.matrix = B_5_matrix ;
117                case { 'tgamma' }
118                    A.matrix = A_pos5_matrix ;
119                    B.matrix = B_pos5_matrix ;
120                otherwise
121                    % no change
122            end
123
124            if (~contains (optype, 'complex'))
125
126                % for real operators, avoiding complex results
127                switch (opname)
128                    case { 'pow', 'sqrt', 'log', 'log10', 'log2', ...
129                        'gammaln', 'lgamma' }
130                        A.matrix = A_pos_matrix ;
131                        B.matrix = B_pos_matrix ;
132                    case { 'asin', 'acos', 'atanh' }
133                        A.matrix = A_1_matrix ;
134                        B.matrix = B_1_matrix ;
135                    case { 'acosh', 'asech' }
136                        A.matrix = A_1inf_matrix ;
137                        B.matrix = B_1inf_matrix ;
138                    case 'log1p'
139                        A.matrix = A_n1inf_matrix ;
140                        B.matrix = B_n1inf_matrix ;
141                    case { 'tanh', 'exp' }
142                        % domain is ok, but limit it to avoid integer typecast
143                        % failures from O(eps) errors
144                        A.matrix = A_5_matrix ;
145                        B.matrix = B_5_matrix ;
146                    otherwise
147                        % no change
148                end
149
150            end
151
152            % op
153
154            tol = 0 ;
155            if (contains (optype, 'single') || contains (atype, 'single'))
156                tol = 1e-5 ;
157            elseif (contains (optype, 'double') || contains (atype, 'double'))
158                tol = 1e-12 ;
159            end
160
161            for A_sparsity = [hrange 2]
162
163            if (A_sparsity == 0)
164                A_is_hyper = 0 ;
165                A_is_bitmap = 0 ;
166                A_sparsity_control = 2 ;    % sparse
167            elseif (A_sparsity == 1)
168                A_is_hyper = 1 ;
169                A_is_bitmap = 0 ;
170                A_sparsity_control = 1 ;    % hypersparse
171            else
172                A_is_hyper = 0 ;
173                A_is_bitmap = 1 ;
174                A_sparsity_control = 4 ;    % bitmap
175            end
176
177            for A_is_csc   = crange
178
179            if (longer_tests)
180                fprintf ('.') ;
181            end
182
183            for C_is_hyper = hrange
184            for C_is_csc   = crange
185            for M_is_hyper = hrange
186            for M_is_csc   = crange
187            A.is_csc    = A_is_csc ; A.is_hyper    = A_is_hyper ;
188            Cin.is_csc  = C_is_csc ; Cin.is_hyper  = C_is_hyper ;
189            B.is_csc    = A_is_csc ; B.is_hyper    = A_is_hyper ;
190            Mask.is_csc = M_is_csc ; Mask.is_hyper = M_is_hyper ;
191
192            A.sparsity = A_sparsity_control ;
193            B.sparsity = A_sparsity_control ;
194
195            % no mask
196            C1 = GB_spec_apply (Cin, [], [], op, A, []) ;
197            C2 = GB_mex_apply  (Cin, [], [], op, A, []) ;
198            test10_compare (op, C1, C2, tol) ;
199
200            % with mask
201            C1 = GB_spec_apply (Cin, Mask, [], op, A, []) ;
202            C2 = GB_mex_apply  (Cin, Mask, [], op, A, []) ;
203            test10_compare (op, C1, C2, tol) ;
204
205            % with C == mask, and outp = replace
206            C1 = GB_spec_apply (Cin, Cmask, [], op, A, dr) ;
207            C2 = GB_mex_apply_maskalias (Cin,        [], op, A, dr) ;
208            test10_compare (op, C1, C2, tol) ;
209
210            % no mask, transpose
211            C1 = GB_spec_apply (Cin, [], [], op, B, dt) ;
212            C2 = GB_mex_apply  (Cin, [], [], op, B, dt) ;
213            test10_compare (op, C1, C2, tol) ;
214
215            % with mask, transpose
216            C1 = GB_spec_apply (Cin, Mask, [], op, B, dt) ;
217            C2 = GB_mex_apply  (Cin, Mask, [], op, B, dt) ;
218            test10_compare (op, C1, C2, tol) ;
219
220            switch (opname)
221                % the results from these operators must be check before summing
222                % their results with the accum operator, so skip the rest of
223                % the tests.
224                case { 'acos', 'asin', 'atan' 'acosh', 'asinh', 'atanh' }
225                    continue ;
226            end
227
228            % no mask, with accum
229            C1 = GB_spec_apply (Cin, [], 'plus', op, A, []) ;
230            C2 = GB_mex_apply  (Cin, [], 'plus', op, A, []) ;
231            test10_compare (op, C1, C2, tol) ;
232
233            % with mask and accum
234            C1 = GB_spec_apply (Cin, Mask, 'plus', op, A, []) ;
235            C2 = GB_mex_apply  (Cin, Mask, 'plus', op, A, []) ;
236            test10_compare (op, C1, C2, tol) ;
237
238            % with C == mask and accum, and outp = replace
239            C1 = GB_spec_apply (Cin, Cmask, 'plus', op, A, dr) ;
240            C2 = GB_mex_apply_maskalias (Cin,        'plus', op, A, dr) ;
241            test10_compare (op, C1, C2, tol) ;
242
243            % no mask, with accum, transpose
244            C1 = GB_spec_apply (Cin, [], 'plus', op, B, dt) ;
245            C2 = GB_mex_apply  (Cin, [], 'plus', op, B, dt) ;
246            test10_compare (op, C1, C2, tol) ;
247
248            % with mask and accum, transpose
249            C1 = GB_spec_apply (Cin, Mask, 'plus', op, B, dt) ;
250            C2 = GB_mex_apply  (Cin, Mask, 'plus', op, B, dt) ;
251            test10_compare (op, C1, C2, tol) ;
252
253        end
254    end
255
256    end
257    end
258    end
259    end
260    end
261    end
262    fprintf ('\n') ;
263
264end
265
266fprintf ('\ntest10: all tests passed\n') ;
267
268