1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft Corporation. All rights reserved.
3# Licensed under the MIT License. See License.txt in the project root for
4# license information.
5# --------------------------------------------------------------------------
6import sys
7import uuid
8from datetime import date
9from io import (BytesIO, IOBase, SEEK_SET, SEEK_END, UnsupportedOperation)
10from os import fstat
11from time import time
12from wsgiref.handlers import format_date_time
13
14from dateutil.tz import tzutc
15
16if sys.version_info >= (3,):
17    from urllib.parse import quote as url_quote
18else:
19    from urllib2 import quote as url_quote
20
21try:
22    from xml.etree import cElementTree as ETree
23except ImportError:
24    from xml.etree import ElementTree as ETree
25
26from ._error import (
27    _ERROR_VALUE_SHOULD_BE_BYTES,
28    _ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM,
29    _ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM
30)
31from .models import (
32    _unicode_type,
33)
34from ._common_conversion import (
35    _str,
36)
37
38
39def _to_utc_datetime(value):
40    # Azure expects the date value passed in to be UTC.
41    # Azure will always return values as UTC.
42    # If a date is passed in without timezone info, it is assumed to be UTC.
43    if value.tzinfo:
44        value = value.astimezone(tzutc())
45    return value.strftime('%Y-%m-%dT%H:%M:%SZ')
46
47
48def _update_request(request, x_ms_version, user_agent_string):
49    # Verify body
50    if request.body:
51        request.body = _get_data_bytes_or_stream_only('request.body', request.body)
52        length = _len_plus(request.body)
53
54        # only scenario where this case is plausible is if the stream object is not seekable.
55        if length is None:
56            raise ValueError(_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM)
57
58        # if it is PUT, POST, MERGE, DELETE, need to add content-length to header.
59        if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']:
60            request.headers['Content-Length'] = str(length)
61
62    # append addtional headers based on the service
63    request.headers['x-ms-version'] = x_ms_version
64    request.headers['User-Agent'] = user_agent_string
65    request.headers['x-ms-client-request-id'] = str(uuid.uuid1())
66
67    # If the host has a path component (ex local storage), move it
68    path = request.host.split('/', 1)
69    if len(path) == 2:
70        request.host = path[0]
71        request.path = '/{}{}'.format(path[1], request.path)
72
73    # Encode and optionally add local storage prefix to path
74    request.path = url_quote(request.path, '/()$=\',~')
75
76
77def _add_metadata_headers(metadata, request):
78    if metadata:
79        if not request.headers:
80            request.headers = {}
81        for name, value in metadata.items():
82            request.headers['x-ms-meta-' + name] = value
83
84
85def _add_date_header(request):
86    current_time = format_date_time(time())
87    request.headers['x-ms-date'] = current_time
88
89
90def _get_data_bytes_only(param_name, param_value):
91    '''Validates the request body passed in and converts it to bytes
92    if our policy allows it.'''
93    if param_value is None:
94        return b''
95
96    if isinstance(param_value, bytes):
97        return param_value
98
99    raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name))
100
101
102def _get_data_bytes_or_stream_only(param_name, param_value):
103    '''Validates the request body passed in is a stream/file-like or bytes
104    object.'''
105    if param_value is None:
106        return b''
107
108    if isinstance(param_value, bytes) or hasattr(param_value, 'read'):
109        return param_value
110
111    raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM.format(param_name))
112
113
114def _get_request_body(request_body):
115    '''Converts an object into a request body.  If it's None
116    we'll return an empty string, if it's one of our objects it'll
117    convert it to XML and return it.  Otherwise we just use the object
118    directly'''
119    if request_body is None:
120        return b''
121
122    if isinstance(request_body, bytes) or isinstance(request_body, IOBase):
123        return request_body
124
125    if isinstance(request_body, _unicode_type):
126        return request_body.encode('utf-8')
127
128    request_body = str(request_body)
129    if isinstance(request_body, _unicode_type):
130        return request_body.encode('utf-8')
131
132    return request_body
133
134
135def _convert_signed_identifiers_to_xml(signed_identifiers):
136    if signed_identifiers is None:
137        return ''
138
139    sis = ETree.Element('SignedIdentifiers')
140    for id, access_policy in signed_identifiers.items():
141        # Root signed identifers element
142        si = ETree.SubElement(sis, 'SignedIdentifier')
143
144        # Id element
145        ETree.SubElement(si, 'Id').text = id
146
147        # Access policy element
148        policy = ETree.SubElement(si, 'AccessPolicy')
149
150        if access_policy.start:
151            start = access_policy.start
152            if isinstance(access_policy.start, date):
153                start = _to_utc_datetime(start)
154            ETree.SubElement(policy, 'Start').text = start
155
156        if access_policy.expiry:
157            expiry = access_policy.expiry
158            if isinstance(access_policy.expiry, date):
159                expiry = _to_utc_datetime(expiry)
160            ETree.SubElement(policy, 'Expiry').text = expiry
161
162        if access_policy.permission:
163            ETree.SubElement(policy, 'Permission').text = _str(access_policy.permission)
164
165    # Add xml declaration and serialize
166    try:
167        stream = BytesIO()
168        ETree.ElementTree(sis).write(stream, xml_declaration=True, encoding='utf-8', method='xml')
169    except:
170        raise
171    finally:
172        output = stream.getvalue()
173        stream.close()
174
175    return output
176
177
178def _convert_service_properties_to_xml(logging, hour_metrics, minute_metrics,
179                                       cors, target_version=None, delete_retention_policy=None, static_website=None):
180    '''
181    <?xml version="1.0" encoding="utf-8"?>
182    <StorageServiceProperties>
183        <Logging>
184            <Version>version-number</Version>
185            <Delete>true|false</Delete>
186            <Read>true|false</Read>
187            <Write>true|false</Write>
188            <RetentionPolicy>
189                <Enabled>true|false</Enabled>
190                <Days>number-of-days</Days>
191            </RetentionPolicy>
192        </Logging>
193        <HourMetrics>
194            <Version>version-number</Version>
195            <Enabled>true|false</Enabled>
196            <IncludeAPIs>true|false</IncludeAPIs>
197            <RetentionPolicy>
198                <Enabled>true|false</Enabled>
199                <Days>number-of-days</Days>
200            </RetentionPolicy>
201        </HourMetrics>
202        <MinuteMetrics>
203            <Version>version-number</Version>
204            <Enabled>true|false</Enabled>
205            <IncludeAPIs>true|false</IncludeAPIs>
206            <RetentionPolicy>
207                <Enabled>true|false</Enabled>
208                <Days>number-of-days</Days>
209            </RetentionPolicy>
210        </MinuteMetrics>
211        <Cors>
212            <CorsRule>
213                <AllowedOrigins>comma-separated-list-of-allowed-origins</AllowedOrigins>
214                <AllowedMethods>comma-separated-list-of-HTTP-verb</AllowedMethods>
215                <MaxAgeInSeconds>max-caching-age-in-seconds</MaxAgeInSeconds>
216                <ExposedHeaders>comma-seperated-list-of-response-headers</ExposedHeaders>
217                <AllowedHeaders>comma-seperated-list-of-request-headers</AllowedHeaders>
218            </CorsRule>
219        </Cors>
220        <DeleteRetentionPolicy>
221            <Enabled>true|false</Enabled>
222            <Days>number-of-days</Days>
223        </DeleteRetentionPolicy>
224        <StaticWebsite>
225            <Enabled>true|false</Enabled>
226            <IndexDocument></IndexDocument>
227            <ErrorDocument404Path></ErrorDocument404Path>
228        </StaticWebsite>
229    </StorageServiceProperties>
230    '''
231    service_properties_element = ETree.Element('StorageServiceProperties')
232
233    # Logging
234    if logging:
235        logging_element = ETree.SubElement(service_properties_element, 'Logging')
236        ETree.SubElement(logging_element, 'Version').text = logging.version
237        ETree.SubElement(logging_element, 'Delete').text = str(logging.delete)
238        ETree.SubElement(logging_element, 'Read').text = str(logging.read)
239        ETree.SubElement(logging_element, 'Write').text = str(logging.write)
240
241        retention_element = ETree.SubElement(logging_element, 'RetentionPolicy')
242        _convert_retention_policy_to_xml(logging.retention_policy, retention_element)
243
244    # HourMetrics
245    if hour_metrics:
246        hour_metrics_element = ETree.SubElement(service_properties_element, 'HourMetrics')
247        _convert_metrics_to_xml(hour_metrics, hour_metrics_element)
248
249    # MinuteMetrics
250    if minute_metrics:
251        minute_metrics_element = ETree.SubElement(service_properties_element, 'MinuteMetrics')
252        _convert_metrics_to_xml(minute_metrics, minute_metrics_element)
253
254    # CORS
255    # Make sure to still serialize empty list
256    if cors is not None:
257        cors_element = ETree.SubElement(service_properties_element, 'Cors')
258        for rule in cors:
259            cors_rule = ETree.SubElement(cors_element, 'CorsRule')
260            ETree.SubElement(cors_rule, 'AllowedOrigins').text = ",".join(rule.allowed_origins)
261            ETree.SubElement(cors_rule, 'AllowedMethods').text = ",".join(rule.allowed_methods)
262            ETree.SubElement(cors_rule, 'MaxAgeInSeconds').text = str(rule.max_age_in_seconds)
263            ETree.SubElement(cors_rule, 'ExposedHeaders').text = ",".join(rule.exposed_headers)
264            ETree.SubElement(cors_rule, 'AllowedHeaders').text = ",".join(rule.allowed_headers)
265
266    # Target version
267    if target_version:
268        ETree.SubElement(service_properties_element, 'DefaultServiceVersion').text = target_version
269
270    # DeleteRetentionPolicy
271    if delete_retention_policy:
272        policy_element = ETree.SubElement(service_properties_element, 'DeleteRetentionPolicy')
273        ETree.SubElement(policy_element, 'Enabled').text = str(delete_retention_policy.enabled)
274
275        if delete_retention_policy.enabled:
276            ETree.SubElement(policy_element, 'Days').text = str(delete_retention_policy.days)
277
278    # StaticWebsite
279    if static_website:
280        static_website_element = ETree.SubElement(service_properties_element, 'StaticWebsite')
281        ETree.SubElement(static_website_element, 'Enabled').text = str(static_website.enabled)
282
283        if static_website.enabled:
284
285            if static_website.index_document is not None:
286                ETree.SubElement(static_website_element, 'IndexDocument').text = str(static_website.index_document)
287
288            if static_website.error_document_404_path is not None:
289                ETree.SubElement(static_website_element, 'ErrorDocument404Path').text = \
290                    str(static_website.error_document_404_path)
291
292    # Add xml declaration and serialize
293    try:
294        stream = BytesIO()
295        ETree.ElementTree(service_properties_element).write(stream, xml_declaration=True, encoding='utf-8',
296                                                            method='xml')
297    except:
298        raise
299    finally:
300        output = stream.getvalue()
301        stream.close()
302
303    return output
304
305
306def _convert_metrics_to_xml(metrics, root):
307    '''
308    <Version>version-number</Version>
309    <Enabled>true|false</Enabled>
310    <IncludeAPIs>true|false</IncludeAPIs>
311    <RetentionPolicy>
312        <Enabled>true|false</Enabled>
313        <Days>number-of-days</Days>
314    </RetentionPolicy>
315    '''
316    # Version
317    ETree.SubElement(root, 'Version').text = metrics.version
318
319    # Enabled
320    ETree.SubElement(root, 'Enabled').text = str(metrics.enabled)
321
322    # IncludeAPIs
323    if metrics.enabled and metrics.include_apis is not None:
324        ETree.SubElement(root, 'IncludeAPIs').text = str(metrics.include_apis)
325
326    # RetentionPolicy
327    retention_element = ETree.SubElement(root, 'RetentionPolicy')
328    _convert_retention_policy_to_xml(metrics.retention_policy, retention_element)
329
330
331def _convert_retention_policy_to_xml(retention_policy, root):
332    '''
333    <Enabled>true|false</Enabled>
334    <Days>number-of-days</Days>
335    '''
336    # Enabled
337    ETree.SubElement(root, 'Enabled').text = str(retention_policy.enabled)
338
339    # Days
340    if retention_policy.enabled and retention_policy.days:
341        ETree.SubElement(root, 'Days').text = str(retention_policy.days)
342
343
344def _len_plus(data):
345    length = None
346    # Check if object implements the __len__ method, covers most input cases such as bytearray.
347    try:
348        length = len(data)
349    except:
350        pass
351
352    if not length:
353        # Check if the stream is a file-like stream object.
354        # If so, calculate the size using the file descriptor.
355        try:
356            fileno = data.fileno()
357        except (AttributeError, UnsupportedOperation):
358            pass
359        else:
360            return fstat(fileno).st_size
361
362        # If the stream is seekable and tell() is implemented, calculate the stream size.
363        try:
364            current_position = data.tell()
365            data.seek(0, SEEK_END)
366            length = data.tell() - current_position
367            data.seek(current_position, SEEK_SET)
368        except (AttributeError, UnsupportedOperation):
369            pass
370
371    return length
372