1!--------------------------------------------------------------------------------------------------! 2! Copyright (C) by the DBCSR developers group - All rights reserved ! 3! This file is part of the DBCSR library. ! 4! ! 5! For information on the license, see the LICENSE file. ! 6! For further information please visit https://dbcsr.cp2k.org ! 7! SPDX-License-Identifier: GPL-2.0+ ! 8!--------------------------------------------------------------------------------------------------! 9 10MODULE dbcsr_tensor 11 !! DBCSR tensor framework for block-sparse tensor contraction. 12 !! Representation of n-rank tensors as DBCSR tall-and-skinny matrices. 13 !! Support for arbitrary redistribution between different representations. 14 !! Support for arbitrary tensor contractions 15 !! \todo implement checks and error messages 16 17#:include "dbcsr_tensor.fypp" 18#:set maxdim = maxrank 19#:set ndims = range(2,maxdim+1) 20 21 USE dbcsr_allocate_wrap, ONLY: & 22 allocate_any 23 USE dbcsr_array_list_methods, ONLY: & 24 get_arrays, reorder_arrays, get_ith_array, array_list, array_sublist, check_equal, array_eq_i, & 25 create_array_list, destroy_array_list, sizes_of_arrays 26 USE dbcsr_api, ONLY: & 27 dbcsr_type, dbcsr_iterator_type, dbcsr_iterator_blocks_left, & 28 dbcsr_iterator_next_block, dbcsr_iterator_start, dbcsr_iterator_stop, & 29 dbcsr_transpose, dbcsr_no_transpose, dbcsr_scalar, dbcsr_put_block, & 30 ${uselist(dtype_float_param)}$, dbcsr_clear, & 31 dbcsr_release, dbcsr_desymmetrize, dbcsr_has_symmetry 32 USE dbcsr_tas_types, ONLY: & 33 dbcsr_tas_split_info 34 USE dbcsr_tas_base, ONLY: & 35 dbcsr_tas_copy, dbcsr_tas_finalize, dbcsr_tas_get_data_type, dbcsr_tas_get_info, dbcsr_tas_info 36 USE dbcsr_tas_mm, ONLY: & 37 dbcsr_tas_multiply, dbcsr_tas_batched_mm_init, dbcsr_tas_batched_mm_finalize, dbcsr_tas_result_index, & 38 dbcsr_tas_batched_mm_complete, dbcsr_tas_set_batched_state 39 USE dbcsr_tensor_block, ONLY: & 40 dbcsr_t_iterator_type, dbcsr_t_get_block, dbcsr_t_put_block, dbcsr_t_iterator_start, & 41 dbcsr_t_iterator_blocks_left, dbcsr_t_iterator_stop, dbcsr_t_iterator_next_block, & 42 ndims_iterator, dbcsr_t_reserve_blocks, block_nd, destroy_block 43 USE dbcsr_tensor_index, ONLY: & 44 dbcsr_t_get_mapping_info, nd_to_2d_mapping, dbcsr_t_inverse_order, permute_index, get_nd_indices_tensor, & 45 ndims_mapping_row, ndims_mapping_column, ndims_mapping 46 USE dbcsr_tensor_types, ONLY: & 47 dbcsr_t_create, dbcsr_t_get_data_type, dbcsr_t_type, ndims_tensor, dims_tensor, & 48 dbcsr_t_distribution_type, dbcsr_t_distribution, dbcsr_t_nd_mp_comm, dbcsr_t_destroy, & 49 dbcsr_t_distribution_destroy, dbcsr_t_distribution_new_expert, dbcsr_t_get_stored_coordinates, & 50 blk_dims_tensor, dbcsr_t_hold, dbcsr_t_pgrid_type, mp_environ_pgrid, dbcsr_t_filter, & 51 dbcsr_t_clear, dbcsr_t_finalize, dbcsr_t_get_num_blocks, dbcsr_t_scale, & 52 dbcsr_t_get_num_blocks_total, dbcsr_t_get_info, ndims_matrix_row, ndims_matrix_column, & 53 dbcsr_t_max_nblks_local, dbcsr_t_default_distvec, dbcsr_t_contraction_storage, dbcsr_t_nblks_total, & 54 dbcsr_t_distribution_new, dbcsr_t_copy_contraction_storage, dbcsr_t_pgrid_destroy 55 USE dbcsr_kinds, ONLY: & 56 ${uselist(dtype_float_prec)}$, default_string_length, int_8, dp 57 USE dbcsr_mpiwrap, ONLY: & 58 mp_environ, mp_max, mp_sum, mp_comm_free, mp_cart_create, mp_sync 59 USE dbcsr_toollib, ONLY: & 60 sort 61 USE dbcsr_tensor_reshape, ONLY: & 62 dbcsr_t_reshape 63 USE dbcsr_tas_split, ONLY: & 64 dbcsr_tas_mp_comm, rowsplit, colsplit, dbcsr_tas_info_hold, dbcsr_tas_release_info, default_nsplit_accept_ratio, & 65 default_pdims_accept_ratio, dbcsr_tas_create_split 66 USE dbcsr_data_types, ONLY: & 67 dbcsr_scalar_type 68 USE dbcsr_tensor_split, ONLY: & 69 dbcsr_t_split_copyback, dbcsr_t_make_compatible_blocks, dbcsr_t_crop 70 USE dbcsr_tensor_io, ONLY: & 71 dbcsr_t_write_tensor_info, dbcsr_t_write_tensor_dist, prep_output_unit, dbcsr_t_write_split_info 72 USE dbcsr_dist_operations, ONLY: & 73 checker_tr 74 USE dbcsr_toollib, ONLY: & 75 swap 76 77#include "base/dbcsr_base_uses.f90" 78 79 IMPLICIT NONE 80 PRIVATE 81 CHARACTER(len=*), PARAMETER, PRIVATE :: moduleN = 'dbcsr_tensor' 82 83 PUBLIC :: & 84 dbcsr_t_contract, & 85 dbcsr_t_copy, & 86 dbcsr_t_get_block, & 87 dbcsr_t_get_stored_coordinates, & 88 dbcsr_t_inverse_order, & 89 dbcsr_t_iterator_blocks_left, & 90 dbcsr_t_iterator_next_block, & 91 dbcsr_t_iterator_start, & 92 dbcsr_t_iterator_stop, & 93 dbcsr_t_iterator_type, & 94 dbcsr_t_put_block, & 95 dbcsr_t_reserve_blocks, & 96 dbcsr_t_copy_matrix_to_tensor, & 97 dbcsr_t_copy_tensor_to_matrix, & 98 dbcsr_t_contract_index, & 99 dbcsr_t_batched_contract_init, & 100 dbcsr_t_batched_contract_finalize 101 102CONTAINS 103 104 SUBROUTINE dbcsr_t_copy(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr) 105 !! Copy tensor data. 106 !! Redistributes tensor data according to distributions of target and source tensor. 107 !! Permutes tensor index according to `order` argument (if present). 108 !! Source and target tensor formats are arbitrary as long as the following requirements are met: 109 !! * source and target tensors have the same rank and the same sizes in each dimension in terms 110 !! of tensor elements (block sizes don't need to be the same). 111 !! If `order` argument is present, sizes must match after index permutation. 112 !! OR 113 !! * target tensor is not yet created, in this case an exact copy of source tensor is returned. 114 115 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out 116 !! Source 117 !! Target 118 INTEGER, DIMENSION(ndims_tensor(tensor_in)), & 119 INTENT(IN), OPTIONAL :: order 120 !! Permutation of target tensor index. Exact same convention as order argument of RESHAPE intrinsic 121 LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data 122 INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), & 123 INTENT(IN), OPTIONAL :: bounds 124 !! crop tensor data: start and end index for each tensor dimension 125 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 126 INTEGER :: handle 127 128 CALL mp_sync(tensor_in%pgrid%mp_comm_2d) 129 CALL timeset("dbcsr_t_total", handle) 130 131 ! make sure that it is safe to use dbcsr_t_copy during a batched contraction 132 CALL dbcsr_tas_batched_mm_complete(tensor_in%matrix_rep, warn=.TRUE.) 133 CALL dbcsr_tas_batched_mm_complete(tensor_out%matrix_rep, warn=.TRUE.) 134 135 CALL dbcsr_t_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr) 136 CALL mp_sync(tensor_in%pgrid%mp_comm_2d) 137 CALL timestop(handle) 138 END SUBROUTINE 139 140 SUBROUTINE dbcsr_t_copy_expert(tensor_in, tensor_out, order, summation, bounds, move_data, unit_nr) 141 !! expert routine for copying a tensor. For internal use only. 142 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_in, tensor_out 143 INTEGER, DIMENSION(ndims_tensor(tensor_in)), & 144 INTENT(IN), OPTIONAL :: order 145 LOGICAL, INTENT(IN), OPTIONAL :: summation, move_data 146 INTEGER, DIMENSION(2, ndims_tensor(tensor_in)), & 147 INTENT(IN), OPTIONAL :: bounds 148 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 149 150 TYPE(dbcsr_t_type), POINTER :: in_tmp_1, in_tmp_2, & 151 in_tmp_3, out_tmp_1 152 INTEGER :: handle, unit_nr_prv 153 INTEGER, DIMENSION(:), ALLOCATABLE :: map1_in_1, map1_in_2, map2_in_1, map2_in_2 154 155 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy' 156 LOGICAL :: dist_compatible_tas, dist_compatible_tensor, & 157 summation_prv, new_in_1, new_in_2, & 158 new_in_3, new_out_1, block_compatible, & 159 move_prv 160 TYPE(array_list) :: blk_sizes_in 161 162 CALL timeset(routineN, handle) 163 164 DBCSR_ASSERT(tensor_out%valid) 165 166 unit_nr_prv = prep_output_unit(unit_nr) 167 168 IF (PRESENT(move_data)) THEN 169 move_prv = move_data 170 ELSE 171 move_prv = .FALSE. 172 ENDIF 173 174 dist_compatible_tas = .FALSE. 175 dist_compatible_tensor = .FALSE. 176 block_compatible = .FALSE. 177 new_in_1 = .FALSE. 178 new_in_2 = .FALSE. 179 new_in_3 = .FALSE. 180 new_out_1 = .FALSE. 181 182 IF (PRESENT(summation)) THEN 183 summation_prv = summation 184 ELSE 185 summation_prv = .FALSE. 186 ENDIF 187 188 IF (PRESENT(bounds)) THEN 189 ALLOCATE (in_tmp_1) 190 CALL dbcsr_t_crop(tensor_in, in_tmp_1, bounds=bounds, move_data=move_prv) 191 new_in_1 = .TRUE. 192 move_prv = .TRUE. 193 ELSE 194 in_tmp_1 => tensor_in 195 ENDIF 196 197 IF (PRESENT(order)) THEN 198 CALL reorder_arrays(in_tmp_1%blk_sizes, blk_sizes_in, order=order) 199 block_compatible = check_equal(blk_sizes_in, tensor_out%blk_sizes) 200 ELSE 201 block_compatible = check_equal(in_tmp_1%blk_sizes, tensor_out%blk_sizes) 202 ENDIF 203 204 IF (.NOT. block_compatible) THEN 205 ALLOCATE (in_tmp_2, out_tmp_1) 206 CALL dbcsr_t_make_compatible_blocks(in_tmp_1, tensor_out, in_tmp_2, out_tmp_1, order=order, & 207 nodata2=.NOT. summation_prv, move_data=move_prv) 208 new_in_2 = .TRUE.; new_out_1 = .TRUE. 209 move_prv = .TRUE. 210 ELSE 211 in_tmp_2 => in_tmp_1 212 out_tmp_1 => tensor_out 213 ENDIF 214 215 IF (PRESENT(order)) THEN 216 ALLOCATE (in_tmp_3) 217 CALL dbcsr_t_permute_index(in_tmp_2, in_tmp_3, order) 218 new_in_3 = .TRUE. 219 ELSE 220 in_tmp_3 => in_tmp_2 221 ENDIF 222 223 ALLOCATE (map1_in_1(ndims_matrix_row(in_tmp_3))) 224 ALLOCATE (map1_in_2(ndims_matrix_column(in_tmp_3))) 225 CALL dbcsr_t_get_mapping_info(in_tmp_3%nd_index, map1_2d=map1_in_1, map2_2d=map1_in_2) 226 227 ALLOCATE (map2_in_1(ndims_matrix_row(out_tmp_1))) 228 ALLOCATE (map2_in_2(ndims_matrix_column(out_tmp_1))) 229 CALL dbcsr_t_get_mapping_info(out_tmp_1%nd_index, map1_2d=map2_in_1, map2_2d=map2_in_2) 230 231 IF (.NOT. PRESENT(order)) THEN 232 IF (array_eq_i(map1_in_1, map2_in_1) .AND. array_eq_i(map1_in_2, map2_in_2)) THEN 233 dist_compatible_tas = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist) 234 ELSEIF (array_eq_i([map1_in_1, map1_in_2], [map2_in_1, map2_in_2])) THEN 235 dist_compatible_tensor = check_equal(in_tmp_3%nd_dist, out_tmp_1%nd_dist) 236 ENDIF 237 ENDIF 238 239 IF (dist_compatible_tas) THEN 240 CALL dbcsr_tas_copy(out_tmp_1%matrix_rep, in_tmp_3%matrix_rep, summation) 241 IF (move_prv) CALL dbcsr_t_clear(in_tmp_3) 242 ELSEIF (dist_compatible_tensor) THEN 243 CALL dbcsr_t_copy_nocomm(in_tmp_3, out_tmp_1, summation) 244 IF (move_prv) CALL dbcsr_t_clear(in_tmp_3) 245 ELSE 246 CALL dbcsr_t_reshape(in_tmp_3, out_tmp_1, summation, move_data=move_prv) 247 ENDIF 248 249 IF (new_in_1) THEN 250 CALL dbcsr_t_destroy(in_tmp_1) 251 DEALLOCATE (in_tmp_1) 252 ENDIF 253 254 IF (new_in_2) THEN 255 CALL dbcsr_t_destroy(in_tmp_2) 256 DEALLOCATE (in_tmp_2) 257 ENDIF 258 259 IF (new_in_3) THEN 260 CALL dbcsr_t_destroy(in_tmp_3) 261 DEALLOCATE (in_tmp_3) 262 ENDIF 263 264 IF (new_out_1) THEN 265 IF (unit_nr_prv /= 0) THEN 266 CALL dbcsr_t_write_tensor_dist(out_tmp_1, unit_nr) 267 ENDIF 268 CALL dbcsr_t_split_copyback(out_tmp_1, tensor_out, summation) 269 CALL dbcsr_t_destroy(out_tmp_1) 270 DEALLOCATE (out_tmp_1) 271 ENDIF 272 273 CALL timestop(handle) 274 275 END SUBROUTINE 276 277 SUBROUTINE dbcsr_t_copy_nocomm(tensor_in, tensor_out, summation) 278 !! copy without communication, requires that both tensors have same process grid and distribution 279 280 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_in 281 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_out 282 LOGICAL, INTENT(IN), OPTIONAL :: summation 283 !! Whether to sum matrices b = a + b 284 TYPE(dbcsr_t_iterator_type) :: iter 285 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: ind_nd 286 INTEGER :: blk 287 TYPE(block_nd) :: blk_data 288 LOGICAL :: found 289 290 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy_nocomm' 291 INTEGER :: handle 292 293 CALL timeset(routineN, handle) 294 DBCSR_ASSERT(tensor_out%valid) 295 296 IF (PRESENT(summation)) THEN 297 IF (.NOT. summation) CALL dbcsr_t_clear(tensor_out) 298 ELSE 299 CALL dbcsr_t_clear(tensor_out) 300 ENDIF 301 302 CALL dbcsr_t_reserve_blocks(tensor_in, tensor_out) 303 304 CALL dbcsr_t_iterator_start(iter, tensor_in) 305 DO WHILE (dbcsr_t_iterator_blocks_left(iter)) 306 CALL dbcsr_t_iterator_next_block(iter, ind_nd, blk) 307 CALL dbcsr_t_get_block(tensor_in, ind_nd, blk_data, found) 308 DBCSR_ASSERT(found) 309 CALL dbcsr_t_put_block(tensor_out, ind_nd, blk_data, summation=summation) 310 CALL destroy_block(blk_data) 311 ENDDO 312 CALL dbcsr_t_iterator_stop(iter) 313 314 CALL timestop(handle) 315 END SUBROUTINE 316 317 SUBROUTINE dbcsr_t_copy_matrix_to_tensor(matrix_in, tensor_out, summation) 318 !! copy matrix to tensor. 319 320 TYPE(dbcsr_type), TARGET, INTENT(IN) :: matrix_in 321 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_out 322 LOGICAL, INTENT(IN), OPTIONAL :: summation 323 !! tensor_out = tensor_out + matrix_in 324 TYPE(dbcsr_type), POINTER :: matrix_in_desym 325 326 INTEGER, DIMENSION(2) :: ind_2d 327 REAL(KIND=real_8), ALLOCATABLE, DIMENSION(:, :) :: block_arr 328 REAL(KIND=real_8), DIMENSION(:, :), POINTER :: block 329 TYPE(dbcsr_iterator_type) :: iter 330 LOGICAL :: tr 331 332 INTEGER :: handle 333 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy_matrix_to_tensor' 334 335 CALL timeset(routineN, handle) 336 DBCSR_ASSERT(tensor_out%valid) 337 338 NULLIFY (block) 339 340 IF (dbcsr_has_symmetry(matrix_in)) THEN 341 ALLOCATE (matrix_in_desym) 342 CALL dbcsr_desymmetrize(matrix_in, matrix_in_desym) 343 ELSE 344 matrix_in_desym => matrix_in 345 ENDIF 346 347 IF (PRESENT(summation)) THEN 348 IF (.NOT. summation) CALL dbcsr_t_clear(tensor_out) 349 ELSE 350 CALL dbcsr_t_clear(tensor_out) 351 ENDIF 352 353 CALL dbcsr_t_reserve_blocks(matrix_in_desym, tensor_out) 354 355 CALL dbcsr_iterator_start(iter, matrix_in_desym) 356 DO WHILE (dbcsr_iterator_blocks_left(iter)) 357 CALL dbcsr_iterator_next_block(iter, ind_2d(1), ind_2d(2), block, tr) 358 CALL allocate_any(block_arr, source=block) 359 CALL dbcsr_t_put_block(tensor_out, ind_2d, SHAPE(block_arr), block_arr, summation=summation) 360 DEALLOCATE (block_arr) 361 ENDDO 362 CALL dbcsr_iterator_stop(iter) 363 364 IF (dbcsr_has_symmetry(matrix_in)) THEN 365 CALL dbcsr_release(matrix_in_desym) 366 DEALLOCATE (matrix_in_desym) 367 ENDIF 368 369 CALL timestop(handle) 370 371 END SUBROUTINE 372 373 SUBROUTINE dbcsr_t_copy_tensor_to_matrix(tensor_in, matrix_out, summation) 374 !! copy tensor to matrix 375 376 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_in 377 TYPE(dbcsr_type), INTENT(INOUT) :: matrix_out 378 LOGICAL, INTENT(IN), OPTIONAL :: summation 379 !! matrix_out = matrix_out + tensor_in 380 TYPE(dbcsr_t_iterator_type) :: iter 381 INTEGER :: blk, handle 382 INTEGER, DIMENSION(2) :: ind_2d 383 REAL(KIND=real_8), DIMENSION(:, :), ALLOCATABLE :: block 384 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_copy_tensor_to_matrix' 385 LOGICAL :: found 386 387 CALL timeset(routineN, handle) 388 389 IF (PRESENT(summation)) THEN 390 IF (.NOT. summation) CALL dbcsr_clear(matrix_out) 391 ELSE 392 CALL dbcsr_clear(matrix_out) 393 ENDIF 394 395 CALL dbcsr_t_reserve_blocks(tensor_in, matrix_out) 396 397 CALL dbcsr_t_iterator_start(iter, tensor_in) 398 DO WHILE (dbcsr_t_iterator_blocks_left(iter)) 399 CALL dbcsr_t_iterator_next_block(iter, ind_2d, blk) 400 IF (dbcsr_has_symmetry(matrix_out) .AND. checker_tr(ind_2d(1), ind_2d(2))) CYCLE 401 402 CALL dbcsr_t_get_block(tensor_in, ind_2d, block, found) 403 DBCSR_ASSERT(found) 404 405 IF (dbcsr_has_symmetry(matrix_out) .AND. ind_2d(1) > ind_2d(2)) THEN 406 CALL dbcsr_put_block(matrix_out, ind_2d(2), ind_2d(1), TRANSPOSE(block), summation=summation) 407 ELSE 408 CALL dbcsr_put_block(matrix_out, ind_2d(1), ind_2d(2), block, summation=summation) 409 ENDIF 410 DEALLOCATE (block) 411 ENDDO 412 CALL dbcsr_t_iterator_stop(iter) 413 414 CALL timestop(handle) 415 416 END SUBROUTINE 417 418 SUBROUTINE dbcsr_t_contract(alpha, tensor_1, tensor_2, beta, tensor_3, & 419 contract_1, notcontract_1, & 420 contract_2, notcontract_2, & 421 map_1, map_2, & 422 bounds_1, bounds_2, bounds_3, & 423 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, & 424 filter_eps, flop, move_data, retain_sparsity, unit_nr, log_verbose) 425 !! Contract tensors by multiplying matrix representations. 426 !! tensor_3(map_1, map_2) := alpha * tensor_1(notcontract_1, contract_1) 427 !! * tensor_2(contract_2, notcontract_2) 428 !! + beta * tensor_3(map_1, map_2) 429 !! 430 !! @note 431 !! note 1: block sizes of the corresponding indices need to be the same in all tensors. 432 !! 433 !! note 2: for best performance the tensors should have been created in matrix layouts 434 !! compatible with the contraction, e.g. tensor_1 should have been created with either 435 !! map1_2d == contract_1 and map2_2d == notcontract_1 or map1_2d == notcontract_1 and 436 !! map2_2d == contract_1 (the same with tensor_2 and contract_2 / notcontract_2 and with 437 !! tensor_3 and map_1 / map_2). 438 !! Furthermore the two largest tensors involved in the contraction should map both to either 439 !! tall or short matrices: the largest matrix dimension should be "on the same side" 440 !! and should have identical distribution (which is always the case if the distributions were 441 !! obtained with dbcsr_t_default_distvec). 442 !! 443 !! note 3: if the same tensor occurs in multiple contractions, a different tensor object should 444 !! be created for each contraction and the data should be copied between the tensors by use of 445 !! dbcsr_t_copy. If the same tensor object is used in multiple contractions, matrix layouts are 446 !! not compatible for all contractions (see note 2). 447 !! 448 !! note 4: automatic optimizations are enabled by using the feature of batched contraction, see 449 !! dbcsr_t_batched_contract_init, dbcsr_t_batched_contract_finalize. The arguments bounds_1, 450 !! bounds_2, bounds_3 give the index ranges of the batches. 451 !! @endnote 452 453 TYPE(dbcsr_scalar_type), INTENT(IN) :: alpha 454 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_1 455 !! first tensor (in) 456 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_2 457 !! second tensor (in) 458 TYPE(dbcsr_scalar_type), INTENT(IN) :: beta 459 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1 460 !! indices of tensor_1 to contract 461 INTEGER, DIMENSION(:), INTENT(IN) :: contract_2 462 !! indices of tensor_2 to contract (1:1 with contract_1) 463 INTEGER, DIMENSION(:), INTENT(IN) :: map_1 464 !! which indices of tensor_3 map to non-contracted indices of tensor_1 (1:1 with notcontract_1) 465 INTEGER, DIMENSION(:), INTENT(IN) :: map_2 466 !! which indices of tensor_3 map to non-contracted indices of tensor_2 (1:1 with notcontract_2) 467 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1 468 !! indices of tensor_1 not to contract 469 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2 470 !! indices of tensor_2 not to contract 471 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_3 472 !! contracted tensor (out) 473 INTEGER, DIMENSION(2, SIZE(contract_1)), & 474 INTENT(IN), OPTIONAL :: bounds_1 475 !! bounds corresponding to contract_1 AKA contract_2: start and end index of an index range over 476 !! which to contract. For use in batched contraction. 477 INTEGER, DIMENSION(2, SIZE(notcontract_1)), & 478 INTENT(IN), OPTIONAL :: bounds_2 479 !! bounds corresponding to notcontract_1: start and end index of an index range. 480 !! For use in batched contraction. 481 INTEGER, DIMENSION(2, SIZE(notcontract_2)), & 482 INTENT(IN), OPTIONAL :: bounds_3 483 !! bounds corresponding to notcontract_2: start and end index of an index range. 484 !! For use in batched contraction. 485 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist 486 !! Whether distribution should be optimized internally. In the current implementation this guarantees optimal parameters 487 !! only for dense matrices. 488 TYPE(dbcsr_t_pgrid_type), INTENT(OUT), & 489 POINTER, OPTIONAL :: pgrid_opt_1 490 !! Optionally return optimal process grid for tensor_1. This can be used to choose optimal process grids for subsequent 491 !! tensor contractions with tensors of similar shape and sparsity. Under some conditions, pgrid_opt_1 can not be returned, 492 !! in this case the pointer is not associated. 493 TYPE(dbcsr_t_pgrid_type), INTENT(OUT), & 494 POINTER, OPTIONAL :: pgrid_opt_2 495 !! Optionally return optimal process grid for tensor_2. 496 TYPE(dbcsr_t_pgrid_type), INTENT(OUT), & 497 POINTER, OPTIONAL :: pgrid_opt_3 498 !! Optionally return optimal process grid for tensor_3. 499 REAL(KIND=real_8), INTENT(IN), OPTIONAL :: filter_eps 500 !! As in DBCSR mm 501 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop 502 !! As in DBCSR mm 503 LOGICAL, INTENT(IN), OPTIONAL :: move_data 504 !! memory optimization: transfer data such that tensor_1 and tensor_2 are empty on return 505 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity 506 !! enforce the sparsity pattern of the existing tensor_3; default is no 507 INTEGER, OPTIONAL, INTENT(IN) :: unit_nr 508 !! output unit for logging 509 !! set it to -1 on ranks that should not write (and any valid unit number on ranks that should write output) 510 !! if 0 on ALL ranks, no output is written 511 LOGICAL, INTENT(IN), OPTIONAL :: log_verbose 512 !! verbose logging (for testing only) 513 514 INTEGER :: handle 515 516 CALL mp_sync(tensor_1%pgrid%mp_comm_2d) 517 CALL timeset("dbcsr_t_total", handle) 518 CALL dbcsr_t_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, & 519 contract_1, notcontract_1, & 520 contract_2, notcontract_2, & 521 map_1, map_2, & 522 bounds_1=bounds_1, & 523 bounds_2=bounds_2, & 524 bounds_3=bounds_3, & 525 optimize_dist=optimize_dist, & 526 pgrid_opt_1=pgrid_opt_1, & 527 pgrid_opt_2=pgrid_opt_2, & 528 pgrid_opt_3=pgrid_opt_3, & 529 filter_eps=filter_eps, & 530 flop=flop, & 531 move_data=move_data, & 532 retain_sparsity=retain_sparsity, & 533 unit_nr=unit_nr, & 534 log_verbose=log_verbose) 535 CALL mp_sync(tensor_1%pgrid%mp_comm_2d) 536 CALL timestop(handle) 537 538 END SUBROUTINE 539 540 SUBROUTINE dbcsr_t_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, & 541 contract_1, notcontract_1, & 542 contract_2, notcontract_2, & 543 map_1, map_2, & 544 bounds_1, bounds_2, bounds_3, & 545 optimize_dist, pgrid_opt_1, pgrid_opt_2, pgrid_opt_3, & 546 filter_eps, flop, move_data, retain_sparsity, & 547 nblks_local, result_index, unit_nr, log_verbose) 548 !! expert routine for tensor contraction. For internal use only. 549 TYPE(dbcsr_scalar_type), INTENT(IN) :: alpha 550 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_1 551 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_2 552 TYPE(dbcsr_scalar_type), INTENT(IN) :: beta 553 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1 554 INTEGER, DIMENSION(:), INTENT(IN) :: contract_2 555 INTEGER, DIMENSION(:), INTENT(IN) :: map_1 556 INTEGER, DIMENSION(:), INTENT(IN) :: map_2 557 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1 558 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2 559 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_3 560 INTEGER, DIMENSION(2, SIZE(contract_1)), & 561 INTENT(IN), OPTIONAL :: bounds_1 562 INTEGER, DIMENSION(2, SIZE(notcontract_1)), & 563 INTENT(IN), OPTIONAL :: bounds_2 564 INTEGER, DIMENSION(2, SIZE(notcontract_2)), & 565 INTENT(IN), OPTIONAL :: bounds_3 566 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist 567 TYPE(dbcsr_t_pgrid_type), INTENT(OUT), & 568 POINTER, OPTIONAL :: pgrid_opt_1 569 TYPE(dbcsr_t_pgrid_type), INTENT(OUT), & 570 POINTER, OPTIONAL :: pgrid_opt_2 571 TYPE(dbcsr_t_pgrid_type), INTENT(OUT), & 572 POINTER, OPTIONAL :: pgrid_opt_3 573 REAL(KIND=real_8), INTENT(IN), OPTIONAL :: filter_eps 574 INTEGER(KIND=int_8), INTENT(OUT), OPTIONAL :: flop 575 LOGICAL, INTENT(IN), OPTIONAL :: move_data 576 LOGICAL, INTENT(IN), OPTIONAL :: retain_sparsity 577 INTEGER, INTENT(OUT), OPTIONAL :: nblks_local 578 !! number of local blocks on this MPI rank 579 INTEGER, DIMENSION(dbcsr_t_max_nblks_local(tensor_3), ndims_tensor(tensor_3)), & 580 OPTIONAL, INTENT(OUT) :: result_index 581 !! get indices of non-zero tensor blocks for tensor_3 without actually performing contraction 582 !! this is an estimate based on block norm multiplication 583 INTEGER, OPTIONAL, INTENT(IN) :: unit_nr 584 LOGICAL, INTENT(IN), OPTIONAL :: log_verbose 585 586 TYPE(dbcsr_t_type), POINTER :: tensor_contr_1, tensor_contr_2, tensor_contr_3 587 TYPE(dbcsr_t_type), TARGET :: tensor_algn_1, tensor_algn_2, tensor_algn_3 588 TYPE(dbcsr_t_type), POINTER :: tensor_crop_1, tensor_crop_2 589 TYPE(dbcsr_t_type), POINTER :: tensor_small, tensor_large 590 591 INTEGER(int_8), DIMENSION(:, :), ALLOCATABLE :: result_index_2d 592 LOGICAL :: assert_stmt, tensors_remapped 593 INTEGER :: data_type, max_mm_dim, max_tensor, mp_comm, & 594 iblk, nblk, unit_nr_prv, ref_tensor, mp_comm_opt, & 595 handle 596 INTEGER, DIMENSION(SIZE(contract_1)) :: contract_1_mod 597 INTEGER, DIMENSION(SIZE(notcontract_1)) :: notcontract_1_mod 598 INTEGER, DIMENSION(SIZE(contract_2)) :: contract_2_mod 599 INTEGER, DIMENSION(SIZE(notcontract_2)) :: notcontract_2_mod 600 INTEGER, DIMENSION(SIZE(map_1)) :: map_1_mod 601 INTEGER, DIMENSION(SIZE(map_2)) :: map_2_mod 602 CHARACTER(LEN=1) :: trans_1, trans_2, trans_3 603 LOGICAL :: new_1, new_2, new_3, move_data_1, move_data_2 604 INTEGER :: ndims1, ndims2, ndims3 605 INTEGER :: occ_1, occ_2 606 INTEGER, DIMENSION(:), ALLOCATABLE :: dims1, dims2, dims3 607 608 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_contract' 609 CHARACTER(LEN=1), DIMENSION(:), ALLOCATABLE :: indchar1, indchar2, indchar3, indchar1_mod, & 610 indchar2_mod, indchar3_mod 611 CHARACTER(LEN=1), DIMENSION(15) :: alph = & 612 ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n', 'o'] 613 INTEGER, DIMENSION(2, ndims_tensor(tensor_1)) :: bounds_t1 614 INTEGER, DIMENSION(2, ndims_tensor(tensor_2)) :: bounds_t2 615 LOGICAL :: do_crop_1, do_crop_2, do_write_3, nodata_3, do_batched, pgrid_changed, & 616 pgrid_changed_any, do_change_pgrid(2) 617 TYPE(dbcsr_tas_split_info) :: split_opt, split, split_opt_avg 618 INTEGER, DIMENSION(2) :: pdims_2d_opt, pdims_2d, pcoord_2d, pdims_sub, pdims_sub_opt 619 LOGICAL, DIMENSION(2) :: periods_2d 620 REAL(real_8) :: pdim_ratio, pdim_ratio_opt 621 622 NULLIFY (tensor_contr_1, tensor_contr_2, tensor_contr_3, tensor_crop_1, tensor_crop_2, & 623 tensor_small) 624 625 CALL timeset(routineN, handle) 626 627 DBCSR_ASSERT(tensor_1%valid) 628 DBCSR_ASSERT(tensor_2%valid) 629 DBCSR_ASSERT(tensor_3%valid) 630 631 assert_stmt = SIZE(contract_1) .EQ. SIZE(contract_2) 632 DBCSR_ASSERT(assert_stmt) 633 634 assert_stmt = SIZE(map_1) .EQ. SIZE(notcontract_1) 635 DBCSR_ASSERT(assert_stmt) 636 637 assert_stmt = SIZE(map_2) .EQ. SIZE(notcontract_2) 638 DBCSR_ASSERT(assert_stmt) 639 640 assert_stmt = SIZE(notcontract_1) + SIZE(contract_1) .EQ. ndims_tensor(tensor_1) 641 DBCSR_ASSERT(assert_stmt) 642 643 assert_stmt = SIZE(notcontract_2) + SIZE(contract_2) .EQ. ndims_tensor(tensor_2) 644 DBCSR_ASSERT(assert_stmt) 645 646 assert_stmt = SIZE(map_1) + SIZE(map_2) .EQ. ndims_tensor(tensor_3) 647 DBCSR_ASSERT(assert_stmt) 648 649 assert_stmt = dbcsr_t_get_data_type(tensor_1) .EQ. dbcsr_t_get_data_type(tensor_2) 650 DBCSR_ASSERT(assert_stmt) 651 652 unit_nr_prv = prep_output_unit(unit_nr) 653 654 IF (PRESENT(flop)) flop = 0 655 IF (PRESENT(result_index)) result_index = 0 656 IF (PRESENT(nblks_local)) nblks_local = 0 657 658 IF (PRESENT(move_data)) THEN 659 move_data_1 = move_data 660 move_data_2 = move_data 661 ELSE 662 move_data_1 = .FALSE. 663 move_data_2 = .FALSE. 664 ENDIF 665 666 nodata_3 = .TRUE. 667 IF (PRESENT(retain_sparsity)) THEN 668 IF (retain_sparsity) nodata_3 = .FALSE. 669 ENDIF 670 671 CALL dbcsr_t_map_bounds_to_tensors(tensor_1, tensor_2, & 672 contract_1, notcontract_1, & 673 contract_2, notcontract_2, & 674 bounds_t1, bounds_t2, & 675 bounds_1=bounds_1, bounds_2=bounds_2, bounds_3=bounds_3, & 676 do_crop_1=do_crop_1, do_crop_2=do_crop_2) 677 678 IF (do_crop_1) THEN 679 ALLOCATE (tensor_crop_1) 680 CALL dbcsr_t_crop(tensor_1, tensor_crop_1, bounds_t1, move_data=move_data_1) 681 move_data_1 = .TRUE. 682 ELSE 683 tensor_crop_1 => tensor_1 684 ENDIF 685 686 IF (do_crop_2) THEN 687 ALLOCATE (tensor_crop_2) 688 CALL dbcsr_t_crop(tensor_2, tensor_crop_2, bounds_t2, move_data=move_data_2) 689 move_data_2 = .TRUE. 690 ELSE 691 tensor_crop_2 => tensor_2 692 ENDIF 693 694 ! shortcut for empty tensors 695 ! this is needed to avoid unnecessary work in case user contracts different portions of a 696 ! tensor consecutively to save memory 697 mp_comm = tensor_crop_1%pgrid%mp_comm_2d 698 occ_1 = dbcsr_t_get_num_blocks(tensor_crop_1) 699 CALL mp_max(occ_1, mp_comm) 700 occ_2 = dbcsr_t_get_num_blocks(tensor_crop_2) 701 CALL mp_max(occ_2, mp_comm) 702 703 IF (occ_1 == 0 .OR. occ_2 == 0) THEN 704 CALL dbcsr_t_scale(tensor_3, beta) 705 IF (do_crop_1) THEN 706 CALL dbcsr_t_destroy(tensor_crop_1) 707 DEALLOCATE (tensor_crop_1) 708 ENDIF 709 IF (do_crop_2) THEN 710 CALL dbcsr_t_destroy(tensor_crop_2) 711 DEALLOCATE (tensor_crop_2) 712 ENDIF 713 714 CALL timestop(handle) 715 RETURN 716 ENDIF 717 718 IF (unit_nr_prv /= 0) THEN 719 IF (unit_nr_prv > 0) THEN 720 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 721 WRITE (unit_nr_prv, '(A,1X,A,1X,A,1X,A,1X,A,1X,A)') "DBCSR TENSOR CONTRACTION:", & 722 TRIM(tensor_crop_1%name), 'x', TRIM(tensor_crop_2%name), '=', TRIM(tensor_3%name) 723 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 724 ENDIF 725 CALL dbcsr_t_write_tensor_info(tensor_crop_1, unit_nr_prv, full_info=log_verbose) 726 CALL dbcsr_t_write_tensor_dist(tensor_crop_1, unit_nr_prv) 727 CALL dbcsr_t_write_tensor_info(tensor_crop_2, unit_nr_prv, full_info=log_verbose) 728 CALL dbcsr_t_write_tensor_dist(tensor_crop_2, unit_nr_prv) 729 ENDIF 730 731 data_type = dbcsr_t_get_data_type(tensor_crop_1) 732 733 ! align tensor index with data, tensor data is not modified 734 ndims1 = ndims_tensor(tensor_crop_1) 735 ndims2 = ndims_tensor(tensor_crop_2) 736 ndims3 = ndims_tensor(tensor_3) 737 ALLOCATE (indchar1(ndims1), indchar1_mod(ndims1)) 738 ALLOCATE (indchar2(ndims2), indchar2_mod(ndims2)) 739 ALLOCATE (indchar3(ndims3), indchar3_mod(ndims3)) 740 741 ! labeling tensor index with letters 742 743 indchar1([notcontract_1, contract_1]) = alph(1:ndims1) ! arb. choice 744 indchar2(notcontract_2) = alph(ndims1 + 1:ndims1 + SIZE(notcontract_2)) ! arb. choice 745 indchar2(contract_2) = indchar1(contract_1) 746 indchar3(map_1) = indchar1(notcontract_1) 747 indchar3(map_2) = indchar2(notcontract_2) 748 749 IF (unit_nr_prv /= 0) CALL dbcsr_t_print_contraction_index(tensor_crop_1, indchar1, & 750 tensor_crop_2, indchar2, & 751 tensor_3, indchar3, unit_nr_prv) 752 IF (unit_nr_prv > 0) THEN 753 WRITE (unit_nr_prv, '(T2,A)') "aligning tensor index with data" 754 ENDIF 755 756 CALL align_tensor(tensor_crop_1, contract_1, notcontract_1, & 757 tensor_algn_1, contract_1_mod, notcontract_1_mod, indchar1, indchar1_mod) 758 759 CALL align_tensor(tensor_crop_2, contract_2, notcontract_2, & 760 tensor_algn_2, contract_2_mod, notcontract_2_mod, indchar2, indchar2_mod) 761 762 CALL align_tensor(tensor_3, map_1, map_2, & 763 tensor_algn_3, map_1_mod, map_2_mod, indchar3, indchar3_mod) 764 765 IF (unit_nr_prv /= 0) CALL dbcsr_t_print_contraction_index(tensor_algn_1, indchar1_mod, & 766 tensor_algn_2, indchar2_mod, & 767 tensor_algn_3, indchar3_mod, unit_nr_prv) 768 769 ALLOCATE (dims1(ndims1)) 770 ALLOCATE (dims2(ndims2)) 771 ALLOCATE (dims3(ndims3)) 772 773 ! ideally we should consider block sizes and occupancy to measure tensor sizes but current solution should work for most 774 ! cases and is more elegant. Note that we can not easily consider occupancy since it is unknown for result tensor 775 CALL blk_dims_tensor(tensor_crop_1, dims1) 776 CALL blk_dims_tensor(tensor_crop_2, dims2) 777 CALL blk_dims_tensor(tensor_3, dims3) 778 779 max_mm_dim = MAXLOC([PRODUCT(INT(dims1(notcontract_1), int_8)), & 780 PRODUCT(INT(dims1(contract_1), int_8)), & 781 PRODUCT(INT(dims2(notcontract_2), int_8))], DIM=1) 782 max_tensor = MAXLOC([PRODUCT(INT(dims1, int_8)), PRODUCT(INT(dims2, int_8)), PRODUCT(INT(dims3, int_8))], DIM=1) 783 SELECT CASE (max_mm_dim) 784 CASE (1) 785 IF (unit_nr_prv > 0) THEN 786 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 3; small tensor: 2" 787 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices" 788 ENDIF 789 CALL index_linked_sort(contract_1_mod, contract_2_mod) 790 CALL index_linked_sort(map_2_mod, notcontract_2_mod) 791 SELECT CASE (max_tensor) 792 CASE (1) 793 CALL index_linked_sort(notcontract_1_mod, map_1_mod) 794 CASE (3) 795 CALL index_linked_sort(map_1_mod, notcontract_1_mod) 796 CASE DEFAULT 797 DBCSR_ABORT("should not happen") 798 END SELECT 799 800 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_3, tensor_contr_1, tensor_contr_3, & 801 contract_1_mod, notcontract_1_mod, map_2_mod, map_1_mod, & 802 trans_1, trans_3, new_1, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, & 803 move_data_1=move_data_1, unit_nr=unit_nr_prv) 804 805 CALL reshape_mm_small(tensor_algn_2, contract_2_mod, notcontract_2_mod, tensor_contr_2, trans_2, & 806 new_2, move_data=move_data_2, unit_nr=unit_nr_prv) 807 808 SELECT CASE (ref_tensor) 809 CASE (1) 810 tensor_large => tensor_contr_1 811 CASE (2) 812 tensor_large => tensor_contr_3 813 END SELECT 814 tensor_small => tensor_contr_2 815 816 CASE (2) 817 IF (unit_nr_prv > 0) THEN 818 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 1, 2; small tensor: 3" 819 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices" 820 ENDIF 821 822 CALL index_linked_sort(notcontract_1_mod, map_1_mod) 823 CALL index_linked_sort(notcontract_2_mod, map_2_mod) 824 SELECT CASE (max_tensor) 825 CASE (1) 826 CALL index_linked_sort(contract_1_mod, contract_2_mod) 827 CASE (2) 828 CALL index_linked_sort(contract_2_mod, contract_1_mod) 829 CASE DEFAULT 830 DBCSR_ABORT("should not happen") 831 END SELECT 832 833 CALL reshape_mm_compatible(tensor_algn_1, tensor_algn_2, tensor_contr_1, tensor_contr_2, & 834 notcontract_1_mod, contract_1_mod, notcontract_2_mod, contract_2_mod, & 835 trans_1, trans_2, new_1, new_2, ref_tensor, optimize_dist=optimize_dist, & 836 move_data_1=move_data_1, move_data_2=move_data_2, unit_nr=unit_nr_prv) 837 CALL invert_transpose_flag(trans_1) 838 839 CALL reshape_mm_small(tensor_algn_3, map_1_mod, map_2_mod, tensor_contr_3, trans_3, & 840 new_3, nodata=nodata_3, unit_nr=unit_nr_prv) 841 842 SELECT CASE (ref_tensor) 843 CASE (1) 844 tensor_large => tensor_contr_1 845 CASE (2) 846 tensor_large => tensor_contr_2 847 END SELECT 848 tensor_small => tensor_contr_3 849 850 CASE (3) 851 IF (unit_nr_prv > 0) THEN 852 WRITE (unit_nr_prv, '(T2,A)') "large tensors: 2, 3; small tensor: 1" 853 WRITE (unit_nr_prv, '(T2,A)') "sorting contraction indices" 854 ENDIF 855 CALL index_linked_sort(map_1_mod, notcontract_1_mod) 856 CALL index_linked_sort(contract_2_mod, contract_1_mod) 857 SELECT CASE (max_tensor) 858 CASE (2) 859 CALL index_linked_sort(notcontract_2_mod, map_2_mod) 860 CASE (3) 861 CALL index_linked_sort(map_2_mod, notcontract_2_mod) 862 CASE DEFAULT 863 DBCSR_ABORT("should not happen") 864 END SELECT 865 866 CALL reshape_mm_compatible(tensor_algn_2, tensor_algn_3, tensor_contr_2, tensor_contr_3, & 867 contract_2_mod, notcontract_2_mod, map_1_mod, map_2_mod, & 868 trans_2, trans_3, new_2, new_3, ref_tensor, nodata2=nodata_3, optimize_dist=optimize_dist, & 869 move_data_1=move_data_2, unit_nr=unit_nr_prv) 870 871 CALL invert_transpose_flag(trans_2) 872 CALL invert_transpose_flag(trans_3) 873 874 CALL reshape_mm_small(tensor_algn_1, notcontract_1_mod, contract_1_mod, tensor_contr_1, & 875 trans_1, new_1, move_data=move_data_1, unit_nr=unit_nr_prv) 876 877 SELECT CASE (ref_tensor) 878 CASE (1) 879 tensor_large => tensor_contr_2 880 CASE (2) 881 tensor_large => tensor_contr_3 882 END SELECT 883 tensor_small => tensor_contr_1 884 885 END SELECT 886 887 IF (unit_nr_prv /= 0) CALL dbcsr_t_print_contraction_index(tensor_contr_1, indchar1_mod, & 888 tensor_contr_2, indchar2_mod, & 889 tensor_contr_3, indchar3_mod, unit_nr_prv) 890 IF (unit_nr_prv /= 0) THEN 891 IF (new_1) CALL dbcsr_t_write_tensor_info(tensor_contr_1, unit_nr_prv, full_info=log_verbose) 892 IF (new_1) CALL dbcsr_t_write_tensor_dist(tensor_contr_1, unit_nr_prv) 893 IF (new_2) CALL dbcsr_t_write_tensor_info(tensor_contr_2, unit_nr_prv, full_info=log_verbose) 894 IF (new_2) CALL dbcsr_t_write_tensor_dist(tensor_contr_2, unit_nr_prv) 895 ENDIF 896 897 IF (.NOT. PRESENT(result_index)) THEN 898 CALL dbcsr_tas_multiply(trans_1, trans_2, trans_3, alpha, & 899 tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, & 900 beta, & 901 tensor_contr_3%matrix_rep, filter_eps=filter_eps, flop=flop, & 902 unit_nr=unit_nr_prv, log_verbose=log_verbose, & 903 split_opt=split_opt, & 904 move_data_a=move_data_1, move_data_b=move_data_2, retain_sparsity=retain_sparsity) 905 ELSE 906 907 CALL dbcsr_tas_result_index(trans_1, trans_2, trans_3, tensor_contr_1%matrix_rep, tensor_contr_2%matrix_rep, & 908 tensor_contr_3%matrix_rep, filter_eps=filter_eps, blk_ind=result_index_2d) 909 910 nblk = SIZE(result_index_2d, 1) 911 IF (PRESENT(nblks_local)) nblks_local = nblk 912 IF (SIZE(result_index, 1) < nblk) THEN 913 CALL dbcsr_abort(__LOCATION__, & 914 "allocated size of `result_index` is too small. This error occurs due to a high load imbalance of distributed tensor data.") 915 ENDIF 916 917 DO iblk = 1, nblk 918 result_index(iblk, :) = get_nd_indices_tensor(tensor_contr_3%nd_index_blk, result_index_2d(iblk, :)) 919 ENDDO 920 921 IF (new_1) THEN 922 CALL dbcsr_t_destroy(tensor_contr_1) 923 DEALLOCATE (tensor_contr_1) 924 ENDIF 925 IF (new_2) THEN 926 CALL dbcsr_t_destroy(tensor_contr_2) 927 DEALLOCATE (tensor_contr_2) 928 ENDIF 929 IF (new_3) THEN 930 CALL dbcsr_t_destroy(tensor_contr_3) 931 DEALLOCATE (tensor_contr_3) 932 ENDIF 933 IF (do_crop_1) THEN 934 CALL dbcsr_t_destroy(tensor_crop_1) 935 DEALLOCATE (tensor_crop_1) 936 ENDIF 937 IF (do_crop_2) THEN 938 CALL dbcsr_t_destroy(tensor_crop_2) 939 DEALLOCATE (tensor_crop_2) 940 ENDIF 941 942 CALL dbcsr_t_destroy(tensor_algn_1) 943 CALL dbcsr_t_destroy(tensor_algn_2) 944 CALL dbcsr_t_destroy(tensor_algn_3) 945 946 CALL timestop(handle) 947 RETURN 948 ENDIF 949 950 IF (PRESENT(pgrid_opt_1)) THEN 951 IF (.NOT. new_1) THEN 952 ALLOCATE (pgrid_opt_1) 953 pgrid_opt_1 = opt_pgrid(tensor_1, split_opt) 954 ENDIF 955 ENDIF 956 957 IF (PRESENT(pgrid_opt_2)) THEN 958 IF (.NOT. new_2) THEN 959 ALLOCATE (pgrid_opt_2) 960 pgrid_opt_2 = opt_pgrid(tensor_2, split_opt) 961 ENDIF 962 ENDIF 963 964 IF (PRESENT(pgrid_opt_3)) THEN 965 IF (.NOT. new_3) THEN 966 ALLOCATE (pgrid_opt_3) 967 pgrid_opt_3 = opt_pgrid(tensor_3, split_opt) 968 ENDIF 969 ENDIF 970 971 do_batched = tensor_small%matrix_rep%do_batched > 0 972 973 tensors_remapped = .FALSE. 974 IF (new_1 .OR. new_2 .OR. new_3) tensors_remapped = .TRUE. 975 976 IF (tensors_remapped .AND. do_batched) THEN 977 CALL dbcsr_warn(__LOCATION__, & 978 "Internal process grid optimization disabled because tensors are not in contraction-compatible format") 979 ENDIF 980 981 CALL mp_environ(tensor_large%pgrid%mp_comm_2d, 2, pdims_2d, pcoord_2d, periods_2d) 982 983 ! optimize process grid during batched contraction 984 do_change_pgrid(:) = .FALSE. 985 IF ((.NOT. tensors_remapped) .AND. do_batched) THEN 986 ASSOCIATE (storage => tensor_small%contraction_storage) 987 DBCSR_ASSERT(storage%static) 988 split = dbcsr_tas_info(tensor_large%matrix_rep) 989 do_change_pgrid(:) = & 990 update_contraction_storage(storage, split_opt, split) 991 992 IF (ANY(do_change_pgrid)) THEN 993 mp_comm_opt = dbcsr_tas_mp_comm(tensor_small%pgrid%mp_comm_2d, split_opt%split_rowcol, NINT(storage%nsplit_avg)) 994 CALL dbcsr_tas_create_split(split_opt_avg, mp_comm_opt, split_opt%split_rowcol, & 995 NINT(storage%nsplit_avg), own_comm=.TRUE.) 996 CALL mp_environ(split_opt_avg%mp_comm, 2, pdims_2d_opt, pcoord_2d, periods_2d) 997 ENDIF 998 999 END ASSOCIATE 1000 1001 IF (do_change_pgrid(1) .AND. .NOT. do_change_pgrid(2)) THEN 1002 ! check if new grid has better subgrid, if not there is no need to change process grid 1003 CALL mp_environ(split_opt_avg%mp_comm_group, 2, pdims_sub_opt, pcoord_2d, periods_2d) 1004 CALL mp_environ(split%mp_comm_group, 2, pdims_sub, pcoord_2d, periods_2d) 1005 1006 pdim_ratio = MAXVAL(REAL(pdims_sub, real_8))/MINVAL(pdims_sub) 1007 pdim_ratio_opt = MAXVAL(REAL(pdims_sub_opt, real_8))/MINVAL(pdims_sub_opt) 1008 IF (pdim_ratio/pdim_ratio_opt <= default_pdims_accept_ratio**2) THEN 1009 do_change_pgrid(1) = .FALSE. 1010 CALL dbcsr_tas_release_info(split_opt_avg) 1011 ENDIF 1012 ENDIF 1013 ENDIF 1014 1015 IF (unit_nr_prv /= 0) THEN 1016 do_write_3 = .TRUE. 1017 IF (tensor_contr_3%matrix_rep%do_batched > 0) THEN 1018 IF (tensor_contr_3%matrix_rep%mm_storage%batched_out) do_write_3 = .FALSE. 1019 ENDIF 1020 IF (do_write_3) THEN 1021 CALL dbcsr_t_write_tensor_info(tensor_contr_3, unit_nr_prv, full_info=log_verbose) 1022 CALL dbcsr_t_write_tensor_dist(tensor_contr_3, unit_nr_prv) 1023 ENDIF 1024 ENDIF 1025 1026 IF (new_3) THEN 1027 ! need redistribute if we created new tensor for tensor 3 1028 CALL dbcsr_t_scale(tensor_algn_3, beta) 1029 CALL dbcsr_t_copy_expert(tensor_contr_3, tensor_algn_3, summation=.TRUE., move_data=.TRUE.) 1030 IF (PRESENT(filter_eps)) CALL dbcsr_t_filter(tensor_algn_3, filter_eps) 1031 ! tensor_3 automatically has correct data because tensor_algn_3 contains a matrix 1032 ! pointer to data of tensor_3 1033 ENDIF 1034 1035 ! transfer contraction storage 1036 CALL dbcsr_t_copy_contraction_storage(tensor_contr_1, tensor_1) 1037 CALL dbcsr_t_copy_contraction_storage(tensor_contr_2, tensor_2) 1038 CALL dbcsr_t_copy_contraction_storage(tensor_contr_3, tensor_3) 1039 1040 IF (unit_nr_prv /= 0) THEN 1041 IF (new_3 .AND. do_write_3) CALL dbcsr_t_write_tensor_info(tensor_3, unit_nr_prv, full_info=log_verbose) 1042 IF (new_3 .AND. do_write_3) CALL dbcsr_t_write_tensor_dist(tensor_3, unit_nr_prv) 1043 ENDIF 1044 1045 CALL dbcsr_t_destroy(tensor_algn_1) 1046 CALL dbcsr_t_destroy(tensor_algn_2) 1047 CALL dbcsr_t_destroy(tensor_algn_3) 1048 1049 IF (do_crop_1) THEN 1050 CALL dbcsr_t_destroy(tensor_crop_1) 1051 DEALLOCATE (tensor_crop_1) 1052 ENDIF 1053 1054 IF (do_crop_2) THEN 1055 CALL dbcsr_t_destroy(tensor_crop_2) 1056 DEALLOCATE (tensor_crop_2) 1057 ENDIF 1058 1059 IF (new_1) THEN 1060 CALL dbcsr_t_destroy(tensor_contr_1) 1061 DEALLOCATE (tensor_contr_1) 1062 ENDIF 1063 IF (new_2) THEN 1064 CALL dbcsr_t_destroy(tensor_contr_2) 1065 DEALLOCATE (tensor_contr_2) 1066 ENDIF 1067 IF (new_3) THEN 1068 CALL dbcsr_t_destroy(tensor_contr_3) 1069 DEALLOCATE (tensor_contr_3) 1070 ENDIF 1071 1072 IF (PRESENT(move_data)) THEN 1073 IF (move_data) THEN 1074 CALL dbcsr_t_clear(tensor_1) 1075 CALL dbcsr_t_clear(tensor_2) 1076 ENDIF 1077 ENDIF 1078 1079 IF (unit_nr_prv > 0) THEN 1080 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 1081 WRITE (unit_nr_prv, '(A)') "TENSOR CONTRACTION DONE" 1082 WRITE (unit_nr_prv, '(A)') repeat("-", 80) 1083 ENDIF 1084 1085 IF (ANY(do_change_pgrid)) THEN 1086 pgrid_changed_any = .FALSE. 1087 SELECT CASE (max_mm_dim) 1088 CASE (1) 1089 IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN 1090 CALL dbcsr_t_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, & 1091 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, & 1092 pgrid_changed=pgrid_changed, & 1093 unit_nr=unit_nr_prv) 1094 IF (pgrid_changed) pgrid_changed_any = .TRUE. 1095 CALL dbcsr_t_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, & 1096 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, & 1097 pgrid_changed=pgrid_changed, & 1098 unit_nr=unit_nr_prv) 1099 IF (pgrid_changed) pgrid_changed_any = .TRUE. 1100 ENDIF 1101 IF (pgrid_changed_any) THEN 1102 IF (tensor_2%matrix_rep%do_batched == 3) THEN 1103 ! set flag that process grid has been optimized to make sure that no grid optimizations are done 1104 ! in TAS multiply algorithm 1105 CALL dbcsr_tas_batched_mm_complete(tensor_2%matrix_rep) 1106 ENDIF 1107 ENDIF 1108 CASE (2) 1109 IF (ALLOCATED(tensor_1%contraction_storage) .AND. ALLOCATED(tensor_2%contraction_storage)) THEN 1110 CALL dbcsr_t_change_pgrid_2d(tensor_1, tensor_1%pgrid%mp_comm_2d, pdims=pdims_2d_opt, & 1111 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, & 1112 pgrid_changed=pgrid_changed, & 1113 unit_nr=unit_nr_prv) 1114 IF (pgrid_changed) pgrid_changed_any = .TRUE. 1115 CALL dbcsr_t_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, & 1116 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, & 1117 pgrid_changed=pgrid_changed, & 1118 unit_nr=unit_nr_prv) 1119 IF (pgrid_changed) pgrid_changed_any = .TRUE. 1120 ENDIF 1121 IF (pgrid_changed_any) THEN 1122 IF (tensor_3%matrix_rep%do_batched == 3) THEN 1123 CALL dbcsr_tas_batched_mm_complete(tensor_3%matrix_rep) 1124 ENDIF 1125 ENDIF 1126 CASE (3) 1127 IF (ALLOCATED(tensor_2%contraction_storage) .AND. ALLOCATED(tensor_3%contraction_storage)) THEN 1128 CALL dbcsr_t_change_pgrid_2d(tensor_2, tensor_2%pgrid%mp_comm_2d, pdims=pdims_2d_opt, & 1129 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, & 1130 pgrid_changed=pgrid_changed, & 1131 unit_nr=unit_nr_prv) 1132 IF (pgrid_changed) pgrid_changed_any = .TRUE. 1133 CALL dbcsr_t_change_pgrid_2d(tensor_3, tensor_3%pgrid%mp_comm_2d, pdims=pdims_2d_opt, & 1134 nsplit=split_opt_avg%ngroup, dimsplit=split_opt_avg%split_rowcol, & 1135 pgrid_changed=pgrid_changed, & 1136 unit_nr=unit_nr_prv) 1137 IF (pgrid_changed) pgrid_changed_any = .TRUE. 1138 ENDIF 1139 IF (pgrid_changed_any) THEN 1140 IF (tensor_1%matrix_rep%do_batched == 3) THEN 1141 CALL dbcsr_tas_batched_mm_complete(tensor_1%matrix_rep) 1142 ENDIF 1143 ENDIF 1144 END SELECT 1145 CALL dbcsr_tas_release_info(split_opt_avg) 1146 ENDIF 1147 1148 IF ((.NOT. tensors_remapped) .AND. do_batched) THEN 1149 ! freeze TAS process grids if tensor grids were optimized 1150 CALL dbcsr_tas_set_batched_state(tensor_1%matrix_rep, opt_grid=.TRUE.) 1151 CALL dbcsr_tas_set_batched_state(tensor_2%matrix_rep, opt_grid=.TRUE.) 1152 CALL dbcsr_tas_set_batched_state(tensor_3%matrix_rep, opt_grid=.TRUE.) 1153 ENDIF 1154 1155 CALL dbcsr_tas_release_info(split_opt) 1156 1157 CALL timestop(handle) 1158 1159 END SUBROUTINE 1160 1161 SUBROUTINE align_tensor(tensor_in, contract_in, notcontract_in, & 1162 !! align tensor index with data 1163 tensor_out, contract_out, notcontract_out, indp_in, indp_out) 1164 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_in 1165 INTEGER, DIMENSION(:), INTENT(IN) :: contract_in, notcontract_in 1166 TYPE(dbcsr_t_type), INTENT(OUT) :: tensor_out 1167 INTEGER, DIMENSION(SIZE(contract_in)), & 1168 INTENT(OUT) :: contract_out 1169 INTEGER, DIMENSION(SIZE(notcontract_in)), & 1170 INTENT(OUT) :: notcontract_out 1171 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(IN) :: indp_in 1172 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_in)), INTENT(OUT) :: indp_out 1173 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: align 1174 1175 CALL dbcsr_t_align_index(tensor_in, tensor_out, order=align) 1176 contract_out = align(contract_in) 1177 notcontract_out = align(notcontract_in) 1178 indp_out(align) = indp_in 1179 1180 END SUBROUTINE 1181 1182 SUBROUTINE reshape_mm_compatible(tensor1, tensor2, tensor1_out, tensor2_out, ind1_free, ind1_linked, & 1183 ind2_free, ind2_linked, trans1, trans2, new1, new2, ref_tensor, & 1184 nodata1, nodata2, move_data_1, & 1185 move_data_2, optimize_dist, unit_nr) 1186 !! Prepare tensor for contraction: redistribute to a 2d format which can be contracted by 1187 !! matrix multiplication. This routine reshapes the two largest of the three tensors. Redistribution 1188 !! is avoided if tensors already in a consistent layout. 1189 1190 TYPE(dbcsr_t_type), TARGET, INTENT(INOUT) :: tensor1 1191 !! tensor 1 in 1192 TYPE(dbcsr_t_type), TARGET, INTENT(INOUT) :: tensor2 1193 !! tensor 2 in 1194 TYPE(dbcsr_t_type), POINTER, INTENT(OUT) :: tensor1_out, tensor2_out 1195 !! tensor 1 out 1196 !! tensor 2 out 1197 INTEGER, DIMENSION(:), INTENT(IN) :: ind1_free, ind2_free 1198 !! indices of tensor 1 that are "free" (not linked to any index of tensor 2) 1199 INTEGER, DIMENSION(:), INTENT(IN) :: ind1_linked, ind2_linked 1200 !! indices of tensor 1 that are linked to indices of tensor 2 1201 !! 1:1 correspondence with ind1_linked 1202 CHARACTER(LEN=1), INTENT(OUT) :: trans1, trans2 1203 !! transpose flag of matrix rep. of tensor 1 1204 !! transpose flag of matrix rep. tensor 2 1205 LOGICAL, INTENT(OUT) :: new1, new2 1206 !! whether a new tensor 1 was created 1207 !! whether a new tensor 2 was created 1208 INTEGER, INTENT(OUT) :: ref_tensor 1209 LOGICAL, INTENT(IN), OPTIONAL :: nodata1, nodata2 1210 !! don't copy data of tensor 1 1211 !! don't copy data of tensor 2 1212 LOGICAL, INTENT(INOUT), OPTIONAL :: move_data_1, move_data_2 1213 !! memory optimization: transfer data s.t. tensor1 may be empty on return 1214 !! memory optimization: transfer data s.t. tensor2 may be empty on return 1215 LOGICAL, INTENT(IN), OPTIONAL :: optimize_dist 1216 !! experimental: optimize distribution 1217 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 1218 !! output unit 1219 INTEGER :: compat1, compat1_old, compat2, compat2_old, & 1220 comm_2d, unit_nr_prv 1221 TYPE(array_list) :: dist_list 1222 INTEGER, DIMENSION(:), ALLOCATABLE :: mp_dims 1223 TYPE(dbcsr_t_distribution_type) :: dist_in 1224 INTEGER(KIND=int_8) :: nblkrows, nblkcols 1225 LOGICAL :: optimize_dist_prv 1226 INTEGER, DIMENSION(ndims_tensor(tensor1)) :: dims1 1227 INTEGER, DIMENSION(ndims_tensor(tensor2)) :: dims2 1228 1229 NULLIFY (tensor1_out, tensor2_out) 1230 1231 unit_nr_prv = prep_output_unit(unit_nr) 1232 1233 CALL blk_dims_tensor(tensor1, dims1) 1234 CALL blk_dims_tensor(tensor2, dims2) 1235 1236 IF (PRODUCT(int(dims1, int_8)) .GE. PRODUCT(int(dims2, int_8))) THEN 1237 ref_tensor = 1 1238 ELSE 1239 ref_tensor = 2 1240 ENDIF 1241 1242 IF (PRESENT(optimize_dist)) THEN 1243 optimize_dist_prv = optimize_dist 1244 ELSE 1245 optimize_dist_prv = .FALSE. 1246 ENDIF 1247 1248 compat1 = compat_map(tensor1%nd_index, ind1_linked) 1249 compat2 = compat_map(tensor2%nd_index, ind2_linked) 1250 compat1_old = compat1 1251 compat2_old = compat2 1252 1253 IF (unit_nr_prv > 0) THEN 1254 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1%name), ":" 1255 SELECT CASE (compat1) 1256 CASE (0) 1257 WRITE (unit_nr_prv, '(A)') "Not compatible" 1258 CASE (1) 1259 WRITE (unit_nr_prv, '(A)') "Normal" 1260 CASE (2) 1261 WRITE (unit_nr_prv, '(A)') "Transposed" 1262 END SELECT 1263 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2%name), ":" 1264 SELECT CASE (compat2) 1265 CASE (0) 1266 WRITE (unit_nr_prv, '(A)') "Not compatible" 1267 CASE (1) 1268 WRITE (unit_nr_prv, '(A)') "Normal" 1269 CASE (2) 1270 WRITE (unit_nr_prv, '(A)') "Transposed" 1271 END SELECT 1272 ENDIF 1273 1274 new1 = .FALSE. 1275 new2 = .FALSE. 1276 1277 IF (compat1 == 0 .OR. optimize_dist_prv) THEN 1278 new1 = .TRUE. 1279 ENDIF 1280 1281 IF (compat2 == 0 .OR. optimize_dist_prv) THEN 1282 new2 = .TRUE. 1283 ENDIF 1284 1285 IF (ref_tensor == 1) THEN ! tensor 1 is reference and tensor 2 is reshaped compatible with tensor 1 1286 IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape 1287 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor1%name) 1288 nblkrows = PRODUCT(INT(dims1(ind1_linked), KIND=int_8)) 1289 nblkcols = PRODUCT(INT(dims1(ind1_free), KIND=int_8)) 1290 comm_2d = dbcsr_tas_mp_comm(tensor1%pgrid%mp_comm_2d, nblkrows, nblkcols) 1291 ALLOCATE (tensor1_out) 1292 CALL dbcsr_t_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=comm_2d, & 1293 nodata=nodata1, move_data=move_data_1) 1294 CALL mp_comm_free(comm_2d) 1295 compat1 = 1 1296 ELSE 1297 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name) 1298 tensor1_out => tensor1 1299 ENDIF 1300 IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape 1301 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", & 1302 TRIM(tensor2%name), "compatible with", TRIM(tensor1%name) 1303 dist_in = dbcsr_t_distribution(tensor1_out) 1304 dist_list = array_sublist(dist_in%nd_dist, ind1_linked) 1305 IF (compat1 == 1) THEN ! linked index is first 2d dimension 1306 ! get distribution of linked index, tensor 2 must adopt this distribution 1307 ! get grid dimensions of linked index 1308 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid))) 1309 CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims) 1310 ALLOCATE (tensor2_out) 1311 CALL dbcsr_t_remap(tensor2, ind2_linked, ind2_free, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, & 1312 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata2, move_data=move_data_2) 1313 ELSEIF (compat1 == 2) THEN ! linked index is second 2d dimension 1314 ! get distribution of linked index, tensor 2 must adopt this distribution 1315 ! get grid dimensions of linked index 1316 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid))) 1317 CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims) 1318 ALLOCATE (tensor2_out) 1319 CALL dbcsr_t_remap(tensor2, ind2_free, ind2_linked, tensor2_out, comm_2d=dist_in%pgrid%mp_comm_2d, & 1320 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata2, move_data=move_data_2) 1321 ELSE 1322 DBCSR_ABORT("should not happen") 1323 ENDIF 1324 compat2 = compat1 1325 ELSE 1326 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name) 1327 tensor2_out => tensor2 1328 ENDIF 1329 ELSE ! tensor 2 is reference and tensor 1 is reshaped compatible with tensor 2 1330 IF (compat2 == 0 .OR. optimize_dist_prv) THEN ! tensor 2 is not contraction compatible --> reshape 1331 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor2%name) 1332 nblkrows = PRODUCT(INT(dims2(ind2_linked), KIND=int_8)) 1333 nblkcols = PRODUCT(INT(dims2(ind2_free), KIND=int_8)) 1334 comm_2d = dbcsr_tas_mp_comm(tensor2%pgrid%mp_comm_2d, nblkrows, nblkcols) 1335 ALLOCATE (tensor2_out) 1336 CALL dbcsr_t_remap(tensor2, ind2_linked, ind2_free, tensor2_out, nodata=nodata2, move_data=move_data_2) 1337 CALL mp_comm_free(comm_2d) 1338 compat2 = 1 1339 ELSE 1340 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor2%name) 1341 tensor2_out => tensor2 1342 ENDIF 1343 IF (compat1 == 0 .OR. optimize_dist_prv) THEN ! tensor 1 is not contraction compatible --> reshape 1344 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A,1X,A,1X,A)') "Redistribution of", TRIM(tensor1%name), & 1345 "compatible with", TRIM(tensor2%name) 1346 dist_in = dbcsr_t_distribution(tensor2_out) 1347 dist_list = array_sublist(dist_in%nd_dist, ind2_linked) 1348 IF (compat2 == 1) THEN 1349 ALLOCATE (mp_dims(ndims_mapping_row(dist_in%pgrid%nd_index_grid))) 1350 CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims1_2d=mp_dims) 1351 ALLOCATE (tensor1_out) 1352 CALL dbcsr_t_remap(tensor1, ind1_linked, ind1_free, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, & 1353 dist1=dist_list, mp_dims_1=mp_dims, nodata=nodata1, move_data=move_data_1) 1354 ELSEIF (compat2 == 2) THEN 1355 ALLOCATE (mp_dims(ndims_mapping_column(dist_in%pgrid%nd_index_grid))) 1356 CALL dbcsr_t_get_mapping_info(dist_in%pgrid%nd_index_grid, dims2_2d=mp_dims) 1357 ALLOCATE (tensor1_out) 1358 CALL dbcsr_t_remap(tensor1, ind1_free, ind1_linked, tensor1_out, comm_2d=dist_in%pgrid%mp_comm_2d, & 1359 dist2=dist_list, mp_dims_2=mp_dims, nodata=nodata1, move_data=move_data_1) 1360 ELSE 1361 DBCSR_ABORT("should not happen") 1362 ENDIF 1363 compat1 = compat2 1364 ELSE 1365 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor1%name) 1366 tensor1_out => tensor1 1367 ENDIF 1368 ENDIF 1369 1370 SELECT CASE (compat1) 1371 CASE (1) 1372 trans1 = dbcsr_no_transpose 1373 CASE (2) 1374 trans1 = dbcsr_transpose 1375 CASE DEFAULT 1376 DBCSR_ABORT("should not happen") 1377 END SELECT 1378 1379 SELECT CASE (compat2) 1380 CASE (1) 1381 trans2 = dbcsr_no_transpose 1382 CASE (2) 1383 trans2 = dbcsr_transpose 1384 CASE DEFAULT 1385 DBCSR_ABORT("should not happen") 1386 END SELECT 1387 1388 IF (unit_nr_prv > 0) THEN 1389 IF (compat1 .NE. compat1_old) THEN 1390 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor1_out%name), ":" 1391 SELECT CASE (compat1) 1392 CASE (0) 1393 WRITE (unit_nr_prv, '(A)') "Not compatible" 1394 CASE (1) 1395 WRITE (unit_nr_prv, '(A)') "Normal" 1396 CASE (2) 1397 WRITE (unit_nr_prv, '(A)') "Transposed" 1398 END SELECT 1399 ENDIF 1400 IF (compat2 .NE. compat2_old) THEN 1401 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor2_out%name), ":" 1402 SELECT CASE (compat2) 1403 CASE (0) 1404 WRITE (unit_nr_prv, '(A)') "Not compatible" 1405 CASE (1) 1406 WRITE (unit_nr_prv, '(A)') "Normal" 1407 CASE (2) 1408 WRITE (unit_nr_prv, '(A)') "Transposed" 1409 END SELECT 1410 ENDIF 1411 ENDIF 1412 1413 IF (new1 .AND. PRESENT(move_data_1)) move_data_1 = .TRUE. 1414 IF (new2 .AND. PRESENT(move_data_2)) move_data_2 = .TRUE. 1415 1416 END SUBROUTINE 1417 1418 SUBROUTINE reshape_mm_small(tensor_in, ind1, ind2, tensor_out, trans, new, nodata, move_data, unit_nr) 1419 !! Prepare tensor for contraction: redistribute to a 2d format which can be contracted by 1420 !! matrix multiplication. This routine reshapes the smallest of the three tensors. 1421 1422 TYPE(dbcsr_t_type), TARGET, INTENT(INOUT) :: tensor_in 1423 !! tensor in 1424 INTEGER, DIMENSION(:), INTENT(IN) :: ind1, ind2 1425 !! index that should be mapped to first matrix dimension 1426 !! index that should be mapped to second matrix dimension 1427 TYPE(dbcsr_t_type), POINTER, INTENT(OUT) :: tensor_out 1428 !! tensor out 1429 CHARACTER(LEN=1), INTENT(OUT) :: trans 1430 !! transpose flag of matrix rep. 1431 LOGICAL, INTENT(OUT) :: new 1432 !! whether a new tensor was created for tensor_out 1433 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data 1434 !! don't copy tensor data 1435 !! memory optimization: transfer data s.t. tensor_in may be empty on return 1436 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 1437 !! output unit 1438 INTEGER :: compat1, compat2, compat1_old, compat2_old, unit_nr_prv 1439 LOGICAL :: nodata_prv 1440 1441 NULLIFY (tensor_out) 1442 IF (PRESENT(nodata)) THEN 1443 nodata_prv = nodata 1444 ELSE 1445 nodata_prv = .FALSE. 1446 ENDIF 1447 1448 unit_nr_prv = prep_output_unit(unit_nr) 1449 1450 new = .FALSE. 1451 compat1 = compat_map(tensor_in%nd_index, ind1) 1452 compat2 = compat_map(tensor_in%nd_index, ind2) 1453 compat1_old = compat1; compat2_old = compat2 1454 IF (unit_nr_prv > 0) THEN 1455 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_in%name), ":" 1456 IF (compat1 == 1 .AND. compat2 == 2) THEN 1457 WRITE (unit_nr_prv, '(A)') "Normal" 1458 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN 1459 WRITE (unit_nr_prv, '(A)') "Transposed" 1460 ELSE 1461 WRITE (unit_nr_prv, '(A)') "Not compatible" 1462 ENDIF 1463 ENDIF 1464 IF (compat1 == 0 .or. compat2 == 0) THEN ! index mapping not compatible with contract index 1465 1466 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "Redistribution of", TRIM(tensor_in%name) 1467 1468 ALLOCATE (tensor_out) 1469 CALL dbcsr_t_remap(tensor_in, ind1, ind2, tensor_out, nodata=nodata, move_data=move_data) 1470 CALL dbcsr_t_copy_contraction_storage(tensor_in, tensor_out) 1471 compat1 = 1 1472 compat2 = 2 1473 new = .TRUE. 1474 ELSE 1475 IF (unit_nr_prv > 0) WRITE (unit_nr_prv, '(T2,A,1X,A)') "No redistribution of", TRIM(tensor_in%name) 1476 tensor_out => tensor_in 1477 ENDIF 1478 1479 IF (compat1 == 1 .AND. compat2 == 2) THEN 1480 trans = dbcsr_no_transpose 1481 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN 1482 trans = dbcsr_transpose 1483 ELSE 1484 DBCSR_ABORT("this should not happen") 1485 ENDIF 1486 1487 IF (unit_nr_prv > 0) THEN 1488 IF (compat1_old .NE. compat1 .OR. compat2_old .NE. compat2) THEN 1489 WRITE (unit_nr_prv, '(T2,A,1X,A,A,1X)', advance='no') "compatibility of", TRIM(tensor_out%name), ":" 1490 IF (compat1 == 1 .AND. compat2 == 2) THEN 1491 WRITE (unit_nr_prv, '(A)') "Normal" 1492 ELSEIF (compat1 == 2 .AND. compat2 == 1) THEN 1493 WRITE (unit_nr_prv, '(A)') "Transposed" 1494 ELSE 1495 WRITE (unit_nr_prv, '(A)') "Not compatible" 1496 ENDIF 1497 ENDIF 1498 ENDIF 1499 1500 END SUBROUTINE 1501 1502 FUNCTION update_contraction_storage(storage, split_opt, split) RESULT(do_change_pgrid) 1503 !! update contraction storage that keeps track of process grids during a batched contraction 1504 !! and decide if tensor process grid needs to be optimized 1505 TYPE(dbcsr_t_contraction_storage), INTENT(INOUT) :: storage 1506 TYPE(dbcsr_tas_split_info), INTENT(IN) :: split_opt 1507 !! optimized TAS process grid 1508 TYPE(dbcsr_tas_split_info), INTENT(IN) :: split 1509 !! current TAS process grid 1510 INTEGER, DIMENSION(2) :: pdims_opt, coor, pdims, pdims_sub 1511 LOGICAL, DIMENSION(2) :: periods 1512 LOGICAL, DIMENSION(2) :: do_change_pgrid 1513 REAL(kind=real_8) :: change_criterion, pdims_ratio 1514 INTEGER :: nsplit_opt, nsplit 1515 1516 DBCSR_ASSERT(ALLOCATED(split_opt%ngroup_opt)) 1517 nsplit_opt = split_opt%ngroup_opt 1518 nsplit = split%ngroup 1519 1520 CALL mp_environ(split_opt%mp_comm, 2, pdims_opt, coor, periods) 1521 CALL mp_environ(split%mp_comm, 2, pdims, coor, periods) 1522 1523 storage%ibatch = storage%ibatch + 1 1524 1525 storage%nsplit_avg = (storage%nsplit_avg*REAL(storage%ibatch - 1, real_8) + REAL(nsplit_opt, real_8)) & 1526 /REAL(storage%ibatch, real_8) 1527 1528 SELECT CASE (split_opt%split_rowcol) 1529 CASE (rowsplit) 1530 pdims_ratio = REAL(pdims(1), real_8)/pdims(2) 1531 CASE (colsplit) 1532 pdims_ratio = REAL(pdims(2), real_8)/pdims(1) 1533 END SELECT 1534 1535 do_change_pgrid(:) = .FALSE. 1536 1537 ! check for process grid dimensions 1538 CALL mp_environ(split%mp_comm_group, 2, pdims_sub, coor, periods) 1539 change_criterion = MAXVAL(REAL(pdims_sub, real_8))/MINVAL(pdims_sub) 1540 IF (change_criterion > default_pdims_accept_ratio**2) do_change_pgrid(1) = .TRUE. 1541 1542 ! check for split factor 1543 change_criterion = MAX(REAL(nsplit, real_8)/storage%nsplit_avg, REAL(storage%nsplit_avg, real_8)/nsplit) 1544 IF (change_criterion > default_nsplit_accept_ratio) do_change_pgrid(2) = .TRUE. 1545 1546 END FUNCTION 1547 1548 FUNCTION compat_map(nd_index, compat_ind) 1549 !! Check if 2d index is compatible with tensor index 1550 TYPE(nd_to_2d_mapping), INTENT(IN) :: nd_index 1551 INTEGER, DIMENSION(:), INTENT(IN) :: compat_ind 1552 INTEGER, DIMENSION(ndims_mapping_row(nd_index)) :: map1 1553 INTEGER, DIMENSION(ndims_mapping_column(nd_index)) :: map2 1554 INTEGER :: compat_map 1555 1556 CALL dbcsr_t_get_mapping_info(nd_index, map1_2d=map1, map2_2d=map2) 1557 1558 compat_map = 0 1559 IF (array_eq_i(map1, compat_ind)) THEN 1560 compat_map = 1 1561 ELSEIF (array_eq_i(map2, compat_ind)) THEN 1562 compat_map = 2 1563 ENDIF 1564 1565 END FUNCTION 1566 1567 SUBROUTINE invert_transpose_flag(trans_flag) 1568 CHARACTER(LEN=1), INTENT(INOUT) :: trans_flag 1569 1570 IF (trans_flag == dbcsr_transpose) THEN 1571 trans_flag = dbcsr_no_transpose 1572 ELSEIF (trans_flag == dbcsr_no_transpose) THEN 1573 trans_flag = dbcsr_transpose 1574 ENDIF 1575 END SUBROUTINE 1576 1577 SUBROUTINE index_linked_sort(ind_ref, ind_dep) 1578 INTEGER, DIMENSION(:), INTENT(INOUT) :: ind_ref, ind_dep 1579 INTEGER, DIMENSION(SIZE(ind_ref)) :: sort_indices 1580 1581 CALL sort(ind_ref, SIZE(ind_ref), sort_indices) 1582 ind_dep(:) = ind_dep(sort_indices) 1583 1584 END SUBROUTINE 1585 1586 FUNCTION opt_pgrid(tensor, tas_split_info) 1587 TYPE(dbcsr_t_type), INTENT(IN) :: tensor 1588 TYPE(dbcsr_tas_split_info), INTENT(IN) :: tas_split_info 1589 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1 1590 INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2 1591 TYPE(dbcsr_t_pgrid_type) :: opt_pgrid 1592 INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims 1593 1594 CALL dbcsr_t_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2) 1595 CALL blk_dims_tensor(tensor, dims) 1596 opt_pgrid = dbcsr_t_nd_mp_comm(tas_split_info%mp_comm, map1, map2, tdims=dims) 1597 1598 ALLOCATE (opt_pgrid%tas_split_info, SOURCE=tas_split_info) 1599 CALL dbcsr_tas_info_hold(opt_pgrid%tas_split_info) 1600 END FUNCTION 1601 1602 SUBROUTINE dbcsr_t_remap(tensor_in, map1_2d, map2_2d, tensor_out, comm_2d, dist1, dist2, & 1603 mp_dims_1, mp_dims_2, name, nodata, move_data) 1604 !! Copy tensor to tensor with modified index mapping 1605 1606 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_in 1607 INTEGER, DIMENSION(:), INTENT(IN) :: map1_2d, map2_2d 1608 !! new index mapping 1609 !! new index mapping 1610 TYPE(dbcsr_t_type), INTENT(OUT) :: tensor_out 1611 CHARACTER(len=*), INTENT(IN), OPTIONAL :: name 1612 LOGICAL, INTENT(IN), OPTIONAL :: nodata, move_data 1613 INTEGER, INTENT(IN), OPTIONAL :: comm_2d 1614 TYPE(array_list), INTENT(IN), OPTIONAL :: dist1, dist2 1615 INTEGER, DIMENSION(SIZE(map1_2d)), OPTIONAL :: mp_dims_1 1616 INTEGER, DIMENSION(SIZE(map2_2d)), OPTIONAL :: mp_dims_2 1617 CHARACTER(len=default_string_length) :: name_tmp 1618 INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("blk_sizes")}$, & 1619 ${varlist("nd_dist")}$ 1620 TYPE(dbcsr_t_distribution_type) :: dist 1621 INTEGER :: comm_2d_prv, handle, i 1622 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: pdims, myploc 1623 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_remap' 1624 LOGICAL :: nodata_prv 1625 TYPE(dbcsr_t_pgrid_type) :: comm_nd 1626 1627 CALL timeset(routineN, handle) 1628 1629 IF (PRESENT(name)) THEN 1630 name_tmp = name 1631 ELSE 1632 name_tmp = tensor_in%name 1633 ENDIF 1634 IF (PRESENT(dist1)) THEN 1635 DBCSR_ASSERT(PRESENT(mp_dims_1)) 1636 ENDIF 1637 1638 IF (PRESENT(dist2)) THEN 1639 DBCSR_ASSERT(PRESENT(mp_dims_2)) 1640 ENDIF 1641 1642 IF (PRESENT(comm_2d)) THEN 1643 comm_2d_prv = comm_2d 1644 ELSE 1645 comm_2d_prv = tensor_in%pgrid%mp_comm_2d 1646 ENDIF 1647 1648 comm_nd = dbcsr_t_nd_mp_comm(comm_2d_prv, map1_2d, map2_2d, dims1_nd=mp_dims_1, dims2_nd=mp_dims_2) 1649 CALL mp_environ_pgrid(comm_nd, pdims, myploc) 1650 1651#:for ndim in ndims 1652 IF (ndims_tensor(tensor_in) == ${ndim}$) THEN 1653 CALL get_arrays(tensor_in%blk_sizes, ${varlist("blk_sizes", nmax=ndim)}$) 1654 ENDIF 1655#:endfor 1656 1657#:for ndim in ndims 1658 IF (ndims_tensor(tensor_in) == ${ndim}$) THEN 1659#:for idim in range(1, ndim+1) 1660 IF (PRESENT(dist1)) THEN 1661 IF (ANY(map1_2d == ${idim}$)) THEN 1662 i = MINLOC(map1_2d, dim=1, mask=map1_2d == ${idim}$) ! i is location of idim in map1_2d 1663 CALL get_ith_array(dist1, i, nd_dist_${idim}$) 1664 ENDIF 1665 ENDIF 1666 1667 IF (PRESENT(dist2)) THEN 1668 IF (ANY(map2_2d == ${idim}$)) THEN 1669 i = MINLOC(map2_2d, dim=1, mask=map2_2d == ${idim}$) ! i is location of idim in map2_2d 1670 CALL get_ith_array(dist2, i, nd_dist_${idim}$) 1671 ENDIF 1672 ENDIF 1673 1674 IF (.NOT. ALLOCATED(nd_dist_${idim}$)) THEN 1675 ALLOCATE (nd_dist_${idim}$ (SIZE(blk_sizes_${idim}$))) 1676 CALL dbcsr_t_default_distvec(SIZE(blk_sizes_${idim}$), pdims(${idim}$), blk_sizes_${idim}$, nd_dist_${idim}$) 1677 ENDIF 1678#:endfor 1679 CALL dbcsr_t_distribution_new_expert(dist, comm_nd, map1_2d, map2_2d, & 1680 ${varlist("nd_dist", nmax=ndim)}$, own_comm=.TRUE.) 1681 ENDIF 1682#:endfor 1683 1684#:for ndim in ndims 1685 IF (ndims_tensor(tensor_in) == ${ndim}$) THEN 1686 CALL dbcsr_t_create(tensor_out, name_tmp, dist, & 1687 map1_2d, map2_2d, dbcsr_tas_get_data_type(tensor_in%matrix_rep), & 1688 ${varlist("blk_sizes", nmax=ndim)}$) 1689 ENDIF 1690#:endfor 1691 1692 IF (PRESENT(nodata)) THEN 1693 nodata_prv = nodata 1694 ELSE 1695 nodata_prv = .FALSE. 1696 ENDIF 1697 1698 IF (.NOT. nodata_prv) CALL dbcsr_t_copy_expert(tensor_in, tensor_out, move_data=move_data) 1699 CALL dbcsr_t_distribution_destroy(dist) 1700 1701 CALL timestop(handle) 1702 END SUBROUTINE 1703 1704 SUBROUTINE dbcsr_t_align_index(tensor_in, tensor_out, order) 1705 !! Align index with data 1706 1707 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_in 1708 TYPE(dbcsr_t_type), INTENT(OUT) :: tensor_out 1709 INTEGER, DIMENSION(ndims_matrix_row(tensor_in)) :: map1_2d 1710 INTEGER, DIMENSION(ndims_matrix_column(tensor_in)) :: map2_2d 1711 INTEGER, DIMENSION(ndims_tensor(tensor_in)), & 1712 INTENT(OUT), OPTIONAL :: order 1713 !! permutation resulting from alignment 1714 INTEGER, DIMENSION(ndims_tensor(tensor_in)) :: order_prv 1715 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_align_index' 1716 INTEGER :: handle 1717 1718 CALL timeset(routineN, handle) 1719 1720 CALL dbcsr_t_get_mapping_info(tensor_in%nd_index_blk, map1_2d=map1_2d, map2_2d=map2_2d) 1721 order_prv = dbcsr_t_inverse_order([map1_2d, map2_2d]) 1722 CALL dbcsr_t_permute_index(tensor_in, tensor_out, order=order_prv) 1723 1724 IF (PRESENT(order)) order = order_prv 1725 1726 CALL timestop(handle) 1727 END SUBROUTINE 1728 1729 SUBROUTINE dbcsr_t_permute_index(tensor_in, tensor_out, order) 1730 !! Create new tensor by reordering index, data is copied exactly (shallow copy) 1731 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor_in 1732 TYPE(dbcsr_t_type), INTENT(OUT) :: tensor_out 1733 INTEGER, DIMENSION(ndims_tensor(tensor_in)), & 1734 INTENT(IN) :: order 1735 1736 TYPE(nd_to_2d_mapping) :: nd_index_blk_rs, nd_index_rs 1737 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_permute_index' 1738 INTEGER :: handle 1739 INTEGER :: ndims 1740 1741 CALL timeset(routineN, handle) 1742 1743 ndims = ndims_tensor(tensor_in) 1744 1745 CALL permute_index(tensor_in%nd_index, nd_index_rs, order) 1746 CALL permute_index(tensor_in%nd_index_blk, nd_index_blk_rs, order) 1747 CALL permute_index(tensor_in%pgrid%nd_index_grid, tensor_out%pgrid%nd_index_grid, order) 1748 1749 tensor_out%matrix_rep => tensor_in%matrix_rep 1750 tensor_out%owns_matrix = .FALSE. 1751 1752 tensor_out%nd_index = nd_index_rs 1753 tensor_out%nd_index_blk = nd_index_blk_rs 1754 tensor_out%pgrid%mp_comm_2d = tensor_in%pgrid%mp_comm_2d 1755 IF (ALLOCATED(tensor_in%pgrid%tas_split_info)) THEN 1756 ALLOCATE (tensor_out%pgrid%tas_split_info, SOURCE=tensor_in%pgrid%tas_split_info) 1757 ENDIF 1758 tensor_out%refcount => tensor_in%refcount 1759 CALL dbcsr_t_hold(tensor_out) 1760 1761 CALL reorder_arrays(tensor_in%blk_sizes, tensor_out%blk_sizes, order) 1762 CALL reorder_arrays(tensor_in%blk_offsets, tensor_out%blk_offsets, order) 1763 CALL reorder_arrays(tensor_in%nd_dist, tensor_out%nd_dist, order) 1764 CALL reorder_arrays(tensor_in%blks_local, tensor_out%blks_local, order) 1765 ALLOCATE (tensor_out%nblks_local(ndims)) 1766 ALLOCATE (tensor_out%nfull_local(ndims)) 1767 tensor_out%nblks_local(order) = tensor_in%nblks_local(:) 1768 tensor_out%nfull_local(order) = tensor_in%nfull_local(:) 1769 tensor_out%name = tensor_in%name 1770 tensor_out%valid = .TRUE. 1771 1772 IF (ALLOCATED(tensor_in%contraction_storage)) THEN 1773 ALLOCATE (tensor_out%contraction_storage, SOURCE=tensor_in%contraction_storage) 1774 CALL destroy_array_list(tensor_out%contraction_storage%batch_ranges) 1775 CALL reorder_arrays(tensor_in%contraction_storage%batch_ranges, tensor_out%contraction_storage%batch_ranges, order) 1776 ENDIF 1777 1778 CALL timestop(handle) 1779 END SUBROUTINE 1780 1781 SUBROUTINE dbcsr_t_contract_index(alpha, tensor_1, tensor_2, beta, tensor_3, & 1782 contract_1, notcontract_1, & 1783 contract_2, notcontract_2, & 1784 map_1, map_2, & 1785 bounds_1, bounds_2, bounds_3, & 1786 filter_eps, & 1787 nblks_local, result_index) 1788 !! get indices of non-zero tensor blocks for contraction result without actually 1789 !! performing contraction. 1790 !! this is an estimate based on block norm multiplication. 1791 !! See documentation of dbcsr_t_contract. 1792 TYPE(dbcsr_scalar_type), INTENT(IN) :: alpha 1793 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_1 1794 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_2 1795 TYPE(dbcsr_scalar_type), INTENT(IN) :: beta 1796 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1 1797 INTEGER, DIMENSION(:), INTENT(IN) :: contract_2 1798 INTEGER, DIMENSION(:), INTENT(IN) :: map_1 1799 INTEGER, DIMENSION(:), INTENT(IN) :: map_2 1800 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_1 1801 INTEGER, DIMENSION(:), INTENT(IN) :: notcontract_2 1802 TYPE(dbcsr_t_type), INTENT(INOUT), TARGET :: tensor_3 1803 INTEGER, DIMENSION(2, SIZE(contract_1)), & 1804 INTENT(IN), OPTIONAL :: bounds_1 1805 INTEGER, DIMENSION(2, SIZE(notcontract_1)), & 1806 INTENT(IN), OPTIONAL :: bounds_2 1807 INTEGER, DIMENSION(2, SIZE(notcontract_2)), & 1808 INTENT(IN), OPTIONAL :: bounds_3 1809 REAL(KIND=real_8), INTENT(IN), OPTIONAL :: filter_eps 1810 INTEGER, INTENT(OUT) :: nblks_local 1811 !! number of local blocks on this MPI rank 1812 INTEGER, DIMENSION(dbcsr_t_max_nblks_local(tensor_3), ndims_tensor(tensor_3)), & 1813 INTENT(OUT) :: result_index 1814 !! indices of local non-zero tensor blocks for tensor_3 1815 !! only the elements result_index(:nblks_local, :) are relevant (all others are set to 0) 1816 1817 CALL dbcsr_t_contract_expert(alpha, tensor_1, tensor_2, beta, tensor_3, & 1818 contract_1, notcontract_1, & 1819 contract_2, notcontract_2, & 1820 map_1, map_2, & 1821 bounds_1=bounds_1, & 1822 bounds_2=bounds_2, & 1823 bounds_3=bounds_3, & 1824 filter_eps=filter_eps, & 1825 nblks_local=nblks_local, & 1826 result_index=result_index) 1827 END SUBROUTINE 1828 1829 SUBROUTINE dbcsr_t_map_bounds_to_tensors(tensor_1, tensor_2, & 1830 contract_1, notcontract_1, & 1831 contract_2, notcontract_2, & 1832 bounds_t1, bounds_t2, & 1833 bounds_1, bounds_2, bounds_3, & 1834 do_crop_1, do_crop_2) 1835 !! Map contraction bounds to bounds referring to tensor indices 1836 !! see dbcsr_t_contract for docu of dummy arguments 1837 1838 TYPE(dbcsr_t_type), INTENT(IN) :: tensor_1, tensor_2 1839 INTEGER, DIMENSION(:), INTENT(IN) :: contract_1, contract_2, & 1840 notcontract_1, notcontract_2 1841 INTEGER, DIMENSION(2, ndims_tensor(tensor_1)), & 1842 INTENT(OUT) :: bounds_t1 1843 !! bounds mapped to tensor_1 1844 INTEGER, DIMENSION(2, ndims_tensor(tensor_2)), & 1845 INTENT(OUT) :: bounds_t2 1846 !! bounds mapped to tensor_2 1847 INTEGER, DIMENSION(2, SIZE(contract_1)), & 1848 INTENT(IN), OPTIONAL :: bounds_1 1849 INTEGER, DIMENSION(2, SIZE(notcontract_1)), & 1850 INTENT(IN), OPTIONAL :: bounds_2 1851 INTEGER, DIMENSION(2, SIZE(notcontract_2)), & 1852 INTENT(IN), OPTIONAL :: bounds_3 1853 LOGICAL, INTENT(OUT), OPTIONAL :: do_crop_1, do_crop_2 1854 !! whether tensor 1 should be cropped 1855 !! whether tensor 2 should be cropped 1856 LOGICAL, DIMENSION(2) :: do_crop 1857 1858 do_crop = .FALSE. 1859 1860 bounds_t1(1, :) = 1 1861 CALL dbcsr_t_get_info(tensor_1, nfull_total=bounds_t1(2, :)) 1862 1863 bounds_t2(1, :) = 1 1864 CALL dbcsr_t_get_info(tensor_2, nfull_total=bounds_t2(2, :)) 1865 1866 IF (PRESENT(bounds_1)) THEN 1867 bounds_t1(:, contract_1) = bounds_1 1868 do_crop(1) = .TRUE. 1869 bounds_t2(:, contract_2) = bounds_1 1870 do_crop(2) = .TRUE. 1871 ENDIF 1872 1873 IF (PRESENT(bounds_2)) THEN 1874 bounds_t1(:, notcontract_1) = bounds_2 1875 do_crop(1) = .TRUE. 1876 ENDIF 1877 1878 IF (PRESENT(bounds_3)) THEN 1879 bounds_t2(:, notcontract_2) = bounds_3 1880 do_crop(2) = .TRUE. 1881 ENDIF 1882 1883 IF (PRESENT(do_crop_1)) do_crop_1 = do_crop(1) 1884 IF (PRESENT(do_crop_2)) do_crop_2 = do_crop(2) 1885 1886 END SUBROUTINE 1887 1888 SUBROUTINE dbcsr_t_print_contraction_index(tensor_1, indchar1, tensor_2, indchar2, tensor_3, indchar3, unit_nr) 1889 !! print tensor contraction indices in a human readable way 1890 1891 TYPE(dbcsr_t_type), INTENT(IN) :: tensor_1, tensor_2, tensor_3 1892 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_1)), INTENT(IN) :: indchar1 1893 !! characters printed for index of tensor 1 1894 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_2)), INTENT(IN) :: indchar2 1895 !! characters printed for index of tensor 2 1896 CHARACTER(LEN=1), DIMENSION(ndims_tensor(tensor_3)), INTENT(IN) :: indchar3 1897 !! characters printed for index of tensor 3 1898 INTEGER, INTENT(IN) :: unit_nr 1899 !! output unit 1900 INTEGER, DIMENSION(ndims_matrix_row(tensor_1)) :: map11 1901 INTEGER, DIMENSION(ndims_matrix_column(tensor_1)) :: map12 1902 INTEGER, DIMENSION(ndims_matrix_row(tensor_2)) :: map21 1903 INTEGER, DIMENSION(ndims_matrix_column(tensor_2)) :: map22 1904 INTEGER, DIMENSION(ndims_matrix_row(tensor_3)) :: map31 1905 INTEGER, DIMENSION(ndims_matrix_column(tensor_3)) :: map32 1906 INTEGER :: ichar1, ichar2, ichar3, unit_nr_prv 1907 1908 unit_nr_prv = prep_output_unit(unit_nr) 1909 1910 IF (unit_nr_prv /= 0) THEN 1911 CALL dbcsr_t_get_mapping_info(tensor_1%nd_index_blk, map1_2d=map11, map2_2d=map12) 1912 CALL dbcsr_t_get_mapping_info(tensor_2%nd_index_blk, map1_2d=map21, map2_2d=map22) 1913 CALL dbcsr_t_get_mapping_info(tensor_3%nd_index_blk, map1_2d=map31, map2_2d=map32) 1914 ENDIF 1915 1916 IF (unit_nr_prv > 0) THEN 1917 WRITE (unit_nr_prv, '(T2,A)') "INDEX INFO" 1918 WRITE (unit_nr_prv, '(T15,A)', advance='no') "tensor index: (" 1919 DO ichar1 = 1, SIZE(indchar1) 1920 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(ichar1) 1921 ENDDO 1922 WRITE (unit_nr_prv, '(A)', advance='no') ") x (" 1923 DO ichar2 = 1, SIZE(indchar2) 1924 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(ichar2) 1925 ENDDO 1926 WRITE (unit_nr_prv, '(A)', advance='no') ") = (" 1927 DO ichar3 = 1, SIZE(indchar3) 1928 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(ichar3) 1929 ENDDO 1930 WRITE (unit_nr_prv, '(A)') ")" 1931 1932 WRITE (unit_nr_prv, '(T15,A)', advance='no') "matrix index: (" 1933 DO ichar1 = 1, SIZE(map11) 1934 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map11(ichar1)) 1935 ENDDO 1936 WRITE (unit_nr_prv, '(A1)', advance='no') "|" 1937 DO ichar1 = 1, SIZE(map12) 1938 WRITE (unit_nr_prv, '(A1)', advance='no') indchar1(map12(ichar1)) 1939 ENDDO 1940 WRITE (unit_nr_prv, '(A)', advance='no') ") x (" 1941 DO ichar2 = 1, SIZE(map21) 1942 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map21(ichar2)) 1943 ENDDO 1944 WRITE (unit_nr_prv, '(A1)', advance='no') "|" 1945 DO ichar2 = 1, SIZE(map22) 1946 WRITE (unit_nr_prv, '(A1)', advance='no') indchar2(map22(ichar2)) 1947 ENDDO 1948 WRITE (unit_nr_prv, '(A)', advance='no') ") = (" 1949 DO ichar3 = 1, SIZE(map31) 1950 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map31(ichar3)) 1951 ENDDO 1952 WRITE (unit_nr_prv, '(A1)', advance='no') "|" 1953 DO ichar3 = 1, SIZE(map32) 1954 WRITE (unit_nr_prv, '(A1)', advance='no') indchar3(map32(ichar3)) 1955 ENDDO 1956 WRITE (unit_nr_prv, '(A)') ")" 1957 ENDIF 1958 1959 END SUBROUTINE 1960 1961 SUBROUTINE dbcsr_t_batched_contract_init(tensor, ${varlist("batch_range")}$) 1962 !! Initialize batched contraction for this tensor. 1963 !! 1964 !! Explanation: A batched contraction is a contraction performed in several consecutive steps by 1965 !! specification of bounds in dbcsr_t_contract. This can be used to reduce memory by a large factor. 1966 !! The routines dbcsr_t_batched_contract_init and dbcsr_t_batched_contract_finalize should be 1967 !! called to define the scope of a batched contraction as this enables important optimizations 1968 !! (adapting communication scheme to batches and adapting process grid to multiplication algorithm). 1969 !! The routines dbcsr_t_batched_contract_init and dbcsr_t_batched_contract_finalize must be called 1970 !! before the first and after the last contraction step on all 3 tensors. 1971 !! 1972 !! Requirements: 1973 !! - the tensors are in a compatible matrix layout (see documentation of `dbcsr_t_contract`, note 2 & 3). 1974 !! If they are not, process grid optimizations are disabled and a warning is issued. 1975 !! - within the scope of a batched contraction, it is not allowed to access or change tensor data 1976 !! except by calling the routines dbcsr_t_contract & dbcsr_t_copy. 1977 !! - the bounds affecting indices of the smallest tensor must not change in the course of a batched 1978 !! contraction (todo: get rid of this requirement). 1979 !! 1980 !! Side effects: 1981 !! - the parallel layout (process grid and distribution) of all tensors may change. In order to 1982 !! disable the process grid optimization including this side effect, call this routine only on the 1983 !! smallest of the 3 tensors. 1984 !! 1985 !! @note 1986 !! Note 1: for an example of batched contraction see `examples/dbcsr_tensor_example.F`. 1987 !! (todo: the example is outdated and should be updated). 1988 !! 1989 !! Note 2: it is meaningful to use this feature if the contraction consists of one batch only 1990 !! but if multiple contractions involving the same 3 tensors are performed 1991 !! (batched_contract_init and batched_contract_finalize must then be called before/after each 1992 !! contraction call). The process grid is then optimized after the first contraction 1993 !! and future contraction may profit from this optimization. 1994 !! @endnote 1995 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor 1996 INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: ${varlist("batch_range")}$ 1997 !! For internal load balancing optimizations, optionally specify the index ranges of 1998 !! batched contraction. 1999 !! batch_range_i refers to the ith tensor dimension and contains all block indices starting 2000 !! a new range. The size should be the number of ranges plus one, the last element being the 2001 !! block index plus one of the last block in the last range. 2002 INTEGER, DIMENSION(ndims_tensor(tensor)) :: tdims 2003 INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range_prv")}$ 2004 LOGICAL :: static_range 2005 2006 CALL dbcsr_t_get_info(tensor, nblks_total=tdims) 2007 2008 static_range = .TRUE. 2009#:for idim in range(1, maxdim+1) 2010 IF (ndims_tensor(tensor) >= ${idim}$) THEN 2011 IF (PRESENT(batch_range_${idim}$)) THEN 2012 CALL allocate_any(batch_range_prv_${idim}$, source=batch_range_${idim}$) 2013 static_range = .FALSE. 2014 ELSE 2015 ALLOCATE (batch_range_prv_${idim}$ (2)) 2016 batch_range_prv_${idim}$ (1) = 1 2017 batch_range_prv_${idim}$ (2) = tdims(${idim}$) + 1 2018 ENDIF 2019 ENDIF 2020#:endfor 2021 2022 ALLOCATE (tensor%contraction_storage) 2023 tensor%contraction_storage%static = static_range 2024 IF (static_range) THEN 2025 CALL dbcsr_tas_batched_mm_init(tensor%matrix_rep) 2026 ENDIF 2027 tensor%contraction_storage%nsplit_avg = 0.0_real_8 2028 tensor%contraction_storage%ibatch = 0 2029 2030#:for ndim in range(1, maxdim+1) 2031 IF (ndims_tensor(tensor) == ${ndim}$) THEN 2032 CALL create_array_list(tensor%contraction_storage%batch_ranges, ${ndim}$, & 2033 ${varlist("batch_range_prv", nmax=ndim)}$) 2034 ENDIF 2035#:endfor 2036 2037 END SUBROUTINE 2038 2039 SUBROUTINE dbcsr_t_batched_contract_finalize(tensor, unit_nr) 2040 !! finalize batched contraction. This performs all communication that has been postponed in the 2041 !! contraction calls. 2042 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor 2043 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 2044 LOGICAL :: do_write 2045 INTEGER :: unit_nr_prv, handle 2046 2047 CALL mp_sync(tensor%pgrid%mp_comm_2d) 2048 CALL timeset("dbcsr_t_total", handle) 2049 unit_nr_prv = prep_output_unit(unit_nr) 2050 2051 do_write = .FALSE. 2052 2053 IF (tensor%contraction_storage%static) THEN 2054 IF (tensor%matrix_rep%do_batched > 0) THEN 2055 IF (tensor%matrix_rep%mm_storage%batched_out) do_write = .TRUE. 2056 ENDIF 2057 CALL dbcsr_tas_batched_mm_finalize(tensor%matrix_rep) 2058 ENDIF 2059 2060 IF (do_write .AND. unit_nr_prv /= 0) THEN 2061 IF (unit_nr_prv > 0) THEN 2062 WRITE (unit_nr_prv, "(T2,A)") & 2063 "FINALIZING BATCHED PROCESSING OF MATMUL" 2064 ENDIF 2065 CALL dbcsr_t_write_tensor_info(tensor, unit_nr_prv) 2066 CALL dbcsr_t_write_tensor_dist(tensor, unit_nr_prv) 2067 ENDIF 2068 2069 CALL destroy_array_list(tensor%contraction_storage%batch_ranges) 2070 DEALLOCATE (tensor%contraction_storage) 2071 CALL mp_sync(tensor%pgrid%mp_comm_2d) 2072 CALL timestop(handle) 2073 2074 END SUBROUTINE 2075 2076 SUBROUTINE dbcsr_t_change_pgrid(tensor, pgrid, ${varlist("batch_range")}$, & 2077 nodata, pgrid_changed, unit_nr) 2078 !! change the process grid of a tensor 2079 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor 2080 TYPE(dbcsr_t_pgrid_type), INTENT(IN) :: pgrid 2081 INTEGER, DIMENSION(:), OPTIONAL, INTENT(IN) :: ${varlist("batch_range")}$ 2082 !! For internal load balancing optimizations, optionally specify the index ranges of 2083 !! batched contraction. 2084 !! batch_range_i refers to the ith tensor dimension and contains all block indices starting 2085 !! a new range. The size should be the number of ranges plus one, the last element being the 2086 !! block index plus one of the last block in the last range. 2087 LOGICAL, INTENT(IN), OPTIONAL :: nodata 2088 !! optionally don't copy the tensor data (then tensor is empty on returned) 2089 LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed 2090 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 2091 CHARACTER(LEN=*), PARAMETER :: routineN = 'dbcsr_t_change_pgrid' 2092 CHARACTER(default_string_length) :: name 2093 INTEGER :: handle 2094 INTEGER, ALLOCATABLE, DIMENSION(:) :: ${varlist("bs")}$, & 2095 ${varlist("dist")}$ 2096 INTEGER, DIMENSION(ndims_tensor(tensor)) :: pcoord, pcoord_ref, pdims, pdims_ref, & 2097 tdims 2098 TYPE(dbcsr_t_type) :: t_tmp 2099 TYPE(dbcsr_t_distribution_type) :: dist 2100 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1 2101 INTEGER, & 2102 DIMENSION(ndims_matrix_column(tensor)) :: map2 2103 LOGICAL, DIMENSION(ndims_tensor(tensor)) :: mem_aware 2104 INTEGER, DIMENSION(ndims_tensor(tensor)) :: nbatch 2105 INTEGER :: ind1, ind2, batch_size, ibatch 2106 2107 IF (PRESENT(pgrid_changed)) pgrid_changed = .FALSE. 2108 CALL mp_environ_pgrid(pgrid, pdims, pcoord) 2109 CALL mp_environ_pgrid(tensor%pgrid, pdims_ref, pcoord_ref) 2110 2111 IF (ALL(pdims == pdims_ref)) THEN 2112 IF (ALLOCATED(pgrid%tas_split_info) .AND. ALLOCATED(tensor%pgrid%tas_split_info)) THEN 2113 IF (pgrid%tas_split_info%ngroup == tensor%pgrid%tas_split_info%ngroup) THEN 2114 RETURN 2115 ENDIF 2116 ENDIF 2117 ENDIF 2118 2119 CALL timeset(routineN, handle) 2120 2121#:for idim in range(1, maxdim+1) 2122 IF (ndims_tensor(tensor) >= ${idim}$) THEN 2123 mem_aware(${idim}$) = PRESENT(batch_range_${idim}$) 2124 IF (mem_aware(${idim}$)) nbatch(${idim}$) = SIZE(batch_range_${idim}$) - 1 2125 ENDIF 2126#:endfor 2127 2128 CALL dbcsr_t_get_info(tensor, nblks_total=tdims, name=name) 2129 2130#:for idim in range(1, maxdim+1) 2131 IF (ndims_tensor(tensor) >= ${idim}$) THEN 2132 ALLOCATE (bs_${idim}$ (dbcsr_t_nblks_total(tensor, ${idim}$))) 2133 CALL get_ith_array(tensor%blk_sizes, ${idim}$, bs_${idim}$) 2134 ALLOCATE (dist_${idim}$ (tdims(${idim}$))) 2135 dist_${idim}$ = 0 2136 IF (mem_aware(${idim}$)) THEN 2137 DO ibatch = 1, nbatch(${idim}$) 2138 ind1 = batch_range_${idim}$ (ibatch) 2139 ind2 = batch_range_${idim}$ (ibatch + 1) - 1 2140 batch_size = ind2 - ind1 + 1 2141 CALL dbcsr_t_default_distvec(batch_size, pdims(${idim}$), & 2142 bs_${idim}$ (ind1:ind2), dist_${idim}$ (ind1:ind2)) 2143 ENDDO 2144 ELSE 2145 CALL dbcsr_t_default_distvec(tdims(${idim}$), pdims(${idim}$), bs_${idim}$, dist_${idim}$) 2146 ENDIF 2147 ENDIF 2148#:endfor 2149 2150 CALL dbcsr_t_get_mapping_info(tensor%nd_index_blk, map1_2d=map1, map2_2d=map2) 2151#:for ndim in ndims 2152 IF (ndims_tensor(tensor) == ${ndim}$) THEN 2153 CALL dbcsr_t_distribution_new(dist, pgrid, ${varlist("dist", nmax=ndim)}$) 2154 CALL dbcsr_t_create(t_tmp, name, dist, map1, map2, dbcsr_type_real_8, ${varlist("bs", nmax=ndim)}$) 2155 ENDIF 2156#:endfor 2157 CALL dbcsr_t_distribution_destroy(dist) 2158 2159 IF (PRESENT(nodata)) THEN 2160 IF (.NOT. nodata) CALL dbcsr_t_copy_expert(tensor, t_tmp, move_data=.TRUE.) 2161 ELSE 2162 CALL dbcsr_t_copy_expert(tensor, t_tmp, move_data=.TRUE.) 2163 ENDIF 2164 2165 CALL dbcsr_t_copy_contraction_storage(tensor, t_tmp) 2166 2167 CALL dbcsr_t_destroy(tensor) 2168 tensor = t_tmp 2169 2170 IF (PRESENT(unit_nr)) THEN 2171 IF (unit_nr > 0) THEN 2172 WRITE (unit_nr, "(T2,A,1X,A)") "OPTIMIZED PGRID INFO FOR", TRIM(tensor%name) 2173 WRITE (unit_nr, "(T4,A,1X,3I6)") "process grid dimensions:", pdims 2174 CALL dbcsr_t_write_split_info(pgrid, unit_nr) 2175 ENDIF 2176 ENDIF 2177 2178 IF (PRESENT(pgrid_changed)) pgrid_changed = .TRUE. 2179 2180 CALL timestop(handle) 2181 END SUBROUTINE 2182 2183 SUBROUTINE dbcsr_t_change_pgrid_2d(tensor, mp_comm, pdims, nodata, nsplit, dimsplit, pgrid_changed, unit_nr) 2184 !! map tensor to a new 2d process grid for the matrix representation. 2185 TYPE(dbcsr_t_type), INTENT(INOUT) :: tensor 2186 INTEGER, INTENT(IN) :: mp_comm 2187 INTEGER, DIMENSION(2), INTENT(IN), OPTIONAL :: pdims 2188 LOGICAL, INTENT(IN), OPTIONAL :: nodata 2189 INTEGER, INTENT(IN), OPTIONAL :: nsplit, dimsplit 2190 LOGICAL, INTENT(OUT), OPTIONAL :: pgrid_changed 2191 INTEGER, INTENT(IN), OPTIONAL :: unit_nr 2192 INTEGER, DIMENSION(ndims_matrix_row(tensor)) :: map1 2193 INTEGER, DIMENSION(ndims_matrix_column(tensor)) :: map2 2194 INTEGER, DIMENSION(ndims_tensor(tensor)) :: dims, nbatches 2195 TYPE(dbcsr_t_pgrid_type) :: pgrid 2196 INTEGER, DIMENSION(:), ALLOCATABLE :: ${varlist("batch_range")}$ 2197 INTEGER, DIMENSION(:), ALLOCATABLE :: array 2198 INTEGER :: idim 2199 2200 CALL dbcsr_t_get_mapping_info(tensor%pgrid%nd_index_grid, map1_2d=map1, map2_2d=map2) 2201 CALL blk_dims_tensor(tensor, dims) 2202 2203 IF (ALLOCATED(tensor%contraction_storage)) THEN 2204 ASSOCIATE (batch_ranges => tensor%contraction_storage%batch_ranges) 2205 nbatches = sizes_of_arrays(tensor%contraction_storage%batch_ranges) - 1 2206 ! for good load balancing the process grid dimensions should be chosen adapted to the 2207 ! tensor dimenions. For batched contraction the tensor dimensions should be divided by 2208 ! the number of batches (number of index ranges). 2209 DO idim = 1, ndims_tensor(tensor) 2210 CALL get_ith_array(tensor%contraction_storage%batch_ranges, idim, array) 2211 dims(idim) = array(nbatches(idim) + 1) - array(1) 2212 DEALLOCATE (array) 2213 dims(idim) = dims(idim)/nbatches(idim) 2214 IF (dims(idim) <= 0) dims(idim) = 1 2215 ENDDO 2216 END ASSOCIATE 2217 ENDIF 2218 2219 pgrid = dbcsr_t_nd_mp_comm(mp_comm, map1, map2, pdims_2d=pdims, tdims=dims, nsplit=nsplit, dimsplit=dimsplit) 2220 IF (ALLOCATED(tensor%contraction_storage)) THEN 2221#:for ndim in range(1, maxdim+1) 2222 IF (ndims_tensor(tensor) == ${ndim}$) THEN 2223 CALL get_arrays(tensor%contraction_storage%batch_ranges, ${varlist("batch_range", nmax=ndim)}$) 2224 CALL dbcsr_t_change_pgrid(tensor, pgrid, ${varlist("batch_range", nmax=ndim)}$, & 2225 nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr) 2226 ENDIF 2227#:endfor 2228 ELSE 2229 CALL dbcsr_t_change_pgrid(tensor, pgrid, nodata=nodata, pgrid_changed=pgrid_changed, unit_nr=unit_nr) 2230 ENDIF 2231 CALL dbcsr_t_pgrid_destroy(pgrid) 2232 2233 END SUBROUTINE 2234 2235END MODULE 2236