1"""Example of using Hierarchical RNN (HRNN) to classify MNIST digits.
2
3HRNNs can learn across multiple levels
4of temporal hierarchy over a complex sequence.
5Usually, the first recurrent layer of an HRNN
6encodes a sentence (e.g. of word vectors)
7into a  sentence vector.
8The second recurrent layer then encodes a sequence of
9such vectors (encoded by the first layer) into a document vector.
10This document vector is considered to preserve both
11the word-level and sentence-level structure of the context.
12
13# References
14
15- [A Hierarchical Neural Autoencoder for Paragraphs and Documents]
16    (https://arxiv.org/abs/1506.01057)
17    Encodes paragraphs and documents with HRNN.
18    Results have shown that HRNN outperforms standard
19    RNNs and may play some role in more sophisticated generation tasks like
20    summarization or question answering.
21- [Hierarchical recurrent neural network for skeleton based action recognition]
22    (http://ieeexplore.ieee.org/stamp/stamp.jsp?tp=&arnumber=7298714)
23    Achieved state-of-the-art results on
24    skeleton based action recognition with 3 levels
25    of bidirectional HRNN combined with fully connected layers.
26
27In the below MNIST example the first LSTM layer first encodes every
28column of pixels of shape (28, 1) to a column vector of shape (128,).
29The second LSTM layer encodes then these 28 column vectors of shape (28, 128)
30to a image vector representing the whole image.
31A final Dense layer is added for prediction.
32
33After 5 epochs: train acc: 0.9858, val acc: 0.9864
34"""
35from __future__ import print_function
36
37import keras
38from keras.datasets import mnist
39from keras.models import Model
40from keras.layers import Input, Dense, TimeDistributed
41from keras.layers import LSTM
42
43# Training parameters.
44batch_size = 32
45num_classes = 10
46epochs = 5
47
48# Embedding dimensions.
49row_hidden = 128
50col_hidden = 128
51
52# The data, split between train and test sets.
53(x_train, y_train), (x_test, y_test) = mnist.load_data()
54
55# Reshapes data to 4D for Hierarchical RNN.
56x_train = x_train.reshape(x_train.shape[0], 28, 28, 1)
57x_test = x_test.reshape(x_test.shape[0], 28, 28, 1)
58x_train = x_train.astype('float32')
59x_test = x_test.astype('float32')
60x_train /= 255
61x_test /= 255
62print('x_train shape:', x_train.shape)
63print(x_train.shape[0], 'train samples')
64print(x_test.shape[0], 'test samples')
65
66# Converts class vectors to binary class matrices.
67y_train = keras.utils.to_categorical(y_train, num_classes)
68y_test = keras.utils.to_categorical(y_test, num_classes)
69
70row, col, pixel = x_train.shape[1:]
71
72# 4D input.
73x = Input(shape=(row, col, pixel))
74
75# Encodes a row of pixels using TimeDistributed Wrapper.
76encoded_rows = TimeDistributed(LSTM(row_hidden))(x)
77
78# Encodes columns of encoded rows.
79encoded_columns = LSTM(col_hidden)(encoded_rows)
80
81# Final predictions and model.
82prediction = Dense(num_classes, activation='softmax')(encoded_columns)
83model = Model(x, prediction)
84model.compile(loss='categorical_crossentropy',
85              optimizer='rmsprop',
86              metrics=['accuracy'])
87
88# Training.
89model.fit(x_train, y_train,
90          batch_size=batch_size,
91          epochs=epochs,
92          verbose=1,
93          validation_data=(x_test, y_test))
94
95# Evaluation.
96scores = model.evaluate(x_test, y_test, verbose=0)
97print('Test loss:', scores[0])
98print('Test accuracy:', scores[1])
99