1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft.  All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#   http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# --------------------------------------------------------------------------
15import sys
16import uuid
17from datetime import date
18from io import (BytesIO, IOBase, SEEK_SET, SEEK_END, UnsupportedOperation)
19from os import fstat
20from time import time
21from wsgiref.handlers import format_date_time
22
23from dateutil.tz import tzutc
24
25if sys.version_info >= (3,):
26    from urllib.parse import quote as url_quote
27else:
28    from urllib2 import quote as url_quote
29
30try:
31    from xml.etree import cElementTree as ETree
32except ImportError:
33    from xml.etree import ElementTree as ETree
34
35from ._error import (
36    _ERROR_VALUE_SHOULD_BE_BYTES,
37    _ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM,
38    _ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM
39)
40from ._constants import (
41    X_MS_VERSION,
42    USER_AGENT_STRING,
43)
44from .models import (
45    _unicode_type,
46)
47from ._common_conversion import (
48    _str,
49)
50
51
52def _to_utc_datetime(value):
53    # Azure expects the date value passed in to be UTC.
54    # Azure will always return values as UTC.
55    # If a date is passed in without timezone info, it is assumed to be UTC.
56    if value.tzinfo:
57        value = value.astimezone(tzutc())
58    return value.strftime('%Y-%m-%dT%H:%M:%SZ')
59
60
61def _update_request(request):
62    # Verify body
63    if request.body:
64        request.body = _get_data_bytes_or_stream_only('request.body', request.body)
65        length = _len_plus(request.body)
66
67        # only scenario where this case is plausible is if the stream object is not seekable.
68        if length is None:
69            raise ValueError(_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM)
70
71        # if it is PUT, POST, MERGE, DELETE, need to add content-length to header.
72        if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']:
73            request.headers['Content-Length'] = str(length)
74
75    # append addtional headers based on the service
76    request.headers['x-ms-version'] = X_MS_VERSION
77    request.headers['User-Agent'] = USER_AGENT_STRING
78    request.headers['x-ms-client-request-id'] = str(uuid.uuid1())
79
80    # If the host has a path component (ex local storage), move it
81    path = request.host.split('/', 1)
82    if len(path) == 2:
83        request.host = path[0]
84        request.path = '/{}{}'.format(path[1], request.path)
85
86    # Encode and optionally add local storage prefix to path
87    request.path = url_quote(request.path, '/()$=\',~')
88
89
90def _add_metadata_headers(metadata, request):
91    if metadata:
92        if not request.headers:
93            request.headers = {}
94        for name, value in metadata.items():
95            request.headers['x-ms-meta-' + name] = value
96
97
98def _add_date_header(request):
99    current_time = format_date_time(time())
100    request.headers['x-ms-date'] = current_time
101
102
103def _get_data_bytes_only(param_name, param_value):
104    '''Validates the request body passed in and converts it to bytes
105    if our policy allows it.'''
106    if param_value is None:
107        return b''
108
109    if isinstance(param_value, bytes):
110        return param_value
111
112    raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name))
113
114
115def _get_data_bytes_or_stream_only(param_name, param_value):
116    '''Validates the request body passed in is a stream/file-like or bytes
117    object.'''
118    if param_value is None:
119        return b''
120
121    if isinstance(param_value, bytes) or hasattr(param_value, 'read'):
122        return param_value
123
124    raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM.format(param_name))
125
126
127def _get_request_body(request_body):
128    '''Converts an object into a request body.  If it's None
129    we'll return an empty string, if it's one of our objects it'll
130    convert it to XML and return it.  Otherwise we just use the object
131    directly'''
132    if request_body is None:
133        return b''
134
135    if isinstance(request_body, bytes) or isinstance(request_body, IOBase):
136        return request_body
137
138    if isinstance(request_body, _unicode_type):
139        return request_body.encode('utf-8')
140
141    request_body = str(request_body)
142    if isinstance(request_body, _unicode_type):
143        return request_body.encode('utf-8')
144
145    return request_body
146
147
148def _convert_signed_identifiers_to_xml(signed_identifiers):
149    if signed_identifiers is None:
150        return ''
151
152    sis = ETree.Element('SignedIdentifiers')
153    for id, access_policy in signed_identifiers.items():
154        # Root signed identifers element
155        si = ETree.SubElement(sis, 'SignedIdentifier')
156
157        # Id element
158        ETree.SubElement(si, 'Id').text = id
159
160        # Access policy element
161        policy = ETree.SubElement(si, 'AccessPolicy')
162
163        if access_policy.start:
164            start = access_policy.start
165            if isinstance(access_policy.start, date):
166                start = _to_utc_datetime(start)
167            ETree.SubElement(policy, 'Start').text = start
168
169        if access_policy.expiry:
170            expiry = access_policy.expiry
171            if isinstance(access_policy.expiry, date):
172                expiry = _to_utc_datetime(expiry)
173            ETree.SubElement(policy, 'Expiry').text = expiry
174
175        if access_policy.permission:
176            ETree.SubElement(policy, 'Permission').text = _str(access_policy.permission)
177
178    # Add xml declaration and serialize
179    try:
180        stream = BytesIO()
181        ETree.ElementTree(sis).write(stream, xml_declaration=True, encoding='utf-8', method='xml')
182    except:
183        raise
184    finally:
185        output = stream.getvalue()
186        stream.close()
187
188    return output
189
190
191def _convert_service_properties_to_xml(logging, hour_metrics, minute_metrics, cors, target_version=None):
192    '''
193    <?xml version="1.0" encoding="utf-8"?>
194    <StorageServiceProperties>
195        <Logging>
196            <Version>version-number</Version>
197            <Delete>true|false</Delete>
198            <Read>true|false</Read>
199            <Write>true|false</Write>
200            <RetentionPolicy>
201                <Enabled>true|false</Enabled>
202                <Days>number-of-days</Days>
203            </RetentionPolicy>
204        </Logging>
205        <HourMetrics>
206            <Version>version-number</Version>
207            <Enabled>true|false</Enabled>
208            <IncludeAPIs>true|false</IncludeAPIs>
209            <RetentionPolicy>
210                <Enabled>true|false</Enabled>
211                <Days>number-of-days</Days>
212            </RetentionPolicy>
213        </HourMetrics>
214        <MinuteMetrics>
215            <Version>version-number</Version>
216            <Enabled>true|false</Enabled>
217            <IncludeAPIs>true|false</IncludeAPIs>
218            <RetentionPolicy>
219                <Enabled>true|false</Enabled>
220                <Days>number-of-days</Days>
221            </RetentionPolicy>
222        </MinuteMetrics>
223        <Cors>
224            <CorsRule>
225                <AllowedOrigins>comma-separated-list-of-allowed-origins</AllowedOrigins>
226                <AllowedMethods>comma-separated-list-of-HTTP-verb</AllowedMethods>
227                <MaxAgeInSeconds>max-caching-age-in-seconds</MaxAgeInSeconds>
228                <ExposedHeaders>comma-seperated-list-of-response-headers</ExposedHeaders>
229                <AllowedHeaders>comma-seperated-list-of-request-headers</AllowedHeaders>
230            </CorsRule>
231        </Cors>
232    </StorageServiceProperties>
233    '''
234    service_properties_element = ETree.Element('StorageServiceProperties')
235
236    # Logging
237    if logging:
238        logging_element = ETree.SubElement(service_properties_element, 'Logging')
239        ETree.SubElement(logging_element, 'Version').text = logging.version
240        ETree.SubElement(logging_element, 'Delete').text = str(logging.delete)
241        ETree.SubElement(logging_element, 'Read').text = str(logging.read)
242        ETree.SubElement(logging_element, 'Write').text = str(logging.write)
243
244        retention_element = ETree.SubElement(logging_element, 'RetentionPolicy')
245        _convert_retention_policy_to_xml(logging.retention_policy, retention_element)
246
247    # HourMetrics
248    if hour_metrics:
249        hour_metrics_element = ETree.SubElement(service_properties_element, 'HourMetrics')
250        _convert_metrics_to_xml(hour_metrics, hour_metrics_element)
251
252    # MinuteMetrics
253    if minute_metrics:
254        minute_metrics_element = ETree.SubElement(service_properties_element, 'MinuteMetrics')
255        _convert_metrics_to_xml(minute_metrics, minute_metrics_element)
256
257    # CORS
258    # Make sure to still serialize empty list
259    if cors is not None:
260        cors_element = ETree.SubElement(service_properties_element, 'Cors')
261        for rule in cors:
262            cors_rule = ETree.SubElement(cors_element, 'CorsRule')
263            ETree.SubElement(cors_rule, 'AllowedOrigins').text = ",".join(rule.allowed_origins)
264            ETree.SubElement(cors_rule, 'AllowedMethods').text = ",".join(rule.allowed_methods)
265            ETree.SubElement(cors_rule, 'MaxAgeInSeconds').text = str(rule.max_age_in_seconds)
266            ETree.SubElement(cors_rule, 'ExposedHeaders').text = ",".join(rule.exposed_headers)
267            ETree.SubElement(cors_rule, 'AllowedHeaders').text = ",".join(rule.allowed_headers)
268
269    # Target version
270    if target_version:
271        ETree.SubElement(service_properties_element, 'DefaultServiceVersion').text = target_version
272
273    # Add xml declaration and serialize
274    try:
275        stream = BytesIO()
276        ETree.ElementTree(service_properties_element).write(stream, xml_declaration=True, encoding='utf-8',
277                                                            method='xml')
278    except:
279        raise
280    finally:
281        output = stream.getvalue()
282        stream.close()
283
284    return output
285
286
287def _convert_metrics_to_xml(metrics, root):
288    '''
289    <Version>version-number</Version>
290    <Enabled>true|false</Enabled>
291    <IncludeAPIs>true|false</IncludeAPIs>
292    <RetentionPolicy>
293        <Enabled>true|false</Enabled>
294        <Days>number-of-days</Days>
295    </RetentionPolicy>
296    '''
297    # Version
298    ETree.SubElement(root, 'Version').text = metrics.version
299
300    # Enabled
301    ETree.SubElement(root, 'Enabled').text = str(metrics.enabled)
302
303    # IncludeAPIs
304    if metrics.enabled and metrics.include_apis is not None:
305        ETree.SubElement(root, 'IncludeAPIs').text = str(metrics.include_apis)
306
307    # RetentionPolicy
308    retention_element = ETree.SubElement(root, 'RetentionPolicy')
309    _convert_retention_policy_to_xml(metrics.retention_policy, retention_element)
310
311
312def _convert_retention_policy_to_xml(retention_policy, root):
313    '''
314    <Enabled>true|false</Enabled>
315    <Days>number-of-days</Days>
316    '''
317    # Enabled
318    ETree.SubElement(root, 'Enabled').text = str(retention_policy.enabled)
319
320    # Days
321    if retention_policy.enabled and retention_policy.days:
322        ETree.SubElement(root, 'Days').text = str(retention_policy.days)
323
324
325def _len_plus(data):
326    length = None
327    # Check if object implements the __len__ method, covers most input cases such as bytearray.
328    try:
329        length = len(data)
330    except:
331        pass
332
333    if not length:
334        # Check if the stream is a file-like stream object.
335        # If so, calculate the size using the file descriptor.
336        try:
337            fileno = data.fileno()
338        except (AttributeError, UnsupportedOperation):
339            pass
340        else:
341            return fstat(fileno).st_size
342
343        # If the stream is seekable and tell() is implemented, calculate the stream size.
344        try:
345            current_position = data.tell()
346            data.seek(0, SEEK_END)
347            length = data.tell() - current_position
348            data.seek(current_position, SEEK_SET)
349        except (AttributeError, UnsupportedOperation):
350            pass
351
352    return length
353