1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6"""
7Commands for storage file share operations
8"""
9
10import os
11from knack.log import get_logger
12
13from azure.cli.command_modules.storage.util import (filter_none, collect_blobs, collect_files,
14                                                    create_blob_service_from_storage_client,
15                                                    create_short_lived_container_sas, create_short_lived_share_sas,
16                                                    guess_content_type)
17from azure.cli.command_modules.storage.url_quote_util import encode_for_url, make_encoded_file_url_and_params
18from azure.cli.core.profiles import ResourceType
19
20
21def create_share_rm(cmd, client, resource_group_name, account_name, share_name, metadata=None, share_quota=None,
22                    enabled_protocols=None, root_squash=None, access_tier=None):
23
24    return _create_share_rm(cmd, client, resource_group_name, account_name, share_name, metadata=metadata,
25                            share_quota=share_quota, enabled_protocols=enabled_protocols, root_squash=root_squash,
26                            access_tier=access_tier, snapshot=False)
27
28
29def snapshot_share_rm(cmd, client, resource_group_name, account_name, share_name, metadata=None, share_quota=None,
30                      enabled_protocols=None, root_squash=None, access_tier=None):
31
32    return _create_share_rm(cmd, client, resource_group_name, account_name, share_name, metadata=metadata,
33                            share_quota=share_quota, enabled_protocols=enabled_protocols, root_squash=root_squash,
34                            access_tier=access_tier, snapshot=True)
35
36
37def _create_share_rm(cmd, client, resource_group_name, account_name, share_name, metadata=None, share_quota=None,
38                     enabled_protocols=None, root_squash=None, access_tier=None, snapshot=None):
39    FileShare = cmd.get_models('FileShare', resource_type=ResourceType.MGMT_STORAGE)
40
41    file_share = FileShare()
42    expand = None
43    if share_quota is not None:
44        file_share.share_quota = share_quota
45    if enabled_protocols is not None:
46        file_share.enabled_protocols = enabled_protocols
47    if root_squash is not None:
48        file_share.root_squash = root_squash
49    if metadata is not None:
50        file_share.metadata = metadata
51    if access_tier is not None:
52        file_share.access_tier = access_tier
53    if snapshot:
54        expand = 'snapshots'
55
56    return client.create(resource_group_name=resource_group_name, account_name=account_name, share_name=share_name,
57                         file_share=file_share, expand=expand)
58
59
60def get_stats(client, resource_group_name, account_name, share_name):
61    return client.get(resource_group_name=resource_group_name, account_name=account_name, share_name=share_name,
62                      expand='stats')
63
64
65def list_share_rm(client, resource_group_name, account_name, include_deleted=None, include_snapshot=None):
66    expand = None
67    expand_item = []
68    if include_deleted:
69        expand_item.append('deleted')
70    if include_snapshot:
71        expand_item.append('snapshots')
72    if expand_item:
73        expand = ','.join(expand_item)
74    return client.list(resource_group_name=resource_group_name, account_name=account_name, expand=expand)
75
76
77def restore_share_rm(cmd, client, resource_group_name, account_name, share_name, deleted_version, restored_name=None):
78
79    restored_name = restored_name if restored_name else share_name
80
81    deleted_share = cmd.get_models('DeletedShare',
82                                   resource_type=ResourceType.MGMT_STORAGE)(deleted_share_name=share_name,
83                                                                            deleted_share_version=deleted_version)
84
85    return client.restore(resource_group_name=resource_group_name, account_name=account_name,
86                          share_name=restored_name, deleted_share=deleted_share)
87
88
89def update_share_rm(cmd, instance, metadata=None, share_quota=None, root_squash=None, access_tier=None):
90    FileShare = cmd.get_models('FileShare', resource_type=ResourceType.MGMT_STORAGE)
91
92    params = FileShare(
93        share_quota=share_quota if share_quota is not None else instance.share_quota,
94        root_squash=root_squash if root_squash is not None else instance.root_squash,
95        metadata=metadata if metadata is not None else instance.metadata,
96        enabled_protocols=instance.enabled_protocols,
97        access_tier=access_tier if access_tier is not None else instance.access_tier
98    )
99
100    return params
101
102
103def create_share_url(client, share_name, unc=None, protocol=None):
104    url = client.make_file_url(share_name, None, '', protocol=protocol).rstrip('/')
105    if unc:
106        url = ':'.join(url.split(':')[1:])
107    return url
108
109
110def create_file_url(client, share_name, directory_name, file_name, protocol=None):
111    return client.make_file_url(
112        share_name, directory_name, file_name, protocol=protocol, sas_token=client.sas_token)
113
114
115def list_share_files(cmd, client, share_name, directory_name=None, timeout=None, exclude_dir=False, snapshot=None,
116                     num_results=None, marker=None):
117    if cmd.supported_api_version(min_api='2017-04-17'):
118        generator = client.list_directories_and_files(
119            share_name, directory_name, timeout=timeout, num_results=num_results, marker=marker, snapshot=snapshot)
120    else:
121        generator = client.list_directories_and_files(
122            share_name, directory_name, timeout=timeout, num_results=num_results, marker=marker)
123
124    if exclude_dir:
125        t_file_properties = cmd.get_models('file.models#FileProperties')
126
127        return list(f for f in generator if isinstance(f.properties, t_file_properties))
128
129    return generator
130
131
132def storage_file_upload_batch(cmd, client, destination, source, destination_path=None, pattern=None, dryrun=False,
133                              validate_content=False, content_settings=None, max_connections=1, metadata=None,
134                              progress_callback=None):
135    """ Upload local files to Azure Storage File Share in batch """
136
137    from azure.cli.command_modules.storage.util import glob_files_locally, normalize_blob_file_path
138
139    source_files = list(glob_files_locally(source, pattern))
140    logger = get_logger(__name__)
141    settings_class = cmd.get_models('file.models#ContentSettings')
142
143    if dryrun:
144        logger.info('upload files to file share')
145        logger.info('    account %s', client.account_name)
146        logger.info('      share %s', destination)
147        logger.info('      total %d', len(source_files))
148        return [{'File': client.make_file_url(destination, os.path.dirname(dst) or None, os.path.basename(dst)),
149                 'Type': guess_content_type(src, content_settings, settings_class).content_type} for src, dst in
150                source_files]
151
152    # TODO: Performance improvement
153    # 1. Upload files in parallel
154    def _upload_action(src, dst):
155        dst = normalize_blob_file_path(destination_path, dst)
156        dir_name = os.path.dirname(dst)
157        file_name = os.path.basename(dst)
158
159        _make_directory_in_files_share(client, destination, dir_name)
160        create_file_args = {'share_name': destination, 'directory_name': dir_name, 'file_name': file_name,
161                            'local_file_path': src, 'progress_callback': progress_callback,
162                            'content_settings': guess_content_type(src, content_settings, settings_class),
163                            'metadata': metadata, 'max_connections': max_connections}
164
165        if cmd.supported_api_version(min_api='2016-05-31'):
166            create_file_args['validate_content'] = validate_content
167
168        logger.warning('uploading %s', src)
169        client.create_file_from_path(**create_file_args)
170
171        return client.make_file_url(destination, dir_name, file_name)
172
173    return list(_upload_action(src, dst) for src, dst in source_files)
174
175
176def storage_file_download_batch(cmd, client, source, destination, pattern=None, dryrun=False, validate_content=False,
177                                max_connections=1, progress_callback=None, snapshot=None):
178    """
179    Download files from file share to local directory in batch
180    """
181
182    from azure.cli.command_modules.storage.util import glob_files_remotely, mkdir_p
183
184    source_files = glob_files_remotely(cmd, client, source, pattern, snapshot=snapshot)
185
186    if dryrun:
187        source_files_list = list(source_files)
188
189        logger = get_logger(__name__)
190        logger.warning('download files from file share')
191        logger.warning('    account %s', client.account_name)
192        logger.warning('      share %s', source)
193        logger.warning('destination %s', destination)
194        logger.warning('    pattern %s', pattern)
195        logger.warning('      total %d', len(source_files_list))
196        logger.warning(' operations')
197        for f in source_files_list:
198            logger.warning('  - %s/%s => %s', f[0], f[1], os.path.join(destination, *f))
199
200        return []
201
202    def _download_action(pair):
203        destination_dir = os.path.join(destination, pair[0])
204        mkdir_p(destination_dir)
205
206        get_file_args = {'share_name': source, 'directory_name': pair[0], 'file_name': pair[1],
207                         'file_path': os.path.join(destination, *pair), 'max_connections': max_connections,
208                         'progress_callback': progress_callback, 'snapshot': snapshot}
209
210        if cmd.supported_api_version(min_api='2016-05-31'):
211            get_file_args['validate_content'] = validate_content
212
213        client.get_file_to_path(**get_file_args)
214        return client.make_file_url(source, *pair)
215
216    return list(_download_action(f) for f in source_files)
217
218
219def storage_file_copy_batch(cmd, client, source_client, destination_share=None, destination_path=None,
220                            source_container=None, source_share=None, source_sas=None, pattern=None, dryrun=False,
221                            metadata=None, timeout=None):
222    """
223    Copy a group of files asynchronously
224    """
225    logger = None
226    if dryrun:
227        logger = get_logger(__name__)
228        logger.warning('copy files or blobs to file share')
229        logger.warning('    account %s', client.account_name)
230        logger.warning('      share %s', destination_share)
231        logger.warning('       path %s', destination_path)
232        logger.warning('     source %s', source_container or source_share)
233        logger.warning('source type %s', 'blob' if source_container else 'file')
234        logger.warning('    pattern %s', pattern)
235        logger.warning(' operations')
236
237    if source_container:
238        # copy blobs to file share
239
240        # if the source client is None, recreate one from the destination client.
241        source_client = source_client or create_blob_service_from_storage_client(cmd, client)
242
243        # the cache of existing directories in the destination file share. the cache helps to avoid
244        # repeatedly create existing directory so as to optimize the performance.
245        existing_dirs = set([])
246
247        if not source_sas:
248            source_sas = create_short_lived_container_sas(cmd, source_client.account_name, source_client.account_key,
249                                                          source_container)
250
251        # pylint: disable=inconsistent-return-statements
252        def action_blob_copy(blob_name):
253            if dryrun:
254                logger.warning('  - copy blob %s', blob_name)
255            else:
256                return _create_file_and_directory_from_blob(client, source_client, destination_share, source_container,
257                                                            source_sas, blob_name, destination_dir=destination_path,
258                                                            metadata=metadata, timeout=timeout,
259                                                            existing_dirs=existing_dirs)
260
261        return list(
262            filter_none(action_blob_copy(blob) for blob in collect_blobs(source_client, source_container, pattern)))
263
264    if source_share:
265        # copy files from share to share
266
267        # if the source client is None, assume the file share is in the same storage account as
268        # destination, therefore client is reused.
269        source_client = source_client or client
270
271        # the cache of existing directories in the destination file share. the cache helps to avoid
272        # repeatedly create existing directory so as to optimize the performance.
273        existing_dirs = set([])
274
275        if not source_sas:
276            source_sas = create_short_lived_share_sas(cmd, source_client.account_name, source_client.account_key,
277                                                      source_share)
278
279        # pylint: disable=inconsistent-return-statements
280        def action_file_copy(file_info):
281            dir_name, file_name = file_info
282            if dryrun:
283                logger.warning('  - copy file %s', os.path.join(dir_name, file_name))
284            else:
285                return _create_file_and_directory_from_file(client, source_client, destination_share, source_share,
286                                                            source_sas, dir_name, file_name,
287                                                            destination_dir=destination_path, metadata=metadata,
288                                                            timeout=timeout, existing_dirs=existing_dirs)
289
290        return list(filter_none(
291            action_file_copy(file) for file in collect_files(cmd, source_client, source_share, pattern)))
292    # won't happen, the validator should ensure either source_container or source_share is set
293    raise ValueError('Fail to find source. Neither blob container or file share is specified.')
294
295
296def storage_file_delete_batch(cmd, client, source, pattern=None, dryrun=False, timeout=None):
297    """
298    Delete files from file share in batch
299    """
300
301    def delete_action(file_pair):
302        delete_file_args = {'share_name': source, 'directory_name': file_pair[0], 'file_name': file_pair[1],
303                            'timeout': timeout}
304
305        return client.delete_file(**delete_file_args)
306
307    from azure.cli.command_modules.storage.util import glob_files_remotely
308    source_files = list(glob_files_remotely(cmd, client, source, pattern))
309
310    if dryrun:
311        logger = get_logger(__name__)
312        logger.warning('delete files from %s', source)
313        logger.warning('    pattern %s', pattern)
314        logger.warning('      share %s', source)
315        logger.warning('      total %d', len(source_files))
316        logger.warning(' operations')
317        for f in source_files:
318            logger.warning('  - %s/%s', f[0], f[1])
319        return []
320
321    for f in source_files:
322        delete_action(f)
323
324
325def _create_file_and_directory_from_blob(file_service, blob_service, share, container, sas, blob_name,
326                                         destination_dir=None, metadata=None, timeout=None, existing_dirs=None):
327    """
328    Copy a blob to file share and create the directory if needed.
329    """
330    from azure.common import AzureException
331    from azure.cli.command_modules.storage.util import normalize_blob_file_path
332
333    blob_url = blob_service.make_blob_url(container, encode_for_url(blob_name), sas_token=sas)
334    full_path = normalize_blob_file_path(destination_dir, blob_name)
335    file_name = os.path.basename(full_path)
336    dir_name = os.path.dirname(full_path)
337    _make_directory_in_files_share(file_service, share, dir_name, existing_dirs)
338
339    try:
340        file_service.copy_file(share, dir_name, file_name, blob_url, metadata, timeout)
341        return file_service.make_file_url(share, dir_name, file_name)
342    except AzureException:
343        error_template = 'Failed to copy blob {} to file share {}. Please check if you have permission to read ' \
344                         'source or set a correct sas token.'
345        from knack.util import CLIError
346        raise CLIError(error_template.format(blob_name, share))
347
348
349def _create_file_and_directory_from_file(file_service, source_file_service, share, source_share, sas, source_file_dir,
350                                         source_file_name, destination_dir=None, metadata=None, timeout=None,
351                                         existing_dirs=None):
352    """
353    Copy a file from one file share to another
354    """
355    from azure.common import AzureException
356    from azure.cli.command_modules.storage.util import normalize_blob_file_path
357
358    file_url, source_file_dir, source_file_name = make_encoded_file_url_and_params(source_file_service, source_share,
359                                                                                   source_file_dir, source_file_name,
360                                                                                   sas_token=sas)
361
362    full_path = normalize_blob_file_path(destination_dir, os.path.join(source_file_dir, source_file_name))
363    file_name = os.path.basename(full_path)
364    dir_name = os.path.dirname(full_path)
365    _make_directory_in_files_share(file_service, share, dir_name, existing_dirs)
366
367    try:
368        file_service.copy_file(share, dir_name, file_name, file_url, metadata, timeout)
369        return file_service.make_file_url(share, dir_name or None, file_name)
370    except AzureException:
371        error_template = 'Failed to copy file {} from share {} to file share {}. Please check if ' \
372                         'you have right permission to read source or set a correct sas token.'
373        from knack.util import CLIError
374        raise CLIError(error_template.format(file_name, source_share, share))
375
376
377def _make_directory_in_files_share(file_service, file_share, directory_path, existing_dirs=None):
378    """
379    Create directories recursively.
380
381    This method accept a existing_dirs set which serves as the cache of existing directory. If the
382    parameter is given, the method will search the set first to avoid repeatedly create directory
383    which already exists.
384    """
385    from azure.common import AzureHttpError
386
387    if not directory_path:
388        return
389
390    parents = [directory_path]
391    p = os.path.dirname(directory_path)
392    while p:
393        parents.append(p)
394        p = os.path.dirname(p)
395
396    for dir_name in reversed(parents):
397        if existing_dirs and (dir_name in existing_dirs):
398            continue
399
400        try:
401            file_service.create_directory(share_name=file_share, directory_name=dir_name, fail_on_exist=False)
402        except AzureHttpError:
403            from knack.util import CLIError
404            raise CLIError('Failed to create directory {}'.format(dir_name))
405
406        if existing_dirs:
407            existing_dirs.add(directory_path)
408
409
410def _file_share_exists(client, resource_group_name, account_name, share_name):
411    from azure.core.exceptions import HttpResponseError
412    try:
413        file_share = client.get(resource_group_name, account_name, share_name, expand=None)
414        return file_share is not None
415    except HttpResponseError:
416        return False
417