1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6from knack.log import get_logger
7from knack.util import CLIError
8
9from azure.mgmt.datalake.store.models import (
10    UpdateDataLakeStoreAccountParameters,
11    CreateDataLakeStoreAccountParameters,
12    EncryptionConfigType,
13    EncryptionIdentity,
14    EncryptionConfig,
15    EncryptionState,
16    KeyVaultMetaInfo,
17    UpdateEncryptionConfig,
18    UpdateKeyVaultMetaInfo)
19
20from azure.datalake.store.enums import ExpiryOptionType
21from azure.datalake.store.multithread import (ADLUploader, ADLDownloader)
22from azure.cli.command_modules.dls._client_factory import (cf_dls_filesystem)
23from azure.cli.core.commands.client_factory import get_mgmt_service_client
24from azure.cli.core.profiles import ResourceType
25
26
27logger = get_logger(__name__)
28
29
30def get_update_progress(cli_ctx):
31
32    def _update_progress(current, total):
33        hook = cli_ctx.get_progress_controller(det=True)
34        if total:
35            hook.add(message='Alive', value=current, total_val=total)
36            if total == current:
37                hook.end()
38    return _update_progress
39
40
41# region account
42def list_adls_account(client, resource_group_name=None):
43    account_list = client.list_by_resource_group(resource_group_name=resource_group_name) \
44        if resource_group_name else client.list()
45    return list(account_list)
46
47
48def create_adls_account(cmd, client, resource_group_name, account_name, location=None, default_group=None, tags=None,
49                        encryption_type=EncryptionConfigType.service_managed.value, key_vault_id=None, key_name=None,
50                        key_version=None, disable_encryption=False, tier=None):
51
52    location = location or _get_resource_group_location(cmd.cli_ctx, resource_group_name)
53    create_params = CreateDataLakeStoreAccountParameters(
54        location=location,
55        tags=tags,
56        default_group=default_group,
57        new_tier=tier)
58
59    if not disable_encryption:
60        identity = EncryptionIdentity()
61        config = EncryptionConfig(type=encryption_type)
62        if encryption_type == EncryptionConfigType.user_managed:
63            if not key_name or not key_vault_id or not key_version:
64                # pylint: disable=line-too-long
65                raise CLIError('For user managed encryption, --key_vault_id, --key_name and --key_version are required parameters and must be supplied.')
66            config.key_vault_meta_info = KeyVaultMetaInfo(
67                key_vault_resource_id=key_vault_id,
68                encryption_key_name=key_name,
69                encryption_key_version=key_version)
70        else:
71            if key_name or key_vault_id or key_version:
72                # pylint: disable=line-too-long
73                logger.warning('User supplied Key Vault information. For service managed encryption user supplied Key Vault information is ignored.')
74
75        create_params.encryption_config = config
76        create_params.identity = identity
77    else:
78        create_params.encryption_state = EncryptionState.disabled
79        create_params.identity = None
80        create_params.encryption_config = None
81
82    return client.create(resource_group_name, account_name, create_params).result()
83
84
85def update_adls_account(client, account_name, resource_group_name, tags=None, default_group=None, firewall_state=None,
86                        allow_azure_ips=None, trusted_id_provider_state=None, tier=None, key_version=None):
87    update_params = UpdateDataLakeStoreAccountParameters(
88        tags=tags,
89        default_group=default_group,
90        firewall_state=firewall_state,
91        firewall_allow_azure_ips=allow_azure_ips,
92        trusted_id_provider_state=trusted_id_provider_state,
93        new_tier=tier)
94
95    # this will fail if the encryption is not user managed, as service managed key rotation is not supported.
96    if key_version:
97        update_params.encryption_config = UpdateEncryptionConfig(
98            key_vault_meta_info=UpdateKeyVaultMetaInfo(encryption_key_version=key_version))
99
100    return client.update(resource_group_name, account_name, update_params).result()
101# endregion
102
103
104# region firewall
105def add_adls_firewall_rule(client,
106                           account_name,
107                           firewall_rule_name,
108                           start_ip_address,
109                           end_ip_address,
110                           resource_group_name):
111    return client.create_or_update(resource_group_name,
112                                   account_name,
113                                   firewall_rule_name,
114                                   start_ip_address,
115                                   end_ip_address)
116# endregion
117
118
119# region virtual network
120def add_adls_virtual_network_rule(client,
121                                  account_name,
122                                  virtual_network_rule_name,
123                                  subnet,
124                                  resource_group_name):
125    return client.create_or_update(resource_group_name,
126                                   account_name,
127                                   virtual_network_rule_name,
128                                   subnet)
129# endregion
130
131
132# region filesystem
133def get_adls_item(cmd, account_name, path):
134    return cf_dls_filesystem(cmd.cli_ctx, account_name).info(path)
135
136
137def list_adls_items(cmd, account_name, path):
138    return cf_dls_filesystem(cmd.cli_ctx, account_name).ls(path, detail=True)
139
140
141def create_adls_item(cmd, account_name, path, content=None, folder=False, force=False):
142    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
143    if client.exists(path):
144        if force:
145            # only recurse if the user wants this to be a folder
146            # this prevents the user from unintentionally wiping out a folder
147            # when trying to create a file.
148            client.rm(path, recursive=folder)
149        else:
150            # pylint: disable=line-too-long
151            raise CLIError('An item at path: \'{}\' already exists. To overwrite the existing item, specify --force'.format(path))
152
153    if folder:
154        return client.mkdir(path)
155
156    if content:
157        if isinstance(content, str):
158            # turn content into bytes with UTF-8 encoding if it is just a string
159            content = str.encode(content)
160        with client.open(path, mode='wb') as f:
161            return f.write(content)
162    else:
163        return client.touch(path)
164
165
166def append_adls_item(cmd, account_name, path, content):
167    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
168    if not client.exists(path):
169        # pylint: disable=line-too-long
170        raise CLIError('File at path: \'{}\' does not exist. Create the file before attempting to append to it.'.format(path))
171
172    with client.open(path, mode='ab') as f:
173        if isinstance(content, str):
174            content = str.encode(content)
175        f.write(content)
176
177
178def upload_to_adls(cmd, account_name, source_path, destination_path, chunk_size, buffer_size, block_size,
179                   thread_count=None, overwrite=False, progress_callback=None):
180    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
181    ADLUploader(
182        client,
183        destination_path,
184        source_path,
185        thread_count,
186        chunksize=chunk_size,
187        buffersize=buffer_size,
188        blocksize=block_size,
189        overwrite=overwrite,
190        progress_callback=progress_callback or get_update_progress(cmd.cli_ctx))
191
192
193def remove_adls_item(cmd, account_name, path, recurse=False):
194    cf_dls_filesystem(cmd.cli_ctx, account_name).rm(path, recurse)
195
196
197def download_from_adls(cmd, account_name, source_path, destination_path, chunk_size, buffer_size, block_size,
198                       thread_count=None, overwrite=False, progress_callback=None):
199    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
200    ADLDownloader(
201        client,
202        source_path,
203        destination_path,
204        thread_count,
205        chunksize=chunk_size,
206        buffersize=buffer_size,
207        blocksize=block_size,
208        overwrite=overwrite,
209        progress_callback=progress_callback or get_update_progress(cmd.cli_ctx))
210
211
212def test_adls_item(cmd, account_name, path):
213    return cf_dls_filesystem(cmd.cli_ctx, account_name).exists(path)
214
215
216def preview_adls_item(cmd, account_name, path, length=None, offset=0, force=False):
217    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
218    if length:
219        try:
220            length = long(length)
221        except NameError:
222            length = int(length)
223
224    if offset:
225        try:
226            offset = long(offset)
227        except NameError:
228            offset = int(offset)
229
230    if not length or length <= 0:
231        length = client.info(path)['length'] - offset
232        if length > 1 * 1024 * 1024 and not force:
233            # pylint: disable=line-too-long
234            raise CLIError('The remaining data to preview is greater than {} bytes. Please specify a length or use the --force parameter to preview the entire file. The length of the file that would have been previewed: {}'.format(str(1 * 1024 * 1024), str(length)))
235
236    return client.read_block(path, offset, length)
237
238
239def join_adls_items(cmd, account_name, source_paths, destination_path, force=False):
240    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
241    if force and client.exists(destination_path):
242        client.rm(destination_path)
243
244    client.concat(destination_path, source_paths)
245
246
247def move_adls_item(cmd, account_name, source_path, destination_path, force=False):
248    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
249    if force and client.exists(destination_path):
250        client.rm(destination_path)
251    client.mv(source_path, destination_path)
252
253
254def set_adls_item_expiry(cmd, account_name, path, expiration_time):
255    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
256    if client.info(path)['type'] != 'FILE':
257        # pylint: disable=line-too-long
258        raise CLIError('The specified path does not exist or is not a file. Please ensure the path points to a file and it exists. Path supplied: {}'.format(path))
259
260    expiration_time = float(expiration_time)
261    try:
262        expiration_time = long(expiration_time)
263    except NameError:
264        expiration_time = int(expiration_time)
265    client.set_expiry(path, ExpiryOptionType.absolute.value, expiration_time)
266
267
268def remove_adls_item_expiry(cmd, account_name, path):
269    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
270    if client.info(path)['type'] != 'FILE':
271        # pylint: disable=line-too-long
272        raise CLIError('The specified path does not exist or is not a file. Please ensure the path points to a file and it exists. Path supplied: {}'.format(path))
273
274    client.set_expiry(path, ExpiryOptionType.never_expire.value)
275# endregion
276
277
278# region filesystem permissions
279def get_adls_item_acl(cmd, account_name, path):
280    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
281    return client.get_acl_status(path)
282
283
284def remove_adls_item_acl(cmd, account_name, path, default_acl=False):
285    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
286    if default_acl:
287        client.remove_default_acl(path)
288    else:
289        client.remove_acl(path)
290
291
292def remove_adls_item_acl_entry(cmd, account_name, path, acl_spec):
293    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
294    client.remove_acl_entries(path, acl_spec)
295
296
297def set_adls_item_acl(cmd, account_name, path, acl_spec):
298    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
299    client.set_acl(path, acl_spec)
300
301
302def set_adls_item_acl_entry(cmd, account_name, path, acl_spec):
303    client = cf_dls_filesystem(cmd.cli_ctx, account_name)
304    client.modify_acl_entries(path, acl_spec)
305
306
307def set_adls_item_owner(cmd, account_name, path, owner=None, group=None):
308    cf_dls_filesystem(cmd.cli_ctx, account_name).chown(path, owner, group)
309
310
311def set_adls_item_permissions(cmd, account_name, path, permission):
312    cf_dls_filesystem(cmd.cli_ctx, account_name).chmod(path, permission)
313# endregion
314
315
316# helpers
317def _get_resource_group_location(cli_ctx, resource_group_name):
318    client = get_mgmt_service_client(cli_ctx, ResourceType.MGMT_RESOURCE_RESOURCES)
319    # pylint: disable=no-member
320    return client.resource_groups.get(resource_group_name).location
321