1 #include "Halide.h" 2 3 namespace { 4 5 using namespace Halide; 6 using namespace Halide::BoundaryConditions; 7 8 class DepthwiseSeparableConvolution : public Generator<DepthwiseSeparableConvolution> { 9 public: 10 // [in_channels, width, height, batch_size] 11 Input<Buffer<float>> input{"input", 4}; 12 13 // [channel_multiplier, in_channels, filter_width, filter_height] 14 Input<Buffer<float>> depthwise_filter{"depthwise_filter", 4}; 15 16 // [out_channels, channel_multiplier * in_channels] 17 Input<Buffer<float>> pointwise_filter{"pointwise_filter", 2}; 18 19 // [out_channels] 20 Input<Buffer<float>> bias{"bias", 1}; 21 22 // [out_channels, width, height, batch_size] 23 Output<Buffer<float>> output{"output", 4}; 24 generate()25 void generate() { 26 // The algorithm. It will be a generic depthwise convolution, 27 // with no assumptions about input sizes or shapes. This makes 28 // it especially challenging to schedule. 29 30 // Some free variables, where x and y represent the spatial dimensions. 31 Var x("x"), y("y"), d("d"), b("b"); 32 33 // Pad x and y with 0. Unfortunately the built-in boundary 34 // condition helpers cause unwanted loop partitioning. 35 Func input_bounded; 36 Expr in_bounds = (x >= 0 && x < input.dim(1).extent() && 37 y >= 0 && y < input.dim(2).extent()); 38 Expr clamped_x = clamp(x, 0, input.dim(1).max()); 39 Expr clamped_y = clamp(y, 0, input.dim(2).max()); 40 input_bounded(d, x, y, b) = 41 select(in_bounds, input(d, clamped_x, clamped_y, b), 0.0f); 42 43 Expr channel_multiplier = depthwise_filter.dim(0).extent(); 44 45 // Convolve the image depthwise -- for each input channel, 46 // generate channel_multiplier number of intermediate channels using convolution 47 Func depthwise_convolved("depthwise_convolved"); 48 Expr pad_width = depthwise_filter.dim(2).extent() / 2; 49 Expr pad_height = depthwise_filter.dim(3).extent() / 2; 50 RDom depthwise_filter_dom(0, depthwise_filter.dim(0).extent(), 51 0, depthwise_filter.dim(2).extent(), 52 0, depthwise_filter.dim(3).extent()); 53 // Give clearer names to the reduction over input channels (depth), x and y. 54 RVar rd = depthwise_filter_dom[0]; 55 RVar rx = depthwise_filter_dom[1]; 56 RVar ry = depthwise_filter_dom[2]; 57 depthwise_convolved(d, x, y, b) += 58 depthwise_filter(rd, d, rx, ry) * 59 input_bounded(d / channel_multiplier, 60 x + rx - pad_width, 61 y + ry - pad_height, 62 b); 63 64 // Convolve the image point-wise: for each pixel we map from 65 // input_channels * channel_multiplier number of channels to output_channels 66 Func pointwise_convolved("pointwise_convolved"); 67 // Reduction over the channels in the depthwise output 68 RDom rc(0, pointwise_filter.dim(1).extent()); 69 pointwise_convolved(d, x, y, b) = bias(d); 70 pointwise_convolved(d, x, y, b) += 71 pointwise_filter(d, rc) * depthwise_convolved(rc, x, y, b); 72 73 // ReLU 74 output(d, x, y, b) = max(pointwise_convolved(d, x, y, b), 0.f); 75 76 // The schedule. 77 if (auto_schedule) { 78 // Second layer of MobileNet v2 79 const int N = 4, CI = 32, CO = 16, CM = 1, W = 112, H = 112; 80 81 input.dim(0).set_estimate(0, CI); 82 input.dim(1).set_estimate(0, W); 83 input.dim(2).set_estimate(0, H); 84 input.dim(3).set_estimate(0, N); 85 86 depthwise_filter.dim(0).set_estimate(0, CI / CO); 87 depthwise_filter.dim(1).set_estimate(0, CI); 88 depthwise_filter.dim(2).set_estimate(0, 3); 89 depthwise_filter.dim(3).set_estimate(0, 3); 90 91 pointwise_filter.dim(0).set_estimate(0, CO); 92 pointwise_filter.dim(1).set_estimate(0, CI * CM); 93 94 bias.dim(0).set_estimate(0, CO); 95 96 output.dim(0).set_estimate(0, CO); 97 output.dim(1).set_estimate(0, W); 98 output.dim(2).set_estimate(0, H); 99 output.dim(3).set_estimate(0, N); 100 } else if (get_target().has_gpu_feature()) { 101 // 0.066ms on a 2060 RTX super. This is about 1.2 TFlops, 102 // which is not a very large fraction of peak. For 103 // comparison though, tensorflow 2.3 achieves 0.13ms via 104 // cudnn 7. So we're twice as fast. 105 106 // This schedule fuses the depthwise conv into the pointwise 107 // conv. The results of the depthwise conv are computed inside 108 // the outer of the two pointwise reduction loops. 109 110 Var xi, yi, di, dii, xii, yii; 111 RVar ro, ri; 112 113 // The pointwise convolution kernel. Produces a 4x4 tile of output. 114 Func(output) 115 .tile({d, x, y}, {di, xi, yi}, {16, 4, 4}) 116 .tile({di, xi, yi}, {dii, xii, yii}, {1, 2, 2}) 117 .gpu_threads(di, xi, yi) 118 .fuse(y, b, b) 119 .gpu_blocks(d, x, b) 120 .unroll(xii) 121 .unroll(yii) 122 .unroll(dii); 123 124 pointwise_convolved.compute_at(output, di) 125 .reorder(x, y, d) 126 .unroll(x) 127 .unroll(y) 128 .unroll(d) 129 .update() 130 .unroll(x) 131 .unroll(y) 132 .unroll(d) 133 .split(rc, ro, ri, 4) 134 .reorder(ri, x, y, d, ro) 135 .unroll(ri); 136 137 // We're going to call in() on depthwise_convolved twice. 138 // The first will be to give it a wrapper to do the 139 // accumulation in registers before writing the result to 140 // shared. The second will be staging the loads from 141 // shared into registers. We write them in reverse order 142 // below: 143 144 // We can do 4-wide vectorized loads from shared memory if 145 // we unroll the reduction loop by a factor of four above 146 // and stage the loads from the depthwise_convolved 147 // output. 148 149 depthwise_convolved.in() 150 .in() 151 .compute_at(pointwise_convolved, x) 152 .bound_extent(d, 4) 153 .vectorize(d) 154 .unroll(x) 155 .unroll(y); 156 157 // The depthwise convolution kernel. Produces a 4x4 tile 158 // of intermediate state, storing the result in shared. 159 depthwise_convolved.in() 160 .compute_at(output, d) 161 .tile({d, x, y}, {di, xi, yi}, {32, 4, 4}, TailStrategy::RoundUp) 162 .tile({di, xi, yi}, {dii, xii, yii}, {2, 2, 2}) 163 .gpu_threads(di, xi, yi) 164 .unroll(xii) 165 .unroll(yii) 166 .unroll(dii); 167 168 depthwise_convolved 169 .compute_at(depthwise_convolved.in(), di) 170 .unroll(x) 171 .unroll(y) 172 .unroll(d) 173 .update() 174 .reorder(d, x, y, rx, ry, rd) 175 .unroll(x) 176 .unroll(y) 177 .unroll(d); 178 } else { 179 // CPU schedule 180 181 // 0.13ms on an Intel i9-9960X using 16 threads pinned to 3.0 GHz, 182 // which is only about 20% of peak flops. 183 184 int tile_w = 1; 185 int tile_h = 1; 186 int tile_d = 1; 187 const int vec = natural_vector_size<float>(); 188 189 // Figure out how many registers we have in the register 190 // file on this target. 191 int num_regs = 16; 192 if (get_target().has_feature(Target::AVX512_Skylake) || 193 (get_target().arch == Target::ARM && 194 get_target().bits == 64)) { 195 num_regs = 32; 196 } 197 198 // Pick a tile size designed to fit into the register file. 199 if (num_regs == 32 && vec == 16) { 200 // 32 vector registers available of size 16. Use 24 of 201 // them for accumulators. 202 tile_d = 1; 203 tile_w = 6; 204 tile_h = 4; 205 // Using more tiles in the d dimension would be 206 // better, but we're tuning for 16 output channels and 207 // our vectors are already that wide (on avx512). 208 } else if (num_regs == 32 && vec == 4) { 209 // 32 vector registers, of size 4. We'll use 24. 210 tile_d = 4; 211 tile_w = 3; 212 tile_h = 2; 213 } else if (num_regs == 16 && vec == 8) { 214 // 16 registers available of size 8. Use 12 for accumulators. 215 tile_d = 2; 216 tile_w = 3; 217 tile_h = 2; 218 } else { 219 // Old x86 or 32-bit arm. Assume vectors of size 4, 220 // 16 registers. No FMA so we need to reserve a few 221 // more registers for things other than the 222 // accumulators. 223 tile_d = 4; 224 tile_w = 2; 225 tile_h = 1; 226 } 227 // Change units from vectors to elements 228 tile_d *= vec; 229 230 // This schedule aggressively fuses the depthwise conv into 231 // the pointwise conv. We do the depthwise convolution within 232 // slices of the channel reduction loop in the pointwise 233 // convolution. 234 235 Var di, xi, yi; 236 RVar ro, ri; 237 238 Func(output) 239 .tile({d, x, y}, {di, xi, yi}, {tile_d, tile_w, tile_h}) 240 .vectorize(di) 241 .unroll(xi) 242 .unroll(yi) 243 .fuse(y, b, b) 244 .parallel(b); 245 246 pointwise_convolved.compute_at(output, d) 247 .vectorize(d) 248 .unroll(x) 249 .unroll(y) 250 .update() 251 .reorder(d, x, y, rc, b) 252 .vectorize(d) 253 .unroll(x) 254 .unroll(y) 255 .split(rc, ro, ri, tile_d); 256 257 depthwise_convolved 258 .store_in(MemoryType::Stack) 259 .bound_extent(d, tile_d) 260 .compute_at(pointwise_convolved, ro) 261 .vectorize(d) 262 .reorder(x, y, d) 263 .unroll(x) 264 .unroll(y) 265 .update() 266 .vectorize(d) 267 .reorder(x, y, d, rd, rx, ry, b) 268 .unroll(x) 269 .unroll(y); 270 271 input_bounded 272 .store_in(MemoryType::Stack) 273 .compute_at(pointwise_convolved, ro) 274 .tile(d, x, di, xi, vec, 4, TailStrategy::RoundUp) 275 .vectorize(di) 276 .unroll(xi); 277 } 278 279 if (!auto_schedule) { 280 // We're going to specialize both schedules for channel_multiplier = 1, 281 // in which case it's nice to know that depthwise_filter 282 // is dense across the second dimension. 283 depthwise_filter.dim(1).set_stride(channel_multiplier); 284 Expr intermediate_channels = pointwise_filter.dim(1).extent(); 285 // We'll also specialize for a multiple-of-32 intermediate 286 // channels, and a 3x3 conv. 287 output.specialize(channel_multiplier == 1 && 288 intermediate_channels == (intermediate_channels / 32) * 32 && 289 depthwise_filter.dim(2).extent() == 3 && 290 depthwise_filter.dim(3).extent() == 3); 291 } 292 } 293 }; 294 } // namespace 295 296 HALIDE_REGISTER_GENERATOR(DepthwiseSeparableConvolution, depthwise_separable_conv) 297