1 #include "Halide.h"
2 
3 namespace {
4 
5 using namespace Halide;
6 using namespace Halide::BoundaryConditions;
7 
8 class DepthwiseSeparableConvolution : public Generator<DepthwiseSeparableConvolution> {
9 public:
10     // [in_channels, width, height, batch_size]
11     Input<Buffer<float>> input{"input", 4};
12 
13     // [channel_multiplier, in_channels, filter_width, filter_height]
14     Input<Buffer<float>> depthwise_filter{"depthwise_filter", 4};
15 
16     // [out_channels, channel_multiplier * in_channels]
17     Input<Buffer<float>> pointwise_filter{"pointwise_filter", 2};
18 
19     // [out_channels]
20     Input<Buffer<float>> bias{"bias", 1};
21 
22     // [out_channels, width, height, batch_size]
23     Output<Buffer<float>> output{"output", 4};
24 
generate()25     void generate() {
26         // The algorithm. It will be a generic depthwise convolution,
27         // with no assumptions about input sizes or shapes. This makes
28         // it especially challenging to schedule.
29 
30         // Some free variables, where x and y represent the spatial dimensions.
31         Var x("x"), y("y"), d("d"), b("b");
32 
33         // Pad x and y with 0. Unfortunately the built-in boundary
34         // condition helpers cause unwanted loop partitioning.
35         Func input_bounded;
36         Expr in_bounds = (x >= 0 && x < input.dim(1).extent() &&
37                           y >= 0 && y < input.dim(2).extent());
38         Expr clamped_x = clamp(x, 0, input.dim(1).max());
39         Expr clamped_y = clamp(y, 0, input.dim(2).max());
40         input_bounded(d, x, y, b) =
41             select(in_bounds, input(d, clamped_x, clamped_y, b), 0.0f);
42 
43         Expr channel_multiplier = depthwise_filter.dim(0).extent();
44 
45         // Convolve the image depthwise -- for each input channel,
46         // generate channel_multiplier number of intermediate channels using convolution
47         Func depthwise_convolved("depthwise_convolved");
48         Expr pad_width = depthwise_filter.dim(2).extent() / 2;
49         Expr pad_height = depthwise_filter.dim(3).extent() / 2;
50         RDom depthwise_filter_dom(0, depthwise_filter.dim(0).extent(),
51                                   0, depthwise_filter.dim(2).extent(),
52                                   0, depthwise_filter.dim(3).extent());
53         // Give clearer names to the reduction over input channels (depth), x and y.
54         RVar rd = depthwise_filter_dom[0];
55         RVar rx = depthwise_filter_dom[1];
56         RVar ry = depthwise_filter_dom[2];
57         depthwise_convolved(d, x, y, b) +=
58             depthwise_filter(rd, d, rx, ry) *
59             input_bounded(d / channel_multiplier,
60                           x + rx - pad_width,
61                           y + ry - pad_height,
62                           b);
63 
64         // Convolve the image point-wise: for each pixel we map from
65         // input_channels * channel_multiplier number of channels to output_channels
66         Func pointwise_convolved("pointwise_convolved");
67         // Reduction over the channels in the depthwise output
68         RDom rc(0, pointwise_filter.dim(1).extent());
69         pointwise_convolved(d, x, y, b) = bias(d);
70         pointwise_convolved(d, x, y, b) +=
71             pointwise_filter(d, rc) * depthwise_convolved(rc, x, y, b);
72 
73         // ReLU
74         output(d, x, y, b) = max(pointwise_convolved(d, x, y, b), 0.f);
75 
76         // The schedule.
77         if (auto_schedule) {
78             // Second layer of MobileNet v2
79             const int N = 4, CI = 32, CO = 16, CM = 1, W = 112, H = 112;
80 
81             input.dim(0).set_estimate(0, CI);
82             input.dim(1).set_estimate(0, W);
83             input.dim(2).set_estimate(0, H);
84             input.dim(3).set_estimate(0, N);
85 
86             depthwise_filter.dim(0).set_estimate(0, CI / CO);
87             depthwise_filter.dim(1).set_estimate(0, CI);
88             depthwise_filter.dim(2).set_estimate(0, 3);
89             depthwise_filter.dim(3).set_estimate(0, 3);
90 
91             pointwise_filter.dim(0).set_estimate(0, CO);
92             pointwise_filter.dim(1).set_estimate(0, CI * CM);
93 
94             bias.dim(0).set_estimate(0, CO);
95 
96             output.dim(0).set_estimate(0, CO);
97             output.dim(1).set_estimate(0, W);
98             output.dim(2).set_estimate(0, H);
99             output.dim(3).set_estimate(0, N);
100         } else if (get_target().has_gpu_feature()) {
101             // 0.066ms on a 2060 RTX super. This is about 1.2 TFlops,
102             // which is not a very large fraction of peak. For
103             // comparison though, tensorflow 2.3 achieves 0.13ms via
104             // cudnn 7. So we're twice as fast.
105 
106             // This schedule fuses the depthwise conv into the pointwise
107             // conv. The results of the depthwise conv are computed inside
108             // the outer of the two pointwise reduction loops.
109 
110             Var xi, yi, di, dii, xii, yii;
111             RVar ro, ri;
112 
113             // The pointwise convolution kernel. Produces a 4x4 tile of output.
114             Func(output)
115                 .tile({d, x, y}, {di, xi, yi}, {16, 4, 4})
116                 .tile({di, xi, yi}, {dii, xii, yii}, {1, 2, 2})
117                 .gpu_threads(di, xi, yi)
118                 .fuse(y, b, b)
119                 .gpu_blocks(d, x, b)
120                 .unroll(xii)
121                 .unroll(yii)
122                 .unroll(dii);
123 
124             pointwise_convolved.compute_at(output, di)
125                 .reorder(x, y, d)
126                 .unroll(x)
127                 .unroll(y)
128                 .unroll(d)
129                 .update()
130                 .unroll(x)
131                 .unroll(y)
132                 .unroll(d)
133                 .split(rc, ro, ri, 4)
134                 .reorder(ri, x, y, d, ro)
135                 .unroll(ri);
136 
137             // We're going to call in() on depthwise_convolved twice.
138             // The first will be to give it a wrapper to do the
139             // accumulation in registers before writing the result to
140             // shared. The second will be staging the loads from
141             // shared into registers. We write them in reverse order
142             // below:
143 
144             // We can do 4-wide vectorized loads from shared memory if
145             // we unroll the reduction loop by a factor of four above
146             // and stage the loads from the depthwise_convolved
147             // output.
148 
149             depthwise_convolved.in()
150                 .in()
151                 .compute_at(pointwise_convolved, x)
152                 .bound_extent(d, 4)
153                 .vectorize(d)
154                 .unroll(x)
155                 .unroll(y);
156 
157             // The depthwise convolution kernel. Produces a 4x4 tile
158             // of intermediate state, storing the result in shared.
159             depthwise_convolved.in()
160                 .compute_at(output, d)
161                 .tile({d, x, y}, {di, xi, yi}, {32, 4, 4}, TailStrategy::RoundUp)
162                 .tile({di, xi, yi}, {dii, xii, yii}, {2, 2, 2})
163                 .gpu_threads(di, xi, yi)
164                 .unroll(xii)
165                 .unroll(yii)
166                 .unroll(dii);
167 
168             depthwise_convolved
169                 .compute_at(depthwise_convolved.in(), di)
170                 .unroll(x)
171                 .unroll(y)
172                 .unroll(d)
173                 .update()
174                 .reorder(d, x, y, rx, ry, rd)
175                 .unroll(x)
176                 .unroll(y)
177                 .unroll(d);
178         } else {
179             // CPU schedule
180 
181             // 0.13ms on an Intel i9-9960X using 16 threads pinned to 3.0 GHz,
182             // which is only about 20% of peak flops.
183 
184             int tile_w = 1;
185             int tile_h = 1;
186             int tile_d = 1;
187             const int vec = natural_vector_size<float>();
188 
189             // Figure out how many registers we have in the register
190             // file on this target.
191             int num_regs = 16;
192             if (get_target().has_feature(Target::AVX512_Skylake) ||
193                 (get_target().arch == Target::ARM &&
194                  get_target().bits == 64)) {
195                 num_regs = 32;
196             }
197 
198             // Pick a tile size designed to fit into the register file.
199             if (num_regs == 32 && vec == 16) {
200                 // 32 vector registers available of size 16. Use 24 of
201                 // them for accumulators.
202                 tile_d = 1;
203                 tile_w = 6;
204                 tile_h = 4;
205                 // Using more tiles in the d dimension would be
206                 // better, but we're tuning for 16 output channels and
207                 // our vectors are already that wide (on avx512).
208             } else if (num_regs == 32 && vec == 4) {
209                 // 32 vector registers, of size 4. We'll use 24.
210                 tile_d = 4;
211                 tile_w = 3;
212                 tile_h = 2;
213             } else if (num_regs == 16 && vec == 8) {
214                 // 16 registers available of size 8. Use 12 for accumulators.
215                 tile_d = 2;
216                 tile_w = 3;
217                 tile_h = 2;
218             } else {
219                 // Old x86 or 32-bit arm. Assume vectors of size 4,
220                 // 16 registers. No FMA so we need to reserve a few
221                 // more registers for things other than the
222                 // accumulators.
223                 tile_d = 4;
224                 tile_w = 2;
225                 tile_h = 1;
226             }
227             // Change units from vectors to elements
228             tile_d *= vec;
229 
230             // This schedule aggressively fuses the depthwise conv into
231             // the pointwise conv. We do the depthwise convolution within
232             // slices of the channel reduction loop in the pointwise
233             // convolution.
234 
235             Var di, xi, yi;
236             RVar ro, ri;
237 
238             Func(output)
239                 .tile({d, x, y}, {di, xi, yi}, {tile_d, tile_w, tile_h})
240                 .vectorize(di)
241                 .unroll(xi)
242                 .unroll(yi)
243                 .fuse(y, b, b)
244                 .parallel(b);
245 
246             pointwise_convolved.compute_at(output, d)
247                 .vectorize(d)
248                 .unroll(x)
249                 .unroll(y)
250                 .update()
251                 .reorder(d, x, y, rc, b)
252                 .vectorize(d)
253                 .unroll(x)
254                 .unroll(y)
255                 .split(rc, ro, ri, tile_d);
256 
257             depthwise_convolved
258                 .store_in(MemoryType::Stack)
259                 .bound_extent(d, tile_d)
260                 .compute_at(pointwise_convolved, ro)
261                 .vectorize(d)
262                 .reorder(x, y, d)
263                 .unroll(x)
264                 .unroll(y)
265                 .update()
266                 .vectorize(d)
267                 .reorder(x, y, d, rd, rx, ry, b)
268                 .unroll(x)
269                 .unroll(y);
270 
271             input_bounded
272                 .store_in(MemoryType::Stack)
273                 .compute_at(pointwise_convolved, ro)
274                 .tile(d, x, di, xi, vec, 4, TailStrategy::RoundUp)
275                 .vectorize(di)
276                 .unroll(xi);
277         }
278 
279         if (!auto_schedule) {
280             // We're going to specialize both schedules for channel_multiplier = 1,
281             // in which case it's nice to know that depthwise_filter
282             // is dense across the second dimension.
283             depthwise_filter.dim(1).set_stride(channel_multiplier);
284             Expr intermediate_channels = pointwise_filter.dim(1).extent();
285             // We'll also specialize for a multiple-of-32 intermediate
286             // channels, and a 3x3 conv.
287             output.specialize(channel_multiplier == 1 &&
288                               intermediate_channels == (intermediate_channels / 32) * 32 &&
289                               depthwise_filter.dim(2).extent() == 3 &&
290                               depthwise_filter.dim(3).extent() == 3);
291         }
292     }
293 };
294 }  // namespace
295 
296 HALIDE_REGISTER_GENERATOR(DepthwiseSeparableConvolution, depthwise_separable_conv)
297