1"""Networked spam-signature detection server.
2
3The server receives the request in the form of a RFC5321 message, and
4responds with another RFC5321 message.  Neither of these messages has a
5body - all of the data is encapsulated in the headers.
6
7The response headers will always include a "Code" header, which is a
8HTTP-style response code, and a "Diag" header, which is a human-readable
9message explaining the response code (typically this will be "OK").
10
11Both the request and response headers always include a "PV" header, which
12indicates the protocol version that is being used (in a major.minor format).
13Both the requestion and response headers also always include a "Thread",
14which uniquely identifies the request (this is a requirement of using UDP).
15Responses to requests may arrive in any order, but the "Thread" header of
16a response will always match the "Thread" header of the appropriate request.
17
18Authenticated requests must also have "User", "Time" (timestamp), and "Sig"
19(signature) headers.
20"""
21import os
22import sys
23import time
24import errno
25import socket
26import signal
27import logging
28import threading
29import traceback
30import email.message
31
32try:
33    import SocketServer
34except ImportError:
35    import socketserver as SocketServer
36
37import pyzor.config
38import pyzor.account
39import pyzor.engines.common
40
41import pyzor.hacks.py26
42
43
44pyzor.hacks.py26.hack_all()
45
46
47def _eintr_retry(func, *args):
48    """restart a system call interrupted by EINTR"""
49    while True:
50        try:
51            return func(*args)
52        except OSError as e:
53            if e.args[0] != errno.EINTR:
54                raise
55
56
57class Server(SocketServer.UDPServer):
58    """The pyzord server.  Handles incoming UDP connections in a single
59    thread and single process."""
60    max_packet_size = 8192
61    time_diff_allowance = 180
62
63    def __init__(self, address, database, passwd_fn, access_fn,
64                 forwarder=None):
65        if ":" in address[0]:
66            Server.address_family = socket.AF_INET6
67        else:
68            Server.address_family = socket.AF_INET
69        self.log = logging.getLogger("pyzord")
70        self.usage_log = logging.getLogger("pyzord-usage")
71        self.database = database
72        self.one_step = getattr(self.database, "handles_one_step", False)
73
74        # Handle configuration files
75        self.passwd_fn = passwd_fn
76        self.access_fn = access_fn
77        self.accounts = {}
78        self.acl = {}
79        self.load_config()
80
81        self.forwarder = forwarder
82
83        self.log.debug("Listening on %s", address)
84        SocketServer.UDPServer.__init__(self, address, RequestHandler,
85                                        bind_and_activate=False)
86        try:
87            self.socket.setsockopt(socket.IPPROTO_IPV6, socket.IPV6_V6ONLY, 0)
88        except (AttributeError, socket.error) as e:
89            self.log.debug("Unable to set IPV6_V6ONLY to false %s", e)
90        self.server_bind()
91        self.server_activate()
92
93        # Finally, set signals
94        signal.signal(signal.SIGUSR1, self.reload_handler)
95        signal.signal(signal.SIGTERM, self.shutdown_handler)
96
97    def load_config(self):
98        """Reads the configuration files and loads the accounts and ACLs."""
99        self.accounts = pyzor.config.load_passwd_file(self.passwd_fn)
100        self.acl = pyzor.config.load_access_file(self.access_fn, self.accounts)
101
102    def shutdown_handler(self, *args, **kwargs):
103        """Handler for the SIGTERM signal. This should be used to kill the
104        daemon and ensure proper clean-up.
105        """
106        self.log.info("SIGTERM received. Shutting down.")
107        t = threading.Thread(target=self.shutdown)
108        t.start()
109
110    def reload_handler(self, *args, **kwargs):
111        """Handler for the SIGUSR1 signal. This should be used to reload
112        the configuration files.
113        """
114        self.log.info("SIGUSR1 received. Reloading configuration.")
115        t = threading.Thread(target=self.load_config)
116        t.start()
117
118    def handle_error(self, request, client_address):
119        self.log.error("Error while processing request from: %s",
120                       client_address, exc_info=True)
121
122
123class PreForkServer(Server):
124    """The same as Server, but prefork itself when starting the self, by
125    forking a number of child-processes.
126
127    The parent process will then wait for all his child process to complete.
128    """
129    def __init__(self, address, database, passwd_fn, access_fn, prefork=4):
130        """The same as Server.__init__ but requires a list of databases
131        instead of a single database connection.
132        """
133        self.pids = None
134        Server.__init__(self, address, database, passwd_fn, access_fn)
135        self._prefork = prefork
136
137    def serve_forever(self, poll_interval=0.5):
138        """Fork the current process and wait for all children to finish."""
139        pids = []
140        for dummy in xrange(self._prefork):
141            database = self.database.next()
142            pid = os.fork()
143            if not pid:
144                # Create the database in the child process, to prevent issues
145                self.database = database()
146                Server.serve_forever(self, poll_interval=poll_interval)
147                os._exit(0)
148            else:
149                pids.append(pid)
150        self.pids = pids
151        for pid in self.pids:
152            _eintr_retry(os.waitpid, pid, 0)
153
154    def shutdown(self):
155        """If this is the parent process send the TERM signal to all children,
156        else call the super method.
157        """
158        for pid in self.pids or ():
159            os.kill(pid, signal.SIGTERM)
160        if self.pids is None:
161            Server.shutdown(self)
162
163    def load_config(self):
164        """If this is the parent process send the USR1 signal to all children,
165        else call the super method.
166        """
167        for pid in self.pids or ():
168            os.kill(pid, signal.SIGUSR1)
169        if self.pids is None:
170            Server.load_config(self)
171
172
173class ThreadingServer(SocketServer.ThreadingMixIn, Server):
174    """A threaded version of the pyzord server.  Each connection is served
175    in a new thread.  This may not be suitable for all database types."""
176    pass
177
178
179class BoundedThreadingServer(ThreadingServer):
180    """Same as ThreadingServer but this also accepts a limited number of
181    concurrent threads.
182    """
183
184    def __init__(self, address, database, passwd_fn, access_fn, max_threads,
185                 forwarding_server=None):
186        ThreadingServer.__init__(self, address, database, passwd_fn, access_fn,
187                                 forwarder=forwarding_server)
188        self.semaphore = threading.Semaphore(max_threads)
189
190    def process_request(self, request, client_address):
191        self.semaphore.acquire()
192        ThreadingServer.process_request(self, request, client_address)
193
194    def process_request_thread(self, request, client_address):
195        ThreadingServer.process_request_thread(self, request, client_address)
196        self.semaphore.release()
197
198
199class ProcessServer(SocketServer.ForkingMixIn, Server):
200    """A multi-processing version of the pyzord server.  Each connection is
201    served in a new process. This may not be suitable for all database types.
202    """
203
204    def __init__(self, address, database, passwd_fn, access_fn,
205                 max_children=40, forwarding_server=None):
206        ProcessServer.max_children = max_children
207        Server.__init__(self, address, database, passwd_fn, access_fn,
208                        forwarder=forwarding_server)
209
210
211class RequestHandler(SocketServer.DatagramRequestHandler):
212    """Handle a single pyzord request."""
213
214    def __init__(self, *args, **kwargs):
215        self.response = email.message.Message()
216        SocketServer.DatagramRequestHandler.__init__(self, *args, **kwargs)
217
218    def handle(self):
219        """Handle a pyzord operation, cleanly handling any errors."""
220        self.response["Code"] = "200"
221        self.response["Diag"] = "OK"
222        self.response["PV"] = "%s" % pyzor.proto_version
223        try:
224            self._really_handle()
225        except NotImplementedError as e:
226            self.handle_error(501, "Not implemented: %s" % e)
227        except pyzor.UnsupportedVersionError as e:
228            self.handle_error(505, "Version Not Supported: %s" % e)
229        except pyzor.ProtocolError as e:
230            self.handle_error(400, "Bad request: %s" % e)
231        except pyzor.SignatureError as e:
232            self.handle_error(401, "Unauthorized: Signature Error: %s" % e)
233        except pyzor.AuthorizationError as e:
234            self.handle_error(403, "Forbidden: %s" % e)
235        except Exception as e:
236            self.handle_error(500, "Internal Server Error: %s" % e)
237            self.server.log.error(traceback.format_exc())
238        self.server.log.debug("Sending: %r", self.response.as_string())
239        self.wfile.write(self.response.as_string().encode("utf8"))
240
241    def _really_handle(self):
242        """handle() without the exception handling."""
243        self.server.log.debug("Received: %r", self.packet)
244
245        # Read the request.
246        # Old versions of the client sent a double \n after the signature,
247        # which screws up the RFC5321 format.  Specifically handle that
248        # here - this could be removed in time.
249        request = email.message_from_bytes(
250            self.rfile.read().replace(b"\n\n", b"\n") + b"\n")
251
252        # Ensure that the response can be paired with the request.
253        self.response["Thread"] = request["Thread"]
254
255        # If this is an authenticated request, then check the authentication
256        # details.
257        user = request["User"] or pyzor.anonymous_user
258        if user != pyzor.anonymous_user:
259            try:
260                pyzor.account.verify_signature(request,
261                                               self.server.accounts[user])
262            except KeyError:
263                raise pyzor.SignatureError("Unknown user.")
264
265        if "PV" not in request:
266            raise pyzor.ProtocolError("Protocol Version not specified in "
267                                      "request")
268
269        # The protocol version is compatible if the major number is
270        # identical (changes in the minor number are unimportant).
271        try:
272            if int(float(request["PV"])) != int(pyzor.proto_version):
273                raise pyzor.UnsupportedVersionError()
274        except ValueError:
275            self.server.log.warn("Invalid PV: %s", request["PV"])
276            raise pyzor.ProtocolError("Invalid Protocol Version")
277
278        # Check that the user has permission to execute the requested
279        # operation.
280        opcode = request["Op"]
281        if opcode not in self.server.acl[user]:
282            raise pyzor.AuthorizationError(
283                "User is not authorized to request the operation.")
284        self.server.log.debug("Got a %s command from %s", opcode,
285                              self.client_address[0])
286        # Get a handle to the appropriate method to execute this operation.
287        try:
288            dispatch = self.dispatches[opcode]
289        except KeyError:
290            raise NotImplementedError("Requested operation is not "
291                                      "implemented.")
292        # Get the existing record from the database (or a blank one if
293        # there is no matching record).
294        digests = request.get_all("Op-Digest")
295
296        # Do the requested operation, log what we have done, and return.
297        if dispatch and digests:
298            dispatch(self, digests)
299        self.server.usage_log.info("%s,%s,%s,%r,%s", user,
300                                   self.client_address[0], opcode, digests,
301                                   self.response["Code"])
302
303    def handle_error(self, code, message):
304        """Create an appropriate response for an error."""
305        self.server.usage_log.error("%s: %s", code, message)
306        self.response.replace_header("Code", "%d" % code)
307        self.response.replace_header("Diag", message)
308
309    def handle_pong(self, digests):
310        """Handle the 'pong' command.
311
312        This command returns maxint for report counts and 0 whitelist.
313        """
314        self.server.log.debug("Request pong for %s", digests[0])
315        self.response["Count"] = "%d" % sys.maxint
316        self.response["WL-Count"] = "%d" % 0
317
318    def handle_check(self, digests):
319        """Handle the 'check' command.
320
321        This command returns the spam/ham counts for the specified digest.
322        """
323        digest = digests[0]
324        try:
325            record = self.server.database[digest]
326        except KeyError:
327            record = pyzor.engines.common.Record()
328        self.server.log.debug("Request to check digest %s", digest)
329        self.response["Count"] = "%d" % record.r_count
330        self.response["WL-Count"] = "%d" % record.wl_count
331
332    def handle_report(self, digests):
333        """Handle the 'report' command in a single step.
334
335        This command increases the spam count for the specified digests."""
336        self.server.log.debug("Request to report digests %s", digests)
337        if self.server.one_step:
338            self.server.database.report(digests)
339        else:
340            for digest in digests:
341                try:
342                    record = self.server.database[digest]
343                except KeyError:
344                    record = pyzor.engines.common.Record()
345                record.r_increment()
346                self.server.database[digest] = record
347        if self.server.forwarder:
348            for digest in digests:
349                self.server.forwarder.queue_forward_request(digest)
350
351    def handle_whitelist(self, digests):
352        """Handle the 'whitelist' command in a single step.
353
354        This command increases the ham count for the specified digests."""
355        self.server.log.debug("Request to whitelist digests %s", digests)
356        if self.server.one_step:
357            self.server.database.whitelist(digests)
358        else:
359            for digest in digests:
360                try:
361                    record = self.server.database[digest]
362                except KeyError:
363                    record = pyzor.engines.common.Record()
364                record.wl_increment()
365                self.server.database[digest] = record
366        if self.server.forwarder:
367            for digest in digests:
368                self.server.forwarder.queue_forward_request(digest, True)
369
370    def handle_info(self, digests):
371        """Handle the 'info' command.
372
373        This command returns diagnostic data about a digest (timestamps for
374        when the digest was first/last seen as spam/ham, and spam/ham
375        counts).
376        """
377        digest = digests[0]
378        try:
379            record = self.server.database[digest]
380        except KeyError:
381            record = pyzor.engines.common.Record()
382        self.server.log.debug("Request for information about digest %s",
383                              digest)
384
385        def time_output(time_obj):
386            """Convert a datetime object to a POSIX timestamp.
387
388            If the object is None, then return 0.
389            """
390            if not time_obj:
391                return 0
392            return time.mktime(time_obj.timetuple())
393
394        self.response["Entered"] = "%d" % time_output(record.r_entered)
395        self.response["Updated"] = "%d" % time_output(record.r_updated)
396        self.response["WL-Entered"] = "%d" % time_output(record.wl_entered)
397        self.response["WL-Updated"] = "%d" % time_output(record.wl_updated)
398        self.response["Count"] = "%d" % record.r_count
399        self.response["WL-Count"] = "%d" % record.wl_count
400
401    dispatches = {
402        'ping': None,
403        'pong': handle_pong,
404        'info': handle_info,
405        'check': handle_check,
406        'report': handle_report,
407        'whitelist': handle_whitelist,
408    }
409