1 /**
2  * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include "glow/Partitioner/Partitioner.h"
18 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
19 #include "glow/Partitioner/PartitionerOptimizer.h"
20 #include "glow/Partitioner/PartitionerUtils.h"
21 #include "glow/Partitioner/PartitionerValidation.h"
22 #include "glow/Support/Support.h"
23 
24 #include "llvm/Support/CommandLine.h"
25 #include "llvm/Support/raw_ostream.h"
26 
27 #include <fstream>
28 namespace glow {
29 bool GlowEnableLoadBalancedPartitioning = true;
30 bool GlowLogPartition = false;
31 bool GlowDumpPartition = false;
32 static llvm::cl::opt<bool, /* ExternalStorage */ true>
33     GlowEnableLoadBalancedPartitioningOpt(
34         "glow_partitioner_enable_load_balance",
35         llvm::cl::desc(
36             "Enable a partitioner pass to optimize for "
37             "load balance in addition to memory capacity constraints"),
38         llvm::cl::location(GlowEnableLoadBalancedPartitioning));
39 } // namespace glow
40 
41 /// -log-partition - Command line option to dump Partitioner logs.
42 static llvm::cl::OptionCategory PartitionerCat("Glow Partitioner Options");
43 static llvm::cl::opt<bool, /* ExternalStorage */ true> logPartition(
44     "log-partition", llvm::cl::desc("Enable logging partition info"),
45     llvm::cl::location(glow::GlowLogPartition), llvm::cl::cat(PartitionerCat));
46 
47 /// -dump-partition - Command line option to dump the graph of each partitions
48 /// by calling F->dumpDAG().
49 static llvm::cl::opt<bool, /* ExternalStorage */ true>
50     dumpPartition("dump-partition",
51                   llvm::cl::desc("Enable dumping the graph of each partitions"),
52                   llvm::cl::location(glow::GlowDumpPartition),
53                   llvm::cl::cat(PartitionerCat));
54 
55 using namespace glow;
56 using llvm::isa;
57 
58 // Sorted the std::pair<DAGNode *, uint64_t> based on the second from min to
59 // max.
sortMinMemory(const std::pair<Function *,uint64_t> & a,const std::pair<Function *,uint64_t> & b)60 bool sortMinMemory(const std::pair<Function *, uint64_t> &a,
61                    const std::pair<Function *, uint64_t> &b) {
62   return a.second < b.second;
63 }
64 
init()65 void Partitioner::init() {
66   memSize_ = module_->getConstantsSize();
67   logicalDeviceID_ = 0;
68   multiBackendNames_ = false;
69   for (size_t i = 1, e = deviceInfo_.size(); i < e; i++) {
70     if (deviceInfo_[i].backendName != deviceInfo_[0].backendName) {
71       multiBackendNames_ = true;
72       break;
73     }
74   }
75 }
76 
finalize(const DAGListTy & partitions,const NodeToFunctionMap & mapping)77 Error Partitioner::finalize(const DAGListTy &partitions,
78                             const NodeToFunctionMap &mapping) {
79 
80   // NOTE: Cannot validate the functions after partitioning here. The validation
81   // needs the backend specific verifier. Tensor layouts, for example, might
82   // have gone from canonical form to backend specific form.
83 
84   if (logPartition) {
85     LOG(INFO) << "The number of partitions is : "
86               << module_->getFunctions().size();
87     LOG(INFO) << "Dumping partitioning DAG to DAG.dot file.";
88     dumpDAG("DAG.dot", partitions);
89     logPartitionInfo(mapping);
90   }
91 
92   // Dump the graph of each function after partitioning.
93   if (dumpPartition) {
94     for (const auto &node : partitions[0].nodes) {
95       Function *subF = module_->getFunction(node->name);
96       if (!subF) {
97         // If we fail dump partition info for debugging.
98         logPartitionInfo(mapping);
99         return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
100                         "Invalid function name " + node->name);
101       }
102       subF->dumpDAG("partitionLogicalID" +
103                     std::to_string(node->logicalDevices[0]) + "__" +
104                     subF->getName().str() + "__" + node->backendName + ".dot");
105     }
106   }
107   return Error::success();
108 }
109 
Partitioner(Module * parent,const std::vector<DeviceInfo> & devices,const std::vector<Backend * > & backends,bool optimized)110 Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
111                          const std::vector<Backend *> &backends, bool optimized)
112     : module_(parent), deviceInfo_(devices), backends_(backends),
113       optimized_(optimized) {
114   init();
115 }
116 
Partitioner(Module * parent,const std::vector<DeviceInfo> & devices,bool optimized,PartitionConfig partitionConfig)117 Partitioner::Partitioner(Module *parent, const std::vector<DeviceInfo> &devices,
118                          bool optimized, PartitionConfig partitionConfig)
119     : module_(parent), deviceInfo_(devices), optimized_(optimized),
120       partitionConfig_(partitionConfig) {
121   init();
122 }
123 
selectRepFunc(Module * parent,uint64_t & memSize)124 Function *Partitioner::selectRepFunc(Module *parent, uint64_t &memSize) {
125   auto funcList = parent->getFunctions();
126   Function *ret = nullptr;
127   uint64_t maxMemSize = 0;
128   for (Function *F : funcList) {
129     uint64_t curSize = memSize;
130 
131     // The set to keep the placeholders (only for Inputs) whose size is
132     // already calculated.
133     std::set<llvm::StringRef> pSet;
134 
135     for (auto &node : F->getNodes()) {
136       int n = node.getNumInputs();
137       if (node.getKind() == Kinded::Kind::SaveNodeKind) {
138         // Special node, the placeholder should be ignored?
139         continue;
140       }
141       for (int i = 0; i < n; i++) {
142         Placeholder *in =
143             llvm::dyn_cast<Placeholder>(node.getNthInput(i).getNode());
144         if (in && pSet.find(in->getName()) == pSet.end()) {
145           auto ty = in->getType();
146           curSize += ty->getSizeInBytes();
147           pSet.insert(in->getName());
148         }
149       }
150     }
151     // Find the function with largest required memory as the representative
152     // function.
153     if (!ret || curSize > maxMemSize) {
154       ret = F;
155       maxMemSize = curSize;
156     }
157   }
158   memSize = maxMemSize;
159   return ret;
160 }
161 
partitionsAdjust(NodeToFunctionMap & partitions,uint64_t availableMemory)162 void Partitioner::partitionsAdjust(NodeToFunctionMap &partitions,
163                                    uint64_t availableMemory) {
164   // For each partition, create a node set.
165   FunctionToNodesMap nodesSet;
166   for (auto it = partitions.begin(); it != partitions.end(); ++it) {
167     nodesSet[(*it).second].insert((*it).first);
168   }
169 
170   // Optimize the communication cost.
171   optimizeCommunicationCost(partitions, nodesSet, module_, availableMemory);
172 
173   // Combine the current partitions if necessary.
174   partitionsCombine(partitions, nodesSet, module_, availableMemory);
175 }
176 
177 /// Assign nodes to partitions and return the mapping.
selectPartitions(Function * F,uint64_t availableMemory,llvm::StringRef backendName)178 NodeToFunctionMap Partitioner::selectPartitions(Function *F,
179                                                 uint64_t availableMemory,
180                                                 llvm::StringRef backendName) {
181   NodeToFunctionMap mapping;
182   BFSLevel bfs = getBFSLevel(F);
183   size_t level = bfs.size();
184 
185   // Step 1 : get the initial cut based on BFS levels and availableMemory.
186   int color = 0;
187   Function *newF;
188   newF = F->getParent()->createFunction(std::string(F->getName()) + "_part" +
189                                         std::to_string(++color));
190   mapping.createPartition(newF, backendName);
191   NodesSet currentPartition;
192   GraphMemInfo graphMem;
193   graphMem.contextCount = contextCount_;
194 
195   for (int i = level - 1; i >= 0; i--) {
196     for (size_t j = 0, e = bfs[i].size(); j < e; j++) {
197       Node *N = bfs[i][j];
198       graphMem = updateGraphMemInfoByAddingNode(currentPartition, graphMem, N);
199       // If after adding node N, the memory usage of this partition exceeds the
200       // device memory limitations, N can't be added into the current partition
201       // and a new partition is created.
202       if (graphMem.getTotalMemSize() > availableMemory) {
203         newF = F->getParent()->createFunction(
204             std::string(F->getName()) + "_part" + std::to_string(++color));
205         mapping.createPartition(newF, backendName);
206         currentPartition.clear();
207         graphMem =
208             updateGraphMemInfoByAddingNode(currentPartition, GraphMemInfo{}, N);
209       }
210       currentPartition.insert(N);
211       mapping.add(N, newF);
212       mapping.setGraphMemInfo(newF, graphMem);
213     }
214   }
215 
216   // Step 2 : adjust the partition based on performance.
217   partitionsAdjust(mapping, availableMemory);
218 
219   return mapping;
220 }
221 
saturateHost(unsigned logicalDeviceCount,const DAGListTy & partitions)222 void Partitioner::saturateHost(unsigned logicalDeviceCount,
223                                const DAGListTy &partitions) {
224   unsigned duplications = deviceInfo_.size() / logicalDeviceCount;
225   if (duplications < 2) {
226     return;
227   }
228   // Add additional logical devices to each node.
229   for (auto &network : partitions) {
230     for (auto &node : network.nodes) {
231       // Set instanceCount.
232       node->instanceCount = duplications;
233       // Build list of new logical devices to add to node.
234       std::vector<unsigned> newDevices;
235       for (auto logical : node->logicalDevices) {
236         // To ensure we do not have a logicalID collision we use the following
237         // scheme. We have an iterator starting at 1 for each duplication pass.
238         // The new ID we add is calculated as follows:
239         // (iterator * logicalDeviceCount) + initialLogicalID
240         for (unsigned i = 1; i < duplications; i++) {
241           newDevices.push_back(logical + (i * logicalDeviceCount));
242         }
243       }
244       // Append the new logical devices to the node's logical device vector.
245       node->logicalDevices.insert(node->logicalDevices.end(),
246                                   newDevices.begin(), newDevices.end());
247     }
248   }
249 }
250 
backendBasedPartition(FunctionToBackendNameMap & funcToBackend,Function * F,std::vector<Backend * > & backends,CompilationContext & cctx)251 Expected<DAGListTy> Partitioner::backendBasedPartition(
252     FunctionToBackendNameMap &funcToBackend, Function *F,
253     std::vector<Backend *> &backends, CompilationContext &cctx) {
254   NodeToFunctionMap mapping;
255   llvm::DenseMap<Node *, std::string> nodeToBackendName;
256 
257   // For each node find a backend that supports it.
258   for (auto &N : F->getNodes()) {
259     for (auto &backend : backends) {
260       // Find the first backend that supports this node. The order of backends
261       // is important. The check flow is :
262 
263       // Step 1: If a node is in pre-defined non-supported nodes set, it can not
264       // be assigned to this backend. Continue.
265       const auto &nonSupportedNodesKinds =
266           backendMap_[backend->getBackendName()].nonSupportedNodesKinds;
267       if (nonSupportedNodesKinds.count(N.getKind())) {
268         // This op is on the pre-defined non-supported op list:
269         continue;
270       }
271       // Step 2: If the pre-defined supported nodes set is empty, it means all
272       // nodes could be assigned to this backend. If the pre-defined supported
273       // nodes set is not empty, we check that if the node from Step 1 is in
274       // this set or not. If not, continue.
275       const auto &supportedNodesKinds =
276           backendMap_[backend->getBackendName()].supportedNodesKinds;
277       if (!supportedNodesKinds.empty() &&
278           !supportedNodesKinds.count(N.getKind())) {
279         // This op is not on the pre-definded supported op list:
280         continue;
281       }
282       // Step 3: Check if the node is actually supported in this backend, if so,
283       // assign it to this backend and break. Otherwise continue.
284       // TODO: the logic here need to be improved.
285       if (backend->shouldLower(&N) || backend->isOpSupported(N)) {
286         // Put this node into a partition for this backend.
287         nodeToBackendName[&N] = backend->getBackendName();
288         break;
289       }
290     }
291     if (nodeToBackendName.find(&N) == nodeToBackendName.end()) {
292       logPartitionInfo(mapping);
293       return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
294                       "Node is not supported by any of the provided backends");
295     }
296   }
297 
298   BFSLevel bfs = getBFSLevel(F);
299   size_t level = bfs.size();
300   int color = 0;
301   Function *newF;
302   newF = F->getParent()->createFunction(std::string(F->getName()) + "_part" +
303                                         std::to_string(++color));
304   auto backendName = nodeToBackendName[bfs[level - 1][0]];
305   if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) {
306     // When profiling, all the partition backend is assigned to
307     // profilingBackend.
308     mapping.createPartition(newF, profilingBackend);
309     funcToBackend[newF] = profilingBackend;
310   } else {
311     mapping.createPartition(newF, backendName);
312     funcToBackend[newF] = backendName;
313   }
314   for (int i = level - 1; i >= 0; i--) {
315     for (size_t j = 0, e = bfs[i].size(); j < e; j++) {
316       Node *N = bfs[i][j];
317       auto bk = nodeToBackendName[N];
318       if (bk != backendName) {
319         backendName = bk;
320         newF = F->getParent()->createFunction(
321             std::string(F->getName()) + "_part" + std::to_string(++color));
322         if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) {
323           // When profiling, all the partition backend is assigned to be
324           // profilingBackend.
325           mapping.createPartition(newF, profilingBackend);
326           funcToBackend[newF] = profilingBackend;
327         } else {
328           mapping.createPartition(newF, backendName);
329           funcToBackend[newF] = backendName;
330         }
331       }
332       mapping.add(N, newF);
333     }
334   }
335 
336   std::vector<Function *> funcs;
337   funcs.push_back(F);
338   // When profiling, the partition flow will be stopped after
339   // backendBasedPartition. Therefore, the DAG needs to be generated. Otherwise,
340   // no need to generate DAG.
341   bool genDAG = cctx.precisionConfig.quantMode == QuantizationMode::Profile
342                     ? true
343                     : false;
344   if (genDAG) {
345     DeviceIDTy logicalDeviceID = 0;
346     for (auto &func : mapping.getPartitions()) {
347       mapping.appendLogicalDeviceID(func, logicalDeviceID++);
348     }
349   }
350   return doPartitioning(F->getName(), funcs, module_, mapping, genDAG,
351                         cctx.backendOpts.backendSpecificNodeInfo);
352 }
353 
genBackendMap(std::map<std::string,BackendInfo> & backendMap,std::vector<std::unique_ptr<Backend>> & backendsHolder,std::vector<Backend * > & backends)354 void Partitioner::genBackendMap(
355     std::map<std::string, BackendInfo> &backendMap,
356     std::vector<std::unique_ptr<Backend>> &backendsHolder,
357     std::vector<Backend *> &backends) {
358   // If the backends are created already, we use them directly.
359   bool hasBackends = backends_.size() != 0;
360   if (hasBackends) {
361     DCHECK(backends_.size() == deviceInfo_.size())
362         << "number of backends and devices is not match.";
363   }
364 
365   int n = 0;
366   for (size_t i = 0, e = deviceInfo_.size(); i < e; i++) {
367     std::string backendName = deviceInfo_[i].backendName;
368     if (hasBackends) {
369       DCHECK(backends_[i]->getBackendName() == backendName)
370           << "Backend Type mismatch.";
371     }
372     if (backendMap.find(backendName) == backendMap.end()) {
373       BackendInfo backendInfo;
374       backendInfo.num = 1;
375       // We assume that for the same type of devices, the available memory size
376       // is the same.
377       // TODO : will improve the algorithm for different memory size.
378       backendInfo.memSize = deviceInfo_[i].availableMemory;
379       backendInfo.peakDramBw = deviceInfo_[i].peakDramBw;
380       backendInfo.peakSramBw = deviceInfo_[i].peakSramBw;
381       backendInfo.sramCapacity = deviceInfo_[i].sramCapacity;
382       backendInfo.peakCompute = deviceInfo_[i].peakCompute;
383       backendInfo.nonSupportedNodesKinds =
384           generateNodeKindsSet(deviceInfo_[i].nonSupportedNodes);
385       backendInfo.supportedNodesKinds =
386           generateNodeKindsSet(deviceInfo_[i].supportedNodes);
387       if (hasBackends) {
388         backendInfo.backend = backends_[i];
389       } else {
390         backendsHolder.emplace_back(createBackend(backendName));
391         backendInfo.backend = backendsHolder[n++].get();
392       }
393       backendMap[backendName] = backendInfo;
394       backends.push_back(backendMap[backendName].backend);
395     } else {
396       backendMap[backendName].num += 1;
397       // Since we are currently assuming one value it should be the max.
398       backendMap[backendName].memSize = std::max(
399           backendMap[backendName].memSize, deviceInfo_[i].availableMemory);
400     }
401   }
402 }
403 
404 const DeviceInfo &
getDeviceInfoForBackend(llvm::StringRef backendName)405 Partitioner::getDeviceInfoForBackend(llvm::StringRef backendName) {
406   for (DeviceInfo &devInfo : deviceInfo_) {
407     if (devInfo.backendName == backendName)
408       return devInfo;
409   }
410   llvm_unreachable("Each backend should have at least one device");
411 }
412 
createDAGWithoutPartition(llvm::StringRef backendName,std::map<std::string,BackendInfo> & backendMap,CompilationContext & cctx)413 Expected<DAGListTy> Partitioner::createDAGWithoutPartition(
414     llvm::StringRef backendName, std::map<std::string, BackendInfo> &backendMap,
415     CompilationContext &cctx) {
416   DAGListTy partitions;
417   const DeviceIDTy logDevice = 0;
418   for (auto F : module_->getFunctions()) {
419     if (!optimized_) {
420       auto backend = backendMap[backendName].backend;
421       RETURN_IF_ERR(::glow::optimizeFunction(
422           F, *backend, cctx, &getDeviceInfoForBackend(backendName)));
423     }
424     std::unique_ptr<DAGNode> DAG0 = glow::make_unique<DAGNode>();
425     DAG0->logicalDevices = {logDevice};
426     DAG0->name = F->getName();
427     DAG0->module = module_;
428     std::unique_ptr<DAGNode> DAG1 = glow::make_unique<DAGNode>();
429     DAG1->logicalDevices = {logDevice};
430     DAG1->name = F->getName();
431     DAG1->backendName = backendName;
432     DAG1->parents.push_back(DAG0.get());
433     DAG0->children.push_back(DAG1.get());
434     DAG1->replicationCount = cctx.replicationCount;
435     DAGNodePtrVec nodes;
436     nodes.push_back(std::move(DAG1));
437     partitions.push_back({std::move(DAG0), std::move(nodes)});
438   }
439   if (cctx.saturateHost) {
440     // Saturate the Host.
441     saturateHost(1, partitions);
442   }
443 
444   NodeToFunctionMap mapping;
445   for (auto func : module_->getFunctions()) {
446     mapping.createPartition(func, backendName);
447     mapping.setGraphMemInfo(func, getFunctionMemory(func));
448 
449     // Use the same hard-coded logical device ID as used for the DAG itself.
450     mapping.appendLogicalDeviceID(func, logDevice);
451   }
452 
453   RETURN_IF_ERR(finalize(partitions, mapping));
454 
455   return std::move(partitions);
456 }
457 
loadBalancedPartition(CompilationContext & cctx,size_t numDevices)458 Expected<DAGListTy> Partitioner::loadBalancedPartition(CompilationContext &cctx,
459                                                        size_t numDevices) {
460 
461   if (multiBackendNames_) {
462     VLOG(1) << "For multi backend types, load-balanced partition can't be "
463                "applied. Call heterogeneous partition instead.";
464     return heterogeneousPartition(cctx);
465   }
466   F_ = selectRepFunc(module_, memSize_);
467   std::string origName(F_->getName().data());
468   DAGListTy partitions;
469   std::vector<Backend *> backends;
470   genBackendMap(backendMap_, backendHolder_, backends);
471   auto backendName = backends[0]->getBackendName();
472 
473   if (memSize_ < backendMap_[backendName].memSize) {
474     // No partition is needed. Create DAGNode and return. This root is always a
475     // dummy function.
476     if (logPartition) {
477       LOG(INFO) << "The model is too small for applying partition.\n"
478                 << "Model size : " << memSize_ << "\n"
479                 << "Backend Name : " << backendName << "\n"
480                 << "Device memory: " << backendMap_[backendName].memSize
481                 << "\n";
482     }
483     return createDAGWithoutPartition(backendName, backendMap_, cctx);
484   }
485 
486   // Step 1: Get the minial number of partitions from auto-partition.
487   uint64_t availableMemory = backendMap_[backendName].memSize;
488   if (!optimized_) {
489     RETURN_IF_ERR(::glow::optimizeFunction(F_, *(backends[0]), cctx));
490   }
491   NodeToFunctionMap mapping =
492       selectPartitions(F_, availableMemory, backendName);
493   logicalDeviceID_ = assignLogicalDeviceID(mapping, backendMap_);
494 
495   if (logicalDeviceID_ > numDevices) {
496     numDevices = logicalDeviceID_;
497   }
498   // Step 2:
499   // Currently, the load balanced partitioner disregards the input mapping
500   // and only uses the numPartitions input from previous partitioning passes
501   // But we take this in to leave open the option of using the previous mapping
502   // at a later point.
503   // The main idea here is to use the roofline estimates to load balance
504   // partitions. At this point, we stick to one partition per device, so
505   // we ensure that we only have edges from nodes in smaller partition ids to
506   // nodes in larger partition ids to ensure an acyclic DAGNode graph.
507   //
508   // The overall algorithm is as follows:
509   // Iterate through all operators in breadth-first fashion.
510   // For each operator do:
511   // (a) Find the maximum partition id of each input node.
512   // (b) Assign the operator to this partition if memory
513   //     constraints are satisfied and the total sum of operator runtimes
514   //     assigned to the partition exceeds 1/numPartitions fraction of
515   //     overall roofline runtime
516   // (c) In case memory constraint isnt satisfied, then try to put operator
517   //     in successively higher partitions until the conditions get satisfied.
518   //     If we cannot find such a partition where this operator can be assigned,
519   //     throw an error.
520 
521   // Initialize runtimes and memory availability per device
522   std::vector<float> deviceTime(numDevices, 0);
523   std::vector<size_t> memoryAvailable(numDevices, availableMemory);
524   std::vector<NodesSet> nodesInPartitions(numDevices);
525   std::vector<GraphMemInfo> graphMem(numDevices, GraphMemInfo{});
526   std::vector<Function *> partitionFuncs(numDevices);
527 
528   // Compute total roofline time
529   NodeToFunctionMap partitionMap;
530   float totalRooflineTime = 0;
531   for (auto &n : F_->getNodes()) {
532     totalRooflineTime +=
533         getNodeComputeTime(&n, backendMap_[deviceInfo_[0].backendName]);
534   }
535 
536   float timePerPartition = totalRooflineTime / numDevices;
537 
538   // Get the BFS levels
539   Function *newF;
540   BFSLevel bfs = getBFSLevel(F_);
541   size_t level = bfs.size();
542 
543   // Create the functions and push them into the mapping
544   for (DeviceIDTy curPartition = 0; curPartition < numDevices; curPartition++) {
545     std::string funcName =
546         std::string(F_->getName()) + "_part" + std::to_string(curPartition + 1);
547     if (F_->getParent()->hasFunction(funcName)) {
548       newF = F_->getParent()->getFunction(funcName);
549       F_->getParent()->eraseFunction(newF);
550     }
551     newF = F_->getParent()->createFunction(funcName);
552     partitionMap.createPartition(newF, backendName);
553     partitionMap.appendLogicalDeviceID(newF, curPartition);
554     partitionFuncs[curPartition] = newF;
555   }
556 
557   // Go through operators level by level
558   for (int i = level - 1; i >= 0; i--) {
559     for (size_t j = 0, e = bfs[i].size(); j < e; j++) {
560       Node *N = bfs[i][j];
561 
562       // Find the maximum partition id of the inputs to the node
563       DeviceIDTy maxLogicalDeviceId = 0;
564       for (auto &I : getInputs(N)) {
565         Function *inpF = partitionMap[I];
566         auto logicalDeviceIds = partitionMap.getLogicalDeviceIDList(inpF);
567         DCHECK(logicalDeviceIds.size() == 1);
568         auto logicalDeviceId = logicalDeviceIds[0];
569         if (logicalDeviceId > maxLogicalDeviceId) {
570           maxLogicalDeviceId = logicalDeviceId;
571         }
572       }
573 
574       auto curOpTime =
575           getNodeComputeTime(N, backendMap_[deviceInfo_[0].backendName]);
576       auto curOpMemory = getNodeMemUsage(N);
577 
578       // Find a partition to put this node into
579       DeviceIDTy curPartition = maxLogicalDeviceId;
580       const float allowedLoadImbalanceFraction = 0.5f;
581       for (; curPartition < numDevices; curPartition++) {
582         // Put the op in current partition if
583         // (a) memory constaints and load balance constraints are not violated,
584         // or (b) this is the last partition and memory capacity isnt exceeded
585         // The allowedLoadImbalanceFraction in the load balance case is to avoid
586         // edge cases where load balance is only violated by a small amount and
587         // moving to the next partition would result in significant imbalance in
588         // runtime. Hence if the violation is by less than
589         // allowedLoadImbalanceFraction of the operator cost, then we prefer to
590         // keep it in the current partition.
591         bool loadBalanceValid = deviceTime[curPartition] +
592                                     curOpTime * allowedLoadImbalanceFraction <
593                                 timePerPartition;
594         bool memValid = memoryAvailable[curPartition] >= curOpMemory;
595 
596         if (memValid && (loadBalanceValid || curPartition == numDevices - 1)) {
597           // valid, put the node in the current partition
598           Function *curF = partitionFuncs[curPartition];
599           partitionMap.add(N, curF);
600           deviceTime[curPartition] += curOpTime;
601           memoryAvailable[curPartition] -= curOpMemory;
602           graphMem[curPartition] = updateGraphMemInfoByAddingNode(
603               nodesInPartitions[curPartition], graphMem[curPartition], N);
604           nodesInPartitions[curPartition].insert(N);
605           partitionMap.setGraphMemInfo(curF, graphMem[curPartition]);
606           break;
607         }
608       }
609 
610       // Throw error if we were not able to put this node into any partition
611       if (curPartition >= numDevices) {
612         logPartitionInfo(partitionMap);
613         return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
614                         "Load balance partition error");
615       }
616     }
617   }
618   for (size_t i = 0; i < numDevices; i++) {
619     VLOG(1) << "Partition #" << i << " has estimated runtime " << deviceTime[i];
620   }
621   // Check if the memory usage meets the device memory limitation.
622   RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_));
623 
624   // assignLogicalDeviceID adds all partitions to their logical device, clear
625   // the existing first to prevent duplication.
626   partitionMap.clearLogicalDeviceID();
627   logicalDeviceID_ = assignLogicalDeviceID(partitionMap, backendMap_);
628   RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_));
629 
630   partitions =
631       doPartitioning(origName, {F_}, module_, partitionMap, /* saveDAG */ true,
632                      cctx.backendOpts.backendSpecificNodeInfo);
633   module_->eraseFunction(F_);
634 
635   if (cctx.saturateHost &&
636       partitionMap.getPartitions().size() < deviceInfo_.size()) {
637     saturateHost(logicalDeviceID_, partitions);
638   }
639 
640   RETURN_IF_ERR(finalize(partitions, partitionMap));
641 
642   return std::move(partitions);
643 }
644 
645 Expected<DAGListTy>
quantizationProfilingPartition(CompilationContext & cctx)646 Partitioner::quantizationProfilingPartition(CompilationContext &cctx) {
647   // For quantization profiling flow, currently we assume there is only 1
648   // function in a module.
649   if (module_->getFunctions().size() != 1) {
650     return MAKE_ERR(
651         ErrorValue::ErrorCode::PARTITIONER_ERROR,
652         strFormat(
653             "Invalid : %lu functions in a module. In quantization profiling "
654             "partition flow, the module can only contain 1 function",
655             module_->getFunctions().size()));
656   }
657 
658   // Quantization profiling flow is run under CPU backend, so we don't really
659   // need the concrete partition. The backendBasedPartition is necessary since
660   // we need the mapping between quantized tensor and original tensor.
661   DAGListTy partitions;
662   std::vector<Backend *> backends;
663   genBackendMap(backendMap_, backendHolder_, backends);
664   F_ = selectRepFunc(module_, memSize_);
665 
666   FunctionToBackendNameMap funcToBackend;
667   ASSIGN_VALUE_OR_RETURN_ERR(
668       partitions, backendBasedPartition(funcToBackend, F_, backends, cctx));
669   module_->eraseFunction(F_);
670   std::unique_ptr<Backend> backend(createBackend(profilingBackend));
671   for (Function *subF : module_->getFunctions()) {
672     DCHECK(subF->verify()) << "Conversion led to invalid function";
673     if (!optimized_) {
674       RETURN_IF_ERR(::glow::optimizeFunction(subF, *backend, cctx));
675     }
676   }
677   if (logPartition) {
678     LOG(INFO)
679         << "Profiling a model to be partitioned cross different backends. Each "
680            "sub-network will be optimized and run on cpu backend.\n";
681   }
682   return std::move(partitions);
683 }
684 
685 Expected<DAGListTy>
heterogeneousPartition(CompilationContext & cctx)686 Partitioner::heterogeneousPartition(CompilationContext &cctx) {
687   DAGListTy partitions;
688   // Prepare the mapping between BackendName and BackendInfo.
689   std::vector<Backend *> backends;
690   genBackendMap(backendMap_, backendHolder_, backends);
691 
692   // Step 0: Find the representative function for running partitioning
693   // algorithm.
694   F_ = selectRepFunc(module_, memSize_);
695 
696   // Step 1 : do the partition based on backends type.
697   FunctionToBackendNameMap funcToBackend;
698   std::string origName(F_->getName().data());
699   if (backends.size() == 1) {
700     // Only one type of backends, no need to backendName based partition.
701     auto backendName = backends[0]->getBackendName();
702     funcToBackend[F_] = backendName;
703 
704     if (memSize_ < backendMap_[backendName].memSize) {
705       // No partition is needed. Create DAGNode and return. This root is alway a
706       // dummy function.
707       if (logPartition) {
708         LOG(INFO) << "The model is too small for applying partition.\n"
709                   << "Model size : " << memSize_ << "\n"
710                   << "Backend Name : " << backendName << "\n"
711                   << "Device memory: " << backendMap_[backendName].memSize
712                   << "\n";
713       }
714       return createDAGWithoutPartition(backendName, backendMap_, cctx);
715     }
716     // NOTE: the following error detection will be removed once multi-functions
717     // in a module is supported.
718     if (module_->getFunctions().size() != 1) {
719       return MAKE_ERR(
720           ErrorValue::ErrorCode::PARTITIONER_ERROR,
721           strFormat("Invalid : %lu functions in a module. Now in heterogeneous "
722                     "partition flow, the module can only contain 1 function",
723                     module_->getFunctions().size()));
724     }
725   } else {
726     // NOTE: the following error detection will be removed once multi-functions
727     // in a module is supported.
728     if (module_->getFunctions().size() != 1) {
729       return MAKE_ERR(
730           ErrorValue::ErrorCode::PARTITIONER_ERROR,
731           strFormat(
732               "Invalid : %lu functions in a module. Now in heterogeneous partition\
733  flow, the module can only contain 1 function",
734               module_->getFunctions().size()));
735     }
736     ASSIGN_VALUE_OR_RETURN_ERR(
737         partitions, backendBasedPartition(funcToBackend, F_, backends, cctx));
738     module_->eraseFunction(F_);
739   }
740 
741   // Step 2 : optimize each functions based on its backend type and apply the
742   // partition algorithm.
743   NodeToFunctionMap mapping;
744   std::vector<Function *> funcs;
745   for (auto i = funcToBackend.begin(); i != funcToBackend.end(); ++i) {
746     auto *func = i->first;
747     auto *backend = backendMap_[i->second].backend;
748     auto availMem = backendMap_[i->second].memSize;
749     funcs.push_back(func);
750     DCHECK(func->verify()) << "Conversion led to invalid function";
751     // Step 2.1 : optimize a function if it has not been optimized yet.
752     if (!optimized_) {
753       RETURN_IF_ERR(::glow::optimizeFunction(
754           func, *backend, cctx,
755           &getDeviceInfoForBackend(backend->getBackendName())));
756     }
757 
758     // Step 2.2 : apply graph partitioning algrithm to find out the partition.
759     NodeToFunctionMap partitionMap =
760         selectPartitions(func, availMem, i->second);
761     mapping.insert(partitionMap);
762   }
763 
764   // Check if the memory usage meets the device memory limitation.
765   RETURN_IF_ERR(memoryUsageValidation(mapping, backendMap_));
766 
767   // Step 3 : assign each partition with a logical device id. The partitions
768   // with the same logical device id will be assigned into the same physical
769   // device.
770   logicalDeviceID_ = assignLogicalDeviceID(mapping, backendMap_);
771 
772   // Check if the number of logical devices is less than the given physical
773   // devices.
774   RETURN_IF_ERR(logicalDevicesValidation(mapping, backendMap_));
775 
776   // Step 4 : do the real partitioning for the function list.
777   partitions =
778       doPartitioning(origName, funcs, module_, mapping, /* saveDAG */ true,
779                      cctx.backendOpts.backendSpecificNodeInfo);
780 
781   // Step 5 : Post-partition optimization - Adjust the logicalDevice for each
782   // DAGNode.
783   if (cctx.saturateHost && backends.size() == 1 &&
784       mapping.getPartitions().size() < deviceInfo_.size()) {
785     // Attempt to saturate the host when there is only one type of backend.
786     // Passing in the count of logical devices. Since logicalId starts at 0 we
787     // add one.
788     saturateHost(logicalDeviceID_, partitions);
789   }
790 
791   // Step 6 : clean up and verify the generated new functions.
792   for (auto i = funcToBackend.begin(); i != funcToBackend.end(); ++i) {
793     module_->eraseFunction(i->first);
794   }
795 
796   RETURN_IF_ERR(finalize(partitions, mapping));
797 
798   return std::move(partitions);
799 }
800 
801 Expected<DAGListTy>
partitionFromConfig(const PartitionConfig & partitionConfig,CompilationContext & cctx)802 Partitioner::partitionFromConfig(const PartitionConfig &partitionConfig,
803                                  CompilationContext &cctx) {
804   DAGListTy partitions;
805   // Prepare the mapping between BackendName and BackendInfo.
806   std::vector<Backend *> backends;
807   genBackendMap(backendMap_, backendHolder_, backends);
808   Function *F = module_->getFunction(partitionConfig.funcName);
809   if (!F) {
810     return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
811                     strFormat("Can't find function %s in current module.",
812                               F->getName().str().data()));
813   }
814 
815   DCHECK(
816       partitionConfig.numOfPartitions == partitionConfig.backendNames.size() &&
817       partitionConfig.numOfPartitions == partitionConfig.partitionNames.size())
818       << "Invalid user-defined partition config.";
819 
820   if (partitionConfig.backendHints.size()) {
821     DCHECK(partitionConfig.numOfPartitions ==
822            partitionConfig.backendHints.size())
823         << "Invalid user-defined partition config (backendHints).";
824   }
825 
826   NodeToFunctionMap partitionMap;
827   std::vector<Function *> funcList;
828   std::unordered_set<size_t> unused;
829   std::vector<NodesSet> nodesSets(partitionConfig.numOfPartitions);
830   // Create partitions based on the given number and names.
831   for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
832     Function *newF = module_->createFunction(partitionConfig.partitionNames[i]);
833     funcList.push_back(newF);
834     partitionMap.createPartition(newF, partitionConfig.backendNames[i]);
835     unused.insert(i);
836   }
837 
838   // Map the nodes the the partitions.
839   std::vector<Node *> unMapped;
840   for (auto &node : F->getNodes()) {
841     auto iter = partitionConfig.nodeToPartition.find(node.getName());
842     if (iter == partitionConfig.nodeToPartition.end()) {
843       // If a node in F is not in the node to partition mapping, put it into
844       // unMaped list.
845       unMapped.push_back(&node);
846     } else {
847       size_t partitionID = iter->second;
848       DCHECK(partitionID < partitionConfig.numOfPartitions)
849           << "Invalid partition id :" << partitionID;
850       partitionMap.add(&node, funcList[partitionID]);
851       unused.erase(partitionID);
852       nodesSets[partitionID].insert(&node);
853     }
854   }
855 
856   // If there is unused partition and unmapped nodes, map those nodes to the
857   // unused partition.
858   if (unMapped.size()) {
859     DCHECK(unused.size() == 1) << "There must be exactly 1 unused partition.";
860     auto partitionID = *(unused.begin());
861     for (auto &node : unMapped) {
862       partitionMap.add(node, funcList[partitionID]);
863       nodesSets[partitionID].insert(node);
864     }
865   }
866 
867   // Set backend hints if they exist
868   if (partitionConfig.backendHints.size()) {
869     for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
870       auto func = funcList[i];
871       partitionMap.setBackendHints(func, partitionConfig.backendHints[i]);
872     }
873   }
874 
875   // Validate memory usage.
876   for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
877     GraphMemInfo cost = getGraphMemInfo(nodesSets[i]);
878     partitionMap.setGraphMemInfo(funcList[i], cost);
879   }
880   RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_));
881 
882   // If logical device assignments are provided use them otherwise assign them.
883   if (partitionConfig.logicalIDs.size()) {
884     DCHECK(partitionConfig.numOfPartitions ==
885            partitionConfig.logicalIDs.size());
886     for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
887       auto func = funcList[i];
888       for (auto logicalDevice : partitionConfig.logicalIDs[i]) {
889         partitionMap.appendLogicalDeviceID(func, logicalDevice);
890       }
891     }
892 
893   } else {
894     // Logical device ID validation.
895     logicalDeviceID_ = assignLogicalDeviceID(partitionMap, backendMap_);
896   }
897   // Add replication count to config if provided.
898   for (auto &replicationAssignment : partitionConfig.replicationCount) {
899     auto func = funcList.at(replicationAssignment.first);
900     partitionMap.addReplicationCount(func, replicationAssignment.second);
901   }
902 
903   RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_));
904 
905   // Do partition.
906   partitions = doPartitioning(F->getName(), {F}, module_, partitionMap,
907                               /* saveDAG */ true,
908                               cctx.backendOpts.backendSpecificNodeInfo);
909   module_->eraseFunction(F);
910 
911   // DAG validation.
912   RETURN_IF_ERR(dagValidation(partitions[0]));
913 
914   // Verify the function.
915   for (size_t i = 0; i < partitionConfig.numOfPartitions; i++) {
916     auto func = funcList[i];
917     DCHECK(func->verify()) << "Conversion led to invalid function";
918   }
919 
920   RETURN_IF_ERR(finalize(partitions, partitionMap));
921 
922   return std::move(partitions);
923 }
924 
925 Expected<DAGListTy>
setupPrepartitionedModule(CompilationContext & cctx)926 Partitioner::setupPrepartitionedModule(CompilationContext &cctx) {
927   const PrePartitionedConfig &config = *cctx.prepartitionedConfig;
928 
929   RETURN_ERR_IF_NOT(
930       !multiBackendNames_,
931       "Do not support multiple backend kinds in prepartitioned flow.");
932 
933   // Prepare the mapping between BackendName and BackendInfo.
934   std::vector<Backend *> backends;
935   genBackendMap(backendMap_, backendHolder_, backends);
936 
937   const std::vector<Function *> &funcs = config.funcs;
938 
939   Backend *B = backends[0];
940   auto backendName = B->getBackendName();
941 
942   // Optimize all Functions if necessary.
943   if (!optimized_) {
944     for (Function *F : funcs) {
945       RETURN_IF_ERR(::glow::optimizeFunction(
946           F, *B, cctx, &getDeviceInfoForBackend(backendName)));
947     }
948   }
949 
950   NodeToFunctionMap partitionMap;
951   // Create partitions based on the given number and names.
952   for (size_t i = 0, e = funcs.size(); i < e; i++) {
953     partitionMap.createPartition(funcs[i], deviceInfo_[0].backendName);
954   }
955 
956   // Map the nodes the the partitions.
957   for (Function *F : funcs) {
958     for (auto &node : F->getNodes()) {
959       partitionMap.add(&node, F);
960     }
961   }
962 
963   // Validate memory usage.
964   for (Function *F : funcs) {
965     partitionMap.setGraphMemInfo(F, getFunctionMemory(F));
966   }
967   RETURN_IF_ERR(memoryUsageValidation(partitionMap, backendMap_));
968 
969   // If logical device assignments are provided use them otherwise assign them.
970   DCHECK(funcs.size() == config.logicalIDs.size());
971   for (size_t i = 0; i < funcs.size(); i++) {
972     Function *F = funcs[i];
973     for (auto logicalDevice : config.logicalIDs[i]) {
974       partitionMap.appendLogicalDeviceID(F, logicalDevice);
975     }
976   }
977   RETURN_IF_ERR(logicalDevicesValidation(partitionMap, backendMap_));
978 
979   // Copy in or validate all members of the PPC.
980   RETURN_ERR_IF_NOT(
981       funcs.size() == config.backendSpecificOpts.size(),
982       "Number of Functions must equal number of backendSpecificOpts");
983   RETURN_ERR_IF_NOT(funcs.size() == config.backendHints.size(),
984                     "Number of Functions must equal number of backendHints");
985   RETURN_ERR_IF_NOT(funcs.size() == config.replicationCounts.size(),
986                     "Number of Functions must equal");
987   RETURN_ERR_IF_NOT(
988       funcs.size() == config.backendNames.size() || config.backendNames.empty(),
989       "If there are backendNames specified, there must be one per Function");
990   for (size_t i = 0, e = funcs.size(); i < e; i++) {
991     Function *F = funcs[i];
992     partitionMap.setBackendSpecificOpts(F, config.backendSpecificOpts[i]);
993     partitionMap.setBackendHints(F, config.backendHints[i]);
994     partitionMap.addReplicationCount(F, config.replicationCounts[i]);
995     if (!config.backendNames.empty()) {
996       RETURN_ERR_IF_NOT(backendName == config.backendNames[i],
997                         "Mismatch on backendName for partition");
998     }
999   }
1000 
1001   // Do partition.
1002   DAGListTy partitions = doPartitioning(
1003       config.funcName, funcs, module_, partitionMap,
1004       /* saveDAG */ true, cctx.backendOpts.backendSpecificNodeInfo,
1005       /* skipCloning */ true);
1006 
1007   // DAG validation.
1008   RETURN_IF_ERR(dagValidation(partitions[0]));
1009 
1010   // Verify the function.
1011   for (Function *F : funcs) {
1012     DCHECK(F->verify()) << "Conversion led to invalid function";
1013   }
1014 
1015   RETURN_IF_ERR(finalize(partitions, partitionMap));
1016 
1017   return std::move(partitions);
1018 }
1019 
1020 struct SLSTableInfo {
1021   Node *node;
1022   uint64_t numBytesInTable;
1023   unsigned int deviceId;
1024   NodeValue slsResult;
1025   uint64_t cost;
1026 };
1027 
1028 struct SLSDeviceInfo {
1029   unsigned int deviceId;
1030   uint64_t memAvailableInBytes;
1031   size_t currentCost;
1032 };
1033 
1034 /// Helper function for SparseNN Partitioning scheme. Checks for each
1035 /// kind of SLS table and appends their metadata to the vector.
1036 template <typename SLSType>
appendSLSTable(Node & node,std::vector<SLSTableInfo> & slsTables,bool doPerfModelBalance,Backend * backend)1037 Error appendSLSTable(Node &node, std::vector<SLSTableInfo> &slsTables,
1038                      bool doPerfModelBalance, Backend *backend) {
1039   auto *SLS0 = llvm::dyn_cast<SLSType>(&node);
1040   if (SLS0) {
1041     uint64_t cost = 1;
1042     uint64_t numBytesInTable =
1043         (uint64_t)SLS0->getData().getType()->getSizeInBytes();
1044 
1045     // If average length is available, then compute cost using perf model
1046     if (doPerfModelBalance) {
1047       double cost_d;
1048       ASSIGN_VALUE_OR_RETURN_ERR(cost_d, backend->estimateNodeCost(SLS0));
1049       cost = (uint64_t)cost_d;
1050     }
1051     auto slsResult = SLS0->getResult();
1052     slsTables.push_back({SLS0, numBytesInTable, 0, slsResult, cost});
1053   }
1054   return Error::success();
1055 }
1056 
1057 // Check if the weights input for \p SLWS is a SplatNode with more than one
1058 // user, and if so clone the splat node into \p F and set it to be the new
1059 // weights of \p SLWS.
1060 template <class T>
cloneSplatWeightsIfNecessary(T * SLWS,Function * F)1061 static void cloneSplatWeightsIfNecessary(T *SLWS, Function *F) {
1062   SplatNode *splatWeights = llvm::dyn_cast<SplatNode>(SLWS->getWeights());
1063   if (!splatWeights || splatWeights->getNumUsers() <= 1) {
1064     return;
1065   }
1066   SplatNode *splatWeightsClone =
1067       F->addNode(llvm::cast<SplatNode>(splatWeights->clone()));
1068   SLWS->setNthInput(T::WeightsIdx, splatWeightsClone->getResult());
1069 }
1070 
1071 // Insert Split->Concat at barrier between SLS and Non-SLS partitions
sparseNNInsertSplitConcat(Function * F,std::vector<SLSDeviceInfo> slsDevices,std::vector<std::vector<NodeValue>> frontiers,PartitionConfig & partitionConfig)1072 Error sparseNNInsertSplitConcat(Function *F,
1073                                 std::vector<SLSDeviceInfo> slsDevices,
1074                                 std::vector<std::vector<NodeValue>> frontiers,
1075                                 PartitionConfig &partitionConfig) {
1076 
1077   // Walk through SLS tables and check that all the results are able to concat
1078   std::vector<std::vector<NodeValue>> concatInputs(slsDevices.size());
1079   // Insert concat and slice nodes and assign them to partitions
1080   for (size_t p = 0; p < slsDevices.size(); p++) {
1081     auto frontier = frontiers[p];
1082 
1083     if (frontier.size() == 0) {
1084       continue;
1085     }
1086     auto templateResult = frontier[0];
1087     auto templateDims = templateResult.dims();
1088     auto templateConcatDim = templateDims.size() - 1;
1089 
1090     for (auto &tableResult : frontier) {
1091       auto tableDims = tableResult.dims();
1092       RETURN_ERR_IF_NOT(tableDims.size() == templateDims.size(),
1093                         strFormat("SLS concat addition encountered tensors "
1094                                   "with differing dimensions (%zu vs %zu)",
1095                                   (size_t)tableDims.size(),
1096                                   (size_t)templateDims.size()));
1097       for (dim_t otherDim = 0; otherDim < templateConcatDim; otherDim++) {
1098         RETURN_ERR_IF_NOT(tableDims[otherDim] == templateDims[otherDim],
1099                           strFormat("SLS concat addition encountered tensors "
1100                                     "with differing dimension (%zu vs %zu)",
1101                                     (size_t)tableDims[otherDim],
1102                                     (size_t)templateDims[otherDim]));
1103       }
1104       RETURN_ERR_IF_NOT(
1105           tableResult.getType()->getElementType() ==
1106               templateResult.getType()->getElementType(),
1107           "SLS concat addition encountered tensors with differing ElementType");
1108       concatInputs[p].push_back(tableResult);
1109     }
1110 
1111     if (concatInputs[p].size() > 1) {
1112 
1113       // Insert concat
1114       auto *deviceConcat = F->createConcat("concat_dev_" + std::to_string(p),
1115                                            concatInputs[p], templateConcatDim);
1116       partitionConfig.nodeToPartition[deviceConcat->getName()] = p;
1117 
1118       // Insert slices
1119       std::vector<dim_t> splits(concatInputs[p].size());
1120       for (dim_t i = 0; i < concatInputs[p].size(); i++) {
1121         auto inputDim = concatInputs[p][i].dims();
1122         splits[i] = inputDim[templateConcatDim];
1123       }
1124       std::vector<SliceNode *> splitOutputs;
1125       F->createSplit("split_dev" + std::to_string(p), deviceConcat,
1126                      splits.size(), templateConcatDim, splits, splitOutputs);
1127       for (dim_t i = 0; i < concatInputs[p].size(); i++) {
1128         assert(i < splitOutputs.size());
1129         concatInputs[p][i].replaceAllUsesOfWith(splitOutputs[i]);
1130         deviceConcat->setNthInput(i, concatInputs[p][i]);
1131         partitionConfig.nodeToPartition[splitOutputs[i]->getName()] =
1132             partitionConfig.numOfPartitions - 1;
1133       }
1134     }
1135   }
1136   return Error::success();
1137 };
1138 
1139 // Do a search starting at an SLS output to capture any Clip or
1140 // LayerNormalization nodes which are there
expandFrontier(const Node * node,const NodeValue & value,std::vector<NodeValue> & frontier,std::vector<const Node * > & traversedNodes,bool includeLN)1141 void expandFrontier(const Node *node, const NodeValue &value,
1142                     std::vector<NodeValue> &frontier,
1143                     std::vector<const Node *> &traversedNodes, bool includeLN) {
1144   traversedNodes.push_back(node);
1145   bool covered = true;
1146   auto users = node->getUsers();
1147   for (auto j = users.begin(), f = users.end(); j != f; ++j) {
1148     const Node *user = (*j).getUser();
1149     if (const ClipNode *CN = llvm::dyn_cast<ClipNode>(user)) {
1150       expandFrontier(user, CN->getResult(), frontier, traversedNodes,
1151                      includeLN);
1152     } else if ((includeLN) &&
1153                (user->getKind() ==
1154                 glow::Kinded::Kind::LayerNormalizationNodeKind)) {
1155       const LayerNormalizationNode *LN =
1156           llvm::dyn_cast<LayerNormalizationNode>(user);
1157       expandFrontier(user, LN->getResult(), frontier, traversedNodes,
1158                      includeLN);
1159     } else {
1160       covered = false;
1161     }
1162   }
1163   if (!covered) {
1164     frontier.push_back(value);
1165   }
1166 }
1167 
partitionSparseNN(CompilationContext & cctx)1168 Expected<DAGListTy> Partitioner::partitionSparseNN(CompilationContext &cctx) {
1169 
1170   VLOG(1) << "Doing SparseNN partitioning" << std::endl;
1171   PartitionConfig partitionConfig;
1172   partitionConfig.numOfPartitions = 0;
1173 
1174   // Find the first partition with an SLS node
1175   std::string funcName;
1176   bool foundFunction = false;
1177   for (Function *F : module_->getFunctions()) {
1178     for (auto &node : F->getNodes()) {
1179       if (node.getKind() ==
1180               glow::Kinded::Kind::
1181                   FusedRowwiseQuantizedSparseLengthsWeightedSumNodeKind ||
1182           node.getKind() == glow::Kinded::Kind::
1183                                 FusedRowwiseQuantizedSparseLengthsSumNodeKind ||
1184           node.getKind() ==
1185               glow::Kinded::Kind::
1186                   RowwiseQuantizedSparseLengthsWeightedSumNodeKind ||
1187           node.getKind() == glow::Kinded::Kind::SparseLengthsSumNodeKind ||
1188           node.getKind() ==
1189               glow::Kinded::Kind::SparseLengthsWeightedSumNodeKind) {
1190         funcName = std::string(F->getName());
1191         foundFunction = true;
1192         break;
1193       }
1194     }
1195     if (foundFunction) {
1196       break;
1197     }
1198   }
1199 
1200   // If no matching functions then return empty config
1201   if (!foundFunction) {
1202     return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
1203                     "Did not find a partition with an SLS node");
1204   }
1205 
1206   if (deviceInfo_.size() <
1207       cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards) {
1208     return MAKE_ERR(
1209         ErrorValue::ErrorCode::PARTITIONER_ERROR,
1210         strFormat("Not enough devices to partition. Num Devices is %zu and Num "
1211                   "SparseNN Cards Needed is %u",
1212                   deviceInfo_.size(),
1213                   cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards));
1214   }
1215 
1216   // Otherwise partition this function
1217   partitionConfig.funcName = funcName;
1218 
1219   // First optimize the function
1220   Function *F = module_->getFunction(funcName);
1221   std::vector<Backend *> backends;
1222   genBackendMap(backendMap_, backendHolder_, backends);
1223   // First optimize it
1224   if (!optimized_) {
1225     RETURN_IF_ERR(::glow::optimizeFunction(F, *(backends[0]), cctx));
1226   }
1227 
1228   // Now we may want to duplicate Splat weights in case they have been CSE'd
1229   // into a single SplatNode. This is because if two SLWS that share weights are
1230   // separated to two partitions, then partitioning will force a dependence from
1231   // whichever partition the weights are placed to the other partition. After
1232   // partitioning when we optimize each partition individually, they may be
1233   // merged again inside the partition.
1234   for (auto &node : F->getNodes()) {
1235     if (auto *SLWS = llvm::dyn_cast<SparseLengthsWeightedSumNode>(&node)) {
1236       cloneSplatWeightsIfNecessary(SLWS, F);
1237     } else if (auto *SLWS = llvm::dyn_cast<
1238                    FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(&node)) {
1239       cloneSplatWeightsIfNecessary(SLWS, F);
1240     } else if (auto *SLWS =
1241                    llvm::dyn_cast<RowwiseQuantizedSparseLengthsWeightedSumNode>(
1242                        &node)) {
1243       cloneSplatWeightsIfNecessary(SLWS, F);
1244     }
1245   }
1246 
1247   // Create list of SLS Tables
1248   std::vector<SLSTableInfo> slsTables;
1249   partitionConfig.funcName = std::string(F->getName());
1250   VLOG(1) << "Function: " << std::string(F->getName()) << std::endl;
1251   for (auto &node : F->getNodes()) {
1252     bool doPerfModelBalance =
1253         cctx.optimizationOpts.sparseNNPartitioningBalancePerfModel;
1254     RETURN_IF_ERR(
1255         appendSLSTable<FusedRowwiseQuantizedSparseLengthsWeightedSumNode>(
1256             node, slsTables, doPerfModelBalance, backends[0]));
1257     RETURN_IF_ERR(appendSLSTable<FusedRowwiseQuantizedSparseLengthsSumNode>(
1258         node, slsTables, doPerfModelBalance, backends[0]));
1259     RETURN_IF_ERR(appendSLSTable<RowwiseQuantizedSparseLengthsWeightedSumNode>(
1260         node, slsTables, doPerfModelBalance, backends[0]));
1261     RETURN_IF_ERR(appendSLSTable<SparseLengthsSumNode>(
1262         node, slsTables, doPerfModelBalance, backends[0]));
1263     RETURN_IF_ERR(appendSLSTable<SparseLengthsWeightedSumNode>(
1264         node, slsTables, doPerfModelBalance, backends[0]));
1265   }
1266 
1267   // Now sort SLS tables by size decreasing
1268   VLOG(1) << "SLS tables sorted by size decreasing" << std::endl;
1269   std::sort(slsTables.begin(), slsTables.end(),
1270             [](const SLSTableInfo &l, const SLSTableInfo &r) {
1271               return l.numBytesInTable > r.numBytesInTable;
1272             });
1273 
1274   // Print SLS tables
1275   for (auto &table : slsTables) {
1276     VLOG(1) << "(numBytesInTable, deviceID, cost)"
1277             << "\t" << table.numBytesInTable << "\t" << table.deviceId << "\t"
1278             << table.cost << std::endl;
1279   }
1280 
1281   // Create table of devices
1282   std::vector<SLSDeviceInfo> slsDevices;
1283   for (unsigned int device = 0;
1284        device < cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards;
1285        device++) {
1286     slsDevices.push_back(
1287         {device,
1288          (uint64_t)(
1289              (uint64_t)1024 *
1290              (uint64_t)(cctx.optimizationOpts
1291                             .sparseNNPartitioningSchemeSLSTableKBytesPerCard)),
1292          0});
1293   }
1294 
1295   // Now assign SLS Nodes to devices
1296   for (auto &table : slsTables) {
1297 
1298     // Sort by cost increasing
1299     std::sort(slsDevices.begin(), slsDevices.end(),
1300               [](const SLSDeviceInfo &l, const SLSDeviceInfo &r) {
1301                 return l.currentCost < r.currentCost;
1302               });
1303 
1304     VLOG(1) << "Devices sorted by cost increasing" << std::endl;
1305     for (auto &device : slsDevices) {
1306       VLOG(1) << "(deviceId, memAvailableInBytes, currentCost): "
1307               << device.deviceId << "\t" << device.memAvailableInBytes << "\t"
1308               << device.currentCost << std::endl;
1309     }
1310 
1311     // Pick the first that fits
1312     bool deviceFound = false;
1313     for (unsigned int d = 0; d < slsDevices.size(); d++) {
1314       if (slsDevices[d].memAvailableInBytes >= table.numBytesInTable) {
1315         deviceFound = true;
1316         slsDevices[d].memAvailableInBytes -= table.numBytesInTable;
1317         slsDevices[d].currentCost += (size_t)table.cost;
1318         table.deviceId = slsDevices[d].deviceId;
1319         break;
1320       }
1321     }
1322     if (!deviceFound) {
1323       return MAKE_ERR(ErrorValue::ErrorCode::PARTITIONER_ERROR,
1324                       "SLS Balancing Partitioning Error: Not enough memory");
1325     }
1326   }
1327 
1328   // Print final device info
1329   VLOG(1) << "Devices sorted by cost increasing" << std::endl;
1330   for (auto &device : slsDevices) {
1331     VLOG(1) << "(deviceId, memAvailableInBytes, currentCost): "
1332             << device.deviceId << "\t" << device.memAvailableInBytes << "\t"
1333             << device.currentCost << std::endl;
1334   }
1335 
1336   // Print assignments
1337   for (auto &table : slsTables) {
1338     VLOG(1) << "(numBytesInTable, deviceId, cost)"
1339             << "\t" << table.numBytesInTable << "\t" << table.deviceId << "\t"
1340             << table.cost << std::endl;
1341   }
1342 
1343   // Create manual partition
1344   partitionConfig.numOfPartitions = slsDevices.size() + 1;
1345   std::vector<unsigned int> allLogicalIDs;
1346 
1347   // Add SLS Partitions
1348   for (size_t p = 0; p < slsDevices.size(); p++) {
1349     partitionConfig.partitionNames.push_back(std::string("SLSPartition_") +
1350                                              std::to_string(p));
1351     partitionConfig.backendNames.push_back(deviceInfo_[p].backendName);
1352     partitionConfig.logicalIDs.push_back({(unsigned int)p});
1353     BackendHints backendHints;
1354     backendHints.executionUnits =
1355         cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresSLS;
1356     partitionConfig.backendHints.push_back(backendHints);
1357     allLogicalIDs.push_back(p);
1358   }
1359 
1360   // Add last partition
1361   partitionConfig.partitionNames.push_back(std::string("NonSLSPartition_"));
1362   partitionConfig.backendNames.push_back(deviceInfo_[0].backendName);
1363   partitionConfig.logicalIDs.push_back(allLogicalIDs);
1364   BackendHints backendHints;
1365   backendHints.executionUnits =
1366       cctx.optimizationOpts.sparseNNPartitioningSchemeNumCoresOther;
1367   partitionConfig.backendHints.push_back(backendHints);
1368 
1369   // Map SLS nodes to their partitions
1370   for (auto &table : slsTables) {
1371     partitionConfig.nodeToPartition[table.node->getName()] = table.deviceId;
1372   }
1373 
1374   // For each partition, go through all SLS tables assigned and get all their
1375   // predecessors into the same partition
1376   for (size_t p = 0; p < slsDevices.size(); p++) {
1377     std::queue<Node *> preds;
1378     for (auto &table : slsTables) {
1379       if (table.deviceId == p) {
1380         preds.push(table.node);
1381       }
1382     }
1383     while (!preds.empty()) {
1384       auto cur = preds.front();
1385       preds.pop();
1386       for (auto &node : getInputs(cur)) {
1387         if (partitionConfig.nodeToPartition.find(node->getName()) ==
1388             partitionConfig.nodeToPartition.end()) {
1389           partitionConfig.nodeToPartition[node->getName()] = p;
1390           preds.push(node);
1391         }
1392       }
1393     }
1394   }
1395 
1396   // Also, go through all SLS tables and assign any Clip or LN following
1397   // them to same partition
1398   std::vector<std::vector<NodeValue>> frontierValues(slsDevices.size());
1399   for (dim_t deviceId = 0; deviceId < slsDevices.size(); deviceId++) {
1400     std::vector<const Node *> traversedNodes;
1401     for (auto &table : slsTables) {
1402       if (table.deviceId == deviceId) {
1403         bool includeLN =
1404             cctx.optimizationOpts.sparseNNPartitioningPairLNWithSLS;
1405         expandFrontier(table.node, table.slsResult, frontierValues[deviceId],
1406                        traversedNodes, includeLN);
1407       }
1408     }
1409     // Assign the nodes encountered to this partition
1410     for (auto &node : traversedNodes) {
1411       partitionConfig.nodeToPartition[node->getName()] = deviceId;
1412     }
1413   }
1414 
1415   // All other nodes go in the last partition
1416   for (auto &node : F->getNodes()) {
1417     if (partitionConfig.nodeToPartition.find(node.getName()) ==
1418         partitionConfig.nodeToPartition.end()) {
1419       partitionConfig.nodeToPartition[node.getName()] = slsDevices.size();
1420     }
1421   }
1422 
1423   // Insert Split->Concat at barrier between SLS and Non-SLS partitions
1424   if (cctx.optimizationOpts.sparseNNPartitioningAddSLSConcats) {
1425     RETURN_IF_ERR(sparseNNInsertSplitConcat(F, slsDevices, frontierValues,
1426                                             partitionConfig));
1427   }
1428 
1429   VLOG(1) << " Finished SparseNN partitioning" << std::endl;
1430   VLOG(1) << " PartitionConfig ::: funcName = " << partitionConfig.funcName
1431           << "\n";
1432   VLOG(1) << " PartitionConfig ::: numOfPartitions = "
1433           << partitionConfig.numOfPartitions << "\n";
1434   VLOG(1) << " PartitionConfig ::: partitionNames = ";
1435   for (unsigned i = 0; i < partitionConfig.numOfPartitions; i++) {
1436     VLOG(1) << partitionConfig.partitionNames[i] << " ";
1437   }
1438   VLOG(1) << "\n";
1439   VLOG(1) << " PartitionConfig ::: logicalIDs = ";
1440   for (unsigned i = 0; i < partitionConfig.numOfPartitions; i++) {
1441     for (auto &id : partitionConfig.logicalIDs[i]) {
1442       VLOG(1) << id << " ";
1443     }
1444     VLOG(1) << "\n";
1445   }
1446 
1447   DAGListTy partitions;
1448   ASSIGN_VALUE_OR_RETURN_ERR(partitions,
1449                              partitionFromConfig(partitionConfig, cctx));
1450   if (cctx.saturateHost) {
1451     saturateHost(cctx.optimizationOpts.sparseNNPartitioningSchemeNumCards,
1452                  partitions);
1453   }
1454   return std::move(partitions);
1455 }
1456 
partition(CompilationContext & cctx)1457 Expected<DAGListTy> Partitioner::partition(CompilationContext &cctx) {
1458   if (cctx.prepartitionedConfig &&
1459       cctx.prepartitionedConfig->funcs.size() != 0) {
1460     return setupPrepartitionedModule(cctx);
1461   }
1462 
1463   if (cctx.partitionConfig) {
1464     partitionConfig_ = *cctx.partitionConfig;
1465   }
1466 
1467   if (partitionConfig_.enabled()) {
1468     // Call user-defined partition flow.
1469     return partitionFromConfig(partitionConfig_, cctx);
1470   }
1471 
1472   if (!multiBackendNames_ &&
1473       cctx.optimizationOpts.useSparseNNPartitioningScheme) {
1474     return partitionSparseNN(cctx);
1475   }
1476 
1477   if (cctx.precisionConfig.quantMode == QuantizationMode::Profile) {
1478     // Call quantization profiling partition flow.
1479     return quantizationProfilingPartition(cctx);
1480   }
1481 
1482   if (!multiBackendNames_ && glow::GlowEnableLoadBalancedPartitioning) {
1483     // Call load-balance partition flow.
1484     return loadBalancedPartition(cctx);
1485   }
1486 
1487   // Call heterogeneous partition flow.
1488   return heterogeneousPartition(cctx);
1489 }
1490