1# Copyright (C) 2017-2021 Pier Carlo Chiodi
2#
3# This program is free software: you can redistribute it and/or modify
4# it under the terms of the GNU General Public License as published by
5# the Free Software Foundation, either version 3 of the License, or
6# (at your option) any later version.
7#
8# This program is distributed in the hope that it will be useful,
9# but WITHOUT ANY WARRANTY; without even the implied warranty of
10# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11# GNU General Public License for more details.
12#
13# You should have received a copy of the GNU General Public License
14# along with this program.  If not, see <http://www.gnu.org/licenses/>.
15
16from six.moves import input
17import sys
18
19from .ipaddresses import IPAddress
20
21class Ask(object):
22
23    def __init__(self):
24        self.next_answer = None
25
26    def get_input(self):
27        if self.next_answer:
28            ans = self.next_answer
29            self.next_answer = None
30            self.wr_out(ans + "\n")
31            return ans
32        else:
33            return input()
34
35    def wr_out(self, msg):
36        sys.stdout.write(msg)
37
38    def ask(self, text, options=None, default=None, raise_exc=False):
39        """Returns: ([True|False], answer)"""
40
41        msg = "{} ".format(text)
42        if options:
43            msg_options = []
44            for opt in options:
45                if opt == default:
46                    msg_options.append(opt.upper())
47                else:
48                    msg_options.append(str(opt))
49            msg += "["
50            msg += "/".join(msg_options)
51            msg += "] "
52        else:
53            if default:
54                msg += "(default: {}) ".format(default)
55        self.wr_out(msg)
56
57        try:
58            answer = self.get_input()
59        except:
60            if raise_exc:
61                raise
62            return False, None
63
64        answer = answer.strip()
65        if answer:
66            if options and answer.lower() not in [_.lower() for _ in options]:
67                print("Invalid choice: {} - must be one of {}.".format(
68                    answer, ", ".join(options)))
69                return False, None
70            return True, answer
71        else:
72            if default:
73                return True, default
74            else:
75                print("No answer given.")
76                return False, None
77
78    def ask_yes_no(self, text, default=None, raise_exc=False):
79        return self.ask(text, ["yes", "no"], default, raise_exc)
80
81    def ask_int(self, text, default=None, raise_exc=False):
82        answer_given, v = self.ask(text, None, default, raise_exc)
83        if not answer_given:
84            return False, None
85        if not v.isdigit():
86            print("Invalid input: it must be an integer.")
87            return False, None
88        return True, int(v)
89
90    def ask_ipv4_addr(self, text, default=None, raise_exc=False):
91        answer_given, v = self.ask(text, None, default, raise_exc)
92        if not answer_given:
93            return False, None
94        try:
95            ip = IPAddress(v)
96            if ip.version != 4:
97                raise ValueError()
98        except:
99            print("Invalid input: must be a valid IPv4 address.")
100            return False, None
101        return True, ip.ip
102