1# This file is part of the Astrometry.net suite.
2# Licensed under a 3-clause BSD style license - see LICENSE
3from __future__ import print_function
4
5import multiprocessing
6
7class FakeAsyncResult(object):
8    def __init__(self, X):
9        self.X = X
10    def wait(self, *a):
11        pass
12    def get(self, *a):
13        return self.X
14    def ready(self):
15        return True
16    def successful(self):
17        return True
18
19class funcwrapper(object):
20    def __init__(self, func):
21        self.func = func
22    def __call__(self, *X):
23        #print 'Trying to call', self.func
24        #print 'with args', X
25        try:
26            return self.func(*X)
27        except:
28            import traceback
29            print('Exception while calling your function:')
30            print('  params:', X)
31            print('  exception:')
32            traceback.print_exc()
33            raise
34
35class memberfuncwrapper(object):
36    def __init__(self, obj, funcname):
37        self.obj = obj
38        self.funcname = funcname
39    def __call__(self, *X):
40        func = self.obj.getattr(self.funcname)
41        #print 'Trying to call', self.func
42        #print 'with args', X
43        try:
44            return func(self.obj, *X)
45        except:
46            import traceback
47            print('Exception while calling your function:')
48            print('  object:', self.obj)
49            print('  member function:', self.funcname)
50            print('  ', func)
51            print('  params:', X)
52            print('  exception:')
53            traceback.print_exc()
54            raise
55
56
57
58class multiproc(object):
59    def __init__(self, nthreads=1, init=None, initargs=[],
60                 map_chunksize=1, pool=None, wrap_all=False):
61        self.wrap_all = wrap_all
62        if pool is not None:
63            self.pool = pool
64            self.applyfunc = self.pool.apply_async
65        else:
66            if nthreads == 1:
67                self.pool = None
68                # self.map = map
69                self.applyfunc = lambda f,a,k: f(*a, **k)
70                if init is not None:
71                    init(*initargs)
72            else:
73                self.pool = multiprocessing.Pool(nthreads, init, initargs)
74                # self.map = self.pool.map
75                self.applyfunc = self.pool.apply_async
76        self.async_results = []
77        self.map_chunksize = map_chunksize
78
79    def map(self, f, args, chunksize=None, wrap=False):
80        cs = chunksize
81        if cs is None:
82            cs = self.map_chunksize
83        if self.pool:
84            if wrap or self.wrap_all:
85                f = funcwrapper(f)
86            #print 'pool.map: f', f
87            #print 'args', args
88            #print 'cs', cs
89            return self.pool.map(f, args, cs)
90        return list(map(f, args))
91
92    def map_async(self, func, iterable, wrap=False):
93        if self.pool is None:
94            return FakeAsyncResult(map(func, iterable))
95        if wrap or self.wrap_all:
96            return self.pool.map_async(funcwrapper(func), iterable)
97        return self.pool.map_async(func, iterable)
98
99    def imap_unordered(self, func, iterable, chunksize=None, wrap=False):
100        cs = chunksize
101        if cs is None:
102            cs = self.map_chunksize
103        if self.pool is None:
104            import itertools
105            if 'imap' in dir(itertools):
106                # py2
107                return itertools.imap(func, iterable)
108            else:
109                # py3
110                return map(func, iterable)
111        if wrap or self.wrap_all:
112            func = funcwrapper(func)
113        return self.pool.imap_unordered(func, iterable, chunksize=cs)
114
115    def apply(self, f, args, wrap=False, kwargs={}):
116        if self.pool is None:
117            return FakeAsyncResult(f(*args, **kwargs))
118        if wrap:
119            f = funcwrapper(f)
120        res = self.applyfunc(f, args, kwargs)
121        self.async_results.append(res)
122        return res
123
124    def waitforall(self):
125        print('Waiting for async results to finish...')
126        for r in self.async_results:
127            print('  waiting for', r)
128            r.wait()
129        print('all done')
130        self.async_results = []
131
132    def close(self):
133        if self.pool is not None:
134            self.pool.close()
135            self.pool = None
136
137