1#!/usr/bin/python 2 3from __future__ import print_function 4 5from keras.models import Sequential 6from keras.layers import Dense 7from keras.layers import LSTM 8from keras.layers import GRU 9from keras.models import load_model 10from keras import backend as K 11import sys 12import re 13import numpy as np 14 15def printVector(f, vector, name): 16 v = np.reshape(vector, (-1)); 17 #print('static const float ', name, '[', len(v), '] = \n', file=f) 18 f.write('static const rnn_weight {}[{}] = {{\n '.format(name, len(v))) 19 for i in range(0, len(v)): 20 f.write('{}'.format(min(127, int(round(256*v[i]))))) 21 if (i!=len(v)-1): 22 f.write(',') 23 else: 24 break; 25 if (i%8==7): 26 f.write("\n ") 27 else: 28 f.write(" ") 29 #print(v, file=f) 30 f.write('\n};\n\n') 31 return; 32 33def printLayer(f, hf, layer): 34 weights = layer.get_weights() 35 printVector(f, weights[0], layer.name + '_weights') 36 if len(weights) > 2: 37 printVector(f, weights[1], layer.name + '_recurrent_weights') 38 printVector(f, weights[-1], layer.name + '_bias') 39 name = layer.name 40 activation = re.search('function (.*) at', str(layer.activation)).group(1).upper() 41 if len(weights) > 2: 42 f.write('const GRULayer {} = {{\n {}_bias,\n {}_weights,\n {}_recurrent_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 43 .format(name, name, name, name, weights[0].shape[0], weights[0].shape[1]/3, activation)) 44 hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1]/3)) 45 hf.write('extern const GRULayer {};\n\n'.format(name)); 46 else: 47 f.write('const DenseLayer {} = {{\n {}_bias,\n {}_weights,\n {}, {}, ACTIVATION_{}\n}};\n\n' 48 .format(name, name, name, weights[0].shape[0], weights[0].shape[1], activation)) 49 hf.write('#define {}_SIZE {}\n'.format(name.upper(), weights[0].shape[1])) 50 hf.write('extern const DenseLayer {};\n\n'.format(name)); 51 52 53def foo(c, name): 54 return 1 55 56def mean_squared_sqrt_error(y_true, y_pred): 57 return K.mean(K.square(K.sqrt(y_pred) - K.sqrt(y_true)), axis=-1) 58 59 60model = load_model(sys.argv[1], custom_objects={'msse': mean_squared_sqrt_error, 'mean_squared_sqrt_error': mean_squared_sqrt_error, 'my_crossentropy': mean_squared_sqrt_error, 'mycost': mean_squared_sqrt_error, 'WeightClip': foo}) 61 62weights = model.get_weights() 63 64f = open(sys.argv[2], 'w') 65hf = open(sys.argv[3], 'w') 66 67f.write('/*This file is automatically generated from a Keras model*/\n\n') 68f.write('#ifdef HAVE_CONFIG_H\n#include "config.h"\n#endif\n\n#include "rnn.h"\n\n') 69 70hf.write('/*This file is automatically generated from a Keras model*/\n\n') 71hf.write('#ifndef RNN_DATA_H\n#define RNN_DATA_H\n\n#include "rnn.h"\n\n') 72 73layer_list = [] 74for i, layer in enumerate(model.layers): 75 if len(layer.get_weights()) > 0: 76 printLayer(f, hf, layer) 77 if len(layer.get_weights()) > 2: 78 layer_list.append(layer.name) 79 80hf.write('struct RNNState {\n') 81for i, name in enumerate(layer_list): 82 hf.write(' float {}_state[{}_SIZE];\n'.format(name, name.upper())) 83hf.write('};\n') 84 85hf.write('\n\n#endif\n') 86 87f.close() 88hf.close() 89