1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 *     http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16 
17 #include <stdlib.h>
18 
19 #include "tests/test_thread.hpp"
20 
21 #include "rnn/rnn.hpp"
22 #include "rnn/rnn_aux.hpp"
23 
24 #include "rnn/cells.hpp"
25 
26 namespace rnn {
27 template <typename T1, typename T2>
lbr_gru_fwd_postgemm_template(T1 func1,T2 func2,const prb_t & prb,float * gates_,const float * src_iter_,const float * bias_,float * dst_layer_,float * cell_scratchpad_)28 void lbr_gru_fwd_postgemm_template(T1 func1, T2 func2, const prb_t &prb,
29         float *gates_, const float *src_iter_, const float *bias_,
30         float *dst_layer_, float *cell_scratchpad_) {
31     AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
32     AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc);
33     AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
34     AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc);
35     AOC<float> cell_scratchpad(
36             cell_scratchpad_, prb.mb, prb.n_gates(), prb.dhc);
37 
38     for (int64_t i = 0; i < prb.mb; i++)
39         for (int64_t j = 0; j < prb.n_gates() - 1; j++)
40             for (int64_t k = 0; k < prb.dhc; k++) {
41                 gates(i, j, k) = func1(prb.linear_scales[j],
42                         gates(i, j, k) + cell_scratchpad(i, j, k) + bias(j, k));
43             }
44 
45     for (int64_t i = 0; i < prb.mb; i++)
46         for (int64_t k = 0; k < prb.dhc; k++) {
47             gates(i, GRU_O, k) = func2(prb.linear_scales[GRU_O],
48                     gates(i, GRU_O, k)
49                             + gates(i, GRU_R, k)
50                                     * (cell_scratchpad(i, GRU_O, k)
51                                             + bias(LBR_GRU_U_PRIME, k))
52                             + bias(GRU_O, k));
53         }
54 
55     for (int64_t i = 0; i < prb.mb; i++)
56         for (int64_t k = 0; k < prb.dhc; k++) {
57             dst_layer(i, k) = gates(i, GRU_U, k) * src_iter(i, k)
58                     + (1 - gates(i, GRU_U, k)) * gates(i, GRU_O, k);
59         }
60 }
61 
lbr_gru_fwd_postgemm(const prb_t & prb,float * gates_,const float * src_iter_,const float * bias_,float * dst_layer_,float * cell_scratchpad_)62 void lbr_gru_fwd_postgemm(const prb_t &prb, float *gates_,
63         const float *src_iter_, const float *bias_, float *dst_layer_,
64         float *cell_scratchpad_) {
65     if (prb.skip_nonlinear)
66         lbr_gru_fwd_postgemm_template(
67                 [](float scale, float a) { return scale * a; },
68                 [](float scale, float a) { return scale * a; }, prb, gates_,
69                 src_iter_, bias_, dst_layer_, cell_scratchpad_);
70     else
71         lbr_gru_fwd_postgemm_template(
72                 [](float scale, float a) { return logistic(a); },
73                 [](float scale, float a) { return tanhf(a); }, prb, gates_,
74                 src_iter_, bias_, dst_layer_, cell_scratchpad_);
75 }
76 
lbr_gru_fwd(const prb_t & prb,float * dst_layer_,float * gates_,const float * weights_layer_,const float * weights_iter_,const float * bias_,const float * src_layer_,const float * src_iter_,float * cell_scratchpad_)77 void lbr_gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_,
78         const float *weights_layer_, const float *weights_iter_,
79         const float *bias_, const float *src_layer_, const float *src_iter_,
80         float *cell_scratchpad_) {
81     gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0,
82             src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0,
83             gates_, prb.n_gates() * prb.dhc);
84 
85     gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.sic, 1.0,
86             src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 0.0,
87             cell_scratchpad_, prb.n_gates() * prb.dhc);
88 
89     lbr_gru_fwd_postgemm(
90             prb, gates_, src_iter_, bias_, dst_layer_, cell_scratchpad_);
91 }
92 
lbr_gru_bwd_pregemm(const prb_t & prb,const float * src_iter_,const float * diff_dst_layer_,const float * diff_dst_iter_,const float * gates_,const float * Wh_b_,float * diff_src_iter_,float * b_gates_,float * b_gates_r_)93 void lbr_gru_bwd_pregemm(const prb_t &prb, const float *src_iter_,
94         const float *diff_dst_layer_, const float *diff_dst_iter_,
95         const float *gates_, const float *Wh_b_, float *diff_src_iter_,
96         float *b_gates_, float *b_gates_r_) {
97     AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
98     AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc);
99     AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc);
100     AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
101     AOC<const float> Wh_b(Wh_b_, prb.mb, prb.dhc);
102 
103     AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc);
104     AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
105     AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc);
106 
107     // do = (1 - u) * dh; do^ = one_m_square(o) * do;
108     // du = (h - o) * dh; du^ = x_m_square(u) * du;
109     // dr = (Wh + b) * do^; dr^ = x_m_square(r) * dr;
110     for (int64_t ib = 0; ib < prb.mb; ib++)
111         for (int64_t ih = 0; ih < prb.dhc; ih++) {
112             float h = src_iter(ib, ih);
113             float dh = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih);
114             float u = gates(ib, GRU_U, ih);
115             float r = gates(ib, GRU_R, ih);
116             float o = gates(ib, GRU_O, ih);
117             float du = (h - o) * dh;
118             float dO = (1.0f - u) * dh;
119 
120             b_gates(ib, GRU_U, ih) = x_m_square(u) * du;
121             b_gates(ib, GRU_O, ih) = one_m_square(o) * dO;
122 
123             float dr = Wh_b(ib, ih) * b_gates(ib, GRU_O, ih);
124             b_gates(ib, GRU_R, ih) = x_m_square(r) * dr;
125 
126             b_gates_r(ib, GRU_U, ih) = b_gates(ib, GRU_U, ih);
127             b_gates_r(ib, GRU_R, ih) = b_gates(ib, GRU_R, ih);
128             b_gates_r(ib, GRU_O, ih) = b_gates(ib, GRU_O, ih) * r;
129             diff_src_iter(ib, ih) = dh * u;
130         }
131 }
132 
lbr_gru_bwd(const prb_t & prb,float * diff_src_layer_,float * diff_src_iter_,float * diff_weights_layer_,float * diff_weights_iter_,float * diff_bias_,float * b_gates_,const float * src_layer_,const float * src_iter_,const float * weights_layer_,const float * weights_iter_,const float * bias_,const float * gates_,const float * diff_dst_layer_,const float * diff_dst_iter_,float * cell_scratchpad_)133 void lbr_gru_bwd(const prb_t &prb, float *diff_src_layer_,
134         float *diff_src_iter_, float *diff_weights_layer_,
135         float *diff_weights_iter_, float *diff_bias_, float *b_gates_,
136         const float *src_layer_, const float *src_iter_,
137         const float *weights_layer_, const float *weights_iter_,
138         const float *bias_, const float *gates_, const float *diff_dst_layer_,
139         const float *diff_dst_iter_, float *cell_scratchpad_) {
140     AOC<const float> weights_iter(
141             weights_iter_, prb.sic, prb.n_gates(), prb.dhc);
142     AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc);
143 
144     float *Wh_b_ = cell_scratchpad_;
145     float *b_gates_r_ = cell_scratchpad_ + prb.dhc * prb.mb;
146     AOC<float> Wh_b(Wh_b_, prb.mb, prb.dhc);
147     AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc);
148 
149     // TODO: save this this GEMM + bias in the fwd pass
150     for (int64_t ib = 0; ib < prb.mb; ib++)
151         for (int64_t ih = 0; ih < prb.dhc; ih++)
152             Wh_b(ib, ih) = bias(LBR_GRU_U_PRIME, ih);
153 
154     gemm("C", "N", "N", prb.mb, prb.dhc, prb.sic, 1.0, src_iter_, prb.wc,
155             &weights_iter(0, GRU_O, 0), prb.n_gates() * prb.dhc, 1.0, Wh_b_,
156             prb.dhc);
157 
158     lbr_gru_bwd_pregemm(prb, src_iter_, diff_dst_layer_, diff_dst_iter_, gates_,
159             Wh_b_, diff_src_iter_, b_gates_, b_gates_r_);
160 
161     gemm("C", "T", "N", prb.sic, prb.n_gates() * prb.dhc, prb.mb, 1.0,
162             src_iter_, prb.wc, b_gates_r_, prb.n_gates() * prb.dhc, 1.0,
163             diff_weights_iter_, prb.n_gates() * prb.dhc);
164     gemm("C", "T", "N", prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0,
165             src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0,
166             diff_weights_layer_, prb.n_gates() * prb.dhc);
167 
168     gemm("C", "N", "T", prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_,
169             prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc,
170             0.0, diff_src_layer_, prb.wc);
171     gemm("C", "N", "T", prb.mb, prb.sic, prb.n_gates() * prb.dhc, 1.0,
172             b_gates_r_, prb.n_gates() * prb.dhc, weights_iter_,
173             prb.n_gates() * prb.dhc, 1.0, diff_src_iter_, prb.wc);
174 
175     gates_reduction(prb, b_gates_, diff_bias_);
176     for (int64_t i = 0; i < prb.mb; i++)
177         for (int64_t k = 0; k < prb.dhc; k++)
178             diff_bias_[LBR_GRU_U_PRIME * prb.dhc + k] += b_gates_r(i, GRU_O, k);
179 }
180 
181 } // namespace rnn
182