1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "glow/Optimizer/GraphOptimizer/TrainingPreparation.h"
18 
19 #include "glow/Base/Tensor.h"
20 #include "glow/Graph/PlaceholderBindings.h"
21 
22 namespace glow {
23 
24 namespace {
defaultTensorInitializer(Function * F,Node * node,unsigned inputIdx,Tensor * tensor)25 void defaultTensorInitializer(Function *F, Node *node, unsigned inputIdx,
26                               Tensor *tensor) {
27   switch (node->getKind()) {
28   case Kinded::Kind::ConvolutionNodeKind: {
29     if (ConvolutionNode::FilterIdx == inputIdx) {
30       ConvolutionNode *CN = llvm::cast<ConvolutionNode>(node);
31       ShapeNHWC idim = ShapeNHWC(CN->getInput().dims());
32       ShapeHW kdim(CN->getKernels());
33       size_t fanIn = kdim.height * kdim.width * idim.c;
34       tensor->init(Tensor::InitKind::Xavier, fanIn, F->getPRNG());
35     } else if (ConvolutionNode::BiasIdx == inputIdx) {
36       tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG());
37     }
38     break;
39   }
40   case Kinded::Kind::BatchNormalizationNodeKind: {
41     if (BatchNormalizationNode::ScaleIdx == inputIdx) {
42       tensor->init(Tensor::InitKind::Zero, 0, F->getPRNG());
43     } else if (BatchNormalizationNode::BiasIdx == inputIdx) {
44       tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG());
45     } else if (BatchNormalizationNode::MeanIdx == inputIdx) {
46       tensor->init(Tensor::InitKind::Zero, 0, F->getPRNG());
47     } else if (BatchNormalizationNode::VarIdx == inputIdx) {
48       tensor->init(Tensor::InitKind::Broadcast, 1.0, F->getPRNG());
49     }
50     break;
51   }
52   case Kinded::Kind::FullyConnectedNodeKind: {
53     if (FullyConnectedNode::WeightsIdx == inputIdx) {
54       FullyConnectedNode *FCN = llvm::cast<FullyConnectedNode>(node);
55       auto in = FCN->getInput();
56       tensor->init(Tensor::InitKind::Xavier, in.dims()[1], F->getPRNG());
57     } else if (FullyConnectedNode::BiasIdx == inputIdx) {
58       tensor->init(Tensor::InitKind::Broadcast, 0.1, F->getPRNG());
59     }
60     break;
61   }
62   case Kinded::Kind::SoftMaxNodeKind: {
63     if (SoftMaxNode::SelectedIdx == inputIdx) {
64       tensor->zero();
65     }
66     break;
67   }
68   default:
69     break;
70   }
71 }
72 } // namespace
73 
getDefaultTensorInitializer()74 TensorInitializer getDefaultTensorInitializer() {
75   return defaultTensorInitializer;
76 }
77 
prepareFunctionForTraining(Function * F,PlaceholderBindings & bindings,Placeholder * & selected,TensorInitializer && initializer)78 Error prepareFunctionForTraining(Function *F, PlaceholderBindings &bindings,
79                                  Placeholder *&selected,
80                                  TensorInitializer &&initializer) {
81 
82   auto &nodes = F->getNodes();
83 
84   selected = nullptr;
85   // Lookup all nodes, skip Storage types, enumerate inputs,
86   // replace Constant type with trainable Placeholders except special cases,
87   // like BatchNormalization inputs (mean and variance). In special cases
88   // replace Constant type with non-trainable Placeholders.
89   for (auto &node : nodes) {
90     // Skip storages.
91     if (llvm::isa<Storage>(&node)) {
92       continue;
93     }
94 
95     const bool isSoftMax = node.getKind() == Kinded::Kind::SoftMaxNodeKind;
96     const bool isBatchNormalization =
97         node.getKind() == Kinded::Kind::BatchNormalizationNodeKind;
98 
99     for (unsigned idx = 0, e = node.getNumInputs(); idx < e; ++idx) {
100       auto *IN = node.getNthInput(idx).getNode();
101       Constant *C = llvm::dyn_cast<Constant>(IN);
102       if (!C) {
103         continue;
104       }
105 
106       // Condition for NON trainable case
107       // isSoftMax || isBatchNormalization &&
108       //  (BatchNormalizationNode::MeanIdx == idx ||
109       //   BatchNormalizationNode::VarIdx == idx)
110 
111       const bool isTrainable =
112           !isSoftMax &&
113           (!isBatchNormalization || (BatchNormalizationNode::MeanIdx != idx &&
114                                      BatchNormalizationNode::VarIdx != idx));
115 
116       auto *PH = F->getParent()->createPlaceholder(C->getType(), C->getName(),
117                                                    isTrainable);
118 
119       if (isSoftMax) {
120         selected = PH;
121       }
122       C->getOutput().replaceAllUsesOfWith(PH, F);
123       auto &tensor = C->getPayloadMutable();
124       initializer(F, &node, idx, &tensor);
125       bindings.insert(PH, std::move(tensor));
126       RETURN_ERR_IF_NOT(!C->hasUsers(), "Constant is still in use.");
127       F->getParent()->eraseConstant(C);
128     }
129   }
130 
131   return Error::success();
132 }
133 } // namespace glow
134