1 /*
2 
3     Copyright (C) 2014, The University of Texas at Austin
4 
5     This file is part of libflame and is available under the 3-Clause
6     BSD license, which can be found in the LICENSE file at the top-level
7     directory, or at http://opensource.org/licenses/BSD-3-Clause
8 
9 */
10 
11 #include "FLAME.h"
12 
13 #ifdef FLA_ENABLE_GPU
14 
15 #include "cublas.h"
16 
FLA_Symm_external_gpu(FLA_Side side,FLA_Uplo uplo,FLA_Obj alpha,FLA_Obj A,void * A_gpu,FLA_Obj B,void * B_gpu,FLA_Obj beta,FLA_Obj C,void * C_gpu)17 FLA_Error FLA_Symm_external_gpu( FLA_Side side, FLA_Uplo uplo, FLA_Obj alpha, FLA_Obj A, void* A_gpu, FLA_Obj B, void* B_gpu, FLA_Obj beta, FLA_Obj C, void* C_gpu )
18 {
19   FLA_Datatype datatype;
20   int          m_C, n_C;
21   int          ldim_A;
22   int          ldim_B;
23   int          ldim_C;
24   char         blas_side;
25   char         blas_uplo;
26 
27   if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
28     FLA_Symm_check( side, uplo, alpha, A, B, beta, C );
29 
30   if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
31 
32   datatype = FLA_Obj_datatype( A );
33 
34   ldim_A   = FLA_Obj_length( A );
35 
36   ldim_B   = FLA_Obj_length( B );
37 
38   m_C      = FLA_Obj_length( C );
39   n_C      = FLA_Obj_width( C );
40   ldim_C   = FLA_Obj_length( C );
41 
42   FLA_Param_map_flame_to_netlib_side( side, &blas_side );
43   FLA_Param_map_flame_to_netlib_uplo( uplo, &blas_uplo );
44 
45 
46   switch( datatype ){
47 
48   case FLA_FLOAT:
49   {
50     float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
51     float *buff_beta  = ( float * ) FLA_FLOAT_PTR( beta );
52 
53     cublasSsymm( blas_side,
54                  blas_uplo,
55                  m_C,
56                  n_C,
57                  *buff_alpha,
58                  ( float * ) A_gpu, ldim_A,
59                  ( float * ) B_gpu, ldim_B,
60                  *buff_beta,
61                  ( float * ) C_gpu, ldim_C );
62 
63     break;
64   }
65 
66   case FLA_DOUBLE:
67   {
68     double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
69     double *buff_beta  = ( double * ) FLA_DOUBLE_PTR( beta );
70 
71     cublasDsymm( blas_side,
72                  blas_uplo,
73                  m_C,
74                  n_C,
75                  *buff_alpha,
76                  ( double * ) A_gpu, ldim_A,
77                  ( double * ) B_gpu, ldim_B,
78                  *buff_beta,
79                  ( double * ) C_gpu, ldim_C );
80 
81     break;
82   }
83 
84   case FLA_COMPLEX:
85   {
86     cuComplex *buff_alpha = ( cuComplex * ) FLA_COMPLEX_PTR( alpha );
87     cuComplex *buff_beta  = ( cuComplex * ) FLA_COMPLEX_PTR( beta );
88 
89     cublasCsymm( blas_side,
90                  blas_uplo,
91                  m_C,
92                  n_C,
93                  *buff_alpha,
94                  ( cuComplex * ) A_gpu, ldim_A,
95                  ( cuComplex * ) B_gpu, ldim_B,
96                  *buff_beta,
97                  ( cuComplex * ) C_gpu, ldim_C );
98 
99     break;
100   }
101 
102   case FLA_DOUBLE_COMPLEX:
103   {
104     cuDoubleComplex *buff_alpha = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
105     cuDoubleComplex *buff_beta  = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
106 
107     cublasZsymm( blas_side,
108                  blas_uplo,
109                  m_C,
110                  n_C,
111                  *buff_alpha,
112                  ( cuDoubleComplex * ) A_gpu, ldim_A,
113                  ( cuDoubleComplex * ) B_gpu, ldim_B,
114                  *buff_beta,
115                  ( cuDoubleComplex * ) C_gpu, ldim_C );
116 
117     break;
118   }
119 
120   }
121 
122   return FLA_SUCCESS;
123 }
124 
125 #endif
126