1import requests
2import platform
3from numbers import Number
4import xml.etree.cElementTree as xml
5from collections import namedtuple
6from http.client import responses as HTTP_CODES
7from urllib.parse import urlparse
8
9DOWNLOAD_CHUNK_SIZE_BYTES = 1 * 1024 * 1024
10
11class WebdavException(Exception):
12    pass
13
14class ConnectionFailed(WebdavException):
15    pass
16
17
18def codestr(code):
19    return HTTP_CODES.get(code, 'UNKNOWN')
20
21
22File = namedtuple('File', ['name', 'size', 'mtime', 'ctime', 'contenttype'])
23
24
25def prop(elem, name, default=None):
26    child = elem.find('.//{DAV:}' + name)
27    return default if child is None else child.text
28
29
30def elem2file(elem):
31    return File(
32        prop(elem, 'href'),
33        int(prop(elem, 'getcontentlength', 0)),
34        prop(elem, 'getlastmodified', ''),
35        prop(elem, 'creationdate', ''),
36        prop(elem, 'getcontenttype', ''),
37    )
38
39
40class OperationFailed(WebdavException):
41    _OPERATIONS = dict(
42        HEAD = "get header",
43        GET = "download",
44        PUT = "upload",
45        DELETE = "delete",
46        MKCOL = "create directory",
47        PROPFIND = "list directory",
48        )
49
50    def __init__(self, method, path, expected_code, actual_code):
51        self.method = method
52        self.path = path
53        self.expected_code = expected_code
54        self.actual_code = actual_code
55        operation_name = self._OPERATIONS[method]
56        self.reason = 'Failed to {operation_name} "{path}"'.format(**locals())
57        expected_codes = (expected_code,) if isinstance(expected_code, Number) else expected_code
58        expected_codes_str = ", ".join('{0} {1}'.format(code, codestr(code)) for code in expected_codes)
59        actual_code_str = codestr(actual_code)
60        msg = '''\
61{self.reason}.
62  Operation     :  {method} {path}
63  Expected code :  {expected_codes_str}
64  Actual code   :  {actual_code} {actual_code_str}'''.format(**locals())
65        super(OperationFailed, self).__init__(msg)
66
67class Client(object):
68    def __init__(self, host, port=0, auth=None, username=None, password=None,
69                 protocol='http', verify_ssl=True, path=None, cert=None):
70        if not port:
71            port = 443 if protocol == 'https' else 80
72        self.baseurl = '{0}://{1}:{2}'.format(protocol, host, port)
73        if path:
74            self.baseurl = '{0}/{1}'.format(self.baseurl, path)
75        self.cwd = '/'
76        self.session = requests.session()
77        self.session.verify = verify_ssl
78        self.session.stream = True
79
80        if cert:
81            self.session.cert = cert
82
83        if auth:
84            self.session.auth = auth
85        elif username and password:
86            self.session.auth = (username, password)
87
88    def _send(self, method, path, expected_code, **kwargs):
89        url = self._get_url(path)
90        response = self.session.request(method, url, allow_redirects=False, **kwargs)
91        if isinstance(expected_code, Number) and response.status_code != expected_code \
92            or not isinstance(expected_code, Number) and response.status_code not in expected_code:
93            raise OperationFailed(method, path, expected_code, response.status_code)
94        return response
95
96    def _get_url(self, path):
97        path = str(path).strip()
98        if path.startswith('/'):
99            return self.baseurl + path
100        return "".join((self.baseurl, self.cwd, path))
101
102    def cd(self, path):
103        path = path.strip()
104        if not path:
105            return
106        stripped_path = '/'.join(part for part in path.split('/') if part) + '/'
107        if stripped_path == '/':
108            self.cwd = stripped_path
109        elif path.startswith('/'):
110            self.cwd = '/' + stripped_path
111        else:
112            self.cwd += stripped_path
113
114    def mkdir(self, path, safe=False):
115        expected_codes = 201 if not safe else (201, 301, 405)
116        self._send('MKCOL', path, expected_codes)
117
118    def mkdirs(self, path):
119        dirs = [d for d in path.split('/') if d]
120        if not dirs:
121            return
122        if path.startswith('/'):
123            dirs[0] = '/' + dirs[0]
124        old_cwd = self.cwd
125        try:
126            for dir in dirs:
127                try:
128                    self.mkdir(dir, safe=True)
129                except Exception as e:
130                    if e.actual_code == 409:
131                        raise
132                finally:
133                    self.cd(dir)
134        finally:
135            self.cd(old_cwd)
136
137    def rmdir(self, path, safe=False):
138        path = str(path).rstrip('/') + '/'
139        expected_codes = 204 if not safe else (204, 404)
140        self._send('DELETE', path, expected_codes)
141
142    def delete(self, path):
143        self._send('DELETE', path, 204)
144
145    def upload(self, local_path_or_fileobj, remote_path):
146        if isinstance(local_path_or_fileobj, str):
147            with open(local_path_or_fileobj, 'rb') as f:
148                self._upload(f, remote_path)
149        else:
150            self._upload(local_path_or_fileobj, remote_path)
151
152    def _upload(self, fileobj, remote_path):
153        self._send('PUT', remote_path, (200, 201, 204), data=fileobj)
154
155    def download(self, remote_path, local_path_or_fileobj):
156        response = self._send('GET', remote_path, 200, stream=True)
157        if isinstance(local_path_or_fileobj, str):
158            with open(local_path_or_fileobj, 'wb') as f:
159                self._download(f, response)
160        else:
161            self._download(local_path_or_fileobj, response)
162
163    def _download(self, fileobj, response):
164        for chunk in response.iter_content(DOWNLOAD_CHUNK_SIZE_BYTES):
165            fileobj.write(chunk)
166
167    def ls(self, remote_path='.'):
168        headers = {'Depth': '1'}
169        response = self._send('PROPFIND', remote_path, (207, 301), headers=headers)
170
171        # Redirect
172        if response.status_code == 301:
173            url = urlparse(response.headers['location'])
174            return self.ls(url.path)
175
176        tree = xml.fromstring(response.content)
177        return [elem2file(elem) for elem in tree.findall('{DAV:}response')]
178
179    def exists(self, remote_path):
180        response = self._send('HEAD', remote_path, (200, 301, 404))
181        return True if response.status_code != 404 else False
182