1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6
7import os
8from azure.cli.core.profiles import ResourceType
9from datetime import datetime
10
11
12def collect_blobs(blob_service, container, pattern=None):
13    """
14    List the blobs in the given blob container, filter the blob by comparing their path to the given pattern.
15    """
16    return [name for (name, _) in collect_blob_objects(blob_service, container, pattern)]
17
18
19def collect_blob_objects(blob_service, container, pattern=None):
20    """
21    List the blob name and blob in the given blob container, filter the blob by comparing their path to
22     the given pattern.
23    """
24    if not blob_service:
25        raise ValueError('missing parameter blob_service')
26
27    if not container:
28        raise ValueError('missing parameter container')
29
30    if not _pattern_has_wildcards(pattern):
31        if blob_service.exists(container, pattern):
32            yield pattern, blob_service.get_blob_properties(container, pattern)
33    else:
34        for blob in blob_service.list_blobs(container):
35            try:
36                blob_name = blob.name.encode('utf-8') if isinstance(blob.name, unicode) else blob.name
37            except NameError:
38                blob_name = blob.name
39
40            if not pattern or _match_path(blob_name, pattern):
41                yield blob_name, blob
42
43
44def collect_files(cmd, file_service, share, pattern=None):
45    """
46    Search files in the the given file share recursively. Filter the files by matching their path to the given pattern.
47    Returns a iterable of tuple (dir, name).
48    """
49    if not file_service:
50        raise ValueError('missing parameter file_service')
51
52    if not share:
53        raise ValueError('missing parameter share')
54
55    if not _pattern_has_wildcards(pattern):
56        return [pattern]
57
58    return glob_files_remotely(cmd, file_service, share, pattern)
59
60
61def create_blob_service_from_storage_client(cmd, client):
62    t_block_blob_svc = cmd.get_models('blob#BlockBlobService')
63    return t_block_blob_svc(account_name=client.account_name,
64                            account_key=client.account_key,
65                            sas_token=client.sas_token)
66
67
68def create_file_share_from_storage_client(cmd, client):
69    t_file_svc = cmd.get_models('file.fileservice#FileService')
70    return t_file_svc(account_name=client.account_name,
71                      account_key=client.account_key,
72                      sas_token=client.sas_token)
73
74
75def filter_none(iterable):
76    return (x for x in iterable if x is not None)
77
78
79def glob_files_locally(folder_path, pattern):
80    """glob files in local folder based on the given pattern"""
81
82    pattern = os.path.join(folder_path, pattern.lstrip('/')) if pattern else None
83
84    len_folder_path = len(folder_path) + 1
85    for root, _, files in os.walk(folder_path):
86        for f in files:
87            full_path = os.path.join(root, f)
88            if not pattern or _match_path(full_path, pattern):
89                yield (full_path, full_path[len_folder_path:])
90
91
92def glob_files_remotely(cmd, client, share_name, pattern, snapshot=None):
93    """glob the files in remote file share based on the given pattern"""
94    from collections import deque
95    t_dir, t_file = cmd.get_models('file.models#Directory', 'file.models#File')
96
97    queue = deque([""])
98    while queue:
99        current_dir = queue.pop()
100        for f in client.list_directories_and_files(share_name, current_dir, snapshot=snapshot):
101            if isinstance(f, t_file):
102                if not pattern or _match_path(os.path.join(current_dir, f.name), pattern):
103                    yield current_dir, f.name
104            elif isinstance(f, t_dir):
105                queue.appendleft(os.path.join(current_dir, f.name))
106
107
108def create_short_lived_blob_sas(cmd, account_name, account_key, container, blob):
109    from datetime import timedelta
110    if cmd.supported_api_version(min_api='2017-04-17'):
111        t_sas = cmd.get_models('blob.sharedaccesssignature#BlobSharedAccessSignature')
112    else:
113        t_sas = cmd.get_models('shareaccesssignature#SharedAccessSignature')
114
115    t_blob_permissions = cmd.get_models('blob.models#BlobPermissions')
116    expiry = (datetime.utcnow() + timedelta(days=1)).strftime('%Y-%m-%dT%H:%M:%SZ')
117    sas = t_sas(account_name, account_key)
118    return sas.generate_blob(container, blob, permission=t_blob_permissions(read=True), expiry=expiry, protocol='https')
119
120
121def create_short_lived_blob_sas_v2(cmd, account_name, account_key, container, blob):
122    from datetime import timedelta
123
124    t_sas = cmd.get_models('_shared_access_signature#BlobSharedAccessSignature',
125                           resource_type=ResourceType.DATA_STORAGE_BLOB)
126
127    t_blob_permissions = cmd.get_models('_models#BlobSasPermissions', resource_type=ResourceType.DATA_STORAGE_BLOB)
128    expiry = (datetime.utcnow() + timedelta(days=1)).strftime('%Y-%m-%dT%H:%M:%SZ')
129    sas = t_sas(account_name, account_key)
130    return sas.generate_blob(container, blob, permission=t_blob_permissions(read=True), expiry=expiry, protocol='https')
131
132
133def create_short_lived_file_sas(cmd, account_name, account_key, share, directory_name, file_name):
134    from datetime import timedelta
135    if cmd.supported_api_version(min_api='2017-04-17'):
136        t_sas = cmd.get_models('file.sharedaccesssignature#FileSharedAccessSignature')
137    else:
138        t_sas = cmd.get_models('sharedaccesssignature#SharedAccessSignature')
139
140    t_file_permissions = cmd.get_models('file.models#FilePermissions')
141    # if dir is empty string change it to None
142    directory_name = directory_name if directory_name else None
143    expiry = (datetime.utcnow() + timedelta(days=1)).strftime('%Y-%m-%dT%H:%M:%SZ')
144    sas = t_sas(account_name, account_key)
145    return sas.generate_file(share, directory_name=directory_name, file_name=file_name,
146                             permission=t_file_permissions(read=True), expiry=expiry, protocol='https')
147
148
149def create_short_lived_container_sas(cmd, account_name, account_key, container):
150    from datetime import timedelta
151    if cmd.supported_api_version(min_api='2017-04-17'):
152        t_sas = cmd.get_models('blob.sharedaccesssignature#BlobSharedAccessSignature')
153    else:
154        t_sas = cmd.get_models('sharedaccesssignature#SharedAccessSignature')
155    t_blob_permissions = cmd.get_models('blob.models#BlobPermissions')
156
157    expiry = (datetime.utcnow() + timedelta(days=1)).strftime('%Y-%m-%dT%H:%M:%SZ')
158    sas = t_sas(account_name, account_key)
159    return sas.generate_container(container, permission=t_blob_permissions(read=True), expiry=expiry, protocol='https')
160
161
162def create_short_lived_share_sas(cmd, account_name, account_key, share):
163    from datetime import timedelta
164    if cmd.supported_api_version(min_api='2017-04-17'):
165        t_sas = cmd.get_models('file.sharedaccesssignature#FileSharedAccessSignature')
166    else:
167        t_sas = cmd.get_models('sharedaccesssignature#SharedAccessSignature')
168
169    t_file_permissions = cmd.get_models('file.models#FilePermissions')
170    expiry = (datetime.utcnow() + timedelta(days=1)).strftime('%Y-%m-%dT%H:%M:%SZ')
171    sas = t_sas(account_name, account_key)
172    return sas.generate_share(share, permission=t_file_permissions(read=True), expiry=expiry, protocol='https')
173
174
175def mkdir_p(path):
176    import errno
177    try:
178        os.makedirs(path)
179    except OSError as exc:  # Python <= 2.5
180        if exc.errno == errno.EEXIST and os.path.isdir(path):
181            pass
182        else:
183            raise
184
185
186def _pattern_has_wildcards(p):
187    return not p or p.find('*') != -1 or p.find('?') != -1 or p.find('[') != -1
188
189
190def _match_path(path, pattern):
191    from fnmatch import fnmatch
192    return fnmatch(path, pattern)
193
194
195def guess_content_type(file_path, original, settings_class):
196    if original.content_encoding or original.content_type:
197        return original
198
199    import mimetypes
200    mimetypes.add_type('application/json', '.json')
201    mimetypes.add_type('application/javascript', '.js')
202    mimetypes.add_type('application/wasm', '.wasm')
203
204    content_type, _ = mimetypes.guess_type(file_path)
205    return settings_class(
206        content_type=content_type,
207        content_encoding=original.content_encoding,
208        content_disposition=original.content_disposition,
209        content_language=original.content_language,
210        content_md5=original.content_md5,
211        cache_control=original.cache_control)
212
213
214def get_storage_client(cli_ctx, service_type, namespace):
215    from azure.cli.command_modules.storage._client_factory import get_storage_data_service_client
216
217    az_config = cli_ctx.config
218
219    name = getattr(namespace, 'account_name', az_config.get('storage', 'account', None))
220    key = getattr(namespace, 'account_key', az_config.get('storage', 'key', None))
221    connection_string = getattr(namespace, 'connection_string', az_config.get('storage', 'connection_string', None))
222    sas_token = getattr(namespace, 'sas_token', az_config.get('storage', 'sas_token', None))
223
224    return get_storage_data_service_client(cli_ctx, service_type, name, key, connection_string, sas_token)
225
226
227def normalize_blob_file_path(path, name):
228    # '/' is the path separator used by blobs/files, we normalize to it
229    path_sep = '/'
230    if path:
231        name = path_sep.join((path, name))
232    return path_sep.join(os.path.normpath(name).split(os.path.sep)).strip(path_sep)
233
234
235def check_precondition_success(func):
236    def wrapper(*args, **kwargs):
237        from azure.common import AzureHttpError
238        try:
239            return True, func(*args, **kwargs)
240        except AzureHttpError as ex:
241            # Precondition failed error
242            # https://developer.mozilla.org/docs/Web/HTTP/Status/412
243            # Not modified error
244            # https://developer.mozilla.org/docs/Web/HTTP/Status/304
245            if ex.status_code not in [304, 412]:
246                raise
247            return False, None
248    return wrapper
249
250
251def get_datetime_from_string(dt_str):
252    accepted_date_formats = ['%Y-%m-%dT%H:%M:%SZ', '%Y-%m-%dT%H:%MZ',
253                             '%Y-%m-%dT%HZ', '%Y-%m-%d']
254    for form in accepted_date_formats:
255        try:
256            return datetime.strptime(dt_str, form)
257        except ValueError:
258            continue
259    raise ValueError("datetime string '{}' not valid. Valid example: 2000-12-31T12:59:59Z".format(dt_str))
260