1 /*
2
3 Copyright (C) 2014, The University of Texas at Austin
4
5 This file is part of libflame and is available under the 3-Clause
6 BSD license, which can be found in the LICENSE file at the top-level
7 directory, or at http://opensource.org/licenses/BSD-3-Clause
8
9 */
10
11 #include "FLAME.h"
12
13 #ifdef FLA_ENABLE_GPU
14
15 #include "cublas.h"
16
FLA_Symm_external_gpu(FLA_Side side,FLA_Uplo uplo,FLA_Obj alpha,FLA_Obj A,void * A_gpu,FLA_Obj B,void * B_gpu,FLA_Obj beta,FLA_Obj C,void * C_gpu)17 FLA_Error FLA_Symm_external_gpu( FLA_Side side, FLA_Uplo uplo, FLA_Obj alpha, FLA_Obj A, void* A_gpu, FLA_Obj B, void* B_gpu, FLA_Obj beta, FLA_Obj C, void* C_gpu )
18 {
19 FLA_Datatype datatype;
20 int m_C, n_C;
21 int ldim_A;
22 int ldim_B;
23 int ldim_C;
24 char blas_side;
25 char blas_uplo;
26
27 if ( FLA_Check_error_level() == FLA_FULL_ERROR_CHECKING )
28 FLA_Symm_check( side, uplo, alpha, A, B, beta, C );
29
30 if ( FLA_Obj_has_zero_dim( C ) ) return FLA_SUCCESS;
31
32 datatype = FLA_Obj_datatype( A );
33
34 ldim_A = FLA_Obj_length( A );
35
36 ldim_B = FLA_Obj_length( B );
37
38 m_C = FLA_Obj_length( C );
39 n_C = FLA_Obj_width( C );
40 ldim_C = FLA_Obj_length( C );
41
42 FLA_Param_map_flame_to_netlib_side( side, &blas_side );
43 FLA_Param_map_flame_to_netlib_uplo( uplo, &blas_uplo );
44
45
46 switch( datatype ){
47
48 case FLA_FLOAT:
49 {
50 float *buff_alpha = ( float * ) FLA_FLOAT_PTR( alpha );
51 float *buff_beta = ( float * ) FLA_FLOAT_PTR( beta );
52
53 cublasSsymm( blas_side,
54 blas_uplo,
55 m_C,
56 n_C,
57 *buff_alpha,
58 ( float * ) A_gpu, ldim_A,
59 ( float * ) B_gpu, ldim_B,
60 *buff_beta,
61 ( float * ) C_gpu, ldim_C );
62
63 break;
64 }
65
66 case FLA_DOUBLE:
67 {
68 double *buff_alpha = ( double * ) FLA_DOUBLE_PTR( alpha );
69 double *buff_beta = ( double * ) FLA_DOUBLE_PTR( beta );
70
71 cublasDsymm( blas_side,
72 blas_uplo,
73 m_C,
74 n_C,
75 *buff_alpha,
76 ( double * ) A_gpu, ldim_A,
77 ( double * ) B_gpu, ldim_B,
78 *buff_beta,
79 ( double * ) C_gpu, ldim_C );
80
81 break;
82 }
83
84 case FLA_COMPLEX:
85 {
86 cuComplex *buff_alpha = ( cuComplex * ) FLA_COMPLEX_PTR( alpha );
87 cuComplex *buff_beta = ( cuComplex * ) FLA_COMPLEX_PTR( beta );
88
89 cublasCsymm( blas_side,
90 blas_uplo,
91 m_C,
92 n_C,
93 *buff_alpha,
94 ( cuComplex * ) A_gpu, ldim_A,
95 ( cuComplex * ) B_gpu, ldim_B,
96 *buff_beta,
97 ( cuComplex * ) C_gpu, ldim_C );
98
99 break;
100 }
101
102 case FLA_DOUBLE_COMPLEX:
103 {
104 cuDoubleComplex *buff_alpha = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( alpha );
105 cuDoubleComplex *buff_beta = ( cuDoubleComplex * ) FLA_DOUBLE_COMPLEX_PTR( beta );
106
107 cublasZsymm( blas_side,
108 blas_uplo,
109 m_C,
110 n_C,
111 *buff_alpha,
112 ( cuDoubleComplex * ) A_gpu, ldim_A,
113 ( cuDoubleComplex * ) B_gpu, ldim_B,
114 *buff_beta,
115 ( cuDoubleComplex * ) C_gpu, ldim_C );
116
117 break;
118 }
119
120 }
121
122 return FLA_SUCCESS;
123 }
124
125 #endif
126