1"""
2Netmiko SCP operations.
3
4Supports file get and file put operations.
5
6SCP requires a separate SSH connection for a control channel.
7
8Currently only supports Cisco IOS and Cisco ASA.
9"""
10import re
11import os
12import hashlib
13
14import scp
15import platform
16
17
18class SCPConn(object):
19    """
20    Establish a secure copy channel to the remote network device.
21
22    Must close the SCP connection to get the file to write to the remote filesystem
23    """
24
25    def __init__(self, ssh_conn, socket_timeout=10.0, progress=None, progress4=None):
26        self.ssh_ctl_chan = ssh_conn
27        self.socket_timeout = socket_timeout
28        self.progress = progress
29        self.progress4 = progress4
30        self.establish_scp_conn()
31
32    def establish_scp_conn(self):
33        """Establish the secure copy connection."""
34        ssh_connect_params = self.ssh_ctl_chan._connect_params_dict()
35        self.scp_conn = self.ssh_ctl_chan._build_ssh_client()
36        self.scp_conn.connect(**ssh_connect_params)
37        self.scp_client = scp.SCPClient(
38            self.scp_conn.get_transport(),
39            socket_timeout=self.socket_timeout,
40            progress=self.progress,
41            progress4=self.progress4,
42        )
43
44    def scp_transfer_file(self, source_file, dest_file):
45        """Put file using SCP (for backwards compatibility)."""
46        self.scp_client.put(source_file, dest_file)
47
48    def scp_get_file(self, source_file, dest_file):
49        """Get file using SCP."""
50        platform = self.ssh_ctl_chan.device_type
51        if "cisco_ios" in platform or "cisco_xe" in platform:
52            try:
53                self.scp_client.get(source_file, dest_file)
54            except EOFError:
55                pass
56        else:
57            self.scp_client.get(source_file, dest_file)
58
59    def scp_put_file(self, source_file, dest_file):
60        """Put file using SCP."""
61        self.scp_client.put(source_file, dest_file)
62
63    def close(self):
64        """Close the SCP connection."""
65        self.scp_conn.close()
66
67
68class BaseFileTransfer(object):
69    """Class to manage SCP file transfer and associated SSH control channel."""
70
71    def __init__(
72        self,
73        ssh_conn,
74        source_file,
75        dest_file,
76        file_system=None,
77        direction="put",
78        socket_timeout=10.0,
79        progress=None,
80        progress4=None,
81        hash_supported=True,
82    ):
83        self.ssh_ctl_chan = ssh_conn
84        self.source_file = source_file
85        self.dest_file = dest_file
86        self.direction = direction
87        self.socket_timeout = socket_timeout
88        self.progress = progress
89        self.progress4 = progress4
90
91        auto_flag = (
92            "cisco_ios" in ssh_conn.device_type
93            or "cisco_xe" in ssh_conn.device_type
94            or "cisco_xr" in ssh_conn.device_type
95        )
96        if not file_system:
97            if auto_flag:
98                self.file_system = self.ssh_ctl_chan._autodetect_fs()
99            else:
100                raise ValueError("Destination file system not specified")
101        else:
102            self.file_system = file_system
103
104        if direction == "put":
105            self.source_md5 = self.file_md5(source_file) if hash_supported else None
106            self.file_size = os.stat(source_file).st_size
107        elif direction == "get":
108            self.source_md5 = (
109                self.remote_md5(remote_file=source_file) if hash_supported else None
110            )
111            self.file_size = self.remote_file_size(remote_file=source_file)
112        else:
113            raise ValueError("Invalid direction specified")
114
115    def __enter__(self):
116        """Context manager setup"""
117        self.establish_scp_conn()
118        return self
119
120    def __exit__(self, exc_type, exc_value, traceback):
121        """Context manager cleanup."""
122        self.close_scp_chan()
123
124    def establish_scp_conn(self):
125        """Establish SCP connection."""
126        self.scp_conn = SCPConn(
127            self.ssh_ctl_chan,
128            socket_timeout=self.socket_timeout,
129            progress=self.progress,
130            progress4=self.progress4,
131        )
132
133    def close_scp_chan(self):
134        """Close the SCP connection to the remote network device."""
135        self.scp_conn.close()
136        self.scp_conn = None
137
138    def remote_space_available(self, search_pattern=r"(\d+) \w+ free"):
139        """Return space available on remote device."""
140        remote_cmd = f"dir {self.file_system}"
141        remote_output = self.ssh_ctl_chan.send_command_expect(remote_cmd)
142        match = re.search(search_pattern, remote_output)
143        if "kbytes" in match.group(0) or "Kbytes" in match.group(0):
144            return int(match.group(1)) * 1000
145        return int(match.group(1))
146
147    def _remote_space_available_unix(self, search_pattern=""):
148        """Return space available on *nix system (BSD/Linux)."""
149        self.ssh_ctl_chan._enter_shell()
150        remote_cmd = f"/bin/df -k {self.file_system}"
151        remote_output = self.ssh_ctl_chan.send_command(
152            remote_cmd, expect_string=r"[\$#]"
153        )
154
155        # Try to ensure parsing is correct:
156        # Filesystem   1K-blocks  Used   Avail Capacity  Mounted on
157        # /dev/bo0s3f    1264808 16376 1147248     1%    /cf/var
158        remote_output = remote_output.strip()
159        output_lines = remote_output.splitlines()
160
161        # First line is the header; second is the actual file system info
162        header_line = output_lines[0]
163        filesystem_line = output_lines[1]
164
165        if "Filesystem" not in header_line or "Avail" not in header_line.split()[3]:
166            # Filesystem  1K-blocks  Used   Avail Capacity  Mounted on
167            msg = "Parsing error, unexpected output from {}:\n{}".format(
168                remote_cmd, remote_output
169            )
170            raise ValueError(msg)
171
172        space_available = filesystem_line.split()[3]
173        if not re.search(r"^\d+$", space_available):
174            msg = "Parsing error, unexpected output from {}:\n{}".format(
175                remote_cmd, remote_output
176            )
177            raise ValueError(msg)
178
179        self.ssh_ctl_chan._return_cli()
180        return int(space_available) * 1024
181
182    def local_space_available(self):
183        """Return space available on local filesystem."""
184        if platform.system() == "Windows":
185            import ctypes
186
187            free_bytes = ctypes.c_ulonglong(0)
188            ctypes.windll.kernel32.GetDiskFreeSpaceExW(
189                ctypes.c_wchar_p("."), None, None, ctypes.pointer(free_bytes)
190            )
191            return free_bytes.value
192        else:
193            destination_stats = os.statvfs(".")
194            return destination_stats.f_bsize * destination_stats.f_bavail
195
196    def verify_space_available(self, search_pattern=r"(\d+) \w+ free"):
197        """Verify sufficient space is available on destination file system (return boolean)."""
198        if self.direction == "put":
199            space_avail = self.remote_space_available(search_pattern=search_pattern)
200        elif self.direction == "get":
201            space_avail = self.local_space_available()
202        if space_avail > self.file_size:
203            return True
204        return False
205
206    def check_file_exists(self, remote_cmd=""):
207        """Check if the dest_file already exists on the file system (return boolean)."""
208        if self.direction == "put":
209            if not remote_cmd:
210                remote_cmd = f"dir {self.file_system}/{self.dest_file}"
211            remote_out = self.ssh_ctl_chan.send_command_expect(remote_cmd)
212            search_string = r"Directory of .*{0}".format(self.dest_file)
213            if (
214                "Error opening" in remote_out
215                or "No such file or directory" in remote_out
216                or "Path does not exist" in remote_out
217            ):
218                return False
219            elif re.search(search_string, remote_out, flags=re.DOTALL):
220                return True
221            else:
222                raise ValueError("Unexpected output from check_file_exists")
223        elif self.direction == "get":
224            return os.path.exists(self.dest_file)
225
226    def _check_file_exists_unix(self, remote_cmd=""):
227        """Check if the dest_file already exists on the file system (return boolean)."""
228        if self.direction == "put":
229            self.ssh_ctl_chan._enter_shell()
230            remote_cmd = f"ls {self.file_system}"
231            remote_out = self.ssh_ctl_chan.send_command(
232                remote_cmd, expect_string=r"[\$#]"
233            )
234            self.ssh_ctl_chan._return_cli()
235            return self.dest_file in remote_out
236        elif self.direction == "get":
237            return os.path.exists(self.dest_file)
238
239    def remote_file_size(self, remote_cmd="", remote_file=None):
240        """Get the file size of the remote file."""
241        if remote_file is None:
242            if self.direction == "put":
243                remote_file = self.dest_file
244            elif self.direction == "get":
245                remote_file = self.source_file
246        if not remote_cmd:
247            remote_cmd = f"dir {self.file_system}/{remote_file}"
248        remote_out = self.ssh_ctl_chan.send_command(remote_cmd)
249        # Strip out "Directory of flash:/filename line
250        remote_out = re.split(r"Directory of .*", remote_out)
251        remote_out = "".join(remote_out)
252        # Match line containing file name
253        escape_file_name = re.escape(remote_file)
254        pattern = r".*({}).*".format(escape_file_name)
255        match = re.search(pattern, remote_out)
256        if match:
257            line = match.group(0)
258            # Format will be 26  -rw-   6738  Jul 30 2016 19:49:50 -07:00  filename
259            file_size = line.split()[2]
260        if "Error opening" in remote_out or "No such file or directory" in remote_out:
261            raise IOError("Unable to find file on remote system")
262        else:
263            return int(file_size)
264
265    def _remote_file_size_unix(self, remote_cmd="", remote_file=None):
266        """Get the file size of the remote file."""
267        if remote_file is None:
268            if self.direction == "put":
269                remote_file = self.dest_file
270            elif self.direction == "get":
271                remote_file = self.source_file
272        remote_file = f"{self.file_system}/{remote_file}"
273        if not remote_cmd:
274            remote_cmd = f"ls -l {remote_file}"
275
276        self.ssh_ctl_chan._enter_shell()
277        remote_out = self.ssh_ctl_chan.send_command(remote_cmd, expect_string=r"[\$#]")
278        self.ssh_ctl_chan._return_cli()
279
280        if "No such file or directory" in remote_out:
281            raise IOError("Unable to find file on remote system")
282
283        escape_file_name = re.escape(remote_file)
284        pattern = r"^.* ({}).*$".format(escape_file_name)
285        match = re.search(pattern, remote_out, flags=re.M)
286        if match:
287            # Format: -rw-r--r--  1 pyclass  wheel  12 Nov  5 19:07 /var/tmp/test3.txt
288            line = match.group(0)
289            file_size = line.split()[4]
290            return int(file_size)
291
292        raise ValueError(
293            "Search pattern not found for remote file size during SCP transfer."
294        )
295
296    def file_md5(self, file_name, add_newline=False):
297        """Compute MD5 hash of file.
298
299        add_newline is needed to support Cisco IOS MD5 calculation which expects the newline in
300        the string
301
302        Args:
303          file_name: name of file to get md5 digest of
304          add_newline: add newline to end of file contents or not
305
306        """
307        file_hash = hashlib.md5()
308        with open(file_name, "rb") as f:
309            while True:
310                file_contents = f.read(512)
311                if not file_contents:
312                    if add_newline:
313                        file_contents + b"\n"
314                    break
315                file_hash.update(file_contents)
316        return file_hash.hexdigest()
317
318    @staticmethod
319    def process_md5(md5_output, pattern=r"=\s+(\S+)"):
320        """
321        Process the string to retrieve the MD5 hash
322
323        Output from Cisco IOS (ASA is similar)
324        .MD5 of flash:file_name Done!
325        verify /md5 (flash:file_name) = 410db2a7015eaa42b1fe71f1bf3d59a2
326        """
327        match = re.search(pattern, md5_output)
328        if match:
329            return match.group(1)
330        else:
331            raise ValueError(f"Invalid output from MD5 command: {md5_output}")
332
333    def compare_md5(self):
334        """Compare md5 of file on network device to md5 of local file."""
335        if self.direction == "put":
336            remote_md5 = self.remote_md5()
337            return self.source_md5 == remote_md5
338        elif self.direction == "get":
339            local_md5 = self.file_md5(self.dest_file)
340            return self.source_md5 == local_md5
341
342    def remote_md5(self, base_cmd="verify /md5", remote_file=None):
343        """Calculate remote MD5 and returns the hash.
344
345        This command can be CPU intensive on the remote device.
346        """
347        if remote_file is None:
348            if self.direction == "put":
349                remote_file = self.dest_file
350            elif self.direction == "get":
351                remote_file = self.source_file
352        remote_md5_cmd = f"{base_cmd} {self.file_system}/{remote_file}"
353        dest_md5 = self.ssh_ctl_chan.send_command(remote_md5_cmd, max_loops=1500)
354        dest_md5 = self.process_md5(dest_md5)
355        return dest_md5
356
357    def transfer_file(self):
358        """SCP transfer file."""
359        if self.direction == "put":
360            self.put_file()
361        elif self.direction == "get":
362            self.get_file()
363
364    def get_file(self):
365        """SCP copy the file from the remote device to local system."""
366        source_file = f"{self.file_system}/{self.source_file}"
367        self.scp_conn.scp_get_file(source_file, self.dest_file)
368        self.scp_conn.close()
369
370    def put_file(self):
371        """SCP copy the file from the local system to the remote device."""
372        destination = f"{self.file_system}/{self.dest_file}"
373        self.scp_conn.scp_transfer_file(self.source_file, destination)
374        # Must close the SCP connection to get the file written (flush)
375        self.scp_conn.close()
376
377    def verify_file(self):
378        """Verify the file has been transferred correctly."""
379        return self.compare_md5()
380
381    def enable_scp(self, cmd=None):
382        """
383        Enable SCP on remote device.
384
385        Defaults to Cisco IOS command
386        """
387        if cmd is None:
388            cmd = ["ip scp server enable"]
389        elif not hasattr(cmd, "__iter__"):
390            cmd = [cmd]
391        self.ssh_ctl_chan.send_config_set(cmd)
392
393    def disable_scp(self, cmd=None):
394        """
395        Disable SCP on remote device.
396
397        Defaults to Cisco IOS command
398        """
399        if cmd is None:
400            cmd = ["no ip scp server enable"]
401        elif not hasattr(cmd, "__iter__"):
402            cmd = [cmd]
403        self.ssh_ctl_chan.send_config_set(cmd)
404