1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *   http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing,
13  * software distributed under the License is distributed on an
14  * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15  * KIND, either express or implied.  See the License for the
16  * specific language governing permissions and limitations
17  * under the License.
18  */
19 
20 /*!
21  * \file Converter.java
22  * \brief Convert Caffe prototxt to MXNet Python code
23  */
24 
25 package io.mxnet.caffetranslator;
26 
27 import io.mxnet.caffetranslator.generators.*;
28 import lombok.Setter;
29 import org.antlr.v4.runtime.CharStream;
30 import org.antlr.v4.runtime.CharStreams;
31 import org.antlr.v4.runtime.CommonTokenStream;
32 import org.stringtemplate.v4.ST;
33 import org.stringtemplate.v4.STGroup;
34 import org.stringtemplate.v4.STRawGroupDir;
35 
36 import java.io.File;
37 import java.io.FileInputStream;
38 import java.io.IOException;
39 import java.nio.charset.StandardCharsets;
40 import java.util.ArrayList;
41 import java.util.HashSet;
42 import java.util.List;
43 import java.util.Set;
44 
45 public class Converter {
46 
47     private final String trainPrototxt, solverPrototxt;
48     private final MLModel mlModel;
49     private final STGroup stGroup;
50     private final SymbolGeneratorFactory generators;
51     private final String NL;
52     private final GenerationHelper gh;
53     @Setter
54 
55     private String paramsFilePath;
56     private Solver solver;
57 
Converter(String trainPrototxt, String solverPrototxt)58     Converter(String trainPrototxt, String solverPrototxt) {
59         this.trainPrototxt = trainPrototxt;
60         this.solverPrototxt = solverPrototxt;
61         this.mlModel = new MLModel();
62         this.stGroup = new STRawGroupDir("templates");
63         this.generators = SymbolGeneratorFactory.getInstance();
64         NL = System.getProperty("line.separator");
65         gh = new GenerationHelper();
66         addGenerators();
67     }
68 
addGenerators()69     private void addGenerators() {
70         generators.addGenerator("Convolution", new ConvolutionGenerator());
71         generators.addGenerator("Deconvolution", new DeconvolutionGenerator());
72         generators.addGenerator("Pooling", new PoolingGenerator());
73         generators.addGenerator("InnerProduct", new FCGenerator());
74         generators.addGenerator("ReLU", new ReluGenerator());
75         generators.addGenerator("SoftmaxWithLoss", new SoftmaxOutputGenerator());
76         generators.addGenerator("PluginIntLayerGenerator", new PluginIntLayerGenerator());
77         generators.addGenerator("CaffePluginLossLayer", new PluginLossGenerator());
78         generators.addGenerator("Permute", new PermuteGenerator());
79         generators.addGenerator("Concat", new ConcatGenerator());
80         generators.addGenerator("BatchNorm", new BatchNormGenerator());
81         generators.addGenerator("Power", new PowerGenerator());
82         generators.addGenerator("Eltwise", new EltwiseGenerator());
83         generators.addGenerator("Flatten", new FlattenGenerator());
84         generators.addGenerator("Dropout", new DropoutGenerator());
85         generators.addGenerator("Scale", new ScaleGenerator());
86     }
87 
parseTrainingPrototxt()88     public boolean parseTrainingPrototxt() {
89 
90         CharStream cs = null;
91         try {
92             FileInputStream fis = new FileInputStream(new File(trainPrototxt));
93             cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8);
94         } catch (IOException e) {
95             System.err.println("Unable to read prototxt: " + trainPrototxt);
96             return false;
97         }
98 
99         CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs);
100 
101         CommonTokenStream tokens = new CommonTokenStream(lexer);
102         CaffePrototxtParser parser = new CaffePrototxtParser(tokens);
103 
104         CreateModelListener modelCreator = new CreateModelListener(parser, mlModel);
105         parser.addParseListener(modelCreator);
106         parser.prototxt();
107 
108         return true;
109     }
110 
parseSolverPrototxt()111     public boolean parseSolverPrototxt() {
112         solver = new Solver(solverPrototxt);
113         return solver.parsePrototxt();
114     }
115 
generateMXNetCode()116     public String generateMXNetCode() {
117         if (!parseTrainingPrototxt()) {
118             return "";
119         }
120 
121         if (!parseSolverPrototxt()) {
122             return "";
123         }
124 
125         StringBuilder code = new StringBuilder();
126 
127         code.append(generateImports());
128         code.append(System.lineSeparator());
129 
130         code.append(generateLogger());
131         code.append(System.lineSeparator());
132 
133         code.append(generateParamInitializer());
134         code.append(System.lineSeparator());
135 
136         code.append(generateMetricsClasses());
137         code.append(System.lineSeparator());
138 
139         if (paramsFilePath != null) {
140             code.append(generateParamsLoader());
141             code.append(System.lineSeparator());
142         }
143 
144         // Convert data layers
145         code.append(generateIterators());
146 
147         // Generate variables for data and label
148         code.append(generateInputVars());
149 
150         // Convert non data layers
151         List<Layer> layers = mlModel.getNonDataLayers();
152 
153         for (int layerIndex = 0; layerIndex < layers.size(); ) {
154             Layer layer = layers.get(layerIndex);
155             SymbolGenerator generator = generators.getGenerator(layer.getType());
156 
157             // Handle layers for which there is no Generator
158             if (generator == null) {
159                 if (layer.getType().equalsIgnoreCase("Accuracy")) {
160                     // We handle accuracy layers at a later stage. Do nothing for now.
161                 } else if (layer.getType().toLowerCase().endsWith("loss")) {
162                     // This is a loss layer we don't have a generator for. Wrap it in CaffeLoss.
163                     generator = generators.getGenerator("CaffePluginLossLayer");
164                 } else {
165                     // This is a layer we don't have a generator for. Wrap it in CaffeOp.
166                     generator = generators.getGenerator("PluginIntLayerGenerator");
167                 }
168             }
169 
170             if (generator != null) { // If we have a generator
171                 // Generate code
172                 GeneratorOutput out = generator.generate(layer, mlModel);
173                 String segment = out.code;
174                 code.append(segment);
175                 code.append(NL);
176 
177                 // Update layerIndex depending on how many layers we ended up translating
178                 layerIndex += out.numLayersTranslated;
179             } else { // If we don't have a generator
180                 // We've decided to skip this layer. Generate no code. Just increment layerIndex
181                 // by 1 and move on to the next layer.
182                 layerIndex++;
183             }
184         }
185 
186         String loss = getLoss(mlModel, code);
187 
188         String evalMetric = generateValidationMetrics(mlModel);
189         code.append(evalMetric);
190 
191         String runner = generateRunner(loss);
192         code.append(runner);
193 
194         return code.toString();
195     }
196 
generateLogger()197     private String generateLogger() {
198         ST st = gh.getTemplate("logging");
199         st.add("name", mlModel.getName());
200         return st.render();
201     }
202 
generateRunner(String loss)203     private String generateRunner(String loss) {
204         ST st = gh.getTemplate("runner");
205         st.add("max_iter", solver.getProperty("max_iter"));
206         st.add("stepsize", solver.getProperty("stepsize"));
207         st.add("snapshot", solver.getProperty("snapshot"));
208         st.add("test_interval", solver.getProperty("test_interval"));
209         st.add("test_iter", solver.getProperty("test_iter"));
210         st.add("snapshot_prefix", solver.getProperty("snapshot_prefix"));
211 
212         st.add("train_data_itr", getIteratorName("TRAIN"));
213         st.add("test_data_itr", getIteratorName("TEST"));
214 
215         String context = solver.getProperty("solver_mode", "cpu").toLowerCase();
216         context = String.format("mx.%s()", context);
217         st.add("ctx", context);
218 
219         st.add("loss", loss);
220 
221         st.add("data_names", getDataNames());
222         st.add("label_names", getLabelNames());
223 
224         st.add("init_params", generateInitializer());
225 
226         st.add("init_optimizer", generateOptimizer());
227         st.add("gamma", solver.getProperty("gamma"));
228         st.add("power", solver.getProperty("power"));
229         st.add("lr_update", generateLRUpdate());
230 
231         return st.render();
232     }
233 
generateParamInitializer()234     private String generateParamInitializer() {
235         return gh.getTemplate("param_initializer").render();
236     }
237 
generateMetricsClasses()238     private String generateMetricsClasses() {
239         ST st = gh.getTemplate("metrics_classes");
240 
241         String display = solver.getProperty("display");
242         String average_loss = solver.getProperty("average_loss");
243 
244         if (display != null) {
245             st.add("display", display);
246         }
247 
248         if (average_loss != null) {
249             st.add("average_loss", average_loss);
250         }
251 
252         return st.render();
253     }
254 
generateParamsLoader()255     private String generateParamsLoader() {
256         return gh.getTemplate("params_loader").render();
257     }
258 
getLoss(MLModel model, StringBuilder out)259     private String getLoss(MLModel model, StringBuilder out) {
260         List<String> losses = new ArrayList<>();
261         for (Layer layer : model.getLayerList()) {
262             if (layer.getType().toLowerCase().endsWith("loss")) {
263                 losses.add(gh.getVarname(layer.getTop()));
264             }
265         }
266 
267         if (losses.size() == 1) {
268             return losses.get(0);
269         } else if (losses.size() > 1) {
270             String loss_var = "combined_loss";
271             ST st = gh.getTemplate("group");
272             st.add("var", loss_var);
273             st.add("symbols", losses);
274             out.append(st.render());
275             return loss_var;
276         } else {
277             System.err.println("No loss found");
278             return "unknown_loss";
279         }
280     }
281 
generateLRUpdate()282     private String generateLRUpdate() {
283         String code;
284         String lrPolicy = solver.getProperty("lr_policy", "fixed").toLowerCase();
285         ST st;
286         switch (lrPolicy) {
287             case "fixed":
288                 // lr stays fixed. No update needed
289                 code = "";
290                 break;
291             case "multistep":
292                 st = gh.getTemplate("lrpolicy_multistep");
293                 st.add("steps", solver.getProperties("stepvalue"));
294                 code = st.render();
295                 break;
296             case "step":
297             case "exp":
298             case "inv":
299             case "poly":
300             case "sigmoid":
301                 st = gh.getTemplate("lrpolicy_" + lrPolicy);
302                 code = st.render();
303                 break;
304             default:
305                 String message = "Unknown lr_policy: " + lrPolicy;
306                 System.err.println(message);
307                 code = "# " + message + System.lineSeparator();
308                 break;
309         }
310         return Utils.indent(code, 2, true, 4);
311     }
312 
generateValidationMetrics(MLModel mlModel)313     private String generateValidationMetrics(MLModel mlModel) {
314         return new AccuracyMetricsGenerator().generate(mlModel);
315     }
316 
generateOptimizer()317     private String generateOptimizer() {
318         Optimizer optimizer = new Optimizer(solver);
319         return optimizer.generateInitCode();
320     }
321 
generateInitializer()322     private String generateInitializer() {
323         ST st = gh.getTemplate("init_params");
324         st.add("params_file", paramsFilePath);
325         return st.render();
326     }
327 
generateImports()328     private String generateImports() {
329         return gh.getTemplate("imports").render();
330     }
331 
generateIterators()332     private StringBuilder generateIterators() {
333         StringBuilder code = new StringBuilder();
334 
335         for (Layer layer : mlModel.getDataLayers()) {
336             String iterator = generateIterator(layer);
337             code.append(iterator);
338         }
339 
340         return code;
341     }
342 
getIteratorName(String phase)343     private String getIteratorName(String phase) {
344         for (Layer layer : mlModel.getDataLayers()) {
345             String layerPhase = layer.getAttr("include.phase", phase);
346             if (phase.equalsIgnoreCase(layerPhase)) {
347                 return layerPhase.toLowerCase() + "_" + layer.getName() + "_" + "itr";
348             }
349         }
350         return null;
351     }
352 
getDataNames()353     private List<String> getDataNames() {
354         return getDataNames(0);
355     }
356 
getLabelNames()357     private List<String> getLabelNames() {
358         return getDataNames(1);
359     }
360 
getDataNames(int topIndex)361     private List<String> getDataNames(int topIndex) {
362         List<String> dataList = new ArrayList<String>();
363         for (Layer layer : mlModel.getDataLayers()) {
364             if (layer.getAttr("include.phase").equalsIgnoreCase("train")) {
365                 String dataName = layer.getTops().get(topIndex);
366                 if (dataName != null) {
367                     dataList.add(String.format("'%s'", dataName));
368                 }
369             }
370         }
371         return dataList;
372     }
373 
generateInputVars()374     private StringBuilder generateInputVars() {
375         StringBuilder code = new StringBuilder();
376 
377         Set<String> tops = new HashSet<String>();
378 
379         for (Layer layer : mlModel.getDataLayers())
380             for (String top : layer.getTops())
381                 tops.add(top);
382 
383         for (String top : tops)
384             code.append(gh.generateVar(gh.getVarname(top), top, null, null, null, null));
385 
386         code.append(System.lineSeparator());
387         return code;
388     }
389 
generateIterator(Layer layer)390     private String generateIterator(Layer layer) {
391         String iteratorName = layer.getAttr("include.phase");
392         iteratorName = iteratorName.toLowerCase();
393         iteratorName = iteratorName + "_" + layer.getName() + "_" + "itr";
394 
395         ST st = stGroup.getInstanceOf("iterator");
396 
397         String prototxt = layer.getPrototxt();
398         prototxt = prototxt.replace("\r", "");
399         prototxt = prototxt.replace("\n", " \\\n");
400         prototxt = "'" + prototxt + "'";
401         prototxt = Utils.indent(prototxt, 1, true, 4);
402 
403         st.add("iter_name", iteratorName);
404         st.add("prototxt", prototxt);
405 
406         String dataName = "???";
407         if (layer.getTops().size() >= 1) {
408             dataName = layer.getTops().get(0);
409         } else {
410             System.err.println(String.format("Data layer %s doesn't have data", layer.getName()));
411         }
412         st.add("data_name", dataName);
413 
414         String labelName = "???";
415         if (layer.getTops().size() >= 1) {
416             labelName = layer.getTops().get(1);
417         } else {
418             System.err.println(String.format("Data layer %s doesn't have label", layer.getName()));
419         }
420         st.add("label_name", labelName);
421 
422         if (layer.hasAttr("data_param.num_examples")) {
423             st.add("num_examples", layer.getAttr("data_param.num_examples"));
424         }
425 
426         return st.render();
427     }
428 
429 }
430