1# --------------------------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for license information.
4# --------------------------------------------------------------------------------------------
5
6
7import os
8import json
9import platform
10import subprocess
11import datetime
12import sys
13import zipfile
14import stat
15from six.moves.urllib.parse import urlparse
16from six.moves.urllib.request import urlopen  # pylint: disable=import-error
17from azure.cli.core._profile import Profile
18from knack.log import get_logger
19from knack.util import CLIError
20
21logger = get_logger(__name__)
22
23
24STORAGE_RESOURCE_ENDPOINT = "https://storage.azure.com"
25SERVICES = {'blob', 'file'}
26AZCOPY_VERSION = '10.8.0'
27
28
29class AzCopy:
30    def __init__(self, creds=None):
31        self.system = platform.system()
32        install_location = _get_default_install_location()
33        self.executable = install_location
34        self.creds = creds
35        if not os.path.isfile(install_location) or self.check_version() != AZCOPY_VERSION:
36            self.install_azcopy(install_location)
37
38    def install_azcopy(self, install_location):
39        install_dir = os.path.dirname(install_location)
40        if not os.path.exists(install_dir):
41            os.makedirs(install_dir)
42        base_url = 'https://azcopyvnext.azureedge.net/release20201211/azcopy_{}_{}_{}.{}'
43
44        if self.system == 'Windows':
45            if platform.machine().endswith('64'):
46                file_url = base_url.format('windows', 'amd64', AZCOPY_VERSION, 'zip')
47            else:
48                file_url = base_url.format('windows', '386', AZCOPY_VERSION, 'zip')
49        elif self.system == 'Linux':
50            file_url = base_url.format('linux', 'amd64', AZCOPY_VERSION, 'tar.gz')
51        elif self.system == 'Darwin':
52            file_url = base_url.format('darwin', 'amd64', AZCOPY_VERSION, 'zip')
53        elif self.system == 'FreeBSD':
54            raise CLIError('Azcopy ({}) binary not available, follow instructions at https://wiki.freebsd.org/Ports/sysutils/py-azure-cli'.format(self.system))
55        else:
56            raise CLIError('Azcopy ({}) does not exist.'.format(self.system))
57        try:
58            os.chmod(install_dir,
59                     os.stat(install_dir).st_mode | stat.S_IWUSR | stat.S_IWGRP | stat.S_IWOTH)
60            _urlretrieve(file_url, install_location)
61            os.chmod(install_location,
62                     os.stat(install_location).st_mode | stat.S_IXUSR | stat.S_IXGRP | stat.S_IXOTH)
63        except IOError as err:
64            raise CLIError('Connection error while attempting to download azcopy {}. You could also install the '
65                           'specified azcopy version to {} manually. ({})'.format(AZCOPY_VERSION, install_dir, err))
66
67    def check_version(self):
68        try:
69            import re
70            args = [self.executable] + ["--version"]
71            out_bytes = subprocess.check_output(args)
72            out_text = out_bytes.decode('utf-8')
73            version = re.findall(r"azcopy version (.+?)\n", out_text)[0]
74            return version
75        except subprocess.CalledProcessError:
76            return ""
77
78    def run_command(self, args):
79        args = [self.executable] + args
80        logger.warning("Azcopy command: %s", args)
81        env_kwargs = {}
82        if self.creds and self.creds.token_info:
83            env_kwargs = {'AZCOPY_OAUTH_TOKEN_INFO': json.dumps(self.creds.token_info)}
84        result = subprocess.call(args, env=dict(os.environ, **env_kwargs))
85        if result > 0:
86            raise CLIError('Failed to perform {} operation.'.format(args[1]))
87
88    def copy(self, source, destination, flags=None):
89        flags = flags or []
90        self.run_command(['copy', source, destination] + flags)
91
92    def remove(self, target, flags=None):
93        flags = flags or []
94        self.run_command(['remove', target] + flags)
95
96    def sync(self, source, destination, flags=None):
97        flags = flags or []
98        self.run_command(['sync', source, destination] + flags)
99
100
101class AzCopyCredentials:  # pylint: disable=too-few-public-methods
102    def __init__(self, sas_token=None, token_info=None):
103        self.sas_token = sas_token
104        self.token_info = token_info
105
106
107def login_auth_for_azcopy(cmd):
108    token_info = Profile(cli_ctx=cmd.cli_ctx).get_raw_token(resource=STORAGE_RESOURCE_ENDPOINT)[0][2]
109    try:
110        token_info = _unserialize_non_msi_token_payload(token_info)
111    except KeyError:  # unserialized MSI token payload
112        raise Exception('MSI auth not yet supported.')
113    return AzCopyCredentials(token_info=token_info)
114
115
116def client_auth_for_azcopy(cmd, client, service='blob'):
117    azcopy_creds = storage_client_auth_for_azcopy(client, service)
118    if azcopy_creds is not None:
119        return azcopy_creds
120
121    # oauth mode
122    if client.token_credential:
123        token_info = Profile(cli_ctx=cmd.cli_ctx).get_raw_token(resource=STORAGE_RESOURCE_ENDPOINT)[0][2]
124        try:
125            token_info = _unserialize_non_msi_token_payload(token_info)
126        except KeyError as ex:  # unserialized token payload
127            from azure.cli.core.azclierror import ValidationError
128            raise ValidationError('No {}. MSI auth and service principal are not yet supported.'.format(ex))
129        return AzCopyCredentials(token_info=token_info)
130
131    return None
132
133
134def storage_client_auth_for_azcopy(client, service):
135    if service not in SERVICES:
136        raise Exception('{} not one of: {}'.format(service, str(SERVICES)))
137
138    if client.sas_token:
139        return AzCopyCredentials(sas_token=client.sas_token)
140    return None
141
142
143def _unserialize_non_msi_token_payload(token_info):
144    import jwt  # pylint: disable=import-error
145
146    parsed_authority = urlparse(token_info['_authority'])
147    decode = jwt.decode(token_info['accessToken'], algorithms=['RS256'], options={"verify_signature": False})
148    return {
149        'access_token': token_info['accessToken'],
150        'refresh_token': token_info['refreshToken'],
151        'expires_in': str(token_info['expiresIn']),
152        'not_before': str(decode['nbf']),
153        'expires_on': str(int((datetime.datetime.strptime(
154            token_info['expiresOn'], "%Y-%m-%d %H:%M:%S.%f")).timestamp())),
155        'resource': STORAGE_RESOURCE_ENDPOINT,
156        'token_type': token_info['tokenType'],
157        '_tenant': parsed_authority.path.strip('/'),
158        '_client_id': token_info['_clientId'],
159        '_ad_endpoint': '{uri.scheme}://{uri.netloc}'.format(uri=parsed_authority)
160    }
161
162
163def _generate_sas_token(cmd, account_name, account_key, service, resource_types='sco', permissions='rwdlacup'):
164    from .._client_factory import cloud_storage_account_service_factory
165    from .._validators import resource_type_type, services_type
166
167    kwargs = {
168        'account_name': account_name,
169        'account_key': account_key
170    }
171    cloud_storage_client = cloud_storage_account_service_factory(cmd.cli_ctx, kwargs)
172    t_account_permissions = cmd.loader.get_sdk('common.models#AccountPermissions')
173
174    return cloud_storage_client.generate_shared_access_signature(
175        services_type(cmd.loader)(service[0]),
176        resource_type_type(cmd.loader)(resource_types),
177        t_account_permissions(_str=permissions),
178        datetime.datetime.utcnow() + datetime.timedelta(days=1)
179    )
180
181
182def _get_default_install_location():
183    system = platform.system()
184    if system == 'Windows':
185        home_dir = os.environ.get('USERPROFILE')
186        if not home_dir:
187            raise CLIError('In the Windows platform, please specify the environment variable "USERPROFILE" '
188                           'as the installation location.')
189        install_location = os.path.join(home_dir, r'.azcopy\azcopy.exe')
190    elif system in ('Linux', 'Darwin', 'FreeBSD'):
191        install_location = os.path.expanduser(os.path.join('~', 'bin/azcopy'))
192    else:
193        raise CLIError('The {} platform is not currently supported. If you want to know which platforms are supported, '
194                       'please refer to the document for supported platforms: '
195                       'https://docs.microsoft.com/en-us/azure/storage/common/storage-use-azcopy-v10#download-azcopy'
196                       .format(system))
197    return install_location
198
199
200def _urlretrieve(url, install_location):
201    import io
202    req = urlopen(url)
203    compressedFile = io.BytesIO(req.read())
204    if url.endswith('zip'):
205        if sys.version_info.major >= 3:
206            zip_file = zipfile.ZipFile(compressedFile)
207        else:
208            # If Python version is 2.X, use StringIO instead.
209            import StringIO  # pylint: disable=import-error
210            zip_file = zipfile.ZipFile(StringIO.StringIO(req.read()))
211        for fileName in zip_file.namelist():
212            if fileName.endswith('azcopy') or fileName.endswith('azcopy.exe'):
213                with open(install_location, 'wb') as f:
214                    f.write(zip_file.read(fileName))
215    elif url.endswith('gz'):
216        import tarfile
217        with tarfile.open(fileobj=compressedFile, mode="r:gz") as tar:
218            for tarinfo in tar:
219                if tarinfo.isfile() and tarinfo.name.endswith('azcopy'):
220                    with open(install_location, 'wb') as f:
221                        f.write(tar.extractfile(tarinfo).read())
222    else:
223        raise CLIError('Invalid downloading url {}'.format(url))
224