1# Copyright (C) 2014-2021 Greenbone Networks GmbH
2#
3# SPDX-License-Identifier: AGPL-3.0-or-later
4#
5# This program is free software: you can redistribute it and/or modify
6# it under the terms of the GNU Affero General Public License as
7# published by the Free Software Foundation, either version 3 of the
8# License, or (at your option) any later version.
9#
10# This program is distributed in the hope that it will be useful,
11# but WITHOUT ANY WARRANTY; without even the implied warranty of
12# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
13# GNU Affero General Public License for more details.
14#
15# You should have received a copy of the GNU Affero General Public License
16# along with this program. If not, see <http://www.gnu.org/licenses/>.
17
18""" Test module for ospd ssh support.
19"""
20
21import unittest
22
23from ospd import ospd_ssh
24from ospd.ospd_ssh import OSPDaemonSimpleSSH
25from .helper import FakeDataManager
26
27
28class FakeFile(object):
29    def __init__(self, content):
30        self.content = content
31
32    def readlines(self):
33        return self.content.split('\n')
34
35
36commands = None  # pylint: disable=invalid-name
37
38
39class FakeSSHClient(object):
40    def __init__(self):
41        global commands  # pylint: disable=global-statement,invalid-name
42        commands = []
43
44    def set_missing_host_key_policy(self, policy):
45        pass
46
47    def connect(self, **kwargs):
48        pass
49
50    def exec_command(self, cmd):
51        commands.append(cmd)
52        return None, FakeFile(''), None
53
54    def close(self):
55        pass
56
57
58class FakeExceptions(object):
59    AuthenticationException = None  # pylint: disable=invalid-name
60
61
62class fakeparamiko(object):  # pylint: disable=invalid-name
63    @staticmethod
64    def SSHClient(*args):  # pylint: disable=invalid-name
65        return FakeSSHClient(*args)
66
67    @staticmethod
68    def AutoAddPolicy():  # pylint: disable=invalid-name
69        pass
70
71    ssh_exception = FakeExceptions
72
73
74class DummyWrapper(OSPDaemonSimpleSSH):
75    def __init__(self, niceness=10):
76        super().__init__(niceness=niceness)
77        self.scan_collection.data_manager = FakeDataManager()
78        self.scan_collection.file_storage_dir = '/tmp'
79        self.initialized = True
80
81    def check(self):
82        return True
83
84    def exec_scan(self, scan_id: str):
85        return
86
87
88class SSHDaemonTestCase(unittest.TestCase):
89    def test_no_paramiko(self):
90        ospd_ssh.paramiko = None
91
92        with self.assertRaises(ImportError):
93            OSPDaemonSimpleSSH()
94
95    def test_run_command(self):
96        ospd_ssh.paramiko = fakeparamiko
97
98        daemon = DummyWrapper(niceness=10)
99        scanid = daemon.create_scan(
100            None,
101            {
102                'target': 'host.example.com',
103                'ports': '80, 443',
104                'credentials': {},
105                'exclude_hosts': '',
106                'finished_hosts': '',
107                'options': {},
108            },
109            dict(port=5, ssh_timeout=15, username_password='dummy:pw'),
110            '',
111        )
112        daemon.start_queued_scans()
113        res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
114
115        self.assertIsInstance(res, list)
116        self.assertEqual(commands, ['nice -n 10 cat /etc/passwd'])
117
118    def test_run_command_legacy_credential(self):
119        ospd_ssh.paramiko = fakeparamiko
120
121        daemon = DummyWrapper(niceness=10)
122        scanid = daemon.create_scan(
123            None,
124            {
125                'target': 'host.example.com',
126                'ports': '80, 443',
127                'credentials': {},
128                'exclude_hosts': '',
129                'finished_hosts': '',
130                'options': {},
131            },
132            dict(port=5, ssh_timeout=15, username='dummy', password='pw'),
133            '',
134        )
135        daemon.start_queued_scans()
136        res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
137
138        self.assertIsInstance(res, list)
139        self.assertEqual(commands, ['nice -n 10 cat /etc/passwd'])
140
141    def test_run_command_new_credential(self):
142        ospd_ssh.paramiko = fakeparamiko
143
144        daemon = DummyWrapper(niceness=10)
145
146        cred_dict = {
147            'ssh': {
148                'type': 'up',
149                'password': 'mypass',
150                'port': '22',
151                'username': 'scanuser',
152            },
153            'smb': {'type': 'up', 'password': 'mypass', 'username': 'smbuser'},
154        }
155
156        scanid = daemon.create_scan(
157            None,
158            {
159                'target': 'host.example.com',
160                'ports': '80, 443',
161                'credentials': cred_dict,
162                'exclude_hosts': '',
163                'finished_hosts': '',
164                'options': {},
165            },
166            dict(port=5, ssh_timeout=15),
167            '',
168        )
169        daemon.start_queued_scans()
170        res = daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
171
172        self.assertIsInstance(res, list)
173        self.assertEqual(commands, ['nice -n 10 cat /etc/passwd'])
174
175    def test_run_command_no_credential(self):
176        ospd_ssh.paramiko = fakeparamiko
177
178        daemon = DummyWrapper(niceness=10)
179        scanid = daemon.create_scan(
180            None,
181            {
182                'target': 'host.example.com',
183                'ports': '80, 443',
184                'credentials': {},
185                'exclude_hosts': '',
186                'finished_hosts': '',
187                'options': {},
188            },
189            dict(port=5, ssh_timeout=15),
190            '',
191        )
192        daemon.start_queued_scans()
193
194        with self.assertRaises(ValueError):
195            daemon.run_command(scanid, 'host.example.com', 'cat /etc/passwd')
196