1 // Copyright (C) 2015  Davis E. King (davis@dlib.net)
2 // License: Boost Software License   See LICENSE.txt for the full license.
3 #ifndef DLIB_DNn_LAYERS_H_
4 #define DLIB_DNn_LAYERS_H_
5 
6 #include "layers_abstract.h"
7 #include "../cuda/tensor.h"
8 #include "core.h"
9 #include <iostream>
10 #include <string>
11 #include "../rand.h"
12 #include "../string.h"
13 #include "../cuda/tensor_tools.h"
14 #include "../vectorstream.h"
15 #include "utilities.h"
16 #include <sstream>
17 
18 
19 namespace dlib
20 {
21 
22 // ----------------------------------------------------------------------------------------
23 
24     struct num_con_outputs
25     {
num_con_outputsnum_con_outputs26         num_con_outputs(unsigned long n) : num_outputs(n) {}
27         unsigned long num_outputs;
28     };
29 
30     template <
31         long _num_filters,
32         long _nr,
33         long _nc,
34         int _stride_y,
35         int _stride_x,
36         int _padding_y = _stride_y!=1? 0 : _nr/2,
37         int _padding_x = _stride_x!=1? 0 : _nc/2
38         >
39     class con_
40     {
41     public:
42 
43         static_assert(_num_filters > 0, "The number of filters must be > 0");
44         static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
45         static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
46         static_assert(_stride_y > 0, "The filter stride must be > 0");
47         static_assert(_stride_x > 0, "The filter stride must be > 0");
48         static_assert(_nr==0 || (0 <= _padding_y && _padding_y < _nr), "The padding must be smaller than the filter size.");
49         static_assert(_nc==0 || (0 <= _padding_x && _padding_x < _nc), "The padding must be smaller than the filter size.");
50         static_assert(_nr!=0 || 0 == _padding_y, "If _nr==0 then the padding must be set to 0 as well.");
51         static_assert(_nc!=0 || 0 == _padding_x, "If _nr==0 then the padding must be set to 0 as well.");
52 
con_(num_con_outputs o)53         con_(
54             num_con_outputs o
55         ) :
56             learning_rate_multiplier(1),
57             weight_decay_multiplier(1),
58             bias_learning_rate_multiplier(1),
59             bias_weight_decay_multiplier(0),
60             num_filters_(o.num_outputs),
61             padding_y_(_padding_y),
62             padding_x_(_padding_x),
63             use_bias(true)
64         {
65             DLIB_CASSERT(num_filters_ > 0);
66         }
67 
con_()68         con_() : con_(num_con_outputs(_num_filters)) {}
69 
num_filters()70         long num_filters() const { return num_filters_; }
nr()71         long nr() const
72         {
73             if (_nr==0)
74                 return filters.nr();
75             else
76                 return _nr;
77         }
nc()78         long nc() const
79         {
80             if (_nc==0)
81                 return filters.nc();
82             else
83                 return _nc;
84         }
stride_y()85         long stride_y() const { return _stride_y; }
stride_x()86         long stride_x() const { return _stride_x; }
padding_y()87         long padding_y() const { return padding_y_; }
padding_x()88         long padding_x() const { return padding_x_; }
89 
set_num_filters(long num)90         void set_num_filters(long num)
91         {
92             DLIB_CASSERT(num > 0);
93             if (num != num_filters_)
94             {
95                 DLIB_CASSERT(get_layer_params().size() == 0,
96                     "You can't change the number of filters in con_ if the parameter tensor has already been allocated.");
97                 num_filters_ = num;
98             }
99         }
100 
get_learning_rate_multiplier()101         double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
get_weight_decay_multiplier()102         double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
set_learning_rate_multiplier(double val)103         void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
set_weight_decay_multiplier(double val)104         void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }
105 
get_bias_learning_rate_multiplier()106         double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
get_bias_weight_decay_multiplier()107         double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
set_bias_learning_rate_multiplier(double val)108         void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
set_bias_weight_decay_multiplier(double val)109         void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }
disable_bias()110         void disable_bias() { use_bias = false; }
bias_is_disabled()111         bool bias_is_disabled() const { return !use_bias; }
112 
map_input_to_output(dpoint p)113         inline dpoint map_input_to_output (
114             dpoint p
115         ) const
116         {
117             p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
118             p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
119             return p;
120         }
121 
map_output_to_input(dpoint p)122         inline dpoint map_output_to_input (
123             dpoint p
124         ) const
125         {
126             p.x() = p.x()*stride_x() - padding_x() + nc()/2;
127             p.y() = p.y()*stride_y() - padding_y() + nr()/2;
128             return p;
129         }
130 
con_(const con_ & item)131         con_ (
132             const con_& item
133         ) :
134             params(item.params),
135             filters(item.filters),
136             biases(item.biases),
137             learning_rate_multiplier(item.learning_rate_multiplier),
138             weight_decay_multiplier(item.weight_decay_multiplier),
139             bias_learning_rate_multiplier(item.bias_learning_rate_multiplier),
140             bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
141             num_filters_(item.num_filters_),
142             padding_y_(item.padding_y_),
143             padding_x_(item.padding_x_),
144             use_bias(item.use_bias)
145         {
146             // this->conv is non-copyable and basically stateless, so we have to write our
147             // own copy to avoid trying to copy it and getting an error.
148         }
149 
150         con_& operator= (
151             const con_& item
152         )
153         {
154             if (this == &item)
155                 return *this;
156 
157             // this->conv is non-copyable and basically stateless, so we have to write our
158             // own copy to avoid trying to copy it and getting an error.
159             params = item.params;
160             filters = item.filters;
161             biases = item.biases;
162             padding_y_ = item.padding_y_;
163             padding_x_ = item.padding_x_;
164             learning_rate_multiplier = item.learning_rate_multiplier;
165             weight_decay_multiplier = item.weight_decay_multiplier;
166             bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
167             bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
168             num_filters_ = item.num_filters_;
169             use_bias = item.use_bias;
170             return *this;
171         }
172 
173         template <typename SUBNET>
setup(const SUBNET & sub)174         void setup (const SUBNET& sub)
175         {
176             const long filt_nr = _nr!=0 ? _nr : sub.get_output().nr();
177             const long filt_nc = _nc!=0 ? _nc : sub.get_output().nc();
178 
179             long num_inputs = filt_nr*filt_nc*sub.get_output().k();
180             long num_outputs = num_filters_;
181             // allocate params for the filters and also for the filter bias values.
182             params.set_size(num_inputs*num_filters_ + static_cast<int>(use_bias) * num_filters_);
183 
184             dlib::rand rnd(std::rand());
185             randomize_parameters(params, num_inputs+num_outputs, rnd);
186 
187             filters = alias_tensor(num_filters_, sub.get_output().k(), filt_nr, filt_nc);
188             if (use_bias)
189             {
190                 biases = alias_tensor(1,num_filters_);
191                 // set the initial bias values to zero
192                 biases(params,filters.size()) = 0;
193             }
194         }
195 
196         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)197         void forward(const SUBNET& sub, resizable_tensor& output)
198         {
199             conv.setup(sub.get_output(),
200                        filters(params,0),
201                        _stride_y,
202                        _stride_x,
203                        padding_y_,
204                        padding_x_);
205             conv(false, output,
206                 sub.get_output(),
207                 filters(params,0));
208             if (use_bias)
209             {
210                 tt::add(1,output,1,biases(params,filters.size()));
211             }
212         }
213 
214         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor & params_grad)215         void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
216         {
217             conv.get_gradient_for_data (true, gradient_input, filters(params,0), sub.get_gradient_input());
218             // no dpoint computing the parameter gradients if they won't be used.
219             if (learning_rate_multiplier != 0)
220             {
221                 auto filt = filters(params_grad,0);
222                 conv.get_gradient_for_filters (false, gradient_input, sub.get_output(), filt);
223                 if (use_bias)
224                 {
225                     auto b = biases(params_grad, filters.size());
226                     tt::assign_conv_bias_gradient(b, gradient_input);
227                 }
228             }
229         }
230 
get_layer_params()231         const tensor& get_layer_params() const { return params; }
get_layer_params()232         tensor& get_layer_params() { return params; }
233 
serialize(const con_ & item,std::ostream & out)234         friend void serialize(const con_& item, std::ostream& out)
235         {
236             serialize("con_5", out);
237             serialize(item.params, out);
238             serialize(item.num_filters_, out);
239             serialize(_nr, out);
240             serialize(_nc, out);
241             serialize(_stride_y, out);
242             serialize(_stride_x, out);
243             serialize(item.padding_y_, out);
244             serialize(item.padding_x_, out);
245             serialize(item.filters, out);
246             serialize(item.biases, out);
247             serialize(item.learning_rate_multiplier, out);
248             serialize(item.weight_decay_multiplier, out);
249             serialize(item.bias_learning_rate_multiplier, out);
250             serialize(item.bias_weight_decay_multiplier, out);
251             serialize(item.use_bias, out);
252         }
253 
deserialize(con_ & item,std::istream & in)254         friend void deserialize(con_& item, std::istream& in)
255         {
256             std::string version;
257             deserialize(version, in);
258             long nr;
259             long nc;
260             int stride_y;
261             int stride_x;
262             if (version == "con_4" || version == "con_5")
263             {
264                 deserialize(item.params, in);
265                 deserialize(item.num_filters_, in);
266                 deserialize(nr, in);
267                 deserialize(nc, in);
268                 deserialize(stride_y, in);
269                 deserialize(stride_x, in);
270                 deserialize(item.padding_y_, in);
271                 deserialize(item.padding_x_, in);
272                 deserialize(item.filters, in);
273                 deserialize(item.biases, in);
274                 deserialize(item.learning_rate_multiplier, in);
275                 deserialize(item.weight_decay_multiplier, in);
276                 deserialize(item.bias_learning_rate_multiplier, in);
277                 deserialize(item.bias_weight_decay_multiplier, in);
278                 if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
279                 if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
280                 if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
281                 if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
282                 if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
283                 if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
284                 if (version == "con_5")
285                 {
286                     deserialize(item.use_bias, in);
287                 }
288             }
289             else
290             {
291                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
292             }
293         }
294 
295 
296         friend std::ostream& operator<<(std::ostream& out, const con_& item)
297         {
298             out << "con\t ("
299                 << "num_filters="<<item.num_filters_
300                 << ", nr="<<item.nr()
301                 << ", nc="<<item.nc()
302                 << ", stride_y="<<_stride_y
303                 << ", stride_x="<<_stride_x
304                 << ", padding_y="<<item.padding_y_
305                 << ", padding_x="<<item.padding_x_
306                 << ")";
307             out << " learning_rate_mult="<<item.learning_rate_multiplier;
308             out << " weight_decay_mult="<<item.weight_decay_multiplier;
309             if (item.use_bias)
310             {
311                 out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
312                 out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
313             }
314             else
315             {
316                 out << " use_bias=false";
317             }
318             return out;
319         }
320 
to_xml(const con_ & item,std::ostream & out)321         friend void to_xml(const con_& item, std::ostream& out)
322         {
323             out << "<con"
324                 << " num_filters='"<<item.num_filters_<<"'"
325                 << " nr='"<<item.nr()<<"'"
326                 << " nc='"<<item.nc()<<"'"
327                 << " stride_y='"<<_stride_y<<"'"
328                 << " stride_x='"<<_stride_x<<"'"
329                 << " padding_y='"<<item.padding_y_<<"'"
330                 << " padding_x='"<<item.padding_x_<<"'"
331                 << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
332                 << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
333                 << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
334                 << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
335                 << " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
336             out << mat(item.params);
337             out << "</con>";
338         }
339 
340     private:
341 
342         resizable_tensor params;
343         alias_tensor filters, biases;
344 
345         tt::tensor_conv conv;
346         double learning_rate_multiplier;
347         double weight_decay_multiplier;
348         double bias_learning_rate_multiplier;
349         double bias_weight_decay_multiplier;
350         long num_filters_;
351 
352         // These are here only because older versions of con (which you might encounter
353         // serialized to disk) used different padding settings.
354         int padding_y_;
355         int padding_x_;
356         bool use_bias;
357 
358     };
359 
360     template <
361         long num_filters,
362         long nr,
363         long nc,
364         int stride_y,
365         int stride_x,
366         typename SUBNET
367         >
368     using con = add_layer<con_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
369 
370 // ----------------------------------------------------------------------------------------
371 
372     template <
373         long _num_filters,
374         long _nr,
375         long _nc,
376         int _stride_y,
377         int _stride_x,
378         int _padding_y = _stride_y!=1? 0 : _nr/2,
379         int _padding_x = _stride_x!=1? 0 : _nc/2
380         >
381     class cont_
382     {
383     public:
384 
385         static_assert(_num_filters > 0, "The number of filters must be > 0");
386         static_assert(_nr > 0, "The number of rows in a filter must be > 0");
387         static_assert(_nc > 0, "The number of columns in a filter must be > 0");
388         static_assert(_stride_y > 0, "The filter stride must be > 0");
389         static_assert(_stride_x > 0, "The filter stride must be > 0");
390         static_assert(0 <= _padding_y && _padding_y < _nr, "The padding must be smaller than the filter size.");
391         static_assert(0 <= _padding_x && _padding_x < _nc, "The padding must be smaller than the filter size.");
392 
cont_(num_con_outputs o)393         cont_(
394             num_con_outputs o
395         ) :
396             learning_rate_multiplier(1),
397             weight_decay_multiplier(1),
398             bias_learning_rate_multiplier(1),
399             bias_weight_decay_multiplier(0),
400             num_filters_(o.num_outputs),
401             padding_y_(_padding_y),
402             padding_x_(_padding_x),
403             use_bias(true)
404         {
405             DLIB_CASSERT(num_filters_ > 0);
406         }
407 
cont_()408         cont_() : cont_(num_con_outputs(_num_filters)) {}
409 
num_filters()410         long num_filters() const { return num_filters_; }
nr()411         long nr() const { return _nr; }
nc()412         long nc() const { return _nc; }
stride_y()413         long stride_y() const { return _stride_y; }
stride_x()414         long stride_x() const { return _stride_x; }
padding_y()415         long padding_y() const { return padding_y_; }
padding_x()416         long padding_x() const { return padding_x_; }
417 
set_num_filters(long num)418         void set_num_filters(long num)
419         {
420             DLIB_CASSERT(num > 0);
421             if (num != num_filters_)
422             {
423                 DLIB_CASSERT(get_layer_params().size() == 0,
424                     "You can't change the number of filters in cont_ if the parameter tensor has already been allocated.");
425                 num_filters_ = num;
426             }
427         }
428 
get_learning_rate_multiplier()429         double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
get_weight_decay_multiplier()430         double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
set_learning_rate_multiplier(double val)431         void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
set_weight_decay_multiplier(double val)432         void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }
433 
get_bias_learning_rate_multiplier()434         double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
get_bias_weight_decay_multiplier()435         double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
set_bias_learning_rate_multiplier(double val)436         void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
set_bias_weight_decay_multiplier(double val)437         void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }
disable_bias()438         void disable_bias() { use_bias = false; }
bias_is_disabled()439         bool bias_is_disabled() const { return !use_bias; }
440 
map_output_to_input(dpoint p)441         inline dpoint map_output_to_input (
442             dpoint p
443         ) const
444         {
445             p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
446             p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
447             return p;
448         }
449 
map_input_to_output(dpoint p)450         inline dpoint map_input_to_output (
451             dpoint p
452         ) const
453         {
454             p.x() = p.x()*stride_x() - padding_x() + nc()/2;
455             p.y() = p.y()*stride_y() - padding_y() + nr()/2;
456             return p;
457         }
458 
cont_(const cont_ & item)459         cont_ (
460             const cont_& item
461         ) :
462             params(item.params),
463             filters(item.filters),
464             biases(item.biases),
465             learning_rate_multiplier(item.learning_rate_multiplier),
466             weight_decay_multiplier(item.weight_decay_multiplier),
467             bias_learning_rate_multiplier(item.bias_learning_rate_multiplier),
468             bias_weight_decay_multiplier(item.bias_weight_decay_multiplier),
469             num_filters_(item.num_filters_),
470             padding_y_(item.padding_y_),
471             padding_x_(item.padding_x_),
472             use_bias(item.use_bias)
473         {
474             // this->conv is non-copyable and basically stateless, so we have to write our
475             // own copy to avoid trying to copy it and getting an error.
476         }
477 
478         cont_& operator= (
479             const cont_& item
480         )
481         {
482             if (this == &item)
483                 return *this;
484 
485             // this->conv is non-copyable and basically stateless, so we have to write our
486             // own copy to avoid trying to copy it and getting an error.
487             params = item.params;
488             filters = item.filters;
489             biases = item.biases;
490             padding_y_ = item.padding_y_;
491             padding_x_ = item.padding_x_;
492             learning_rate_multiplier = item.learning_rate_multiplier;
493             weight_decay_multiplier = item.weight_decay_multiplier;
494             bias_learning_rate_multiplier = item.bias_learning_rate_multiplier;
495             bias_weight_decay_multiplier = item.bias_weight_decay_multiplier;
496             num_filters_ = item.num_filters_;
497             use_bias = item.use_bias;
498             return *this;
499         }
500 
501         template <typename SUBNET>
setup(const SUBNET & sub)502         void setup (const SUBNET& sub)
503         {
504             long num_inputs = _nr*_nc*sub.get_output().k();
505             long num_outputs = num_filters_;
506             // allocate params for the filters and also for the filter bias values.
507             params.set_size(num_inputs*num_filters_ + num_filters_ * static_cast<int>(use_bias));
508 
509             dlib::rand rnd(std::rand());
510             randomize_parameters(params, num_inputs+num_outputs, rnd);
511 
512             filters = alias_tensor(sub.get_output().k(), num_filters_, _nr, _nc);
513             if (use_bias)
514             {
515                 biases = alias_tensor(1,num_filters_);
516                 // set the initial bias values to zero
517                 biases(params,filters.size()) = 0;
518             }
519         }
520 
521         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)522         void forward(const SUBNET& sub, resizable_tensor& output)
523         {
524             auto filt = filters(params,0);
525             unsigned int gnr = _stride_y * (sub.get_output().nr() - 1) + filt.nr() - 2 * padding_y_;
526             unsigned int gnc = _stride_x * (sub.get_output().nc() - 1) + filt.nc() - 2 * padding_x_;
527             unsigned int gnsamps = sub.get_output().num_samples();
528             unsigned int gk = filt.k();
529             output.set_size(gnsamps,gk,gnr,gnc);
530             conv.setup(output,filt,_stride_y,_stride_x,padding_y_,padding_x_);
531             conv.get_gradient_for_data(false, sub.get_output(),filt,output);
532             if (use_bias)
533             {
534                 tt::add(1,output,1,biases(params,filters.size()));
535             }
536         }
537 
538         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor & params_grad)539         void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
540         {
541             auto filt = filters(params,0);
542             conv(true, sub.get_gradient_input(),gradient_input, filt);
543             // no point computing the parameter gradients if they won't be used.
544             if (learning_rate_multiplier != 0)
545             {
546                 auto filt = filters(params_grad,0);
547                 conv.get_gradient_for_filters (false, sub.get_output(),gradient_input, filt);
548                 if (use_bias)
549                 {
550                     auto b = biases(params_grad, filters.size());
551                     tt::assign_conv_bias_gradient(b, gradient_input);
552                 }
553             }
554         }
555 
get_layer_params()556         const tensor& get_layer_params() const { return params; }
get_layer_params()557         tensor& get_layer_params() { return params; }
558 
serialize(const cont_ & item,std::ostream & out)559         friend void serialize(const cont_& item, std::ostream& out)
560         {
561             serialize("cont_2", out);
562             serialize(item.params, out);
563             serialize(item.num_filters_, out);
564             serialize(_nr, out);
565             serialize(_nc, out);
566             serialize(_stride_y, out);
567             serialize(_stride_x, out);
568             serialize(item.padding_y_, out);
569             serialize(item.padding_x_, out);
570             serialize(item.filters, out);
571             serialize(item.biases, out);
572             serialize(item.learning_rate_multiplier, out);
573             serialize(item.weight_decay_multiplier, out);
574             serialize(item.bias_learning_rate_multiplier, out);
575             serialize(item.bias_weight_decay_multiplier, out);
576             serialize(item.use_bias, out);
577         }
578 
deserialize(cont_ & item,std::istream & in)579         friend void deserialize(cont_& item, std::istream& in)
580         {
581             std::string version;
582             deserialize(version, in);
583             long nr;
584             long nc;
585             int stride_y;
586             int stride_x;
587             if (version == "cont_1" || version == "cont_2")
588             {
589                 deserialize(item.params, in);
590                 deserialize(item.num_filters_, in);
591                 deserialize(nr, in);
592                 deserialize(nc, in);
593                 deserialize(stride_y, in);
594                 deserialize(stride_x, in);
595                 deserialize(item.padding_y_, in);
596                 deserialize(item.padding_x_, in);
597                 deserialize(item.filters, in);
598                 deserialize(item.biases, in);
599                 deserialize(item.learning_rate_multiplier, in);
600                 deserialize(item.weight_decay_multiplier, in);
601                 deserialize(item.bias_learning_rate_multiplier, in);
602                 deserialize(item.bias_weight_decay_multiplier, in);
603                 if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::con_");
604                 if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::con_");
605                 if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::con_");
606                 if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::con_");
607                 if (stride_y != _stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::con_");
608                 if (stride_x != _stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::con_");
609                 if (version == "cont_2")
610                 {
611                     deserialize(item.use_bias, in);
612                 }
613             }
614             else
615             {
616                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::con_.");
617             }
618         }
619 
620 
621         friend std::ostream& operator<<(std::ostream& out, const cont_& item)
622         {
623             out << "cont\t ("
624                 << "num_filters="<<item.num_filters_
625                 << ", nr="<<_nr
626                 << ", nc="<<_nc
627                 << ", stride_y="<<_stride_y
628                 << ", stride_x="<<_stride_x
629                 << ", padding_y="<<item.padding_y_
630                 << ", padding_x="<<item.padding_x_
631                 << ")";
632             out << " learning_rate_mult="<<item.learning_rate_multiplier;
633             out << " weight_decay_mult="<<item.weight_decay_multiplier;
634             if (item.use_bias)
635             {
636                 out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
637                 out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
638             }
639             else
640             {
641                 out << " use_bias=false";
642             }
643             return out;
644         }
645 
to_xml(const cont_ & item,std::ostream & out)646         friend void to_xml(const cont_& item, std::ostream& out)
647         {
648             out << "<cont"
649                 << " num_filters='"<<item.num_filters_<<"'"
650                 << " nr='"<<_nr<<"'"
651                 << " nc='"<<_nc<<"'"
652                 << " stride_y='"<<_stride_y<<"'"
653                 << " stride_x='"<<_stride_x<<"'"
654                 << " padding_y='"<<item.padding_y_<<"'"
655                 << " padding_x='"<<item.padding_x_<<"'"
656                 << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
657                 << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
658                 << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
659                 << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
660                 << " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
661             out << mat(item.params);
662             out << "</cont>";
663         }
664 
665     private:
666 
667         resizable_tensor params;
668         alias_tensor filters, biases;
669 
670         tt::tensor_conv conv;
671         double learning_rate_multiplier;
672         double weight_decay_multiplier;
673         double bias_learning_rate_multiplier;
674         double bias_weight_decay_multiplier;
675         long num_filters_;
676 
677         int padding_y_;
678         int padding_x_;
679 
680         bool use_bias;
681 
682     };
683 
684     template <
685         long num_filters,
686         long nr,
687         long nc,
688         int stride_y,
689         int stride_x,
690         typename SUBNET
691         >
692     using cont = add_layer<cont_<num_filters,nr,nc,stride_y,stride_x>, SUBNET>;
693 
694 // ----------------------------------------------------------------------------------------
695 
696     template <
697         int scale_y,
698         int scale_x
699         >
700     class upsample_
701     {
702     public:
703         static_assert(scale_y >= 1, "upsampling scale factor can't be less than 1.");
704         static_assert(scale_x >= 1, "upsampling scale factor can't be less than 1.");
705 
upsample_()706         upsample_()
707         {
708         }
709 
710         template <typename SUBNET>
setup(const SUBNET &)711         void setup (const SUBNET& /*sub*/)
712         {
713         }
714 
715         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)716         void forward(const SUBNET& sub, resizable_tensor& output)
717         {
718             output.set_size(
719                 sub.get_output().num_samples(),
720                 sub.get_output().k(),
721                 scale_y*sub.get_output().nr(),
722                 scale_x*sub.get_output().nc());
723             tt::resize_bilinear(output, sub.get_output());
724         }
725 
726         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)727         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
728         {
729             tt::resize_bilinear_gradient(sub.get_gradient_input(), gradient_input);
730         }
731 
map_input_to_output(dpoint p)732         inline dpoint map_input_to_output (dpoint p) const
733         {
734             p.x() = p.x()*scale_x;
735             p.y() = p.y()*scale_y;
736             return p;
737         }
map_output_to_input(dpoint p)738         inline dpoint map_output_to_input (dpoint p) const
739         {
740             p.x() = p.x()/scale_x;
741             p.y() = p.y()/scale_y;
742             return p;
743         }
744 
get_layer_params()745         const tensor& get_layer_params() const { return params; }
get_layer_params()746         tensor& get_layer_params() { return params; }
747 
serialize(const upsample_ &,std::ostream & out)748         friend void serialize(const upsample_& /*item*/, std::ostream& out)
749         {
750             serialize("upsample_", out);
751             serialize(scale_y, out);
752             serialize(scale_x, out);
753         }
754 
deserialize(upsample_ &,std::istream & in)755         friend void deserialize(upsample_& /*item*/, std::istream& in)
756         {
757             std::string version;
758             deserialize(version, in);
759             if (version != "upsample_")
760                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::upsample_.");
761 
762             int _scale_y;
763             int _scale_x;
764             deserialize(_scale_y, in);
765             deserialize(_scale_x, in);
766             if (_scale_y != scale_y || _scale_x != scale_x)
767                 throw serialization_error("Wrong scale found while deserializing dlib::upsample_");
768         }
769 
770         friend std::ostream& operator<<(std::ostream& out, const upsample_& /*item*/)
771         {
772             out << "upsample\t ("
773                 << "scale_y="<<scale_y
774                 << ", scale_x="<<scale_x
775                 << ")";
776             return out;
777         }
778 
to_xml(const upsample_ &,std::ostream & out)779         friend void to_xml(const upsample_& /*item*/, std::ostream& out)
780         {
781             out << "<upsample"
782                 << " scale_y='"<<scale_y<<"'"
783                 << " scale_x='"<<scale_x<<"'/>\n";
784         }
785 
786     private:
787         resizable_tensor params;
788     };
789 
790     template <
791         int scale,
792         typename SUBNET
793         >
794     using upsample = add_layer<upsample_<scale,scale>, SUBNET>;
795 
796 // ----------------------------------------------------------------------------------------
797 
798     template <
799         long NR_,
800         long NC_
801         >
802     class resize_to_
803     {
804     public:
805         static_assert(NR_ >= 1, "NR resize parameter can't be less than 1.");
806         static_assert(NC_ >= 1, "NC resize parameter can't be less than 1.");
807 
resize_to_()808         resize_to_()
809         {
810         }
811 
812         template <typename SUBNET>
setup(const SUBNET &)813         void setup (const SUBNET& /*sub*/)
814         {
815         }
816 
817         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)818         void forward(const SUBNET& sub, resizable_tensor& output)
819         {
820             scale_y = (double)NR_/(double)sub.get_output().nr();
821             scale_x = (double)NC_/(double)sub.get_output().nc();
822 
823             output.set_size(
824                 sub.get_output().num_samples(),
825                 sub.get_output().k(),
826                 NR_,
827                 NC_);
828             tt::resize_bilinear(output, sub.get_output());
829         }
830 
831         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)832         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
833         {
834             tt::resize_bilinear_gradient(sub.get_gradient_input(), gradient_input);
835         }
836 
map_input_to_output(dpoint p)837         inline dpoint map_input_to_output (dpoint p) const
838         {
839             p.x() = p.x()*scale_x;
840             p.y() = p.y()*scale_y;
841             return p;
842         }
843 
map_output_to_input(dpoint p)844         inline dpoint map_output_to_input (dpoint p) const
845         {
846             p.x() = p.x()/scale_x;
847             p.y() = p.y()/scale_y;
848             return p;
849         }
850 
get_layer_params()851         const tensor& get_layer_params() const { return params; }
get_layer_params()852         tensor& get_layer_params() { return params; }
853 
serialize(const resize_to_ & item,std::ostream & out)854         friend void serialize(const resize_to_& item, std::ostream& out)
855         {
856             serialize("resize_to_", out);
857             serialize(NR_, out);
858             serialize(NC_, out);
859             serialize(item.scale_y, out);
860             serialize(item.scale_x, out);
861         }
862 
deserialize(resize_to_ & item,std::istream & in)863         friend void deserialize(resize_to_& item, std::istream& in)
864         {
865             std::string version;
866             deserialize(version, in);
867             if (version != "resize_to_")
868                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::resize_to_.");
869 
870             long _nr;
871             long _nc;
872             deserialize(_nr, in);
873             deserialize(_nc, in);
874             deserialize(item.scale_y, in);
875             deserialize(item.scale_x, in);
876             if (_nr != NR_ || _nc != NC_)
877                 throw serialization_error("Wrong size found while deserializing dlib::resize_to_");
878         }
879 
880         friend std::ostream& operator<<(std::ostream& out, const resize_to_& /*item*/)
881         {
882             out << "resize_to ("
883                 << "nr=" << NR_
884                 << ", nc=" << NC_
885                 << ")";
886             return out;
887         }
888 
to_xml(const resize_to_ &,std::ostream & out)889         friend void to_xml(const resize_to_& /*item*/, std::ostream& out)
890         {
891             out << "<resize_to";
892             out << " nr='" << NR_ << "'" ;
893             out << " nc='" << NC_ << "'/>\n";
894         }
895     private:
896         resizable_tensor params;
897         double scale_y;
898         double scale_x;
899 
900     };  // end of class resize_to_
901 
902 
903     template <
904         long NR,
905         long NC,
906         typename SUBNET
907         >
908     using resize_to = add_layer<resize_to_<NR,NC>, SUBNET>;
909 
910 // ----------------------------------------------------------------------------------------
911 
912     template <
913         long _nr,
914         long _nc,
915         int _stride_y,
916         int _stride_x,
917         int _padding_y = _stride_y!=1? 0 : _nr/2,
918         int _padding_x = _stride_x!=1? 0 : _nc/2
919         >
920     class max_pool_
921     {
922         static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
923         static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
924         static_assert(_stride_y > 0, "The filter stride must be > 0");
925         static_assert(_stride_x > 0, "The filter stride must be > 0");
926         static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)),
927             "The padding must be smaller than the filter size, unless the filters size is 0.");
928         static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)),
929             "The padding must be smaller than the filter size, unless the filters size is 0.");
930     public:
931 
932 
max_pool_()933         max_pool_(
934         ) :
935             padding_y_(_padding_y),
936             padding_x_(_padding_x)
937         {}
938 
nr()939         long nr() const { return _nr; }
nc()940         long nc() const { return _nc; }
stride_y()941         long stride_y() const { return _stride_y; }
stride_x()942         long stride_x() const { return _stride_x; }
padding_y()943         long padding_y() const { return padding_y_; }
padding_x()944         long padding_x() const { return padding_x_; }
945 
map_input_to_output(dpoint p)946         inline dpoint map_input_to_output (
947             dpoint p
948         ) const
949         {
950             p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
951             p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
952             return p;
953         }
954 
map_output_to_input(dpoint p)955         inline dpoint map_output_to_input (
956             dpoint p
957         ) const
958         {
959             p.x() = p.x()*stride_x() - padding_x() + nc()/2;
960             p.y() = p.y()*stride_y() - padding_y() + nr()/2;
961             return p;
962         }
963 
max_pool_(const max_pool_ & item)964         max_pool_ (
965             const max_pool_& item
966         )  :
967             padding_y_(item.padding_y_),
968             padding_x_(item.padding_x_)
969         {
970             // this->mp is non-copyable so we have to write our own copy to avoid trying to
971             // copy it and getting an error.
972         }
973 
974         max_pool_& operator= (
975             const max_pool_& item
976         )
977         {
978             if (this == &item)
979                 return *this;
980 
981             padding_y_ = item.padding_y_;
982             padding_x_ = item.padding_x_;
983 
984             // this->mp is non-copyable so we have to write our own copy to avoid trying to
985             // copy it and getting an error.
986             return *this;
987         }
988 
989         template <typename SUBNET>
setup(const SUBNET &)990         void setup (const SUBNET& /*sub*/)
991         {
992         }
993 
994         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)995         void forward(const SUBNET& sub, resizable_tensor& output)
996         {
997             mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(),
998                                  _nc!=0?_nc:sub.get_output().nc(),
999                                  _stride_y, _stride_x, padding_y_, padding_x_);
1000 
1001             mp(output, sub.get_output());
1002         }
1003 
1004         template <typename SUBNET>
backward(const tensor & computed_output,const tensor & gradient_input,SUBNET & sub,tensor &)1005         void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
1006         {
1007             mp.setup_max_pooling(_nr!=0?_nr:sub.get_output().nr(),
1008                                  _nc!=0?_nc:sub.get_output().nc(),
1009                                  _stride_y, _stride_x, padding_y_, padding_x_);
1010 
1011             mp.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
1012         }
1013 
get_layer_params()1014         const tensor& get_layer_params() const { return params; }
get_layer_params()1015         tensor& get_layer_params() { return params; }
1016 
serialize(const max_pool_ & item,std::ostream & out)1017         friend void serialize(const max_pool_& item, std::ostream& out)
1018         {
1019             serialize("max_pool_2", out);
1020             serialize(_nr, out);
1021             serialize(_nc, out);
1022             serialize(_stride_y, out);
1023             serialize(_stride_x, out);
1024             serialize(item.padding_y_, out);
1025             serialize(item.padding_x_, out);
1026         }
1027 
deserialize(max_pool_ & item,std::istream & in)1028         friend void deserialize(max_pool_& item, std::istream& in)
1029         {
1030             std::string version;
1031             deserialize(version, in);
1032             long nr;
1033             long nc;
1034             int stride_y;
1035             int stride_x;
1036             if (version == "max_pool_2")
1037             {
1038                 deserialize(nr, in);
1039                 deserialize(nc, in);
1040                 deserialize(stride_y, in);
1041                 deserialize(stride_x, in);
1042                 deserialize(item.padding_y_, in);
1043                 deserialize(item.padding_x_, in);
1044             }
1045             else
1046             {
1047                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::max_pool_.");
1048             }
1049 
1050             if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::max_pool_");
1051             if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::max_pool_");
1052             if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::max_pool_");
1053             if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::max_pool_");
1054             if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::max_pool_");
1055             if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::max_pool_");
1056         }
1057 
1058         friend std::ostream& operator<<(std::ostream& out, const max_pool_& item)
1059         {
1060             out << "max_pool ("
1061                 << "nr="<<_nr
1062                 << ", nc="<<_nc
1063                 << ", stride_y="<<_stride_y
1064                 << ", stride_x="<<_stride_x
1065                 << ", padding_y="<<item.padding_y_
1066                 << ", padding_x="<<item.padding_x_
1067                 << ")";
1068             return out;
1069         }
1070 
to_xml(const max_pool_ & item,std::ostream & out)1071         friend void to_xml(const max_pool_& item, std::ostream& out)
1072         {
1073             out << "<max_pool"
1074                 << " nr='"<<_nr<<"'"
1075                 << " nc='"<<_nc<<"'"
1076                 << " stride_y='"<<_stride_y<<"'"
1077                 << " stride_x='"<<_stride_x<<"'"
1078                 << " padding_y='"<<item.padding_y_<<"'"
1079                 << " padding_x='"<<item.padding_x_<<"'"
1080                 << "/>\n";
1081         }
1082 
1083 
1084     private:
1085 
1086 
1087         tt::pooling mp;
1088         resizable_tensor params;
1089 
1090         int padding_y_;
1091         int padding_x_;
1092     };
1093 
1094     template <
1095         long nr,
1096         long nc,
1097         int stride_y,
1098         int stride_x,
1099         typename SUBNET
1100         >
1101     using max_pool = add_layer<max_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
1102 
1103     template <
1104         typename SUBNET
1105         >
1106     using max_pool_everything = add_layer<max_pool_<0,0,1,1>, SUBNET>;
1107 
1108 // ----------------------------------------------------------------------------------------
1109 
1110     template <
1111         long _nr,
1112         long _nc,
1113         int _stride_y,
1114         int _stride_x,
1115         int _padding_y = _stride_y!=1? 0 : _nr/2,
1116         int _padding_x = _stride_x!=1? 0 : _nc/2
1117         >
1118     class avg_pool_
1119     {
1120     public:
1121         static_assert(_nr >= 0, "The number of rows in a filter must be >= 0");
1122         static_assert(_nc >= 0, "The number of columns in a filter must be >= 0");
1123         static_assert(_stride_y > 0, "The filter stride must be > 0");
1124         static_assert(_stride_x > 0, "The filter stride must be > 0");
1125         static_assert(0 <= _padding_y && ((_nr==0 && _padding_y == 0) || (_nr!=0 && _padding_y < _nr)),
1126             "The padding must be smaller than the filter size, unless the filters size is 0.");
1127         static_assert(0 <= _padding_x && ((_nc==0 && _padding_x == 0) || (_nc!=0 && _padding_x < _nc)),
1128             "The padding must be smaller than the filter size, unless the filters size is 0.");
1129 
avg_pool_()1130         avg_pool_(
1131         ) :
1132             padding_y_(_padding_y),
1133             padding_x_(_padding_x)
1134         {}
1135 
nr()1136         long nr() const { return _nr; }
nc()1137         long nc() const { return _nc; }
stride_y()1138         long stride_y() const { return _stride_y; }
stride_x()1139         long stride_x() const { return _stride_x; }
padding_y()1140         long padding_y() const { return padding_y_; }
padding_x()1141         long padding_x() const { return padding_x_; }
1142 
map_input_to_output(dpoint p)1143         inline dpoint map_input_to_output (
1144             dpoint p
1145         ) const
1146         {
1147             p.x() = (p.x()+padding_x()-nc()/2)/stride_x();
1148             p.y() = (p.y()+padding_y()-nr()/2)/stride_y();
1149             return p;
1150         }
1151 
map_output_to_input(dpoint p)1152         inline dpoint map_output_to_input (
1153             dpoint p
1154         ) const
1155         {
1156             p.x() = p.x()*stride_x() - padding_x() + nc()/2;
1157             p.y() = p.y()*stride_y() - padding_y() + nr()/2;
1158             return p;
1159         }
1160 
avg_pool_(const avg_pool_ & item)1161         avg_pool_ (
1162             const avg_pool_& item
1163         )  :
1164             padding_y_(item.padding_y_),
1165             padding_x_(item.padding_x_)
1166         {
1167             // this->ap is non-copyable so we have to write our own copy to avoid trying to
1168             // copy it and getting an error.
1169         }
1170 
1171         avg_pool_& operator= (
1172             const avg_pool_& item
1173         )
1174         {
1175             if (this == &item)
1176                 return *this;
1177 
1178             padding_y_ = item.padding_y_;
1179             padding_x_ = item.padding_x_;
1180 
1181             // this->ap is non-copyable so we have to write our own copy to avoid trying to
1182             // copy it and getting an error.
1183             return *this;
1184         }
1185 
1186         template <typename SUBNET>
setup(const SUBNET &)1187         void setup (const SUBNET& /*sub*/)
1188         {
1189         }
1190 
1191         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)1192         void forward(const SUBNET& sub, resizable_tensor& output)
1193         {
1194             ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
1195                                  _nc!=0?_nc:sub.get_output().nc(),
1196                                  _stride_y, _stride_x, padding_y_, padding_x_);
1197 
1198             ap(output, sub.get_output());
1199         }
1200 
1201         template <typename SUBNET>
backward(const tensor & computed_output,const tensor & gradient_input,SUBNET & sub,tensor &)1202         void backward(const tensor& computed_output, const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
1203         {
1204             ap.setup_avg_pooling(_nr!=0?_nr:sub.get_output().nr(),
1205                                  _nc!=0?_nc:sub.get_output().nc(),
1206                                  _stride_y, _stride_x, padding_y_, padding_x_);
1207 
1208             ap.get_gradient(gradient_input, computed_output, sub.get_output(), sub.get_gradient_input());
1209         }
1210 
get_layer_params()1211         const tensor& get_layer_params() const { return params; }
get_layer_params()1212         tensor& get_layer_params() { return params; }
1213 
serialize(const avg_pool_ & item,std::ostream & out)1214         friend void serialize(const avg_pool_& item, std::ostream& out)
1215         {
1216             serialize("avg_pool_2", out);
1217             serialize(_nr, out);
1218             serialize(_nc, out);
1219             serialize(_stride_y, out);
1220             serialize(_stride_x, out);
1221             serialize(item.padding_y_, out);
1222             serialize(item.padding_x_, out);
1223         }
1224 
deserialize(avg_pool_ & item,std::istream & in)1225         friend void deserialize(avg_pool_& item, std::istream& in)
1226         {
1227             std::string version;
1228             deserialize(version, in);
1229 
1230             long nr;
1231             long nc;
1232             int stride_y;
1233             int stride_x;
1234             if (version == "avg_pool_2")
1235             {
1236                 deserialize(nr, in);
1237                 deserialize(nc, in);
1238                 deserialize(stride_y, in);
1239                 deserialize(stride_x, in);
1240                 deserialize(item.padding_y_, in);
1241                 deserialize(item.padding_x_, in);
1242             }
1243             else
1244             {
1245                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::avg_pool_.");
1246             }
1247 
1248             if (item.padding_y_ != _padding_y) throw serialization_error("Wrong padding_y found while deserializing dlib::avg_pool_");
1249             if (item.padding_x_ != _padding_x) throw serialization_error("Wrong padding_x found while deserializing dlib::avg_pool_");
1250             if (_nr != nr) throw serialization_error("Wrong nr found while deserializing dlib::avg_pool_");
1251             if (_nc != nc) throw serialization_error("Wrong nc found while deserializing dlib::avg_pool_");
1252             if (_stride_y != stride_y) throw serialization_error("Wrong stride_y found while deserializing dlib::avg_pool_");
1253             if (_stride_x != stride_x) throw serialization_error("Wrong stride_x found while deserializing dlib::avg_pool_");
1254         }
1255 
1256         friend std::ostream& operator<<(std::ostream& out, const avg_pool_& item)
1257         {
1258             out << "avg_pool ("
1259                 << "nr="<<_nr
1260                 << ", nc="<<_nc
1261                 << ", stride_y="<<_stride_y
1262                 << ", stride_x="<<_stride_x
1263                 << ", padding_y="<<item.padding_y_
1264                 << ", padding_x="<<item.padding_x_
1265                 << ")";
1266             return out;
1267         }
1268 
to_xml(const avg_pool_ & item,std::ostream & out)1269         friend void to_xml(const avg_pool_& item, std::ostream& out)
1270         {
1271             out << "<avg_pool"
1272                 << " nr='"<<_nr<<"'"
1273                 << " nc='"<<_nc<<"'"
1274                 << " stride_y='"<<_stride_y<<"'"
1275                 << " stride_x='"<<_stride_x<<"'"
1276                 << " padding_y='"<<item.padding_y_<<"'"
1277                 << " padding_x='"<<item.padding_x_<<"'"
1278                 << "/>\n";
1279         }
1280     private:
1281 
1282         tt::pooling ap;
1283         resizable_tensor params;
1284 
1285         int padding_y_;
1286         int padding_x_;
1287     };
1288 
1289     template <
1290         long nr,
1291         long nc,
1292         int stride_y,
1293         int stride_x,
1294         typename SUBNET
1295         >
1296     using avg_pool = add_layer<avg_pool_<nr,nc,stride_y,stride_x>, SUBNET>;
1297 
1298     template <
1299         typename SUBNET
1300         >
1301     using avg_pool_everything = add_layer<avg_pool_<0,0,1,1>, SUBNET>;
1302 
1303 // ----------------------------------------------------------------------------------------
1304 
1305     const double DEFAULT_LAYER_NORM_EPS = 1e-5;
1306 
1307     class layer_norm_
1308     {
1309     public:
1310         explicit layer_norm_(
1311             double eps_ = DEFAULT_LAYER_NORM_EPS
1312         ) :
1313             learning_rate_multiplier(1),
1314             weight_decay_multiplier(0),
1315             bias_learning_rate_multiplier(1),
1316             bias_weight_decay_multiplier(1),
1317             eps(eps_)
1318         {
1319         }
1320 
get_eps()1321         double get_eps() const { return eps; }
1322 
get_learning_rate_multiplier()1323         double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
get_weight_decay_multiplier()1324         double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
set_learning_rate_multiplier(double val)1325         void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
set_weight_decay_multiplier(double val)1326         void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }
1327 
get_bias_learning_rate_multiplier()1328         double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
get_bias_weight_decay_multiplier()1329         double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
set_bias_learning_rate_multiplier(double val)1330         void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
set_bias_weight_decay_multiplier(double val)1331         void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }
1332 
map_input_to_output(const dpoint & p)1333         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)1334         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
1335 
1336         template <typename SUBNET>
setup(const SUBNET & sub)1337         void setup (const SUBNET& sub)
1338         {
1339             gamma = alias_tensor(sub.get_output().num_samples());
1340             beta = gamma;
1341 
1342             params.set_size(gamma.size()+beta.size());
1343 
1344             gamma(params,0) = 1;
1345             beta(params,gamma.size()) = 0;
1346         }
1347 
1348         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)1349         void forward(const SUBNET& sub, resizable_tensor& output)
1350         {
1351             auto g = gamma(params,0);
1352             auto b = beta(params,gamma.size());
1353             tt::layer_normalize(eps, output, means, invstds, sub.get_output(), g, b);
1354         }
1355 
1356         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor & params_grad)1357         void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
1358         {
1359             auto g = gamma(params, 0);
1360             auto g_grad = gamma(params_grad, 0);
1361             auto b_grad = beta(params_grad, gamma.size());
1362             tt::layer_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad);
1363         }
1364 
get_layer_params()1365         const tensor& get_layer_params() const { return params; };
get_layer_params()1366         tensor& get_layer_params() { return params; };
1367 
serialize(const layer_norm_ & item,std::ostream & out)1368         friend void serialize(const layer_norm_& item, std::ostream& out)
1369         {
1370             serialize("layer_norm_", out);
1371             serialize(item.params, out);
1372             serialize(item.gamma, out);
1373             serialize(item.beta, out);
1374             serialize(item.means, out);
1375             serialize(item.invstds, out);
1376             serialize(item.learning_rate_multiplier, out);
1377             serialize(item.weight_decay_multiplier, out);
1378             serialize(item.bias_learning_rate_multiplier, out);
1379             serialize(item.bias_weight_decay_multiplier, out);
1380             serialize(item.eps, out);
1381         }
1382 
deserialize(layer_norm_ & item,std::istream & in)1383         friend void deserialize(layer_norm_& item, std::istream& in)
1384         {
1385             std::string version;
1386             deserialize(version, in);
1387             if (version != "layer_norm_")
1388                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::layer_norm_.");
1389             deserialize(item.params, in);
1390             deserialize(item.gamma, in);
1391             deserialize(item.beta, in);
1392             deserialize(item.means, in);
1393             deserialize(item.invstds, in);
1394             deserialize(item.learning_rate_multiplier, in);
1395             deserialize(item.weight_decay_multiplier, in);
1396             deserialize(item.bias_learning_rate_multiplier, in);
1397             deserialize(item.bias_weight_decay_multiplier, in);
1398             deserialize(item.eps, in);
1399         }
1400 
1401         friend std::ostream& operator<<(std::ostream& out, const layer_norm_& item)
1402         {
1403             out << "layer_norm";
1404             out << " eps="<<item.eps;
1405             out << " learning_rate_mult="<<item.learning_rate_multiplier;
1406             out << " weight_decay_mult="<<item.weight_decay_multiplier;
1407             out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
1408             out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
1409             return out;
1410         }
1411 
to_xml(const layer_norm_ & item,std::ostream & out)1412         friend void to_xml(const layer_norm_& item, std::ostream& out)
1413         {
1414             out << "layer_norm";
1415             out << " eps='"<<item.eps<<"'";
1416             out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'";
1417             out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
1418             out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'";
1419             out << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
1420             out << ">\n";
1421             out << mat(item.params);
1422             out << "</layer_norm>\n";
1423         }
1424 
1425     private:
1426         resizable_tensor params;
1427         alias_tensor gamma, beta;
1428         resizable_tensor means, invstds;
1429         double learning_rate_multiplier;
1430         double weight_decay_multiplier;
1431         double bias_learning_rate_multiplier;
1432         double bias_weight_decay_multiplier;
1433         double eps;
1434     };
1435 
1436     template <typename SUBNET>
1437     using layer_norm = add_layer<layer_norm_, SUBNET>;
1438 
1439 // ----------------------------------------------------------------------------------------
1440     enum layer_mode
1441     {
1442         CONV_MODE = 0,
1443         FC_MODE = 1
1444     };
1445 
1446     const double DEFAULT_BATCH_NORM_EPS = 0.0001;
1447 
1448     template <
1449         layer_mode mode
1450         >
1451     class bn_
1452     {
1453     public:
1454         explicit bn_(
1455             unsigned long window_size,
1456             double eps_ = DEFAULT_BATCH_NORM_EPS
1457         ) :
1458             num_updates(0),
1459             running_stats_window_size(window_size),
1460             learning_rate_multiplier(1),
1461             weight_decay_multiplier(0),
1462             bias_learning_rate_multiplier(1),
1463             bias_weight_decay_multiplier(1),
1464             eps(eps_)
1465         {
1466             DLIB_CASSERT(window_size > 0, "The batch normalization running stats window size can't be 0.");
1467         }
1468 
bn_()1469         bn_() : bn_(100) {}
1470 
get_mode()1471         layer_mode get_mode() const { return mode; }
get_running_stats_window_size()1472         unsigned long get_running_stats_window_size () const { return running_stats_window_size; }
set_running_stats_window_size(unsigned long new_window_size)1473         void set_running_stats_window_size (unsigned long new_window_size )
1474         {
1475             DLIB_CASSERT(new_window_size > 0, "The batch normalization running stats window size can't be 0.");
1476             running_stats_window_size = new_window_size;
1477         }
get_eps()1478         double get_eps() const { return eps; }
1479 
get_learning_rate_multiplier()1480         double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
get_weight_decay_multiplier()1481         double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
set_learning_rate_multiplier(double val)1482         void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
set_weight_decay_multiplier(double val)1483         void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }
1484 
get_bias_learning_rate_multiplier()1485         double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
get_bias_weight_decay_multiplier()1486         double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
set_bias_learning_rate_multiplier(double val)1487         void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
set_bias_weight_decay_multiplier(double val)1488         void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }
1489 
map_input_to_output(const dpoint & p)1490         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)1491         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
1492 
1493 
1494         template <typename SUBNET>
setup(const SUBNET & sub)1495         void setup (const SUBNET& sub)
1496         {
1497             if (mode == FC_MODE)
1498             {
1499                 gamma = alias_tensor(1,
1500                                 sub.get_output().k(),
1501                                 sub.get_output().nr(),
1502                                 sub.get_output().nc());
1503             }
1504             else
1505             {
1506                 gamma = alias_tensor(1, sub.get_output().k());
1507             }
1508             beta = gamma;
1509 
1510             params.set_size(gamma.size()+beta.size());
1511 
1512             gamma(params,0) = 1;
1513             beta(params,gamma.size()) = 0;
1514 
1515             running_means.copy_size(gamma(params,0));
1516             running_variances.copy_size(gamma(params,0));
1517             running_means = 0;
1518             running_variances = 1;
1519             num_updates = 0;
1520         }
1521 
1522         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)1523         void forward(const SUBNET& sub, resizable_tensor& output)
1524         {
1525             auto g = gamma(params,0);
1526             auto b = beta(params,gamma.size());
1527             if (sub.get_output().num_samples() > 1)
1528             {
1529                 const double decay = 1.0 - num_updates/(num_updates+1.0);
1530                 ++num_updates;
1531                 if (num_updates > running_stats_window_size)
1532                     num_updates = running_stats_window_size;
1533 
1534                 if (mode == FC_MODE)
1535                     tt::batch_normalize(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
1536                 else
1537                     tt::batch_normalize_conv(eps, output, means, invstds, decay, running_means, running_variances, sub.get_output(), g, b);
1538             }
1539             else // we are running in testing mode so we just linearly scale the input tensor.
1540             {
1541                 if (mode == FC_MODE)
1542                     tt::batch_normalize_inference(eps, output, sub.get_output(), g, b, running_means, running_variances);
1543                 else
1544                     tt::batch_normalize_conv_inference(eps, output, sub.get_output(), g, b, running_means, running_variances);
1545             }
1546         }
1547 
1548         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor & params_grad)1549         void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
1550         {
1551             auto g = gamma(params,0);
1552             auto g_grad = gamma(params_grad, 0);
1553             auto b_grad = beta(params_grad, gamma.size());
1554             if (mode == FC_MODE)
1555                 tt::batch_normalize_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
1556             else
1557                 tt::batch_normalize_conv_gradient(eps, gradient_input, means, invstds, sub.get_output(), g, sub.get_gradient_input(), g_grad, b_grad );
1558         }
1559 
get_layer_params()1560         const tensor& get_layer_params() const { return params; }
get_layer_params()1561         tensor& get_layer_params() { return params; }
1562 
serialize(const bn_ & item,std::ostream & out)1563         friend void serialize(const bn_& item, std::ostream& out)
1564         {
1565             if (mode == CONV_MODE)
1566                 serialize("bn_con2", out);
1567             else // if FC_MODE
1568                 serialize("bn_fc2", out);
1569             serialize(item.params, out);
1570             serialize(item.gamma, out);
1571             serialize(item.beta, out);
1572             serialize(item.means, out);
1573             serialize(item.invstds, out);
1574             serialize(item.running_means, out);
1575             serialize(item.running_variances, out);
1576             serialize(item.num_updates, out);
1577             serialize(item.running_stats_window_size, out);
1578             serialize(item.learning_rate_multiplier, out);
1579             serialize(item.weight_decay_multiplier, out);
1580             serialize(item.bias_learning_rate_multiplier, out);
1581             serialize(item.bias_weight_decay_multiplier, out);
1582             serialize(item.eps, out);
1583         }
1584 
deserialize(bn_ & item,std::istream & in)1585         friend void deserialize(bn_& item, std::istream& in)
1586         {
1587             std::string version;
1588             deserialize(version, in);
1589             if (mode == CONV_MODE)
1590             {
1591                 if (version != "bn_con2")
1592                     throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
1593             }
1594             else // must be in FC_MODE
1595             {
1596                 if (version != "bn_fc2")
1597                     throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::bn_.");
1598             }
1599 
1600             deserialize(item.params, in);
1601             deserialize(item.gamma, in);
1602             deserialize(item.beta, in);
1603             deserialize(item.means, in);
1604             deserialize(item.invstds, in);
1605             deserialize(item.running_means, in);
1606             deserialize(item.running_variances, in);
1607             deserialize(item.num_updates, in);
1608             deserialize(item.running_stats_window_size, in);
1609             deserialize(item.learning_rate_multiplier, in);
1610             deserialize(item.weight_decay_multiplier, in);
1611             deserialize(item.bias_learning_rate_multiplier, in);
1612             deserialize(item.bias_weight_decay_multiplier, in);
1613             deserialize(item.eps, in);
1614         }
1615 
1616         friend std::ostream& operator<<(std::ostream& out, const bn_& item)
1617         {
1618             if (mode == CONV_MODE)
1619                 out << "bn_con  ";
1620             else
1621                 out << "bn_fc   ";
1622             out << " eps="<<item.eps;
1623             out << " running_stats_window_size="<<item.running_stats_window_size;
1624             out << " learning_rate_mult="<<item.learning_rate_multiplier;
1625             out << " weight_decay_mult="<<item.weight_decay_multiplier;
1626             out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
1627             out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
1628             return out;
1629         }
1630 
to_xml(const bn_ & item,std::ostream & out)1631         friend void to_xml(const bn_& item, std::ostream& out)
1632         {
1633             if (mode==CONV_MODE)
1634                 out << "<bn_con";
1635             else
1636                 out << "<bn_fc";
1637 
1638             out << " eps='"<<item.eps<<"'";
1639             out << " running_stats_window_size='"<<item.running_stats_window_size<<"'";
1640             out << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'";
1641             out << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
1642             out << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'";
1643             out << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'";
1644             out << ">\n";
1645 
1646             out << mat(item.params);
1647 
1648             if (mode==CONV_MODE)
1649                 out << "</bn_con>\n";
1650             else
1651                 out << "</bn_fc>\n";
1652         }
1653 
1654     private:
1655 
1656         friend class affine_;
1657 
1658         resizable_tensor params;
1659         alias_tensor gamma, beta;
1660         resizable_tensor means, running_means;
1661         resizable_tensor invstds, running_variances;
1662         unsigned long num_updates;
1663         unsigned long running_stats_window_size;
1664         double learning_rate_multiplier;
1665         double weight_decay_multiplier;
1666         double bias_learning_rate_multiplier;
1667         double bias_weight_decay_multiplier;
1668         double eps;
1669     };
1670 
1671     template <typename SUBNET>
1672     using bn_con = add_layer<bn_<CONV_MODE>, SUBNET>;
1673     template <typename SUBNET>
1674     using bn_fc = add_layer<bn_<FC_MODE>, SUBNET>;
1675 
1676 // ----------------------------------------------------------------------------------------
1677 
1678     namespace impl
1679     {
1680         class visitor_bn_running_stats_window_size
1681         {
1682         public:
1683 
visitor_bn_running_stats_window_size(unsigned long new_window_size_)1684             visitor_bn_running_stats_window_size(unsigned long new_window_size_) : new_window_size(new_window_size_) {}
1685 
1686             template <typename T>
set_window_size(T &)1687             void set_window_size(T&) const
1688             {
1689                 // ignore other layer detail types
1690             }
1691 
1692             template < layer_mode mode >
set_window_size(bn_<mode> & l)1693             void set_window_size(bn_<mode>& l) const
1694             {
1695                 l.set_running_stats_window_size(new_window_size);
1696             }
1697 
1698             template<typename input_layer_type>
operator()1699             void operator()(size_t , input_layer_type& )  const
1700             {
1701                 // ignore other layers
1702             }
1703 
1704             template <typename T, typename U, typename E>
operator()1705             void operator()(size_t , add_layer<T,U,E>& l)  const
1706             {
1707                 set_window_size(l.layer_details());
1708             }
1709 
1710         private:
1711 
1712             unsigned long new_window_size;
1713         };
1714 
1715         class visitor_disable_input_bias
1716         {
1717         public:
1718 
1719             template <typename T>
disable_input_bias(T &)1720             void disable_input_bias(T&) const
1721             {
1722                 // ignore other layer types
1723             }
1724 
1725             // handle the standard case
1726             template <typename U, typename E>
disable_input_bias(add_layer<layer_norm_,U,E> & l)1727             void disable_input_bias(add_layer<layer_norm_, U, E>& l)
1728             {
1729                 disable_bias(l.subnet().layer_details());
1730                 set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
1731                 set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
1732             }
1733 
1734             template <layer_mode mode, typename U, typename E>
disable_input_bias(add_layer<bn_<mode>,U,E> & l)1735             void disable_input_bias(add_layer<bn_<mode>, U, E>& l)
1736             {
1737                 disable_bias(l.subnet().layer_details());
1738                 set_bias_learning_rate_multiplier(l.subnet().layer_details(), 0);
1739                 set_bias_weight_decay_multiplier(l.subnet().layer_details(), 0);
1740             }
1741 
1742             // handle input repeat layer case
1743             template <layer_mode mode, size_t N, template <typename> class R, typename U, typename E>
disable_input_bias(add_layer<bn_<mode>,repeat<N,R,U>,E> & l)1744             void disable_input_bias(add_layer<bn_<mode>, repeat<N, R, U>, E>& l)
1745             {
1746                 disable_bias(l.subnet().get_repeated_layer(0).layer_details());
1747                 set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
1748                 set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
1749             }
1750 
1751             template <size_t N, template <typename> class R, typename U, typename E>
disable_input_bias(add_layer<layer_norm_,repeat<N,R,U>,E> & l)1752             void disable_input_bias(add_layer<layer_norm_, repeat<N, R, U>, E>& l)
1753             {
1754                 disable_bias(l.subnet().get_repeated_layer(0).layer_details());
1755                 set_bias_learning_rate_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
1756                 set_bias_weight_decay_multiplier(l.subnet().get_repeated_layer(0).layer_details(), 0);
1757             }
1758 
1759             // handle input repeat layer with tag case
1760             template <layer_mode mode, unsigned long ID, typename E, typename F>
disable_input_bias(add_layer<bn_<mode>,add_tag_layer<ID,impl::repeat_input_layer,E>,F> &)1761             void disable_input_bias(add_layer<bn_<mode>, add_tag_layer<ID, impl::repeat_input_layer, E>, F>& )
1762             {
1763             }
1764 
1765             template <unsigned long ID, typename E, typename F>
disable_input_bias(add_layer<layer_norm_,add_tag_layer<ID,impl::repeat_input_layer,E>,F> &)1766             void disable_input_bias(add_layer<layer_norm_, add_tag_layer<ID, impl::repeat_input_layer, E>, F>& )
1767             {
1768             }
1769 
1770             template<typename input_layer_type>
operator()1771             void operator()(size_t , input_layer_type& ) const
1772             {
1773                 // ignore other layers
1774             }
1775 
1776             template <typename T, typename U, typename E>
operator()1777             void operator()(size_t , add_layer<T,U,E>& l)
1778             {
1779                 disable_input_bias(l);
1780             }
1781         };
1782     }
1783 
1784     template <typename net_type>
set_all_bn_running_stats_window_sizes(net_type & net,unsigned long new_window_size)1785     void set_all_bn_running_stats_window_sizes (
1786         net_type& net,
1787         unsigned long new_window_size
1788     )
1789     {
1790         visit_layers(net, impl::visitor_bn_running_stats_window_size(new_window_size));
1791     }
1792 
1793     template <typename net_type>
disable_duplicative_biases(net_type & net)1794     void disable_duplicative_biases (
1795         net_type& net
1796     )
1797     {
1798         visit_layers(net, impl::visitor_disable_input_bias());
1799     }
1800 
1801 // ----------------------------------------------------------------------------------------
1802 
1803     enum fc_bias_mode
1804     {
1805         FC_HAS_BIAS = 0,
1806         FC_NO_BIAS = 1
1807     };
1808 
1809     struct num_fc_outputs
1810     {
num_fc_outputsnum_fc_outputs1811         num_fc_outputs(unsigned long n) : num_outputs(n) {}
1812         unsigned long num_outputs;
1813     };
1814 
1815     template <
1816         unsigned long num_outputs_,
1817         fc_bias_mode bias_mode
1818         >
1819     class fc_
1820     {
1821         static_assert(num_outputs_ > 0, "The number of outputs from a fc_ layer must be > 0");
1822 
1823     public:
fc_(num_fc_outputs o)1824         fc_(num_fc_outputs o) : num_outputs(o.num_outputs), num_inputs(0),
1825             learning_rate_multiplier(1),
1826             weight_decay_multiplier(1),
1827             bias_learning_rate_multiplier(1),
1828             bias_weight_decay_multiplier(0),
1829             use_bias(true)
1830         {}
1831 
fc_()1832         fc_() : fc_(num_fc_outputs(num_outputs_)) {}
1833 
get_learning_rate_multiplier()1834         double get_learning_rate_multiplier () const  { return learning_rate_multiplier; }
get_weight_decay_multiplier()1835         double get_weight_decay_multiplier () const   { return weight_decay_multiplier; }
set_learning_rate_multiplier(double val)1836         void set_learning_rate_multiplier(double val) { learning_rate_multiplier = val; }
set_weight_decay_multiplier(double val)1837         void set_weight_decay_multiplier(double val)  { weight_decay_multiplier  = val; }
1838 
get_bias_learning_rate_multiplier()1839         double get_bias_learning_rate_multiplier () const  { return bias_learning_rate_multiplier; }
get_bias_weight_decay_multiplier()1840         double get_bias_weight_decay_multiplier () const   { return bias_weight_decay_multiplier; }
set_bias_learning_rate_multiplier(double val)1841         void set_bias_learning_rate_multiplier(double val) { bias_learning_rate_multiplier = val; }
set_bias_weight_decay_multiplier(double val)1842         void set_bias_weight_decay_multiplier(double val)  { bias_weight_decay_multiplier  = val; }
disable_bias()1843         void disable_bias() { use_bias = false; }
bias_is_disabled()1844         bool bias_is_disabled() const { return !use_bias; }
1845 
get_num_outputs()1846         unsigned long get_num_outputs (
1847         ) const { return num_outputs; }
1848 
set_num_outputs(long num)1849         void set_num_outputs(long num)
1850         {
1851             DLIB_CASSERT(num > 0);
1852             if (num != (long)num_outputs)
1853             {
1854                 DLIB_CASSERT(get_layer_params().size() == 0,
1855                     "You can't change the number of filters in fc_ if the parameter tensor has already been allocated.");
1856                 num_outputs = num;
1857             }
1858         }
1859 
get_bias_mode()1860         fc_bias_mode get_bias_mode (
1861         ) const { return bias_mode; }
1862 
1863         template <typename SUBNET>
setup(const SUBNET & sub)1864         void setup (const SUBNET& sub)
1865         {
1866             num_inputs = sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k();
1867             if (bias_mode == FC_HAS_BIAS && use_bias)
1868                 params.set_size(num_inputs+1, num_outputs);
1869             else
1870                 params.set_size(num_inputs, num_outputs);
1871 
1872             dlib::rand rnd(std::rand());
1873             randomize_parameters(params, num_inputs+num_outputs, rnd);
1874 
1875             weights = alias_tensor(num_inputs, num_outputs);
1876 
1877             if (bias_mode == FC_HAS_BIAS && use_bias)
1878             {
1879                 biases = alias_tensor(1,num_outputs);
1880                 // set the initial bias values to zero
1881                 biases(params,weights.size()) = 0;
1882             }
1883         }
1884 
1885         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)1886         void forward(const SUBNET& sub, resizable_tensor& output)
1887         {
1888             DLIB_CASSERT((long)num_inputs == sub.get_output().nr()*sub.get_output().nc()*sub.get_output().k(),
1889                 "The size of the input tensor to this fc layer doesn't match the size the fc layer was trained with.");
1890             output.set_size(sub.get_output().num_samples(), num_outputs);
1891 
1892             auto w = weights(params, 0);
1893             tt::gemm(0,output, 1,sub.get_output(),false, w,false);
1894             if (bias_mode == FC_HAS_BIAS && use_bias)
1895             {
1896                 auto b = biases(params, weights.size());
1897                 tt::add(1,output,1,b);
1898             }
1899         }
1900 
1901         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor & params_grad)1902         void backward(const tensor& gradient_input, SUBNET& sub, tensor& params_grad)
1903         {
1904             // no point computing the parameter gradients if they won't be used.
1905             if (learning_rate_multiplier != 0)
1906             {
1907                 // compute the gradient of the weight parameters.
1908                 auto pw = weights(params_grad, 0);
1909                 tt::gemm(0,pw, 1,sub.get_output(),true, gradient_input,false);
1910 
1911                 if (bias_mode == FC_HAS_BIAS && use_bias)
1912                 {
1913                     // compute the gradient of the bias parameters.
1914                     auto pb = biases(params_grad, weights.size());
1915                     tt::assign_bias_gradient(pb, gradient_input);
1916                 }
1917             }
1918 
1919             // compute the gradient for the data
1920             auto w = weights(params, 0);
1921             tt::gemm(1,sub.get_gradient_input(), 1,gradient_input,false, w,true);
1922         }
1923 
get_weights()1924         alias_tensor_instance get_weights()
1925         {
1926             return weights(params, 0);
1927         }
1928 
get_weights()1929         alias_tensor_const_instance get_weights() const
1930         {
1931             return weights(params, 0);
1932         }
1933 
get_biases()1934         alias_tensor_instance get_biases()
1935         {
1936             static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector "
1937                 "to be retrieved, as per template parameter 'bias_mode'.");
1938             return biases(params, weights.size());
1939         }
1940 
get_biases()1941         alias_tensor_const_instance get_biases() const
1942         {
1943             static_assert(bias_mode == FC_HAS_BIAS, "This fc_ layer doesn't have a bias vector "
1944                 "to be retrieved, as per template parameter 'bias_mode'.");
1945             return biases(params, weights.size());
1946         }
1947 
get_layer_params()1948         const tensor& get_layer_params() const { return params; }
get_layer_params()1949         tensor& get_layer_params() { return params; }
1950 
serialize(const fc_ & item,std::ostream & out)1951         friend void serialize(const fc_& item, std::ostream& out)
1952         {
1953             serialize("fc_3", out);
1954             serialize(item.num_outputs, out);
1955             serialize(item.num_inputs, out);
1956             serialize(item.params, out);
1957             serialize(item.weights, out);
1958             serialize(item.biases, out);
1959             serialize((int)bias_mode, out);
1960             serialize(item.learning_rate_multiplier, out);
1961             serialize(item.weight_decay_multiplier, out);
1962             serialize(item.bias_learning_rate_multiplier, out);
1963             serialize(item.bias_weight_decay_multiplier, out);
1964             serialize(item.use_bias, out);
1965         }
1966 
deserialize(fc_ & item,std::istream & in)1967         friend void deserialize(fc_& item, std::istream& in)
1968         {
1969             std::string version;
1970             deserialize(version, in);
1971             if (version == "fc_2" || version == "fc_3")
1972             {
1973                 deserialize(item.num_outputs, in);
1974                 deserialize(item.num_inputs, in);
1975                 deserialize(item.params, in);
1976                 deserialize(item.weights, in);
1977                 deserialize(item.biases, in);
1978                 int bmode = 0;
1979                 deserialize(bmode, in);
1980                 if (bias_mode != (fc_bias_mode)bmode) throw serialization_error("Wrong fc_bias_mode found while deserializing dlib::fc_");
1981                 deserialize(item.learning_rate_multiplier, in);
1982                 deserialize(item.weight_decay_multiplier, in);
1983                 deserialize(item.bias_learning_rate_multiplier, in);
1984                 deserialize(item.bias_weight_decay_multiplier, in);
1985                 if (version == "fc_3")
1986                 {
1987                     deserialize(item.use_bias, in);
1988                 }
1989             }
1990             else
1991             {
1992                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::fc_.");
1993             }
1994         }
1995 
1996         friend std::ostream& operator<<(std::ostream& out, const fc_& item)
1997         {
1998             if (bias_mode == FC_HAS_BIAS)
1999             {
2000                 out << "fc\t ("
2001                     << "num_outputs="<<item.num_outputs
2002                     << ")";
2003                 out << " learning_rate_mult="<<item.learning_rate_multiplier;
2004                 out << " weight_decay_mult="<<item.weight_decay_multiplier;
2005                 if (item.use_bias)
2006                 {
2007                     out << " bias_learning_rate_mult="<<item.bias_learning_rate_multiplier;
2008                     out << " bias_weight_decay_mult="<<item.bias_weight_decay_multiplier;
2009                 }
2010                 else
2011                 {
2012                     out << " use_bias=false";
2013                 }
2014             }
2015             else
2016             {
2017                 out << "fc_no_bias ("
2018                     << "num_outputs="<<item.num_outputs
2019                     << ")";
2020                 out << " learning_rate_mult="<<item.learning_rate_multiplier;
2021                 out << " weight_decay_mult="<<item.weight_decay_multiplier;
2022             }
2023             return out;
2024         }
2025 
to_xml(const fc_ & item,std::ostream & out)2026         friend void to_xml(const fc_& item, std::ostream& out)
2027         {
2028             if (bias_mode==FC_HAS_BIAS)
2029             {
2030                 out << "<fc"
2031                     << " num_outputs='"<<item.num_outputs<<"'"
2032                     << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
2033                     << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'"
2034                     << " bias_learning_rate_mult='"<<item.bias_learning_rate_multiplier<<"'"
2035                     << " bias_weight_decay_mult='"<<item.bias_weight_decay_multiplier<<"'"
2036                     << " use_bias='"<<(item.use_bias?"true":"false")<<"'>\n";
2037                 out << ">\n";
2038                 out << mat(item.params);
2039                 out << "</fc>\n";
2040             }
2041             else
2042             {
2043                 out << "<fc_no_bias"
2044                     << " num_outputs='"<<item.num_outputs<<"'"
2045                     << " learning_rate_mult='"<<item.learning_rate_multiplier<<"'"
2046                     << " weight_decay_mult='"<<item.weight_decay_multiplier<<"'";
2047                 out << ">\n";
2048                 out << mat(item.params);
2049                 out << "</fc_no_bias>\n";
2050             }
2051         }
2052 
2053     private:
2054 
2055         unsigned long num_outputs;
2056         unsigned long num_inputs;
2057         resizable_tensor params;
2058         alias_tensor weights, biases;
2059         double learning_rate_multiplier;
2060         double weight_decay_multiplier;
2061         double bias_learning_rate_multiplier;
2062         double bias_weight_decay_multiplier;
2063         bool use_bias;
2064     };
2065 
2066     template <
2067         unsigned long num_outputs,
2068         typename SUBNET
2069         >
2070     using fc = add_layer<fc_<num_outputs,FC_HAS_BIAS>, SUBNET>;
2071 
2072     template <
2073         unsigned long num_outputs,
2074         typename SUBNET
2075         >
2076     using fc_no_bias = add_layer<fc_<num_outputs,FC_NO_BIAS>, SUBNET>;
2077 
2078 // ----------------------------------------------------------------------------------------
2079 
2080     class dropout_
2081     {
2082     public:
2083         explicit dropout_(
2084             float drop_rate_ = 0.5
2085         ) :
drop_rate(drop_rate_)2086             drop_rate(drop_rate_),
2087             rnd(std::rand())
2088         {
2089             DLIB_CASSERT(0 <= drop_rate && drop_rate <= 1);
2090         }
2091 
2092         // We have to add a copy constructor and assignment operator because the rnd object
2093         // is non-copyable.
dropout_(const dropout_ & item)2094         dropout_(
2095             const dropout_& item
2096         ) : drop_rate(item.drop_rate), mask(item.mask), rnd(std::rand())
2097         {}
2098 
2099         dropout_& operator= (
2100             const dropout_& item
2101         )
2102         {
2103             if (this == &item)
2104                 return *this;
2105 
2106             drop_rate = item.drop_rate;
2107             mask = item.mask;
2108             return *this;
2109         }
2110 
get_drop_rate()2111         float get_drop_rate (
2112         ) const { return drop_rate; }
2113 
2114         template <typename SUBNET>
setup(const SUBNET &)2115         void setup (const SUBNET& /*sub*/)
2116         {
2117         }
2118 
forward_inplace(const tensor & input,tensor & output)2119         void forward_inplace(const tensor& input, tensor& output)
2120         {
2121             // create a random mask and use it to filter the data
2122             mask.copy_size(input);
2123             rnd.fill_uniform(mask);
2124             tt::threshold(mask, drop_rate);
2125             tt::multiply(false, output, input, mask);
2126         }
2127 
backward_inplace(const tensor & gradient_input,tensor & data_grad,tensor &)2128         void backward_inplace(
2129             const tensor& gradient_input,
2130             tensor& data_grad,
2131             tensor& /*params_grad*/
2132         )
2133         {
2134             if (is_same_object(gradient_input, data_grad))
2135                 tt::multiply(false, data_grad, mask, gradient_input);
2136             else
2137                 tt::multiply(true, data_grad, mask, gradient_input);
2138         }
2139 
map_input_to_output(const dpoint & p)2140         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)2141         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
2142 
get_layer_params()2143         const tensor& get_layer_params() const { return params; }
get_layer_params()2144         tensor& get_layer_params() { return params; }
2145 
serialize(const dropout_ & item,std::ostream & out)2146         friend void serialize(const dropout_& item, std::ostream& out)
2147         {
2148             serialize("dropout_", out);
2149             serialize(item.drop_rate, out);
2150             serialize(item.mask, out);
2151         }
2152 
deserialize(dropout_ & item,std::istream & in)2153         friend void deserialize(dropout_& item, std::istream& in)
2154         {
2155             std::string version;
2156             deserialize(version, in);
2157             if (version != "dropout_")
2158                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::dropout_.");
2159             deserialize(item.drop_rate, in);
2160             deserialize(item.mask, in);
2161         }
2162 
clean()2163         void clean(
2164         )
2165         {
2166             mask.clear();
2167         }
2168 
2169         friend std::ostream& operator<<(std::ostream& out, const dropout_& item)
2170         {
2171             out << "dropout\t ("
2172                 << "drop_rate="<<item.drop_rate
2173                 << ")";
2174             return out;
2175         }
2176 
to_xml(const dropout_ & item,std::ostream & out)2177         friend void to_xml(const dropout_& item, std::ostream& out)
2178         {
2179             out << "<dropout"
2180                 << " drop_rate='"<<item.drop_rate<<"'";
2181             out << "/>\n";
2182         }
2183 
2184     private:
2185         float drop_rate;
2186         resizable_tensor mask;
2187 
2188         tt::tensor_rand rnd;
2189         resizable_tensor params; // unused
2190     };
2191 
2192 
2193     template <typename SUBNET>
2194     using dropout = add_layer<dropout_, SUBNET>;
2195 
2196 // ----------------------------------------------------------------------------------------
2197 
2198     class multiply_
2199     {
2200     public:
2201         explicit multiply_(
2202             float val_ = 0.5
2203         ) :
val(val_)2204             val(val_)
2205         {
2206         }
2207 
multiply_(const dropout_ & item)2208         multiply_ (
2209             const dropout_& item
2210         ) : val(1-item.get_drop_rate()) {}
2211 
get_multiply_value()2212         float get_multiply_value (
2213         ) const { return val; }
2214 
2215         template <typename SUBNET>
setup(const SUBNET &)2216         void setup (const SUBNET& /*sub*/)
2217         {
2218         }
2219 
forward_inplace(const tensor & input,tensor & output)2220         void forward_inplace(const tensor& input, tensor& output)
2221         {
2222             tt::affine_transform(output, input, val);
2223         }
2224 
map_input_to_output(const dpoint & p)2225         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)2226         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
2227 
backward_inplace(const tensor & gradient_input,tensor & data_grad,tensor &)2228         void backward_inplace(
2229             const tensor& gradient_input,
2230             tensor& data_grad,
2231             tensor& /*params_grad*/
2232         )
2233         {
2234             if (is_same_object(gradient_input, data_grad))
2235                 tt::affine_transform(data_grad, gradient_input, val);
2236             else
2237                 tt::affine_transform(data_grad, data_grad, gradient_input, 1, val);
2238         }
2239 
get_layer_params()2240         const tensor& get_layer_params() const { return params; }
get_layer_params()2241         tensor& get_layer_params() { return params; }
2242 
serialize(const multiply_ & item,std::ostream & out)2243         friend void serialize(const multiply_& item, std::ostream& out)
2244         {
2245             serialize("multiply_", out);
2246             serialize(item.val, out);
2247         }
2248 
deserialize(multiply_ & item,std::istream & in)2249         friend void deserialize(multiply_& item, std::istream& in)
2250         {
2251             std::string version;
2252             deserialize(version, in);
2253             if (version == "dropout_")
2254             {
2255                 // Since we can build a multiply_ from a dropout_ we check if that's what
2256                 // is in the stream and if so then just convert it right here.
2257                 unserialize sin(version, in);
2258                 dropout_ temp;
2259                 deserialize(temp, sin);
2260                 item = temp;
2261                 return;
2262             }
2263 
2264             if (version != "multiply_")
2265                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::multiply_.");
2266             deserialize(item.val, in);
2267         }
2268 
2269         friend std::ostream& operator<<(std::ostream& out, const multiply_& item)
2270         {
2271             out << "multiply ("
2272                 << "val="<<item.val
2273                 << ")";
2274             return out;
2275         }
2276 
to_xml(const multiply_ & item,std::ostream & out)2277         friend void to_xml(const multiply_& item, std::ostream& out)
2278         {
2279             out << "<multiply"
2280                 << " val='"<<item.val<<"'";
2281             out << "/>\n";
2282         }
2283     private:
2284         float val;
2285         resizable_tensor params; // unused
2286     };
2287 
2288     template <typename SUBNET>
2289     using multiply = add_layer<multiply_, SUBNET>;
2290 
2291 // ----------------------------------------------------------------------------------------
2292 
2293     class affine_
2294     {
2295     public:
affine_()2296         affine_(
2297         ) : mode(FC_MODE)
2298         {
2299         }
2300 
affine_(layer_mode mode_)2301         affine_(
2302             layer_mode mode_
2303         ) : mode(mode_)
2304         {
2305         }
2306 
2307         template <
2308             layer_mode bnmode
2309             >
affine_(const bn_<bnmode> & item)2310         affine_(
2311             const bn_<bnmode>& item
2312         )
2313         {
2314             gamma = item.gamma;
2315             beta = item.beta;
2316             mode = bnmode;
2317 
2318             params.copy_size(item.params);
2319 
2320             auto g = gamma(params,0);
2321             auto b = beta(params,gamma.size());
2322 
2323             resizable_tensor temp(item.params);
2324             auto sg = gamma(temp,0);
2325             auto sb = beta(temp,gamma.size());
2326 
2327             g = pointwise_divide(mat(sg), sqrt(mat(item.running_variances)+item.get_eps()));
2328             b = mat(sb) - pointwise_multiply(mat(g), mat(item.running_means));
2329         }
2330 
get_mode()2331         layer_mode get_mode() const { return mode; }
2332 
map_input_to_output(const dpoint & p)2333         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)2334         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
2335 
2336         template <typename SUBNET>
setup(const SUBNET & sub)2337         void setup (const SUBNET& sub)
2338         {
2339             if (mode == FC_MODE)
2340             {
2341                 gamma = alias_tensor(1,
2342                                 sub.get_output().k(),
2343                                 sub.get_output().nr(),
2344                                 sub.get_output().nc());
2345             }
2346             else
2347             {
2348                 gamma = alias_tensor(1, sub.get_output().k());
2349             }
2350             beta = gamma;
2351 
2352             params.set_size(gamma.size()+beta.size());
2353 
2354             gamma(params,0) = 1;
2355             beta(params,gamma.size()) = 0;
2356         }
2357 
forward_inplace(const tensor & input,tensor & output)2358         void forward_inplace(const tensor& input, tensor& output)
2359         {
2360             auto g = gamma(params,0);
2361             auto b = beta(params,gamma.size());
2362             if (mode == FC_MODE)
2363                 tt::affine_transform(output, input, g, b);
2364             else
2365                 tt::affine_transform_conv(output, input, g, b);
2366         }
2367 
backward_inplace(const tensor & gradient_input,tensor & data_grad,tensor &)2368         void backward_inplace(
2369             const tensor& gradient_input,
2370             tensor& data_grad,
2371             tensor& /*params_grad*/
2372         )
2373         {
2374             auto g = gamma(params,0);
2375             auto b = beta(params,gamma.size());
2376 
2377             // We are computing the gradient of dot(gradient_input, computed_output*g + b)
2378             if (mode == FC_MODE)
2379             {
2380                 if (is_same_object(gradient_input, data_grad))
2381                     tt::multiply(false, data_grad, gradient_input, g);
2382                 else
2383                     tt::multiply(true, data_grad, gradient_input, g);
2384             }
2385             else
2386             {
2387                 if (is_same_object(gradient_input, data_grad))
2388                     tt::multiply_conv(false, data_grad, gradient_input, g);
2389                 else
2390                     tt::multiply_conv(true, data_grad, gradient_input, g);
2391             }
2392         }
2393 
get_layer_params()2394         const tensor& get_layer_params() const { return empty_params; }
get_layer_params()2395         tensor& get_layer_params() { return empty_params; }
2396 
serialize(const affine_ & item,std::ostream & out)2397         friend void serialize(const affine_& item, std::ostream& out)
2398         {
2399             serialize("affine_", out);
2400             serialize(item.params, out);
2401             serialize(item.gamma, out);
2402             serialize(item.beta, out);
2403             serialize((int)item.mode, out);
2404         }
2405 
deserialize(affine_ & item,std::istream & in)2406         friend void deserialize(affine_& item, std::istream& in)
2407         {
2408             std::string version;
2409             deserialize(version, in);
2410             if (version == "bn_con2")
2411             {
2412                 // Since we can build an affine_ from a bn_ we check if that's what is in
2413                 // the stream and if so then just convert it right here.
2414                 unserialize sin(version, in);
2415                 bn_<CONV_MODE> temp;
2416                 deserialize(temp, sin);
2417                 item = temp;
2418                 return;
2419             }
2420             else if (version == "bn_fc2")
2421             {
2422                 // Since we can build an affine_ from a bn_ we check if that's what is in
2423                 // the stream and if so then just convert it right here.
2424                 unserialize sin(version, in);
2425                 bn_<FC_MODE> temp;
2426                 deserialize(temp, sin);
2427                 item = temp;
2428                 return;
2429             }
2430 
2431             if (version != "affine_")
2432                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::affine_.");
2433             deserialize(item.params, in);
2434             deserialize(item.gamma, in);
2435             deserialize(item.beta, in);
2436             int mode;
2437             deserialize(mode, in);
2438             item.mode = (layer_mode)mode;
2439         }
2440 
2441         friend std::ostream& operator<<(std::ostream& out, const affine_& /*item*/)
2442         {
2443             out << "affine";
2444             return out;
2445         }
2446 
to_xml(const affine_ & item,std::ostream & out)2447         friend void to_xml(const affine_& item, std::ostream& out)
2448         {
2449             if (item.mode==CONV_MODE)
2450                 out << "<affine_con>\n";
2451             else
2452                 out << "<affine_fc>\n";
2453 
2454             out << mat(item.params);
2455 
2456             if (item.mode==CONV_MODE)
2457                 out << "</affine_con>\n";
2458             else
2459                 out << "</affine_fc>\n";
2460         }
2461 
2462     private:
2463         resizable_tensor params, empty_params;
2464         alias_tensor gamma, beta;
2465         layer_mode mode;
2466     };
2467 
2468     template <typename SUBNET>
2469     using affine = add_layer<affine_, SUBNET>;
2470 
2471 // ----------------------------------------------------------------------------------------
2472 
2473     template <
2474         template<typename> class tag
2475         >
2476     class add_prev_
2477     {
2478     public:
2479         const static unsigned long id = tag_id<tag>::id;
2480 
add_prev_()2481         add_prev_()
2482         {
2483         }
2484 
2485         template <typename SUBNET>
setup(const SUBNET &)2486         void setup (const SUBNET& /*sub*/)
2487         {
2488         }
2489 
2490         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)2491         void forward(const SUBNET& sub, resizable_tensor& output)
2492         {
2493             auto&& t1 = sub.get_output();
2494             auto&& t2 = layer<tag>(sub).get_output();
2495             output.set_size(std::max(t1.num_samples(),t2.num_samples()),
2496                             std::max(t1.k(),t2.k()),
2497                             std::max(t1.nr(),t2.nr()),
2498                             std::max(t1.nc(),t2.nc()));
2499             tt::add(output, t1, t2);
2500         }
2501 
2502         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)2503         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
2504         {
2505             // The gradient just flows backwards to the two layers that forward() added
2506             // together.
2507             tt::add(sub.get_gradient_input(), sub.get_gradient_input(), gradient_input);
2508             tt::add(layer<tag>(sub).get_gradient_input(), layer<tag>(sub).get_gradient_input(), gradient_input);
2509         }
2510 
get_layer_params()2511         const tensor& get_layer_params() const { return params; }
get_layer_params()2512         tensor& get_layer_params() { return params; }
2513 
map_input_to_output(const dpoint & p)2514         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)2515         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
2516 
serialize(const add_prev_ &,std::ostream & out)2517         friend void serialize(const add_prev_& /*item*/, std::ostream& out)
2518         {
2519             serialize("add_prev_", out);
2520         }
2521 
deserialize(add_prev_ &,std::istream & in)2522         friend void deserialize(add_prev_& /*item*/, std::istream& in)
2523         {
2524             std::string version;
2525             deserialize(version, in);
2526             if (version != "add_prev_")
2527                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::add_prev_.");
2528         }
2529         friend std::ostream& operator<<(std::ostream& out, const add_prev_& /*item*/)
2530         {
2531             out << "add_prev"<<id;
2532             return out;
2533         }
2534 
to_xml(const add_prev_ &,std::ostream & out)2535         friend void to_xml(const add_prev_& /*item*/, std::ostream& out)
2536         {
2537             out << "<add_prev tag='"<<id<<"'/>\n";
2538         }
2539 
2540     private:
2541         resizable_tensor params;
2542     };
2543 
2544     template <
2545         template<typename> class tag,
2546         typename SUBNET
2547         >
2548     using add_prev = add_layer<add_prev_<tag>, SUBNET>;
2549 
2550     template <typename SUBNET> using add_prev1  = add_prev<tag1, SUBNET>;
2551     template <typename SUBNET> using add_prev2  = add_prev<tag2, SUBNET>;
2552     template <typename SUBNET> using add_prev3  = add_prev<tag3, SUBNET>;
2553     template <typename SUBNET> using add_prev4  = add_prev<tag4, SUBNET>;
2554     template <typename SUBNET> using add_prev5  = add_prev<tag5, SUBNET>;
2555     template <typename SUBNET> using add_prev6  = add_prev<tag6, SUBNET>;
2556     template <typename SUBNET> using add_prev7  = add_prev<tag7, SUBNET>;
2557     template <typename SUBNET> using add_prev8  = add_prev<tag8, SUBNET>;
2558     template <typename SUBNET> using add_prev9  = add_prev<tag9, SUBNET>;
2559     template <typename SUBNET> using add_prev10 = add_prev<tag10, SUBNET>;
2560 
2561     using add_prev1_  = add_prev_<tag1>;
2562     using add_prev2_  = add_prev_<tag2>;
2563     using add_prev3_  = add_prev_<tag3>;
2564     using add_prev4_  = add_prev_<tag4>;
2565     using add_prev5_  = add_prev_<tag5>;
2566     using add_prev6_  = add_prev_<tag6>;
2567     using add_prev7_  = add_prev_<tag7>;
2568     using add_prev8_  = add_prev_<tag8>;
2569     using add_prev9_  = add_prev_<tag9>;
2570     using add_prev10_ = add_prev_<tag10>;
2571 
2572 // ----------------------------------------------------------------------------------------
2573 
2574     template <
2575         template<typename> class tag
2576         >
2577     class mult_prev_
2578     {
2579     public:
2580         const static unsigned long id = tag_id<tag>::id;
2581 
mult_prev_()2582         mult_prev_()
2583         {
2584         }
2585 
2586         template <typename SUBNET>
setup(const SUBNET &)2587         void setup (const SUBNET& /*sub*/)
2588         {
2589         }
2590 
2591         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)2592         void forward(const SUBNET& sub, resizable_tensor& output)
2593         {
2594             auto&& t1 = sub.get_output();
2595             auto&& t2 = layer<tag>(sub).get_output();
2596             output.set_size(std::max(t1.num_samples(),t2.num_samples()),
2597                             std::max(t1.k(),t2.k()),
2598                             std::max(t1.nr(),t2.nr()),
2599                             std::max(t1.nc(),t2.nc()));
2600             tt::multiply_zero_padded(false, output, t1, t2);
2601         }
2602 
2603         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)2604         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
2605         {
2606             auto&& t1 = sub.get_output();
2607             auto&& t2 = layer<tag>(sub).get_output();
2608             // The gradient just flows backwards to the two layers that forward()
2609             // multiplied together.
2610             tt::multiply_zero_padded(true, sub.get_gradient_input(), t2, gradient_input);
2611             tt::multiply_zero_padded(true, layer<tag>(sub).get_gradient_input(), t1, gradient_input);
2612         }
2613 
get_layer_params()2614         const tensor& get_layer_params() const { return params; }
get_layer_params()2615         tensor& get_layer_params() { return params; }
2616 
map_input_to_output(const dpoint & p)2617         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)2618         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
2619 
serialize(const mult_prev_ &,std::ostream & out)2620         friend void serialize(const mult_prev_& /*item*/, std::ostream& out)
2621         {
2622             serialize("mult_prev_", out);
2623         }
2624 
deserialize(mult_prev_ &,std::istream & in)2625         friend void deserialize(mult_prev_& /*item*/, std::istream& in)
2626         {
2627             std::string version;
2628             deserialize(version, in);
2629             if (version != "mult_prev_")
2630                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::mult_prev_.");
2631         }
2632 
2633         friend std::ostream& operator<<(std::ostream& out, const mult_prev_& /*item*/)
2634         {
2635             out << "mult_prev"<<id;
2636             return out;
2637         }
2638 
to_xml(const mult_prev_ &,std::ostream & out)2639         friend void to_xml(const mult_prev_& /*item*/, std::ostream& out)
2640         {
2641             out << "<mult_prev tag='"<<id<<"'/>\n";
2642         }
2643 
2644     private:
2645         resizable_tensor params;
2646     };
2647 
2648     template <
2649         template<typename> class tag,
2650         typename SUBNET
2651         >
2652     using mult_prev = add_layer<mult_prev_<tag>, SUBNET>;
2653 
2654     template <typename SUBNET> using mult_prev1  = mult_prev<tag1, SUBNET>;
2655     template <typename SUBNET> using mult_prev2  = mult_prev<tag2, SUBNET>;
2656     template <typename SUBNET> using mult_prev3  = mult_prev<tag3, SUBNET>;
2657     template <typename SUBNET> using mult_prev4  = mult_prev<tag4, SUBNET>;
2658     template <typename SUBNET> using mult_prev5  = mult_prev<tag5, SUBNET>;
2659     template <typename SUBNET> using mult_prev6  = mult_prev<tag6, SUBNET>;
2660     template <typename SUBNET> using mult_prev7  = mult_prev<tag7, SUBNET>;
2661     template <typename SUBNET> using mult_prev8  = mult_prev<tag8, SUBNET>;
2662     template <typename SUBNET> using mult_prev9  = mult_prev<tag9, SUBNET>;
2663     template <typename SUBNET> using mult_prev10 = mult_prev<tag10, SUBNET>;
2664 
2665     using mult_prev1_  = mult_prev_<tag1>;
2666     using mult_prev2_  = mult_prev_<tag2>;
2667     using mult_prev3_  = mult_prev_<tag3>;
2668     using mult_prev4_  = mult_prev_<tag4>;
2669     using mult_prev5_  = mult_prev_<tag5>;
2670     using mult_prev6_  = mult_prev_<tag6>;
2671     using mult_prev7_  = mult_prev_<tag7>;
2672     using mult_prev8_  = mult_prev_<tag8>;
2673     using mult_prev9_  = mult_prev_<tag9>;
2674     using mult_prev10_ = mult_prev_<tag10>;
2675 
2676 // ----------------------------------------------------------------------------------------
2677 
2678     template <
2679         template<typename> class tag
2680         >
2681     class resize_prev_to_tagged_
2682     {
2683     public:
2684         const static unsigned long id = tag_id<tag>::id;
2685 
resize_prev_to_tagged_()2686         resize_prev_to_tagged_()
2687         {
2688         }
2689 
2690         template <typename SUBNET>
setup(const SUBNET &)2691         void setup (const SUBNET& /*sub*/)
2692         {
2693         }
2694 
2695         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)2696         void forward(const SUBNET& sub, resizable_tensor& output)
2697         {
2698             auto& prev = sub.get_output();
2699             auto& tagged = layer<tag>(sub).get_output();
2700 
2701             DLIB_CASSERT(prev.num_samples() == tagged.num_samples());
2702 
2703             output.set_size(prev.num_samples(),
2704                             prev.k(),
2705                             tagged.nr(),
2706                             tagged.nc());
2707 
2708             if (prev.nr() == tagged.nr() && prev.nc() == tagged.nc())
2709             {
2710                 tt::copy_tensor(false, output, 0, prev, 0, prev.k());
2711             }
2712             else
2713             {
2714                 tt::resize_bilinear(output, prev);
2715             }
2716         }
2717 
2718         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)2719         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
2720         {
2721             auto& prev = sub.get_gradient_input();
2722 
2723             DLIB_CASSERT(prev.k() == gradient_input.k());
2724             DLIB_CASSERT(prev.num_samples() == gradient_input.num_samples());
2725 
2726             if (prev.nr() == gradient_input.nr() && prev.nc() == gradient_input.nc())
2727             {
2728                 tt::copy_tensor(true, prev, 0, gradient_input, 0, prev.k());
2729             }
2730             else
2731             {
2732                 tt::resize_bilinear_gradient(prev, gradient_input);
2733             }
2734         }
2735 
get_layer_params()2736         const tensor& get_layer_params() const { return params; }
get_layer_params()2737         tensor& get_layer_params() { return params; }
2738 
map_input_to_output(const dpoint & p)2739         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)2740         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
2741 
serialize(const resize_prev_to_tagged_ &,std::ostream & out)2742         friend void serialize(const resize_prev_to_tagged_& /*item*/, std::ostream& out)
2743         {
2744             serialize("resize_prev_to_tagged_", out);
2745         }
2746 
deserialize(resize_prev_to_tagged_ &,std::istream & in)2747         friend void deserialize(resize_prev_to_tagged_& /*item*/, std::istream& in)
2748         {
2749             std::string version;
2750             deserialize(version, in);
2751             if (version != "resize_prev_to_tagged_")
2752                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::resize_prev_to_tagged_.");
2753         }
2754 
2755         friend std::ostream& operator<<(std::ostream& out, const resize_prev_to_tagged_& /*item*/)
2756         {
2757             out << "resize_prev_to_tagged"<<id;
2758             return out;
2759         }
2760 
to_xml(const resize_prev_to_tagged_ &,std::ostream & out)2761         friend void to_xml(const resize_prev_to_tagged_& /*item*/, std::ostream& out)
2762         {
2763             out << "<resize_prev_to_tagged tag='"<<id<<"'/>\n";
2764         }
2765 
2766     private:
2767         resizable_tensor params;
2768     };
2769 
2770     template <
2771         template<typename> class tag,
2772         typename SUBNET
2773         >
2774     using resize_prev_to_tagged = add_layer<resize_prev_to_tagged_<tag>, SUBNET>;
2775 
2776 // ----------------------------------------------------------------------------------------
2777 
2778     template <
2779         template<typename> class tag
2780         >
2781     class scale_
2782     {
2783     public:
2784         const static unsigned long id = tag_id<tag>::id;
2785 
scale_()2786         scale_()
2787         {
2788         }
2789 
2790         template <typename SUBNET>
setup(const SUBNET &)2791         void setup (const SUBNET& /*sub*/)
2792         {
2793         }
2794 
2795         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)2796         void forward(const SUBNET& sub, resizable_tensor& output)
2797         {
2798             auto&& scales = sub.get_output();
2799             auto&& src = layer<tag>(sub).get_output();
2800             DLIB_CASSERT(scales.num_samples() == src.num_samples() &&
2801                          scales.k()           == src.k() &&
2802                          scales.nr()          == 1 &&
2803                          scales.nc()          == 1,
2804                          "scales.k(): " << scales.k() <<
2805                          "\nsrc.k(): " << src.k()
2806                          );
2807 
2808             output.copy_size(src);
2809             tt::scale_channels(false, output, src, scales);
2810         }
2811 
2812         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)2813         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
2814         {
2815             auto&& scales = sub.get_output();
2816             auto&& src = layer<tag>(sub).get_output();
2817             // The gradient just flows backwards to the two layers that forward()
2818             // read from.
2819             tt::scale_channels(true, layer<tag>(sub).get_gradient_input(), gradient_input, scales);
2820 
2821             if (reshape_src.num_samples() != src.num_samples())
2822             {
2823                 reshape_scales = alias_tensor(src.num_samples()*src.k());
2824                 reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc());
2825             }
2826 
2827             auto&& scales_grad = sub.get_gradient_input();
2828             auto sgrad = reshape_scales(scales_grad);
2829             tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input));
2830         }
2831 
get_layer_params()2832         const tensor& get_layer_params() const { return params; }
get_layer_params()2833         tensor& get_layer_params() { return params; }
2834 
serialize(const scale_ & item,std::ostream & out)2835         friend void serialize(const scale_& item, std::ostream& out)
2836         {
2837             serialize("scale_", out);
2838             serialize(item.reshape_scales, out);
2839             serialize(item.reshape_src, out);
2840         }
2841 
deserialize(scale_ & item,std::istream & in)2842         friend void deserialize(scale_& item, std::istream& in)
2843         {
2844             std::string version;
2845             deserialize(version, in);
2846             if (version != "scale_")
2847                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_.");
2848             deserialize(item.reshape_scales, in);
2849             deserialize(item.reshape_src, in);
2850         }
2851 
2852         friend std::ostream& operator<<(std::ostream& out, const scale_& /*item*/)
2853         {
2854             out << "scale"<<id;
2855             return out;
2856         }
2857 
to_xml(const scale_ &,std::ostream & out)2858         friend void to_xml(const scale_& /*item*/, std::ostream& out)
2859         {
2860             out << "<scale tag='"<<id<<"'/>\n";
2861         }
2862 
2863     private:
2864         alias_tensor reshape_scales;
2865         alias_tensor reshape_src;
2866         resizable_tensor params;
2867     };
2868 
2869     template <
2870         template<typename> class tag,
2871         typename SUBNET
2872         >
2873     using scale = add_layer<scale_<tag>, SUBNET>;
2874 
2875     template <typename SUBNET> using scale1  = scale<tag1, SUBNET>;
2876     template <typename SUBNET> using scale2  = scale<tag2, SUBNET>;
2877     template <typename SUBNET> using scale3  = scale<tag3, SUBNET>;
2878     template <typename SUBNET> using scale4  = scale<tag4, SUBNET>;
2879     template <typename SUBNET> using scale5  = scale<tag5, SUBNET>;
2880     template <typename SUBNET> using scale6  = scale<tag6, SUBNET>;
2881     template <typename SUBNET> using scale7  = scale<tag7, SUBNET>;
2882     template <typename SUBNET> using scale8  = scale<tag8, SUBNET>;
2883     template <typename SUBNET> using scale9  = scale<tag9, SUBNET>;
2884     template <typename SUBNET> using scale10 = scale<tag10, SUBNET>;
2885 
2886     using scale1_  = scale_<tag1>;
2887     using scale2_  = scale_<tag2>;
2888     using scale3_  = scale_<tag3>;
2889     using scale4_  = scale_<tag4>;
2890     using scale5_  = scale_<tag5>;
2891     using scale6_  = scale_<tag6>;
2892     using scale7_  = scale_<tag7>;
2893     using scale8_  = scale_<tag8>;
2894     using scale9_  = scale_<tag9>;
2895     using scale10_ = scale_<tag10>;
2896 
2897 // ----------------------------------------------------------------------------------------
2898 
2899     template <
2900         template<typename> class tag
2901         >
2902     class scale_prev_
2903     {
2904     public:
2905         const static unsigned long id = tag_id<tag>::id;
2906 
scale_prev_()2907         scale_prev_()
2908         {
2909         }
2910 
2911         template <typename SUBNET>
setup(const SUBNET &)2912         void setup (const SUBNET& /*sub*/)
2913         {
2914         }
2915 
2916         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & output)2917         void forward(const SUBNET& sub, resizable_tensor& output)
2918         {
2919             auto&& src = sub.get_output();
2920             auto&& scales = layer<tag>(sub).get_output();
2921             DLIB_CASSERT(scales.num_samples() == src.num_samples() &&
2922                          scales.k()           == src.k() &&
2923                          scales.nr()          == 1 &&
2924                          scales.nc()          == 1,
2925                          "scales.k(): " << scales.k() <<
2926                          "\nsrc.k(): " << src.k()
2927                          );
2928 
2929             output.copy_size(src);
2930             tt::scale_channels(false, output, src, scales);
2931         }
2932 
2933         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)2934         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
2935         {
2936             auto&& src = sub.get_output();
2937             auto&& scales = layer<tag>(sub).get_output();
2938             tt::scale_channels(true, sub.get_gradient_input(), gradient_input, scales);
2939 
2940             if (reshape_src.num_samples() != src.num_samples())
2941             {
2942                 reshape_scales = alias_tensor(src.num_samples()*src.k());
2943                 reshape_src = alias_tensor(src.num_samples()*src.k(),src.nr()*src.nc());
2944             }
2945 
2946             auto&& scales_grad = layer<tag>(sub).get_gradient_input();
2947             auto sgrad = reshape_scales(scales_grad);
2948             tt::dot_prods(true, sgrad, reshape_src(src), reshape_src(gradient_input));
2949         }
2950 
get_layer_params()2951         const tensor& get_layer_params() const { return params; }
get_layer_params()2952         tensor& get_layer_params() { return params; }
2953 
map_input_to_output(const dpoint & p)2954         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)2955         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
2956 
serialize(const scale_prev_ & item,std::ostream & out)2957         friend void serialize(const scale_prev_& item, std::ostream& out)
2958         {
2959             serialize("scale_prev_", out);
2960             serialize(item.reshape_scales, out);
2961             serialize(item.reshape_src, out);
2962         }
2963 
deserialize(scale_prev_ & item,std::istream & in)2964         friend void deserialize(scale_prev_& item, std::istream& in)
2965         {
2966             std::string version;
2967             deserialize(version, in);
2968             if (version != "scale_prev_")
2969                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::scale_prev_.");
2970             deserialize(item.reshape_scales, in);
2971             deserialize(item.reshape_src, in);
2972         }
2973 
2974         friend std::ostream& operator<<(std::ostream& out, const scale_prev_& /*item*/)
2975         {
2976             out << "scale_prev"<<id;
2977             return out;
2978         }
2979 
to_xml(const scale_prev_ &,std::ostream & out)2980         friend void to_xml(const scale_prev_& /*item*/, std::ostream& out)
2981         {
2982             out << "<scale_prev tag='"<<id<<"'/>\n";
2983         }
2984 
2985     private:
2986         alias_tensor reshape_scales;
2987         alias_tensor reshape_src;
2988         resizable_tensor params;
2989     };
2990 
2991     template <
2992         template<typename> class tag,
2993         typename SUBNET
2994         >
2995     using scale_prev = add_layer<scale_prev_<tag>, SUBNET>;
2996 
2997     template <typename SUBNET> using scale_prev1  = scale_prev<tag1, SUBNET>;
2998     template <typename SUBNET> using scale_prev2  = scale_prev<tag2, SUBNET>;
2999     template <typename SUBNET> using scale_prev3  = scale_prev<tag3, SUBNET>;
3000     template <typename SUBNET> using scale_prev4  = scale_prev<tag4, SUBNET>;
3001     template <typename SUBNET> using scale_prev5  = scale_prev<tag5, SUBNET>;
3002     template <typename SUBNET> using scale_prev6  = scale_prev<tag6, SUBNET>;
3003     template <typename SUBNET> using scale_prev7  = scale_prev<tag7, SUBNET>;
3004     template <typename SUBNET> using scale_prev8  = scale_prev<tag8, SUBNET>;
3005     template <typename SUBNET> using scale_prev9  = scale_prev<tag9, SUBNET>;
3006     template <typename SUBNET> using scale_prev10 = scale_prev<tag10, SUBNET>;
3007 
3008     using scale_prev1_  = scale_prev_<tag1>;
3009     using scale_prev2_  = scale_prev_<tag2>;
3010     using scale_prev3_  = scale_prev_<tag3>;
3011     using scale_prev4_  = scale_prev_<tag4>;
3012     using scale_prev5_  = scale_prev_<tag5>;
3013     using scale_prev6_  = scale_prev_<tag6>;
3014     using scale_prev7_  = scale_prev_<tag7>;
3015     using scale_prev8_  = scale_prev_<tag8>;
3016     using scale_prev9_  = scale_prev_<tag9>;
3017     using scale_prev10_ = scale_prev_<tag10>;
3018 
3019 // ----------------------------------------------------------------------------------------
3020 
3021     class relu_
3022     {
3023     public:
relu_()3024         relu_()
3025         {
3026         }
3027 
3028         template <typename SUBNET>
setup(const SUBNET &)3029         void setup (const SUBNET& /*sub*/)
3030         {
3031         }
3032 
forward_inplace(const tensor & input,tensor & output)3033         void forward_inplace(const tensor& input, tensor& output)
3034         {
3035             tt::relu(output, input);
3036         }
3037 
backward_inplace(const tensor & computed_output,const tensor & gradient_input,tensor & data_grad,tensor &)3038         void backward_inplace(
3039             const tensor& computed_output,
3040             const tensor& gradient_input,
3041             tensor& data_grad,
3042             tensor&
3043         )
3044         {
3045             tt::relu_gradient(data_grad, computed_output, gradient_input);
3046         }
3047 
map_input_to_output(const dpoint & p)3048         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)3049         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
3050 
get_layer_params()3051         const tensor& get_layer_params() const { return params; }
get_layer_params()3052         tensor& get_layer_params() { return params; }
3053 
serialize(const relu_ &,std::ostream & out)3054         friend void serialize(const relu_& /*item*/, std::ostream& out)
3055         {
3056             serialize("relu_", out);
3057         }
3058 
deserialize(relu_ &,std::istream & in)3059         friend void deserialize(relu_& /*item*/, std::istream& in)
3060         {
3061             std::string version;
3062             deserialize(version, in);
3063             if (version != "relu_")
3064                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::relu_.");
3065         }
3066 
3067         friend std::ostream& operator<<(std::ostream& out, const relu_& /*item*/)
3068         {
3069             out << "relu";
3070             return out;
3071         }
3072 
to_xml(const relu_ &,std::ostream & out)3073         friend void to_xml(const relu_& /*item*/, std::ostream& out)
3074         {
3075             out << "<relu/>\n";
3076         }
3077 
3078     private:
3079         resizable_tensor params;
3080     };
3081 
3082 
3083     template <typename SUBNET>
3084     using relu = add_layer<relu_, SUBNET>;
3085 
3086 // ----------------------------------------------------------------------------------------
3087 
3088     class prelu_
3089     {
3090     public:
3091         explicit prelu_(
3092             float initial_param_value_ = 0.25
initial_param_value(initial_param_value_)3093         ) : initial_param_value(initial_param_value_)
3094         {
3095         }
3096 
get_initial_param_value()3097         float get_initial_param_value (
3098         ) const { return initial_param_value; }
3099 
3100         template <typename SUBNET>
setup(const SUBNET &)3101         void setup (const SUBNET& /*sub*/)
3102         {
3103             params.set_size(1);
3104             params = initial_param_value;
3105         }
3106 
3107         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & data_output)3108         void forward(
3109             const SUBNET& sub,
3110             resizable_tensor& data_output
3111         )
3112         {
3113             data_output.copy_size(sub.get_output());
3114             tt::prelu(data_output, sub.get_output(), params);
3115         }
3116 
3117         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor & params_grad)3118         void backward(
3119             const tensor& gradient_input,
3120             SUBNET& sub,
3121             tensor& params_grad
3122         )
3123         {
3124             tt::prelu_gradient(sub.get_gradient_input(), sub.get_output(),
3125                 gradient_input, params, params_grad);
3126         }
3127 
map_input_to_output(const dpoint & p)3128         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)3129         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
3130 
get_layer_params()3131         const tensor& get_layer_params() const { return params; }
get_layer_params()3132         tensor& get_layer_params() { return params; }
3133 
serialize(const prelu_ & item,std::ostream & out)3134         friend void serialize(const prelu_& item, std::ostream& out)
3135         {
3136             serialize("prelu_", out);
3137             serialize(item.params, out);
3138             serialize(item.initial_param_value, out);
3139         }
3140 
deserialize(prelu_ & item,std::istream & in)3141         friend void deserialize(prelu_& item, std::istream& in)
3142         {
3143             std::string version;
3144             deserialize(version, in);
3145             if (version != "prelu_")
3146                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::prelu_.");
3147             deserialize(item.params, in);
3148             deserialize(item.initial_param_value, in);
3149         }
3150 
3151         friend std::ostream& operator<<(std::ostream& out, const prelu_& item)
3152         {
3153             out << "prelu\t ("
3154                 << "initial_param_value="<<item.initial_param_value
3155                 << ")";
3156             return out;
3157         }
3158 
to_xml(const prelu_ & item,std::ostream & out)3159         friend void to_xml(const prelu_& item, std::ostream& out)
3160         {
3161             out << "<prelu initial_param_value='"<<item.initial_param_value<<"'>\n";
3162             out << mat(item.params);
3163             out << "</prelu>\n";
3164         }
3165 
3166     private:
3167         resizable_tensor params;
3168         float initial_param_value;
3169     };
3170 
3171     template <typename SUBNET>
3172     using prelu = add_layer<prelu_, SUBNET>;
3173 
3174 // ----------------------------------------------------------------------------------------
3175     class leaky_relu_
3176     {
3177     public:
3178         explicit leaky_relu_(
3179             float alpha_ = 0.01f
alpha(alpha_)3180         ) : alpha(alpha_)
3181         {
3182         }
3183 
get_alpha()3184         float get_alpha(
3185         ) const {
3186             return alpha;
3187         }
3188 
3189         template <typename SUBNET>
setup(const SUBNET &)3190         void setup(const SUBNET& /*sub*/)
3191         {
3192         }
3193 
forward_inplace(const tensor & input,tensor & output)3194         void forward_inplace(const tensor& input, tensor& output)
3195         {
3196             tt::leaky_relu(output, input, alpha);
3197         }
3198 
backward_inplace(const tensor & computed_output,const tensor & gradient_input,tensor & data_grad,tensor &)3199         void backward_inplace(
3200             const tensor& computed_output,
3201             const tensor& gradient_input,
3202             tensor& data_grad,
3203             tensor&
3204         )
3205         {
3206             tt::leaky_relu_gradient(data_grad, computed_output, gradient_input, alpha);
3207         }
3208 
map_input_to_output(const dpoint & p)3209         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)3210         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
3211 
get_layer_params()3212         const tensor& get_layer_params() const { return params; }
get_layer_params()3213         tensor& get_layer_params() { return params; }
3214 
serialize(const leaky_relu_ & item,std::ostream & out)3215         friend void serialize(const leaky_relu_& item, std::ostream& out)
3216         {
3217             serialize("leaky_relu_", out);
3218             serialize(item.alpha, out);
3219         }
3220 
deserialize(leaky_relu_ & item,std::istream & in)3221         friend void deserialize(leaky_relu_& item, std::istream& in)
3222         {
3223             std::string version;
3224             deserialize(version, in);
3225             if (version != "leaky_relu_")
3226                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::leaky_relu_.");
3227             deserialize(item.alpha, in);
3228         }
3229 
3230         friend std::ostream& operator<<(std::ostream& out, const leaky_relu_& item)
3231         {
3232             out << "leaky_relu\t("
3233                 << "alpha=" << item.alpha
3234                 << ")";
3235             return out;
3236         }
3237 
to_xml(const leaky_relu_ & item,std::ostream & out)3238         friend void to_xml(const leaky_relu_& item, std::ostream& out)
3239         {
3240             out << "<leaky_relu alpha='"<< item.alpha << "'>\n";
3241             out << "<leaky_relu/>\n";
3242         }
3243 
3244     private:
3245         resizable_tensor params;
3246         float alpha;
3247     };
3248 
3249     template <typename SUBNET>
3250     using leaky_relu = add_layer<leaky_relu_, SUBNET>;
3251 
3252 // ----------------------------------------------------------------------------------------
3253 
3254     class sig_
3255     {
3256     public:
sig_()3257         sig_()
3258         {
3259         }
3260 
3261         template <typename SUBNET>
setup(const SUBNET &)3262         void setup (const SUBNET& /*sub*/)
3263         {
3264         }
3265 
forward_inplace(const tensor & input,tensor & output)3266         void forward_inplace(const tensor& input, tensor& output)
3267         {
3268             tt::sigmoid(output, input);
3269         }
3270 
backward_inplace(const tensor & computed_output,const tensor & gradient_input,tensor & data_grad,tensor &)3271         void backward_inplace(
3272             const tensor& computed_output,
3273             const tensor& gradient_input,
3274             tensor& data_grad,
3275             tensor&
3276         )
3277         {
3278             tt::sigmoid_gradient(data_grad, computed_output, gradient_input);
3279         }
3280 
map_input_to_output(const dpoint & p)3281         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)3282         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
3283 
get_layer_params()3284         const tensor& get_layer_params() const { return params; }
get_layer_params()3285         tensor& get_layer_params() { return params; }
3286 
serialize(const sig_ &,std::ostream & out)3287         friend void serialize(const sig_& /*item*/, std::ostream& out)
3288         {
3289             serialize("sig_", out);
3290         }
3291 
deserialize(sig_ &,std::istream & in)3292         friend void deserialize(sig_& /*item*/, std::istream& in)
3293         {
3294             std::string version;
3295             deserialize(version, in);
3296             if (version != "sig_")
3297                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::sig_.");
3298         }
3299 
3300         friend std::ostream& operator<<(std::ostream& out, const sig_& /*item*/)
3301         {
3302             out << "sig";
3303             return out;
3304         }
3305 
to_xml(const sig_ &,std::ostream & out)3306         friend void to_xml(const sig_& /*item*/, std::ostream& out)
3307         {
3308             out << "<sig/>\n";
3309         }
3310 
3311 
3312     private:
3313         resizable_tensor params;
3314     };
3315 
3316 
3317     template <typename SUBNET>
3318     using sig = add_layer<sig_, SUBNET>;
3319 
3320 // ----------------------------------------------------------------------------------------
3321 
3322     class mish_
3323     {
3324     public:
mish_()3325         mish_()
3326         {
3327         }
3328 
3329         template <typename SUBNET>
setup(const SUBNET &)3330         void setup (const SUBNET& /*sub*/)
3331         {
3332         }
3333 
3334         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & data_output)3335         void forward(
3336             const SUBNET& sub,
3337             resizable_tensor& data_output
3338         )
3339         {
3340             data_output.copy_size(sub.get_output());
3341             tt::mish(data_output, sub.get_output());
3342         }
3343 
3344         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)3345         void backward(
3346             const tensor& gradient_input,
3347             SUBNET& sub,
3348             tensor&
3349         )
3350         {
3351             tt::mish_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input);
3352         }
3353 
map_input_to_output(const dpoint & p)3354         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)3355         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
3356 
get_layer_params()3357         const tensor& get_layer_params() const { return params; }
get_layer_params()3358         tensor& get_layer_params() { return params; }
3359 
serialize(const mish_ &,std::ostream & out)3360         friend void serialize(const mish_& /*item*/, std::ostream& out)
3361         {
3362             serialize("mish_", out);
3363         }
3364 
deserialize(mish_ &,std::istream & in)3365         friend void deserialize(mish_& /*item*/, std::istream& in)
3366         {
3367             std::string version;
3368             deserialize(version, in);
3369             if (version != "mish_")
3370                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::mish_.");
3371         }
3372 
3373         friend std::ostream& operator<<(std::ostream& out, const mish_& /*item*/)
3374         {
3375             out << "mish";
3376             return out;
3377         }
3378 
to_xml(const mish_ &,std::ostream & out)3379         friend void to_xml(const mish_& /*item*/, std::ostream& out)
3380         {
3381             out << "<mish/>\n";
3382         }
3383 
3384 
3385     private:
3386         resizable_tensor params;
3387     };
3388 
3389 
3390     template <typename SUBNET>
3391     using mish = add_layer<mish_, SUBNET>;
3392 
3393 // ----------------------------------------------------------------------------------------
3394 
3395     class htan_
3396     {
3397     public:
htan_()3398         htan_()
3399         {
3400         }
3401 
3402         template <typename SUBNET>
setup(const SUBNET &)3403         void setup (const SUBNET& /*sub*/)
3404         {
3405         }
3406 
map_input_to_output(const dpoint & p)3407         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)3408         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
3409 
forward_inplace(const tensor & input,tensor & output)3410         void forward_inplace(const tensor& input, tensor& output)
3411         {
3412             tt::tanh(output, input);
3413         }
3414 
backward_inplace(const tensor & computed_output,const tensor & gradient_input,tensor & data_grad,tensor &)3415         void backward_inplace(
3416             const tensor& computed_output,
3417             const tensor& gradient_input,
3418             tensor& data_grad,
3419             tensor&
3420         )
3421         {
3422             tt::tanh_gradient(data_grad, computed_output, gradient_input);
3423         }
3424 
get_layer_params()3425         const tensor& get_layer_params() const { return params; }
get_layer_params()3426         tensor& get_layer_params() { return params; }
3427 
serialize(const htan_ &,std::ostream & out)3428         friend void serialize(const htan_& /*item*/, std::ostream& out)
3429         {
3430             serialize("htan_", out);
3431         }
3432 
deserialize(htan_ &,std::istream & in)3433         friend void deserialize(htan_& /*item*/, std::istream& in)
3434         {
3435             std::string version;
3436             deserialize(version, in);
3437             if (version != "htan_")
3438                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::htan_.");
3439         }
3440 
3441         friend std::ostream& operator<<(std::ostream& out, const htan_& /*item*/)
3442         {
3443             out << "htan";
3444             return out;
3445         }
3446 
to_xml(const htan_ &,std::ostream & out)3447         friend void to_xml(const htan_& /*item*/, std::ostream& out)
3448         {
3449             out << "<htan/>\n";
3450         }
3451 
3452 
3453     private:
3454         resizable_tensor params;
3455     };
3456 
3457 
3458     template <typename SUBNET>
3459     using htan = add_layer<htan_, SUBNET>;
3460 
3461 // ----------------------------------------------------------------------------------------
3462 
3463     class gelu_
3464     {
3465     public:
gelu_()3466         gelu_()
3467         {
3468         }
3469 
3470         template <typename SUBNET>
setup(const SUBNET &)3471         void setup (const SUBNET& /*sub*/)
3472         {
3473         }
3474 
3475         template <typename SUBNET>
forward(const SUBNET & sub,resizable_tensor & data_output)3476         void forward(
3477             const SUBNET& sub,
3478             resizable_tensor& data_output
3479         )
3480         {
3481             data_output.copy_size(sub.get_output());
3482             tt::gelu(data_output, sub.get_output());
3483         }
3484 
3485         template <typename SUBNET>
backward(const tensor & gradient_input,SUBNET & sub,tensor &)3486         void backward(
3487             const tensor& gradient_input,
3488             SUBNET& sub,
3489             tensor&
3490         )
3491         {
3492             tt::gelu_gradient(sub.get_gradient_input(), sub.get_output(), gradient_input);
3493         }
3494 
map_input_to_output(const dpoint & p)3495         inline dpoint map_input_to_output (const dpoint& p) const { return p; }
map_output_to_input(const dpoint & p)3496         inline dpoint map_output_to_input (const dpoint& p) const { return p; }
3497 
get_layer_params()3498         const tensor& get_layer_params() const { return params; }
get_layer_params()3499         tensor& get_layer_params() { return params; }
3500 
serialize(const gelu_ &,std::ostream & out)3501         friend void serialize(const gelu_& /*item*/, std::ostream& out)
3502         {
3503             serialize("gelu_", out);
3504         }
3505 
deserialize(gelu_ &,std::istream & in)3506         friend void deserialize(gelu_& /*item*/, std::istream& in)
3507         {
3508             std::string version;
3509             deserialize(version, in);
3510             if (version != "gelu_")
3511                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::gelu_.");
3512         }
3513 
3514         friend std::ostream& operator<<(std::ostream& out, const gelu_& /*item*/)
3515         {
3516             out << "gelu";
3517             return out;
3518         }
3519 
to_xml(const gelu_ &,std::ostream & out)3520         friend void to_xml(const gelu_& /*item*/, std::ostream& out)
3521         {
3522             out << "<gelu/>\n";
3523         }
3524 
3525 
3526     private:
3527         resizable_tensor params;
3528     };
3529 
3530     template <typename SUBNET>
3531     using gelu = add_layer<gelu_, SUBNET>;
3532 
3533 // ----------------------------------------------------------------------------------------
3534 
3535     class softmax_
3536     {
3537     public:
softmax_()3538         softmax_()
3539         {
3540         }
3541 
3542         template <typename SUBNET>
setup(const SUBNET &)3543         void setup (const SUBNET& /*sub*/)
3544         {
3545         }
3546 
forward_inplace(const tensor & input,tensor & output)3547         void forward_inplace(const tensor& input, tensor& output)
3548         {
3549             tt::softmax(output, input);
3550         }
3551 
backward_inplace(const tensor & computed_output,const tensor & gradient_input,tensor & data_grad,tensor &)3552         void backward_inplace(
3553             const tensor& computed_output,
3554             const tensor& gradient_input,
3555             tensor& data_grad,
3556             tensor&
3557         )
3558         {
3559             tt::softmax_gradient(data_grad, computed_output, gradient_input);
3560         }
3561 
get_layer_params()3562         const tensor& get_layer_params() const { return params; }
get_layer_params()3563         tensor& get_layer_params() { return params; }
3564 
serialize(const softmax_ &,std::ostream & out)3565         friend void serialize(const softmax_& /*item*/, std::ostream& out)
3566         {
3567             serialize("softmax_", out);
3568         }
3569 
deserialize(softmax_ &,std::istream & in)3570         friend void deserialize(softmax_& /*item*/, std::istream& in)
3571         {
3572             std::string version;
3573             deserialize(version, in);
3574             if (version != "softmax_")
3575                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_.");
3576         }
3577 
3578         friend std::ostream& operator<<(std::ostream& out, const softmax_& /*item*/)
3579         {
3580             out << "softmax";
3581             return out;
3582         }
3583 
to_xml(const softmax_ &,std::ostream & out)3584         friend void to_xml(const softmax_& /*item*/, std::ostream& out)
3585         {
3586             out << "<softmax/>\n";
3587         }
3588 
3589     private:
3590         resizable_tensor params;
3591     };
3592 
3593     template <typename SUBNET>
3594     using softmax = add_layer<softmax_, SUBNET>;
3595 
3596 // ----------------------------------------------------------------------------------------
3597 
3598     class softmax_all_
3599     {
3600     public:
softmax_all_()3601         softmax_all_()
3602         {
3603         }
3604 
3605         template <typename SUBNET>
setup(const SUBNET &)3606         void setup (const SUBNET& /*sub*/)
3607         {
3608         }
3609 
forward_inplace(const tensor & input,tensor & output)3610         void forward_inplace(const tensor& input, tensor& output)
3611         {
3612             tt::softmax_all(output, input);
3613         }
3614 
backward_inplace(const tensor & computed_output,const tensor & gradient_input,tensor & data_grad,tensor &)3615         void backward_inplace(
3616             const tensor& computed_output,
3617             const tensor& gradient_input,
3618             tensor& data_grad,
3619             tensor&
3620         )
3621         {
3622             tt::softmax_all_gradient(data_grad, computed_output, gradient_input);
3623         }
3624 
get_layer_params()3625         const tensor& get_layer_params() const { return params; }
get_layer_params()3626         tensor& get_layer_params() { return params; }
3627 
serialize(const softmax_all_ &,std::ostream & out)3628         friend void serialize(const softmax_all_& /*item*/, std::ostream& out)
3629         {
3630             serialize("softmax_all_", out);
3631         }
3632 
deserialize(softmax_all_ &,std::istream & in)3633         friend void deserialize(softmax_all_& /*item*/, std::istream& in)
3634         {
3635             std::string version;
3636             deserialize(version, in);
3637             if (version != "softmax_all_")
3638                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::softmax_all_.");
3639         }
3640 
3641         friend std::ostream& operator<<(std::ostream& out, const softmax_all_& /*item*/)
3642         {
3643             out << "softmax_all";
3644             return out;
3645         }
3646 
to_xml(const softmax_all_ &,std::ostream & out)3647         friend void to_xml(const softmax_all_& /*item*/, std::ostream& out)
3648         {
3649             out << "<softmax_all/>\n";
3650         }
3651 
3652     private:
3653         resizable_tensor params;
3654     };
3655 
3656     template <typename SUBNET>
3657     using softmax_all = add_layer<softmax_all_, SUBNET>;
3658 
3659 // ----------------------------------------------------------------------------------------
3660 
3661     namespace impl
3662     {
3663         template <template<typename> class TAG_TYPE, template<typename> class... TAG_TYPES>
3664         struct concat_helper_impl{
3665 
tag_countconcat_helper_impl3666             constexpr static size_t tag_count() {return 1 + concat_helper_impl<TAG_TYPES...>::tag_count();}
list_tagsconcat_helper_impl3667             static void list_tags(std::ostream& out)
3668             {
3669                 out << tag_id<TAG_TYPE>::id << (tag_count() > 1 ? "," : "");
3670                 concat_helper_impl<TAG_TYPES...>::list_tags(out);
3671             }
3672 
3673             template<typename SUBNET>
resize_outconcat_helper_impl3674             static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
3675             {
3676                 auto& t = layer<TAG_TYPE>(sub).get_output();
3677                 concat_helper_impl<TAG_TYPES...>::resize_out(out, sub, sum_k + t.k());
3678             }
3679             template<typename SUBNET>
concatconcat_helper_impl3680             static void concat(tensor& out, const SUBNET& sub, size_t k_offset)
3681             {
3682                 auto& t = layer<TAG_TYPE>(sub).get_output();
3683                 tt::copy_tensor(false, out, k_offset, t, 0, t.k());
3684                 k_offset += t.k();
3685                 concat_helper_impl<TAG_TYPES...>::concat(out, sub, k_offset);
3686             }
3687             template<typename SUBNET>
splitconcat_helper_impl3688             static void split(const tensor& input, SUBNET& sub, size_t k_offset)
3689             {
3690                 auto& t = layer<TAG_TYPE>(sub).get_gradient_input();
3691                 tt::copy_tensor(true, t, 0, input, k_offset, t.k());
3692                 k_offset += t.k();
3693                 concat_helper_impl<TAG_TYPES...>::split(input, sub, k_offset);
3694             }
3695         };
3696         template <template<typename> class TAG_TYPE>
3697         struct concat_helper_impl<TAG_TYPE>{
3698             constexpr static size_t tag_count() {return 1;}
3699             static void list_tags(std::ostream& out)
3700             {
3701                 out << tag_id<TAG_TYPE>::id;
3702             }
3703 
3704             template<typename SUBNET>
3705             static void resize_out(resizable_tensor& out, const SUBNET& sub, long sum_k)
3706             {
3707                 auto& t = layer<TAG_TYPE>(sub).get_output();
3708                 out.set_size(t.num_samples(), t.k() + sum_k, t.nr(), t.nc());
3709             }
3710             template<typename SUBNET>
3711             static void concat(tensor& out, const SUBNET& sub, size_t k_offset)
3712             {
3713                 auto& t = layer<TAG_TYPE>(sub).get_output();
3714                 tt::copy_tensor(false, out, k_offset, t, 0, t.k());
3715             }
3716             template<typename SUBNET>
3717             static void split(const tensor& input, SUBNET& sub, size_t k_offset)
3718             {
3719                 auto& t = layer<TAG_TYPE>(sub).get_gradient_input();
3720                 tt::copy_tensor(true, t, 0, input, k_offset, t.k());
3721             }
3722         };
3723     }
3724     // concat layer
3725     template<
3726         template<typename> class... TAG_TYPES
3727         >
3728     class concat_
3729     {
3730         static void list_tags(std::ostream& out) { impl::concat_helper_impl<TAG_TYPES...>::list_tags(out);};
3731 
3732     public:
3733         constexpr static size_t tag_count() {return impl::concat_helper_impl<TAG_TYPES...>::tag_count();};
3734 
3735         template <typename SUBNET>
3736         void setup (const SUBNET&)
3737         {
3738             // do nothing
3739         }
3740         template <typename SUBNET>
3741         void forward(const SUBNET& sub, resizable_tensor& output)
3742         {
3743             // the total depth of result is the sum of depths from all tags
3744             impl::concat_helper_impl<TAG_TYPES...>::resize_out(output, sub, 0);
3745 
3746             // copy output from each tag into different part result
3747             impl::concat_helper_impl<TAG_TYPES...>::concat(output, sub, 0);
3748         }
3749 
3750         template <typename SUBNET>
3751         void backward(const tensor& gradient_input, SUBNET& sub, tensor&)
3752         {
3753             // Gradient is split into parts for each tag layer
3754             impl::concat_helper_impl<TAG_TYPES...>::split(gradient_input, sub, 0);
3755         }
3756 
3757         dpoint map_input_to_output(dpoint p) const { return p; }
3758         dpoint map_output_to_input(dpoint p) const { return p; }
3759 
3760         const tensor& get_layer_params() const { return params; }
3761         tensor& get_layer_params() { return params; }
3762 
3763         friend void serialize(const concat_& /*item*/, std::ostream& out)
3764         {
3765             serialize("concat_", out);
3766             size_t count = tag_count();
3767             serialize(count, out);
3768         }
3769 
3770         friend void deserialize(concat_& /*item*/, std::istream& in)
3771         {
3772             std::string version;
3773             deserialize(version, in);
3774             if (version != "concat_")
3775                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::concat_.");
3776             size_t count_tags;
3777             deserialize(count_tags, in);
3778             if (count_tags != tag_count())
3779                 throw serialization_error("Invalid count of tags "+ std::to_string(count_tags) +", expecting " +
3780                                           std::to_string(tag_count()) +
3781                                                   " found while deserializing dlib::concat_.");
3782         }
3783 
3784         friend std::ostream& operator<<(std::ostream& out, const concat_& /*item*/)
3785         {
3786             out << "concat\t (";
3787             list_tags(out);
3788             out << ")";
3789             return out;
3790         }
3791 
3792         friend void to_xml(const concat_& /*item*/, std::ostream& out)
3793         {
3794             out << "<concat tags='";
3795             list_tags(out);
3796             out << "'/>\n";
3797         }
3798 
3799     private:
3800         resizable_tensor params; // unused
3801     };
3802 
3803 
3804     // concat layer definitions
3805     template <template<typename> class TAG1,
3806             template<typename> class TAG2,
3807             typename SUBNET>
3808     using concat2 = add_layer<concat_<TAG1, TAG2>, SUBNET>;
3809 
3810     template <template<typename> class TAG1,
3811             template<typename> class TAG2,
3812             template<typename> class TAG3,
3813             typename SUBNET>
3814     using concat3 = add_layer<concat_<TAG1, TAG2, TAG3>, SUBNET>;
3815 
3816     template <template<typename> class TAG1,
3817             template<typename> class TAG2,
3818             template<typename> class TAG3,
3819             template<typename> class TAG4,
3820             typename SUBNET>
3821     using concat4 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4>, SUBNET>;
3822 
3823     template <template<typename> class TAG1,
3824             template<typename> class TAG2,
3825             template<typename> class TAG3,
3826             template<typename> class TAG4,
3827             template<typename> class TAG5,
3828             typename SUBNET>
3829     using concat5 = add_layer<concat_<TAG1, TAG2, TAG3, TAG4, TAG5>, SUBNET>;
3830 
3831     // inception layer will use tags internally. If user will use tags too, some conflicts
3832     // possible to exclude them, here are new tags specially for inceptions
3833     template <typename SUBNET> using itag0  = add_tag_layer< 1000 + 0, SUBNET>;
3834     template <typename SUBNET> using itag1  = add_tag_layer< 1000 + 1, SUBNET>;
3835     template <typename SUBNET> using itag2  = add_tag_layer< 1000 + 2, SUBNET>;
3836     template <typename SUBNET> using itag3  = add_tag_layer< 1000 + 3, SUBNET>;
3837     template <typename SUBNET> using itag4  = add_tag_layer< 1000 + 4, SUBNET>;
3838     template <typename SUBNET> using itag5  = add_tag_layer< 1000 + 5, SUBNET>;
3839     // skip to inception input
3840     template <typename SUBNET> using iskip  = add_skip_layer< itag0, SUBNET>;
3841 
3842     // here are some templates to be used for creating inception layer groups
3843     template <template<typename>class B1,
3844             template<typename>class B2,
3845             typename SUBNET>
3846     using inception2 = concat2<itag1, itag2, itag1<B1<iskip< itag2<B2< itag0<SUBNET>>>>>>>;
3847 
3848     template <template<typename>class B1,
3849             template<typename>class B2,
3850             template<typename>class B3,
3851             typename SUBNET>
3852     using inception3 = concat3<itag1, itag2, itag3, itag1<B1<iskip< itag2<B2<iskip< itag3<B3<  itag0<SUBNET>>>>>>>>>>;
3853 
3854     template <template<typename>class B1,
3855             template<typename>class B2,
3856             template<typename>class B3,
3857             template<typename>class B4,
3858             typename SUBNET>
3859     using inception4 = concat4<itag1, itag2, itag3, itag4,
3860                 itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip<  itag4<B4<  itag0<SUBNET>>>>>>>>>>>>>;
3861 
3862     template <template<typename>class B1,
3863             template<typename>class B2,
3864             template<typename>class B3,
3865             template<typename>class B4,
3866             template<typename>class B5,
3867             typename SUBNET>
3868     using inception5 = concat5<itag1, itag2, itag3, itag4, itag5,
3869                 itag1<B1<iskip< itag2<B2<iskip< itag3<B3<iskip<  itag4<B4<iskip<  itag5<B5<  itag0<SUBNET>>>>>>>>>>>>>>>>;
3870 
3871 // ----------------------------------------------------------------------------------------
3872 // ----------------------------------------------------------------------------------------
3873 
3874     const double DEFAULT_L2_NORM_EPS = 1e-5;
3875 
3876     class l2normalize_
3877     {
3878     public:
3879         explicit l2normalize_(
3880             double eps_ = DEFAULT_L2_NORM_EPS
3881         ) :
3882             eps(eps_)
3883         {
3884         }
3885 
3886         double get_eps() const { return eps; }
3887 
3888         template <typename SUBNET>
3889         void setup (const SUBNET& /*sub*/)
3890         {
3891         }
3892 
3893         void forward_inplace(const tensor& input, tensor& output)
3894         {
3895             tt::inverse_norms(norm, input, eps);
3896             tt::scale_rows(output, input, norm);
3897         }
3898 
3899         void backward_inplace(
3900             const tensor& computed_output,
3901             const tensor& gradient_input,
3902             tensor& data_grad,
3903             tensor& /*params_grad*/
3904         )
3905         {
3906             if (is_same_object(gradient_input, data_grad))
3907             {
3908                 tt::dot_prods(temp, gradient_input, computed_output);
3909                 tt::scale_rows2(0, data_grad, gradient_input, computed_output, temp, norm);
3910             }
3911             else
3912             {
3913                 tt::dot_prods(temp, gradient_input, computed_output);
3914                 tt::scale_rows2(1, data_grad, gradient_input, computed_output, temp, norm);
3915             }
3916         }
3917 
3918         const tensor& get_layer_params() const { return params; }
3919         tensor& get_layer_params() { return params; }
3920 
3921         friend void serialize(const l2normalize_& item, std::ostream& out)
3922         {
3923             serialize("l2normalize_", out);
3924             serialize(item.eps, out);
3925         }
3926 
3927         friend void deserialize(l2normalize_& item, std::istream& in)
3928         {
3929             std::string version;
3930             deserialize(version, in);
3931             if (version != "l2normalize_")
3932                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::l2normalize_.");
3933             deserialize(item.eps, in);
3934         }
3935 
3936         friend std::ostream& operator<<(std::ostream& out, const l2normalize_& item)
3937         {
3938             out << "l2normalize";
3939             out << " eps="<<item.eps;
3940             return out;
3941         }
3942 
3943         friend void to_xml(const l2normalize_& item, std::ostream& out)
3944         {
3945             out << "<l2normalize";
3946             out << " eps='"<<item.eps<<"'";
3947             out << "/>\n";
3948         }
3949     private:
3950         double eps;
3951 
3952         resizable_tensor params; // unused
3953         // Here only to avoid reallocation and as a cache between forward/backward
3954         // functions.
3955         resizable_tensor norm;
3956         resizable_tensor temp;
3957     };
3958 
3959     template <typename SUBNET>
3960     using l2normalize = add_layer<l2normalize_, SUBNET>;
3961 
3962 // ----------------------------------------------------------------------------------------
3963 
3964     template <
3965         long _offset,
3966         long _k,
3967         long _nr,
3968         long _nc
3969         >
3970     class extract_
3971     {
3972         static_assert(_offset >= 0, "The offset must be >= 0.");
3973         static_assert(_k > 0,  "The number of channels must be > 0.");
3974         static_assert(_nr > 0, "The number of rows must be > 0.");
3975         static_assert(_nc > 0, "The number of columns must be > 0.");
3976     public:
3977         extract_(
3978         )
3979         {
3980         }
3981 
3982         template <typename SUBNET>
3983         void setup (const SUBNET& sub)
3984         {
3985             DLIB_CASSERT((long)sub.get_output().size() >= sub.get_output().num_samples()*(_offset+_k*_nr*_nc),
3986                 "The tensor we are trying to extract from the input tensor is too big to fit into the input tensor.");
3987 
3988             aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc);
3989             ain = alias_tensor(sub.get_output().num_samples(),  sub.get_output().size()/sub.get_output().num_samples());
3990         }
3991 
3992         template <typename SUBNET>
3993         void forward(const SUBNET& sub, resizable_tensor& output)
3994         {
3995             if (aout.num_samples() != sub.get_output().num_samples())
3996             {
3997                 aout = alias_tensor(sub.get_output().num_samples(), _k*_nr*_nc);
3998                 ain = alias_tensor(sub.get_output().num_samples(),  sub.get_output().size()/sub.get_output().num_samples());
3999             }
4000 
4001             output.set_size(sub.get_output().num_samples(), _k, _nr, _nc);
4002             auto out = aout(output,0);
4003             auto in = ain(sub.get_output(),0);
4004             tt::copy_tensor(false, out, 0, in, _offset, _k*_nr*_nc);
4005         }
4006 
4007         template <typename SUBNET>
4008         void backward(const tensor& gradient_input, SUBNET& sub, tensor& /*params_grad*/)
4009         {
4010             auto out = ain(sub.get_gradient_input(),0);
4011             auto in = aout(gradient_input,0);
4012             tt::copy_tensor(true, out, _offset, in, 0, _k*_nr*_nc);
4013         }
4014 
4015         const tensor& get_layer_params() const { return params; }
4016         tensor& get_layer_params() { return params; }
4017 
4018         friend void serialize(const extract_& /*item*/, std::ostream& out)
4019         {
4020             serialize("extract_", out);
4021             serialize(_offset, out);
4022             serialize(_k, out);
4023             serialize(_nr, out);
4024             serialize(_nc, out);
4025         }
4026 
4027         friend void deserialize(extract_& /*item*/, std::istream& in)
4028         {
4029             std::string version;
4030             deserialize(version, in);
4031             if (version != "extract_")
4032                 throw serialization_error("Unexpected version '"+version+"' found while deserializing dlib::extract_.");
4033 
4034             long offset;
4035             long k;
4036             long nr;
4037             long nc;
4038             deserialize(offset, in);
4039             deserialize(k, in);
4040             deserialize(nr, in);
4041             deserialize(nc, in);
4042 
4043             if (offset != _offset) throw serialization_error("Wrong offset found while deserializing dlib::extract_");
4044             if (k != _k)   throw serialization_error("Wrong k found while deserializing dlib::extract_");
4045             if (nr != _nr) throw serialization_error("Wrong nr found while deserializing dlib::extract_");
4046             if (nc != _nc) throw serialization_error("Wrong nc found while deserializing dlib::extract_");
4047         }
4048 
4049         friend std::ostream& operator<<(std::ostream& out, const extract_& /*item*/)
4050         {
4051             out << "extract\t ("
4052                 << "offset="<<_offset
4053                 << ", k="<<_k
4054                 << ", nr="<<_nr
4055                 << ", nc="<<_nc
4056                 << ")";
4057             return out;
4058         }
4059 
4060         friend void to_xml(const extract_& /*item*/, std::ostream& out)
4061         {
4062             out << "<extract";
4063             out << " offset='"<<_offset<<"'";
4064             out << " k='"<<_k<<"'";
4065             out << " nr='"<<_nr<<"'";
4066             out << " nc='"<<_nc<<"'";
4067             out << "/>\n";
4068         }
4069     private:
4070         alias_tensor aout, ain;
4071 
4072         resizable_tensor params; // unused
4073     };
4074 
4075     template <
4076         long offset,
4077         long k,
4078         long nr,
4079         long nc,
4080         typename SUBNET
4081         >
4082     using extract = add_layer<extract_<offset,k,nr,nc>, SUBNET>;
4083 
4084 // ----------------------------------------------------------------------------------------
4085 
4086 }
4087 
4088 #endif // DLIB_DNn_LAYERS_H_
4089 
4090 
4091