1"""
2Low Level socket methods to communication with the bareos-director.
3"""
4
5# Authentication code is taken from
6# https://github.com/hanxiangduo/bacula-console-python
7
8import hashlib
9import hmac
10import logging
11import random
12import re
13from select import select
14import socket
15import ssl
16import struct
17import sys
18import time
19import warnings
20
21from bareos.bsock.constants import Constants
22from bareos.bsock.connectiontype import ConnectionType
23from bareos.bsock.protocolmessageids import ProtocolMessageIds
24from bareos.bsock.protocolmessages import ProtocolMessages
25from bareos.bsock.protocolversions import ProtocolVersions
26from bareos.util.bareosbase64 import BareosBase64
27from bareos.util.password import Password
28import bareos.exceptions
29
30# Try to load the sslpsk module,
31# with implement TLS-PSK (Transport Layer Security - Pre-Shared-Key)
32# on top of the ssl module.
33# If it is not available, we continue anyway,
34# but don't use TLS-PSK.
35try:
36    import sslpsk
37except ImportError:
38    warnings.warn(
39        u"Connection encryption via TLS-PSK is not available, as the module sslpsk is not installed."
40    )
41
42
43class LowLevel(object):
44    """
45    Low Level socket methods to communicate with the bareos-director.
46    """
47
48    @staticmethod
49    def argparser_get_bareos_parameter(args):
50        """
51        This method is usally used together with the method argparser_add_default_command_line_arguments.
52
53        @param args: Arguments retrieved by ArgumentParser.parse_args()
54        @type args:  ArgParser.Namespace
55
56        @return: returns the relevant parameter from args to initialize a connection.
57        @rtype: dict
58        """
59        result = {}
60        for key, value in vars(args).items():
61            if value is not None:
62                if key.startswith("BAREOS_"):
63                    bareoskey = key.split("BAREOS_", 1)[1]
64                    result[bareoskey] = value
65        return result
66
67    def __init__(self):
68        self.logger = logging.getLogger()
69        self.logger.debug("init")
70        self.status = None
71        self.address = None
72        self.password = None
73        self.pam_username = None
74        self.pam_password = None
75        self.port = None
76        self.dirname = None
77        self.socket = None
78        self.auth_credentials_valid = False
79        self.max_reconnects = 0
80        self.tls_psk_enable = True
81        self.tls_psk_require = False
82        try:
83            self.tls_version = ssl.PROTOCOL_TLS
84        except AttributeError:
85            self.tls_version = ssl.PROTOCOL_SSLv23
86        self.connection_type = None
87        self.requested_protocol_version = None
88        self.protocol_messages = ProtocolMessages()
89        # identity_prefix have to be set in each class
90        self.identity_prefix = u"R_NONE"
91        self.receive_buffer = b""
92
93    def __del__(self):
94        self.close()
95
96    def connect(
97        self, address, port, dirname, connection_type, name=None, password=None
98    ):
99        self.address = address
100        self.port = int(port)
101        if dirname:
102            self.dirname = dirname
103        else:
104            self.dirname = address
105        self.connection_type = connection_type
106        self.name = name
107        if password is None:
108            raise bareos.exceptions.ConnectionError(u"Parameter 'password' is required.")
109        if isinstance(password, Password):
110            self.password = password
111        else:
112            self.password = Password(password)
113
114        return self.__connect()
115
116    def __connect(self):
117        connected = False
118        connected_plain = False
119        auth = False
120        if self.tls_psk_require:
121            if not self.is_tls_psk_available():
122                raise bareos.exceptions.ConnectionError(
123                    u"TLS-PSK is required, but sslpsk module not loaded/available."
124                )
125            if not self.tls_psk_enable:
126                raise bareos.exceptions.ConnectionError(
127                    u"TLS-PSK is required, but not enabled."
128                )
129
130        if self.tls_psk_enable and self.is_tls_psk_available():
131            try:
132                self.__connect_tls_psk()
133            except (bareos.exceptions.ConnectionError, ssl.SSLError) as e:
134                self._handleSocketError(e)
135                if self.tls_psk_require:
136                    raise
137                else:
138                    self.logger.warning(
139                        u"Failed to connect via TLS-PSK. Trying plain connection."
140                    )
141            else:
142                connected = True
143                self.logger.debug("Encryption: {0}".format(self.socket.cipher()))
144
145        if not connected:
146            self.__connect_plain()
147            connected = True
148            connected_plain = True
149            self.logger.debug("Encryption: None")
150
151        if connected:
152            try:
153                auth = self.auth()
154            except bareos.exceptions.PamAuthenticationError:
155                raise
156            except bareos.exceptions.AuthenticationError:
157                if (
158                    self.connection_type == ConnectionType.DIRECTOR
159                    and self.requested_protocol_version is None
160                    and self.get_protocol_version() > ProtocolVersions.bareos_12_4
161                ):
162                    # reconnect and try old protocol
163                    self.logger.warning(
164                        "Failed to connect using protocol version {0}. Trying protocol version {1}. ".format(
165                            self.get_protocol_version(), ProtocolVersions.bareos_12_4
166                        )
167                    )
168                    self.close()
169                    self.__connect_plain()
170                    self.protocol_messages.set_version(ProtocolVersions.bareos_12_4)
171                    auth = self.auth()
172                else:
173                    raise
174
175        return auth
176
177    def __connect_plain(self):
178        self.socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
179        # initialize
180        try:
181            self.socket.connect((self.address, self.port))
182        except (socket.error, socket.gaierror) as e:
183            self._handleSocketError(e)
184            raise bareos.exceptions.ConnectionError(
185                "Failed to connect to host {0}, port {1}: {2}".format(
186                    self.address, self.port, str(e)
187                )
188            )
189
190        self.logger.debug("connected to {0}:{1}".format(self.address, self.port))
191
192        return True
193
194    def __connect_tls_psk(self):
195        """
196        Connect and establish a TLS-PSK connection on top of the connection.
197        """
198        self.__connect_plain()
199        # wrap socket with TLS-PSK
200        client_socket = self.socket
201        identity = self.get_tls_psk_identity()
202        if isinstance(self.password, Password):
203            password = self.password.md5()
204        else:
205            raise bareos.exceptions.ConnectionError(u"No password provided.")
206        self.logger.debug("identity = {0}, password = {1}".format(identity, password))
207        try:
208            self.socket = sslpsk.wrap_socket(
209                client_socket,
210                ssl_version=self.tls_version,
211                ciphers="ALL:!ADH:!LOW:!EXP:!MD5:@STRENGTH",
212                psk=(password, identity),
213                server_side=False,
214            )
215        except ssl.SSLError as e:
216            # raise ConnectionError(
217            #     "failed to connect to host {0}, port {1}: {2}".format(self.address, self.port, str(e)))
218            # Using a general raise keep more information about the type of error.
219            raise
220        return True
221
222    def get_tls_psk_identity(self):
223        """Bareos TLS-PSK excepts the identiy is a specific format."""
224        name = str(self.name)
225        if isinstance(self.name, bytes):
226            name = self.name.decode("utf-8")
227        result = u"{0}{1}{2}".format(self.identity_prefix, Constants.record_separator, name)
228        return bytes(bytearray(result, "utf-8"))
229
230
231    @staticmethod
232    def is_tls_psk_available():
233        """Checks if we have all required modules for TLS-PSK."""
234        return "sslpsk" in sys.modules
235
236    def get_protocol_version(self):
237        return self.protocol_messages.get_version()
238
239    def get_cipher(self):
240        if hasattr(self.socket, "cipher"):
241            return self.socket.cipher()
242        else:
243            return None
244
245    def auth(self):
246        """
247        Login to a Bareos Daemon.
248
249        @return: True, if the authentication succeeds.
250                 In earlier versions, authentication failures returned False.
251                 However, now an authentication failure raises an exception.
252        @rtype: bool
253
254        @raise bareos.exceptions.AuthenticationError: if authentication fails.
255        """
256
257        bashed_name = self.protocol_messages.hello(self.name, type=self.connection_type)
258        # send the bash to the director
259        self.send(bashed_name)
260
261        try:
262            (ssl, result_compatible, result) = self._cram_md5_respond(
263                password=self.password.md5(), tls_remote_need=0
264            )
265        except bareos.exceptions.SignalReceivedException as e:
266            self._handleSocketError(e)
267            raise bareos.exceptions.AuthenticationError(
268                "Received unexcepted signal: {0}".format(str(e))
269            )
270        if not result:
271            raise bareos.exceptions.AuthenticationError("failed (in response)")
272        if not self._cram_md5_challenge(
273            clientname=self.name,
274            password=self.password.md5(),
275            tls_local_need=0,
276            compatible=True,
277        ):
278            raise bareos.exceptions.AuthenticationError("failed (in challenge)")
279
280        self.finalize_authentication()
281
282        return self.auth_credentials_valid
283
284    def receive_and_evaluate_response_message(self):
285        regex_str = r"^(\d\d\d\d){0}(.*)$".format(
286            Constants.record_separator_compat_regex
287        )
288        regex = bytes(bytearray(regex_str, "utf8"))
289        incoming_message = self.recv_msg(regex)
290        match = re.search(regex, incoming_message, re.DOTALL)
291        code = int(match.group(1))
292        text = match.group(2)
293
294        return (code, text)
295
296    def _init_connection(self):
297        pass
298
299    def close(self):
300        """disconnect"""
301        if self.socket is not None:
302            self.socket.close()
303        self.socket = None
304
305    def reconnect(self):
306        result = False
307        if self.max_reconnects > 0:
308            try:
309                self.max_reconnects -= 1
310                if self.__connect() and self._init_connection():
311                    result = True
312            except (socket.error, bareos.exceptions.ConnectionLostError):
313                self.logger.warning("failed to reconnect")
314        return result
315
316    def call(self, command):
317        """
318        call a bareos-director user agent command
319        """
320        if isinstance(command, list):
321            command = " ".join(command)
322        return self._send_a_command_and_receive_result(command)
323
324    def _send_a_command_and_receive_result(self, command):
325        """
326        Send a command and receive the result.
327        If connection is lost, try to reconnect.
328        """
329        result = b""
330        try:
331            self.send(bytearray(command, "utf-8"))
332            result = self.recv_msg()
333        except (
334            bareos.exceptions.SocketEmptyHeader,
335            bareos.exceptions.ConnectionLostError,
336        ) as e:
337            self.logger.error(
338                "connection problem (%s): %s" % (type(e).__name__, str(e))
339            )
340            if self.reconnect():
341                return self._send_a_command_and_receive_result(command, count + 1)
342            else:
343                raise
344        return result
345
346    def send_command(self, command):
347        return self.call(command)
348
349    def send(self, msg=None):
350        """use socket to send request to director"""
351        self.__check_socket_connection()
352        msg_len = len(msg)  # plus the msglen info
353
354        try:
355            # convert to network flow
356            self.logger.debug("{0}".format(msg.rstrip()))
357            self.socket.sendall(struct.pack("!i", msg_len) + msg)
358        except socket.error as e:
359            self._handleSocketError(e)
360
361    def recv_bytes(self, length, timeout=10):
362        """
363        Receive a number of bytes.
364
365        @raise bareos.exceptions.ConnectionLostError:
366               is raised, if the socket connection gets lost.
367        @raise socket.timeout:
368               is raised, if a timeout occurs on the socket connection,
369               meaning no data received.
370        """
371        self.socket.settimeout(timeout)
372        msg = b""
373        # get the message
374        while length > 0:
375            self.logger.debug("expecting {0} bytes.".format(length))
376            submsg = self.socket.recv(length)
377            if len(submsg) == 0:
378                errormsg = u"Failed to retrieve data. Assuming the connection is lost."
379                self._handleSocketError(errormsg)
380                raise bareos.exceptions.ConnectionLostError(errormsg)
381            length -= len(submsg)
382            msg += submsg
383        return msg
384
385    def recv(self):
386        """
387        Receive a single message.
388        This is,
389        header (4 bytes): if
390            > 0: length of the following message
391            < 0: Bareos signal
392        msg: of the length descriped in the header.
393
394        @raise bareos.exceptions.SignalReceivedException:
395               is raised, if a Bareos signal is received.
396        """
397        self.__check_socket_connection()
398        # get the message header
399        header = self.__get_header()
400        if header <= 0:
401            self.logger.debug("header: " + str(header))
402            raise bareos.exceptions.SignalReceivedException(header)
403        # get the message
404        length = header
405        msg = self.recv_submsg(length)
406        return msg
407
408    def recv_msg(self, regex=b"^\d\d\d\d OK.*$"):
409        """
410        Receive a full Director message.
411
412        It retrieves Director messages (header + message text),
413        until
414          1. the message matches the specified regex or
415          2. the header indicates a signal.
416
417        @raise bareos.exceptions.SignalReceivedException:
418               is raised, if a Bareos signal is received.
419        """
420        self.__check_socket_connection()
421        try:
422            timeouts = 0
423            while True:
424                # get the message header
425                try:
426                    header = self.__get_header()
427                except (socket.timeout, ssl.SSLError) as exception:
428                    # When using a SSL connection,
429                    # a timeout is raised as
430                    # ssl.SSLError exception with message: 'The read operation timed out'.
431                    # ssl.SSLError is inherited from socket.error.
432                    # Because we can't be sure,
433                    # that it is really a timeout, we log it.
434                    if isinstance(exception, ssl.SSLError) and self.logger.isEnabledFor(
435                        logging.DEBUG
436                    ):
437                        # self.logger.exception('On SSL connections, timeout are raised as ssl.SSLError exceptions:')
438                        self.logger.debug("{0}".format(repr(exception)))
439                    self.logger.debug("timeout (%i) on receiving header" % (timeouts))
440                    timeouts += 1
441                else:
442                    if header <= 0:
443                        # header is a signal
444                        self.__set_status(header)
445                        if self.is_end_of_message(header):
446                            result = self.receive_buffer
447                            self.receive_buffer = b""
448                            return result
449                    else:
450                        # header is the length of the next message
451                        length = header
452                        submsg = self.recv_submsg(length)
453                        # check for regex in new submsg
454                        # and last line in old message,
455                        # which might have been incomplete without new submsg.
456                        lastlineindex = self.receive_buffer.rfind(b"\n") + 1
457                        self.receive_buffer += submsg
458                        match = re.search(
459                            regex, self.receive_buffer[lastlineindex:], re.DOTALL
460                        )
461                        # Bareos indicates end of command result by line starting with 4 digits
462                        if match:
463                            self.logger.debug(
464                                'msg "{0}" matches regex "{1}"'.format(
465                                    self.receive_buffer.strip(), regex
466                                )
467                            )
468                            result = self.receive_buffer[
469                                0 : lastlineindex + match.end()
470                            ]
471                            self.receive_buffer = self.receive_buffer[
472                                lastlineindex + match.end() + 1 :
473                            ]
474                            return result
475        except socket.error as e:
476            self._handleSocketError(e)
477
478    def recv_submsg(self, length):
479        # get the message
480        msg = self.recv_bytes(length)
481        if type(msg) is str:
482            msg = bytearray(msg.decode("utf-8"), "utf-8")
483        if type(msg) is bytes:
484            msg = bytearray(msg)
485        self.logger.debug(str(msg))
486        return msg
487
488    def interactive(self):
489        """
490        Enter the interactive mode.
491        Exit via typing "exit" or "quit".
492        """
493        command = ""
494        while command != "exit" and command != "quit" and self.is_connected():
495            try:
496                command = self._get_input()
497            except EOFError:
498                return False
499            try:
500                resultmsg = self.call(command)
501                self._show_result(resultmsg)
502            except bareos.exceptions.JsonRpcErrorReceivedException as exp:
503                print(str(exp))
504                # print(str(exp.jsondata))
505
506        return True
507
508    def _get_input(self):
509        # Python2: raw_input, Python3: input
510        try:
511            myinput = raw_input
512        except NameError:
513            myinput = input
514        data = myinput(">>")
515        return data
516
517    def _show_result(self, msg):
518        # print(msg.decode('utf-8'))
519        sys.stdout.write(msg.decode("utf-8"))
520        # add a linefeed, if there isn't one already
521        if len(msg) >= 2:
522            if msg[-2] != ord(b'\n'):
523                sys.stdout.write(b'\n')
524
525    def __get_header(self, timeout=10):
526        header = self.recv_bytes(4, timeout)
527        return self.__get_header_data(header)
528
529    def __get_header_data(self, header):
530        # struct.unpack:
531        #   !: network (big/little endian conversion)
532        #   i: integer (4 bytes)
533        data = struct.unpack("!i", header)[0]
534        return data
535
536    def is_end_of_message(self, data):
537        return (
538            (not self.is_connected())
539            or data == Constants.BNET_EOD
540            or data == Constants.BNET_TERMINATE
541            or data == Constants.BNET_MAIN_PROMPT
542            or data == Constants.BNET_SUB_PROMPT
543        )
544
545    def is_connected(self):
546        return self.status != Constants.BNET_TERMINATE
547
548    def _cram_md5_challenge(
549        self, clientname, password, tls_local_need=0, compatible=True
550    ):
551        """
552        client launch the challenge,
553        client confirm the dir is the correct director
554        """
555
556        # get the timestamp
557        # here is the console
558        # to confirm the director so can do this on bconsole`way
559        rand = random.randint(1000000000, 9999999999)
560        # chal = "<%u.%u@%s>" %(rand, int(time.time()), self.dirname)
561        chal = "<%u.%u@%s>" % (rand, int(time.time()), clientname)
562        msg = bytearray("auth cram-md5 %s ssl=%d\n" % (chal, tls_local_need), "utf-8")
563        # send the confirmation
564        self.send(msg)
565        # get the response
566        msg = self.recv()
567        if msg[-1] == 0:
568            del msg[-1]
569        self.logger.debug("received: " + str(msg))
570
571        # hash with password
572        hmac_md5 = hmac.new(password, None, hashlib.md5)
573        hmac_md5.update(bytes(bytearray(chal, "utf-8")))
574        bbase64compatible = BareosBase64().string_to_base64(
575            bytearray(hmac_md5.digest()), True
576        )
577        bbase64notcompatible = BareosBase64().string_to_base64(
578            bytearray(hmac_md5.digest()), False
579        )
580        self.logger.debug("string_to_base64, compatible:     " + str(bbase64compatible))
581        self.logger.debug(
582            "string_to_base64, not compatible: " + str(bbase64notcompatible)
583        )
584
585        is_correct = (msg == bbase64compatible) or (msg == bbase64notcompatible)
586        # check against compatible base64 and Bareos specific base64
587        if is_correct:
588            self.send(ProtocolMessages.auth_ok())
589        else:
590            self.logger.error(
591                "expected result: %s or %s, but get %s"
592                % (bbase64compatible, bbase64notcompatible, msg)
593            )
594            self.send(ProtocolMessages.auth_failed())
595
596        # check the response is equal to base64
597        return is_correct
598
599    def _cram_md5_respond(self, password, tls_remote_need=0, compatible=True):
600        """
601        client connect to dir,
602        the dir confirm the password and the config is correct
603        """
604        # receive from the director
605        chal = ""
606        ssl = 0
607        result = False
608        msg = ""
609        try:
610            msg = self.recv()
611        except RuntimeError:
612            self.logger.error("RuntimeError exception in recv")
613            return (0, True, False)
614
615        # invalid username
616        if ProtocolMessages.is_not_authorized(msg):
617            self.logger.error("failed: " + str(msg))
618            return (0, True, False)
619
620        # check the receive message
621        self.logger.debug("(recv): " + str(msg).rstrip())
622
623        msg_list = msg.split(b" ")
624        chal = msg_list[2]
625        # get th timestamp and the tle info from director response
626        ssl = int(msg_list[3][4])
627        compatible = True
628        # hmac chal and the password
629        hmac_md5 = hmac.new((password), None, hashlib.md5)
630        hmac_md5.update(bytes(chal))
631
632        # base64 encoding
633        msg = BareosBase64().string_to_base64(bytearray(hmac_md5.digest()))
634
635        # send the base64 encoding to director
636        self.send(msg)
637        received = self.recv()
638        if ProtocolMessages.is_auth_ok(received):
639            result = True
640        else:
641            self.logger.error("failed: " + str(received))
642        return (ssl, compatible, result)
643
644    def __set_status(self, status):
645        self.status = status
646        status_text = Constants.get_description(status)
647        self.logger.debug(str(status_text) + " (" + str(status) + ")")
648
649    def has_data(self):
650        self.__check_socket_connection()
651        timeout = 0.1
652        readable, writable, exceptional = select([self.socket], [], [], timeout)
653        return readable
654
655    def get_to_prompt(self):
656        time.sleep(0.1)
657        if self.has_data():
658            msg = self.recv_msg()
659            self.logger.debug("received message: " + str(msg))
660        # TODO: check prompt
661        return True
662
663    def __check_socket_connection(self):
664        result = True
665        if self.socket is None:
666            result = False
667            if self.auth_credentials_valid:
668                # connection have worked before, but now it is gone
669                raise bareos.exceptions.ConnectionLostError(
670                    "currently no network connection"
671                )
672            else:
673                raise RuntimeError("should connect to director first before send data")
674        return result
675
676    def _handleSocketError(self, exception):
677        self.logger.warning("socket error: {0}".format(str(exception)))
678        self.close()
679