1#----------------------------------------------------------------------------------------------
2#  Copyright (c) Microsoft Corporation. All rights reserved.
3#  Licensed under the MIT License. See License.txt in the project root for license information.
4#----------------------------------------------------------------------------------------------
5
6from __future__ import division
7import os
8import sys
9import numpy as np
10from six import text_type, binary_type, integer_types
11import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2
12
13
14__all__ = ["assign_IRnode_values", "convert_onnx_pad_to_tf", 'convert_tf_pad_to_onnx',
15           'compute_tf_same_padding', 'is_valid_padding', 'download_file',
16           'shape_to_list', 'list_to_shape']
17
18
19def assign_attr_value(attr, val):
20    from mmdnn.conversion.common.IR.graph_pb2 import TensorShape
21    '''Assign value to AttrValue proto according to data type.'''
22    if isinstance(val, bool):
23        attr.b = val
24    elif isinstance(val, integer_types):
25        attr.i = val
26    elif isinstance(val, float):
27        attr.f = val
28    elif isinstance(val, binary_type) or isinstance(val, text_type):
29        if hasattr(val, 'encode'):
30            val = val.encode()
31        attr.s = val
32    elif isinstance(val, TensorShape):
33        attr.shape.MergeFromString(val.SerializeToString())
34    elif isinstance(val, list):
35        if not val: return
36        if isinstance(val[0], integer_types):
37            attr.list.i.extend(val)
38        elif isinstance(val[0], TensorShape):
39            attr.list.shape.extend(val)
40        elif isinstance(val[0], float):
41            attr.list.f.extend(val)
42        else:
43            raise NotImplementedError('AttrValue cannot be of list[{}].'.format(val[0]))
44    elif isinstance(val, np.ndarray):
45        assign_attr_value(attr, val.tolist())
46    else:
47        raise NotImplementedError('AttrValue cannot be of %s' % type(val))
48
49
50def assign_IRnode_values(IR_node, val_dict):
51    for name, val in val_dict.items():
52        assign_attr_value(IR_node.attr[name], val)
53
54
55# For padding
56def convert_tf_pad_to_onnx(pads):
57    pads = np.reshape(pads, -1).tolist()
58    dims = len(pads)
59    assert dims % 2 == 0
60    ret = []
61    for idx in range(0, dims, 2):
62        ret.append(pads[idx])
63    for idx in range(1, dims, 2):
64        ret.append(pads[idx])
65    return ret
66
67
68def convert_onnx_pad_to_tf(pads):
69    return np.transpose(np.array(pads).reshape([2, -1])).reshape(-1, 2).tolist()
70
71
72def is_valid_padding(pads):
73    return sum(np.reshape(pads, -1)) == 0
74
75
76def shape_to_list(shape):
77    return [dim.size for dim in shape.dim]
78
79
80def list_to_shape(shape):
81    ret = graph_pb2.TensorShape()
82    for dim in shape:
83        new_dim = ret.dim.add()
84        new_dim.size = dim
85    return ret
86
87
88def compute_tf_same_padding(input_shape, kernel_shape, strides, data_format='NHWC'):
89    """ Convert [SAME] padding in tensorflow, keras to onnx pads,
90        i.e. [x1_begin, x2_begin...x1_end, x2_end,...] """
91    # print (input_shape)
92    # print (kernel_shape)
93    # print (strides)
94    if data_format.startswith('NC'):
95        # Not tested
96        input_shape = input_shape[2:]
97        remove_dim = len(strides) - len(input_shape)
98        if remove_dim > 0:
99            strides = strides[remove_dim::]
100
101    else:
102        input_shape = input_shape[1:-1]
103        remove_dim = len(input_shape) - len(strides) + 1
104        if remove_dim < 0:
105            strides = strides[1:remove_dim]
106
107    # print (input_shape)
108    # print (kernel_shape)
109    # print (strides)
110
111    up_list = [0]
112    down_list = [0]
113
114    for idx in range(0, len(input_shape)):
115        # kernel_shape[idx] = (kernel_shape[idx] - 1) * dilation_rate + 1
116        output_shape = (input_shape[idx] + strides[idx] - 1) // strides[idx]
117        this_padding = (output_shape - 1) * strides[idx] + kernel_shape[idx] - input_shape[idx]
118        this_padding = max(0, this_padding)
119        up_list.append(this_padding // 2)
120        down_list.append(this_padding - this_padding // 2)
121
122    # print ([0] + up_list + [0] + down_list if data_format.startswith('NC') else up_list + [0] + down_list + [0])
123    # print ('-----------------------------------------------------')
124    return [0] + up_list + [0] + down_list if data_format.startswith('NC') else up_list + [0] + down_list + [0]
125
126
127
128# network library
129def sizeof_fmt(num, suffix='B'):
130    for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']:
131        if abs(num) < 1024.0:
132            return "%3.1f %s%s" % (num, unit, suffix)
133        num /= 1024.0
134    return "%.1f %s%s" % (num, 'Yi', suffix)
135
136
137def _progress_check(count, block_size, total_size):
138    read_size = count * block_size
139    read_size_str = sizeof_fmt(read_size)
140    if total_size > 0:
141        percent = int(count * block_size * 100 / total_size)
142        percent = min(percent, 100)
143        sys.stdout.write("\rprogress: {} downloaded, {}%.".format(read_size_str, percent))
144        if read_size >= total_size:
145            sys.stdout.write("\n")
146    else:
147        sys.stdout.write("\rprogress: {} downloaded.".format(read_size_str))
148    sys.stdout.flush()
149
150
151def _single_thread_download(url, file_name):
152    from six.moves import urllib
153    result, _ = urllib.request.urlretrieve(url, file_name, _progress_check)
154    return result
155
156
157def _downloader(start, end, url, filename):
158    import requests
159    headers = {'Range': 'bytes=%d-%d' % (start, end)}
160    r = requests.get(url, headers=headers, stream=True)
161    with open(filename, "r+b") as fp:
162        fp.seek(start)
163        var = fp.tell()
164        fp.write(r.content)
165
166
167def _multi_thread_download(url, file_name, file_size, thread_count):
168    import threading
169    fp = open(file_name, "wb")
170    fp.truncate(file_size)
171    fp.close()
172
173    part = file_size // thread_count
174    for i in range(thread_count):
175        start = part * i
176        if i == thread_count - 1:
177            end = file_size
178        else:
179            end = start + part
180
181        t = threading.Thread(target=_downloader, kwargs={'start': start, 'end': end, 'url': url, 'filename': file_name})
182        t.setDaemon(True)
183        t.start()
184
185    main_thread = threading.current_thread()
186    for t in threading.enumerate():
187        if t is main_thread:
188            continue
189        t.join()
190
191    return file_name
192
193
194def download_file(url, directory='./', local_fname=None, force_write=False, auto_unzip=False, compre_type=''):
195    """Download the data from source url, unless it's already here.
196
197    Args:
198        filename: string, name of the file in the directory.
199        work_directory: string, path to working directory.
200        source_url: url to download from if file doesn't exist.
201
202    Returns:
203        Path to resulting file.
204    """
205
206    if not os.path.isdir(directory):
207        os.mkdir(directory)
208
209    if not local_fname:
210        k = url.rfind('/')
211        local_fname = url[k + 1:]
212
213    local_fname = os.path.join(directory, local_fname)
214
215    if os.path.exists(local_fname) and not force_write:
216        print ("File [{}] existed!".format(local_fname))
217        return local_fname
218
219    else:
220        print ("Downloading file [{}] from [{}]".format(local_fname, url))
221        try:
222            import wget
223            ret = wget.download(url, local_fname)
224            print ("")
225        except:
226            ret = _single_thread_download(url, local_fname)
227
228    if auto_unzip:
229        if ret.endswith(".tar.gz") or ret.endswith(".tgz"):
230            try:
231                import tarfile
232                tar = tarfile.open(ret)
233                for name in tar.getnames():
234                    if not (os.path.realpath(os.path.join(directory, name))+ os.sep).startswith(os.path.realpath(directory) + os.sep):
235                        raise ValueError('The decompression path does not match the current path. For more info: https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall')
236                tar.extractall(directory)
237                tar.close()
238            except ValueError:
239                raise
240            except:
241                print("Unzip file [{}] failed.".format(ret))
242
243        elif ret.endswith('.zip'):
244            try:
245                import zipfile
246                zip_ref = zipfile.ZipFile(ret, 'r')
247                for name in zip_ref.namelist():
248                    if not (os.path.realpath(os.path.join(directory, name))+ os.sep).startswith(os.path.realpath(directory) + os.sep):
249                        raise ValueError('The decompression path does not match the current path. For more info: https://docs.python.org/3/library/zipfile.html?highlight=zipfile#zipfile.ZipFile.extractall')
250                zip_ref.extractall(directory)
251                zip_ref.close()
252            except ValueError:
253                raise
254            except:
255                print("Unzip file [{}] failed.".format(ret))
256    return ret
257"""
258    r = requests.head(url)
259    try:
260        file_size = int(r.headers['content-length'])
261        return _multi_thread_download(url, local_fname, file_size, 5)
262
263    except:
264        # not support multi-threads download
265        return _single_thread_download(url, local_fname)
266
267    return result
268"""
269