1from contextlib import contextmanager
2
3import chainer.functions as F
4
5from onnx_chainer import functions
6from onnx_chainer.functions.converter import FunctionConverter
7from onnx_chainer.replace_func import fake_as_funcnode
8
9
10_supported_function_node_set = {
11    # Activation
12    'ClippedReLU',
13    'ELU',
14    'HardSigmoid',
15    'LeakyReLU',
16    'LogSoftmax',
17    'PReLUFunction',
18    'ReLU',
19    'Selu',
20    'Sigmoid',
21    'Softmax',
22    'Softplus',
23    'Tanh',
24
25    # Array
26    'Cast',
27    'Concat',
28    'Copy',
29    'Depth2Space',
30    'Dstack',
31    'ExpandDims',
32    'GetItem',
33    'Hstack',
34    'Moveaxis',
35    'Pad',
36    'Permutate',
37    'Repeat',
38    'Reshape',
39    'ResizeImages',
40    'Rollaxis',
41    'SelectItem',
42    'Separate',
43    'Shape',
44    'Space2Depth',
45    'SplitAxis',
46    'Squeeze',
47    'Stack',
48    'Swapaxes',
49    'Tile',
50    'Transpose',
51    'TransposeSequence',
52    'Vstack',
53    'Where',
54
55    # Connection
56    'Convolution2DFunction',
57    'ConvolutionND',
58    'Deconvolution2DFunction',
59    'DeconvolutionND',
60    'EmbedIDFunction',
61    'LinearFunction',
62
63    # Loss
64    'SoftmaxCrossEntropy',
65
66    # Math
67    'Absolute',
68    'Add',
69    'AddConstant',
70    'Arccos',
71    'Arcsin',
72    'Arctan',
73    'ArgMax',
74    'ArgMin',
75    'BroadcastTo',
76    'Clip',
77    'Cos',
78    'Cosh',
79    'Div',
80    'DivFromConstant',
81    'Exp',
82    'Identity',
83    'LinearInterpolate',
84    'Log',
85    'LogSumExp',
86    'MatMul',
87    'Max',
88    'Maximum',
89    'Mean',
90    'Min',
91    'Minimum',
92    'Mul',
93    'MulConstant',
94    'Neg',
95    'PowConstVar',
96    'PowVarConst',
97    'PowVarVar',
98    'Prod',
99    'RsqrtGPU',
100    'sign',
101    'Sin',
102    'Sinh',
103    'Sqrt',
104    'Square',
105    'Sub',
106    'SubFromConstant',
107    'Sum',
108    'Tan',
109
110    # Noise
111    'Dropout',
112
113    # Normalization
114    'BatchNormalization',
115    'FixedBatchNormalization',
116    'GroupNormalization',
117    'LocalResponseNormalization',
118    'NormalizeL2',
119
120    # Pooling
121    'AveragePooling2D',
122    'AveragePoolingND',
123    'MaxPooling2D',
124    'MaxPoolingND',
125    'ROIPooling2D',
126    'Unpooling2D',
127
128    # RNN
129    'n_step_gru',
130}
131
132_converters = None
133
134
135def _get_converters():
136    global _converters
137
138    if _converters is not None:
139        return _converters
140
141    _converters = {
142        name: FunctionConverter(getattr(functions, 'convert_'+name, None))
143        for name in _supported_function_node_set}
144    return _converters
145
146
147converters = _get_converters()
148
149
150_supported_function_set = {
151    # Math
152    (F, 'sign'),
153
154    # RNN
155    (F, 'n_step_gru'),
156    (F.rnn.n_step_gru, 'n_step_gru'),
157}
158
159
160@contextmanager
161def patch_functions():
162    org_funcs = {}
163    for mod, name in _supported_function_set:
164        org_func = getattr(mod, name)
165        org_funcs[(mod, name)] = org_func
166        setattr(mod, name, fake_as_funcnode(
167            org_func, name, experimental_warning=False))
168    try:
169        yield
170    finally:
171        for mod, name in _supported_function_set:
172            setattr(mod, name, org_funcs[(mod, name)])
173