1from concurrent.futures import ThreadPoolExecutor
2from .utils import CountUpDownLatch
3import threading
4import logging
5import multiprocessing
6import os
7import logging.handlers
8from .exceptions import  FileNotFoundError
9try:
10    from queue import Empty     # Python 3
11except ImportError:
12    from Queue import Empty     # Python 2
13
14WORKER_THREAD_PER_PROCESS = 50
15QUEUE_BUCKET_SIZE = 10
16END_QUEUE_SENTINEL = [None, None]
17GLOBAL_EXCEPTION = None
18GLOBAL_EXCEPTION_LOCK = threading.Lock()
19
20
21def monitor_exception(exception_queue, process_ids):
22    global GLOBAL_EXCEPTION
23    logger = logging.getLogger(__name__)
24
25    while True:
26        try:
27            local_exception = exception_queue.get(timeout=0.1)
28            if local_exception == END_QUEUE_SENTINEL:
29                break
30            logger.log(logging.DEBUG, "Setting global exception")
31            GLOBAL_EXCEPTION_LOCK.acquire()
32            GLOBAL_EXCEPTION = local_exception
33            GLOBAL_EXCEPTION_LOCK.release()
34            logger.log(logging.DEBUG, "Closing processes")
35            for p in process_ids:
36                p.terminate()
37            logger.log(logging.DEBUG, "Joining processes")
38            for p in process_ids:
39                p.join()
40            import thread
41            logger.log(logging.DEBUG, "Interrupting main")
42            raise Exception(local_exception)
43        except Empty:
44            pass
45
46
47def log_listener_process(queue):
48    while True:
49        try:
50            record = queue.get(timeout=0.1)
51            queue.task_done()
52            if record == END_QUEUE_SENTINEL:  # We send this as a sentinel to tell the listener to quit.
53                break
54            logger = logging.getLogger(record.name)
55            logger.handlers.clear()
56            logger.handle(record)   # No level or filter logic applied - just do it!
57        except Empty:               # Try again
58            pass
59        except Exception as e:
60            import sys, traceback
61            print('Problems in logging')
62            traceback.print_exc(file=sys.stderr)
63
64
65def multi_processor_change_acl(adl, path=None, method_name="", acl_spec="", number_of_sub_process=None):
66    logger = logging.getLogger(__name__)
67
68    def launch_processes(number_of_processes):
69        if number_of_processes is None:
70            number_of_processes = max(2, multiprocessing.cpu_count() - 1)
71        process_list = []
72        for i in range(number_of_processes):
73            process_list.append(multiprocessing.Process(target=processor,
74                                    args=(adl, file_path_queue, finish_queue_processing_flag,
75                                          method_name, acl_spec, log_queue, exception_queue)))
76            process_list[-1].start()
77        return process_list
78
79    def walk(walk_path):
80        try:
81            paths = []
82            all_files = adl.ls(path=walk_path, detail=True)
83
84            for files in all_files:
85                if files['type'] == 'DIRECTORY':
86                    dir_processed_counter.increment()               # A new directory to process
87                    walk_thread_pool.submit(walk, files['name'])
88                paths.append(files['name'])
89                if len(paths) == QUEUE_BUCKET_SIZE:
90                    file_path_queue.put(list(paths))
91                    paths = []
92            if paths != []:
93                file_path_queue.put(list(paths))  # For leftover paths < bucket_size
94        except FileNotFoundError:
95            pass                    # Continue in case the file was deleted in between
96        except Exception:
97            import traceback
98            logger.exception("Failed to walk for path: " + str(walk_path) + ". Exiting!")
99            exception_queue.put(traceback.format_exc())
100        finally:
101            dir_processed_counter.decrement()           # Processing complete for this directory
102
103    # Initialize concurrency primitives
104    log_queue = multiprocessing.JoinableQueue()
105    exception_queue = multiprocessing.Queue()
106    finish_queue_processing_flag = multiprocessing.Event()
107    file_path_queue = multiprocessing.JoinableQueue()
108    dir_processed_counter = CountUpDownLatch()
109
110    # Start relevant threads and processes
111    log_listener = threading.Thread(target=log_listener_process, args=(log_queue,))
112    log_listener.start()
113    child_processes = launch_processes(number_of_sub_process)
114    exception_monitor_thread = threading.Thread(target=monitor_exception, args=(exception_queue, child_processes))
115    exception_monitor_thread.start()
116    walk_thread_pool = ThreadPoolExecutor(max_workers=WORKER_THREAD_PER_PROCESS)
117
118    # Root directory needs to be explicitly passed
119    file_path_queue.put([path])
120    dir_processed_counter.increment()
121
122    # Processing starts here
123    walk(path)
124
125    if dir_processed_counter.is_zero():  # Done processing all directories. Blocking call.
126        walk_thread_pool.shutdown()
127        file_path_queue.close()          # No new elements to add
128        file_path_queue.join()           # Wait for operations to be done
129        logger.log(logging.DEBUG, "file path queue closed")
130        finish_queue_processing_flag.set()  # Set flag to break loop of child processes
131        for child in child_processes:  # Wait for all child process to finish
132            logger.log(logging.DEBUG, "Joining process: "+str(child.pid))
133            child.join()
134
135    # Cleanup
136    logger.log(logging.DEBUG, "Sending exception sentinel")
137    exception_queue.put(END_QUEUE_SENTINEL)
138    exception_monitor_thread.join()
139    logger.log(logging.DEBUG, "Exception monitor thread finished")
140    logger.log(logging.DEBUG, "Sending logger sentinel")
141    log_queue.put(END_QUEUE_SENTINEL)
142    log_queue.join()
143    log_queue.close()
144    logger.log(logging.DEBUG, "Log queue closed")
145    log_listener.join()
146    logger.log(logging.DEBUG, "Log thread finished")
147
148
149def processor(adl, file_path_queue, finish_queue_processing_flag, method_name, acl_spec, log_queue, exception_queue):
150    logger = logging.getLogger(__name__)
151
152    try:
153        logger.addHandler(logging.handlers.QueueHandler(log_queue))
154        logger.propagate = False                                                        # Prevents double logging
155    except AttributeError:
156        # Python 2 doesn't have Queue Handler. Default to best effort logging.
157        pass
158
159    try:
160        func_table = {"mod_acl": adl.modify_acl_entries, "set_acl": adl.set_acl, "rem_acl": adl.remove_acl_entries}
161        function_thread_pool = ThreadPoolExecutor(max_workers=WORKER_THREAD_PER_PROCESS)
162        adl_function = func_table[method_name]
163        logger.log(logging.DEBUG, "Started processor pid:"+str(os.getpid()))
164
165        def func_wrapper(func, path, spec):
166            try:
167                func(path=path, acl_spec=spec)
168            except FileNotFoundError:
169                logger.exception("File "+str(path)+" not found")
170                # Complete Exception is being logged in the relevant acl method. Don't print exception here
171            except Exception as e:
172                logger.exception("File " + str(path) + " not set. Exception "+str(e))
173
174            logger.log(logging.DEBUG, "Completed running on path:" + str(path))
175
176        while finish_queue_processing_flag.is_set() == False:
177            try:
178                file_paths = file_path_queue.get(timeout=0.1)
179                file_path_queue.task_done()                 # Will not be called if empty
180                for file_path in file_paths:
181                    logger.log(logging.DEBUG, "Starting on path:" + str(file_path))
182                    function_thread_pool.submit(func_wrapper, adl_function, file_path, acl_spec)
183            except Empty:
184                pass
185
186    except Exception as e:
187        import traceback
188        logger.exception("Exception in pid "+str(os.getpid())+"Exception: " + str(e))
189        exception_queue.put(traceback.format_exc())
190    finally:
191        function_thread_pool.shutdown()  # Blocking call. Will wait till all threads are done executing.
192        logger.log(logging.DEBUG, "Finished processor pid: " + str(os.getpid()))
193