1from contextlib import contextmanager 2 3import chainer.functions as F 4 5from onnx_chainer import functions 6from onnx_chainer.functions.converter import FunctionConverter 7from onnx_chainer.replace_func import fake_as_funcnode 8 9 10_supported_function_node_set = { 11 # Activation 12 'ClippedReLU', 13 'ELU', 14 'HardSigmoid', 15 'LeakyReLU', 16 'LogSoftmax', 17 'PReLUFunction', 18 'ReLU', 19 'Selu', 20 'Sigmoid', 21 'Softmax', 22 'Softplus', 23 'Tanh', 24 25 # Array 26 'Cast', 27 'Concat', 28 'Copy', 29 'Depth2Space', 30 'Dstack', 31 'ExpandDims', 32 'GetItem', 33 'Hstack', 34 'Moveaxis', 35 'Pad', 36 'Permutate', 37 'Repeat', 38 'Reshape', 39 'ResizeImages', 40 'Rollaxis', 41 'SelectItem', 42 'Separate', 43 'Shape', 44 'Space2Depth', 45 'SplitAxis', 46 'Squeeze', 47 'Stack', 48 'Swapaxes', 49 'Tile', 50 'Transpose', 51 'TransposeSequence', 52 'Vstack', 53 'Where', 54 55 # Connection 56 'Convolution2DFunction', 57 'ConvolutionND', 58 'Deconvolution2DFunction', 59 'DeconvolutionND', 60 'EmbedIDFunction', 61 'LinearFunction', 62 63 # Loss 64 'SoftmaxCrossEntropy', 65 66 # Math 67 'Absolute', 68 'Add', 69 'AddConstant', 70 'Arccos', 71 'Arcsin', 72 'Arctan', 73 'ArgMax', 74 'ArgMin', 75 'BroadcastTo', 76 'Clip', 77 'Cos', 78 'Cosh', 79 'Div', 80 'DivFromConstant', 81 'Exp', 82 'Identity', 83 'LinearInterpolate', 84 'Log', 85 'LogSumExp', 86 'MatMul', 87 'Max', 88 'Maximum', 89 'Mean', 90 'Min', 91 'Minimum', 92 'Mul', 93 'MulConstant', 94 'Neg', 95 'PowConstVar', 96 'PowVarConst', 97 'PowVarVar', 98 'Prod', 99 'RsqrtGPU', 100 'sign', 101 'Sin', 102 'Sinh', 103 'Sqrt', 104 'Square', 105 'Sub', 106 'SubFromConstant', 107 'Sum', 108 'Tan', 109 110 # Noise 111 'Dropout', 112 113 # Normalization 114 'BatchNormalization', 115 'FixedBatchNormalization', 116 'GroupNormalization', 117 'LocalResponseNormalization', 118 'NormalizeL2', 119 120 # Pooling 121 'AveragePooling2D', 122 'AveragePoolingND', 123 'MaxPooling2D', 124 'MaxPoolingND', 125 'ROIPooling2D', 126 'Unpooling2D', 127 128 # RNN 129 'n_step_gru', 130} 131 132_converters = None 133 134 135def _get_converters(): 136 global _converters 137 138 if _converters is not None: 139 return _converters 140 141 _converters = { 142 name: FunctionConverter(getattr(functions, 'convert_'+name, None)) 143 for name in _supported_function_node_set} 144 return _converters 145 146 147converters = _get_converters() 148 149 150_supported_function_set = { 151 # Math 152 (F, 'sign'), 153 154 # RNN 155 (F, 'n_step_gru'), 156 (F.rnn.n_step_gru, 'n_step_gru'), 157} 158 159 160@contextmanager 161def patch_functions(): 162 org_funcs = {} 163 for mod, name in _supported_function_set: 164 org_func = getattr(mod, name) 165 org_funcs[(mod, name)] = org_func 166 setattr(mod, name, fake_as_funcnode( 167 org_func, name, experimental_warning=False)) 168 try: 169 yield 170 finally: 171 for mod, name in _supported_function_set: 172 setattr(mod, name, org_funcs[(mod, name)]) 173