1 /* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18 package org.apache.mxnetexamples.javaapi.infer.predictor; 19 20 import org.apache.mxnet.infer.javaapi.Predictor; 21 import org.apache.mxnet.javaapi.*; 22 import org.kohsuke.args4j.CmdLineParser; 23 import org.kohsuke.args4j.Option; 24 import org.slf4j.Logger; 25 import org.slf4j.LoggerFactory; 26 27 import java.awt.image.BufferedImage; 28 import java.io.BufferedReader; 29 import java.io.File; 30 import java.io.FileReader; 31 import java.io.IOException; 32 import java.util.ArrayList; 33 import java.util.List; 34 35 /** 36 * This Class is a demo to show how users can use Predictor APIs to do 37 * Image Classification with all hand-crafted Pre-processing. 38 * All helper functions for image pre-processing are 39 * currently available in ObjectDetector class. 40 */ 41 public class PredictorExample { 42 @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model") 43 private String modelPathPrefix = "/model/ssd_resnet50_512"; 44 @Option(name = "--input-image", usage = "the input image") 45 private String inputImagePath = "/images/dog.jpg"; 46 47 final static Logger logger = LoggerFactory.getLogger(PredictorExample.class); 48 private static NDArray$ NDArray = NDArray$.MODULE$; 49 50 /** 51 * Helper class to print the maximum prediction result 52 * @param probabilities The float array of probability 53 * @param modelPathPrefix model Path needs to load the synset.txt 54 */ printMaximumClass(float[] probabilities, String modelPathPrefix)55 private static String printMaximumClass(float[] probabilities, 56 String modelPathPrefix) throws IOException { 57 String synsetFilePath = modelPathPrefix.substring(0, 58 1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt"; 59 BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath)); 60 ArrayList<String> list = new ArrayList<>(); 61 String line = reader.readLine(); 62 63 while (line != null){ 64 list.add(line); 65 line = reader.readLine(); 66 } 67 reader.close(); 68 69 int maxIdx = 0; 70 for (int i = 1;i<probabilities.length;i++) { 71 if (probabilities[i] > probabilities[maxIdx]) { 72 maxIdx = i; 73 } 74 } 75 76 return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ; 77 } 78 main(String[] args)79 public static void main(String[] args) { 80 PredictorExample inst = new PredictorExample(); 81 CmdLineParser parser = new CmdLineParser(inst); 82 try { 83 parser.parseArgument(args); 84 } catch (Exception e) { 85 logger.error(e.getMessage(), e); 86 parser.printUsage(System.err); 87 System.exit(1); 88 } 89 // Prepare the model 90 List<Context> context = new ArrayList<Context>(); 91 if (System.getenv().containsKey("SCALA_TEST_ON_GPU") && 92 Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) { 93 context.add(Context.gpu()); 94 } else { 95 context.add(Context.cpu()); 96 } 97 List<DataDesc> inputDesc = new ArrayList<>(); 98 Shape inputShape = new Shape(new int[]{1, 3, 224, 224}); 99 inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW")); 100 Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0); 101 // Prepare data 102 NDArray img = Image.imRead(inst.inputImagePath, 1, true); 103 img = Image.imResize(img, 224, 224, null); 104 // predict 105 float[][] result = predictor.predict(new float[][]{img.toArray()}); 106 try { 107 System.out.println("Predict with Float input"); 108 System.out.println(printMaximumClass(result[0], inst.modelPathPrefix)); 109 } catch (IOException e) { 110 System.err.println(e); 111 } 112 // predict with NDArray 113 NDArray nd = img; 114 nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0]; 115 nd = NDArray.expand_dims(nd, 0, null)[0]; 116 nd = nd.asType(DType.Float32()); 117 List<NDArray> ndList = new ArrayList<>(); 118 ndList.add(nd); 119 List<NDArray> ndResult = predictor.predictWithNDArray(ndList); 120 try { 121 System.out.println("Predict with NDArray"); 122 System.out.println(printMaximumClass(ndResult.get(0).toArray(), inst.modelPathPrefix)); 123 } catch (IOException e) { 124 System.err.println(e); 125 } 126 } 127 128 } 129