1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one 3 * or more contributor license agreements. See the NOTICE file 4 * distributed with this work for additional information 5 * regarding copyright ownership. The ASF licenses this file 6 * to you under the Apache License, Version 2.0 (the 7 * "License"); you may not use this file except in compliance 8 * with the License. You may obtain a copy of the License at 9 * 10 * http://www.apache.org/licenses/LICENSE-2.0 11 * 12 * Unless required by applicable law or agreed to in writing, 13 * software distributed under the License is distributed on an 14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY 15 * KIND, either express or implied. See the License for the 16 * specific language governing permissions and limitations 17 * under the License. 18 */ 19 20 /*! 21 * \file Converter.java 22 * \brief Convert Caffe prototxt to MXNet Python code 23 */ 24 25 package io.mxnet.caffetranslator; 26 27 import io.mxnet.caffetranslator.generators.*; 28 import lombok.Setter; 29 import org.antlr.v4.runtime.CharStream; 30 import org.antlr.v4.runtime.CharStreams; 31 import org.antlr.v4.runtime.CommonTokenStream; 32 import org.stringtemplate.v4.ST; 33 import org.stringtemplate.v4.STGroup; 34 import org.stringtemplate.v4.STRawGroupDir; 35 36 import java.io.File; 37 import java.io.FileInputStream; 38 import java.io.IOException; 39 import java.nio.charset.StandardCharsets; 40 import java.util.ArrayList; 41 import java.util.HashSet; 42 import java.util.List; 43 import java.util.Set; 44 45 public class Converter { 46 47 private final String trainPrototxt, solverPrototxt; 48 private final MLModel mlModel; 49 private final STGroup stGroup; 50 private final SymbolGeneratorFactory generators; 51 private final String NL; 52 private final GenerationHelper gh; 53 @Setter 54 55 private String paramsFilePath; 56 private Solver solver; 57 Converter(String trainPrototxt, String solverPrototxt)58 Converter(String trainPrototxt, String solverPrototxt) { 59 this.trainPrototxt = trainPrototxt; 60 this.solverPrototxt = solverPrototxt; 61 this.mlModel = new MLModel(); 62 this.stGroup = new STRawGroupDir("templates"); 63 this.generators = SymbolGeneratorFactory.getInstance(); 64 NL = System.getProperty("line.separator"); 65 gh = new GenerationHelper(); 66 addGenerators(); 67 } 68 addGenerators()69 private void addGenerators() { 70 generators.addGenerator("Convolution", new ConvolutionGenerator()); 71 generators.addGenerator("Deconvolution", new DeconvolutionGenerator()); 72 generators.addGenerator("Pooling", new PoolingGenerator()); 73 generators.addGenerator("InnerProduct", new FCGenerator()); 74 generators.addGenerator("ReLU", new ReluGenerator()); 75 generators.addGenerator("SoftmaxWithLoss", new SoftmaxOutputGenerator()); 76 generators.addGenerator("PluginIntLayerGenerator", new PluginIntLayerGenerator()); 77 generators.addGenerator("CaffePluginLossLayer", new PluginLossGenerator()); 78 generators.addGenerator("Permute", new PermuteGenerator()); 79 generators.addGenerator("Concat", new ConcatGenerator()); 80 generators.addGenerator("BatchNorm", new BatchNormGenerator()); 81 generators.addGenerator("Power", new PowerGenerator()); 82 generators.addGenerator("Eltwise", new EltwiseGenerator()); 83 generators.addGenerator("Flatten", new FlattenGenerator()); 84 generators.addGenerator("Dropout", new DropoutGenerator()); 85 generators.addGenerator("Scale", new ScaleGenerator()); 86 } 87 parseTrainingPrototxt()88 public boolean parseTrainingPrototxt() { 89 90 CharStream cs = null; 91 try { 92 FileInputStream fis = new FileInputStream(new File(trainPrototxt)); 93 cs = CharStreams.fromStream(fis, StandardCharsets.UTF_8); 94 } catch (IOException e) { 95 System.err.println("Unable to read prototxt: " + trainPrototxt); 96 return false; 97 } 98 99 CaffePrototxtLexer lexer = new CaffePrototxtLexer(cs); 100 101 CommonTokenStream tokens = new CommonTokenStream(lexer); 102 CaffePrototxtParser parser = new CaffePrototxtParser(tokens); 103 104 CreateModelListener modelCreator = new CreateModelListener(parser, mlModel); 105 parser.addParseListener(modelCreator); 106 parser.prototxt(); 107 108 return true; 109 } 110 parseSolverPrototxt()111 public boolean parseSolverPrototxt() { 112 solver = new Solver(solverPrototxt); 113 return solver.parsePrototxt(); 114 } 115 generateMXNetCode()116 public String generateMXNetCode() { 117 if (!parseTrainingPrototxt()) { 118 return ""; 119 } 120 121 if (!parseSolverPrototxt()) { 122 return ""; 123 } 124 125 StringBuilder code = new StringBuilder(); 126 127 code.append(generateImports()); 128 code.append(System.lineSeparator()); 129 130 code.append(generateLogger()); 131 code.append(System.lineSeparator()); 132 133 code.append(generateParamInitializer()); 134 code.append(System.lineSeparator()); 135 136 code.append(generateMetricsClasses()); 137 code.append(System.lineSeparator()); 138 139 if (paramsFilePath != null) { 140 code.append(generateParamsLoader()); 141 code.append(System.lineSeparator()); 142 } 143 144 // Convert data layers 145 code.append(generateIterators()); 146 147 // Generate variables for data and label 148 code.append(generateInputVars()); 149 150 // Convert non data layers 151 List<Layer> layers = mlModel.getNonDataLayers(); 152 153 for (int layerIndex = 0; layerIndex < layers.size(); ) { 154 Layer layer = layers.get(layerIndex); 155 SymbolGenerator generator = generators.getGenerator(layer.getType()); 156 157 // Handle layers for which there is no Generator 158 if (generator == null) { 159 if (layer.getType().equalsIgnoreCase("Accuracy")) { 160 // We handle accuracy layers at a later stage. Do nothing for now. 161 } else if (layer.getType().toLowerCase().endsWith("loss")) { 162 // This is a loss layer we don't have a generator for. Wrap it in CaffeLoss. 163 generator = generators.getGenerator("CaffePluginLossLayer"); 164 } else { 165 // This is a layer we don't have a generator for. Wrap it in CaffeOp. 166 generator = generators.getGenerator("PluginIntLayerGenerator"); 167 } 168 } 169 170 if (generator != null) { // If we have a generator 171 // Generate code 172 GeneratorOutput out = generator.generate(layer, mlModel); 173 String segment = out.code; 174 code.append(segment); 175 code.append(NL); 176 177 // Update layerIndex depending on how many layers we ended up translating 178 layerIndex += out.numLayersTranslated; 179 } else { // If we don't have a generator 180 // We've decided to skip this layer. Generate no code. Just increment layerIndex 181 // by 1 and move on to the next layer. 182 layerIndex++; 183 } 184 } 185 186 String loss = getLoss(mlModel, code); 187 188 String evalMetric = generateValidationMetrics(mlModel); 189 code.append(evalMetric); 190 191 String runner = generateRunner(loss); 192 code.append(runner); 193 194 return code.toString(); 195 } 196 generateLogger()197 private String generateLogger() { 198 ST st = gh.getTemplate("logging"); 199 st.add("name", mlModel.getName()); 200 return st.render(); 201 } 202 generateRunner(String loss)203 private String generateRunner(String loss) { 204 ST st = gh.getTemplate("runner"); 205 st.add("max_iter", solver.getProperty("max_iter")); 206 st.add("stepsize", solver.getProperty("stepsize")); 207 st.add("snapshot", solver.getProperty("snapshot")); 208 st.add("test_interval", solver.getProperty("test_interval")); 209 st.add("test_iter", solver.getProperty("test_iter")); 210 st.add("snapshot_prefix", solver.getProperty("snapshot_prefix")); 211 212 st.add("train_data_itr", getIteratorName("TRAIN")); 213 st.add("test_data_itr", getIteratorName("TEST")); 214 215 String context = solver.getProperty("solver_mode", "cpu").toLowerCase(); 216 context = String.format("mx.%s()", context); 217 st.add("ctx", context); 218 219 st.add("loss", loss); 220 221 st.add("data_names", getDataNames()); 222 st.add("label_names", getLabelNames()); 223 224 st.add("init_params", generateInitializer()); 225 226 st.add("init_optimizer", generateOptimizer()); 227 st.add("gamma", solver.getProperty("gamma")); 228 st.add("power", solver.getProperty("power")); 229 st.add("lr_update", generateLRUpdate()); 230 231 return st.render(); 232 } 233 generateParamInitializer()234 private String generateParamInitializer() { 235 return gh.getTemplate("param_initializer").render(); 236 } 237 generateMetricsClasses()238 private String generateMetricsClasses() { 239 ST st = gh.getTemplate("metrics_classes"); 240 241 String display = solver.getProperty("display"); 242 String average_loss = solver.getProperty("average_loss"); 243 244 if (display != null) { 245 st.add("display", display); 246 } 247 248 if (average_loss != null) { 249 st.add("average_loss", average_loss); 250 } 251 252 return st.render(); 253 } 254 generateParamsLoader()255 private String generateParamsLoader() { 256 return gh.getTemplate("params_loader").render(); 257 } 258 getLoss(MLModel model, StringBuilder out)259 private String getLoss(MLModel model, StringBuilder out) { 260 List<String> losses = new ArrayList<>(); 261 for (Layer layer : model.getLayerList()) { 262 if (layer.getType().toLowerCase().endsWith("loss")) { 263 losses.add(gh.getVarname(layer.getTop())); 264 } 265 } 266 267 if (losses.size() == 1) { 268 return losses.get(0); 269 } else if (losses.size() > 1) { 270 String loss_var = "combined_loss"; 271 ST st = gh.getTemplate("group"); 272 st.add("var", loss_var); 273 st.add("symbols", losses); 274 out.append(st.render()); 275 return loss_var; 276 } else { 277 System.err.println("No loss found"); 278 return "unknown_loss"; 279 } 280 } 281 generateLRUpdate()282 private String generateLRUpdate() { 283 String code; 284 String lrPolicy = solver.getProperty("lr_policy", "fixed").toLowerCase(); 285 ST st; 286 switch (lrPolicy) { 287 case "fixed": 288 // lr stays fixed. No update needed 289 code = ""; 290 break; 291 case "multistep": 292 st = gh.getTemplate("lrpolicy_multistep"); 293 st.add("steps", solver.getProperties("stepvalue")); 294 code = st.render(); 295 break; 296 case "step": 297 case "exp": 298 case "inv": 299 case "poly": 300 case "sigmoid": 301 st = gh.getTemplate("lrpolicy_" + lrPolicy); 302 code = st.render(); 303 break; 304 default: 305 String message = "Unknown lr_policy: " + lrPolicy; 306 System.err.println(message); 307 code = "# " + message + System.lineSeparator(); 308 break; 309 } 310 return Utils.indent(code, 2, true, 4); 311 } 312 generateValidationMetrics(MLModel mlModel)313 private String generateValidationMetrics(MLModel mlModel) { 314 return new AccuracyMetricsGenerator().generate(mlModel); 315 } 316 generateOptimizer()317 private String generateOptimizer() { 318 Optimizer optimizer = new Optimizer(solver); 319 return optimizer.generateInitCode(); 320 } 321 generateInitializer()322 private String generateInitializer() { 323 ST st = gh.getTemplate("init_params"); 324 st.add("params_file", paramsFilePath); 325 return st.render(); 326 } 327 generateImports()328 private String generateImports() { 329 return gh.getTemplate("imports").render(); 330 } 331 generateIterators()332 private StringBuilder generateIterators() { 333 StringBuilder code = new StringBuilder(); 334 335 for (Layer layer : mlModel.getDataLayers()) { 336 String iterator = generateIterator(layer); 337 code.append(iterator); 338 } 339 340 return code; 341 } 342 getIteratorName(String phase)343 private String getIteratorName(String phase) { 344 for (Layer layer : mlModel.getDataLayers()) { 345 String layerPhase = layer.getAttr("include.phase", phase); 346 if (phase.equalsIgnoreCase(layerPhase)) { 347 return layerPhase.toLowerCase() + "_" + layer.getName() + "_" + "itr"; 348 } 349 } 350 return null; 351 } 352 getDataNames()353 private List<String> getDataNames() { 354 return getDataNames(0); 355 } 356 getLabelNames()357 private List<String> getLabelNames() { 358 return getDataNames(1); 359 } 360 getDataNames(int topIndex)361 private List<String> getDataNames(int topIndex) { 362 List<String> dataList = new ArrayList<String>(); 363 for (Layer layer : mlModel.getDataLayers()) { 364 if (layer.getAttr("include.phase").equalsIgnoreCase("train")) { 365 String dataName = layer.getTops().get(topIndex); 366 if (dataName != null) { 367 dataList.add(String.format("'%s'", dataName)); 368 } 369 } 370 } 371 return dataList; 372 } 373 generateInputVars()374 private StringBuilder generateInputVars() { 375 StringBuilder code = new StringBuilder(); 376 377 Set<String> tops = new HashSet<String>(); 378 379 for (Layer layer : mlModel.getDataLayers()) 380 for (String top : layer.getTops()) 381 tops.add(top); 382 383 for (String top : tops) 384 code.append(gh.generateVar(gh.getVarname(top), top, null, null, null, null)); 385 386 code.append(System.lineSeparator()); 387 return code; 388 } 389 generateIterator(Layer layer)390 private String generateIterator(Layer layer) { 391 String iteratorName = layer.getAttr("include.phase"); 392 iteratorName = iteratorName.toLowerCase(); 393 iteratorName = iteratorName + "_" + layer.getName() + "_" + "itr"; 394 395 ST st = stGroup.getInstanceOf("iterator"); 396 397 String prototxt = layer.getPrototxt(); 398 prototxt = prototxt.replace("\r", ""); 399 prototxt = prototxt.replace("\n", " \\\n"); 400 prototxt = "'" + prototxt + "'"; 401 prototxt = Utils.indent(prototxt, 1, true, 4); 402 403 st.add("iter_name", iteratorName); 404 st.add("prototxt", prototxt); 405 406 String dataName = "???"; 407 if (layer.getTops().size() >= 1) { 408 dataName = layer.getTops().get(0); 409 } else { 410 System.err.println(String.format("Data layer %s doesn't have data", layer.getName())); 411 } 412 st.add("data_name", dataName); 413 414 String labelName = "???"; 415 if (layer.getTops().size() >= 1) { 416 labelName = layer.getTops().get(1); 417 } else { 418 System.err.println(String.format("Data layer %s doesn't have label", layer.getName())); 419 } 420 st.add("label_name", labelName); 421 422 if (layer.hasAttr("data_param.num_examples")) { 423 st.add("num_examples", layer.getAttr("data_param.num_examples")); 424 } 425 426 return st.render(); 427 } 428 429 } 430