1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 #include "glow/Converter/FunctionConverter.h"
17 
18 #include "glow/Graph/Graph.h" // For Function.
19 #include "glow/Graph/Node.h"  // For Node.
20 #include "glow/Graph/Nodes.h" // For Placeholder and Constant.
21 #include "glow/Graph/PlaceholderBindings.h"
22 
23 #include "llvm/ADT/DenseMap.h"
24 
25 using namespace glow;
26 
27 TypeRef
getTargetTypeForOutput(const NodeValue & nodeVal) const28 FunctionConverter::getTargetTypeForOutput(const NodeValue &nodeVal) const {
29   // Default implementation says there is nothing to do.
30   return nullptr;
31 }
32 
getTargetTypeForInput(const Node & use,unsigned idx) const33 TypeRef FunctionConverter::getTargetTypeForInput(const Node &use,
34                                                  unsigned idx) const {
35   // Default implementation says there is nothing to do.
36   return nullptr;
37 }
38 
canConvert(const Node & node) const39 bool FunctionConverter::canConvert(const Node &node) const {
40   // By default, we assume everything is convertible.
41   switch (node.getKind()) {
42   default:
43     return true;
44   case Kinded::Kind::PlaceholderKind:
45   case Kinded::Kind::SaveNodeKind:
46     // Save node or placeholder special because
47     // they are or their effects are visible from
48     // the outside of the function being converted.
49     // Thus, we cannot convert them, unless we change
50     // the semantic of this function and the related
51     // placeholder.
52     return false;
53   }
54 }
55 
getConversionOutput(Node & conversion) const56 NodeValue FunctionConverter::getConversionOutput(Node &conversion) const {
57   assert(conversion.getNumResults() == 1 && "This method should be overloaded");
58   return NodeValue(&conversion, 0);
59 }
60 
morphNode(Node & node)61 Node &FunctionConverter::morphNode(Node &node) { return node; }
62 
postProcessing(Node & node)63 void FunctionConverter::postProcessing(Node &node) {}
64 
convertOutputs(Node & node)65 void FunctionConverter::convertOutputs(Node &node) {
66   using FunctionAndValIdx = std::pair<Function *, unsigned>;
67   llvm::DenseMap<FunctionAndValIdx, NodeValue> functionAndValToConversion;
68   for (unsigned idx = 0, end = node.getNumResults(); idx != end; ++idx) {
69     NodeValue val = node.getNthResult(idx);
70     TypeRef targetTy = getTargetTypeForOutput(val);
71     if (!targetTy || targetTy == val.getType()) {
72       continue;
73     }
74     // convert the node and create a conversion to keep the users happy.
75     assert(targetTy->dims() == val.getType()->dims() &&
76            "Conversion does not preserve shape");
77     TypeRef origTy = val.getType();
78     // Fake the morphing of the node so that the creation
79     // of the conversion works properly.
80     val.setType(targetTy);
81     // Store the users in a temporary object because setOperand
82     // will invalidate the iterator.
83     llvm::SmallVector<NodeUse, 4> users(val.getUsers().begin(),
84                                         val.getUsers().end());
85     // We cannot use replaceAllUsesWith here because:
86     // 1. At this point, val and conversion don't have the same type
87     //    (one is converted the other is the original type), and that
88     //    would trigger an assertion.
89     // 2. We would end up replacing the use of val in "conversion" by
90     //   "conversion".
91     // 3. Node may be a module level value and we need one conversion per
92     //    function.
93     for (auto use : users) {
94       Node *user = use.getUser();
95       Function *parent = user->getParent();
96       assert(parent && "User not in a function?!");
97 
98       SaveNode *saveNode = llvm::dyn_cast<SaveNode>(user);
99       // The output of save nodes is special because it doesn't use
100       // the value of the node, but its address.
101       // Thus, if we want to change the value of the output of
102       // a save node, we actually have to convert the input.
103       if (saveNode && saveNode->getOutput() == val) {
104         NodeValue input = saveNode->getInput();
105         Node *conversion = createConversion(*parent, node, input, targetTy,
106                                             /* isInput */ false);
107         saveNode->setNthInput(SaveNode::InputIdx,
108                               getConversionOutput(*conversion));
109         continue;
110       }
111 
112       FunctionAndValIdx functionAndVal = std::make_pair(parent, idx);
113       auto conversionValIt = functionAndValToConversion.find(functionAndVal);
114       if (conversionValIt == functionAndValToConversion.end()) {
115         // Create the conversion.
116         Node *conversion =
117             createConversion(*parent, node, val, origTy, /* isInput */ false);
118         // "conversion" uses val so after this call,
119         // we will get a use of conversion inside conversion.
120         NodeValue conversionVal = getConversionOutput(*conversion);
121         auto insertion =
122             functionAndValToConversion.insert({functionAndVal, conversionVal});
123         assert(insertion.second && "Conversion already there?!");
124         conversionValIt = insertion.first;
125       }
126 
127       NodeValue conversionVal = conversionValIt->second;
128       if (user == conversionVal.getNode()) {
129         continue;
130       }
131       // Log the change of node input(operand).
132       if (Function *F = node.getParent()) {
133         F->getLogContext()->logNodeInputChange(*user, *(use.get()),
134                                                conversionVal);
135       }
136 
137       use.get()->setOperand(conversionVal.getNode(), conversionVal.getResNo());
138     }
139   }
140 }
141 
convertInputs(Node & node)142 void FunctionConverter::convertInputs(Node &node) {
143   // We shouldn't have to convert the inputs of something that is not in
144   // function_.
145   assert((node.getNumInputs() == 0 || node.getParent() == &function_) &&
146          "Invalid requested conversion");
147   for (unsigned idx = 0, end = node.getNumInputs(); idx != end; ++idx) {
148     NodeValue val = node.getNthInput(idx);
149     TypeRef targetTy = getTargetTypeForInput(node, idx);
150     if (!targetTy || targetTy == val.getType()) {
151       continue;
152     }
153     // convert the node and create a conversion to keep the users happy.
154     assert(targetTy->dims() == val.getType()->dims() &&
155            "Conversion does not preserve shape");
156     // Create the conversion.
157     Node *conversion =
158         createConversion(function_, node, val, targetTy, /* isInput */ true);
159     node.setNthInput(idx, getConversionOutput(*conversion));
160   }
161 }
162 
convert()163 void FunctionConverter::convert() {
164   assert(function_.verify() && "Input function must be valid");
165 
166   // Traverse all nodes.
167   // Check what the conversion should look like, if any.
168   // Convert the node appropriately.
169 
170   // For every unprocessed node in the graph we keep the invariant of having
171   // all inputs to be of the uncovered type.
172   // I.e., if we have:
173   // res(outTy) = node arg1(in2Ty), arg2(in2Ty)
174   //
175   // after converting "node", we will have something that looks like:
176   // newArg1(convertedIn1Ty) = conversion arg1
177   // newArg2(convertedIn2Ty) = conversion arg2
178   // newRes(convertedOutTy) = node newArg1, newArg2
179   // res(outTy) = conversion newRes
180   //
181   // In other words, the boundaries (in and out) are unchanged.
182 
183   // The iterator looks weird because we only want to iterate through
184   // the original nodes.
185   auto nodeIt = function_.getNodes().end();
186   auto stopIt = function_.getNodes().begin();
187   do {
188     --nodeIt;
189     Node &node = *nodeIt;
190     if (!canConvert(node)) {
191       continue;
192     }
193     // Mutate the output types and insert the conversion to keep our
194     // invariant.
195     convertOutputs(node);
196     // Convert the inputs of the node.
197     convertInputs(node);
198     // All the surrounding code is properly typed, finally the morph node.
199     Node &morphedNode = morphNode(node);
200     // Do some post processing if need be.
201     postProcessing(morphedNode);
202   } while (nodeIt != stopIt);
203 
204   // Allow a late clean-up before verifying the conversation produced a valid
205   // function.
206   cleanUp();
207 
208   assert(function_.verify() && "Conversion led to invalid function");
209 }
210 
convertPlaceholder(Placeholder & placeholder,PlaceholderBindings * bindings)211 void FunctionConverter::convertPlaceholder(Placeholder &placeholder,
212                                            PlaceholderBindings *bindings) {
213   TypeRef destTy = getTargetTypeForOutput(placeholder.getOutput());
214   if (!destTy || destTy == placeholder.getType()) {
215     return;
216   }
217   convertOutputs(placeholder);
218   if (!bindings) {
219     return;
220   }
221   Tensor *tensor = bindings->get(&placeholder);
222   if (tensor) {
223     convertTensor(*tensor, destTy);
224   }
225 }
226