1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
13 #define FLA_ALG_REFERENCE 0
14 #define FLA_ALG_BLOCKED   1
15 #define FLA_ALG_UNBLOCKED 2
16 #define FLA_ALG_UNB_OPT   3
17 
18 
19 void time_Chol_u(
20                   int variant, int type, int nrepeats, int n, int nb_alg,
21                   FLA_Obj A, FLA_Obj b, FLA_Obj b_orig, FLA_Obj norm,
22                   double *dtime, double *diff, double *gflops );
23 
24 
time_Chol_u(int variant,int type,int nrepeats,int n,int nb_alg,FLA_Obj A,FLA_Obj b,FLA_Obj b_orig,FLA_Obj norm,double * dtime,double * diff,double * gflops)25 void time_Chol_u(
26                   int variant, int type, int nrepeats, int n, int nb_alg,
27                   FLA_Obj A, FLA_Obj b, FLA_Obj b_orig, FLA_Obj norm,
28                   double *dtime, double *diff, double *gflops )
29 {
30   int
31     irep;
32 
33   double
34     dtime_save = 1.0e9;
35 
36   FLA_Obj
37     A_save, b_save, b_orig_save;
38 
39   fla_blocksize_t*
40     bp;
41   fla_chol_t*
42     cntl_chol_var;
43   fla_chol_t*
44     cntl_chol_unb;
45   fla_syrk_t*
46     cntl_syrk_blas;
47   fla_herk_t*
48     cntl_herk_blas;
49   fla_trsm_t*
50     cntl_trsm_blas;
51   fla_gemm_t*
52     cntl_gemm_blas;
53 
54 /*
55   if( type == FLA_ALG_UNBLOCKED && n > 400 )
56   {
57     *gflops = 0.0;
58     *diff   = 0.0;
59     return;
60   }
61 */
62 
63   bp               = FLA_Blocksize_create( nb_alg, nb_alg, nb_alg, nb_alg );
64   cntl_chol_unb    = FLA_Cntl_chol_obj_create( FLA_FLAT, FLA_UNB_OPT_VARIANT2, NULL, NULL, NULL, NULL, NULL, NULL );
65   cntl_syrk_blas   = FLA_Cntl_syrk_obj_create( FLA_FLAT, FLA_SUBPROBLEM, NULL, NULL, NULL );
66   cntl_herk_blas   = FLA_Cntl_herk_obj_create( FLA_FLAT, FLA_SUBPROBLEM, NULL, NULL, NULL );
67   cntl_trsm_blas   = FLA_Cntl_trsm_obj_create( FLA_FLAT, FLA_SUBPROBLEM, NULL, NULL, NULL );
68   cntl_gemm_blas   = FLA_Cntl_gemm_obj_create( FLA_FLAT, FLA_SUBPROBLEM, NULL, NULL );
69   cntl_chol_var    = FLA_Cntl_chol_obj_create( FLA_FLAT, variant, bp,
70                                                cntl_chol_unb,
71                                                cntl_syrk_blas,
72                                                cntl_herk_blas,
73                                                cntl_trsm_blas,
74                                                cntl_gemm_blas );
75 
76   FLA_Obj_create_conf_to( FLA_NO_TRANSPOSE, A, &A_save );
77   FLA_Obj_create_conf_to( FLA_NO_TRANSPOSE, b, &b_save );
78   FLA_Obj_create_conf_to( FLA_NO_TRANSPOSE, b_orig, &b_orig_save );
79 
80   FLA_Copy_external( A, A_save );
81   FLA_Copy_external( b, b_save );
82   FLA_Copy_external( b_orig, b_orig_save );
83 
84 
85   for ( irep = 0 ; irep < nrepeats; irep++ ){
86 
87     FLA_Copy_external( A_save, A );
88 
89     *dtime = FLA_Clock();
90 
91     switch( variant ){
92 
93     case 0:
94 
95       REF_Chol_u( A );
96 
97       break;
98 
99     case 1:{
100 
101       // Time variant 1
102       switch( type ){
103       case FLA_ALG_UNBLOCKED:
104         FLA_Chol_u_unb_var1( A );
105         break;
106       case FLA_ALG_UNB_OPT:
107         FLA_Chol_u_opt_var1( A );
108         break;
109       case FLA_ALG_BLOCKED:
110         FLA_Chol_u_blk_var1( A, cntl_chol_var );
111         break;
112       default:
113         printf("trouble\n");
114       }
115 
116       break;
117     }
118 
119     case 2:{
120 
121       // Time variant 2
122       switch( type ){
123       case FLA_ALG_UNBLOCKED:
124         FLA_Chol_u_unb_var2( A );
125         break;
126       case FLA_ALG_UNB_OPT:
127         FLA_Chol_u_opt_var2( A );
128         break;
129       case FLA_ALG_BLOCKED:
130         FLA_Chol_u_blk_var2( A, cntl_chol_var );
131         break;
132       default:
133         printf("trouble\n");
134       }
135 
136       break;
137     }
138     case 3:{
139 
140       // Time variant 3
141       switch( type ){
142       case FLA_ALG_UNBLOCKED:
143         FLA_Chol_u_unb_var3( A );
144         break;
145       case FLA_ALG_UNB_OPT:
146         FLA_Chol_u_opt_var3( A );
147         break;
148       case FLA_ALG_BLOCKED:
149         FLA_Chol_u_blk_var3( A, cntl_chol_var );
150         break;
151       default:
152         printf("trouble\n");
153       }
154 
155       break;
156     }
157     }
158 
159     *dtime = FLA_Clock() - *dtime;
160     dtime_save = min( *dtime, dtime_save );
161   }
162 
163   FLA_Cntl_obj_free( cntl_chol_var );
164   FLA_Cntl_obj_free( cntl_chol_unb );
165   FLA_Cntl_obj_free( cntl_syrk_blas );
166   FLA_Cntl_obj_free( cntl_herk_blas );
167   FLA_Cntl_obj_free( cntl_trsm_blas );
168   FLA_Cntl_obj_free( cntl_gemm_blas );
169   FLA_Blocksize_free( bp );
170 
171   if ( type == FLA_ALG_REFERENCE )
172   {
173     FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE,
174                        FLA_UNIT_DIAG, A, b );
175     FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_NO_TRANSPOSE,
176                        FLA_NONUNIT_DIAG, A, b );
177 
178     FLA_Hemv_external( FLA_UPPER_TRIANGULAR,
179                        FLA_ONE, A_save, b, FLA_MINUS_ONE, b_orig );
180 
181     FLA_Nrm2_external( b_orig, norm );
182     FLA_Copy_object_to_buffer( FLA_NO_TRANSPOSE, 0, 0, norm,
183                                1, 1, diff, 1, 1 );
184   }
185   else
186   {
187     FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_CONJ_TRANSPOSE,
188                        FLA_UNIT_DIAG, A, b );
189     FLA_Trsv_external( FLA_UPPER_TRIANGULAR, FLA_NO_TRANSPOSE,
190                        FLA_NONUNIT_DIAG, A, b );
191 
192     FLA_Hemv_external( FLA_UPPER_TRIANGULAR,
193                        FLA_ONE, A_save, b, FLA_MINUS_ONE, b_orig );
194 
195     FLA_Nrm2_external( b_orig, norm );
196     FLA_Copy_object_to_buffer( FLA_NO_TRANSPOSE, 0, 0, norm,
197                                1, 1, diff, 1, 1 );
198   }
199 
200   *gflops = 1.0 / 3.0 *
201             FLA_Obj_length( A ) *
202             FLA_Obj_length( A ) *
203             FLA_Obj_length( A ) /
204             dtime_save / 1e9;
205 
206   if ( FLA_Obj_is_complex( A ) )
207     *gflops *= 4.0;
208 
209   *dtime = dtime_save;
210 
211   FLA_Copy_external( A_save, A );
212   FLA_Copy_external( b_save, b );
213   FLA_Copy_external( b_orig_save, b_orig );
214 
215   FLA_Obj_free( &A_save );
216   FLA_Obj_free( &b_save );
217   FLA_Obj_free( &b_orig_save );
218 }
219 
220