1#!/usr/bin/env python
2# Copyright 2014-2018,2021 The PySCF Developers. All Rights Reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8#     http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import sys
17
18import time
19import threading
20import traceback
21import numpy
22from mpi4py import MPI
23from . import mpi_pool
24from .mpi_pool import MPIPool
25
26_registry = {}
27
28if 'pool' not in _registry:
29    import atexit
30    pool = MPIPool(debug=False)
31    _registry['pool'] = pool
32    atexit.register(pool.close)
33
34comm = pool.comm
35rank = pool.rank
36
37def static_partition(tasks):
38    size = len(tasks)
39    segsize = (size+pool.size-1) // pool.size
40    start = pool.rank * segsize
41    stop = min(size, start+segsize)
42    return tasks[start:stop]
43
44def work_balanced_partition(tasks, costs=None):
45    if costs is None:
46        costs = numpy.ones(tasks)
47    if rank == 0:
48        segsize = float(sum(costs)) / pool.size
49        loads = []
50        cum_costs = numpy.cumsum(costs)
51        start_id = 0
52        for k in range(pool.size):
53            stop_id = numpy.argmin(abs(cum_costs - (k+1)*segsize)) + 1
54            stop_id = max(stop_id, start_id+1)
55            loads.append([start_id,stop_id])
56            start_id = stop_id
57        comm.bcast(loads)
58    else:
59        loads = comm.bcast()
60    if rank < len(loads):
61        start, stop = loads[rank]
62        return tasks[start:stop]
63    else:
64        return tasks[:0]
65
66INQUIRY = 50050
67TASK = 50051
68def work_share_partition(tasks, interval=.02, loadmin=1):
69    loadmin = max(loadmin, len(tasks)//50//pool.size)
70    rest_tasks = [x for x in tasks[loadmin*pool.size:]]
71    tasks = tasks[loadmin*rank:loadmin*rank+loadmin]
72    def distribute_task():
73        while True:
74            load = len(tasks)
75            if rank == 0:
76                for i in range(pool.size):
77                    if i != 0:
78                        load = comm.recv(source=i, tag=INQUIRY)
79                    if rest_tasks:
80                        if load <= loadmin:
81                            task = rest_tasks.pop(0)
82                            comm.send(task, i, tag=TASK)
83                    else:
84                        comm.send('OUT_OF_TASK', i, tag=TASK)
85            else:
86                comm.send(load, 0, tag=INQUIRY)
87            if comm.Iprobe(source=0, tag=TASK):
88                tasks.append(comm.recv(source=0, tag=TASK))
89                if isinstance(tasks[-1], str) and tasks[-1] == 'OUT_OF_TASK':
90                    return
91            time.sleep(interval)
92
93    tasks_handler = threading.Thread(target=distribute_task)
94    tasks_handler.start()
95
96    while True:
97        if tasks:
98            task = tasks.pop(0)
99            if isinstance(task, str) and task == 'OUT_OF_TASK':
100                tasks_handler.join()
101                return
102            yield task
103
104def work_stealing_partition(tasks, interval=.0001):
105    tasks = static_partition(tasks)
106    out_of_task = [False]
107    def task_daemon():
108        while True:
109            time.sleep(interval)
110            while comm.Iprobe(source=MPI.ANY_SOURCE, tag=INQUIRY):
111                src, req = comm.recv(source=MPI.ANY_SOURCE, tag=INQUIRY)
112                if isinstance(req, str) and req == 'STOP_DAEMON':
113                    return
114                elif tasks:
115                    comm.send(tasks.pop(), src, tag=TASK)
116                elif src == 0 and isinstance(req, str) and req == 'ALL_DONE':
117                    comm.send(out_of_task[0], src, tag=TASK)
118                elif out_of_task[0]:
119                    comm.send('OUT_OF_TASK', src, tag=TASK)
120                else:
121                    comm.send('BYPASS', src, tag=TASK)
122    def prepare_to_stop():
123        out_of_task[0] = True
124        if rank == 0:
125            while True:
126                done = []
127                for i in range(1, pool.size):
128                    comm.send((0,'ALL_DONE'), i, tag=INQUIRY)
129                    done.append(comm.recv(source=i, tag=TASK))
130                if all(done):
131                    break
132                time.sleep(interval)
133            for i in range(pool.size):
134                comm.send((0,'STOP_DAEMON'), i, tag=INQUIRY)
135        tasks_handler.join()
136
137    if pool.size > 1:
138        tasks_handler = threading.Thread(target=task_daemon)
139        tasks_handler.start()
140
141    while tasks:
142        task = tasks.pop(0)
143        yield task
144
145    if pool.size > 1:
146        def next_proc(proc):
147            proc = (proc+1) % pool.size
148            if proc == rank:
149                proc = (proc+1) % pool.size
150            return proc
151        proc_last = (rank + 1) % pool.size
152        proc = next_proc(proc_last)
153
154        while True:
155            comm.send((rank,None), proc, tag=INQUIRY)
156            task = comm.recv(source=proc, tag=TASK)
157            if isinstance(task, str) and task == 'OUT_OF_TASK':
158                prepare_to_stop()
159                return
160            elif isinstance(task, str) and task == 'BYPASS':
161                if proc == proc_last:
162                    prepare_to_stop()
163                    return
164                else:
165                    proc = next_proc(proc)
166            else:
167                if proc != proc_last:
168                    proc_last, proc = proc, next_proc(proc)
169                yield task
170
171def bcast(buf, root=0):
172    buf = numpy.asarray(buf, order='C')
173    shape, dtype = comm.bcast((buf.shape, buf.dtype.char))
174    if rank != root:
175        buf = numpy.empty(shape, dtype=dtype)
176    comm.Bcast(buf, root)
177    return buf
178
179## Useful when sending large batches of arrays
180#def safe_bcast(buf, root=0):
181
182
183def reduce(sendbuf, op=MPI.SUM, root=0):
184    sendbuf = numpy.asarray(sendbuf, order='C')
185    shape, mpi_dtype = comm.bcast((sendbuf.shape, sendbuf.dtype.char))
186    _assert(sendbuf.shape == shape and sendbuf.dtype.char == mpi_dtype)
187
188    recvbuf = numpy.zeros_like(sendbuf)
189    comm.Reduce(sendbuf, recvbuf, op, root)
190    if rank == root:
191        return recvbuf
192    else:
193        return sendbuf
194
195def allreduce(sendbuf, op=MPI.SUM):
196    sendbuf = numpy.asarray(sendbuf, order='C')
197    shape, mpi_dtype = comm.bcast((sendbuf.shape, sendbuf.dtype.char))
198    _assert(sendbuf.shape == shape and sendbuf.dtype.char == mpi_dtype)
199
200    recvbuf = numpy.zeros_like(sendbuf)
201    comm.Allreduce(sendbuf, recvbuf, op)
202    return recvbuf
203
204def gather(sendbuf, root=0):
205    #if pool.debug:
206    #    if rank == 0:
207    #        res = [sendbuf]
208    #        for k in range(1, pool.size):
209    #            dat = comm.recv(source=k)
210    #            res.append(dat)
211    #        return numpy.vstack([x for x in res if len(x) > 0])
212    #    else:
213    #        comm.send(sendbuf, dest=0)
214    #        return sendbuf
215
216    sendbuf = numpy.asarray(sendbuf, order='C')
217    mpi_dtype = sendbuf.dtype.char
218    if rank == root:
219        size_dtype = comm.gather((sendbuf.size, mpi_dtype), root=root)
220        _assert(all(x[1] == mpi_dtype for x in size_dtype if x[0] > 0))
221        counts = numpy.array([x[0] for x in size_dtype])
222        displs = numpy.append(0, numpy.cumsum(counts[:-1]))
223        recvbuf = numpy.empty(sum(counts), dtype=sendbuf.dtype)
224        comm.Gatherv([sendbuf.ravel(), mpi_dtype],
225                     [recvbuf.ravel(), counts, displs, mpi_dtype], root)
226        return recvbuf.reshape((-1,) + sendbuf[0].shape)
227    else:
228        comm.gather((sendbuf.size, mpi_dtype), root=root)
229        comm.Gatherv([sendbuf.ravel(), mpi_dtype], None, root)
230        return sendbuf
231
232def allgather(sendbuf):
233    sendbuf = numpy.asarray(sendbuf, order='C')
234    shape, mpi_dtype = comm.bcast((sendbuf.shape, sendbuf.dtype.char))
235    _assert(sendbuf.dtype.char == mpi_dtype or sendbuf.size == 0)
236    counts = numpy.array(comm.allgather(sendbuf.size))
237    displs = numpy.append(0, numpy.cumsum(counts[:-1]))
238    recvbuf = numpy.empty(sum(counts), dtype=sendbuf.dtype)
239    comm.Allgatherv([sendbuf.ravel(), mpi_dtype],
240                    [recvbuf.ravel(), counts, displs, mpi_dtype])
241    return recvbuf.reshape((-1,) + shape[1:])
242
243def alltoall(sendbuf, split_recvbuf=False):
244    if isinstance(sendbuf, numpy.ndarray):
245        mpi_dtype = comm.bcast(sendbuf.dtype.char)
246        sendbuf = numpy.asarray(sendbuf, mpi_dtype, 'C')
247        nrow = sendbuf.shape[0]
248        ncol = sendbuf.size // nrow
249        segsize = (nrow+pool.size-1) // pool.size * ncol
250        sdispls = numpy.arange(0, pool.size*segsize, segsize)
251        sdispls[sdispls>sendbuf.size] = sendbuf.size
252        scounts = numpy.append(sdispls[1:]-sdispls[:-1], sendbuf.size-sdispls[-1])
253    else:
254        assert(len(sendbuf) == pool.size)
255        mpi_dtype = comm.bcast(sendbuf[0].dtype.char)
256        sendbuf = [numpy.asarray(x, mpi_dtype).ravel() for x in sendbuf]
257        scounts = numpy.asarray([x.size for x in sendbuf])
258        sdispls = numpy.append(0, numpy.cumsum(scounts[:-1]))
259        sendbuf = numpy.hstack(sendbuf)
260
261    rcounts = numpy.asarray(comm.alltoall(scounts))
262    rdispls = numpy.append(0, numpy.cumsum(rcounts[:-1]))
263
264    recvbuf = numpy.empty(sum(rcounts), dtype=mpi_dtype)
265    comm.Alltoallv([sendbuf.ravel(), scounts, sdispls, mpi_dtype],
266                   [recvbuf.ravel(), rcounts, rdispls, mpi_dtype])
267    if split_recvbuf:
268        return [recvbuf[p0:p0+c] for p0,c in zip(rdispls,rcounts)]
269    else:
270        return recvbuf
271
272def sendrecv(sendbuf, source=0, dest=0):
273    if source == dest:
274        return sendbuf
275
276    if rank == source:
277        sendbuf = numpy.asarray(sendbuf, order='C')
278        comm.send((sendbuf.shape, sendbuf.dtype), dest=dest)
279        comm.Send(sendbuf, dest=dest)
280        return sendbuf
281    elif rank == dest:
282        shape, dtype = comm.recv(source=source)
283        recvbuf = numpy.empty(shape, dtype=dtype)
284        comm.Recv(recvbuf, source=source)
285        return recvbuf
286
287def _assert(condition):
288    if not condition:
289        sys.stderr.write(''.join(traceback.format_stack()[:-1]))
290        comm.Abort()
291
292def register_for(obj):
293    global _registry
294    key = id(obj)
295    # Keep track of the object in a global registry.  On slave nodes, the
296    # object can be accessed from global registry.
297    _registry[key] = obj
298    keys = comm.gather(key)
299    if rank == 0:
300        obj._reg_keys = keys
301    return obj
302
303def del_registry(reg_keys):
304    if reg_keys:
305        def f(reg_keys):
306            from mpi4pyscf.tools import mpi
307            mpi._registry.pop(reg_keys[mpi.rank])
308        pool.apply(f, reg_keys, reg_keys)
309    return []
310