1 /**
2 * Copyright (c) Glow Contributors. See CONTRIBUTORS file.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include "glow/Backends/DeviceManager.h"
18 #include "glow/Base/Image.h"
19 #include "glow/ExecutionEngine/ExecutionEngine.h"
20 #include "glow/Graph/Graph.h"
21 #include "glow/Importer/Caffe2ModelLoader.h"
22 #include "glow/Optimizer/GraphOptimizer/GraphOptimizer.h"
23 #include "glow/Runtime/RuntimeTypes.h"
24
25 #include "llvm/Support/CommandLine.h"
26 #include "llvm/Support/FileSystem.h"
27
28 #include <chrono>
29 #include <future>
30
31 using namespace glow;
32 using namespace glow::runtime;
33
34 #if (GLOW_WITH_OPENCL)
35 std::array<std::string, 3> supportedBackends{{"CPU", "Interpreter", "OpenCL"}};
36 #else
37 std::array<std::string, 2> supportedBackends{{"CPU", "Interpreter"}};
38 #endif
39
40 namespace {
41 llvm::cl::OptionCategory category("tracing-compare Options");
42 llvm::cl::opt<std::string>
43 inputImage(llvm::cl::desc("path to input image to classify, which must be "
44 "a png with standard imagenet normalization"),
45 llvm::cl::init("../tests/images/imagenet/dog_207.png"),
46 llvm::cl::Positional, llvm::cl::cat(category));
47
48 llvm::cl::opt<std::string> tracePath("trace-path",
49 llvm::cl::desc("Write trace logs to disk"),
50 llvm::cl::init("./glow-trace.json"),
51 llvm::cl::cat(category));
52
53 } // namespace
54
55 /// Loads the model into /p module and returns the input and output
56 /// Placeholders.
loadResnet50Model(TypeRef inputType,Module & module)57 std::pair<Placeholder *, Placeholder *> loadResnet50Model(TypeRef inputType,
58 Module &module) {
59 Function *F = module.createFunction("resnet50");
60
61 llvm::outs() << "Loading resnet50 model.\n";
62
63 const char inputName[] = "gpu_0/data";
64 Caffe2ModelLoader loader("resnet50/predict_net.pb", "resnet50/init_net.pb",
65 {inputName}, {inputType}, *F);
66 Placeholder *input = llvm::cast<Placeholder>(
67 EXIT_ON_ERR(loader.getNodeValueByName(inputName)));
68 Placeholder *output = EXIT_ON_ERR(loader.getSingleOutput());
69
70 return std::make_pair(input, output);
71 }
72
73 /// Compiles the resnet50 function.
compileModel(Module & module,llvm::StringRef backendName)74 std::unique_ptr<CompiledFunction> compileModel(Module &module,
75 llvm::StringRef backendName) {
76 auto *backend = createBackend(backendName);
77 Function *F = module.getFunction("resnet50");
78 Function *F_ = F->clone("resnet50." + backendName.str());
79
80 llvm::outs() << "Starting compile on " << backendName << " device.\n";
81 CompilationContext cctx;
82 cctx.compMode = CompilationMode::Infer;
83 cctx.backendOpts.autoInstrument = true;
84 EXIT_ON_ERR(::glow::optimizeFunction(F_, *backend, cctx));
85 return EXIT_ON_ERR(backend->compile(F_, cctx.backendOpts));
86 }
87
addToDevice(unsigned int id,DeviceManager * device,Module & module,FunctionMapTy functions)88 std::future<void> addToDevice(unsigned int id, DeviceManager *device,
89 Module &module, FunctionMapTy functions) {
90 auto compilePromise = std::make_shared<std::promise<void>>();
91 auto future = compilePromise->get_future();
92
93 device->addNetwork(
94 &module, functions, [compilePromise, id](const Module *, Error err) {
95 if (err) {
96 llvm::errs() << "Failed to compile model for device " << id << ".\n";
97 EXIT_ON_ERR(std::move(err));
98 } else {
99 llvm::outs() << "Successfully added to Device " << id << ".\n";
100 }
101 compilePromise->set_value();
102 });
103
104 return future;
105 }
106
main(int argc,char ** argv)107 int main(int argc, char **argv) {
108 llvm::cl::ParseCommandLineOptions(
109 argc, argv, "Run resnet and export a json file containing trace events");
110
111 std::vector<DeviceManager *> devices(supportedBackends.size());
112 for (unsigned i = 0, e = supportedBackends.size(); i < e; ++i) {
113 devices[i] =
114 DeviceManager::createDeviceManager(DeviceConfig(supportedBackends[i]));
115 EXIT_ON_ERR(devices[i]->init());
116 }
117
118 // Load and compile model.
119
120 Module module;
121 TypeRef inputType(module.uniqueType(ElemKind::FloatTy, {1, 3, 224, 224}));
122 Placeholder *input, *output;
123
124 std::tie(input, output) = loadResnet50Model(inputType, module);
125
126 std::vector<std::unique_ptr<CompiledFunction>> compiledFunctions(
127 supportedBackends.size());
128 for (unsigned i = 0, e = supportedBackends.size(); i < e; ++i) {
129 compiledFunctions[i] = compileModel(module, supportedBackends[i]);
130
131 FunctionMapTy functions;
132 functions.emplace("resnet50", compiledFunctions[i].get());
133
134 auto f = addToDevice(i, devices[i], module, functions);
135 f.wait_for(/* timeout_duration */ std::chrono::seconds(30));
136 }
137
138 auto image = readPngImageAndPreprocess(
139 inputImage, ImageNormalizationMode::k0to1, ImageChannelOrder::BGR,
140 ImageLayout::NCHW, imagenetNormMean, imagenetNormStd);
141
142 Tensor batch = image.getUnowned(inputType->dims());
143
144 llvm::outs() << "Starting Run.\n";
145 std::vector<std::promise<std::unique_ptr<ExecutionContext>>> promises(
146 supportedBackends.size());
147
148 for (unsigned i = 0, e = supportedBackends.size(); i < e; ++i) {
149 auto context = glow::make_unique<ExecutionContext>();
150 context->setTraceContext(
151 glow::make_unique<TraceContext>(TraceLevel::STANDARD));
152 context->getPlaceholderBindings()->allocate(module.getPlaceholders());
153 updateInputPlaceholders(*(context->getPlaceholderBindings()), {input},
154 {&batch});
155
156 devices[i]->runFunction(
157 "resnet50", std::move(context),
158 [&promises, i](RunIdentifierTy, Error err,
159 std::unique_ptr<ExecutionContext> context) {
160 EXIT_ON_ERR(std::move(err));
161 promises[i].set_value(std::move(context));
162 });
163 }
164
165 TraceContext allEvents(TraceLevel::STANDARD);
166 size_t index = 0;
167 for (auto backend : supportedBackends) {
168 allEvents.setThreadName(index++, backend);
169 }
170
171 for (unsigned i = 0, e = supportedBackends.size(); i < e; ++i) {
172 auto f = promises[i].get_future();
173 f.wait_for(/* timeout_duration */ std::chrono::seconds(30));
174 auto runbindings = f.get();
175 allEvents.merge(runbindings->getTraceContext());
176 }
177
178 llvm::outs() << "Dumping json to " << tracePath << ".\n";
179 allEvents.dump(tracePath, "tracing-compare");
180
181 return 0;
182 }
183