1# Copyright (C) 2003 CAMP 2# Please see the accompanying LICENSE file for further information. 3 4import sys 5import time 6import traceback 7import atexit 8import pickle 9from contextlib import contextmanager 10from typing import Any 11 12from ase.parallel import world as aseworld 13import numpy as np 14 15import gpaw 16from .broadcast_imports import world 17import _gpaw 18 19MASTER = 0 20 21 22def is_contiguous(*args, **kwargs): 23 from gpaw.utilities import is_contiguous 24 return is_contiguous(*args, **kwargs) 25 26 27@contextmanager 28def broadcast_exception(comm): 29 """Make sure all ranks get a possible exception raised. 30 31 This example:: 32 33 with broadcast_exception(world): 34 if world.rank == 0: 35 x = 1 / 0 36 37 will raise ZeroDivisionError in the whole world. 38 """ 39 # Each core will send -1 on success or its rank on failure. 40 try: 41 yield 42 except Exception as ex: 43 rank = comm.max(comm.rank) 44 if rank == comm.rank: 45 broadcast(ex, rank, comm) 46 raise 47 else: 48 rank = comm.max(-1) 49 # rank will now be the highest failing rank or -1 50 if rank >= 0: 51 raise broadcast(None, rank, comm) 52 53 54class _Communicator: 55 def __init__(self, comm, parent=None): 56 """Construct a wrapper of the C-object for any MPI-communicator. 57 58 Parameters: 59 60 comm: MPI-communicator 61 Communicator. 62 63 Attributes: 64 65 ============ ====================================================== 66 ``size`` Number of ranks in the MPI group. 67 ``rank`` Number of this CPU in the MPI group. 68 ``parent`` Parent MPI-communicator. 69 ============ ====================================================== 70 """ 71 self.comm = comm 72 self.size = comm.size 73 self.rank = comm.rank 74 self.parent = parent # XXX check C-object against comm.parent? 75 76 def new_communicator(self, ranks): 77 """Create a new MPI communicator for a subset of ranks in a group. 78 Must be called with identical arguments by all relevant processes. 79 80 Note that a valid communicator is only returned to the processes 81 which are included in the new group; other ranks get None returned. 82 83 Parameters: 84 85 ranks: ndarray (type int) 86 List of integers of the ranks to include in the new group. 87 Note that these ranks correspond to indices in the current 88 group whereas the rank attribute in the new communicators 89 correspond to their respective index in the subset. 90 91 """ 92 93 comm = self.comm.new_communicator(ranks) 94 if comm is None: 95 # This cpu is not in the new communicator: 96 return None 97 else: 98 return _Communicator(comm, parent=self) 99 100 def sum(self, a, root=-1): 101 """Perform summation by MPI reduce operations of numerical data. 102 103 Parameters: 104 105 a: ndarray or value (type int, float or complex) 106 Numerical data to sum over all ranks in the communicator group. 107 If the data is a single value of type int, float or complex, 108 the result is returned because the input argument is immutable. 109 Otherwise, the reduce operation is carried out in-place such 110 that the elements of the input array will represent the sum of 111 the equivalent elements across all processes in the group. 112 root: int (default -1) 113 Rank of the root process, on which the outcome of the reduce 114 operation is valid. A root rank of -1 signifies that the result 115 will be distributed back to all processes, i.e. a broadcast. 116 117 """ 118 if isinstance(a, (int, float, complex)): 119 return self.comm.sum(a, root) 120 else: 121 tc = a.dtype 122 assert tc == int or tc == float or tc == complex 123 assert is_contiguous(a, tc) 124 assert root == -1 or 0 <= root < self.size 125 self.comm.sum(a, root) 126 127 def product(self, a, root=-1): 128 """Do multiplication by MPI reduce operations of numerical data. 129 130 Parameters: 131 132 a: ndarray or value (type int or float) 133 Numerical data to multiply across all ranks in the communicator 134 group. NB: Find the global product from the local products. 135 If the data is a single value of type int or float (no complex), 136 the result is returned because the input argument is immutable. 137 Otherwise, the reduce operation is carried out in-place such 138 that the elements of the input array will represent the product 139 of the equivalent elements across all processes in the group. 140 root: int (default -1) 141 Rank of the root process, on which the outcome of the reduce 142 operation is valid. A root rank of -1 signifies that the result 143 will be distributed back to all processes, i.e. a broadcast. 144 145 """ 146 if isinstance(a, (int, float)): 147 return self.comm.product(a, root) 148 else: 149 tc = a.dtype 150 assert tc == int or tc == float 151 assert is_contiguous(a, tc) 152 assert root == -1 or 0 <= root < self.size 153 self.comm.product(a, root) 154 155 def max(self, a, root=-1): 156 """Find maximal value by an MPI reduce operation of numerical data. 157 158 Parameters: 159 160 a: ndarray or value (type int or float) 161 Numerical data to find the maximum value of across all ranks in 162 the communicator group. NB: Find global maximum from local max. 163 If the data is a single value of type int or float (no complex), 164 the result is returned because the input argument is immutable. 165 Otherwise, the reduce operation is carried out in-place such 166 that the elements of the input array will represent the max of 167 the equivalent elements across all processes in the group. 168 root: int (default -1) 169 Rank of the root process, on which the outcome of the reduce 170 operation is valid. A root rank of -1 signifies that the result 171 will be distributed back to all processes, i.e. a broadcast. 172 173 """ 174 if isinstance(a, (int, float)): 175 return self.comm.max(a, root) 176 else: 177 tc = a.dtype 178 assert tc == int or tc == float 179 assert is_contiguous(a, tc) 180 assert root == -1 or 0 <= root < self.size 181 self.comm.max(a, root) 182 183 def min(self, a, root=-1): 184 """Find minimal value by an MPI reduce operation of numerical data. 185 186 Parameters: 187 188 a: ndarray or value (type int or float) 189 Numerical data to find the minimal value of across all ranks in 190 the communicator group. NB: Find global minimum from local min. 191 If the data is a single value of type int or float (no complex), 192 the result is returned because the input argument is immutable. 193 Otherwise, the reduce operation is carried out in-place such 194 that the elements of the input array will represent the min of 195 the equivalent elements across all processes in the group. 196 root: int (default -1) 197 Rank of the root process, on which the outcome of the reduce 198 operation is valid. A root rank of -1 signifies that the result 199 will be distributed back to all processes, i.e. a broadcast. 200 201 """ 202 if isinstance(a, (int, float)): 203 return self.comm.min(a, root) 204 else: 205 tc = a.dtype 206 assert tc == int or tc == float 207 assert is_contiguous(a, tc) 208 assert root == -1 or 0 <= root < self.size 209 self.comm.min(a, root) 210 211 def scatter(self, a, b, root): 212 """Distribute data from one rank to all other processes in a group. 213 214 Parameters: 215 216 a: ndarray (ignored on all ranks different from root; use None) 217 Source of the data to distribute, i.e. send buffer on root rank. 218 b: ndarray 219 Destination of the distributed data, i.e. local receive buffer. 220 The size of this array multiplied by the number of process in 221 the group must match the size of the source array on the root. 222 root: int 223 Rank of the root process, from which the source data originates. 224 225 The reverse operation is ``gather``. 226 227 Example:: 228 229 # The master has all the interesting data. Distribute it. 230 if comm.rank == 0: 231 data = np.random.normal(size=N*comm.size) 232 else: 233 data = None 234 mydata = np.empty(N, dtype=float) 235 comm.scatter(data, mydata, 0) 236 237 # .. which is equivalent to .. 238 239 if comm.rank == 0: 240 # Extract my part directly 241 mydata[:] = data[0:N] 242 # Distribute parts to the slaves 243 for rank in range(1, comm.size): 244 buf = data[rank*N:(rank+1)*N] 245 comm.send(buf, rank, tag=123) 246 else: 247 # Receive from the master 248 comm.receive(mydata, 0, tag=123) 249 250 """ 251 if self.rank == root: 252 assert a.dtype == b.dtype 253 assert a.size == self.size * b.size 254 assert a.flags.contiguous 255 assert b.flags.contiguous 256 assert 0 <= root < self.size 257 self.comm.scatter(a, b, root) 258 259 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls): 260 """All-to-all in a group. 261 262 Parameters: 263 264 sbuffer: ndarray 265 Source of the data to distribute, i.e., send buffers on all ranks 266 scounts: ndarray 267 Integer array equal to the group size specifying the number of 268 elements to send to each processor 269 sdispls: ndarray 270 Integer array (of length group size). Entry j specifies the 271 displacement (relative to sendbuf from which to take the 272 outgoing data destined for process j) 273 rbuffer: ndarray 274 Destination of the distributed data, i.e., local receive buffer. 275 rcounts: ndarray 276 Integer array equal to the group size specifying the maximum 277 number of elements that can be received from each processor. 278 rdispls: 279 Integer array (of length group size). Entry i specifies the 280 displacement (relative to recvbuf at which to place the incoming 281 data from process i 282 """ 283 assert sbuffer.flags.contiguous 284 assert scounts.flags.contiguous 285 assert sdispls.flags.contiguous 286 assert rbuffer.flags.contiguous 287 assert rcounts.flags.contiguous 288 assert rdispls.flags.contiguous 289 assert sbuffer.dtype == rbuffer.dtype 290 291 for arr in [scounts, sdispls, rcounts, rdispls]: 292 assert arr.dtype == int, arr.dtype 293 assert len(arr) == self.size 294 295 assert np.all(0 <= sdispls) 296 assert np.all(0 <= rdispls) 297 assert np.all(sdispls + scounts <= sbuffer.size) 298 assert np.all(rdispls + rcounts <= rbuffer.size) 299 self.comm.alltoallv(sbuffer, scounts, sdispls, 300 rbuffer, rcounts, rdispls) 301 302 def all_gather(self, a, b): 303 """Gather data from all ranks onto all processes in a group. 304 305 Parameters: 306 307 a: ndarray 308 Source of the data to gather, i.e. send buffer of this rank. 309 b: ndarray 310 Destination of the distributed data, i.e. receive buffer. 311 The size of this array must match the size of the distributed 312 source arrays multiplied by the number of process in the group. 313 314 Example:: 315 316 # All ranks have parts of interesting data. Gather on all ranks. 317 mydata = np.random.normal(size=N) 318 data = np.empty(N*comm.size, dtype=float) 319 comm.all_gather(mydata, data) 320 321 # .. which is equivalent to .. 322 323 if comm.rank == 0: 324 # Insert my part directly 325 data[0:N] = mydata 326 # Gather parts from the slaves 327 buf = np.empty(N, dtype=float) 328 for rank in range(1, comm.size): 329 comm.receive(buf, rank, tag=123) 330 data[rank*N:(rank+1)*N] = buf 331 else: 332 # Send to the master 333 comm.send(mydata, 0, tag=123) 334 # Broadcast from master to all slaves 335 comm.broadcast(data, 0) 336 337 """ 338 assert a.flags.contiguous 339 assert b.flags.contiguous 340 assert b.dtype == a.dtype 341 assert (b.shape[0] == self.size and a.shape == b.shape[1:] or 342 a.size * self.size == b.size) 343 self.comm.all_gather(a, b) 344 345 def gather(self, a, root, b=None): 346 """Gather data from all ranks onto a single process in a group. 347 348 Parameters: 349 350 a: ndarray 351 Source of the data to gather, i.e. send buffer of this rank. 352 root: int 353 Rank of the root process, on which the data is to be gathered. 354 b: ndarray (ignored on all ranks different from root; default None) 355 Destination of the distributed data, i.e. root's receive buffer. 356 The size of this array must match the size of the distributed 357 source arrays multiplied by the number of process in the group. 358 359 The reverse operation is ``scatter``. 360 361 Example:: 362 363 # All ranks have parts of interesting data. Gather it on master. 364 mydata = np.random.normal(size=N) 365 if comm.rank == 0: 366 data = np.empty(N*comm.size, dtype=float) 367 else: 368 data = None 369 comm.gather(mydata, 0, data) 370 371 # .. which is equivalent to .. 372 373 if comm.rank == 0: 374 # Extract my part directly 375 data[0:N] = mydata 376 # Gather parts from the slaves 377 buf = np.empty(N, dtype=float) 378 for rank in range(1, comm.size): 379 comm.receive(buf, rank, tag=123) 380 data[rank*N:(rank+1)*N] = buf 381 else: 382 # Send to the master 383 comm.send(mydata, 0, tag=123) 384 385 """ 386 assert a.flags.contiguous 387 assert 0 <= root < self.size 388 if root == self.rank: 389 assert b.flags.contiguous and b.dtype == a.dtype 390 assert (b.shape[0] == self.size and a.shape == b.shape[1:] or 391 a.size * self.size == b.size) 392 self.comm.gather(a, root, b) 393 else: 394 assert b is None 395 self.comm.gather(a, root) 396 397 def broadcast(self, a, root): 398 """Share data from a single process to all ranks in a group. 399 400 Parameters: 401 402 a: ndarray 403 Data, i.e. send buffer on root rank, receive buffer elsewhere. 404 Note that after the broadcast, all ranks have the same data. 405 root: int 406 Rank of the root process, from which the data is to be shared. 407 408 Example:: 409 410 # All ranks have parts of interesting data. Take a given index. 411 mydata[:] = np.random.normal(size=N) 412 413 # Who has the element at global index 13? Everybody needs it! 414 index = 13 415 root, myindex = divmod(index, N) 416 element = np.empty(1, dtype=float) 417 if comm.rank == root: 418 # This process has the requested element so extract it 419 element[:] = mydata[myindex] 420 421 # Broadcast from owner to everyone else 422 comm.broadcast(element, root) 423 424 # .. which is equivalent to .. 425 426 if comm.rank == root: 427 # We are root so send it to the other ranks 428 for rank in range(comm.size): 429 if rank != root: 430 comm.send(element, rank, tag=123) 431 else: 432 # We don't have it so receive from root 433 comm.receive(element, root, tag=123) 434 435 """ 436 assert 0 <= root < self.size 437 assert is_contiguous(a) 438 self.comm.broadcast(a, root) 439 440 def sendreceive(self, a, dest, b, src, sendtag=123, recvtag=123): 441 assert 0 <= dest < self.size 442 assert dest != self.rank 443 assert is_contiguous(a) 444 assert 0 <= src < self.size 445 assert src != self.rank 446 assert is_contiguous(b) 447 return self.comm.sendreceive(a, dest, b, src, sendtag, recvtag) 448 449 def send(self, a, dest, tag=123, block=True): 450 assert 0 <= dest < self.size 451 assert dest != self.rank 452 assert is_contiguous(a) 453 if not block: 454 pass # assert sys.getrefcount(a) > 3 455 return self.comm.send(a, dest, tag, block) 456 457 def ssend(self, a, dest, tag=123): 458 assert 0 <= dest < self.size 459 assert dest != self.rank 460 assert is_contiguous(a) 461 return self.comm.ssend(a, dest, tag) 462 463 def receive(self, a, src, tag=123, block=True): 464 assert 0 <= src < self.size 465 assert src != self.rank 466 assert is_contiguous(a) 467 return self.comm.receive(a, src, tag, block) 468 469 def test(self, request): 470 """Test whether a non-blocking MPI operation has completed. A boolean 471 is returned immediately and the request is not modified in any way. 472 473 Parameters: 474 475 request: MPI request 476 Request e.g. returned from send/receive when block=False is used. 477 478 """ 479 return self.comm.test(request) 480 481 def testall(self, requests): 482 """Test whether non-blocking MPI operations have completed. A boolean 483 is returned immediately but requests may have been deallocated as a 484 result, provided they have completed before or during this invokation. 485 486 Parameters: 487 488 request: MPI request 489 Request e.g. returned from send/receive when block=False is used. 490 491 """ 492 return self.comm.testall(requests) # may deallocate requests! 493 494 def wait(self, request): 495 """Wait for a non-blocking MPI operation to complete before returning. 496 497 Parameters: 498 499 request: MPI request 500 Request e.g. returned from send/receive when block=False is used. 501 502 """ 503 self.comm.wait(request) 504 505 def waitall(self, requests): 506 """Wait for non-blocking MPI operations to complete before returning. 507 508 Parameters: 509 510 requests: list 511 List of MPI requests e.g. aggregated from returned requests of 512 multiple send/receive calls where block=False was used. 513 514 """ 515 self.comm.waitall(requests) 516 517 def abort(self, errcode): 518 """Terminate MPI execution environment of all tasks in the group. 519 This function only returns in the advent of an error occurring. 520 521 Parameters: 522 523 errcode: int 524 Error code to return to the invoking environment. 525 526 """ 527 return self.comm.abort(errcode) 528 529 def name(self): 530 """Return the name of the processor as a string.""" 531 return self.comm.name() 532 533 def barrier(self): 534 """Block execution until all process have reached this point.""" 535 self.comm.barrier() 536 537 def compare(self, othercomm): 538 """Compare communicator to other. 539 540 Returns 'ident' if they are identical, 'congruent' if they are 541 copies of each other, 'similar' if they are permutations of 542 each other, and otherwise 'unequal'. 543 544 This method corresponds to MPI_Comm_compare.""" 545 if isinstance(self.comm, SerialCommunicator): 546 return self.comm.compare(othercomm.comm) # argh! 547 result = self.comm.compare(othercomm.get_c_object()) 548 assert result in ['ident', 'congruent', 'similar', 'unequal'] 549 return result 550 551 def translate_ranks(self, other, ranks): 552 """"Translate ranks from communicator to other. 553 554 ranks must be valid on this communicator. Returns ranks 555 on other communicator corresponding to the same processes. 556 Ranks that are not defined on the other communicator are 557 assigned values of -1. (In contrast to MPI which would 558 assign MPI_UNDEFINED).""" 559 assert hasattr(other, 'translate_ranks'), \ 560 'Excpected communicator, got %s' % other 561 assert all(0 <= rank for rank in ranks) 562 assert all(rank < self.size for rank in ranks) 563 if isinstance(self.comm, SerialCommunicator): 564 return self.comm.translate_ranks(other.comm, ranks) # argh! 565 otherranks = self.comm.translate_ranks(other.get_c_object(), ranks) 566 assert all(-1 <= rank for rank in otherranks) 567 assert ranks.dtype == otherranks.dtype 568 return otherranks 569 570 def get_members(self): 571 """Return the subset of processes which are members of this MPI group 572 in terms of the ranks they are assigned on the parent communicator. 573 For the world communicator, this is all integers up to ``size``. 574 575 Example:: 576 577 >>> world.rank, world.size # doctest: +SKIP 578 (3, 4) 579 >>> world.get_members() # doctest: +SKIP 580 array([0, 1, 2, 3]) 581 >>> comm = world.new_communicator(np.array([2, 3])) # doctest: +SKIP 582 >>> comm.rank, comm.size # doctest: +SKIP 583 (1, 2) 584 >>> comm.get_members() # doctest: +SKIP 585 array([2, 3]) 586 >>> comm.get_members()[comm.rank] == world.rank # doctest: +SKIP 587 True 588 589 """ 590 return self.comm.get_members() 591 592 def get_c_object(self): 593 """Return the C-object wrapped by this debug interface. 594 595 Whenever a communicator object is passed to C code, that object 596 must be a proper C-object - *not* e.g. this debug wrapper. For 597 this reason. The C-communicator object has a get_c_object() 598 implementation which returns itself; thus, always call 599 comm.get_c_object() and pass the resulting object to the C code. 600 """ 601 c_obj = self.comm.get_c_object() 602 assert isinstance(c_obj, _gpaw.Communicator) 603 return c_obj 604 605 606# Serial communicator 607class SerialCommunicator: 608 size = 1 609 rank = 0 610 611 def __init__(self, parent=None): 612 self.parent = parent 613 614 def sum(self, array, root=-1): 615 if isinstance(array, (int, float, complex)): 616 return array 617 618 def scatter(self, s, r, root): 619 r[:] = s 620 621 def min(self, value, root=-1): 622 return value 623 624 def max(self, value, root=-1): 625 return value 626 627 def broadcast(self, buf, root): 628 pass 629 630 def send(self, buff, root, tag=123, block=True): 631 pass 632 633 def barrier(self): 634 pass 635 636 def gather(self, a, root, b): 637 b[:] = a 638 639 def all_gather(self, a, b): 640 b[:] = a 641 642 def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls): 643 assert len(scounts) == 1 644 assert len(sdispls) == 1 645 assert len(rcounts) == 1 646 assert len(rdispls) == 1 647 assert len(sbuffer) == len(rbuffer) 648 649 rbuffer[rdispls[0]:rdispls[0] + rcounts[0]] = \ 650 sbuffer[sdispls[0]:sdispls[0] + scounts[0]] 651 652 def new_communicator(self, ranks): 653 if self.rank not in ranks: 654 return None 655 comm = SerialCommunicator(parent=self) 656 comm.size = len(ranks) 657 return comm 658 659 def test(self, request): 660 return 1 661 662 def testall(self, requests): 663 return 1 664 665 def wait(self, request): 666 raise NotImplementedError('Calls to mpi wait should not happen in ' 667 'serial mode') 668 669 def waitall(self, requests): 670 if not requests: 671 return 672 raise NotImplementedError('Calls to mpi waitall should not happen in ' 673 'serial mode') 674 675 def get_members(self): 676 return np.array([0]) 677 678 def compare(self, other): 679 if self == other: 680 return 'ident' 681 elif isinstance(other, SerialCommunicator): 682 return 'congruent' 683 else: 684 raise NotImplementedError('Compare serial comm to other') 685 686 def translate_ranks(self, other, ranks): 687 if isinstance(other, SerialCommunicator): 688 assert all(rank == 0 for rank in ranks) or gpaw.dry_run 689 return np.zeros(len(ranks), dtype=int) 690 raise NotImplementedError( 691 'Translate non-trivial ranks with serial comm') 692 693 def get_c_object(self): 694 if gpaw.dry_run: 695 return None # won't actually be passed to C 696 raise NotImplementedError('Should not get C-object for serial comm') 697 698 699serial_comm = SerialCommunicator() 700 701have_mpi = world is not None 702 703if world is None: 704 world = serial_comm 705 706if gpaw.debug: 707 serial_comm = _Communicator(serial_comm) # type: ignore 708 world = _Communicator(world) # type: ignore 709 710rank = world.rank 711size = world.size 712parallel = (size > 1) 713 714if world.size != aseworld.size: 715 raise RuntimeError('Please use "gpaw python" to run in parallel') 716 717 718def broadcast(obj, root=0, comm=world): 719 """Broadcast a Python object across an MPI communicator and return it.""" 720 if comm.rank == root: 721 assert obj is not None 722 b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 723 else: 724 assert obj is None 725 b = None 726 b = broadcast_bytes(b, root, comm) 727 if comm.rank == root: 728 return obj 729 else: 730 return pickle.loads(b) 731 732 733def broadcast_float(x, comm): 734 array = np.array([x]) 735 comm.broadcast(array, 0) 736 return array[0] 737 738 739def synchronize_atoms(atoms, comm, tolerance=1e-8): 740 """Synchronize atoms between multiple CPUs removing numerical noise. 741 742 If the atoms differ significantly, raise ValueError on all ranks. 743 The error object contains the ranks where the check failed. 744 745 In debug mode, write atoms to files in case of failure.""" 746 747 if len(atoms) == 0: 748 return 749 750 if comm.rank == 0: 751 src = (atoms.positions, atoms.cell, atoms.numbers, atoms.pbc) 752 else: 753 src = None 754 755 # XXX replace with ase.cell.same_cell in the future 756 # (if that functions gets to exist) 757 # def same_cell(cell1, cell2): 758 # return ((cell1 is None) == (cell2 is None) and 759 # (cell1 is None or (cell1 == cell2).all())) 760 761 # Cell vectors should be compared with a tolerance like positions? 762 def same_cell(cell1, cell2, tolerance=1e-8): 763 return ((cell1 is None) == (cell2 is None) and 764 (cell1 is None or (abs(cell1 - cell2).max() <= tolerance))) 765 766 positions, cell, numbers, pbc = broadcast(src, root=0, comm=comm) 767 ok = (len(positions) == len(atoms.positions) and 768 (abs(positions - atoms.positions).max() <= tolerance) and 769 (numbers == atoms.numbers).all() and 770 same_cell(cell, atoms.cell) and 771 (pbc == atoms.pbc).all()) 772 773 # We need to fail equally on all ranks to avoid trouble. Thus 774 # we use an array to gather check results from everyone. 775 my_fail = np.array(not ok, dtype=bool) 776 all_fail = np.zeros(comm.size, dtype=bool) 777 comm.all_gather(my_fail, all_fail) 778 779 if all_fail.any(): 780 if gpaw.debug: 781 with open('synchronize_atoms_r%d.pckl' % comm.rank, 'wb') as fd: 782 pickle.dump((atoms.positions, atoms.cell, 783 atoms.numbers, atoms.pbc, 784 positions, cell, numbers, pbc), fd) 785 err_ranks = np.arange(comm.size)[all_fail] 786 raise ValueError('Mismatch of Atoms objects. In debug ' 787 'mode, atoms will be dumped to files.', 788 err_ranks) 789 790 atoms.positions = positions 791 atoms.cell = cell 792 793 794def broadcast_string(string=None, root=0, comm=world): 795 if comm.rank == root: 796 string = string.encode() 797 return broadcast_bytes(string, root, comm).decode() 798 799 800def broadcast_bytes(b=None, root=0, comm=world): 801 """Broadcast a bytes across an MPI communicator and return it.""" 802 if comm.rank == root: 803 assert isinstance(b, bytes) 804 n = np.array(len(b), int) 805 else: 806 assert b is None 807 n = np.zeros(1, int) 808 comm.broadcast(n, root) 809 if comm.rank == root: 810 b = np.frombuffer(b, np.int8) 811 else: 812 b = np.zeros(n, np.int8) 813 comm.broadcast(b, root) 814 return b.tobytes() 815 816 817def broadcast_array(array: np.ndarray, *communicators) -> np.ndarray: 818 """Broadcast np.ndarray across sequence of MPI-communicators.""" 819 comms = list(communicators) 820 while comms: 821 comm = comms.pop() 822 if all(comm.rank == 0 for comm in comms): 823 comm.broadcast(array, 0) 824 return array 825 826 827def send(obj, rank: int, comm) -> None: 828 """Send object to rank on the MPI communicator comm.""" 829 b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL) 830 comm.send(np.array(len(b)), rank) 831 comm.send(np.frombuffer(b, np.int8).copy(), rank) 832 833 834def receive(rank: int, comm) -> Any: 835 """Receive object from rank on the MPI communicator comm.""" 836 n = np.array(0) 837 comm.receive(n, rank) 838 buf = np.empty(int(n), np.int8) 839 comm.receive(buf, rank) 840 return pickle.loads(buf.tobytes()) 841 842 843def send_string(string, rank, comm=world): 844 b = string.encode() 845 comm.send(np.array(len(b)), rank) 846 comm.send(np.frombuffer(b, np.int8).copy(), rank) 847 848 849def receive_string(rank, comm=world): 850 n = np.array(0) 851 comm.receive(n, rank) 852 string = np.empty(n, np.int8) 853 comm.receive(string, rank) 854 return string.tobytes().decode() 855 856 857def alltoallv_string(send_dict, comm=world): 858 scounts = np.zeros(comm.size, dtype=int) 859 sdispls = np.zeros(comm.size, dtype=int) 860 stotal = 0 861 for proc in range(comm.size): 862 if proc in send_dict: 863 data = np.frombuffer(send_dict[proc].encode(), np.int8) 864 scounts[proc] = data.size 865 sdispls[proc] = stotal 866 stotal += scounts[proc] 867 868 rcounts = np.zeros(comm.size, dtype=int) 869 comm.alltoallv(scounts, np.ones(comm.size, dtype=int), 870 np.arange(comm.size, dtype=int), 871 rcounts, np.ones(comm.size, dtype=int), 872 np.arange(comm.size, dtype=int)) 873 rdispls = np.zeros(comm.size, dtype=int) 874 rtotal = 0 875 for proc in range(comm.size): 876 rdispls[proc] = rtotal 877 rtotal += rcounts[proc] 878 # rtotal += rcounts[proc] # CHECK: is this correct? 879 880 sbuffer = np.zeros(stotal, dtype=np.int8) 881 for proc in range(comm.size): 882 sbuffer[sdispls[proc]:(sdispls[proc] + scounts[proc])] = ( 883 np.frombuffer(send_dict[proc].encode(), np.int8)) 884 885 rbuffer = np.zeros(rtotal, dtype=np.int8) 886 comm.alltoallv(sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls) 887 888 rdict = {} 889 for proc in range(comm.size): 890 i = rdispls[proc] 891 rdict[proc] = rbuffer[i:i + rcounts[proc]].tobytes().decode() 892 893 return rdict 894 895 896def ibarrier(timeout=None, root=0, tag=123, comm=world): 897 """Non-blocking barrier returning a list of requests to wait for. 898 An optional time-out may be given, turning the call into a blocking 899 barrier with an upper time limit, beyond which an exception is raised.""" 900 requests = [] 901 byte = np.ones(1, dtype=np.int8) 902 if comm.rank == root: 903 # Everybody else: 904 for rank in range(comm.size): 905 if rank == root: 906 continue 907 rbuf, sbuf = np.empty_like(byte), byte.copy() 908 requests.append(comm.send(sbuf, rank, tag=2 * tag + 0, 909 block=False)) 910 requests.append(comm.receive(rbuf, rank, tag=2 * tag + 1, 911 block=False)) 912 else: 913 rbuf, sbuf = np.empty_like(byte), byte 914 requests.append(comm.receive(rbuf, root, tag=2 * tag + 0, block=False)) 915 requests.append(comm.send(sbuf, root, tag=2 * tag + 1, block=False)) 916 917 if comm.size == 1 or timeout is None: 918 return requests 919 920 t0 = time.time() 921 while not comm.testall(requests): # automatic clean-up upon success 922 if time.time() - t0 > timeout: 923 raise RuntimeError('MPI barrier timeout.') 924 return [] 925 926 927def run(iterators): 928 """Run through list of iterators one step at a time.""" 929 if not isinstance(iterators, list): 930 # It's a single iterator - empty it: 931 for i in iterators: 932 pass 933 return 934 935 if len(iterators) == 0: 936 return 937 938 while True: 939 try: 940 results = [next(iter) for iter in iterators] 941 except StopIteration: 942 return results 943 944 945class Parallelization: 946 def __init__(self, comm, nkpts): 947 self.comm = comm 948 self.size = comm.size 949 self.nkpts = nkpts 950 951 self.kpt = None 952 self.domain = None 953 self.band = None 954 955 self.nclaimed = 1 956 self.navail = comm.size 957 958 def set(self, kpt=None, domain=None, band=None): 959 if kpt is not None: 960 self.kpt = kpt 961 if domain is not None: 962 self.domain = domain 963 if band is not None: 964 self.band = band 965 966 nclaimed = 1 967 for group, name in zip([self.kpt, self.domain, self.band], 968 ['k-point', 'domain', 'band']): 969 if group is not None: 970 assert group > 0, ('Bad: Only {} cores requested for ' 971 '{} parallelization'.format(group, name)) 972 if self.size % group != 0: 973 msg = ('Cannot parallelize as the ' 974 'communicator size %d is not divisible by the ' 975 'requested number %d of ranks for %s ' 976 'parallelization' % (self.size, group, name)) 977 raise ValueError(msg) 978 nclaimed *= group 979 navail = self.size // nclaimed 980 981 assert self.size % nclaimed == 0 982 assert self.size % navail == 0 983 984 self.navail = navail 985 self.nclaimed = nclaimed 986 987 def get_communicator_sizes(self, kpt=None, domain=None, band=None): 988 self.set(kpt=kpt, domain=domain, band=band) 989 self.autofinalize() 990 return self.kpt, self.domain, self.band 991 992 def build_communicators(self, kpt=None, domain=None, band=None, 993 order='kbd'): 994 """Construct communicators. 995 996 Returns a communicator for k-points, domains, bands and 997 k-points/bands. The last one "unites" all ranks that are 998 responsible for the same domain. 999 1000 The order must be a permutation of the characters 'kbd', each 1001 corresponding to each a parallelization mode. The last 1002 character signifies the communicator that will be assigned 1003 contiguous ranks, i.e. order='kbd' will yield contiguous 1004 domain ranks, whereas order='kdb' will yield contiguous band 1005 ranks.""" 1006 self.set(kpt=kpt, domain=domain, band=band) 1007 self.autofinalize() 1008 1009 comm = self.comm 1010 rank = comm.rank 1011 communicators = {} 1012 parent_stride = self.size 1013 offset = 0 1014 1015 groups = dict(k=self.kpt, b=self.band, d=self.domain) 1016 1017 # Build communicators in hierarchical manner 1018 # The ranks in the first group have largest separation while 1019 # the ranks in the last group are next to each other 1020 for name in order: 1021 group = groups[name] 1022 stride = parent_stride // group 1023 # First rank in this group 1024 r0 = rank % stride + offset 1025 # Last rank in this group 1026 r1 = r0 + stride * group 1027 ranks = np.arange(r0, r1, stride) 1028 communicators[name] = comm.new_communicator(ranks) 1029 parent_stride = stride 1030 # Offset for the next communicator 1031 offset += communicators[name].rank * stride 1032 1033 # We want a communicator for kpts/bands, i.e. the complement of the 1034 # grid comm: a communicator uniting all cores with the same domain. 1035 c1, c2, c3 = [communicators[name] for name in order] 1036 allranks = [range(c1.size), range(c2.size), range(c3.size)] 1037 1038 def get_communicator_complement(name): 1039 relevant_ranks = list(allranks) 1040 relevant_ranks[order.find(name)] = [communicators[name].rank] 1041 ranks = np.array([r3 + c3.size * (r2 + c2.size * r1) 1042 for r1 in relevant_ranks[0] 1043 for r2 in relevant_ranks[1] 1044 for r3 in relevant_ranks[2]]) 1045 return comm.new_communicator(ranks) 1046 1047 # The communicator of all processes that share a domain, i.e. 1048 # the combination of k-point and band dommunicators. 1049 communicators['D'] = get_communicator_complement('d') 1050 # For each k-point comm rank, a communicator of all 1051 # band/domain ranks. This is typically used with ScaLAPACK 1052 # and LCAO orbital stuff. 1053 communicators['K'] = get_communicator_complement('k') 1054 return communicators 1055 1056 def autofinalize(self): 1057 if self.kpt is None: 1058 self.set(kpt=self.get_optimal_kpt_parallelization()) 1059 if self.domain is None: 1060 self.set(domain=self.navail) 1061 if self.band is None: 1062 self.set(band=self.navail) 1063 1064 if self.navail > 1: 1065 assignments = dict(kpt=self.kpt, 1066 domain=self.domain, 1067 band=self.band) 1068 raise gpaw.BadParallelization( 1069 f'All the CPUs must be used. Have {assignments} but ' 1070 f'{self.navail} times more are available.') 1071 1072 def get_optimal_kpt_parallelization(self, kptprioritypower=1.4): 1073 if self.domain and self.band: 1074 # Try to use all the CPUs for k-point parallelization 1075 ncpus = min(self.nkpts, self.navail) 1076 return ncpus 1077 ncpuvalues, wastevalues = self.find_kpt_parallelizations() 1078 scores = ((self.navail // ncpuvalues) * 1079 ncpuvalues**kptprioritypower)**(1.0 - wastevalues) 1080 arg = np.argmax(scores) 1081 ncpus = ncpuvalues[arg] 1082 return ncpus 1083 1084 def find_kpt_parallelizations(self): 1085 nkpts = self.nkpts 1086 ncpuvalues = [] 1087 wastevalues = [] 1088 1089 ncpus = nkpts 1090 while ncpus > 0: 1091 if self.navail % ncpus == 0: 1092 nkptsmax = -(-nkpts // ncpus) 1093 effort = nkptsmax * ncpus 1094 efficiency = nkpts / float(effort) 1095 waste = 1.0 - efficiency 1096 wastevalues.append(waste) 1097 ncpuvalues.append(ncpus) 1098 ncpus -= 1 1099 return np.array(ncpuvalues), np.array(wastevalues) 1100 1101 1102def cleanup(): 1103 error = getattr(sys, 'last_type', None) 1104 if error is not None: # else: Python script completed or raise SystemExit 1105 if parallel and not (gpaw.dry_run > 1): 1106 sys.stdout.flush() 1107 sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred. ' 1108 'Calling MPI_Abort!\n') % (world.rank, error)) 1109 sys.stderr.flush() 1110 # Give other nodes a moment to crash by themselves (perhaps 1111 # producing helpful error messages) 1112 time.sleep(10) 1113 world.abort(42) 1114 1115 1116def print_mpi_stack_trace(type, value, tb): 1117 """Format exceptions nicely when running in parallel. 1118 1119 Use this function as an except hook. Adds rank 1120 and line number to each line of the exception. Lines will 1121 still be printed from different ranks in random order, but 1122 one can grep for a rank or run 'sort' on the output to obtain 1123 readable data.""" 1124 1125 exception_text = traceback.format_exception(type, value, tb) 1126 ndigits = len(str(world.size - 1)) 1127 rankstring = ('%%0%dd' % ndigits) % world.rank 1128 1129 lines = [] 1130 # The exception elements may contain newlines themselves 1131 for element in exception_text: 1132 lines.extend(element.splitlines()) 1133 1134 line_ndigits = len(str(len(lines) - 1)) 1135 1136 for lineno, line in enumerate(lines): 1137 lineno = ('%%0%dd' % line_ndigits) % lineno 1138 sys.stderr.write('rank=%s L%s: %s\n' % (rankstring, lineno, line)) 1139 1140 1141if world.size > 1: # Triggers for dry-run communicators too, but we care not. 1142 sys.excepthook = print_mpi_stack_trace 1143 1144 1145def exit(error='Manual exit'): 1146 # Note that exit must be called on *all* MPI tasks 1147 atexit._exithandlers = [] # not needed because we are intentially exiting 1148 if parallel and not (gpaw.dry_run > 1): 1149 sys.stdout.flush() 1150 sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred. ' + 1151 'Calling MPI_Finalize!\n') % (world.rank, error)) 1152 sys.stderr.flush() 1153 else: 1154 cleanup(error) 1155 world.barrier() # sync up before exiting 1156 sys.exit() # quit for serial case, return to _gpaw.c for parallel case 1157 1158 1159atexit.register(cleanup) 1160