1# -*- coding: utf-8 -*-
2"""
3
4requests_toolbelt.multipart.encoder
5===================================
6
7This holds all of the implementation details of the MultipartEncoder
8
9"""
10import contextlib
11import io
12import os
13from uuid import uuid4
14
15import requests
16
17from .._compat import fields
18
19
20class FileNotSupportedError(Exception):
21    """File not supported error."""
22
23
24class MultipartEncoder(object):
25
26    """
27
28    The ``MultipartEncoder`` object is a generic interface to the engine that
29    will create a ``multipart/form-data`` body for you.
30
31    The basic usage is:
32
33    .. code-block:: python
34
35        import requests
36        from requests_toolbelt import MultipartEncoder
37
38        encoder = MultipartEncoder({'field': 'value',
39                                    'other_field', 'other_value'})
40        r = requests.post('https://httpbin.org/post', data=encoder,
41                          headers={'Content-Type': encoder.content_type})
42
43    If you do not need to take advantage of streaming the post body, you can
44    also do:
45
46    .. code-block:: python
47
48        r = requests.post('https://httpbin.org/post',
49                          data=encoder.to_string(),
50                          headers={'Content-Type': encoder.content_type})
51
52    If you want the encoder to use a specific order, you can use an
53    OrderedDict or more simply, a list of tuples:
54
55    .. code-block:: python
56
57        encoder = MultipartEncoder([('field', 'value'),
58                                    ('other_field', 'other_value')])
59
60    .. versionchanged:: 0.4.0
61
62    You can also provide tuples as part values as you would provide them to
63    requests' ``files`` parameter.
64
65    .. code-block:: python
66
67        encoder = MultipartEncoder({
68            'field': ('file_name', b'{"a": "b"}', 'application/json',
69                      {'X-My-Header': 'my-value'})
70        ])
71
72    .. warning::
73
74        This object will end up directly in :mod:`httplib`. Currently,
75        :mod:`httplib` has a hard-coded read size of **8192 bytes**. This
76        means that it will loop until the file has been read and your upload
77        could take a while. This is **not** a bug in requests. A feature is
78        being considered for this object to allow you, the user, to specify
79        what size should be returned on a read. If you have opinions on this,
80        please weigh in on `this issue`_.
81
82    .. _this issue:
83        https://github.com/requests/toolbelt/issues/75
84
85    """
86
87    def __init__(self, fields, boundary=None, encoding='utf-8'):
88        #: Boundary value either passed in by the user or created
89        self.boundary_value = boundary or uuid4().hex
90
91        # Computed boundary
92        self.boundary = '--{}'.format(self.boundary_value)
93
94        #: Encoding of the data being passed in
95        self.encoding = encoding
96
97        # Pre-encoded boundary
98        self._encoded_boundary = b''.join([
99            encode_with(self.boundary, self.encoding),
100            encode_with('\r\n', self.encoding)
101            ])
102
103        #: Fields provided by the user
104        self.fields = fields
105
106        #: Whether or not the encoder is finished
107        self.finished = False
108
109        #: Pre-computed parts of the upload
110        self.parts = []
111
112        # Pre-computed parts iterator
113        self._iter_parts = iter([])
114
115        # The part we're currently working with
116        self._current_part = None
117
118        # Cached computation of the body's length
119        self._len = None
120
121        # Our buffer
122        self._buffer = CustomBytesIO(encoding=encoding)
123
124        # Pre-compute each part's headers
125        self._prepare_parts()
126
127        # Load boundary into buffer
128        self._write_boundary()
129
130    @property
131    def len(self):
132        """Length of the multipart/form-data body.
133
134        requests will first attempt to get the length of the body by calling
135        ``len(body)`` and then by checking for the ``len`` attribute.
136
137        On 32-bit systems, the ``__len__`` method cannot return anything
138        larger than an integer (in C) can hold. If the total size of the body
139        is even slightly larger than 4GB users will see an OverflowError. This
140        manifested itself in `bug #80`_.
141
142        As such, we now calculate the length lazily as a property.
143
144        .. _bug #80:
145            https://github.com/requests/toolbelt/issues/80
146        """
147        # If _len isn't already calculated, calculate, return, and set it
148        return self._len or self._calculate_length()
149
150    def __repr__(self):
151        return '<MultipartEncoder: {!r}>'.format(self.fields)
152
153    def _calculate_length(self):
154        """
155        This uses the parts to calculate the length of the body.
156
157        This returns the calculated length so __len__ can be lazy.
158        """
159        boundary_len = len(self.boundary)  # Length of --{boundary}
160        # boundary length + header length + body length + len('\r\n') * 2
161        self._len = sum(
162            (boundary_len + total_len(p) + 4) for p in self.parts
163            ) + boundary_len + 4
164        return self._len
165
166    def _calculate_load_amount(self, read_size):
167        """This calculates how many bytes need to be added to the buffer.
168
169        When a consumer read's ``x`` from the buffer, there are two cases to
170        satisfy:
171
172            1. Enough data in the buffer to return the requested amount
173            2. Not enough data
174
175        This function uses the amount of unread bytes in the buffer and
176        determines how much the Encoder has to load before it can return the
177        requested amount of bytes.
178
179        :param int read_size: the number of bytes the consumer requests
180        :returns: int -- the number of bytes that must be loaded into the
181            buffer before the read can be satisfied. This will be strictly
182            non-negative
183        """
184        amount = read_size - total_len(self._buffer)
185        return amount if amount > 0 else 0
186
187    def _load(self, amount):
188        """Load ``amount`` number of bytes into the buffer."""
189        self._buffer.smart_truncate()
190        part = self._current_part or self._next_part()
191        while amount == -1 or amount > 0:
192            written = 0
193            if part and not part.bytes_left_to_write():
194                written += self._write(b'\r\n')
195                written += self._write_boundary()
196                part = self._next_part()
197
198            if not part:
199                written += self._write_closing_boundary()
200                self.finished = True
201                break
202
203            written += part.write_to(self._buffer, amount)
204
205            if amount != -1:
206                amount -= written
207
208    def _next_part(self):
209        try:
210            p = self._current_part = next(self._iter_parts)
211        except StopIteration:
212            p = None
213        return p
214
215    def _iter_fields(self):
216        _fields = self.fields
217        if hasattr(self.fields, 'items'):
218            _fields = list(self.fields.items())
219        for k, v in _fields:
220            file_name = None
221            file_type = None
222            file_headers = None
223            if isinstance(v, (list, tuple)):
224                if len(v) == 2:
225                    file_name, file_pointer = v
226                elif len(v) == 3:
227                    file_name, file_pointer, file_type = v
228                else:
229                    file_name, file_pointer, file_type, file_headers = v
230            else:
231                file_pointer = v
232
233            field = fields.RequestField(name=k, data=file_pointer,
234                                        filename=file_name,
235                                        headers=file_headers)
236            field.make_multipart(content_type=file_type)
237            yield field
238
239    def _prepare_parts(self):
240        """This uses the fields provided by the user and creates Part objects.
241
242        It populates the `parts` attribute and uses that to create a
243        generator for iteration.
244        """
245        enc = self.encoding
246        self.parts = [Part.from_field(f, enc) for f in self._iter_fields()]
247        self._iter_parts = iter(self.parts)
248
249    def _write(self, bytes_to_write):
250        """Write the bytes to the end of the buffer.
251
252        :param bytes bytes_to_write: byte-string (or bytearray) to append to
253            the buffer
254        :returns: int -- the number of bytes written
255        """
256        return self._buffer.append(bytes_to_write)
257
258    def _write_boundary(self):
259        """Write the boundary to the end of the buffer."""
260        return self._write(self._encoded_boundary)
261
262    def _write_closing_boundary(self):
263        """Write the bytes necessary to finish a multipart/form-data body."""
264        with reset(self._buffer):
265            self._buffer.seek(-2, 2)
266            self._buffer.write(b'--\r\n')
267        return 2
268
269    def _write_headers(self, headers):
270        """Write the current part's headers to the buffer."""
271        return self._write(encode_with(headers, self.encoding))
272
273    @property
274    def content_type(self):
275        return str(
276            'multipart/form-data; boundary={}'.format(self.boundary_value)
277            )
278
279    def to_string(self):
280        """Return the entirety of the data in the encoder.
281
282        .. note::
283
284            This simply reads all of the data it can. If you have started
285            streaming or reading data from the encoder, this method will only
286            return whatever data is left in the encoder.
287
288        .. note::
289
290            This method affects the internal state of the encoder. Calling
291            this method will exhaust the encoder.
292
293        :returns: the multipart message
294        :rtype: bytes
295        """
296
297        return self.read()
298
299    def read(self, size=-1):
300        """Read data from the streaming encoder.
301
302        :param int size: (optional), If provided, ``read`` will return exactly
303            that many bytes. If it is not provided, it will return the
304            remaining bytes.
305        :returns: bytes
306        """
307        if self.finished:
308            return self._buffer.read(size)
309
310        bytes_to_load = size
311        if bytes_to_load != -1 and bytes_to_load is not None:
312            bytes_to_load = self._calculate_load_amount(int(size))
313
314        self._load(bytes_to_load)
315        return self._buffer.read(size)
316
317
318def IDENTITY(monitor):
319    return monitor
320
321
322class MultipartEncoderMonitor(object):
323
324    """
325    An object used to monitor the progress of a :class:`MultipartEncoder`.
326
327    The :class:`MultipartEncoder` should only be responsible for preparing and
328    streaming the data. For anyone who wishes to monitor it, they shouldn't be
329    using that instance to manage that as well. Using this class, they can
330    monitor an encoder and register a callback. The callback receives the
331    instance of the monitor.
332
333    To use this monitor, you construct your :class:`MultipartEncoder` as you
334    normally would.
335
336    .. code-block:: python
337
338        from requests_toolbelt import (MultipartEncoder,
339                                       MultipartEncoderMonitor)
340        import requests
341
342        def callback(monitor):
343            # Do something with this information
344            pass
345
346        m = MultipartEncoder(fields={'field0': 'value0'})
347        monitor = MultipartEncoderMonitor(m, callback)
348        headers = {'Content-Type': monitor.content_type}
349        r = requests.post('https://httpbin.org/post', data=monitor,
350                          headers=headers)
351
352    Alternatively, if your use case is very simple, you can use the following
353    pattern.
354
355    .. code-block:: python
356
357        from requests_toolbelt import MultipartEncoderMonitor
358        import requests
359
360        def callback(monitor):
361            # Do something with this information
362            pass
363
364        monitor = MultipartEncoderMonitor.from_fields(
365            fields={'field0': 'value0'}, callback
366            )
367        headers = {'Content-Type': montior.content_type}
368        r = requests.post('https://httpbin.org/post', data=monitor,
369                          headers=headers)
370
371    """
372
373    def __init__(self, encoder, callback=None):
374        #: Instance of the :class:`MultipartEncoder` being monitored
375        self.encoder = encoder
376
377        #: Optionally function to call after a read
378        self.callback = callback or IDENTITY
379
380        #: Number of bytes already read from the :class:`MultipartEncoder`
381        #: instance
382        self.bytes_read = 0
383
384        #: Avoid the same problem in bug #80
385        self.len = self.encoder.len
386
387    @classmethod
388    def from_fields(cls, fields, boundary=None, encoding='utf-8',
389                    callback=None):
390        encoder = MultipartEncoder(fields, boundary, encoding)
391        return cls(encoder, callback)
392
393    @property
394    def content_type(self):
395        return self.encoder.content_type
396
397    def to_string(self):
398        return self.read()
399
400    def read(self, size=-1):
401        string = self.encoder.read(size)
402        self.bytes_read += len(string)
403        self.callback(self)
404        return string
405
406
407def encode_with(string, encoding):
408    """Encoding ``string`` with ``encoding`` if necessary.
409
410    :param str string: If string is a bytes object, it will not encode it.
411        Otherwise, this function will encode it with the provided encoding.
412    :param str encoding: The encoding with which to encode string.
413    :returns: encoded bytes object
414    """
415    if not (string is None or isinstance(string, bytes)):
416        return string.encode(encoding)
417    return string
418
419
420def readable_data(data, encoding):
421    """Coerce the data to an object with a ``read`` method."""
422    if hasattr(data, 'read'):
423        return data
424
425    return CustomBytesIO(data, encoding)
426
427
428def total_len(o):
429    if hasattr(o, '__len__'):
430        return len(o)
431
432    if hasattr(o, 'len'):
433        return o.len
434
435    if hasattr(o, 'fileno'):
436        try:
437            fileno = o.fileno()
438        except io.UnsupportedOperation:
439            pass
440        else:
441            return os.fstat(fileno).st_size
442
443    if hasattr(o, 'getvalue'):
444        # e.g. BytesIO, cStringIO.StringIO
445        return len(o.getvalue())
446
447
448@contextlib.contextmanager
449def reset(buffer):
450    """Keep track of the buffer's current position and write to the end.
451
452    This is a context manager meant to be used when adding data to the buffer.
453    It eliminates the need for every function to be concerned with the
454    position of the cursor in the buffer.
455    """
456    original_position = buffer.tell()
457    buffer.seek(0, 2)
458    yield
459    buffer.seek(original_position, 0)
460
461
462def coerce_data(data, encoding):
463    """Ensure that every object's __len__ behaves uniformly."""
464    if not isinstance(data, CustomBytesIO):
465        if hasattr(data, 'getvalue'):
466            return CustomBytesIO(data.getvalue(), encoding)
467
468        if hasattr(data, 'fileno'):
469            return FileWrapper(data)
470
471        if not hasattr(data, 'read'):
472            return CustomBytesIO(data, encoding)
473
474    return data
475
476
477def to_list(fields):
478    if hasattr(fields, 'items'):
479        return list(fields.items())
480    return list(fields)
481
482
483class Part(object):
484    def __init__(self, headers, body):
485        self.headers = headers
486        self.body = body
487        self.headers_unread = True
488        self.len = len(self.headers) + total_len(self.body)
489
490    @classmethod
491    def from_field(cls, field, encoding):
492        """Create a part from a Request Field generated by urllib3."""
493        headers = encode_with(field.render_headers(), encoding)
494        body = coerce_data(field.data, encoding)
495        return cls(headers, body)
496
497    def bytes_left_to_write(self):
498        """Determine if there are bytes left to write.
499
500        :returns: bool -- ``True`` if there are bytes left to write, otherwise
501            ``False``
502        """
503        to_read = 0
504        if self.headers_unread:
505            to_read += len(self.headers)
506
507        return (to_read + total_len(self.body)) > 0
508
509    def write_to(self, buffer, size):
510        """Write the requested amount of bytes to the buffer provided.
511
512        The number of bytes written may exceed size on the first read since we
513        load the headers ambitiously.
514
515        :param CustomBytesIO buffer: buffer we want to write bytes to
516        :param int size: number of bytes requested to be written to the buffer
517        :returns: int -- number of bytes actually written
518        """
519        written = 0
520        if self.headers_unread:
521            written += buffer.append(self.headers)
522            self.headers_unread = False
523
524        while total_len(self.body) > 0 and (size == -1 or written < size):
525            amount_to_read = size
526            if size != -1:
527                amount_to_read = size - written
528            written += buffer.append(self.body.read(amount_to_read))
529
530        return written
531
532
533class CustomBytesIO(io.BytesIO):
534    def __init__(self, buffer=None, encoding='utf-8'):
535        buffer = encode_with(buffer, encoding)
536        super(CustomBytesIO, self).__init__(buffer)
537
538    def _get_end(self):
539        current_pos = self.tell()
540        self.seek(0, 2)
541        length = self.tell()
542        self.seek(current_pos, 0)
543        return length
544
545    @property
546    def len(self):
547        length = self._get_end()
548        return length - self.tell()
549
550    def append(self, bytes):
551        with reset(self):
552            written = self.write(bytes)
553        return written
554
555    def smart_truncate(self):
556        to_be_read = total_len(self)
557        already_read = self._get_end() - to_be_read
558
559        if already_read >= to_be_read:
560            old_bytes = self.read()
561            self.seek(0, 0)
562            self.truncate()
563            self.write(old_bytes)
564            self.seek(0, 0)  # We want to be at the beginning
565
566
567class FileWrapper(object):
568    def __init__(self, file_object):
569        self.fd = file_object
570
571    @property
572    def len(self):
573        return total_len(self.fd) - self.fd.tell()
574
575    def read(self, length=-1):
576        return self.fd.read(length)
577
578
579class FileFromURLWrapper(object):
580    """File from URL wrapper.
581
582    The :class:`FileFromURLWrapper` object gives you the ability to stream file
583    from provided URL in chunks by :class:`MultipartEncoder`.
584    Provide a stateless solution for streaming file from one server to another.
585    You can use the :class:`FileFromURLWrapper` without a session or with
586    a session as demonstated by the examples below:
587
588    .. code-block:: python
589        # no session
590
591        import requests
592        from requests_toolbelt import MultipartEncoder, FileFromURLWrapper
593
594        url = 'https://httpbin.org/image/png'
595        streaming_encoder = MultipartEncoder(
596            fields={
597                'file': FileFromURLWrapper(url)
598            }
599        )
600        r = requests.post(
601            'https://httpbin.org/post', data=streaming_encoder,
602            headers={'Content-Type': streaming_encoder.content_type}
603        )
604
605    .. code-block:: python
606        # using a session
607
608        import requests
609        from requests_toolbelt import MultipartEncoder, FileFromURLWrapper
610
611        session = requests.Session()
612        url = 'https://httpbin.org/image/png'
613        streaming_encoder = MultipartEncoder(
614            fields={
615                'file': FileFromURLWrapper(url, session=session)
616            }
617        )
618        r = session.post(
619            'https://httpbin.org/post', data=streaming_encoder,
620            headers={'Content-Type': streaming_encoder.content_type}
621        )
622
623    """
624
625    def __init__(self, file_url, session=None):
626        self.session = session or requests.Session()
627        requested_file = self._request_for_file(file_url)
628        self.len = int(requested_file.headers['content-length'])
629        self.raw_data = requested_file.raw
630
631    def _request_for_file(self, file_url):
632        """Make call for file under provided URL."""
633        response = self.session.get(file_url, stream=True)
634        content_length = response.headers.get('content-length', None)
635        if content_length is None:
636            error_msg = (
637                "Data from provided URL {url} is not supported. Lack of "
638                "content-length Header in requested file response.".format(
639                    url=file_url)
640            )
641            raise FileNotSupportedError(error_msg)
642        elif not content_length.isdigit():
643            error_msg = (
644                "Data from provided URL {url} is not supported. content-length"
645                " header value is not a digit.".format(url=file_url)
646            )
647            raise FileNotSupportedError(error_msg)
648        return response
649
650    def read(self, chunk_size):
651        """Read file in chunks."""
652        chunk_size = chunk_size if chunk_size >= 0 else self.len
653        chunk = self.raw_data.read(chunk_size) or b''
654        self.len -= len(chunk) if chunk else 0  # left to read
655        return chunk
656