1# -*- coding: utf-8 -*-
2# Copyright (C) 2019-2021 Greenbone Networks GmbH
3#
4# SPDX-License-Identifier: GPL-3.0-or-later
5#
6# This program is free software: you can redistribute it and/or modify
7# it under the terms of the GNU General Public License as published by
8# the Free Software Foundation, either version 3 of the License, or
9# (at your option) any later version.
10#
11# This program is distributed in the hope that it will be useful,
12# but WITHOUT ANY WARRANTY; without even the implied warranty of
13# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
14# GNU General Public License for more details.
15#
16# You should have received a copy of the GNU General Public License
17# along with this program.  If not, see <http://www.gnu.org/licenses/>.
18
19import os
20import sys
21import unittest
22
23from unittest.mock import patch
24from pathlib import Path
25
26from argparse import Namespace
27from gvm.connections import (
28    DEFAULT_UNIX_SOCKET_PATH,
29    DEFAULT_TIMEOUT,
30    UnixSocketConnection,
31    TLSConnection,
32    SSHConnection,
33)
34
35from gvmtools.parser import CliParser, create_parser, create_connection
36
37from . import SuppressOutput
38
39__here__ = Path(__file__).parent.resolve()
40
41
42class ConfigParserTestCase(unittest.TestCase):
43    def setUp(self):
44        self.test_config_path = __here__ / 'test.cfg'
45
46        self.assertTrue(self.test_config_path.is_file())
47
48        self.parser = CliParser('TestParser', 'test.log')
49
50    def test_socket_defaults_from_config(self):
51        args = self.parser.parse_args(
52            ['--config', str(self.test_config_path), 'socket']
53        )
54
55        self.assertEqual(args.foo, 'bar')
56        self.assertEqual(args.timeout, 1000)
57        self.assertEqual(args.gmp_password, 'bar')
58        self.assertEqual(args.gmp_username, 'bar')
59        self.assertEqual(args.socketpath, '/foo/bar.sock')
60
61    def test_ssh_defaults_from_config(self):
62        args = self.parser.parse_args(
63            ['--config', str(self.test_config_path), 'ssh', '--hostname', 'foo']
64        )
65
66        self.assertEqual(args.foo, 'bar')
67        self.assertEqual(args.timeout, 1000)
68        self.assertEqual(args.gmp_password, 'bar')
69        self.assertEqual(args.gmp_username, 'bar')
70        self.assertEqual(args.ssh_password, 'lorem')
71        self.assertEqual(args.ssh_username, 'ipsum')
72        self.assertEqual(args.port, 123)
73
74    def test_tls_defaults_from_config(self):
75        args = self.parser.parse_args(
76            ['--config', str(self.test_config_path), 'tls', '--hostname', 'foo']
77        )
78
79        self.assertEqual(args.foo, 'bar')
80        self.assertEqual(args.timeout, 1000)
81        self.assertEqual(args.gmp_password, 'bar')
82        self.assertEqual(args.gmp_username, 'bar')
83        self.assertEqual(args.certfile, 'foo.cert')
84        self.assertEqual(args.keyfile, 'foo.key')
85        self.assertEqual(args.cafile, 'foo.ca')
86        self.assertEqual(args.port, 123)
87
88    @patch('gvmtools.parser.logger')
89    @patch('gvmtools.parser.Path')
90    def test_resolve_file_not_found_error(self, path_mock, logger_mock):
91        # Making sure that resolve raises an error
92        def resolve_raises_error():
93            raise FileNotFoundError()
94
95        configpath = unittest.mock.MagicMock()
96        configpath.expanduser().resolve = unittest.mock.MagicMock(
97            side_effect=resolve_raises_error
98        )
99        path_mock.return_value = configpath
100
101        logger_mock.debug = unittest.mock.MagicMock()
102
103        args = self.parser.parse_args(['socket'])
104
105        self.assertIsInstance(args, Namespace)
106        self.assertEqual(args.connection_type, 'socket')
107        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
108        logger_mock.debug.assert_any_call(
109            'Ignoring non existing config file %s', '~/.config/gvm-tools.conf'
110        )
111
112    @patch('gvmtools.parser.Path')
113    @patch('gvmtools.parser.Config')
114    def test_config_load_raises_error(self, config_mock, path_mock):
115        def config_load_error():
116            raise Exception
117
118        config = unittest.mock.MagicMock()
119        config.load = unittest.mock.MagicMock(side_effect=config_load_error)
120        config_mock.return_value = config
121
122        # Making sure that the function thinks the config file exists
123        configpath_exists = unittest.mock.Mock()
124        configpath_exists.expanduser().resolve().exists = (
125            unittest.mock.MagicMock(return_value=True)
126        )
127        path_mock.return_value = configpath_exists
128
129        self.assertRaises(RuntimeError, self.parser.parse_args, ['socket'])
130
131
132class IgnoreConfigParserTestCase(unittest.TestCase):
133    def test_unkown_config_file(self):
134        test_config_path = __here__ / 'foo.cfg'
135
136        self.assertFalse(test_config_path.is_file())
137
138        self.parser = CliParser('TestParser', 'test.log')
139
140        args = self.parser.parse_args(
141            ['--config', str(test_config_path), 'socket']
142        )
143
144        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
145        self.assertEqual(args.gmp_password, '')
146        self.assertEqual(args.gmp_username, '')
147        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
148
149    def test_unkown_config_file_in_unkown_dir(self):
150        test_config_path = __here__ / 'foo' / 'foo.cfg'
151
152        self.assertFalse(test_config_path.is_file())
153
154        self.parser = CliParser('TestParser', 'test.log')
155
156        args = self.parser.parse_args(
157            ['--config', str(test_config_path), 'socket']
158        )
159
160        self.assertEqual(args.timeout, DEFAULT_TIMEOUT)
161        self.assertEqual(args.gmp_password, '')
162        self.assertEqual(args.gmp_username, '')
163        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
164
165
166class ParserTestCase(unittest.TestCase):
167    def setUp(self):
168        self.parser = CliParser(
169            'TestParser', 'test.log', ignore_config=True, prog='gvm-test-cli'
170        )
171
172
173class RootArgumentsParserTest(ParserTestCase):
174    def test_config(self):
175        args = self.parser.parse_args(['--config', 'foo.cfg', 'socket'])
176        self.assertEqual(args.config, 'foo.cfg')
177
178    def test_defaults(self):
179        args = self.parser.parse_args(['socket'])
180        self.assertEqual(args.config, '~/.config/gvm-tools.conf')
181        self.assertEqual(args.gmp_password, '')
182        self.assertEqual(args.gmp_username, '')
183        self.assertEqual(args.timeout, 60)
184        self.assertIsNone(args.loglevel)
185
186    def test_loglevel(self):
187        args = self.parser.parse_args(['--log', 'ERROR', 'socket'])
188        self.assertEqual(args.loglevel, 'ERROR')
189
190    def test_loglevel_after_subparser(self):
191        with SuppressOutput(suppress_stderr=True):
192            with self.assertRaises(SystemExit):
193                self.parser.parse_args(['socket', '--log', 'ERROR'])
194
195    def test_timeout(self):
196        args = self.parser.parse_args(['--timeout', '1000', 'socket'])
197        self.assertEqual(args.timeout, 1000)
198
199    def test_timeout_after_subparser(self):
200        with SuppressOutput(suppress_stderr=True):
201            with self.assertRaises(SystemExit):
202                self.parser.parse_args(['socket', '--timeout', '1000'])
203
204    def test_gmp_username(self):
205        args = self.parser.parse_args(['--gmp-username', 'foo', 'socket'])
206        self.assertEqual(args.gmp_username, 'foo')
207
208    def test_gmp_username_after_subparser(self):
209        with SuppressOutput(suppress_stderr=True):
210            with self.assertRaises(SystemExit):
211                self.parser.parse_args(['socket', '--gmp-username', 'foo'])
212
213    def test_gmp_password(self):
214        args = self.parser.parse_args(['--gmp-password', 'foo', 'socket'])
215        self.assertEqual(args.gmp_password, 'foo')
216
217    def test_gmp_password_after_subparser(self):
218        with SuppressOutput(suppress_stderr=True):
219            with self.assertRaises(SystemExit):
220                self.parser.parse_args(['socket', '--gmp-password', 'foo'])
221
222    def test_with_unknown_args(self):
223        args, script_args = self.parser.parse_known_args(
224            ['--gmp-password', 'foo', 'socket', '--bar', '--bar2']
225        )
226        self.assertEqual(args.gmp_password, 'foo')
227        self.assertEqual(script_args, ['--bar', '--bar2'])
228
229    @patch('gvmtools.parser.logging')
230    def test_socket_has_no_timeout(self, _logging_mock):
231        # pylint: disable=protected-access
232        args_mock = unittest.mock.MagicMock()
233        args_mock.timeout = -1
234        self.parser._parser.parse_known_args = unittest.mock.MagicMock(
235            return_value=(args_mock, unittest.mock.MagicMock())
236        )
237
238        args, _ = self.parser.parse_known_args(
239            ['socket', '--timeout', '--', '-1']
240        )
241
242        self.assertIsNone(args.timeout)
243
244    @patch('gvmtools.parser.logging')
245    @patch('gvmtools.parser.argparse.ArgumentParser.print_usage')
246    @patch('gvmtools.parser.argparse.ArgumentParser._print_message')
247    def test_no_args_provided(
248        self, _logging_mock, _print_usage_mock, _print_message
249    ):
250        # pylint: disable=protected-access
251        self.parser._set_defaults = unittest.mock.MagicMock()
252
253        self.assertRaises(SystemExit, self.parser.parse_known_args, None)
254
255
256class SocketParserTestCase(ParserTestCase):
257    def test_defaults(self):
258        args = self.parser.parse_args(['socket'])
259        self.assertEqual(args.socketpath, DEFAULT_UNIX_SOCKET_PATH)
260
261    def test_connection_type(self):
262        args = self.parser.parse_args(['socket'])
263        self.assertEqual(args.connection_type, 'socket')
264
265    def test_sockpath(self):
266        args = self.parser.parse_args(['socket', '--sockpath', 'foo.sock'])
267        self.assertEqual(args.socketpath, 'foo.sock')
268
269    def test_socketpath(self):
270        args = self.parser.parse_args(['socket', '--socketpath', 'foo.sock'])
271        self.assertEqual(args.socketpath, 'foo.sock')
272
273
274class SshParserTestCase(ParserTestCase):
275    def test_defaults(self):
276        args = self.parser.parse_args(['ssh', '--hostname=foo'])
277        self.assertEqual(args.port, 22)
278        self.assertEqual(args.ssh_username, 'gmp')
279        self.assertEqual(args.ssh_password, 'gmp')
280
281    def test_connection_type(self):
282        args = self.parser.parse_args(['ssh', '--hostname=foo'])
283        self.assertEqual(args.connection_type, 'ssh')
284
285    def test_hostname(self):
286        args = self.parser.parse_args(['ssh', '--hostname', 'foo'])
287        self.assertEqual(args.hostname, 'foo')
288
289    def test_port(self):
290        args = self.parser.parse_args(
291            ['ssh', '--hostname', 'foo', '--port', '123']
292        )
293        self.assertEqual(args.port, 123)
294
295    def test_ssh_username(self):
296        args = self.parser.parse_args(
297            ['ssh', '--hostname', 'foo', '--ssh-username', 'foo']
298        )
299        self.assertEqual(args.ssh_username, 'foo')
300
301    def test_ssh_password(self):
302        args = self.parser.parse_args(
303            ['ssh', '--hostname', 'foo', '--ssh-password', 'foo']
304        )
305        self.assertEqual(args.ssh_password, 'foo')
306
307
308class TlsParserTestCase(ParserTestCase):
309    def test_defaults(self):
310        args = self.parser.parse_args(['tls', '--hostname=foo'])
311        self.assertIsNone(args.certfile)
312        self.assertIsNone(args.keyfile)
313        self.assertIsNone(args.cafile)
314        self.assertEqual(args.port, 9390)
315
316    def test_connection_type(self):
317        args = self.parser.parse_args(['tls', '--hostname=foo'])
318        self.assertEqual(args.connection_type, 'tls')
319
320    def test_hostname(self):
321        args = self.parser.parse_args(['tls', '--hostname', 'foo'])
322        self.assertEqual(args.hostname, 'foo')
323
324    def test_port(self):
325        args = self.parser.parse_args(
326            ['tls', '--hostname', 'foo', '--port', '123']
327        )
328        self.assertEqual(args.port, 123)
329
330    def test_certfile(self):
331        args = self.parser.parse_args(
332            ['tls', '--hostname', 'foo', '--certfile', 'foo.cert']
333        )
334        self.assertEqual(args.certfile, 'foo.cert')
335
336    def test_keyfile(self):
337        args = self.parser.parse_args(
338            ['tls', '--hostname', 'foo', '--keyfile', 'foo.key']
339        )
340        self.assertEqual(args.keyfile, 'foo.key')
341
342    def test_cafile(self):
343        args = self.parser.parse_args(
344            ['tls', '--hostname', 'foo', '--cafile', 'foo.ca']
345        )
346        self.assertEqual(args.cafile, 'foo.ca')
347
348    def test_no_credentials(self):
349        args = self.parser.parse_args(
350            ['tls', '--hostname', 'foo', '--no-credentials']
351        )
352        self.assertTrue(args.no_credentials)
353
354
355class CustomizeParserTestCase(ParserTestCase):
356    def test_add_optional_argument(self):
357        self.parser.add_argument('--foo', type=int)
358
359        args = self.parser.parse_args(['socket', '--foo', '123'])
360        self.assertEqual(args.foo, 123)
361
362        args = self.parser.parse_args(
363            ['ssh', '--hostname', 'bar', '--foo', '123']
364        )
365        self.assertEqual(args.foo, 123)
366
367        args = self.parser.parse_args(
368            ['tls', '--hostname', 'bar', '--foo', '123']
369        )
370        self.assertEqual(args.foo, 123)
371
372    def test_add_positional_argument(self):
373        self.parser.add_argument('foo', type=int)
374        args = self.parser.parse_args(['socket', '123'])
375
376        self.assertEqual(args.foo, 123)
377
378    def test_add_protocol_argument(self):
379        self.parser.add_protocol_argument()
380
381        args = self.parser.parse_args(['socket'])
382        self.assertEqual(args.protocol, 'GMP')
383
384        args = self.parser.parse_args(['--protocol', 'OSP', 'socket'])
385
386        self.assertEqual(args.protocol, 'OSP')
387
388
389class HelpFormattingParserTestCase(ParserTestCase):
390    # pylint: disable=protected-access
391    maxDiff = None
392    python_version = '.'.join([str(i) for i in sys.version_info[:2]])
393
394    def setUp(self):
395        super().setUp()
396
397        # ensure all tests are using the same terminal width
398        self.columns = os.environ.get('COLUMNS')
399        os.environ['COLUMNS'] = '80'
400
401    def tearDown(self):
402        super().tearDown()
403
404        if not self.columns:
405            del os.environ['COLUMNS']
406        else:
407            os.environ['COLUMNS'] = self.columns
408
409    def _snapshot_specific_path(self, name):
410        return __here__ / f'{name}.{self.python_version}.snap'
411
412    def _snapshot_generic_path(self, name):
413        return __here__ / f'{name}.snap'
414
415    def _snapshot_failed_path(self, name):
416        return __here__ / f'{name}.{self.python_version}-failed.snap'
417
418    def _snapshot_path(self, name):
419        snapshot_specific_path = self._snapshot_specific_path(name)
420
421        if snapshot_specific_path.exists():
422            return snapshot_specific_path
423
424        return self._snapshot_generic_path(name)
425
426    def assert_snapshot(self, name, output):
427        path = self._snapshot_path(name)
428
429        if not path.exists():
430            path.write_text(output)
431
432        content = path.read_text(encoding='utf-8')
433
434        try:
435            self.assertEqual(content, output, 'Snapshot differs from output')
436        except AssertionError:
437            # write new output to snapshot file
438            # reraise error afterwards
439            path = self._snapshot_failed_path(name)
440            path.write_text(output)
441            raise
442
443    def test_root_help(self):
444        help_output = self.parser._parser.format_help()
445        self.assert_snapshot('root_help', help_output)
446
447    def test_socket_help(self):
448        help_output = self.parser._parser_socket.format_help()
449        self.assert_snapshot('socket_help', help_output)
450
451    def test_ssh_help(self):
452        self.parser._set_defaults(None)
453        help_output = self.parser._parser_ssh.format_help()
454        self.assert_snapshot('ssh_help', help_output)
455
456    def test_tls_help(self):
457        self.parser._set_defaults(None)
458        help_output = self.parser._parser_tls.format_help()
459        self.assert_snapshot('tls_help', help_output)
460
461
462class CreateParserFunctionTestCase(unittest.TestCase):
463    # pylint: disable=protected-access
464    def test_create_parser(self):
465        description = 'parser description'
466        logfilename = 'logfilename'
467
468        parser = create_parser(description, logfilename)
469
470        self.assertIsInstance(parser, CliParser)
471        self.assertEqual(parser._logfilename, logfilename)
472        self.assertEqual(parser._bootstrap_parser.description, description)
473
474
475class CreateConnectionTestCase(unittest.TestCase):
476    def test_create_unix_socket_connection(self):
477        self.perform_create_connection_test()
478
479    def test_create_tls_connection(self):
480        self.perform_create_connection_test('tls', TLSConnection)
481
482    def test_create_ssh_connection(self):
483        self.perform_create_connection_test('ssh', SSHConnection, 22)
484
485    def perform_create_connection_test(
486        self,
487        connection_type='socket',
488        connection_class=UnixSocketConnection,
489        port=None,
490    ):
491        connection = create_connection(connection_type, port=port)
492        self.assertIsInstance(connection, connection_class)
493