1 /*******************************************************************************
2 * Copyright 2019-2020 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17 #include <stdlib.h>
18
19 #include "tests/test_thread.hpp"
20
21 #include "rnn/rnn.hpp"
22 #include "rnn/rnn_aux.hpp"
23
24 #include "rnn/cells.hpp"
25
26 namespace rnn {
27 template <typename T1, typename T2>
lbr_gru_fwd_postgemm_template(T1 func1,T2 func2,const prb_t & prb,float * gates_,const float * src_iter_,const float * bias_,float * dst_layer_,float * cell_scratchpad_)28 void lbr_gru_fwd_postgemm_template(T1 func1, T2 func2, const prb_t &prb,
29 float *gates_, const float *src_iter_, const float *bias_,
30 float *dst_layer_, float *cell_scratchpad_) {
31 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
32 AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc);
33 AOC<float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
34 AOC<float> dst_layer(dst_layer_, prb.mb, prb.wc);
35 AOC<float> cell_scratchpad(
36 cell_scratchpad_, prb.mb, prb.n_gates(), prb.dhc);
37
38 for (int64_t i = 0; i < prb.mb; i++)
39 for (int64_t j = 0; j < prb.n_gates() - 1; j++)
40 for (int64_t k = 0; k < prb.dhc; k++) {
41 gates(i, j, k) = func1(prb.linear_scales[j],
42 gates(i, j, k) + cell_scratchpad(i, j, k) + bias(j, k));
43 }
44
45 for (int64_t i = 0; i < prb.mb; i++)
46 for (int64_t k = 0; k < prb.dhc; k++) {
47 gates(i, GRU_O, k) = func2(prb.linear_scales[GRU_O],
48 gates(i, GRU_O, k)
49 + gates(i, GRU_R, k)
50 * (cell_scratchpad(i, GRU_O, k)
51 + bias(LBR_GRU_U_PRIME, k))
52 + bias(GRU_O, k));
53 }
54
55 for (int64_t i = 0; i < prb.mb; i++)
56 for (int64_t k = 0; k < prb.dhc; k++) {
57 dst_layer(i, k) = gates(i, GRU_U, k) * src_iter(i, k)
58 + (1 - gates(i, GRU_U, k)) * gates(i, GRU_O, k);
59 }
60 }
61
lbr_gru_fwd_postgemm(const prb_t & prb,float * gates_,const float * src_iter_,const float * bias_,float * dst_layer_,float * cell_scratchpad_)62 void lbr_gru_fwd_postgemm(const prb_t &prb, float *gates_,
63 const float *src_iter_, const float *bias_, float *dst_layer_,
64 float *cell_scratchpad_) {
65 if (prb.skip_nonlinear)
66 lbr_gru_fwd_postgemm_template(
67 [](float scale, float a) { return scale * a; },
68 [](float scale, float a) { return scale * a; }, prb, gates_,
69 src_iter_, bias_, dst_layer_, cell_scratchpad_);
70 else
71 lbr_gru_fwd_postgemm_template(
72 [](float scale, float a) { return logistic(a); },
73 [](float scale, float a) { return tanhf(a); }, prb, gates_,
74 src_iter_, bias_, dst_layer_, cell_scratchpad_);
75 }
76
lbr_gru_fwd(const prb_t & prb,float * dst_layer_,float * gates_,const float * weights_layer_,const float * weights_iter_,const float * bias_,const float * src_layer_,const float * src_iter_,float * cell_scratchpad_)77 void lbr_gru_fwd(const prb_t &prb, float *dst_layer_, float *gates_,
78 const float *weights_layer_, const float *weights_iter_,
79 const float *bias_, const float *src_layer_, const float *src_iter_,
80 float *cell_scratchpad_) {
81 gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.slc, 1.0,
82 src_layer_, prb.wc, weights_layer_, prb.n_gates() * prb.dhc, 0.0,
83 gates_, prb.n_gates() * prb.dhc);
84
85 gemm("C", "N", "N", prb.mb, prb.n_gates() * prb.dhc, prb.sic, 1.0,
86 src_iter_, prb.wc, weights_iter_, prb.n_gates() * prb.dhc, 0.0,
87 cell_scratchpad_, prb.n_gates() * prb.dhc);
88
89 lbr_gru_fwd_postgemm(
90 prb, gates_, src_iter_, bias_, dst_layer_, cell_scratchpad_);
91 }
92
lbr_gru_bwd_pregemm(const prb_t & prb,const float * src_iter_,const float * diff_dst_layer_,const float * diff_dst_iter_,const float * gates_,const float * Wh_b_,float * diff_src_iter_,float * b_gates_,float * b_gates_r_)93 void lbr_gru_bwd_pregemm(const prb_t &prb, const float *src_iter_,
94 const float *diff_dst_layer_, const float *diff_dst_iter_,
95 const float *gates_, const float *Wh_b_, float *diff_src_iter_,
96 float *b_gates_, float *b_gates_r_) {
97 AOC<const float> src_iter(src_iter_, prb.mb, prb.wc);
98 AOC<const float> diff_dst_layer(diff_dst_layer_, prb.mb, prb.wc);
99 AOC<const float> diff_dst_iter(diff_dst_iter_, prb.mb, prb.wc);
100 AOC<const float> gates(gates_, prb.mb, prb.n_gates(), prb.dhc);
101 AOC<const float> Wh_b(Wh_b_, prb.mb, prb.dhc);
102
103 AOC<float> diff_src_iter(diff_src_iter_, prb.mb, prb.wc);
104 AOC<float> b_gates(b_gates_, prb.mb, prb.n_gates(), prb.dhc);
105 AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc);
106
107 // do = (1 - u) * dh; do^ = one_m_square(o) * do;
108 // du = (h - o) * dh; du^ = x_m_square(u) * du;
109 // dr = (Wh + b) * do^; dr^ = x_m_square(r) * dr;
110 for (int64_t ib = 0; ib < prb.mb; ib++)
111 for (int64_t ih = 0; ih < prb.dhc; ih++) {
112 float h = src_iter(ib, ih);
113 float dh = diff_dst_layer(ib, ih) + diff_dst_iter(ib, ih);
114 float u = gates(ib, GRU_U, ih);
115 float r = gates(ib, GRU_R, ih);
116 float o = gates(ib, GRU_O, ih);
117 float du = (h - o) * dh;
118 float dO = (1.0f - u) * dh;
119
120 b_gates(ib, GRU_U, ih) = x_m_square(u) * du;
121 b_gates(ib, GRU_O, ih) = one_m_square(o) * dO;
122
123 float dr = Wh_b(ib, ih) * b_gates(ib, GRU_O, ih);
124 b_gates(ib, GRU_R, ih) = x_m_square(r) * dr;
125
126 b_gates_r(ib, GRU_U, ih) = b_gates(ib, GRU_U, ih);
127 b_gates_r(ib, GRU_R, ih) = b_gates(ib, GRU_R, ih);
128 b_gates_r(ib, GRU_O, ih) = b_gates(ib, GRU_O, ih) * r;
129 diff_src_iter(ib, ih) = dh * u;
130 }
131 }
132
lbr_gru_bwd(const prb_t & prb,float * diff_src_layer_,float * diff_src_iter_,float * diff_weights_layer_,float * diff_weights_iter_,float * diff_bias_,float * b_gates_,const float * src_layer_,const float * src_iter_,const float * weights_layer_,const float * weights_iter_,const float * bias_,const float * gates_,const float * diff_dst_layer_,const float * diff_dst_iter_,float * cell_scratchpad_)133 void lbr_gru_bwd(const prb_t &prb, float *diff_src_layer_,
134 float *diff_src_iter_, float *diff_weights_layer_,
135 float *diff_weights_iter_, float *diff_bias_, float *b_gates_,
136 const float *src_layer_, const float *src_iter_,
137 const float *weights_layer_, const float *weights_iter_,
138 const float *bias_, const float *gates_, const float *diff_dst_layer_,
139 const float *diff_dst_iter_, float *cell_scratchpad_) {
140 AOC<const float> weights_iter(
141 weights_iter_, prb.sic, prb.n_gates(), prb.dhc);
142 AOC<const float> bias(bias_, prb.n_gates() + 1, prb.dhc);
143
144 float *Wh_b_ = cell_scratchpad_;
145 float *b_gates_r_ = cell_scratchpad_ + prb.dhc * prb.mb;
146 AOC<float> Wh_b(Wh_b_, prb.mb, prb.dhc);
147 AOC<float> b_gates_r(b_gates_r_, prb.mb, prb.n_gates(), prb.dhc);
148
149 // TODO: save this this GEMM + bias in the fwd pass
150 for (int64_t ib = 0; ib < prb.mb; ib++)
151 for (int64_t ih = 0; ih < prb.dhc; ih++)
152 Wh_b(ib, ih) = bias(LBR_GRU_U_PRIME, ih);
153
154 gemm("C", "N", "N", prb.mb, prb.dhc, prb.sic, 1.0, src_iter_, prb.wc,
155 &weights_iter(0, GRU_O, 0), prb.n_gates() * prb.dhc, 1.0, Wh_b_,
156 prb.dhc);
157
158 lbr_gru_bwd_pregemm(prb, src_iter_, diff_dst_layer_, diff_dst_iter_, gates_,
159 Wh_b_, diff_src_iter_, b_gates_, b_gates_r_);
160
161 gemm("C", "T", "N", prb.sic, prb.n_gates() * prb.dhc, prb.mb, 1.0,
162 src_iter_, prb.wc, b_gates_r_, prb.n_gates() * prb.dhc, 1.0,
163 diff_weights_iter_, prb.n_gates() * prb.dhc);
164 gemm("C", "T", "N", prb.slc, prb.n_gates() * prb.dhc, prb.mb, 1.0,
165 src_layer_, prb.wc, b_gates_, prb.n_gates() * prb.dhc, 1.0,
166 diff_weights_layer_, prb.n_gates() * prb.dhc);
167
168 gemm("C", "N", "T", prb.mb, prb.slc, prb.n_gates() * prb.dhc, 1.0, b_gates_,
169 prb.n_gates() * prb.dhc, weights_layer_, prb.n_gates() * prb.dhc,
170 0.0, diff_src_layer_, prb.wc);
171 gemm("C", "N", "T", prb.mb, prb.sic, prb.n_gates() * prb.dhc, 1.0,
172 b_gates_r_, prb.n_gates() * prb.dhc, weights_iter_,
173 prb.n_gates() * prb.dhc, 1.0, diff_src_iter_, prb.wc);
174
175 gates_reduction(prb, b_gates_, diff_bias_);
176 for (int64_t i = 0; i < prb.mb; i++)
177 for (int64_t k = 0; k < prb.dhc; k++)
178 diff_bias_[LBR_GRU_U_PRIME * prb.dhc + k] += b_gates_r(i, GRU_O, k);
179 }
180
181 } // namespace rnn
182