1!--------------------------------------------------------------------------------------------------! 2! Copyright (C) by the DBCSR developers group - All rights reserved ! 3! This file is part of the DBCSR library. ! 4! ! 5! For information on the license, see the LICENSE file. ! 6! For further information please visit https://dbcsr.cp2k.org ! 7! SPDX-License-Identifier: GPL-2.0+ ! 8!--------------------------------------------------------------------------------------------------! 9 10MODULE dbcsr_tas_mm 11 !! Matrix multiplication for tall-and-skinny matrices. This uses the k-split (non-recursive) CARMA 12 !! algorithm that is communication-optimal as long as the two smaller dimensions have 13 !! the same size. 14 !! Submatrices are obtained by splitting a dimension of the process grid. Multiplication of 15 !! submatrices uses DBCSR Cannon algorithm. Due to unknown sparsity pattern of result matrix, parameters 16 !! (group sizes and process grid dimensions) can not be derived from matrix dimensions and need to be 17 !! set manually. 18 19#:include "dbcsr_tas.fypp" 20 21 USE dbcsr_data_methods, ONLY: & 22 dbcsr_scalar_zero, dbcsr_scalar 23 USE dbcsr_data_types, ONLY: & 24 dbcsr_scalar_type, dbcsr_type_real_8, dbcsr_type_real_4, dbcsr_type_complex_8, dbcsr_type_complex_4 25 USE dbcsr_multiply_api, ONLY: dbcsr_multiply 26 USE dbcsr_tas_base, ONLY: & 27 dbcsr_tas_create, dbcsr_tas_destroy, dbcsr_tas_distribution_destroy, dbcsr_tas_distribution_new, & 28 dbcsr_tas_get_data_type, dbcsr_tas_info, dbcsr_tas_nblkcols_total, & 29 dbcsr_tas_nblkrows_total, dbcsr_tas_filter, dbcsr_tas_get_info, dbcsr_tas_iterator_blocks_left, & 30 dbcsr_tas_get_nze_total, dbcsr_tas_reserve_blocks, dbcsr_tas_iterator_start, dbcsr_tas_iterator_next_block, & 31 dbcsr_tas_iterator_stop, dbcsr_tas_copy, dbcsr_tas_get_block_p, dbcsr_tas_clear, dbcsr_tas_get_num_blocks 32 USE dbcsr_tas_types, ONLY: & 33 dbcsr_tas_distribution_type, dbcsr_tas_split_info, dbcsr_tas_type, dbcsr_tas_iterator 34 USE dbcsr_tas_global, ONLY: & 35 dbcsr_tas_dist_cyclic, dbcsr_tas_dist_arb, dbcsr_tas_distribution, dbcsr_tas_dist_arb_default, & 36 dbcsr_tas_rowcol_data, dbcsr_tas_blk_size_one, dbcsr_tas_default_distvec 37 USE dbcsr_tas_reshape_ops, ONLY: & 38 dbcsr_tas_merge, dbcsr_tas_replicate, dbcsr_tas_reshape 39 USE dbcsr_tas_split, ONLY: & 40 rowsplit, colsplit, dbcsr_tas_get_split_info, dbcsr_tas_create_split, dbcsr_tas_mp_comm, & 41 dbcsr_tas_release_info, accept_pgrid_dims, dbcsr_tas_info_hold, default_nsplit_accept_ratio_1 42 USE dbcsr_tas_util, ONLY: & 43 swap, invert_transpose_flag, array_eq, dbcsr_mp_environ 44 USE dbcsr_types, ONLY: & 45 dbcsr_no_transpose, dbcsr_transpose, dbcsr_type, dbcsr_distribution_obj, dbcsr_mp_obj, & 46 dbcsr_type_no_symmetry 47 USE dbcsr_kinds, ONLY: & 48 int_8, real_8, real_4, default_string_length 49 USE dbcsr_mpiwrap, ONLY: & 50 mp_comm_compare, mp_environ, mp_sum, mp_comm_free, mp_cart_create 51 USE dbcsr_operations, ONLY: & 52 dbcsr_scale, dbcsr_get_info, dbcsr_copy, dbcsr_clear, dbcsr_add 53 USE dbcsr_tas_io, ONLY: & 54 dbcsr_tas_write_dist, dbcsr_tas_write_matrix_info, dbcsr_tas_write_split_info, prep_output_unit 55 USE dbcsr_work_operations, ONLY: dbcsr_create, dbcsr_finalize 56 USE dbcsr_transformations, ONLY: dbcsr_redistribute 57 USE dbcsr_dist_methods, ONLY: dbcsr_distribution_new 58 USE dbcsr_methods, ONLY: & 59 dbcsr_mp_release, dbcsr_release, dbcsr_distribution_release, dbcsr_get_nze 60#include "base/dbcsr_base_uses.f90" 61 62 IMPLICIT NONE 63 PRIVATE 64 65 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_tas_mm' 66 67 PUBLIC :: & 68 dbcsr_tas_multiply, & 69 dbcsr_tas_batched_mm_init, & 70 dbcsr_tas_batched_mm_finalize, & 71 dbcsr_tas_result_index 72 73CONTAINS 74 75 RECURSIVE SUBROUTINE dbcsr_tas_multiply(transa, transb, transc, alpha, matrix_a, matrix_b, beta, matrix_c, & 76 optimize_dist, split_opt, filter_eps, flop, move_data_a, & 77 move_data_b, retain_sparsity, simple_split, result_index, unit_nr, log_verbose) 78 !! tall-and-skinny matrix-matrix multiplication. Undocumented dummy arguments are identical to 79 !! arguments of dbcsr_multiply (see dbcsr_mm, dbcsr_multiply_generic). 80 81 CHARACTER(LEN=1), INTENT(IN) :: transa, transb, transc 82 TYPE(dbcsr_scalar_type), INTENT(IN) :: alpha, beta 83 TYPE(dbcsr_tas_type), TARGET, & 84 INTENT(INOUT) :: matrix_a, matrix_b, matrix_c 85 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist 86 !! Whether distribution should be optimized internally. In the current implementation this guarantees optimal parameters 87 !! only for dense matrices. 88 TYPE(dbcsr_tas_split_info), INTENT(OUT), & 89 POINTER, OPTIONAL :: split_opt 90 !! optionally return split info containing optimal grid and split parameters. This can be used to choose optimal process 91 !! grids for subsequent matrix multiplications with matrices of similar shape and sparsity. under some conditions, 92 !! split_opt can not be returned, in this case the pointer is not associated 93 REAL(KIND=real_8), INTENT(IN), OPTIONAL :: filter_eps 94 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop 95 LOGICAL, INTENT(IN), OPTIONAL :: move_data_a, move_data_b, simple_split, retain_sparsity 96 !! memory optimization: move data to matrix_c such that matrix_a is empty on return 97 !! memory optimization: move data to matrix_c such that matrix_b is empty on return 98 !! for internal use only 99 INTEGER(int_8), DIMENSION(:, :), ALLOCATABLE, INTENT(OUT), OPTIONAL :: result_index 100 INTEGER, OPTIONAL, INTENT(IN) :: unit_nr 101 !! unit number for logging output 102 LOGICAL, OPTIONAL, INTENT(IN) :: log_verbose 103 !! only for testing: verbose output 104 105 TYPE(dbcsr_tas_type), POINTER :: matrix_b_rs, matrix_a_rs, matrix_c_rs, & 106 matrix_c_rep, matrix_b_rep, matrix_a_rep 107 108 REAL(KIND=real_8) :: filter_eps_prv 109 INTEGER(KIND=int_8), DIMENSION(2) :: dims_a, dims_b, dims_c 110 INTEGER, DIMENSION(2) :: pdims, pcoord, pcoord_sub, pdims_sub 111 INTEGER(KIND=int_8), DIMENSION(3) :: dims 112 INTEGER :: max_mm_dim, data_type, mp_comm, comm_tmp, & 113 handle, handle2, unit_nr_prv, nsplit, nsplit_opt, numproc, numproc_sub, iproc, & 114 mp_comm_group, mp_comm_mm, split_rc, split_a, split_b, split_c, & 115 mp_comm_opt, batched_repl, max_mm_dim_batched, nsplit_batched 116 CHARACTER(LEN=1) :: tr_case, transa_prv, transb_prv, transc_prv 117 TYPE(dbcsr_scalar_type) :: zero 118 LOGICAL :: new_a, new_b, new_c, simple_split_prv, opt_pgrid, & 119 move_a, move_b, do_batched, simple_split_save, & 120 nodata_3 121 TYPE(dbcsr_tas_split_info) :: info, info_a, info_b, info_c 122 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_tas_multiply', & 123 routineP = moduleN//':'//routineN 124 INTEGER(KIND=int_8) :: nze_a, nze_b, nze_c 125 TYPE(dbcsr_type), POINTER :: matrix_a_mm, matrix_b_mm, matrix_c_mm 126 127 CALL timeset(routineN, handle) 128 129 NULLIFY (matrix_b_rs, matrix_a_rs, matrix_c_rs, matrix_a_mm, matrix_b_mm, matrix_c_mm) 130 131 unit_nr_prv = prep_output_unit(unit_nr) 132 133 IF (PRESENT(simple_split)) THEN 134 simple_split_prv = simple_split 135 ELSE 136 simple_split_prv = .FALSE. 137 138 info_a = dbcsr_tas_info(matrix_a); info_b = dbcsr_tas_info(matrix_b); info_c = dbcsr_tas_info(matrix_c) 139 IF (info_a%strict_split .OR. info_b%strict_split .OR. info_c%strict_split) simple_split_prv = .TRUE. 140 ENDIF 141 142 nodata_3 = .TRUE. 143 IF (PRESENT(retain_sparsity)) THEN 144 IF (retain_sparsity) nodata_3 = .FALSE. 145 ENDIF 146 147 ! get prestored info for multiplication strategy in case of batched mm 148 batched_repl = 0 149 do_batched = .FALSE. 150 IF (matrix_a%do_batched > 0) THEN 151 do_batched = .TRUE. 152 IF (matrix_a%do_batched == 2) THEN 153 DBCSR_ASSERT(batched_repl == 0) 154 batched_repl = 1 155 CALL dbcsr_tas_get_split_info( & 156 dbcsr_tas_info(matrix_a%mm_storage%store_batched_repl), & 157 nsplit=nsplit_batched) 158 DBCSR_ASSERT(nsplit_batched > 0) 159 max_mm_dim_batched = 3 160 ENDIF 161 ENDIF 162 163 IF (matrix_b%do_batched > 0) THEN 164 do_batched = .TRUE. 165 IF (matrix_b%do_batched == 2) THEN 166 DBCSR_ASSERT(batched_repl == 0) 167 batched_repl = 2 168 CALL dbcsr_tas_get_split_info( & 169 dbcsr_tas_info(matrix_b%mm_storage%store_batched_repl), & 170 nsplit=nsplit_batched) 171 DBCSR_ASSERT(nsplit_batched > 0) 172 max_mm_dim_batched = 1 173 ENDIF 174 ENDIF 175 176 IF (matrix_c%do_batched > 0) THEN 177 do_batched = .TRUE. 178 IF (matrix_c%do_batched == 2) THEN 179 DBCSR_ASSERT(batched_repl == 0) 180 batched_repl = 3 181 CALL dbcsr_tas_get_split_info( & 182 dbcsr_tas_info(matrix_c%mm_storage%store_batched_repl), & 183 nsplit=nsplit_batched) 184 DBCSR_ASSERT(nsplit_batched > 0) 185 max_mm_dim_batched = 2 186 ENDIF 187 ENDIF 188 189 IF (batched_repl > 0) THEN 190 simple_split_save = simple_split_prv 191 simple_split_prv = .TRUE. 192 ENDIF 193 194 move_a = .FALSE. 195 move_b = .FALSE. 196 197 IF (PRESENT(move_data_a)) move_a = move_data_a 198 IF (PRESENT(move_data_b)) move_b = move_data_b 199 200 IF (.NOT. dbcsr_tas_get_data_type(matrix_a) .EQ. dbcsr_tas_get_data_type(matrix_b)) THEN 201 DBCSR_ABORT("matrices must have same datatype") 202 ENDIF 203 204 data_type = dbcsr_tas_get_data_type(matrix_a) 205 206 transa_prv = transa; transb_prv = transb; transc_prv = transc 207 208 dims_a = [dbcsr_tas_nblkrows_total(matrix_a), dbcsr_tas_nblkcols_total(matrix_a)] 209 dims_b = [dbcsr_tas_nblkrows_total(matrix_b), dbcsr_tas_nblkcols_total(matrix_b)] 210 dims_c = [dbcsr_tas_nblkrows_total(matrix_c), dbcsr_tas_nblkcols_total(matrix_c)] 211 212 IF (unit_nr_prv .GT. 0) THEN 213 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 214 WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBCSR TAS MATRIX MULTIPLICATION:", & 215 TRIM(matrix_a%matrix%name), 'x', TRIM(matrix_b%matrix%name), '=', TRIM(matrix_c%matrix%name) 216 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 217 ENDIF 218 IF (do_batched) THEN 219 IF (unit_nr_prv > 0) THEN 220 WRITE (unit_nr_prv, "(T2,A)") & 221 "BATCHED PROCESSING OF MATMUL" 222 IF (batched_repl > 0) THEN 223 WRITE (unit_nr_prv, "(T4,A,T80,I1)") "reusing replicated matrix:", batched_repl 224 ENDIF 225 ENDIF 226 ENDIF 227 228 IF (transa_prv .EQ. dbcsr_transpose) THEN 229 CALL swap(dims_a) 230 ENDIF 231 232 IF (transb_prv .EQ. dbcsr_transpose) THEN 233 CALL swap(dims_b) 234 ENDIF 235 236 dims_c = [dims_a(1), dims_b(2)] 237 238 IF (.NOT. (dims_a(2) .EQ. dims_b(1))) THEN 239 DBCSR_ABORT("inconsistent matrix dimensions") 240 ENDIF 241 242 dims(:) = [dims_a(1), dims_a(2), dims_b(2)] 243 244 tr_case = '' 245 246 IF (unit_nr_prv > 0) THEN 247 WRITE (unit_nr_prv, "(T2,A, 1X, I12, 1X, I12, 1X, I12)") "mm dims:", dims(1), dims(2), dims(3) 248 ENDIF 249 250 CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_a), mp_comm=mp_comm) 251 CALL mp_environ(numproc, iproc, mp_comm) 252 253 ! derive optimal matrix layout and split factor from occupancies 254 IF (.NOT. simple_split_prv) THEN 255 nze_a = dbcsr_tas_get_nze_total(matrix_a) 256 nze_b = dbcsr_tas_get_nze_total(matrix_b) 257 CALL dbcsr_tas_result_index(transa, transb, transc, matrix_a, matrix_b, matrix_c, filter_eps, & 258 blk_ind=result_index, nze=nze_c, retain_sparsity=retain_sparsity) 259 IF (PRESENT(result_index)) THEN 260 CALL timestop(handle) 261 RETURN 262 ENDIF 263 264 max_mm_dim = MAXLOC(dims, 1) 265 nsplit = split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numproc) 266 nsplit_opt = nsplit 267 268 IF (unit_nr_prv > 0) THEN 269 WRITE (unit_nr_prv, "(T2,A)") & 270 "MM PARAMETERS" 271 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. number of matrix elements per CPU of result matrix:", & 272 (nze_c + numproc - 1)/numproc 273 274 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit 275 ENDIF 276 277 ELSEIF (batched_repl > 0) THEN 278 nsplit = nsplit_batched 279 nsplit_opt = nsplit 280 max_mm_dim = max_mm_dim_batched 281 simple_split_prv = simple_split_save 282 IF (unit_nr_prv > 0) THEN 283 WRITE (unit_nr_prv, "(T2,A)") & 284 "MM PARAMETERS" 285 WRITE (unit_nr_prv, "(T4,A,T68,I13)") "Est. optimal split factor:", nsplit 286 ENDIF 287 288 ELSE 289 nsplit = 0 290 max_mm_dim = MAXLOC(dims, 1) 291 ENDIF 292 293 ! reshape matrices to the optimal layout and split factor 294 split_a = rowsplit; split_b = rowsplit; split_c = rowsplit 295 SELECT CASE (max_mm_dim) 296 CASE (1) 297 298 split_a = rowsplit; split_c = rowsplit 299 CALL reshape_mm_compatible(matrix_a, matrix_c, matrix_a_rs, matrix_c_rs, & 300 new_a, new_c, transa_prv, transc_prv, optimize_dist=optimize_dist, & 301 nsplit=nsplit, & 302 opt_nsplit=batched_repl == 0, & 303 split_rc_1=split_a, split_rc_2=split_c, & 304 nodata2=nodata_3, comm_new=comm_tmp, & 305 move_data_1=move_a, unit_nr=unit_nr_prv) 306 307 info = dbcsr_tas_info(matrix_a_rs) 308 CALL dbcsr_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm) 309 310 new_b = .FALSE. 311 IF (matrix_b%do_batched <= 1) THEN 312 ALLOCATE (matrix_b_rs) 313 CALL reshape_mm_small(mp_comm, matrix_b, matrix_b_rs, transb_prv == dbcsr_transpose, transb_prv, move_data=move_b) 314 new_b = .TRUE. 315 ENDIF 316 317 tr_case = transa_prv 318 319 IF (unit_nr_prv > 0) THEN 320 IF (tr_case == 'N') THEN 321 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "| x + = |" 322 ELSE 323 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "--T x + = --T" 324 ENDIF 325 ENDIF 326 327 CASE (2) 328 329 split_a = colsplit; split_b = rowsplit 330 CALL reshape_mm_compatible(matrix_a, matrix_b, matrix_a_rs, matrix_b_rs, new_a, new_b, transa_prv, transb_prv, & 331 optimize_dist=optimize_dist, & 332 nsplit=nsplit, & 333 opt_nsplit=batched_repl == 0, & 334 split_rc_1=split_a, split_rc_2=split_b, & 335 comm_new=comm_tmp, & 336 move_data_1=move_a, move_data_2=move_b, unit_nr=unit_nr_prv) 337 338 info = dbcsr_tas_info(matrix_a_rs) 339 CALL dbcsr_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm) 340 341 IF (matrix_c%do_batched <= 1) THEN 342 ALLOCATE (matrix_c_rs) 343 CALL reshape_mm_small(mp_comm, matrix_c, matrix_c_rs, transc_prv == dbcsr_transpose, transc_prv, nodata=nodata_3) 344 IF (matrix_c%do_batched == 1) THEN 345 matrix_c%mm_storage%store_batched => matrix_c_rs 346 ENDIF 347 ELSEIF (matrix_c%do_batched == 2) THEN 348 matrix_c_rs => matrix_c%mm_storage%store_batched 349 ENDIF 350 351 new_c = matrix_c%do_batched == 0 352 tr_case = transa_prv 353 354 IF (unit_nr_prv > 0) THEN 355 IF (tr_case == 'N') THEN 356 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "-- x --T = +" 357 ELSE 358 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "|T x | = +" 359 ENDIF 360 ENDIF 361 362 CASE (3) 363 364 split_b = colsplit; split_c = colsplit 365 CALL reshape_mm_compatible(matrix_b, matrix_c, matrix_b_rs, matrix_c_rs, new_b, new_c, transb_prv, & 366 transc_prv, optimize_dist=optimize_dist, & 367 nsplit=nsplit, & 368 opt_nsplit=batched_repl == 0, & 369 split_rc_1=split_b, split_rc_2=split_c, & 370 nodata2=nodata_3, comm_new=comm_tmp, & 371 move_data_1=move_b, unit_nr=unit_nr_prv) 372 info = dbcsr_tas_info(matrix_b_rs) 373 CALL dbcsr_tas_get_split_info(info, split_rowcol=split_rc, mp_comm=mp_comm) 374 375 new_a = .FALSE. 376 IF (matrix_a%do_batched <= 1) THEN 377 ALLOCATE (matrix_a_rs) 378 CALL reshape_mm_small(mp_comm, matrix_a, matrix_a_rs, transa_prv == dbcsr_transpose, transa_prv, move_data=move_a) 379 new_a = .TRUE. 380 ENDIF 381 382 tr_case = transb_prv 383 384 IF (unit_nr_prv > 0) THEN 385 IF (tr_case == 'N') THEN 386 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x -- = --" 387 ELSE 388 WRITE (unit_nr_prv, "(T2,A, 1X, A)") "mm case:", "+ x |T = |T" 389 ENDIF 390 ENDIF 391 392 END SELECT 393 394 CALL dbcsr_tas_get_split_info(info, nsplit=nsplit, mp_comm=mp_comm, mp_comm_group=mp_comm_group) 395 396 CALL mp_environ(numproc, pdims, pcoord, mp_comm) 397 CALL mp_environ(numproc_sub, pdims_sub, pcoord_sub, mp_comm_group) 398 399 IF (.NOT. simple_split_prv) THEN 400 opt_pgrid = .NOT. accept_pgrid_dims(pdims_sub, relative=.TRUE.) 401 ELSE 402 opt_pgrid = .FALSE. 403 ENDIF 404 405 IF (PRESENT(filter_eps)) THEN 406 filter_eps_prv = filter_eps 407 ELSE 408 filter_eps_prv = 0.0_real_8 409 ENDIF 410 411 IF (unit_nr_prv /= 0) THEN 412 IF (unit_nr_prv > 0) THEN 413 WRITE (unit_nr_prv, "(T2, A)") "SPLIT / PARALLELIZATION INFO" 414 ENDIF 415 CALL dbcsr_tas_write_split_info(info, unit_nr_prv) 416 IF (ASSOCIATED(matrix_a_rs)) CALL dbcsr_tas_write_matrix_info(matrix_a_rs, unit_nr_prv, full_info=log_verbose) 417 IF (ASSOCIATED(matrix_b_rs)) CALL dbcsr_tas_write_matrix_info(matrix_b_rs, unit_nr_prv, full_info=log_verbose) 418 IF (ASSOCIATED(matrix_c_rs)) CALL dbcsr_tas_write_matrix_info(matrix_c_rs, unit_nr_prv, full_info=log_verbose) 419 IF (unit_nr_prv > 0) THEN 420 IF (opt_pgrid) THEN 421 WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "Yes" 422 ELSE 423 WRITE (unit_nr_prv, "(T4, A, 1X, A)") "Change process grid:", "No" 424 ENDIF 425 ENDIF 426 ENDIF 427 428 zero = dbcsr_scalar_zero(data_type) 429 430 pdims = 0 431 CALL mp_cart_create(mp_comm_group, 2, pdims, pcoord, mp_comm_mm) 432 433 ! Convert DBCSR submatrices to optimized process grids and multiply 434 SELECT CASE (max_mm_dim) 435 CASE (1) 436 IF (matrix_b%do_batched <= 1) THEN 437 ALLOCATE (matrix_b_rep) 438 CALL dbcsr_tas_replicate(matrix_b_rs%matrix, dbcsr_tas_info(matrix_a_rs), matrix_b_rep, move_data=.TRUE.) 439 IF (matrix_b%do_batched == 1) THEN 440 matrix_b%mm_storage%store_batched_repl => matrix_b_rep 441 matrix_b%do_batched = 2 442 ENDIF 443 ELSEIF (matrix_b%do_batched == 2) THEN 444 matrix_b_rep => matrix_b%mm_storage%store_batched_repl 445 ENDIF 446 447 IF (new_b) THEN 448 CALL dbcsr_tas_destroy(matrix_b_rs) 449 DEALLOCATE (matrix_b_rs) 450 ENDIF 451 IF (unit_nr_prv /= 0) THEN 452 CALL dbcsr_tas_write_dist(matrix_a_rs, unit_nr_prv) 453 CALL dbcsr_tas_write_dist(matrix_b_rep, unit_nr_prv, full_info=log_verbose) 454 ENDIF 455 456 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a) 457 458 IF (new_a .AND. opt_pgrid) THEN 459 CALL dbcsr_tas_destroy(matrix_a_rs) 460 DEALLOCATE (matrix_a_rs) 461 ENDIF 462 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rep%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, & 463 move_data=matrix_b%do_batched == 0) 464 465 IF (opt_pgrid .AND. matrix_b%do_batched == 0) THEN 466 CALL dbcsr_tas_destroy(matrix_b_rep) 467 DEALLOCATE (matrix_b_rep) 468 ENDIF 469 470 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid) 471 472 SELECT CASE (tr_case) 473 CASE (dbcsr_no_transpose) 474 CALL timeset(routineN//"_mm_1N", handle2) 475 476 CALL dbcsr_multiply(transa=dbcsr_no_transpose, transb=dbcsr_no_transpose, alpha=alpha, & 477 matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, & 478 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop) 479 CALL timestop(handle2) 480 CASE (dbcsr_transpose) 481 CALL timeset(routineN//"_mm_1T", handle2) 482 CALL dbcsr_multiply(transa=dbcsr_transpose, transb=dbcsr_no_transpose, alpha=alpha, & 483 matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, & 484 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop) 485 486 CALL timestop(handle2) 487 END SELECT 488 489 IF (opt_pgrid) THEN 490 CALL dbcsr_release(matrix_a_mm) 491 CALL dbcsr_release(matrix_b_mm) 492 IF (.NOT. new_c) THEN 493 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, alpha=beta) 494 ELSE 495 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix) 496 ENDIF 497 498 CALL dbcsr_release(matrix_c_mm) 499 500 DEALLOCATE (matrix_a_mm, matrix_b_mm, matrix_c_mm) 501 ELSE 502 IF (new_a) CALL dbcsr_tas_destroy(matrix_a_rs) 503 IF (new_a) DEALLOCATE (matrix_a_rs) 504 IF (matrix_b%do_batched == 0) THEN 505 CALL dbcsr_tas_destroy(matrix_b_rep) 506 DEALLOCATE (matrix_b_rep) 507 ENDIF 508 ENDIF 509 510 IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c_rs, filter_eps) 511 512 IF (unit_nr_prv /= 0) THEN 513 CALL dbcsr_tas_write_dist(matrix_c_rs, unit_nr_prv) 514 ENDIF 515 516 CASE (2) 517 IF (matrix_c%do_batched <= 1) THEN 518 ALLOCATE (matrix_c_rep) 519 CALL dbcsr_tas_replicate(matrix_c_rs%matrix, dbcsr_tas_info(matrix_a_rs), matrix_c_rep, nodata=nodata_3) 520 IF (matrix_c%do_batched == 1) THEN 521 matrix_c%mm_storage%store_batched_repl => matrix_c_rep 522 matrix_c%do_batched = 2 523 ENDIF 524 ELSEIF (matrix_c%do_batched == 2) THEN 525 matrix_c_rep => matrix_c%mm_storage%store_batched_repl 526 ENDIF 527 528 IF (unit_nr_prv /= 0) THEN 529 CALL dbcsr_tas_write_dist(matrix_a_rs, unit_nr_prv) 530 CALL dbcsr_tas_write_dist(matrix_b_rs, unit_nr_prv) 531 ENDIF 532 533 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rs%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, move_data=move_a) 534 IF (new_a .AND. opt_pgrid) THEN 535 CALL dbcsr_tas_destroy(matrix_a_rs) 536 DEALLOCATE (matrix_a_rs) 537 ENDIF 538 539 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b) 540 IF (new_b .AND. opt_pgrid) THEN 541 CALL dbcsr_tas_destroy(matrix_b_rs) 542 DEALLOCATE (matrix_b_rs) 543 ENDIF 544 545 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rep%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid) 546 547 CALL timeset(routineN//"_mm_2", handle2) 548 CALL dbcsr_multiply(transa=transa_prv, transb=transb_prv, alpha=alpha, matrix_a=matrix_a_mm, & 549 matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, & 550 filter_eps=filter_eps_prv/REAL(nsplit, KIND=real_8), retain_sparsity=retain_sparsity, flop=flop) 551 CALL timestop(handle2) 552 553 IF (opt_pgrid) THEN 554 CALL dbcsr_release(matrix_a_mm) 555 CALL dbcsr_release(matrix_b_mm) 556 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rep%matrix, alpha=beta) 557 CALL dbcsr_release(matrix_c_mm) 558 559 DEALLOCATE (matrix_a_mm, matrix_b_mm, matrix_c_mm) 560 ELSE 561 IF (new_a) CALL dbcsr_tas_destroy(matrix_a_rs) 562 IF (new_a) DEALLOCATE (matrix_a_rs) 563 IF (new_b) CALL dbcsr_tas_destroy(matrix_b_rs) 564 IF (new_b) DEALLOCATE (matrix_b_rs) 565 ENDIF 566 567 IF (unit_nr_prv /= 0) THEN 568 CALL dbcsr_tas_write_dist(matrix_c_rep, unit_nr_prv, full_info=log_verbose) 569 ENDIF 570 571 IF (matrix_c%do_batched == 0) THEN 572 CALL dbcsr_tas_merge(matrix_c_rs%matrix, matrix_c_rep, move_data=.TRUE.) 573 ELSE 574 matrix_c%mm_storage%batched_out = .TRUE. ! postpone merging submatrices to dbcsr_tas_batched_mm_finalize 575 ENDIF 576 577 IF (matrix_c%do_batched == 0) THEN 578 CALL dbcsr_tas_destroy(matrix_c_rep) 579 DEALLOCATE (matrix_c_rep) 580 ENDIF 581 IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c_rs, filter_eps) 582 CASE (3) 583 IF (matrix_a%do_batched <= 1) THEN 584 ALLOCATE (matrix_a_rep) 585 CALL dbcsr_tas_replicate(matrix_a_rs%matrix, dbcsr_tas_info(matrix_b_rs), matrix_a_rep, move_data=.TRUE.) 586 IF (matrix_a%do_batched == 1) THEN 587 matrix_a%mm_storage%store_batched_repl => matrix_a_rep 588 matrix_a%do_batched = 2 589 ENDIF 590 ELSEIF (matrix_a%do_batched == 2) THEN 591 matrix_a_rep => matrix_a%mm_storage%store_batched_repl 592 ENDIF 593 594 IF (new_a) THEN 595 CALL dbcsr_tas_destroy(matrix_a_rs) 596 DEALLOCATE (matrix_a_rs) 597 ENDIF 598 IF (unit_nr_prv /= 0) THEN 599 CALL dbcsr_tas_write_dist(matrix_a_rep, unit_nr_prv, full_info=log_verbose) 600 CALL dbcsr_tas_write_dist(matrix_b_rs, unit_nr_prv) 601 ENDIF 602 603 CALL convert_to_new_pgrid(mp_comm_mm, matrix_a_rep%matrix, matrix_a_mm, optimize_pgrid=opt_pgrid, & 604 move_data=matrix_a%do_batched == 0) 605 606 IF (opt_pgrid .AND. matrix_a%do_batched == 0) THEN 607 CALL dbcsr_tas_destroy(matrix_a_rep) 608 DEALLOCATE (matrix_a_rep) 609 ENDIF 610 611 CALL convert_to_new_pgrid(mp_comm_mm, matrix_b_rs%matrix, matrix_b_mm, optimize_pgrid=opt_pgrid, move_data=move_b) 612 613 IF (new_b .AND. opt_pgrid) THEN 614 CALL dbcsr_tas_destroy(matrix_b_rs) 615 DEALLOCATE (matrix_b_rs) 616 ENDIF 617 CALL convert_to_new_pgrid(mp_comm_mm, matrix_c_rs%matrix, matrix_c_mm, nodata=nodata_3, optimize_pgrid=opt_pgrid) 618 619 SELECT CASE (tr_case) 620 CASE (dbcsr_no_transpose) 621 CALL timeset(routineN//"_mm_3N", handle2) 622 CALL dbcsr_multiply(transa=dbcsr_no_transpose, transb=dbcsr_no_transpose, alpha=alpha, & 623 matrix_a=matrix_a_mm, matrix_b=matrix_b_mm, beta=beta, matrix_c=matrix_c_mm, & 624 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop) 625 CALL timestop(handle2) 626 CASE (dbcsr_transpose) 627 CALL timeset(routineN//"_mm_3T", handle2) 628 CALL dbcsr_multiply(transa=dbcsr_no_transpose, transb=dbcsr_transpose, alpha=alpha, & 629 matrix_a=matrix_b_mm, matrix_b=matrix_a_mm, beta=beta, matrix_c=matrix_c_mm, & 630 filter_eps=filter_eps_prv, retain_sparsity=retain_sparsity, flop=flop) 631 CALL timestop(handle2) 632 END SELECT 633 634 IF (opt_pgrid) THEN 635 CALL dbcsr_release(matrix_a_mm) 636 CALL dbcsr_release(matrix_b_mm) 637 638 IF (.NOT. new_c) THEN 639 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix, alpha=beta) 640 ELSE 641 CALL redistribute_and_sum(matrix_c_mm, matrix_c_rs%matrix) 642 ENDIF 643 644 CALL dbcsr_release(matrix_c_mm) 645 646 DEALLOCATE (matrix_a_mm, matrix_b_mm, matrix_c_mm) 647 ELSE 648 IF (new_b) CALL dbcsr_tas_destroy(matrix_b_rs) 649 IF (new_b) DEALLOCATE (matrix_b_rs) 650 IF (matrix_a%do_batched == 0) THEN 651 CALL dbcsr_tas_destroy(matrix_a_rep) 652 DEALLOCATE (matrix_a_rep) 653 ENDIF 654 ENDIF 655 656 IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c_rs, filter_eps) 657 658 IF (unit_nr_prv /= 0) THEN 659 CALL dbcsr_tas_write_dist(matrix_c_rs, unit_nr_prv) 660 ENDIF 661 END SELECT 662 663 CALL mp_comm_free(mp_comm_mm) 664 665 CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_c), mp_comm=mp_comm) 666 667 IF (PRESENT(split_opt) .AND. .NOT. simple_split_prv) THEN 668 ! ideally we should rederive the split factor from the actual sparsity of C, but 669 ! due to parameter beta, we can not get the sparsity of AxB from DBCSR if not new_c 670 ALLOCATE (split_opt) 671 mp_comm_opt = dbcsr_tas_mp_comm(mp_comm, split_rc, nsplit_opt) 672 CALL dbcsr_tas_create_split(split_opt, mp_comm_opt, split_rc, nsplit_opt, own_comm=.TRUE.) 673 ENDIF 674 675 IF (new_c) THEN 676 CALL dbcsr_scale(matrix_c%matrix, beta) 677 CALL dbcsr_tas_reshape(matrix_c_rs, matrix_c, summation=.TRUE., transposed=transc_prv /= transc, & 678 move_data=.TRUE.) 679 CALL dbcsr_tas_destroy(matrix_c_rs) 680 DEALLOCATE (matrix_c_rs) 681 IF (PRESENT(filter_eps)) CALL dbcsr_tas_filter(matrix_c, filter_eps) 682 ELSEIF (matrix_c%do_batched > 0) THEN 683 IF (matrix_c%mm_storage%batched_out) THEN 684 matrix_c%mm_storage%batched_beta = beta 685 matrix_c%mm_storage%batched_trans = transc_prv /= transc 686 ENDIF 687 ENDIF 688 689 IF (PRESENT(move_data_a)) THEN 690 IF (move_data_a) CALL dbcsr_tas_clear(matrix_a) 691 ENDIF 692 IF (PRESENT(move_data_b)) THEN 693 IF (move_data_b) CALL dbcsr_tas_clear(matrix_b) 694 ENDIF 695 696 IF (PRESENT(flop)) THEN 697 CALL mp_sum(flop, mp_comm) 698 flop = (flop + numproc - 1)/numproc 699 ENDIF 700 701 IF (PRESENT(optimize_dist)) THEN 702 IF (optimize_dist) CALL mp_comm_free(comm_tmp) 703 ENDIF 704 IF (unit_nr_prv > 0) THEN 705 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 706 WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "TAS MATRIX MULTIPLICATION DONE" 707 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 708 ENDIF 709 710 CALL timestop(handle) 711 712 END SUBROUTINE 713 714 SUBROUTINE redistribute_and_sum(matrix_in, matrix_out, alpha) 715 TYPE(dbcsr_type), INTENT(IN) :: matrix_in 716 TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out 717 TYPE(dbcsr_scalar_type), INTENT(IN), OPTIONAL :: alpha 718 TYPE(dbcsr_type) :: matrix_tmp 719 720 CALL dbcsr_create(matrix_tmp, matrix_out) 721 CALL dbcsr_redistribute(matrix_in, matrix_tmp) 722 CALL dbcsr_add(matrix_out, matrix_tmp, alpha_scalar=alpha) 723 CALL dbcsr_release(matrix_tmp) 724 725 END SUBROUTINE 726 727 SUBROUTINE reshape_mm_small(mp_comm, matrix_in, matrix_out, transposed, trans, nodata, move_data) 728 !! Make sure that smallest matrix involved in a multiplication is not split and bring it to 729 !! the same process grid as the other 2 matrices. 730 731 INTEGER, INTENT(IN) :: mp_comm 732 !! communicator that defines Cartesian topology 733 TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix_in 734 TYPE(dbcsr_tas_type), INTENT(OUT) :: matrix_out 735 LOGICAL, INTENT(IN) :: transposed 736 !! Whether matrix_out should be transposed 737 CHARACTER(LEN=1), INTENT(INOUT) :: trans 738 !! update transpose flag for DBCSR mm according to 'transposed' argument 739 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data 740 !! Data of matrix_in should not be copied to matrix_out 741 !! memory optimization: move data such that matrix_in is empty on return. 742 743 INTEGER :: numnodes 744 INTEGER(KIND=int_8), DIMENSION(2) :: dims 745 INTEGER, DIMENSION(2) :: pdims, pcoord 746 TYPE(dbcsr_tas_dist_arb) :: new_row_dist, new_col_dist 747 TYPE(dbcsr_tas_distribution_type) :: dist 748 LOGICAL :: nodata_prv 749 CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_small', & 750 routineP = moduleN//':'//routineN 751 INTEGER :: handle 752 753 CALL timeset(routineN, handle) 754 755 IF (PRESENT(nodata)) THEN 756 nodata_prv = nodata 757 ELSE 758 nodata_prv = .FALSE. 759 ENDIF 760 761 IF (transposed) THEN 762 SELECT CASE (trans) 763 CASE (dbcsr_transpose) 764 trans = dbcsr_no_transpose 765 CASE (dbcsr_no_transpose) 766 trans = dbcsr_transpose 767 END SELECT 768 ENDIF 769 770 CALL mp_environ(numnodes, pdims, pcoord, mp_comm) 771 772 dims = [dbcsr_tas_nblkrows_total(matrix_in), dbcsr_tas_nblkcols_total(matrix_in)] 773 774 IF (transposed) CALL swap(dims) 775 776 IF (.NOT. transposed) THEN 777 new_row_dist = dbcsr_tas_dist_arb_default(pdims(1), dims(1), matrix_in%row_blk_size) 778 new_col_dist = dbcsr_tas_dist_arb_default(pdims(2), dims(2), matrix_in%col_blk_size) 779 CALL dbcsr_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.) 780 CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist, dbcsr_tas_get_data_type(matrix_in), & 781 matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.) 782 ELSE 783 new_row_dist = dbcsr_tas_dist_arb_default(pdims(1), dims(1), matrix_in%col_blk_size) 784 new_col_dist = dbcsr_tas_dist_arb_default(pdims(2), dims(2), matrix_in%row_blk_size) 785 CALL dbcsr_tas_distribution_new(dist, mp_comm, new_row_dist, new_col_dist, nosplit=.TRUE.) 786 CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist, dbcsr_tas_get_data_type(matrix_in), & 787 matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.) 788 789 ENDIF 790 IF (.NOT. nodata_prv) CALL dbcsr_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data) 791 792 CALL timestop(handle) 793 794 END SUBROUTINE 795 796 SUBROUTINE reshape_mm_compatible(matrix1_in, matrix2_in, matrix1_out, matrix2_out, new1, new2, trans1, trans2, & 797 optimize_dist, nsplit, opt_nsplit, split_rc_1, split_rc_2, nodata1, nodata2, & 798 move_data_1, move_data_2, comm_new, unit_nr) 799 !! Reshape either matrix1 or matrix2 to make sure that their process grids are compatible with 800 !! the same split factor. 801 802 TYPE(dbcsr_tas_type), TARGET, & 803 INTENT(INOUT) :: matrix1_in, matrix2_in 804 TYPE(dbcsr_tas_type), POINTER, INTENT(OUT) :: matrix1_out, matrix2_out 805 LOGICAL, INTENT(OUT) :: new1, new2 806 !! Whether matrix1_out is a new matrix or simply pointing to matrix1_in 807 !! Whether matrix2_out is a new matrix or simply pointing to matrix2_in 808 CHARACTER(LEN=1), INTENT(INOUT) :: trans1, trans2 809 !! transpose flag of matrix1_in for multiplication 810 !! transpose flag of matrix2_in for multiplication 811 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist 812 !! experimental: optimize matrix splitting and distribution 813 INTEGER, INTENT(IN), OPTIONAL :: nsplit 814 !! Optimal split factor (set to 0 if split factor should not be changed) 815 LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit 816 INTEGER, INTENT(INOUT) :: split_rc_1, split_rc_2 817 !! Whether to split rows or columns for matrix 1 818 !! Whether to split rows or columns for matrix 2 819 INTEGER, INTENT(OUT), OPTIONAL :: comm_new 820 !! returns the new communicator only if optimize_dist 821 LOGICAL, OPTIONAL, INTENT(IN) :: nodata1, nodata2 822 !! Don't copy matrix data from matrix1_in to matrix1_out 823 !! Don't copy matrix data from matrix2_in to matrix2_out 824 LOGICAL, OPTIONAL, INTENT(INOUT) :: move_data_1, move_data_2 825 !! memory optimization: move data such that matrix1_in may be empty on return. 826 !! memory optimization: move data such that matrix2_in may be empty on return. 827 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 828 !! output unit 829 830 INTEGER(KIND=int_8), DIMENSION(2) :: dims1, dims2, dims_ref 831 INTEGER(KIND=int_8) :: d1, d2 832 CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_compatible', & 833 routineP = moduleN//':'//routineN 834 INTEGER :: handle, mp_comm, numnodes, unit_nr_prv, & 835 nsplit_prv, ref, split_rc_ref 836 INTEGER, DIMENSION(2) :: pcoord, pdims 837 LOGICAL :: optimize_dist_prv, trans1_newdist, trans2_newdist 838 TYPE(dbcsr_tas_dist_cyclic) :: row_dist_1, col_dist_1, row_dist_2, col_dist_2 839 TYPE(dbcsr_tas_distribution_type) :: dist_1, dist_2 840 TYPE(dbcsr_tas_split_info) :: split_info 841 INTEGER(KIND=int_8) :: nze1, nze2 842 LOGICAL :: nodata1_prv, nodata2_prv 843 844 CALL timeset(routineN, handle) 845 new1 = .FALSE.; new2 = .FALSE. 846 847 IF (PRESENT(nodata1)) THEN 848 nodata1_prv = nodata1 849 ELSE 850 nodata1_prv = .FALSE. 851 ENDIF 852 853 IF (PRESENT(nodata2)) THEN 854 nodata2_prv = nodata2 855 ELSE 856 nodata2_prv = .FALSE. 857 ENDIF 858 859 unit_nr_prv = prep_output_unit(unit_nr) 860 861 NULLIFY (matrix1_out, matrix2_out) 862 863 IF (PRESENT(optimize_dist)) THEN 864 optimize_dist_prv = optimize_dist 865 ELSE 866 optimize_dist_prv = .FALSE. 867 ENDIF 868 869 dims1 = [dbcsr_tas_nblkrows_total(matrix1_in), dbcsr_tas_nblkcols_total(matrix1_in)] 870 dims2 = [dbcsr_tas_nblkrows_total(matrix2_in), dbcsr_tas_nblkcols_total(matrix2_in)] 871 nze1 = dbcsr_tas_get_nze_total(matrix1_in) 872 nze2 = dbcsr_tas_get_nze_total(matrix2_in) 873 874 IF (trans1 == dbcsr_transpose) split_rc_1 = MOD(split_rc_1, 2) + 1 875 876 IF (trans2 == dbcsr_transpose) split_rc_2 = MOD(split_rc_2, 2) + 1 877 878 IF (nze1 >= nze2) THEN 879 ref = 1 880 split_rc_ref = split_rc_1 881 dims_ref = dims1 882 ELSE 883 ref = 2 884 split_rc_ref = split_rc_2 885 dims_ref = dims2 886 ENDIF 887 888 IF (PRESENT(nsplit)) THEN 889 nsplit_prv = nsplit 890 ELSE 891 nsplit_prv = 0 892 ENDIF 893 894 IF (optimize_dist_prv) THEN 895 DBCSR_ASSERT(PRESENT(comm_new)) 896 ENDIF 897 898 IF ((.NOT. optimize_dist_prv) .AND. dist_compatible(matrix1_in, matrix2_in, split_rc_1, split_rc_2)) THEN 899 CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, & 900 move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit) 901 CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, & 902 move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit) 903 IF (unit_nr_prv > 0) THEN 904 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "No redistribution of", TRIM(matrix1_in%matrix%name), & 905 "and", TRIM(matrix2_in%matrix%name) 906 IF (new1) THEN 907 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix1_in%matrix%name), ": Yes" 908 ELSE 909 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix1_in%matrix%name), ": No" 910 ENDIF 911 IF (new2) THEN 912 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix2_in%matrix%name), ": Yes" 913 ELSE 914 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A)") "Change split factor of", TRIM(matrix2_in%matrix%name), ": No" 915 ENDIF 916 ENDIF 917 ELSE 918 919 IF (optimize_dist_prv) THEN 920 IF (unit_nr_prv > 0) THEN 921 WRITE (unit_nr_prv, "(T2,A,1X,A,1X,A,1X,A)") "Optimizing distribution of", TRIM(matrix1_in%matrix%name), & 922 "and", TRIM(matrix2_in%matrix%name) 923 ENDIF 924 925 trans1_newdist = (split_rc_1 == colsplit) 926 trans2_newdist = (split_rc_2 == colsplit) 927 928 IF (trans1_newdist) THEN 929 CALL swap(dims1) 930 CALL invert_transpose_flag(trans1) 931 ENDIF 932 933 IF (trans2_newdist) THEN 934 CALL swap(dims2) 935 CALL invert_transpose_flag(trans2) 936 ENDIF 937 938 IF (nsplit_prv == 0) THEN 939 SELECT CASE (split_rc_ref) 940 CASE (rowsplit) 941 d1 = dims_ref(1) 942 d2 = dims_ref(2) 943 CASE (colsplit) 944 d1 = dims_ref(2) 945 d2 = dims_ref(1) 946 END SELECT 947 nsplit_prv = INT((d1 - 1)/d2 + 1) 948 ENDIF 949 950 DBCSR_ASSERT(nsplit_prv > 0) 951 952 CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix1_in), mp_comm=mp_comm) 953 comm_new = dbcsr_tas_mp_comm(mp_comm, rowsplit, nsplit_prv) 954 CALL dbcsr_tas_create_split(split_info, comm_new, rowsplit, nsplit_prv) 955 956 CALL mp_environ(numnodes, pdims, pcoord, comm_new) 957 958 ! use a very simple cyclic distribution that may not be load balanced if block 959 ! sizes are not equal. However we can not use arbitrary distributions 960 ! for large dimensions since this would require storing distribution vectors as arrays 961 ! which can not be stored for large dimensions. 962 row_dist_1 = dbcsr_tas_dist_cyclic(1, pdims(1), dims1(1)) 963 col_dist_1 = dbcsr_tas_dist_cyclic(1, pdims(2), dims1(2)) 964 965 row_dist_2 = dbcsr_tas_dist_cyclic(1, pdims(1), dims2(1)) 966 col_dist_2 = dbcsr_tas_dist_cyclic(1, pdims(2), dims2(2)) 967 968 CALL dbcsr_tas_distribution_new(dist_1, comm_new, row_dist_1, col_dist_1, split_info=split_info) 969 CALL dbcsr_tas_distribution_new(dist_2, comm_new, row_dist_2, col_dist_2, split_info=split_info) 970 CALL dbcsr_tas_release_info(split_info) 971 972 ALLOCATE (matrix1_out) 973 IF (.NOT. trans1_newdist) THEN 974 CALL dbcsr_tas_create(matrix1_out, matrix1_in%matrix%name, dist_1, dbcsr_tas_get_data_type(matrix1_in), & 975 matrix1_in%row_blk_size, matrix1_in%col_blk_size, own_dist=.TRUE.) 976 977 ELSE 978 CALL dbcsr_tas_create(matrix1_out, matrix1_in%matrix%name, dist_1, dbcsr_tas_get_data_type(matrix1_in), & 979 matrix1_in%col_blk_size, matrix1_in%row_blk_size, own_dist=.TRUE.) 980 ENDIF 981 982 ALLOCATE (matrix2_out) 983 IF (.NOT. trans2_newdist) THEN 984 CALL dbcsr_tas_create(matrix2_out, matrix2_in%matrix%name, dist_2, dbcsr_tas_get_data_type(matrix2_in), & 985 matrix2_in%row_blk_size, matrix2_in%col_blk_size, own_dist=.TRUE.) 986 ELSE 987 CALL dbcsr_tas_create(matrix2_out, matrix2_in%matrix%name, dist_2, dbcsr_tas_get_data_type(matrix2_in), & 988 matrix2_in%col_blk_size, matrix2_in%row_blk_size, own_dist=.TRUE.) 989 ENDIF 990 991 IF (.NOT. nodata1_prv) CALL dbcsr_tas_reshape(matrix1_in, matrix1_out, transposed=trans1_newdist, move_data=move_data_1) 992 IF (.NOT. nodata2_prv) CALL dbcsr_tas_reshape(matrix2_in, matrix2_out, transposed=trans2_newdist, move_data=move_data_2) 993 new1 = .TRUE. 994 new2 = .TRUE. 995 996 ELSE 997 SELECT CASE (ref) 998 CASE (1) 999 IF (unit_nr_prv > 0) THEN 1000 WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", TRIM(matrix2_in%matrix%name) 1001 ENDIF 1002 1003 CALL change_split(matrix1_in, matrix1_out, nsplit_prv, split_rc_1, new1, & 1004 move_data=move_data_1, nodata=nodata1, opt_nsplit=opt_nsplit) 1005 1006 ALLOCATE (matrix2_out) 1007 CALL reshape_mm_template(matrix1_out, matrix2_in, matrix2_out, trans2, split_rc_2, & 1008 nodata=nodata2, move_data=move_data_2) 1009 new2 = .TRUE. 1010 CASE (2) 1011 IF (unit_nr_prv > 0) THEN 1012 WRITE (unit_nr_prv, "(T2,A,1X,A)") "Redistribution of", TRIM(matrix1_in%matrix%name) 1013 ENDIF 1014 1015 CALL change_split(matrix2_in, matrix2_out, nsplit_prv, split_rc_2, new2, & 1016 move_data=move_data_2, nodata=nodata2, opt_nsplit=opt_nsplit) 1017 1018 ALLOCATE (matrix1_out) 1019 CALL reshape_mm_template(matrix2_out, matrix1_in, matrix1_out, trans1, split_rc_1, & 1020 nodata=nodata1, move_data=move_data_1) 1021 new1 = .TRUE. 1022 END SELECT 1023 ENDIF 1024 ENDIF 1025 1026 IF (PRESENT(move_data_1) .AND. new1) move_data_1 = .TRUE. 1027 IF (PRESENT(move_data_2) .AND. new2) move_data_2 = .TRUE. 1028 1029 CALL timestop(handle) 1030 1031 END SUBROUTINE 1032 1033 SUBROUTINE change_split(matrix_in, matrix_out, nsplit, split_rowcol, is_new, opt_nsplit, move_data, nodata) 1034 !! Change split factor without redistribution 1035 1036 TYPE(dbcsr_tas_type), TARGET, & 1037 INTENT(INOUT) :: matrix_in 1038 TYPE(dbcsr_tas_type), POINTER, INTENT(OUT) :: matrix_out 1039 INTEGER, INTENT(IN) :: nsplit 1040 !! new split factor, set to 0 to not change split of matrix_in 1041 INTEGER, INTENT(IN) :: split_rowcol 1042 !! split rows or columns 1043 LOGICAL, INTENT(OUT) :: is_new 1044 !! whether matrix_out is new or a pointer to matrix_in 1045 LOGICAL, INTENT(IN), OPTIONAL :: opt_nsplit 1046 !! whether nsplit should be optimized for current process grid 1047 LOGICAL, INTENT(IN), OPTIONAL :: nodata 1048 !! Data of matrix_in should not be copied to matrix_out 1049 LOGICAL, INTENT(INOUT), OPTIONAL :: move_data 1050 !! memory optimization: move data such that matrix_in is empty on return. 1051 1052 INTEGER :: & 1053 mp_comm, split_rc, nsplit_old, handle, data_type, nsplit_new, nsplit_prv 1054 TYPE(dbcsr_tas_split_info) :: split_info 1055 CHARACTER(len=default_string_length) :: name 1056 TYPE(dbcsr_tas_distribution_type) :: dist 1057 LOGICAL :: nodata_prv 1058 CLASS(dbcsr_tas_distribution), ALLOCATABLE :: rdist, cdist 1059 CLASS(dbcsr_tas_rowcol_data), ALLOCATABLE :: rbsize, cbsize 1060 CHARACTER(LEN=*), PARAMETER :: routineN = 'change_split', & 1061 routineP = moduleN//':'//routineN 1062 NULLIFY (matrix_out) 1063 1064 is_new = .TRUE. 1065 1066 CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_in), mp_comm=mp_comm, & 1067 split_rowcol=split_rc, nsplit=nsplit_old) 1068 1069 IF (nsplit == 0) THEN 1070 IF (split_rowcol == split_rc) THEN 1071 matrix_out => matrix_in 1072 is_new = .FALSE. 1073 RETURN 1074 ELSE 1075 nsplit_prv = 1 1076 ENDIF 1077 ELSE 1078 nsplit_prv = nsplit 1079 ENDIF 1080 1081 CALL timeset(routineN, handle) 1082 1083 nodata_prv = .FALSE. 1084 IF (PRESENT(nodata)) nodata_prv = nodata 1085 1086 CALL dbcsr_tas_get_info(matrix_in, data_type=data_type, name=name, & 1087 row_blk_size=rbsize, col_blk_size=cbsize, & 1088 proc_row_dist=rdist, proc_col_dist=cdist) 1089 1090 CALL dbcsr_tas_create_split(split_info, mp_comm, split_rowcol, nsplit_prv, opt_nsplit=opt_nsplit) 1091 1092 CALL dbcsr_tas_get_split_info(split_info, nsplit=nsplit_new) 1093 1094 IF (nsplit_old == nsplit_new .AND. split_rc == split_rowcol) THEN 1095 matrix_out => matrix_in 1096 is_new = .FALSE. 1097 CALL dbcsr_tas_release_info(split_info) 1098 CALL timestop(handle) 1099 RETURN 1100 ENDIF 1101 1102 CALL dbcsr_tas_distribution_new(dist, mp_comm, rdist, cdist, & 1103 split_info=split_info) 1104 1105 CALL dbcsr_tas_release_info(split_info) 1106 1107 ALLOCATE (matrix_out) 1108 CALL dbcsr_tas_create(matrix_out, name, dist, & 1109 data_type, & 1110 rbsize, cbsize, own_dist=.TRUE.) 1111 1112 IF (.NOT. nodata_prv) CALL dbcsr_tas_copy(matrix_out, matrix_in) 1113 1114 IF (PRESENT(move_data)) THEN 1115 IF (.NOT. nodata_prv) THEN 1116 IF (move_data) CALL dbcsr_tas_clear(matrix_in) 1117 move_data = .TRUE. 1118 ENDIF 1119 ENDIF 1120 1121 CALL timestop(handle) 1122 END SUBROUTINE 1123 1124 FUNCTION dist_compatible(mat_a, mat_b, split_rc_a, split_rc_b, unit_nr) 1125 !! Check whether matrices have same distribution and same split. 1126 TYPE(dbcsr_tas_type), INTENT(IN) :: mat_a, mat_b 1127 INTEGER, INTENT(IN) :: split_rc_a, split_rc_b 1128 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 1129 LOGICAL :: dist_compatible 1130 1131 INTEGER :: res, same_local_rowcols, split_check 1132 TYPE(dbcsr_tas_split_info) :: info_a, info_b 1133 INTEGER :: unit_nr_prv, numproc 1134 INTEGER, DIMENSION(2) :: pdims_a, pdims_b, pcoord_a, pcoord_b 1135 1136 unit_nr_prv = prep_output_unit(unit_nr) 1137 1138 dist_compatible = .FALSE. 1139 1140 info_a = dbcsr_tas_info(mat_a) 1141 info_b = dbcsr_tas_info(mat_b) 1142 CALL dbcsr_tas_get_split_info(info_a, split_rowcol=split_check) 1143 IF (split_check /= split_rc_a) RETURN 1144 CALL dbcsr_tas_get_split_info(info_b, split_rowcol=split_check) 1145 IF (split_check /= split_rc_b) RETURN 1146 1147 ! check if communicators are equivalent (global process grid and subgrids) 1148 ! Note: mpi_comm_compare is not sufficient since this does not compare associated Cartesian grids. 1149 ! It's sufficient to check dimensions of global grid and rank order of subgrids. 1150 CALL mp_environ(numproc, pdims_a, pcoord_a, info_a%mp_comm) 1151 CALL mp_environ(numproc, pdims_b, pcoord_b, info_b%mp_comm) 1152 IF (.NOT. array_eq(pdims_a, pdims_b)) THEN 1153 RETURN 1154 ENDIF 1155 1156 CALL mp_comm_compare(info_a%mp_comm_group, info_b%mp_comm_group, res) 1157 IF (res .GT. 1) THEN 1158 RETURN 1159 ENDIF 1160 1161 IF (unit_nr_prv > 0) THEN 1162 WRITE (unit_nr_prv, *) "mp comm compatible" 1163 ENDIF 1164 1165 IF (mat_a%dist%info%split_rowcol == mat_b%dist%info%split_rowcol) THEN 1166 1167 IF (unit_nr_prv > 0) THEN 1168 WRITE (unit_nr_prv, *) "split compatible" 1169 ENDIF 1170 1171 same_local_rowcols = MERGE(1, 0, array_eq(mat_a%dist%local_rowcols, mat_b%dist%local_rowcols)) 1172 CALL mp_sum(same_local_rowcols, info_a%mp_comm) 1173 1174 IF (same_local_rowcols == numproc) THEN 1175 IF (unit_nr_prv > 0) THEN 1176 WRITE (unit_nr_prv, *) "local rowcols compatible" 1177 ENDIF 1178 dist_compatible = .TRUE. 1179 ELSE 1180 IF (unit_nr_prv > 0) THEN 1181 WRITE (unit_nr_prv, *) "local rowcols A", mat_a%dist%local_rowcols 1182 WRITE (unit_nr_prv, *) "local rowcols B", mat_b%dist%local_rowcols 1183 ENDIF 1184 ENDIF 1185 ENDIF 1186 1187 IF (unit_nr_prv > 0) THEN 1188 WRITE (unit_nr_prv, *) "is compatible?", dist_compatible 1189 ENDIF 1190 1191 END FUNCTION 1192 1193 SUBROUTINE reshape_mm_template(template, matrix_in, matrix_out, trans, split_rc, nodata, move_data) 1194 !! Reshape matrix_in s.t. it has same process grid, distribution and split as template 1195 TYPE(dbcsr_tas_type), INTENT(IN) :: template 1196 TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix_in 1197 TYPE(dbcsr_tas_type), INTENT(OUT) :: matrix_out 1198 CHARACTER(LEN=1), INTENT(INOUT) :: trans 1199 INTEGER, INTENT(IN) :: split_rc 1200 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data 1201 CLASS(dbcsr_tas_distribution), ALLOCATABLE :: row_dist, col_dist 1202 1203 TYPE(dbcsr_tas_distribution_type) :: dist_new 1204 TYPE(dbcsr_tas_split_info) :: info_template, info_matrix 1205 INTEGER :: mp_comm, dim_split_template, dim_split_matrix, & 1206 numnodes, handle 1207 INTEGER, DIMENSION(2) :: pcoord, pdims 1208 LOGICAL :: nodata_prv, transposed 1209 CHARACTER(LEN=*), PARAMETER :: routineN = 'reshape_mm_template', & 1210 routineP = moduleN//':'//routineN 1211 1212 CALL timeset(routineN, handle) 1213 1214 IF (PRESENT(nodata)) THEN 1215 nodata_prv = nodata 1216 ELSE 1217 nodata_prv = .FALSE. 1218 ENDIF 1219 1220 info_template = dbcsr_tas_info(template) 1221 info_matrix = dbcsr_tas_info(matrix_in) 1222 1223 dim_split_template = info_template%split_rowcol 1224 dim_split_matrix = split_rc 1225 1226 transposed = dim_split_template .NE. dim_split_matrix 1227 IF (transposed) THEN 1228 SELECT CASE (trans) 1229 CASE (dbcsr_transpose) 1230 trans = dbcsr_no_transpose 1231 CASE (dbcsr_no_transpose) 1232 trans = dbcsr_transpose 1233 END SELECT 1234 ENDIF 1235 1236 CALL mp_environ(numnodes, pdims, pcoord, info_template%mp_comm) 1237 1238 SELECT CASE (dim_split_template) 1239 CASE (1) 1240 IF (.NOT. transposed) THEN 1241 ALLOCATE (row_dist, source=template%dist%row_dist) 1242 ALLOCATE (col_dist, source=dbcsr_tas_dist_arb_default(pdims(2), matrix_in%nblkcols, matrix_in%col_blk_size)) 1243 ELSE 1244 ALLOCATE (row_dist, source=template%dist%row_dist) 1245 ALLOCATE (col_dist, source=dbcsr_tas_dist_arb_default(pdims(2), matrix_in%nblkrows, matrix_in%row_blk_size)) 1246 ENDIF 1247 CASE (2) 1248 IF (.NOT. transposed) THEN 1249 ALLOCATE (row_dist, source=dbcsr_tas_dist_arb_default(pdims(1), matrix_in%nblkrows, matrix_in%row_blk_size)) 1250 ALLOCATE (col_dist, source=template%dist%col_dist) 1251 ELSE 1252 ALLOCATE (row_dist, source=dbcsr_tas_dist_arb_default(pdims(1), matrix_in%nblkcols, matrix_in%col_blk_size)) 1253 ALLOCATE (col_dist, source=template%dist%col_dist) 1254 ENDIF 1255 END SELECT 1256 1257 CALL dbcsr_tas_get_split_info(info_template, mp_comm=mp_comm) 1258 CALL dbcsr_tas_distribution_new(dist_new, mp_comm, row_dist, col_dist, split_info=info_template) 1259 IF (.NOT. transposed) THEN 1260 CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist_new, dbcsr_tas_get_data_type(matrix_in), & 1261 matrix_in%row_blk_size, matrix_in%col_blk_size, own_dist=.TRUE.) 1262 ELSE 1263 CALL dbcsr_tas_create(matrix_out, matrix_in%matrix%name, dist_new, dbcsr_tas_get_data_type(matrix_in), & 1264 matrix_in%col_blk_size, matrix_in%row_blk_size, own_dist=.TRUE.) 1265 ENDIF 1266 1267 IF (.NOT. nodata_prv) CALL dbcsr_tas_reshape(matrix_in, matrix_out, transposed=transposed, move_data=move_data) 1268 1269 CALL timestop(handle) 1270 1271 END SUBROUTINE 1272 1273 SUBROUTINE dbcsr_tas_result_index(transa, transb, transc, matrix_a, matrix_b, matrix_c, filter_eps, & 1274 unit_nr, blk_ind, nze, retain_sparsity) 1275 !! Estimate sparsity pattern of C resulting from A x B = C by multiplying the block norms of A and B 1276 !! Same dummy arguments as dbcsr_tas_multiply 1277 CHARACTER(LEN=1), INTENT(IN) :: transa, transb, transc 1278 TYPE(dbcsr_tas_type), INTENT(INOUT), TARGET :: matrix_a, matrix_b, matrix_c 1279 TYPE(dbcsr_tas_type), POINTER :: matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm 1280 REAL(KIND=real_8), INTENT(IN), OPTIONAL :: filter_eps 1281 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 1282 INTEGER(int_8), DIMENSION(:, :), ALLOCATABLE, INTENT(OUT), OPTIONAL :: blk_ind 1283 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity 1284 INTEGER(int_8), INTENT(OUT), OPTIONAL :: nze 1285 1286 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_tas_result_index', & 1287 routineP = moduleN//':'//routineN 1288 LOGICAL :: retain_sparsity_prv 1289 INTEGER :: bn, row_size, col_size, handle, iblk, mp_comm, nblk 1290 INTEGER(int_8) :: row, col 1291 TYPE(dbcsr_tas_iterator) :: iter 1292 1293 CALL timeset(routineN, handle) 1294 1295 IF (PRESENT(retain_sparsity)) THEN 1296 retain_sparsity_prv = retain_sparsity 1297 ELSE 1298 retain_sparsity_prv = .FALSE. 1299 ENDIF 1300 1301 IF (.NOT. retain_sparsity_prv) THEN 1302 ALLOCATE (matrix_a_bnorm, matrix_b_bnorm, matrix_c_bnorm) 1303 CALL create_block_norms_matrix(matrix_a, matrix_a_bnorm) 1304 CALL create_block_norms_matrix(matrix_b, matrix_b_bnorm) 1305 CALL create_block_norms_matrix(matrix_c, matrix_c_bnorm, nodata=.TRUE.) 1306 1307 CALL dbcsr_tas_multiply(transa, transb, transc, dbcsr_scalar(1.0_real_8), matrix_a_bnorm, & 1308 matrix_b_bnorm, dbcsr_scalar(0.0_real_8), matrix_c_bnorm, & 1309 filter_eps=filter_eps, move_data_a=.TRUE., move_data_b=.TRUE., & 1310 simple_split=.TRUE., unit_nr=unit_nr) 1311 CALL dbcsr_tas_destroy(matrix_a_bnorm) 1312 CALL dbcsr_tas_destroy(matrix_b_bnorm) 1313 1314 DEALLOCATE (matrix_a_bnorm, matrix_b_bnorm) 1315 ELSE 1316 matrix_c_bnorm => matrix_c 1317 ENDIF 1318 1319 nblk = dbcsr_tas_get_num_blocks(matrix_c_bnorm) 1320 IF (PRESENT(blk_ind)) ALLOCATE (blk_ind(nblk, 2)) 1321 1322 CALL dbcsr_tas_iterator_start(iter, matrix_c_bnorm) 1323 IF (PRESENT(nze)) nze = 0 1324 DO iblk = 1, nblk 1325 CALL dbcsr_tas_iterator_next_block(iter, row, col, bn) 1326 row_size = matrix_c%row_blk_size%data(row) 1327 col_size = matrix_c%col_blk_size%data(col) 1328 IF (PRESENT(nze)) nze = nze + row_size*col_size 1329 IF (PRESENT(blk_ind)) blk_ind(iblk, :) = [row, col] 1330 ENDDO 1331 CALL dbcsr_tas_iterator_stop(iter) 1332 1333 IF (PRESENT(nze)) THEN 1334 CALL dbcsr_tas_get_split_info(dbcsr_tas_info(matrix_a), mp_comm=mp_comm) 1335 CALL mp_sum(nze, mp_comm) 1336 ENDIF 1337 1338 IF (.NOT. retain_sparsity_prv) THEN 1339 CALL dbcsr_tas_destroy(matrix_c_bnorm) 1340 DEALLOCATE (matrix_c_bnorm) 1341 ENDIF 1342 1343 CALL timestop(handle) 1344 1345 END SUBROUTINE 1346 1347 FUNCTION split_factor_estimate(max_mm_dim, nze_a, nze_b, nze_c, numnodes) RESULT(nsplit) 1348 !! Estimate optimal split factor for AxB=C from occupancies (number of non-zero elements) 1349 !! This estimate is based on the minimization of communication volume whereby 1350 !! the communication of CARMA n-split step and CANNON-multiplication of submatrices are 1351 !! considered. 1352 !! \result estimated split factor 1353 1354 INTEGER, INTENT(IN) :: max_mm_dim 1355 INTEGER(KIND=int_8), INTENT(IN) :: nze_a, nze_b, nze_c 1356 !! number of non-zeroes in A 1357 !! number of non-zeroes in B 1358 !! number of non-zeroes in C 1359 INTEGER, INTENT(IN) :: numnodes 1360 !! number of MPI ranks 1361 INTEGER :: nsplit, nsplit_comm, nsplit_memory 1362 INTEGER(KIND=int_8) :: max_nze, min_nze 1363 1364 SELECT CASE (max_mm_dim) 1365 CASE (1) 1366 min_nze = MAX(nze_b, 1_int_8) 1367 max_nze = MAX(MAXVAL([nze_a, nze_c]), 1_int_8) 1368 CASE (2) 1369 min_nze = MAX(nze_c, 1_int_8) 1370 max_nze = MAX(MAXVAL([nze_a, nze_b]), 1_int_8) 1371 CASE (3) 1372 min_nze = MAX(nze_a, 1_int_8) 1373 max_nze = MAX(MAXVAL([nze_b, nze_c]), 1_int_8) 1374 CASE DEFAULT 1375 DBCSR_ABORT("") 1376 END SELECT 1377 1378 nsplit_comm = NINT((REAL(nze_a + nze_b, real_8)/(2*min_nze))**(2._real_8/3)*REAL(numnodes, real_8)**(1._real_8/3)) 1379 IF (nsplit_comm == 0) nsplit_comm = 1 1380 1381 ! nsplit_memory protects against excess memory usage 1382 ! actual split factor may be up to default_nsplit_accept_ratio_1 times larger, so the largest nsplit 1383 ! that fits into memory used by A or B needs to be divided by this factor 1384 nsplit_memory = CEILING(REAL((max_nze - 1)/min_nze + 1, real_8)/default_nsplit_accept_ratio_1) 1385 1386 nsplit = MIN(nsplit_comm, nsplit_memory) 1387 1388 END FUNCTION 1389 1390 SUBROUTINE create_block_norms_matrix(matrix_in, matrix_out, nodata) 1391 !! Create a matrix with block sizes one that contains the block norms of matrix_in 1392 TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix_in 1393 TYPE(dbcsr_tas_type), INTENT(OUT) :: matrix_out 1394 LOGICAL, INTENT(IN), OPTIONAL :: nodata 1395 TYPE(dbcsr_tas_blk_size_one) :: row_blk_size, col_blk_size 1396 TYPE(dbcsr_tas_iterator) :: iter 1397 INTEGER(KIND=int_8) :: row, column, nblkrows, nblkcols 1398 CHARACTER(len=default_string_length) :: name 1399 INTEGER :: data_type 1400 1401#:for dparam, dtype, dsuffix in dtype_float_list 1402 ${dtype}$, DIMENSION(:, :), POINTER :: block_get_${dsuffix}$ 1403 ${dtype}$, DIMENSION(:, :), POINTER :: block_put_${dsuffix}$ 1404#:endfor 1405 LOGICAL :: tr, nodata_prv, found 1406 1407 DBCSR_ASSERT(matrix_in%valid) 1408 1409 IF (PRESENT(nodata)) THEN 1410 nodata_prv = nodata 1411 ELSE 1412 nodata_prv = .FALSE. 1413 ENDIF 1414 1415 CALL dbcsr_tas_get_info(matrix_in, data_type=data_type, name=name, & 1416 nblkrows_total=nblkrows, nblkcols_total=nblkcols) 1417 1418 row_blk_size = dbcsr_tas_blk_size_one(nblkrows) 1419 col_blk_size = dbcsr_tas_blk_size_one(nblkcols) 1420 1421 ! not sure if assumption that same distribution can be taken still holds 1422 CALL dbcsr_tas_create(matrix_out, name, matrix_in%dist, & 1423 data_type, & 1424 row_blk_size, col_blk_size) 1425 1426 IF (.NOT. nodata_prv) THEN 1427 CALL dbcsr_tas_reserve_blocks(matrix_in, matrix_out) 1428 1429 CALL dbcsr_tas_iterator_start(iter, matrix_in) 1430 1431 DO WHILE (dbcsr_tas_iterator_blocks_left(iter)) 1432 1433#:for dparam, dtype, dsuffix in dtype_float_list 1434 IF (data_type == ${dparam}$) THEN 1435 CALL dbcsr_tas_iterator_next_block(iter, row, column, block_get_${dsuffix}$, tr) 1436 CALL dbcsr_tas_get_block_p(matrix_out, row, column, block_put_${dsuffix}$, tr, found) 1437 DBCSR_ASSERT(found) 1438 block_put_${dsuffix}$ (1, 1) = SQRT(SUM(block_get_${dsuffix}$**2)) ! norm2 works only for real 1439 ENDIF 1440#:endfor 1441 ENDDO 1442 CALL dbcsr_tas_iterator_stop(iter) 1443 ENDIF 1444 1445 END SUBROUTINE 1446 1447 SUBROUTINE convert_to_new_pgrid(mp_comm_cart, matrix_in, matrix_out, move_data, nodata, optimize_pgrid) 1448 !! Convert a DBCSR matrix to a new process grid 1449 1450 INTEGER, INTENT(IN) :: mp_comm_cart 1451 !! new process grid 1452 TYPE(dbcsr_type), INTENT(INOUT), TARGET :: matrix_in 1453 TYPE(dbcsr_type), INTENT(OUT), POINTER :: matrix_out 1454 LOGICAL, INTENT(IN), OPTIONAL :: move_data, nodata 1455 !! memory optimization: move data such that matrix_in is empty on return. 1456 !! Data of matrix_in should not be copied to matrix_out 1457 LOGICAL, INTENT(IN), OPTIONAL :: optimize_pgrid 1458 !! Whether to change process grid 1459 1460 INTEGER :: & 1461 nbrows, nbcols, data_type, nproc, handle 1462 INTEGER, DIMENSION(2) :: pdims, pcoord 1463 INTEGER, DIMENSION(:), POINTER, CONTIGUOUS :: row_dist, col_dist, rbsize, rcsize 1464 TYPE(dbcsr_distribution_obj) :: dist, dist_old 1465 TYPE(dbcsr_mp_obj) :: mp_obj 1466 CHARACTER(len=default_string_length) :: name 1467 LOGICAL :: nodata_prv, optimize_pgrid_prv 1468 CHARACTER(LEN=*), PARAMETER :: routineN = 'convert_to_new_pgrid', & 1469 routineP = moduleN//':'//routineN 1470 1471 NULLIFY (row_dist, col_dist, rbsize, rcsize) 1472 1473 IF (PRESENT(optimize_pgrid)) THEN 1474 optimize_pgrid_prv = optimize_pgrid 1475 ELSE 1476 optimize_pgrid_prv = .TRUE. 1477 ENDIF 1478 1479 IF (.NOT. optimize_pgrid_prv) THEN 1480 matrix_out => matrix_in 1481 RETURN 1482 ENDIF 1483 1484 CALL timeset(routineN, handle) 1485 1486 IF (PRESENT(nodata)) THEN 1487 nodata_prv = nodata 1488 ELSE 1489 nodata_prv = .FALSE. 1490 ENDIF 1491 1492 ALLOCATE (matrix_out) 1493 1494 CALL dbcsr_get_info(matrix_in, nblkrows_total=nbrows, nblkcols_total=nbcols, & 1495 row_blk_size=rbsize, col_blk_size=rcsize, & 1496 data_type=data_type, distribution=dist_old, name=name) 1497 CALL mp_environ(nproc, pdims, pcoord, mp_comm_cart) 1498 1499 ALLOCATE (row_dist(nbrows), col_dist(nbcols)) 1500 CALL dbcsr_tas_default_distvec(nbrows, pdims(1), rbsize, row_dist) 1501 CALL dbcsr_tas_default_distvec(nbcols, pdims(2), rcsize, col_dist) 1502 1503 mp_obj = dbcsr_mp_environ(mp_comm_cart) 1504 CALL dbcsr_distribution_new(dist, mp_obj, row_dist, col_dist, reuse_arrays=.TRUE.) 1505 CALL dbcsr_mp_release(mp_obj) 1506 1507 CALL dbcsr_create(matrix_out, name, dist, dbcsr_type_no_symmetry, rbsize, rcsize, data_type=data_type) 1508 CALL dbcsr_distribution_release(dist) 1509 1510 IF (.NOT. nodata_prv) THEN 1511 CALL dbcsr_redistribute(matrix_in, matrix_out) 1512 IF (PRESENT(move_data)) THEN 1513 IF (move_data) CALL dbcsr_clear(matrix_in) 1514 ENDIF 1515 ENDIF 1516 1517 CALL timestop(handle) 1518 END SUBROUTINE 1519 1520 SUBROUTINE dbcsr_tas_batched_mm_init(matrix) 1521 TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix 1522 matrix%do_batched = 1 1523 ALLOCATE (matrix%mm_storage) 1524 matrix%mm_storage%batched_out = .FALSE. 1525 END SUBROUTINE 1526 1527 SUBROUTINE dbcsr_tas_batched_mm_finalize(matrix) 1528 TYPE(dbcsr_tas_type), INTENT(INOUT) :: matrix 1529 1530 IF (matrix%do_batched == 0) RETURN 1531 ASSOCIATE (storage => matrix%mm_storage) 1532 IF (storage%batched_out) THEN 1533 CALL dbcsr_tas_merge(storage%store_batched%matrix, storage%store_batched_repl, move_data=.TRUE.) 1534 CALL dbcsr_scale(matrix%matrix, storage%batched_beta) 1535 CALL dbcsr_tas_reshape(storage%store_batched, matrix, summation=.TRUE., & 1536 transposed=storage%batched_trans, move_data=.TRUE.) 1537 CALL dbcsr_tas_destroy(storage%store_batched) 1538 DEALLOCATE (storage%store_batched) 1539 ENDIF 1540 1541 IF (ASSOCIATED(storage%store_batched_repl)) THEN 1542 CALL dbcsr_tas_destroy(storage%store_batched_repl) 1543 DEALLOCATE (storage%store_batched_repl) 1544 ENDIF 1545 1546 storage%batched_out = .FALSE. 1547 END ASSOCIATE 1548 1549 DEALLOCATE (matrix%mm_storage) 1550 matrix%do_batched = 0 1551 1552 END SUBROUTINE 1553 1554END MODULE 1555