1"""
2Module for data decoding
3"""
4import json
5import struct
6import tarfile
7import warnings
8from io import BytesIO
9from xml.etree import ElementTree
10
11import numpy as np
12import tifffile as tiff
13from PIL import Image
14
15from .constants import MimeType
16from .exceptions import ImageDecodingError
17
18
19warnings.simplefilter('ignore', Image.DecompressionBombWarning)
20
21
22def decode_data(response_content, data_type):
23    """ Interprets downloaded data and returns it.
24
25    :param response_content: downloaded data (i.e. json, png, tiff, xml, zip, ... file)
26    :type response_content: bytes
27    :param data_type: expected downloaded data type
28    :type data_type: constants.MimeType
29    :return: downloaded data
30    :rtype: numpy array in case of image data type, or other possible data type
31    :raises: ValueError
32    """
33    if data_type is MimeType.JSON:
34        response_text = response_content.decode('utf-8')
35        if not response_text:
36            return response_text
37        return json.loads(response_text)
38    if data_type is MimeType.TAR:
39        return decode_tar(response_content)
40    if MimeType.is_image_format(data_type):
41        return decode_image(response_content, data_type)
42    if data_type is MimeType.XML or data_type is MimeType.GML or data_type is MimeType.SAFE:
43        return ElementTree.fromstring(response_content)
44
45    try:
46        return {
47            MimeType.RAW: response_content,
48            MimeType.TXT: response_content,
49            MimeType.ZIP: BytesIO(response_content)
50        }[data_type]
51    except KeyError as exception:
52        raise ValueError(f'Decoding data format {data_type} is not supported') from exception
53
54
55def decode_image(data, image_type):
56    """ Decodes the image provided in various formats, i.e. png, 16-bit float tiff, 32-bit float tiff, jp2
57    and returns it as an numpy array
58
59    :param data: image in its original format
60    :type data: any of possible image types
61    :param image_type: expected image format
62    :type image_type: constants.MimeType
63    :return: image as numpy array
64    :rtype: numpy array
65    :raises: ImageDecodingError
66    """
67    bytes_data = BytesIO(data)
68    if image_type is MimeType.TIFF:
69        image = tiff.imread(bytes_data)
70    else:
71        image = np.array(Image.open(bytes_data))
72
73        if image_type is MimeType.JP2:
74            try:
75                bit_depth = get_jp2_bit_depth(bytes_data)
76                image = fix_jp2_image(image, bit_depth)
77            except ValueError:
78                pass
79
80    if image is None:
81        raise ImageDecodingError('Unable to decode image')
82    return image
83
84
85def decode_tar(data):
86    """ A decoder to convert response bytes into a dictionary of {filename: value}
87
88    :param data: Data to decode
89    :type data: bytes or IOBase
90    :return: A dictionary of decoded files from a tar file
91    :rtype: dict(str: object)
92    """
93    if isinstance(data, bytes):
94        data = BytesIO(data)
95
96    with tarfile.open(fileobj=data) as tar:
97        file_members = (member for member in tar.getmembers() if member.isfile())
98        itr = ((member.name, get_data_format(member.name), tar.extractfile(member)) for member in file_members)
99        return {filename: decode_data(file.read(), file_type) for filename, file_type, file in itr}
100
101
102def decode_sentinelhub_err_msg(response):
103    """ Decodes error message from Sentinel Hub service
104
105    :param response: Sentinel Hub service response
106    :type response: requests.Response
107    :return: An error message
108    :rtype: str
109    """
110    try:
111        server_message = []
112        for elem in decode_data(response.content, MimeType.XML):
113            if 'ServiceException' in elem.tag or 'Message' in elem.tag:
114                server_message.append(elem.text.strip('\n\t '))
115        return ''.join(server_message)
116    except ElementTree.ParseError:
117        return response.text
118
119
120def get_jp2_bit_depth(stream):
121    """ Reads bit encoding depth of jpeg2000 file in binary stream format
122
123    :param stream: binary stream format
124    :type stream: Binary I/O (e.g. io.BytesIO, io.BufferedReader, ...)
125    :return: bit depth
126    :rtype: int
127    """
128    stream.seek(0)
129    while True:
130        read_buffer = stream.read(8)
131        if len(read_buffer) < 8:
132            raise ValueError('Image Header Box not found in Jpeg2000 file')
133
134        _, box_id = struct.unpack('>I4s', read_buffer)
135
136        if box_id == b'ihdr':
137            read_buffer = stream.read(14)
138            params = struct.unpack('>IIHBBBB', read_buffer)
139            return (params[3] & 0x7f) + 1
140
141
142def fix_jp2_image(image, bit_depth):
143    """ Because Pillow library incorrectly reads JPEG 2000 images with 15-bit encoding this function corrects the
144    values in image.
145
146    :param image: image read by opencv library
147    :type image: numpy array
148    :param bit_depth: bit depth of jp2 image encoding
149    :type bit_depth: int
150    :return: corrected image
151    :rtype: numpy array
152    """
153    if bit_depth in [8, 16]:
154        return image
155    if bit_depth == 15:
156        try:
157            return image >> 1
158        except TypeError as exception:
159            raise IOError('Failed to read JPEG 2000 image correctly. Most likely reason is that Pillow did not '
160                          'install OpenJPEG library correctly. Try reinstalling Pillow from a wheel') from exception
161
162    raise ValueError(f'Bit depth {bit_depth} of jp2 image is currently not supported. '
163                     'Please raise an issue on package Github page')
164
165
166def get_data_format(filename):
167    """ Util function to guess format from filename extension
168
169    :param filename: name of file
170    :type filename: str
171    :return: file extension
172    :rtype: MimeType
173    """
174    fmt_ext = filename.split('.')[-1]
175    return MimeType.from_string(fmt_ext)
176