1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft.  All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#   http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# --------------------------------------------------------------------------
15
16import base64
17import hashlib
18import hmac
19import sys
20from io import (SEEK_SET)
21
22from dateutil.tz import tzutc
23
24from ._error import (
25    _ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM,
26    _ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM,
27)
28from .models import (
29    _unicode_type,
30)
31
32if sys.version_info < (3,):
33    def _str(value):
34        if isinstance(value, unicode):
35            return value.encode('utf-8')
36
37        return str(value)
38else:
39    _str = str
40
41
42def _to_str(value):
43    return _str(value) if value is not None else None
44
45
46def _int_to_str(value):
47    return str(int(value)) if value is not None else None
48
49
50def _bool_to_str(value):
51    if value is None:
52        return None
53
54    if isinstance(value, bool):
55        if value:
56            return 'true'
57        else:
58            return 'false'
59
60    return str(value)
61
62
63def _to_utc_datetime(value):
64    return value.strftime('%Y-%m-%dT%H:%M:%SZ')
65
66
67def _datetime_to_utc_string(value):
68    # Azure expects the date value passed in to be UTC.
69    # Azure will always return values as UTC.
70    # If a date is passed in without timezone info, it is assumed to be UTC.
71    if value is None:
72        return None
73
74    if value.tzinfo:
75        value = value.astimezone(tzutc())
76
77    return value.strftime('%a, %d %b %Y %H:%M:%S GMT')
78
79
80def _encode_base64(data):
81    if isinstance(data, _unicode_type):
82        data = data.encode('utf-8')
83    encoded = base64.b64encode(data)
84    return encoded.decode('utf-8')
85
86
87def _decode_base64_to_bytes(data):
88    if isinstance(data, _unicode_type):
89        data = data.encode('utf-8')
90    return base64.b64decode(data)
91
92
93def _decode_base64_to_text(data):
94    decoded_bytes = _decode_base64_to_bytes(data)
95    return decoded_bytes.decode('utf-8')
96
97
98def _sign_string(key, string_to_sign, key_is_base64=True):
99    if key_is_base64:
100        key = _decode_base64_to_bytes(key)
101    else:
102        if isinstance(key, _unicode_type):
103            key = key.encode('utf-8')
104    if isinstance(string_to_sign, _unicode_type):
105        string_to_sign = string_to_sign.encode('utf-8')
106    signed_hmac_sha256 = hmac.HMAC(key, string_to_sign, hashlib.sha256)
107    digest = signed_hmac_sha256.digest()
108    encoded_digest = _encode_base64(digest)
109    return encoded_digest
110
111
112def _get_content_md5(data):
113    md5 = hashlib.md5()
114    if isinstance(data, bytes):
115        md5.update(data)
116    elif hasattr(data, 'read'):
117        pos = 0
118        try:
119            pos = data.tell()
120        except:
121            pass
122        for chunk in iter(lambda: data.read(4096), b""):
123            md5.update(chunk)
124        try:
125            data.seek(pos, SEEK_SET)
126        except (AttributeError, IOError):
127            raise ValueError(_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM.format('data'))
128    else:
129        raise ValueError(_ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM.format('data'))
130
131    return base64.b64encode(md5.digest()).decode('utf-8')
132
133
134def _lower(text):
135    return text.lower()
136