1#!/usr/bin/python
2
3from __future__ import print_function
4
5from keras.models import Sequential
6from keras.layers import Dense
7from keras.layers import LSTM
8from keras.layers import GRU
9from keras.models import load_model
10from keras import backend as K
11import sys
12import re
13import numpy as np
14
15def printVector(f, vector, name):
16    v = np.reshape(vector, (-1));
17    #print('static const float ', name, '[', len(v), '] = \n', file=f)
18    f.write('static const rnn_weight {}[{}] = {{\n   '.format(name, len(v)))
19    for i in range(0, len(v)):
20        f.write('{}'.format(min(127, int(round(256*v[i])))))
21        if (i!=len(v)-1):
22            f.write(',')
23        else:
24            break;
25        if (i%8==7):
26            f.write("\n   ")
27        else:
28            f.write(" ")
29    #print(v, file=f)
30    f.write('\n};\n\n')
31    return;
32
33def printLayer(f, hf, layer):
34    weights = layer.get_weights()
35    printVector(f, weights[0], layer.name + '_weights')
36    if len(weights) > 2:
37        printVector(f, weights[1], layer.name + '_recurrent_weights')
38    printVector(f, weights[-1], layer.name + '_bias')
39    name = layer.name
40    activation = re.search('function (.*) at', str(layer.activation)).group(1).upper()
41    if len(weights) > 2:
42        f.write('const GRULayer {} = {{\n   {}_bias,\n   {}_weights,\n   {}_recurrent_weights,\n   {}, {}, ACTIVATION_{}\n}};\n\n'
43                .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation))
44        hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]/3))
45        hf.write('extern const GRULayer {};\n\n'.format(name));
46    else:
47        f.write('const DenseLayer {} = {{\n   {}_bias,\n   {}_weights,\n   {}, {}, ACTIVATION_{}\n}};\n\n'
48                .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation))
49        hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]))
50        hf.write('extern const DenseLayer {};\n\n'.format(name));
51
52
53def foo(c, name):
54    return 1
55
56def mean_squared_sqrt_error(y_true, y_pred):
57    return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1)
58
59
60model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, 'mean_squared_sqrt_error': mean_squared_sqrt_error, 'my_crossentropy': mean_squared_sqrt_error, 'mycost': mean_squared_sqrt_error, 'WeightClip': foo})
61
62weights = model.get_weights()
63
64f = open(sys.argv[2], 'w')
65hf = open(sys.argv[3], 'w')
66
67f.write('/*This file is automatically generated from a Keras model*/\n\n')
68f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n')
69
70hf.write('/*This file is automatically generated from a Keras model*/\n\n')
71hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "rnn.h"\n\n')
72
73layer_list = []
74for i, layer in enumerate(model.layers):
75    if len(layer.get_weights()) > 0:
76        printLayer(f, hf, layer)
77    if len(layer.get_weights()) > 2:
78        layer_list.append(layer.name)
79
80hf.write('struct RNNState {\n')
81for i, name in enumerate(layer_list):
82    hf.write('  float {}_state[{}_SIZE];\n'.format(name, name.upper()))
83hf.write('};\n')
84
85hf.write('\n\n#endif\n')
86
87f.close()
88hf.close()
89