1#!/usr/bin/env python 2# Copyright 2014-2018,2021 The PySCF Developers. All Rights Reserved. 3# 4# Licensed under the Apache License, Version 2.0 (the "License"); 5# you may not use this file except in compliance with the License. 6# You may obtain a copy of the License at 7# 8# http://www.apache.org/licenses/LICENSE-2.0 9# 10# Unless required by applicable law or agreed to in writing, software 11# distributed under the License is distributed on an "AS IS" BASIS, 12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 13# See the License for the specific language governing permissions and 14# limitations under the License. 15 16import sys 17 18import time 19import threading 20import traceback 21import numpy 22from mpi4py import MPI 23from . import mpi_pool 24from .mpi_pool import MPIPool 25 26_registry = {} 27 28if 'pool' not in _registry: 29 import atexit 30 pool = MPIPool(debug=False) 31 _registry['pool'] = pool 32 atexit.register(pool.close) 33 34comm = pool.comm 35rank = pool.rank 36 37def static_partition(tasks): 38 size = len(tasks) 39 segsize = (size+pool.size-1) // pool.size 40 start = pool.rank * segsize 41 stop = min(size, start+segsize) 42 return tasks[start:stop] 43 44def work_balanced_partition(tasks, costs=None): 45 if costs is None: 46 costs = numpy.ones(tasks) 47 if rank == 0: 48 segsize = float(sum(costs)) / pool.size 49 loads = [] 50 cum_costs = numpy.cumsum(costs) 51 start_id = 0 52 for k in range(pool.size): 53 stop_id = numpy.argmin(abs(cum_costs - (k+1)*segsize)) + 1 54 stop_id = max(stop_id, start_id+1) 55 loads.append([start_id,stop_id]) 56 start_id = stop_id 57 comm.bcast(loads) 58 else: 59 loads = comm.bcast() 60 if rank < len(loads): 61 start, stop = loads[rank] 62 return tasks[start:stop] 63 else: 64 return tasks[:0] 65 66INQUIRY = 50050 67TASK = 50051 68def work_share_partition(tasks, interval=.02, loadmin=1): 69 loadmin = max(loadmin, len(tasks)//50//pool.size) 70 rest_tasks = [x for x in tasks[loadmin*pool.size:]] 71 tasks = tasks[loadmin*rank:loadmin*rank+loadmin] 72 def distribute_task(): 73 while True: 74 load = len(tasks) 75 if rank == 0: 76 for i in range(pool.size): 77 if i != 0: 78 load = comm.recv(source=i, tag=INQUIRY) 79 if rest_tasks: 80 if load <= loadmin: 81 task = rest_tasks.pop(0) 82 comm.send(task, i, tag=TASK) 83 else: 84 comm.send('OUT_OF_TASK', i, tag=TASK) 85 else: 86 comm.send(load, 0, tag=INQUIRY) 87 if comm.Iprobe(source=0, tag=TASK): 88 tasks.append(comm.recv(source=0, tag=TASK)) 89 if isinstance(tasks[-1], str) and tasks[-1] == 'OUT_OF_TASK': 90 return 91 time.sleep(interval) 92 93 tasks_handler = threading.Thread(target=distribute_task) 94 tasks_handler.start() 95 96 while True: 97 if tasks: 98 task = tasks.pop(0) 99 if isinstance(task, str) and task == 'OUT_OF_TASK': 100 tasks_handler.join() 101 return 102 yield task 103 104def work_stealing_partition(tasks, interval=.0001): 105 tasks = static_partition(tasks) 106 out_of_task = [False] 107 def task_daemon(): 108 while True: 109 time.sleep(interval) 110 while comm.Iprobe(source=MPI.ANY_SOURCE, tag=INQUIRY): 111 src, req = comm.recv(source=MPI.ANY_SOURCE, tag=INQUIRY) 112 if isinstance(req, str) and req == 'STOP_DAEMON': 113 return 114 elif tasks: 115 comm.send(tasks.pop(), src, tag=TASK) 116 elif src == 0 and isinstance(req, str) and req == 'ALL_DONE': 117 comm.send(out_of_task[0], src, tag=TASK) 118 elif out_of_task[0]: 119 comm.send('OUT_OF_TASK', src, tag=TASK) 120 else: 121 comm.send('BYPASS', src, tag=TASK) 122 def prepare_to_stop(): 123 out_of_task[0] = True 124 if rank == 0: 125 while True: 126 done = [] 127 for i in range(1, pool.size): 128 comm.send((0,'ALL_DONE'), i, tag=INQUIRY) 129 done.append(comm.recv(source=i, tag=TASK)) 130 if all(done): 131 break 132 time.sleep(interval) 133 for i in range(pool.size): 134 comm.send((0,'STOP_DAEMON'), i, tag=INQUIRY) 135 tasks_handler.join() 136 137 if pool.size > 1: 138 tasks_handler = threading.Thread(target=task_daemon) 139 tasks_handler.start() 140 141 while tasks: 142 task = tasks.pop(0) 143 yield task 144 145 if pool.size > 1: 146 def next_proc(proc): 147 proc = (proc+1) % pool.size 148 if proc == rank: 149 proc = (proc+1) % pool.size 150 return proc 151 proc_last = (rank + 1) % pool.size 152 proc = next_proc(proc_last) 153 154 while True: 155 comm.send((rank,None), proc, tag=INQUIRY) 156 task = comm.recv(source=proc, tag=TASK) 157 if isinstance(task, str) and task == 'OUT_OF_TASK': 158 prepare_to_stop() 159 return 160 elif isinstance(task, str) and task == 'BYPASS': 161 if proc == proc_last: 162 prepare_to_stop() 163 return 164 else: 165 proc = next_proc(proc) 166 else: 167 if proc != proc_last: 168 proc_last, proc = proc, next_proc(proc) 169 yield task 170 171def bcast(buf, root=0): 172 buf = numpy.asarray(buf, order='C') 173 shape, dtype = comm.bcast((buf.shape, buf.dtype.char)) 174 if rank != root: 175 buf = numpy.empty(shape, dtype=dtype) 176 comm.Bcast(buf, root) 177 return buf 178 179## Useful when sending large batches of arrays 180#def safe_bcast(buf, root=0): 181 182 183def reduce(sendbuf, op=MPI.SUM, root=0): 184 sendbuf = numpy.asarray(sendbuf, order='C') 185 shape, mpi_dtype = comm.bcast((sendbuf.shape, sendbuf.dtype.char)) 186 _assert(sendbuf.shape == shape and sendbuf.dtype.char == mpi_dtype) 187 188 recvbuf = numpy.zeros_like(sendbuf) 189 comm.Reduce(sendbuf, recvbuf, op, root) 190 if rank == root: 191 return recvbuf 192 else: 193 return sendbuf 194 195def allreduce(sendbuf, op=MPI.SUM): 196 sendbuf = numpy.asarray(sendbuf, order='C') 197 shape, mpi_dtype = comm.bcast((sendbuf.shape, sendbuf.dtype.char)) 198 _assert(sendbuf.shape == shape and sendbuf.dtype.char == mpi_dtype) 199 200 recvbuf = numpy.zeros_like(sendbuf) 201 comm.Allreduce(sendbuf, recvbuf, op) 202 return recvbuf 203 204def gather(sendbuf, root=0): 205 #if pool.debug: 206 # if rank == 0: 207 # res = [sendbuf] 208 # for k in range(1, pool.size): 209 # dat = comm.recv(source=k) 210 # res.append(dat) 211 # return numpy.vstack([x for x in res if len(x) > 0]) 212 # else: 213 # comm.send(sendbuf, dest=0) 214 # return sendbuf 215 216 sendbuf = numpy.asarray(sendbuf, order='C') 217 mpi_dtype = sendbuf.dtype.char 218 if rank == root: 219 size_dtype = comm.gather((sendbuf.size, mpi_dtype), root=root) 220 _assert(all(x[1] == mpi_dtype for x in size_dtype if x[0] > 0)) 221 counts = numpy.array([x[0] for x in size_dtype]) 222 displs = numpy.append(0, numpy.cumsum(counts[:-1])) 223 recvbuf = numpy.empty(sum(counts), dtype=sendbuf.dtype) 224 comm.Gatherv([sendbuf.ravel(), mpi_dtype], 225 [recvbuf.ravel(), counts, displs, mpi_dtype], root) 226 return recvbuf.reshape((-1,) + sendbuf[0].shape) 227 else: 228 comm.gather((sendbuf.size, mpi_dtype), root=root) 229 comm.Gatherv([sendbuf.ravel(), mpi_dtype], None, root) 230 return sendbuf 231 232def allgather(sendbuf): 233 sendbuf = numpy.asarray(sendbuf, order='C') 234 shape, mpi_dtype = comm.bcast((sendbuf.shape, sendbuf.dtype.char)) 235 _assert(sendbuf.dtype.char == mpi_dtype or sendbuf.size == 0) 236 counts = numpy.array(comm.allgather(sendbuf.size)) 237 displs = numpy.append(0, numpy.cumsum(counts[:-1])) 238 recvbuf = numpy.empty(sum(counts), dtype=sendbuf.dtype) 239 comm.Allgatherv([sendbuf.ravel(), mpi_dtype], 240 [recvbuf.ravel(), counts, displs, mpi_dtype]) 241 return recvbuf.reshape((-1,) + shape[1:]) 242 243def alltoall(sendbuf, split_recvbuf=False): 244 if isinstance(sendbuf, numpy.ndarray): 245 mpi_dtype = comm.bcast(sendbuf.dtype.char) 246 sendbuf = numpy.asarray(sendbuf, mpi_dtype, 'C') 247 nrow = sendbuf.shape[0] 248 ncol = sendbuf.size // nrow 249 segsize = (nrow+pool.size-1) // pool.size * ncol 250 sdispls = numpy.arange(0, pool.size*segsize, segsize) 251 sdispls[sdispls>sendbuf.size] = sendbuf.size 252 scounts = numpy.append(sdispls[1:]-sdispls[:-1], sendbuf.size-sdispls[-1]) 253 else: 254 assert(len(sendbuf) == pool.size) 255 mpi_dtype = comm.bcast(sendbuf[0].dtype.char) 256 sendbuf = [numpy.asarray(x, mpi_dtype).ravel() for x in sendbuf] 257 scounts = numpy.asarray([x.size for x in sendbuf]) 258 sdispls = numpy.append(0, numpy.cumsum(scounts[:-1])) 259 sendbuf = numpy.hstack(sendbuf) 260 261 rcounts = numpy.asarray(comm.alltoall(scounts)) 262 rdispls = numpy.append(0, numpy.cumsum(rcounts[:-1])) 263 264 recvbuf = numpy.empty(sum(rcounts), dtype=mpi_dtype) 265 comm.Alltoallv([sendbuf.ravel(), scounts, sdispls, mpi_dtype], 266 [recvbuf.ravel(), rcounts, rdispls, mpi_dtype]) 267 if split_recvbuf: 268 return [recvbuf[p0:p0+c] for p0,c in zip(rdispls,rcounts)] 269 else: 270 return recvbuf 271 272def sendrecv(sendbuf, source=0, dest=0): 273 if source == dest: 274 return sendbuf 275 276 if rank == source: 277 sendbuf = numpy.asarray(sendbuf, order='C') 278 comm.send((sendbuf.shape, sendbuf.dtype), dest=dest) 279 comm.Send(sendbuf, dest=dest) 280 return sendbuf 281 elif rank == dest: 282 shape, dtype = comm.recv(source=source) 283 recvbuf = numpy.empty(shape, dtype=dtype) 284 comm.Recv(recvbuf, source=source) 285 return recvbuf 286 287def _assert(condition): 288 if not condition: 289 sys.stderr.write(''.join(traceback.format_stack()[:-1])) 290 comm.Abort() 291 292def register_for(obj): 293 global _registry 294 key = id(obj) 295 # Keep track of the object in a global registry. On slave nodes, the 296 # object can be accessed from global registry. 297 _registry[key] = obj 298 keys = comm.gather(key) 299 if rank == 0: 300 obj._reg_keys = keys 301 return obj 302 303def del_registry(reg_keys): 304 if reg_keys: 305 def f(reg_keys): 306 from mpi4pyscf.tools import mpi 307 mpi._registry.pop(reg_keys[mpi.rank]) 308 pool.apply(f, reg_keys, reg_keys) 309 return [] 310