1# Licensed to the Apache Software Foundation (ASF) under one
2# or more contributor license agreements.  See the NOTICE file
3# distributed with this work for additional information
4# regarding copyright ownership.  The ASF licenses this file
5# to you under the Apache License, Version 2.0 (the
6# "License"); you may not use this file except in compliance
7# with the License.  You may obtain a copy of the License at
8#
9#   http://www.apache.org/licenses/LICENSE-2.0
10#
11# Unless required by applicable law or agreed to in writing,
12# software distributed under the License is distributed on an
13# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14# KIND, either express or implied.  See the License for the
15# specific language governing permissions and limitations
16# under the License.
17
18import numpy as np
19from scipy import linalg as LA
20import mxnet as mx
21import argparse
22import utils
23
24def conv_vh_decomposition(model, args):
25  W = model.arg_params[args.layer+'_weight'].asnumpy()
26  N, C, y, x = W.shape
27  b = model.arg_params[args.layer+'_bias'].asnumpy()
28  W = W.transpose((1,2,0,3)).reshape((C*y, -1))
29
30  U, D, Q = np.linalg.svd(W, full_matrices=False)
31  sqrt_D = LA.sqrtm(np.diag(D))
32  K = args.K
33  V = U[:,:K].dot(sqrt_D[:K, :K])
34  H = Q.T[:,:K].dot(sqrt_D[:K, :K])
35  V = V.T.reshape(K, C, y, 1)
36  b_1 = np.zeros((K, ))
37  H = H.reshape(N, x, 1, K).transpose((0,3,2,1))
38  b_2 = b
39
40  W1, b1, W2, b2 = V, b_1, H, b_2
41  def sym_handle(data, node):
42    kernel = eval(node['param']['kernel'])
43    pad = eval(node['param']['pad'])
44    name = node['name']
45
46    name1 = name + '_v'
47    kernel1 = tuple((kernel[0], 1))
48    pad1 = tuple((pad[0], 0))
49    num_filter = W1.shape[0]
50    sym1 = mx.symbol.Convolution(data=data, kernel=kernel1, pad=pad1, num_filter=num_filter, name=name1)
51
52    name2 = name + '_h'
53    kernel2 = tuple((1, kernel[1]))
54    pad2 = tuple((0, pad[1]))
55    num_filter = W2.shape[0]
56    sym2 = mx.symbol.Convolution(data=sym1, kernel=kernel2, pad=pad2, num_filter=num_filter, name=name2)
57    return sym2
58
59  def arg_handle(arg_shape_dic, arg_params):
60    name1 = args.layer + '_v'
61    name2 = args.layer + '_h'
62    weight1 = mx.ndarray.array(W1)
63    bias1 = mx.ndarray.array(b1)
64    weight2 = mx.ndarray.array(W2)
65    bias2 = mx.ndarray.array(b2)
66    assert weight1.shape == arg_shape_dic[name1+'_weight'], 'weight1'
67    assert weight2.shape == arg_shape_dic[name2+'_weight'], 'weight2'
68    assert bias1.shape == arg_shape_dic[name1+'_bias'], 'bias1'
69    assert bias2.shape == arg_shape_dic[name2+'_bias'], 'bias2'
70
71    arg_params[name1 + '_weight'] = weight1
72    arg_params[name1 + '_bias'] = bias1
73    arg_params[name2 + '_weight'] = weight2
74    arg_params[name2 + '_bias'] = bias2
75
76  new_model = utils.replace_conv_layer(args.layer, model, sym_handle, arg_handle)
77  return new_model
78
79def main():
80  model = utils.load_model(args)
81  new_model = conv_vh_decomposition(model, args)
82  new_model.save(args.save_model)
83
84if __name__ == '__main__':
85  parser=argparse.ArgumentParser()
86  parser.add_argument('-m', '--model', help='the model to speed up')
87  parser.add_argument('-g', '--gpus', default='0', help='the gpus to be used in ctx')
88  parser.add_argument('--load-epoch',type=int,default=1)
89  parser.add_argument('--layer')
90  parser.add_argument('--K', type=int)
91  parser.add_argument('--save-model')
92  args = parser.parse_args()
93  main()
94
95