1# Copyright (C) 2003  CAMP
2# Please see the accompanying LICENSE file for further information.
3
4import sys
5import time
6import traceback
7import atexit
8import pickle
9from contextlib import contextmanager
10from typing import Any
11
12from ase.parallel import world as aseworld
13import numpy as np
14
15import gpaw
16from .broadcast_imports import world
17import _gpaw
18
19MASTER = 0
20
21
22def is_contiguous(*args, **kwargs):
23    from gpaw.utilities import is_contiguous
24    return is_contiguous(*args, **kwargs)
25
26
27@contextmanager
28def broadcast_exception(comm):
29    """Make sure all ranks get a possible exception raised.
30
31    This example::
32
33        with broadcast_exception(world):
34            if world.rank == 0:
35                x = 1 / 0
36
37    will raise ZeroDivisionError in the whole world.
38    """
39    # Each core will send -1 on success or its rank on failure.
40    try:
41        yield
42    except Exception as ex:
43        rank = comm.max(comm.rank)
44        if rank == comm.rank:
45            broadcast(ex, rank, comm)
46            raise
47    else:
48        rank = comm.max(-1)
49    # rank will now be the highest failing rank or -1
50    if rank >= 0:
51        raise broadcast(None, rank, comm)
52
53
54class _Communicator:
55    def __init__(self, comm, parent=None):
56        """Construct a wrapper of the C-object for any MPI-communicator.
57
58        Parameters:
59
60        comm: MPI-communicator
61            Communicator.
62
63        Attributes:
64
65        ============  ======================================================
66        ``size``      Number of ranks in the MPI group.
67        ``rank``      Number of this CPU in the MPI group.
68        ``parent``    Parent MPI-communicator.
69        ============  ======================================================
70        """
71        self.comm = comm
72        self.size = comm.size
73        self.rank = comm.rank
74        self.parent = parent  # XXX check C-object against comm.parent?
75
76    def new_communicator(self, ranks):
77        """Create a new MPI communicator for a subset of ranks in a group.
78        Must be called with identical arguments by all relevant processes.
79
80        Note that a valid communicator is only returned to the processes
81        which are included in the new group; other ranks get None returned.
82
83        Parameters:
84
85        ranks: ndarray (type int)
86            List of integers of the ranks to include in the new group.
87            Note that these ranks correspond to indices in the current
88            group whereas the rank attribute in the new communicators
89            correspond to their respective index in the subset.
90
91        """
92
93        comm = self.comm.new_communicator(ranks)
94        if comm is None:
95            # This cpu is not in the new communicator:
96            return None
97        else:
98            return _Communicator(comm, parent=self)
99
100    def sum(self, a, root=-1):
101        """Perform summation by MPI reduce operations of numerical data.
102
103        Parameters:
104
105        a: ndarray or value (type int, float or complex)
106            Numerical data to sum over all ranks in the communicator group.
107            If the data is a single value of type int, float or complex,
108            the result is returned because the input argument is immutable.
109            Otherwise, the reduce operation is carried out in-place such
110            that the elements of the input array will represent the sum of
111            the equivalent elements across all processes in the group.
112        root: int (default -1)
113            Rank of the root process, on which the outcome of the reduce
114            operation is valid. A root rank of -1 signifies that the result
115            will be distributed back to all processes, i.e. a broadcast.
116
117        """
118        if isinstance(a, (int, float, complex)):
119            return self.comm.sum(a, root)
120        else:
121            tc = a.dtype
122            assert tc == int or tc == float or tc == complex
123            assert is_contiguous(a, tc)
124            assert root == -1 or 0 <= root < self.size
125            self.comm.sum(a, root)
126
127    def product(self, a, root=-1):
128        """Do multiplication by MPI reduce operations of numerical data.
129
130        Parameters:
131
132        a: ndarray or value (type int or float)
133            Numerical data to multiply across all ranks in the communicator
134            group. NB: Find the global product from the local products.
135            If the data is a single value of type int or float (no complex),
136            the result is returned because the input argument is immutable.
137            Otherwise, the reduce operation is carried out in-place such
138            that the elements of the input array will represent the product
139            of the equivalent elements across all processes in the group.
140        root: int (default -1)
141            Rank of the root process, on which the outcome of the reduce
142            operation is valid. A root rank of -1 signifies that the result
143            will be distributed back to all processes, i.e. a broadcast.
144
145        """
146        if isinstance(a, (int, float)):
147            return self.comm.product(a, root)
148        else:
149            tc = a.dtype
150            assert tc == int or tc == float
151            assert is_contiguous(a, tc)
152            assert root == -1 or 0 <= root < self.size
153            self.comm.product(a, root)
154
155    def max(self, a, root=-1):
156        """Find maximal value by an MPI reduce operation of numerical data.
157
158        Parameters:
159
160        a: ndarray or value (type int or float)
161            Numerical data to find the maximum value of across all ranks in
162            the communicator group. NB: Find global maximum from local max.
163            If the data is a single value of type int or float (no complex),
164            the result is returned because the input argument is immutable.
165            Otherwise, the reduce operation is carried out in-place such
166            that the elements of the input array will represent the max of
167            the equivalent elements across all processes in the group.
168        root: int (default -1)
169            Rank of the root process, on which the outcome of the reduce
170            operation is valid. A root rank of -1 signifies that the result
171            will be distributed back to all processes, i.e. a broadcast.
172
173        """
174        if isinstance(a, (int, float)):
175            return self.comm.max(a, root)
176        else:
177            tc = a.dtype
178            assert tc == int or tc == float
179            assert is_contiguous(a, tc)
180            assert root == -1 or 0 <= root < self.size
181            self.comm.max(a, root)
182
183    def min(self, a, root=-1):
184        """Find minimal value by an MPI reduce operation of numerical data.
185
186        Parameters:
187
188        a: ndarray or value (type int or float)
189            Numerical data to find the minimal value of across all ranks in
190            the communicator group. NB: Find global minimum from local min.
191            If the data is a single value of type int or float (no complex),
192            the result is returned because the input argument is immutable.
193            Otherwise, the reduce operation is carried out in-place such
194            that the elements of the input array will represent the min of
195            the equivalent elements across all processes in the group.
196        root: int (default -1)
197            Rank of the root process, on which the outcome of the reduce
198            operation is valid. A root rank of -1 signifies that the result
199            will be distributed back to all processes, i.e. a broadcast.
200
201        """
202        if isinstance(a, (int, float)):
203            return self.comm.min(a, root)
204        else:
205            tc = a.dtype
206            assert tc == int or tc == float
207            assert is_contiguous(a, tc)
208            assert root == -1 or 0 <= root < self.size
209            self.comm.min(a, root)
210
211    def scatter(self, a, b, root):
212        """Distribute data from one rank to all other processes in a group.
213
214        Parameters:
215
216        a: ndarray (ignored on all ranks different from root; use None)
217            Source of the data to distribute, i.e. send buffer on root rank.
218        b: ndarray
219            Destination of the distributed data, i.e. local receive buffer.
220            The size of this array multiplied by the number of process in
221            the group must match the size of the source array on the root.
222        root: int
223            Rank of the root process, from which the source data originates.
224
225        The reverse operation is ``gather``.
226
227        Example::
228
229          # The master has all the interesting data. Distribute it.
230          if comm.rank == 0:
231              data = np.random.normal(size=N*comm.size)
232          else:
233              data = None
234          mydata = np.empty(N, dtype=float)
235          comm.scatter(data, mydata, 0)
236
237          # .. which is equivalent to ..
238
239          if comm.rank == 0:
240              # Extract my part directly
241              mydata[:] = data[0:N]
242              # Distribute parts to the slaves
243              for rank in range(1, comm.size):
244                  buf = data[rank*N:(rank+1)*N]
245                  comm.send(buf, rank, tag=123)
246          else:
247              # Receive from the master
248              comm.receive(mydata, 0, tag=123)
249
250        """
251        if self.rank == root:
252            assert a.dtype == b.dtype
253            assert a.size == self.size * b.size
254            assert a.flags.contiguous
255        assert b.flags.contiguous
256        assert 0 <= root < self.size
257        self.comm.scatter(a, b, root)
258
259    def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls):
260        """All-to-all in a group.
261
262        Parameters:
263
264        sbuffer: ndarray
265            Source of the data to distribute, i.e., send buffers on all ranks
266        scounts: ndarray
267            Integer array equal to the group size specifying the number of
268            elements to send to each processor
269        sdispls: ndarray
270            Integer array (of length group size). Entry j specifies the
271            displacement (relative to sendbuf from which to take the
272            outgoing data destined for process j)
273        rbuffer: ndarray
274            Destination of the distributed data, i.e., local receive buffer.
275        rcounts: ndarray
276            Integer array equal to the group size specifying the maximum
277            number of elements that can be received from each processor.
278        rdispls:
279            Integer array (of length group size). Entry i specifies the
280            displacement (relative to recvbuf at which to place the incoming
281            data from process i
282        """
283        assert sbuffer.flags.contiguous
284        assert scounts.flags.contiguous
285        assert sdispls.flags.contiguous
286        assert rbuffer.flags.contiguous
287        assert rcounts.flags.contiguous
288        assert rdispls.flags.contiguous
289        assert sbuffer.dtype == rbuffer.dtype
290
291        for arr in [scounts, sdispls, rcounts, rdispls]:
292            assert arr.dtype == int, arr.dtype
293            assert len(arr) == self.size
294
295        assert np.all(0 <= sdispls)
296        assert np.all(0 <= rdispls)
297        assert np.all(sdispls + scounts <= sbuffer.size)
298        assert np.all(rdispls + rcounts <= rbuffer.size)
299        self.comm.alltoallv(sbuffer, scounts, sdispls,
300                            rbuffer, rcounts, rdispls)
301
302    def all_gather(self, a, b):
303        """Gather data from all ranks onto all processes in a group.
304
305        Parameters:
306
307        a: ndarray
308            Source of the data to gather, i.e. send buffer of this rank.
309        b: ndarray
310            Destination of the distributed data, i.e. receive buffer.
311            The size of this array must match the size of the distributed
312            source arrays multiplied by the number of process in the group.
313
314        Example::
315
316          # All ranks have parts of interesting data. Gather on all ranks.
317          mydata = np.random.normal(size=N)
318          data = np.empty(N*comm.size, dtype=float)
319          comm.all_gather(mydata, data)
320
321          # .. which is equivalent to ..
322
323          if comm.rank == 0:
324              # Insert my part directly
325              data[0:N] = mydata
326              # Gather parts from the slaves
327              buf = np.empty(N, dtype=float)
328              for rank in range(1, comm.size):
329                  comm.receive(buf, rank, tag=123)
330                  data[rank*N:(rank+1)*N] = buf
331          else:
332              # Send to the master
333              comm.send(mydata, 0, tag=123)
334          # Broadcast from master to all slaves
335          comm.broadcast(data, 0)
336
337        """
338        assert a.flags.contiguous
339        assert b.flags.contiguous
340        assert b.dtype == a.dtype
341        assert (b.shape[0] == self.size and a.shape == b.shape[1:] or
342                a.size * self.size == b.size)
343        self.comm.all_gather(a, b)
344
345    def gather(self, a, root, b=None):
346        """Gather data from all ranks onto a single process in a group.
347
348        Parameters:
349
350        a: ndarray
351            Source of the data to gather, i.e. send buffer of this rank.
352        root: int
353            Rank of the root process, on which the data is to be gathered.
354        b: ndarray (ignored on all ranks different from root; default None)
355            Destination of the distributed data, i.e. root's receive buffer.
356            The size of this array must match the size of the distributed
357            source arrays multiplied by the number of process in the group.
358
359        The reverse operation is ``scatter``.
360
361        Example::
362
363          # All ranks have parts of interesting data. Gather it on master.
364          mydata = np.random.normal(size=N)
365          if comm.rank == 0:
366              data = np.empty(N*comm.size, dtype=float)
367          else:
368              data = None
369          comm.gather(mydata, 0, data)
370
371          # .. which is equivalent to ..
372
373          if comm.rank == 0:
374              # Extract my part directly
375              data[0:N] = mydata
376              # Gather parts from the slaves
377              buf = np.empty(N, dtype=float)
378              for rank in range(1, comm.size):
379                  comm.receive(buf, rank, tag=123)
380                  data[rank*N:(rank+1)*N] = buf
381          else:
382              # Send to the master
383              comm.send(mydata, 0, tag=123)
384
385        """
386        assert a.flags.contiguous
387        assert 0 <= root < self.size
388        if root == self.rank:
389            assert b.flags.contiguous and b.dtype == a.dtype
390            assert (b.shape[0] == self.size and a.shape == b.shape[1:] or
391                    a.size * self.size == b.size)
392            self.comm.gather(a, root, b)
393        else:
394            assert b is None
395            self.comm.gather(a, root)
396
397    def broadcast(self, a, root):
398        """Share data from a single process to all ranks in a group.
399
400        Parameters:
401
402        a: ndarray
403            Data, i.e. send buffer on root rank, receive buffer elsewhere.
404            Note that after the broadcast, all ranks have the same data.
405        root: int
406            Rank of the root process, from which the data is to be shared.
407
408        Example::
409
410          # All ranks have parts of interesting data. Take a given index.
411          mydata[:] = np.random.normal(size=N)
412
413          # Who has the element at global index 13? Everybody needs it!
414          index = 13
415          root, myindex = divmod(index, N)
416          element = np.empty(1, dtype=float)
417          if comm.rank == root:
418              # This process has the requested element so extract it
419              element[:] = mydata[myindex]
420
421          # Broadcast from owner to everyone else
422          comm.broadcast(element, root)
423
424          # .. which is equivalent to ..
425
426          if comm.rank == root:
427              # We are root so send it to the other ranks
428              for rank in range(comm.size):
429                  if rank != root:
430                      comm.send(element, rank, tag=123)
431          else:
432              # We don't have it so receive from root
433              comm.receive(element, root, tag=123)
434
435        """
436        assert 0 <= root < self.size
437        assert is_contiguous(a)
438        self.comm.broadcast(a, root)
439
440    def sendreceive(self, a, dest, b, src, sendtag=123, recvtag=123):
441        assert 0 <= dest < self.size
442        assert dest != self.rank
443        assert is_contiguous(a)
444        assert 0 <= src < self.size
445        assert src != self.rank
446        assert is_contiguous(b)
447        return self.comm.sendreceive(a, dest, b, src, sendtag, recvtag)
448
449    def send(self, a, dest, tag=123, block=True):
450        assert 0 <= dest < self.size
451        assert dest != self.rank
452        assert is_contiguous(a)
453        if not block:
454            pass  # assert sys.getrefcount(a) > 3
455        return self.comm.send(a, dest, tag, block)
456
457    def ssend(self, a, dest, tag=123):
458        assert 0 <= dest < self.size
459        assert dest != self.rank
460        assert is_contiguous(a)
461        return self.comm.ssend(a, dest, tag)
462
463    def receive(self, a, src, tag=123, block=True):
464        assert 0 <= src < self.size
465        assert src != self.rank
466        assert is_contiguous(a)
467        return self.comm.receive(a, src, tag, block)
468
469    def test(self, request):
470        """Test whether a non-blocking MPI operation has completed. A boolean
471        is returned immediately and the request is not modified in any way.
472
473        Parameters:
474
475        request: MPI request
476            Request e.g. returned from send/receive when block=False is used.
477
478        """
479        return self.comm.test(request)
480
481    def testall(self, requests):
482        """Test whether non-blocking MPI operations have completed. A boolean
483        is returned immediately but requests may have been deallocated as a
484        result, provided they have completed before or during this invokation.
485
486        Parameters:
487
488        request: MPI request
489            Request e.g. returned from send/receive when block=False is used.
490
491        """
492        return self.comm.testall(requests)  # may deallocate requests!
493
494    def wait(self, request):
495        """Wait for a non-blocking MPI operation to complete before returning.
496
497        Parameters:
498
499        request: MPI request
500            Request e.g. returned from send/receive when block=False is used.
501
502        """
503        self.comm.wait(request)
504
505    def waitall(self, requests):
506        """Wait for non-blocking MPI operations to complete before returning.
507
508        Parameters:
509
510        requests: list
511            List of MPI requests e.g. aggregated from returned requests of
512            multiple send/receive calls where block=False was used.
513
514        """
515        self.comm.waitall(requests)
516
517    def abort(self, errcode):
518        """Terminate MPI execution environment of all tasks in the group.
519        This function only returns in the advent of an error occurring.
520
521        Parameters:
522
523        errcode: int
524            Error code to return to the invoking environment.
525
526        """
527        return self.comm.abort(errcode)
528
529    def name(self):
530        """Return the name of the processor as a string."""
531        return self.comm.name()
532
533    def barrier(self):
534        """Block execution until all process have reached this point."""
535        self.comm.barrier()
536
537    def compare(self, othercomm):
538        """Compare communicator to other.
539
540        Returns 'ident' if they are identical, 'congruent' if they are
541        copies of each other, 'similar' if they are permutations of
542        each other, and otherwise 'unequal'.
543
544        This method corresponds to MPI_Comm_compare."""
545        if isinstance(self.comm, SerialCommunicator):
546            return self.comm.compare(othercomm.comm)  # argh!
547        result = self.comm.compare(othercomm.get_c_object())
548        assert result in ['ident', 'congruent', 'similar', 'unequal']
549        return result
550
551    def translate_ranks(self, other, ranks):
552        """"Translate ranks from communicator to other.
553
554        ranks must be valid on this communicator.  Returns ranks
555        on other communicator corresponding to the same processes.
556        Ranks that are not defined on the other communicator are
557        assigned values of -1.  (In contrast to MPI which would
558        assign MPI_UNDEFINED)."""
559        assert hasattr(other, 'translate_ranks'), \
560            'Excpected communicator, got %s' % other
561        assert all(0 <= rank for rank in ranks)
562        assert all(rank < self.size for rank in ranks)
563        if isinstance(self.comm, SerialCommunicator):
564            return self.comm.translate_ranks(other.comm, ranks)  # argh!
565        otherranks = self.comm.translate_ranks(other.get_c_object(), ranks)
566        assert all(-1 <= rank for rank in otherranks)
567        assert ranks.dtype == otherranks.dtype
568        return otherranks
569
570    def get_members(self):
571        """Return the subset of processes which are members of this MPI group
572        in terms of the ranks they are assigned on the parent communicator.
573        For the world communicator, this is all integers up to ``size``.
574
575        Example::
576
577          >>> world.rank, world.size  # doctest: +SKIP
578          (3, 4)
579          >>> world.get_members()  # doctest: +SKIP
580          array([0, 1, 2, 3])
581          >>> comm = world.new_communicator(np.array([2, 3]))  # doctest: +SKIP
582          >>> comm.rank, comm.size  # doctest: +SKIP
583          (1, 2)
584          >>> comm.get_members()  # doctest: +SKIP
585          array([2, 3])
586          >>> comm.get_members()[comm.rank] == world.rank  # doctest: +SKIP
587          True
588
589        """
590        return self.comm.get_members()
591
592    def get_c_object(self):
593        """Return the C-object wrapped by this debug interface.
594
595        Whenever a communicator object is passed to C code, that object
596        must be a proper C-object - *not* e.g. this debug wrapper.  For
597        this reason.  The C-communicator object has a get_c_object()
598        implementation which returns itself; thus, always call
599        comm.get_c_object() and pass the resulting object to the C code.
600        """
601        c_obj = self.comm.get_c_object()
602        assert isinstance(c_obj, _gpaw.Communicator)
603        return c_obj
604
605
606# Serial communicator
607class SerialCommunicator:
608    size = 1
609    rank = 0
610
611    def __init__(self, parent=None):
612        self.parent = parent
613
614    def sum(self, array, root=-1):
615        if isinstance(array, (int, float, complex)):
616            return array
617
618    def scatter(self, s, r, root):
619        r[:] = s
620
621    def min(self, value, root=-1):
622        return value
623
624    def max(self, value, root=-1):
625        return value
626
627    def broadcast(self, buf, root):
628        pass
629
630    def send(self, buff, root, tag=123, block=True):
631        pass
632
633    def barrier(self):
634        pass
635
636    def gather(self, a, root, b):
637        b[:] = a
638
639    def all_gather(self, a, b):
640        b[:] = a
641
642    def alltoallv(self, sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls):
643        assert len(scounts) == 1
644        assert len(sdispls) == 1
645        assert len(rcounts) == 1
646        assert len(rdispls) == 1
647        assert len(sbuffer) == len(rbuffer)
648
649        rbuffer[rdispls[0]:rdispls[0] + rcounts[0]] = \
650            sbuffer[sdispls[0]:sdispls[0] + scounts[0]]
651
652    def new_communicator(self, ranks):
653        if self.rank not in ranks:
654            return None
655        comm = SerialCommunicator(parent=self)
656        comm.size = len(ranks)
657        return comm
658
659    def test(self, request):
660        return 1
661
662    def testall(self, requests):
663        return 1
664
665    def wait(self, request):
666        raise NotImplementedError('Calls to mpi wait should not happen in '
667                                  'serial mode')
668
669    def waitall(self, requests):
670        if not requests:
671            return
672        raise NotImplementedError('Calls to mpi waitall should not happen in '
673                                  'serial mode')
674
675    def get_members(self):
676        return np.array([0])
677
678    def compare(self, other):
679        if self == other:
680            return 'ident'
681        elif isinstance(other, SerialCommunicator):
682            return 'congruent'
683        else:
684            raise NotImplementedError('Compare serial comm to other')
685
686    def translate_ranks(self, other, ranks):
687        if isinstance(other, SerialCommunicator):
688            assert all(rank == 0 for rank in ranks) or gpaw.dry_run
689            return np.zeros(len(ranks), dtype=int)
690        raise NotImplementedError(
691            'Translate non-trivial ranks with serial comm')
692
693    def get_c_object(self):
694        if gpaw.dry_run:
695            return None  # won't actually be passed to C
696        raise NotImplementedError('Should not get C-object for serial comm')
697
698
699serial_comm = SerialCommunicator()
700
701have_mpi = world is not None
702
703if world is None:
704    world = serial_comm
705
706if gpaw.debug:
707    serial_comm = _Communicator(serial_comm)  # type: ignore
708    world = _Communicator(world)  # type: ignore
709
710rank = world.rank
711size = world.size
712parallel = (size > 1)
713
714if world.size != aseworld.size:
715    raise RuntimeError('Please use "gpaw python" to run in parallel')
716
717
718def broadcast(obj, root=0, comm=world):
719    """Broadcast a Python object across an MPI communicator and return it."""
720    if comm.rank == root:
721        assert obj is not None
722        b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
723    else:
724        assert obj is None
725        b = None
726    b = broadcast_bytes(b, root, comm)
727    if comm.rank == root:
728        return obj
729    else:
730        return pickle.loads(b)
731
732
733def broadcast_float(x, comm):
734    array = np.array([x])
735    comm.broadcast(array, 0)
736    return array[0]
737
738
739def synchronize_atoms(atoms, comm, tolerance=1e-8):
740    """Synchronize atoms between multiple CPUs removing numerical noise.
741
742    If the atoms differ significantly, raise ValueError on all ranks.
743    The error object contains the ranks where the check failed.
744
745    In debug mode, write atoms to files in case of failure."""
746
747    if len(atoms) == 0:
748        return
749
750    if comm.rank == 0:
751        src = (atoms.positions, atoms.cell, atoms.numbers, atoms.pbc)
752    else:
753        src = None
754
755    # XXX replace with ase.cell.same_cell in the future
756    # (if that functions gets to exist)
757    # def same_cell(cell1, cell2):
758    #     return ((cell1 is None) == (cell2 is None) and
759    #             (cell1 is None or (cell1 == cell2).all()))
760
761    # Cell vectors should be compared with a tolerance like positions?
762    def same_cell(cell1, cell2, tolerance=1e-8):
763        return ((cell1 is None) == (cell2 is None) and
764                (cell1 is None or (abs(cell1 - cell2).max() <= tolerance)))
765
766    positions, cell, numbers, pbc = broadcast(src, root=0, comm=comm)
767    ok = (len(positions) == len(atoms.positions) and
768          (abs(positions - atoms.positions).max() <= tolerance) and
769          (numbers == atoms.numbers).all() and
770          same_cell(cell, atoms.cell) and
771          (pbc == atoms.pbc).all())
772
773    # We need to fail equally on all ranks to avoid trouble.  Thus
774    # we use an array to gather check results from everyone.
775    my_fail = np.array(not ok, dtype=bool)
776    all_fail = np.zeros(comm.size, dtype=bool)
777    comm.all_gather(my_fail, all_fail)
778
779    if all_fail.any():
780        if gpaw.debug:
781            with open('synchronize_atoms_r%d.pckl' % comm.rank, 'wb') as fd:
782                pickle.dump((atoms.positions, atoms.cell,
783                             atoms.numbers, atoms.pbc,
784                             positions, cell, numbers, pbc), fd)
785        err_ranks = np.arange(comm.size)[all_fail]
786        raise ValueError('Mismatch of Atoms objects.  In debug '
787                         'mode, atoms will be dumped to files.',
788                         err_ranks)
789
790    atoms.positions = positions
791    atoms.cell = cell
792
793
794def broadcast_string(string=None, root=0, comm=world):
795    if comm.rank == root:
796        string = string.encode()
797    return broadcast_bytes(string, root, comm).decode()
798
799
800def broadcast_bytes(b=None, root=0, comm=world):
801    """Broadcast a bytes across an MPI communicator and return it."""
802    if comm.rank == root:
803        assert isinstance(b, bytes)
804        n = np.array(len(b), int)
805    else:
806        assert b is None
807        n = np.zeros(1, int)
808    comm.broadcast(n, root)
809    if comm.rank == root:
810        b = np.frombuffer(b, np.int8)
811    else:
812        b = np.zeros(n, np.int8)
813    comm.broadcast(b, root)
814    return b.tobytes()
815
816
817def broadcast_array(array: np.ndarray, *communicators) -> np.ndarray:
818    """Broadcast np.ndarray across sequence of MPI-communicators."""
819    comms = list(communicators)
820    while comms:
821        comm = comms.pop()
822        if all(comm.rank == 0 for comm in comms):
823            comm.broadcast(array, 0)
824    return array
825
826
827def send(obj, rank: int, comm) -> None:
828    """Send object to rank on the MPI communicator comm."""
829    b = pickle.dumps(obj, pickle.HIGHEST_PROTOCOL)
830    comm.send(np.array(len(b)), rank)
831    comm.send(np.frombuffer(b, np.int8).copy(), rank)
832
833
834def receive(rank: int, comm) -> Any:
835    """Receive object from rank on the MPI communicator comm."""
836    n = np.array(0)
837    comm.receive(n, rank)
838    buf = np.empty(int(n), np.int8)
839    comm.receive(buf, rank)
840    return pickle.loads(buf.tobytes())
841
842
843def send_string(string, rank, comm=world):
844    b = string.encode()
845    comm.send(np.array(len(b)), rank)
846    comm.send(np.frombuffer(b, np.int8).copy(), rank)
847
848
849def receive_string(rank, comm=world):
850    n = np.array(0)
851    comm.receive(n, rank)
852    string = np.empty(n, np.int8)
853    comm.receive(string, rank)
854    return string.tobytes().decode()
855
856
857def alltoallv_string(send_dict, comm=world):
858    scounts = np.zeros(comm.size, dtype=int)
859    sdispls = np.zeros(comm.size, dtype=int)
860    stotal = 0
861    for proc in range(comm.size):
862        if proc in send_dict:
863            data = np.frombuffer(send_dict[proc].encode(), np.int8)
864            scounts[proc] = data.size
865            sdispls[proc] = stotal
866            stotal += scounts[proc]
867
868    rcounts = np.zeros(comm.size, dtype=int)
869    comm.alltoallv(scounts, np.ones(comm.size, dtype=int),
870                   np.arange(comm.size, dtype=int),
871                   rcounts, np.ones(comm.size, dtype=int),
872                   np.arange(comm.size, dtype=int))
873    rdispls = np.zeros(comm.size, dtype=int)
874    rtotal = 0
875    for proc in range(comm.size):
876        rdispls[proc] = rtotal
877        rtotal += rcounts[proc]
878        # rtotal += rcounts[proc]  # CHECK: is this correct?
879
880    sbuffer = np.zeros(stotal, dtype=np.int8)
881    for proc in range(comm.size):
882        sbuffer[sdispls[proc]:(sdispls[proc] + scounts[proc])] = (
883            np.frombuffer(send_dict[proc].encode(), np.int8))
884
885    rbuffer = np.zeros(rtotal, dtype=np.int8)
886    comm.alltoallv(sbuffer, scounts, sdispls, rbuffer, rcounts, rdispls)
887
888    rdict = {}
889    for proc in range(comm.size):
890        i = rdispls[proc]
891        rdict[proc] = rbuffer[i:i + rcounts[proc]].tobytes().decode()
892
893    return rdict
894
895
896def ibarrier(timeout=None, root=0, tag=123, comm=world):
897    """Non-blocking barrier returning a list of requests to wait for.
898    An optional time-out may be given, turning the call into a blocking
899    barrier with an upper time limit, beyond which an exception is raised."""
900    requests = []
901    byte = np.ones(1, dtype=np.int8)
902    if comm.rank == root:
903        # Everybody else:
904        for rank in range(comm.size):
905            if rank == root:
906                continue
907            rbuf, sbuf = np.empty_like(byte), byte.copy()
908            requests.append(comm.send(sbuf, rank, tag=2 * tag + 0,
909                                      block=False))
910            requests.append(comm.receive(rbuf, rank, tag=2 * tag + 1,
911                                         block=False))
912    else:
913        rbuf, sbuf = np.empty_like(byte), byte
914        requests.append(comm.receive(rbuf, root, tag=2 * tag + 0, block=False))
915        requests.append(comm.send(sbuf, root, tag=2 * tag + 1, block=False))
916
917    if comm.size == 1 or timeout is None:
918        return requests
919
920    t0 = time.time()
921    while not comm.testall(requests):  # automatic clean-up upon success
922        if time.time() - t0 > timeout:
923            raise RuntimeError('MPI barrier timeout.')
924    return []
925
926
927def run(iterators):
928    """Run through list of iterators one step at a time."""
929    if not isinstance(iterators, list):
930        # It's a single iterator - empty it:
931        for i in iterators:
932            pass
933        return
934
935    if len(iterators) == 0:
936        return
937
938    while True:
939        try:
940            results = [next(iter) for iter in iterators]
941        except StopIteration:
942            return results
943
944
945class Parallelization:
946    def __init__(self, comm, nkpts):
947        self.comm = comm
948        self.size = comm.size
949        self.nkpts = nkpts
950
951        self.kpt = None
952        self.domain = None
953        self.band = None
954
955        self.nclaimed = 1
956        self.navail = comm.size
957
958    def set(self, kpt=None, domain=None, band=None):
959        if kpt is not None:
960            self.kpt = kpt
961        if domain is not None:
962            self.domain = domain
963        if band is not None:
964            self.band = band
965
966        nclaimed = 1
967        for group, name in zip([self.kpt, self.domain, self.band],
968                               ['k-point', 'domain', 'band']):
969            if group is not None:
970                assert group > 0, ('Bad: Only {} cores requested for '
971                                   '{} parallelization'.format(group, name))
972                if self.size % group != 0:
973                    msg = ('Cannot parallelize as the '
974                           'communicator size %d is not divisible by the '
975                           'requested number %d of ranks for %s '
976                           'parallelization' % (self.size, group, name))
977                    raise ValueError(msg)
978                nclaimed *= group
979        navail = self.size // nclaimed
980
981        assert self.size % nclaimed == 0
982        assert self.size % navail == 0
983
984        self.navail = navail
985        self.nclaimed = nclaimed
986
987    def get_communicator_sizes(self, kpt=None, domain=None, band=None):
988        self.set(kpt=kpt, domain=domain, band=band)
989        self.autofinalize()
990        return self.kpt, self.domain, self.band
991
992    def build_communicators(self, kpt=None, domain=None, band=None,
993                            order='kbd'):
994        """Construct communicators.
995
996        Returns a communicator for k-points, domains, bands and
997        k-points/bands.  The last one "unites" all ranks that are
998        responsible for the same domain.
999
1000        The order must be a permutation of the characters 'kbd', each
1001        corresponding to each a parallelization mode.  The last
1002        character signifies the communicator that will be assigned
1003        contiguous ranks, i.e. order='kbd' will yield contiguous
1004        domain ranks, whereas order='kdb' will yield contiguous band
1005        ranks."""
1006        self.set(kpt=kpt, domain=domain, band=band)
1007        self.autofinalize()
1008
1009        comm = self.comm
1010        rank = comm.rank
1011        communicators = {}
1012        parent_stride = self.size
1013        offset = 0
1014
1015        groups = dict(k=self.kpt, b=self.band, d=self.domain)
1016
1017        # Build communicators in hierarchical manner
1018        # The ranks in the first group have largest separation while
1019        # the ranks in the last group are next to each other
1020        for name in order:
1021            group = groups[name]
1022            stride = parent_stride // group
1023            # First rank in this group
1024            r0 = rank % stride + offset
1025            # Last rank in this group
1026            r1 = r0 + stride * group
1027            ranks = np.arange(r0, r1, stride)
1028            communicators[name] = comm.new_communicator(ranks)
1029            parent_stride = stride
1030            # Offset for the next communicator
1031            offset += communicators[name].rank * stride
1032
1033        # We want a communicator for kpts/bands, i.e. the complement of the
1034        # grid comm: a communicator uniting all cores with the same domain.
1035        c1, c2, c3 = [communicators[name] for name in order]
1036        allranks = [range(c1.size), range(c2.size), range(c3.size)]
1037
1038        def get_communicator_complement(name):
1039            relevant_ranks = list(allranks)
1040            relevant_ranks[order.find(name)] = [communicators[name].rank]
1041            ranks = np.array([r3 + c3.size * (r2 + c2.size * r1)
1042                              for r1 in relevant_ranks[0]
1043                              for r2 in relevant_ranks[1]
1044                              for r3 in relevant_ranks[2]])
1045            return comm.new_communicator(ranks)
1046
1047        # The communicator of all processes that share a domain, i.e.
1048        # the combination of k-point and band dommunicators.
1049        communicators['D'] = get_communicator_complement('d')
1050        # For each k-point comm rank, a communicator of all
1051        # band/domain ranks.  This is typically used with ScaLAPACK
1052        # and LCAO orbital stuff.
1053        communicators['K'] = get_communicator_complement('k')
1054        return communicators
1055
1056    def autofinalize(self):
1057        if self.kpt is None:
1058            self.set(kpt=self.get_optimal_kpt_parallelization())
1059        if self.domain is None:
1060            self.set(domain=self.navail)
1061        if self.band is None:
1062            self.set(band=self.navail)
1063
1064        if self.navail > 1:
1065            assignments = dict(kpt=self.kpt,
1066                               domain=self.domain,
1067                               band=self.band)
1068            raise gpaw.BadParallelization(
1069                f'All the CPUs must be used.  Have {assignments} but '
1070                f'{self.navail} times more are available.')
1071
1072    def get_optimal_kpt_parallelization(self, kptprioritypower=1.4):
1073        if self.domain and self.band:
1074            # Try to use all the CPUs for k-point parallelization
1075            ncpus = min(self.nkpts, self.navail)
1076            return ncpus
1077        ncpuvalues, wastevalues = self.find_kpt_parallelizations()
1078        scores = ((self.navail // ncpuvalues) *
1079                  ncpuvalues**kptprioritypower)**(1.0 - wastevalues)
1080        arg = np.argmax(scores)
1081        ncpus = ncpuvalues[arg]
1082        return ncpus
1083
1084    def find_kpt_parallelizations(self):
1085        nkpts = self.nkpts
1086        ncpuvalues = []
1087        wastevalues = []
1088
1089        ncpus = nkpts
1090        while ncpus > 0:
1091            if self.navail % ncpus == 0:
1092                nkptsmax = -(-nkpts // ncpus)
1093                effort = nkptsmax * ncpus
1094                efficiency = nkpts / float(effort)
1095                waste = 1.0 - efficiency
1096                wastevalues.append(waste)
1097                ncpuvalues.append(ncpus)
1098            ncpus -= 1
1099        return np.array(ncpuvalues), np.array(wastevalues)
1100
1101
1102def cleanup():
1103    error = getattr(sys, 'last_type', None)
1104    if error is not None:  # else: Python script completed or raise SystemExit
1105        if parallel and not (gpaw.dry_run > 1):
1106            sys.stdout.flush()
1107            sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred.  '
1108                              'Calling MPI_Abort!\n') % (world.rank, error))
1109            sys.stderr.flush()
1110            # Give other nodes a moment to crash by themselves (perhaps
1111            # producing helpful error messages)
1112            time.sleep(10)
1113            world.abort(42)
1114
1115
1116def print_mpi_stack_trace(type, value, tb):
1117    """Format exceptions nicely when running in parallel.
1118
1119    Use this function as an except hook.  Adds rank
1120    and line number to each line of the exception.  Lines will
1121    still be printed from different ranks in random order, but
1122    one can grep for a rank or run 'sort' on the output to obtain
1123    readable data."""
1124
1125    exception_text = traceback.format_exception(type, value, tb)
1126    ndigits = len(str(world.size - 1))
1127    rankstring = ('%%0%dd' % ndigits) % world.rank
1128
1129    lines = []
1130    # The exception elements may contain newlines themselves
1131    for element in exception_text:
1132        lines.extend(element.splitlines())
1133
1134    line_ndigits = len(str(len(lines) - 1))
1135
1136    for lineno, line in enumerate(lines):
1137        lineno = ('%%0%dd' % line_ndigits) % lineno
1138        sys.stderr.write('rank=%s L%s: %s\n' % (rankstring, lineno, line))
1139
1140
1141if world.size > 1:  # Triggers for dry-run communicators too, but we care not.
1142    sys.excepthook = print_mpi_stack_trace
1143
1144
1145def exit(error='Manual exit'):
1146    # Note that exit must be called on *all* MPI tasks
1147    atexit._exithandlers = []  # not needed because we are intentially exiting
1148    if parallel and not (gpaw.dry_run > 1):
1149        sys.stdout.flush()
1150        sys.stderr.write(('GPAW CLEANUP (node %d): %s occurred.  ' +
1151                          'Calling MPI_Finalize!\n') % (world.rank, error))
1152        sys.stderr.flush()
1153    else:
1154        cleanup(error)
1155    world.barrier()  # sync up before exiting
1156    sys.exit()  # quit for serial case, return to _gpaw.c for parallel case
1157
1158
1159atexit.register(cleanup)
1160