1!--------------------------------------------------------------------------------------------------! 2! CP2K: A general program to perform molecular dynamics simulations ! 3! Copyright (C) 2000 - 2019 CP2K developers group ! 4!--------------------------------------------------------------------------------------------------! 5 6! ************************************************************************************************** 7!> \brief basic linear algebra operations for full matrixes 8!> \par History 9!> 08.2002 splitted out of qs_blacs [fawzi] 10!> \author Fawzi Mohamed 11! ************************************************************************************************** 12MODULE cp_gemm_interface 13 USE cp_dbcsr_operations, ONLY: copy_dbcsr_to_fm_bc,& 14 copy_fm_to_dbcsr_bc 15 USE cp_fm_basic_linalg, ONLY: cp_fm_gemm 16 USE cp_fm_types, ONLY: cp_fm_get_info,& 17 cp_fm_get_mm_type,& 18 cp_fm_type 19 USE dbcsr_api, ONLY: dbcsr_multiply,& 20 dbcsr_release,& 21 dbcsr_type 22 USE input_constants, ONLY: do_dbcsr,& 23 do_pdgemm 24 USE kinds, ONLY: dp 25 USE message_passing, ONLY: mp_min 26 USE string_utilities, ONLY: uppercase 27#include "./base/base_uses.f90" 28 29 IMPLICIT NONE 30 PRIVATE 31 32 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'cp_gemm_interface' 33 34 PUBLIC :: cp_gemm 35 36CONTAINS 37 38! ************************************************************************************************** 39!> \brief ... 40!> \param transa ... 41!> \param transb ... 42!> \param m ... 43!> \param n ... 44!> \param k ... 45!> \param alpha ... 46!> \param matrix_a ... 47!> \param matrix_b ... 48!> \param beta ... 49!> \param matrix_c ... 50!> \param a_first_col ... 51!> \param a_first_row ... 52!> \param b_first_col ... 53!> \param b_first_row ... 54!> \param c_first_col ... 55!> \param c_first_row ... 56! ************************************************************************************************** 57 SUBROUTINE cp_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, & 58 matrix_c, a_first_col, a_first_row, b_first_col, b_first_row, & 59 c_first_col, c_first_row) 60 CHARACTER(LEN=1), INTENT(IN) :: transa, transb 61 INTEGER, INTENT(IN) :: m, n, k 62 REAL(KIND=dp), INTENT(IN) :: alpha 63 TYPE(cp_fm_type), POINTER :: matrix_a, matrix_b 64 REAL(KIND=dp), INTENT(IN) :: beta 65 TYPE(cp_fm_type), POINTER :: matrix_c 66 INTEGER, INTENT(IN), OPTIONAL :: a_first_col, a_first_row, b_first_col, & 67 b_first_row, c_first_col, c_first_row 68 69 CHARACTER(len=*), PARAMETER :: routineN = 'cp_gemm', routineP = moduleN//':'//routineN 70 71 CHARACTER(LEN=1) :: my_trans 72 INTEGER :: handle, handle1, my_multi 73 INTEGER, DIMENSION(:), POINTER :: a_col_loc, a_row_loc, b_col_loc, & 74 b_row_loc, c_col_loc, c_row_loc 75 TYPE(dbcsr_type) :: a_db, b_db, c_db 76 77 CALL timeset(routineN, handle) 78 79 my_multi = cp_fm_get_mm_type() 80 ! catch the special case that matrices have different blocking 81 ! SCALAPACK can deal with it but dbcsr doesn't like it 82 CALL cp_fm_get_info(matrix_a, nrow_locals=a_row_loc, ncol_locals=a_col_loc) 83 CALL cp_fm_get_info(matrix_b, nrow_locals=b_row_loc, ncol_locals=b_col_loc) 84 CALL cp_fm_get_info(matrix_c, nrow_locals=c_row_loc, ncol_locals=c_col_loc) 85 IF (PRESENT(a_first_row)) my_multi = do_pdgemm 86 IF (PRESENT(a_first_col)) my_multi = do_pdgemm 87 IF (PRESENT(b_first_row)) my_multi = do_pdgemm 88 IF (PRESENT(b_first_col)) my_multi = do_pdgemm 89 IF (PRESENT(c_first_row)) my_multi = do_pdgemm 90 IF (PRESENT(c_first_col)) my_multi = do_pdgemm 91 92 my_trans = transa; CALL uppercase(my_trans) 93 IF (my_trans == 'T') THEN 94 CALL cp_fm_get_info(matrix_a, nrow_locals=a_col_loc, ncol_locals=a_row_loc) 95 END IF 96 97 my_trans = transb; CALL uppercase(my_trans) 98 IF (my_trans == 'T') THEN 99 CALL cp_fm_get_info(matrix_b, nrow_locals=b_col_loc, ncol_locals=b_row_loc) 100 END IF 101 102 IF (my_multi .NE. do_pdgemm) THEN 103 IF (SIZE(a_row_loc) == SIZE(c_row_loc)) THEN 104 IF (ANY(a_row_loc - c_row_loc .NE. 0)) my_multi = do_pdgemm 105 ELSE 106 my_multi = do_pdgemm 107 END IF 108 END IF 109 IF (my_multi .NE. do_pdgemm) THEN 110 IF (SIZE(b_col_loc) == SIZE(c_col_loc)) THEN 111 IF (ANY(b_col_loc - c_col_loc .NE. 0)) my_multi = do_pdgemm 112 ELSE 113 my_multi = do_pdgemm 114 END IF 115 END IF 116 IF (my_multi .NE. do_pdgemm) THEN 117 IF (SIZE(a_col_loc) == SIZE(b_row_loc)) THEN 118 IF (ANY(a_col_loc - b_row_loc .NE. 0)) my_multi = do_pdgemm 119 ELSE 120 my_multi = do_pdgemm 121 END IF 122 END IF 123 124 ! IMPORTANT do_pdgemm is lowest value. If one processor has it set make all do pdgemm 125 IF (cp_fm_get_mm_type() .NE. do_pdgemm) CALL mp_min(my_multi, matrix_a%matrix_struct%para_env%group) 126 127 SELECT CASE (my_multi) 128 CASE (do_pdgemm) 129 CALL timeset("cp_gemm_fm_gemm", handle1) 130 CALL cp_fm_gemm(transa, transb, m, n, k, alpha, matrix_a, matrix_b, beta, matrix_c, & 131 a_first_col=a_first_col, & 132 a_first_row=a_first_row, & 133 b_first_col=b_first_col, & 134 b_first_row=b_first_row, & 135 c_first_col=c_first_col, & 136 c_first_row=c_first_row) 137 CALL timestop(handle1) 138 CASE (do_dbcsr) 139 CALL timeset("cp_gemm_dbcsr_mm", handle1) 140 CALL copy_fm_to_dbcsr_bc(matrix_a, a_db) 141 CALL copy_fm_to_dbcsr_bc(matrix_b, b_db) 142 CALL copy_fm_to_dbcsr_bc(matrix_c, c_db) 143 144 CALL dbcsr_multiply(transa, transb, alpha, a_db, b_db, beta, c_db, last_k=k) 145 146 CALL copy_dbcsr_to_fm_bc(c_db, matrix_c) 147 CALL dbcsr_release(a_db) 148 CALL dbcsr_release(b_db) 149 CALL dbcsr_release(c_db) 150 CALL timestop(handle1) 151 END SELECT 152 CALL timestop(handle) 153 154 END SUBROUTINE cp_gemm 155 156END MODULE cp_gemm_interface 157