1#!/usr/bin/env python
2# -*- coding: utf-8 -*-
3
4# Copyright (C) 2018  Mate Soos
5#
6# This program is free software; you can redistribute it and/or
7# modify it under the terms of the GNU General Public License
8# as published by the Free Software Foundation; version 2
9# of the License.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program; if not, write to the Free Software
18# Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
19# 02110-1301, USA.
20
21from __future__ import print_function
22import os
23import socket
24import sys
25import struct
26import pickle
27import time
28import pprint
29import traceback
30import Queue
31import threading
32import logging
33import server_option_parser
34
35# for importing in systems where "." is not in the PATH
36sys.path.append(os.getcwd())
37from common_aws import *
38import RequestSpotClient
39
40
41def get_n_bytes_from_connection(sock, MSGLEN):
42    chunks = []
43    bytes_recd = 0
44    while bytes_recd < MSGLEN:
45        chunk = sock.recv(min(MSGLEN - bytes_recd, 2048))
46        if chunk == '':
47            raise RuntimeError("socket connection broken")
48        chunks.append(chunk)
49        bytes_recd = bytes_recd + len(chunk)
50
51    return ''.join(chunks)
52
53
54def send_command(sock, command, tosend=None):
55    # note, this is a python issue, we can't set above tosend={}
56    # https://nedbatchelder.com/blog/200806/pylint.html
57    tosend = tosend or {}
58
59    tosend["command"] = command
60    tosend = pickle.dumps(tosend)
61    tosend = struct.pack('!q', len(tosend)) + tosend
62    sock.sendall(tosend)
63
64
65class ToSolve:
66
67    def __init__(self, num, name):
68        self.num = num
69        self.name = name
70
71    def __str__(self):
72        return "%s (num: %d)" % (self.name, self.num)
73
74
75class Server (threading.Thread):
76    def __init__(self):
77        threading.Thread.__init__(self)
78        self.files_available = []
79        self.files_finished = []
80        self.files = {}
81
82        logging.info("Getting list of files %s", options.cnf_list)
83        key = boto.connect_s3().get_bucket("msoos-solve-data").get_key("solvers/" + options.cnf_list)
84        key.get_contents_to_filename(options.cnf_list)
85
86        fnames = open(options.cnf_list, "r")
87        logging.info("CNF list is file %s", options.cnf_list)
88        num = 0
89        for fname in fnames:
90            fname = fname.strip()
91            self.files[num] = ToSolve(num, fname)
92            self.files_available.append(num)
93            logging.info("File added: %s", fname)
94            num = num+1
95        fnames.close()
96
97        self.files_running = {}
98        logging.info("Solving %d files", len(self.files_available))
99        self.uniq_cnt = 0
100
101    def ready_to_shutdown(self):
102        if len(self.files_available) > 0:
103            return False
104
105        if len(self.files_finished) < len(self.files):
106            return False
107
108        return True
109
110    def handle_done(self, connection, cli_addr, indata):
111        file_num = indata["file_num"]
112
113        logging.info("Finished with file %s (num %d), got files %s",
114                     self.files[indata["file_num"]], indata["file_num"],
115                     indata["files"])
116        self.files_finished.append(indata["file_num"])
117        if file_num in self.files_running:
118            del self.files_running[file_num]
119
120        logging.info("Num files_available: %d Num files_finished %d",
121                     len(self.files_available), len(self.files_finished))
122
123        self.rename_files_to_final(indata["files"])
124        sys.stdout.flush()
125
126    def rename_files_to_final(self, files):
127        for fnames in files:
128            logging.info("Renaming file %s to %s", fnames[0], fnames[1])
129            ret = os.system("aws s3 mv s3://{bucket}/{origname} s3://{bucket}/{toname} --region {region}".format(
130                bucket=options.s3_bucket,
131                origname=fnames[0],
132                toname=fnames[1],
133                region=options.region))
134            if ret:
135                logging.warn("Renaming file to final name failed!")
136
137    def check_for_dead_files(self):
138        this_time = time.time()
139        files_to_remove_from_files_running = []
140        for file_num, starttime in self.files_running.items():
141            duration = this_time - starttime
142            # print("* death check. running:" , file_num, " duration: ",
143            # duration)
144            if duration > options.timeout_in_secs*options.tout_mult:
145                logging.warn("* dead file %s duration: %d re-inserting",
146                             file_num, duration)
147                files_to_remove_from_files_running.append(file_num)
148                self.files_available.append(file_num)
149
150        for c in files_to_remove_from_files_running:
151            del self.files_running[c]
152
153    def find_something_to_solve(self):
154        self.check_for_dead_files()
155        logging.info("Num files_available pre-send: %d",
156                     len(self.files_available))
157
158        if len(self.files_available) == 0:
159            return None
160
161        file_num = self.files_available[0]
162        del self.files_available[0]
163        logging.info("Num files_available post-send: %d",
164                     len(self.files_available))
165        sys.stdout.flush()
166
167        return file_num
168
169    def handle_build(self, connection, cli_addr, indata):
170        tosend = self.default_tosend()
171        logging.info("Sending git revision %s to %s", options.git_rev,
172                     cli_addr)
173        send_command(connection, "build_data", tosend)
174
175    def send_termination(self, connection, cli_addr):
176        tosend = {}
177        tosend["noshutdown"] = options.noshutdown
178        send_command(connection, "finish", tosend)
179
180        logging.info("No more to solve, terminating %s", cli_addr)
181        global last_termination_sent
182        last_termination_sent = time.time()
183
184    def send_wait(self, connection, cli_addr):
185        tosend = {}
186        tosend["noshutdown"] = options.noshutdown
187        logging.info("Everything is in sent queue, sending wait to %s", cli_addr)
188        send_command(connection, "wait", tosend)
189
190    def default_tosend(self):
191        tosend = {}
192        tosend["solver"] = options.solver
193        tosend["git_rev"] = options.git_rev
194        tosend["stats"] = options.stats
195        tosend["gauss"] = options.gauss
196        tosend["s3_bucket"] = options.s3_bucket
197        tosend["given_folder"] = options.given_folder
198        tosend["timeout_in_secs"] = options.timeout_in_secs
199        tosend["mem_limit_in_mb"] = options.mem_limit_in_mb
200        tosend["noshutdown"] = options.noshutdown
201        tosend["extra_opts"] = options.extra_opts
202        tosend["drat"] = options.drat
203        tosend["region"] = options.region
204
205        return tosend
206
207    def send_one_to_solve(self, connection, cli_addr, file_num):
208        # set timer that we have sent this to be solved
209        self.files_running[file_num] = time.time()
210        filename = self.files[file_num].name
211
212        tosend = self.default_tosend()
213        tosend["file_num"] = file_num
214        tosend["cnf_filename"] = filename
215        tosend["uniq_cnt"] = str(self.uniq_cnt)
216        logging.info("Sending file %s (num %d) to %s",
217                     filename, file_num, cli_addr)
218        send_command(connection, "solve", tosend)
219        self.uniq_cnt += 1
220
221    def handle_need(self, connection, cli_addr, indata):
222        # TODO don't ignore 'indata' for solving CNF instances, use it to
223        # opitimize for uptime
224        file_num = self.find_something_to_solve()
225
226        if file_num is None:
227            if len(self.files_running) == 0:
228                self.send_termination(connection, cli_addr)
229            else:
230                self.send_wait(connection, cli_addr)
231        else:
232            self.send_one_to_solve(connection, cli_addr, file_num)
233
234    def handle_one_client(self, conn, cli_addr):
235        try:
236            logging.info("connection from %s", cli_addr)
237
238            data = get_n_bytes_from_connection(conn, 8)
239            length = struct.unpack('!q', data)[0]
240            data = get_n_bytes_from_connection(conn, length)
241            data = pickle.loads(data)
242
243            if data["command"] == "done":
244                self.handle_done(conn, cli_addr, data)
245
246            if data["command"] == "error":
247                shutdown(-1)
248                raise
249
250            elif data["command"] == "need":
251                self.handle_need(conn, cli_addr, data)
252
253            elif data["command"] == "build":
254                self.handle_build(conn, cli_addr, data)
255
256            sys.stdout.flush()
257        except:
258            exc_type, exc_value, exc_traceback = sys.exc_info()
259            traceback.print_exc()
260            the_trace = traceback.format_exc()
261
262            logging.error("Exception from %s, Trace: %s", cli_addr,
263                          the_trace)
264
265        finally:
266            # Clean up the connection
267            logging.info("Finished with client %s", cli_addr)
268            conn.close()
269
270    def run(self):
271        global acc_queue
272        while True:
273            conn, cli_addr = acc_queue.get()
274            self.handle_one_client(conn, cli_addr)
275
276
277class Listener (threading.Thread):
278
279    def __init__(self):
280        threading.Thread.__init__(self)
281
282    def listen_to_connection(self):
283        # Create a TCP/IP socket
284        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
285
286        # Bind the socket to the port
287        server_address = ('0.0.0.0', options.port)
288        logging.info('starting up on %s port %s', server_address, options.port)
289        sock.bind(server_address)
290
291        # Listen for incoming connections
292        sock.listen(128)
293        return sock
294
295    def handle_one_connection(self):
296        global acc_queue
297
298        # Wait for a connection
299        conn, client_addr = self.sock.accept()
300        acc_queue.put_nowait((conn, client_addr))
301
302    def run(self):
303        try:
304            self.sock = self.listen_to_connection()
305        except:
306            exc_type, exc_value, exc_traceback = sys.exc_info()
307            the_trace = traceback.format_exc().rstrip().replace("\n", " || ")
308            logging.error("Cannot listen on stocket! Traceback: %s", the_trace)
309            shutdown(-1)
310            raise
311        while True:
312            self.handle_one_connection()
313
314
315class SpotManager (threading.Thread):
316
317    def __init__(self):
318        threading.Thread.__init__(self)
319        self.spot_creator = RequestSpotClient.RequestSpotClient(
320            options.git_rev,
321            ("test" in options.cnf_list), noshutdown=options.noshutdown,
322            count=options.client_count)
323
324    def run(self):
325        while True:
326            try:
327                if not server.ready_to_shutdown():
328                    self.spot_creator.create_spots_if_needed()
329            except:
330                exc_type, exc_value, exc_traceback = sys.exc_info()
331                the_trace = traceback.format_exc().rstrip().replace("\n", " || ")
332                logging.error("Cannot create spots! Traceback: %s", the_trace)
333
334            time.sleep(60)
335
336
337def shutdown(exitval=0):
338    toexec = "sudo shutdown -h now"
339    logging.info("SHUTTING DOWN")
340
341    # send email
342    try:
343        email_subject = "Server shutting down "
344        if exitval == 0:
345            email_subject += "OK"
346        else:
347            email_subject += "FAIL"
348
349        full_s3_folder = get_s3_folder(
350            options.given_folder,
351            options.git_rev,
352            options.solver,
353            options.timeout_in_secs,
354            options.mem_limit_in_mb)
355        text = """Server finished. Please download the final data:
356
357mkdir {0}
358cd {0}
359aws s3 cp --recursive s3://{1}/{0}/ .
360
361Don't forget to:
362
363* check volume
364* check EC2 still running
365
366So long and thanks for all the fish!
367""".format(full_s3_folder, options.s3_bucket)
368        send_email(email_subject, text, options.logfile_name)
369    except:
370        exc_type, exc_value, exc_traceback = sys.exc_info()
371        the_trace = traceback.format_exc().rstrip().replace("\n", " || ")
372        logging.error("Cannot send email! Traceback: %s", the_trace)
373
374    if not options.noshutdown:
375        os.system(toexec)
376
377    exit(exitval)
378
379
380def set_up_logging():
381    form = '[ %(asctime)-15s  %(levelname)s  %(message)s ]'
382    logformatter = logging.Formatter(form)
383
384    try:
385        os.unlink(options.logfile_name)
386    except:
387        pass
388    fileHandler = logging.FileHandler(options.logfile_name)
389    fileHandler.setFormatter(logformatter)
390    logging.getLogger().addHandler(fileHandler)
391    logging.getLogger().setLevel(logging.INFO)
392
393if __name__ == "__main__":
394    global options
395    global args
396    options, args = server_option_parser.parse_arguments()
397    if options.drat:
398        assert "cryptominisat" in options.solver
399
400    global acc_queue
401    acc_queue = Queue.Queue()
402    last_termination_sent = None
403
404    set_up_logging()
405    logging.info("Server called with parameters: %s",
406                 pprint.pformat(options, indent=4).replace("\n", " || "))
407
408    if not options.git_rev:
409        options.git_rev = get_revision(options.base_dir + options.solver, options.base_dir)
410        logging.info("Revision not given, taking HEAD: %s", options.git_rev)
411
412    server = Server()
413    listener = Listener()
414    spotmanager = SpotManager()
415    listener.setDaemon(True)
416    server.setDaemon(True)
417    spotmanager.setDaemon(True)
418
419    listener.start()
420    server.start()
421    time.sleep(20)
422    spotmanager.start()
423
424    while threading.active_count() > 0:
425        time.sleep(0.5)
426        if last_termination_sent is not None and server.ready_to_shutdown():
427            diff = time.time() - last_termination_sent
428            limit = 100
429            if diff > limit:
430                break
431
432    shutdown()
433