1# -------------------------------------------------------------------------
2# Copyright (c) Microsoft.  All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#   http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14# --------------------------------------------------------------------------
15import sys
16import uuid
17from datetime import date
18from io import (BytesIO, IOBase, SEEK_SET, SEEK_END, UnsupportedOperation)
19from os import fstat
20from time import time
21from wsgiref.handlers import format_date_time
22
23from dateutil.tz import tzutc
24
25if sys.version_info >= (3,):
26    from urllib.parse import quote as url_quote
27else:
28    from urllib2 import quote as url_quote
29
30try:
31    from xml.etree import cElementTree as ETree
32except ImportError:
33    from xml.etree import ElementTree as ETree
34
35from ._error import (
36    _ERROR_VALUE_SHOULD_BE_BYTES,
37    _ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM,
38    _ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM
39)
40from .models import (
41    _unicode_type,
42)
43from ._common_conversion import (
44    _str,
45)
46
47
48def _to_utc_datetime(value):
49    # Azure expects the date value passed in to be UTC.
50    # Azure will always return values as UTC.
51    # If a date is passed in without timezone info, it is assumed to be UTC.
52    if value.tzinfo:
53        value = value.astimezone(tzutc())
54    return value.strftime('%Y-%m-%dT%H:%M:%SZ')
55
56
57def _update_request(request, x_ms_version, user_agent_string):
58    # Verify body
59    if request.body:
60        request.body = _get_data_bytes_or_stream_only('request.body', request.body)
61        length = _len_plus(request.body)
62
63        # only scenario where this case is plausible is if the stream object is not seekable.
64        if length is None:
65            raise ValueError(_ERROR_VALUE_SHOULD_BE_SEEKABLE_STREAM)
66
67        # if it is PUT, POST, MERGE, DELETE, need to add content-length to header.
68        if request.method in ['PUT', 'POST', 'MERGE', 'DELETE']:
69            request.headers['Content-Length'] = str(length)
70
71    # append addtional headers based on the service
72    request.headers['x-ms-version'] = x_ms_version
73    request.headers['User-Agent'] = user_agent_string
74    request.headers['x-ms-client-request-id'] = str(uuid.uuid1())
75
76    # If the host has a path component (ex local storage), move it
77    path = request.host.split('/', 1)
78    if len(path) == 2:
79        request.host = path[0]
80        request.path = '/{}{}'.format(path[1], request.path)
81
82    # Encode and optionally add local storage prefix to path
83    request.path = url_quote(request.path, '/()$=\',~')
84
85
86def _add_metadata_headers(metadata, request):
87    if metadata:
88        if not request.headers:
89            request.headers = {}
90        for name, value in metadata.items():
91            request.headers['x-ms-meta-' + name] = value
92
93
94def _add_date_header(request):
95    current_time = format_date_time(time())
96    request.headers['x-ms-date'] = current_time
97
98
99def _get_data_bytes_only(param_name, param_value):
100    '''Validates the request body passed in and converts it to bytes
101    if our policy allows it.'''
102    if param_value is None:
103        return b''
104
105    if isinstance(param_value, bytes):
106        return param_value
107
108    raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES.format(param_name))
109
110
111def _get_data_bytes_or_stream_only(param_name, param_value):
112    '''Validates the request body passed in is a stream/file-like or bytes
113    object.'''
114    if param_value is None:
115        return b''
116
117    if isinstance(param_value, bytes) or hasattr(param_value, 'read'):
118        return param_value
119
120    raise TypeError(_ERROR_VALUE_SHOULD_BE_BYTES_OR_STREAM.format(param_name))
121
122
123def _get_request_body(request_body):
124    '''Converts an object into a request body.  If it's None
125    we'll return an empty string, if it's one of our objects it'll
126    convert it to XML and return it.  Otherwise we just use the object
127    directly'''
128    if request_body is None:
129        return b''
130
131    if isinstance(request_body, bytes) or isinstance(request_body, IOBase):
132        return request_body
133
134    if isinstance(request_body, _unicode_type):
135        return request_body.encode('utf-8')
136
137    request_body = str(request_body)
138    if isinstance(request_body, _unicode_type):
139        return request_body.encode('utf-8')
140
141    return request_body
142
143
144def _convert_signed_identifiers_to_xml(signed_identifiers):
145    if signed_identifiers is None:
146        return ''
147
148    sis = ETree.Element('SignedIdentifiers')
149    for id, access_policy in signed_identifiers.items():
150        # Root signed identifers element
151        si = ETree.SubElement(sis, 'SignedIdentifier')
152
153        # Id element
154        ETree.SubElement(si, 'Id').text = id
155
156        # Access policy element
157        policy = ETree.SubElement(si, 'AccessPolicy')
158
159        if access_policy.start:
160            start = access_policy.start
161            if isinstance(access_policy.start, date):
162                start = _to_utc_datetime(start)
163            ETree.SubElement(policy, 'Start').text = start
164
165        if access_policy.expiry:
166            expiry = access_policy.expiry
167            if isinstance(access_policy.expiry, date):
168                expiry = _to_utc_datetime(expiry)
169            ETree.SubElement(policy, 'Expiry').text = expiry
170
171        if access_policy.permission:
172            ETree.SubElement(policy, 'Permission').text = _str(access_policy.permission)
173
174    # Add xml declaration and serialize
175    try:
176        stream = BytesIO()
177        ETree.ElementTree(sis).write(stream, xml_declaration=True, encoding='utf-8', method='xml')
178    except:
179        raise
180    finally:
181        output = stream.getvalue()
182        stream.close()
183
184    return output
185
186
187def _convert_service_properties_to_xml(logging, hour_metrics, minute_metrics, cors, target_version=None):
188    '''
189    <?xml version="1.0" encoding="utf-8"?>
190    <StorageServiceProperties>
191        <Logging>
192            <Version>version-number</Version>
193            <Delete>true|false</Delete>
194            <Read>true|false</Read>
195            <Write>true|false</Write>
196            <RetentionPolicy>
197                <Enabled>true|false</Enabled>
198                <Days>number-of-days</Days>
199            </RetentionPolicy>
200        </Logging>
201        <HourMetrics>
202            <Version>version-number</Version>
203            <Enabled>true|false</Enabled>
204            <IncludeAPIs>true|false</IncludeAPIs>
205            <RetentionPolicy>
206                <Enabled>true|false</Enabled>
207                <Days>number-of-days</Days>
208            </RetentionPolicy>
209        </HourMetrics>
210        <MinuteMetrics>
211            <Version>version-number</Version>
212            <Enabled>true|false</Enabled>
213            <IncludeAPIs>true|false</IncludeAPIs>
214            <RetentionPolicy>
215                <Enabled>true|false</Enabled>
216                <Days>number-of-days</Days>
217            </RetentionPolicy>
218        </MinuteMetrics>
219        <Cors>
220            <CorsRule>
221                <AllowedOrigins>comma-separated-list-of-allowed-origins</AllowedOrigins>
222                <AllowedMethods>comma-separated-list-of-HTTP-verb</AllowedMethods>
223                <MaxAgeInSeconds>max-caching-age-in-seconds</MaxAgeInSeconds>
224                <ExposedHeaders>comma-seperated-list-of-response-headers</ExposedHeaders>
225                <AllowedHeaders>comma-seperated-list-of-request-headers</AllowedHeaders>
226            </CorsRule>
227        </Cors>
228    </StorageServiceProperties>
229    '''
230    service_properties_element = ETree.Element('StorageServiceProperties')
231
232    # Logging
233    if logging:
234        logging_element = ETree.SubElement(service_properties_element, 'Logging')
235        ETree.SubElement(logging_element, 'Version').text = logging.version
236        ETree.SubElement(logging_element, 'Delete').text = str(logging.delete)
237        ETree.SubElement(logging_element, 'Read').text = str(logging.read)
238        ETree.SubElement(logging_element, 'Write').text = str(logging.write)
239
240        retention_element = ETree.SubElement(logging_element, 'RetentionPolicy')
241        _convert_retention_policy_to_xml(logging.retention_policy, retention_element)
242
243    # HourMetrics
244    if hour_metrics:
245        hour_metrics_element = ETree.SubElement(service_properties_element, 'HourMetrics')
246        _convert_metrics_to_xml(hour_metrics, hour_metrics_element)
247
248    # MinuteMetrics
249    if minute_metrics:
250        minute_metrics_element = ETree.SubElement(service_properties_element, 'MinuteMetrics')
251        _convert_metrics_to_xml(minute_metrics, minute_metrics_element)
252
253    # CORS
254    # Make sure to still serialize empty list
255    if cors is not None:
256        cors_element = ETree.SubElement(service_properties_element, 'Cors')
257        for rule in cors:
258            cors_rule = ETree.SubElement(cors_element, 'CorsRule')
259            ETree.SubElement(cors_rule, 'AllowedOrigins').text = ",".join(rule.allowed_origins)
260            ETree.SubElement(cors_rule, 'AllowedMethods').text = ",".join(rule.allowed_methods)
261            ETree.SubElement(cors_rule, 'MaxAgeInSeconds').text = str(rule.max_age_in_seconds)
262            ETree.SubElement(cors_rule, 'ExposedHeaders').text = ",".join(rule.exposed_headers)
263            ETree.SubElement(cors_rule, 'AllowedHeaders').text = ",".join(rule.allowed_headers)
264
265    # Target version
266    if target_version:
267        ETree.SubElement(service_properties_element, 'DefaultServiceVersion').text = target_version
268
269    # Add xml declaration and serialize
270    try:
271        stream = BytesIO()
272        ETree.ElementTree(service_properties_element).write(stream, xml_declaration=True, encoding='utf-8',
273                                                            method='xml')
274    except:
275        raise
276    finally:
277        output = stream.getvalue()
278        stream.close()
279
280    return output
281
282
283def _convert_metrics_to_xml(metrics, root):
284    '''
285    <Version>version-number</Version>
286    <Enabled>true|false</Enabled>
287    <IncludeAPIs>true|false</IncludeAPIs>
288    <RetentionPolicy>
289        <Enabled>true|false</Enabled>
290        <Days>number-of-days</Days>
291    </RetentionPolicy>
292    '''
293    # Version
294    ETree.SubElement(root, 'Version').text = metrics.version
295
296    # Enabled
297    ETree.SubElement(root, 'Enabled').text = str(metrics.enabled)
298
299    # IncludeAPIs
300    if metrics.enabled and metrics.include_apis is not None:
301        ETree.SubElement(root, 'IncludeAPIs').text = str(metrics.include_apis)
302
303    # RetentionPolicy
304    retention_element = ETree.SubElement(root, 'RetentionPolicy')
305    _convert_retention_policy_to_xml(metrics.retention_policy, retention_element)
306
307
308def _convert_retention_policy_to_xml(retention_policy, root):
309    '''
310    <Enabled>true|false</Enabled>
311    <Days>number-of-days</Days>
312    '''
313    # Enabled
314    ETree.SubElement(root, 'Enabled').text = str(retention_policy.enabled)
315
316    # Days
317    if retention_policy.enabled and retention_policy.days:
318        ETree.SubElement(root, 'Days').text = str(retention_policy.days)
319
320
321def _len_plus(data):
322    length = None
323    # Check if object implements the __len__ method, covers most input cases such as bytearray.
324    try:
325        length = len(data)
326    except:
327        pass
328
329    if not length:
330        # Check if the stream is a file-like stream object.
331        # If so, calculate the size using the file descriptor.
332        try:
333            fileno = data.fileno()
334        except (AttributeError, UnsupportedOperation):
335            pass
336        else:
337            return fstat(fileno).st_size
338
339        # If the stream is seekable and tell() is implemented, calculate the stream size.
340        try:
341            current_position = data.tell()
342            data.seek(0, SEEK_END)
343            length = data.tell() - current_position
344            data.seek(current_position, SEEK_SET)
345        except (AttributeError, UnsupportedOperation):
346            pass
347
348    return length
349