1#---------------------------------------------------------------------------------------------- 2# Copyright (c) Microsoft Corporation. All rights reserved. 3# Licensed under the MIT License. See License.txt in the project root for license information. 4#---------------------------------------------------------------------------------------------- 5 6from __future__ import division 7import os 8import sys 9import numpy as np 10from six import text_type, binary_type, integer_types 11import mmdnn.conversion.common.IR.graph_pb2 as graph_pb2 12 13 14__all__ = ["assign_IRnode_values", "convert_onnx_pad_to_tf", 'convert_tf_pad_to_onnx', 15 'compute_tf_same_padding', 'is_valid_padding', 'download_file', 16 'shape_to_list', 'list_to_shape'] 17 18 19def assign_attr_value(attr, val): 20 from mmdnn.conversion.common.IR.graph_pb2 import TensorShape 21 '''Assign value to AttrValue proto according to data type.''' 22 if isinstance(val, bool): 23 attr.b = val 24 elif isinstance(val, integer_types): 25 attr.i = val 26 elif isinstance(val, float): 27 attr.f = val 28 elif isinstance(val, binary_type) or isinstance(val, text_type): 29 if hasattr(val, 'encode'): 30 val = val.encode() 31 attr.s = val 32 elif isinstance(val, TensorShape): 33 attr.shape.MergeFromString(val.SerializeToString()) 34 elif isinstance(val, list): 35 if not val: return 36 if isinstance(val[0], integer_types): 37 attr.list.i.extend(val) 38 elif isinstance(val[0], TensorShape): 39 attr.list.shape.extend(val) 40 elif isinstance(val[0], float): 41 attr.list.f.extend(val) 42 else: 43 raise NotImplementedError('AttrValue cannot be of list[{}].'.format(val[0])) 44 elif isinstance(val, np.ndarray): 45 assign_attr_value(attr, val.tolist()) 46 else: 47 raise NotImplementedError('AttrValue cannot be of %s' % type(val)) 48 49 50def assign_IRnode_values(IR_node, val_dict): 51 for name, val in val_dict.items(): 52 assign_attr_value(IR_node.attr[name], val) 53 54 55# For padding 56def convert_tf_pad_to_onnx(pads): 57 pads = np.reshape(pads, -1).tolist() 58 dims = len(pads) 59 assert dims % 2 == 0 60 ret = [] 61 for idx in range(0, dims, 2): 62 ret.append(pads[idx]) 63 for idx in range(1, dims, 2): 64 ret.append(pads[idx]) 65 return ret 66 67 68def convert_onnx_pad_to_tf(pads): 69 return np.transpose(np.array(pads).reshape([2, -1])).reshape(-1, 2).tolist() 70 71 72def is_valid_padding(pads): 73 return sum(np.reshape(pads, -1)) == 0 74 75 76def shape_to_list(shape): 77 return [dim.size for dim in shape.dim] 78 79 80def list_to_shape(shape): 81 ret = graph_pb2.TensorShape() 82 for dim in shape: 83 new_dim = ret.dim.add() 84 new_dim.size = dim 85 return ret 86 87 88def compute_tf_same_padding(input_shape, kernel_shape, strides, data_format='NHWC'): 89 """ Convert [SAME] padding in tensorflow, keras to onnx pads, 90 i.e. [x1_begin, x2_begin...x1_end, x2_end,...] """ 91 # print (input_shape) 92 # print (kernel_shape) 93 # print (strides) 94 if data_format.startswith('NC'): 95 # Not tested 96 input_shape = input_shape[2:] 97 remove_dim = len(strides) - len(input_shape) 98 if remove_dim > 0: 99 strides = strides[remove_dim::] 100 101 else: 102 input_shape = input_shape[1:-1] 103 remove_dim = len(input_shape) - len(strides) + 1 104 if remove_dim < 0: 105 strides = strides[1:remove_dim] 106 107 # print (input_shape) 108 # print (kernel_shape) 109 # print (strides) 110 111 up_list = [0] 112 down_list = [0] 113 114 for idx in range(0, len(input_shape)): 115 # kernel_shape[idx] = (kernel_shape[idx] - 1) * dilation_rate + 1 116 output_shape = (input_shape[idx] + strides[idx] - 1) // strides[idx] 117 this_padding = (output_shape - 1) * strides[idx] + kernel_shape[idx] - input_shape[idx] 118 this_padding = max(0, this_padding) 119 up_list.append(this_padding // 2) 120 down_list.append(this_padding - this_padding // 2) 121 122 # print ([0] + up_list + [0] + down_list if data_format.startswith('NC') else up_list + [0] + down_list + [0]) 123 # print ('-----------------------------------------------------') 124 return [0] + up_list + [0] + down_list if data_format.startswith('NC') else up_list + [0] + down_list + [0] 125 126 127 128# network library 129def sizeof_fmt(num, suffix='B'): 130 for unit in ['','Ki','Mi','Gi','Ti','Pi','Ei','Zi']: 131 if abs(num) < 1024.0: 132 return "%3.1f %s%s" % (num, unit, suffix) 133 num /= 1024.0 134 return "%.1f %s%s" % (num, 'Yi', suffix) 135 136 137def _progress_check(count, block_size, total_size): 138 read_size = count * block_size 139 read_size_str = sizeof_fmt(read_size) 140 if total_size > 0: 141 percent = int(count * block_size * 100 / total_size) 142 percent = min(percent, 100) 143 sys.stdout.write("\rprogress: {} downloaded, {}%.".format(read_size_str, percent)) 144 if read_size >= total_size: 145 sys.stdout.write("\n") 146 else: 147 sys.stdout.write("\rprogress: {} downloaded.".format(read_size_str)) 148 sys.stdout.flush() 149 150 151def _single_thread_download(url, file_name): 152 from six.moves import urllib 153 result, _ = urllib.request.urlretrieve(url, file_name, _progress_check) 154 return result 155 156 157def _downloader(start, end, url, filename): 158 import requests 159 headers = {'Range': 'bytes=%d-%d' % (start, end)} 160 r = requests.get(url, headers=headers, stream=True) 161 with open(filename, "r+b") as fp: 162 fp.seek(start) 163 var = fp.tell() 164 fp.write(r.content) 165 166 167def _multi_thread_download(url, file_name, file_size, thread_count): 168 import threading 169 fp = open(file_name, "wb") 170 fp.truncate(file_size) 171 fp.close() 172 173 part = file_size // thread_count 174 for i in range(thread_count): 175 start = part * i 176 if i == thread_count - 1: 177 end = file_size 178 else: 179 end = start + part 180 181 t = threading.Thread(target=_downloader, kwargs={'start': start, 'end': end, 'url': url, 'filename': file_name}) 182 t.setDaemon(True) 183 t.start() 184 185 main_thread = threading.current_thread() 186 for t in threading.enumerate(): 187 if t is main_thread: 188 continue 189 t.join() 190 191 return file_name 192 193 194def download_file(url, directory='./', local_fname=None, force_write=False, auto_unzip=False, compre_type=''): 195 """Download the data from source url, unless it's already here. 196 197 Args: 198 filename: string, name of the file in the directory. 199 work_directory: string, path to working directory. 200 source_url: url to download from if file doesn't exist. 201 202 Returns: 203 Path to resulting file. 204 """ 205 206 if not os.path.isdir(directory): 207 os.mkdir(directory) 208 209 if not local_fname: 210 k = url.rfind('/') 211 local_fname = url[k + 1:] 212 213 local_fname = os.path.join(directory, local_fname) 214 215 if os.path.exists(local_fname) and not force_write: 216 print ("File [{}] existed!".format(local_fname)) 217 return local_fname 218 219 else: 220 print ("Downloading file [{}] from [{}]".format(local_fname, url)) 221 try: 222 import wget 223 ret = wget.download(url, local_fname) 224 print ("") 225 except: 226 ret = _single_thread_download(url, local_fname) 227 228 if auto_unzip: 229 if ret.endswith(".tar.gz") or ret.endswith(".tgz"): 230 try: 231 import tarfile 232 tar = tarfile.open(ret) 233 for name in tar.getnames(): 234 if not (os.path.realpath(os.path.join(directory, name))+ os.sep).startswith(os.path.realpath(directory) + os.sep): 235 raise ValueError('The decompression path does not match the current path. For more info: https://docs.python.org/3/library/tarfile.html#tarfile.TarFile.extractall') 236 tar.extractall(directory) 237 tar.close() 238 except ValueError: 239 raise 240 except: 241 print("Unzip file [{}] failed.".format(ret)) 242 243 elif ret.endswith('.zip'): 244 try: 245 import zipfile 246 zip_ref = zipfile.ZipFile(ret, 'r') 247 for name in zip_ref.namelist(): 248 if not (os.path.realpath(os.path.join(directory, name))+ os.sep).startswith(os.path.realpath(directory) + os.sep): 249 raise ValueError('The decompression path does not match the current path. For more info: https://docs.python.org/3/library/zipfile.html?highlight=zipfile#zipfile.ZipFile.extractall') 250 zip_ref.extractall(directory) 251 zip_ref.close() 252 except ValueError: 253 raise 254 except: 255 print("Unzip file [{}] failed.".format(ret)) 256 return ret 257""" 258 r = requests.head(url) 259 try: 260 file_size = int(r.headers['content-length']) 261 return _multi_thread_download(url, local_fname, file_size, 5) 262 263 except: 264 # not support multi-threads download 265 return _single_thread_download(url, local_fname) 266 267 return result 268""" 269