1 /*
2  * Licensed to the Apache Software Foundation (ASF) under one or more
3  * contributor license agreements.  See the NOTICE file distributed with
4  * this work for additional information regarding copyright ownership.
5  * The ASF licenses this file to You under the Apache License, Version 2.0
6  * (the "License"); you may not use this file except in compliance with
7  * the License.  You may obtain a copy of the License at
8  *
9  *    http://www.apache.org/licenses/LICENSE-2.0
10  *
11  * Unless required by applicable law or agreed to in writing, software
12  * distributed under the License is distributed on an "AS IS" BASIS,
13  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14  * See the License for the specific language governing permissions and
15  * limitations under the License.
16  */
17 
18 package org.apache.mxnetexamples.javaapi.infer.predictor;
19 
20 import org.apache.mxnet.infer.javaapi.Predictor;
21 import org.apache.mxnet.javaapi.*;
22 import org.kohsuke.args4j.CmdLineParser;
23 import org.kohsuke.args4j.Option;
24 import org.slf4j.Logger;
25 import org.slf4j.LoggerFactory;
26 
27 import java.awt.image.BufferedImage;
28 import java.io.BufferedReader;
29 import java.io.File;
30 import java.io.FileReader;
31 import java.io.IOException;
32 import java.util.ArrayList;
33 import java.util.List;
34 
35 /**
36  * This Class is a demo to show how users can use Predictor APIs to do
37  * Image Classification with all hand-crafted Pre-processing.
38  * All helper functions for image pre-processing are
39  * currently available in ObjectDetector class.
40  */
41 public class PredictorExample {
42     @Option(name = "--model-path-prefix", usage = "input model directory and prefix of the model")
43     private String modelPathPrefix = "/model/ssd_resnet50_512";
44     @Option(name = "--input-image", usage = "the input image")
45     private String inputImagePath = "/images/dog.jpg";
46 
47     final static Logger logger = LoggerFactory.getLogger(PredictorExample.class);
48     private static NDArray$ NDArray = NDArray$.MODULE$;
49 
50     /**
51      * Helper class to print the maximum prediction result
52      * @param probabilities The float array of probability
53      * @param modelPathPrefix model Path needs to load the synset.txt
54      */
printMaximumClass(float[] probabilities, String modelPathPrefix)55     private static String printMaximumClass(float[] probabilities,
56                                             String modelPathPrefix) throws IOException {
57         String synsetFilePath = modelPathPrefix.substring(0,
58                 1 + modelPathPrefix.lastIndexOf(File.separator)) + "/synset.txt";
59         BufferedReader reader = new BufferedReader(new FileReader(synsetFilePath));
60         ArrayList<String> list = new ArrayList<>();
61         String line = reader.readLine();
62 
63         while (line != null){
64             list.add(line);
65             line = reader.readLine();
66         }
67         reader.close();
68 
69         int maxIdx = 0;
70         for (int i = 1;i<probabilities.length;i++) {
71             if (probabilities[i] > probabilities[maxIdx]) {
72                 maxIdx = i;
73             }
74         }
75 
76         return "Probability : " + probabilities[maxIdx] + " Class : " + list.get(maxIdx) ;
77     }
78 
main(String[] args)79     public static void main(String[] args) {
80         PredictorExample inst = new PredictorExample();
81         CmdLineParser parser  = new CmdLineParser(inst);
82         try {
83             parser.parseArgument(args);
84         } catch (Exception e) {
85             logger.error(e.getMessage(), e);
86             parser.printUsage(System.err);
87             System.exit(1);
88         }
89         // Prepare the model
90         List<Context> context = new ArrayList<Context>();
91         if (System.getenv().containsKey("SCALA_TEST_ON_GPU") &&
92                 Integer.valueOf(System.getenv("SCALA_TEST_ON_GPU")) == 1) {
93             context.add(Context.gpu());
94         } else {
95             context.add(Context.cpu());
96         }
97         List<DataDesc> inputDesc = new ArrayList<>();
98         Shape inputShape = new Shape(new int[]{1, 3, 224, 224});
99         inputDesc.add(new DataDesc("data", inputShape, DType.Float32(), "NCHW"));
100         Predictor predictor = new Predictor(inst.modelPathPrefix, inputDesc, context,0);
101         // Prepare data
102         NDArray img = Image.imRead(inst.inputImagePath, 1, true);
103         img = Image.imResize(img, 224, 224, null);
104         // predict
105         float[][] result = predictor.predict(new float[][]{img.toArray()});
106         try {
107             System.out.println("Predict with Float input");
108             System.out.println(printMaximumClass(result[0], inst.modelPathPrefix));
109         } catch (IOException e) {
110             System.err.println(e);
111         }
112         // predict with NDArray
113         NDArray nd = img;
114         nd = NDArray.transpose(nd, new Shape(new int[]{2, 0, 1}), null)[0];
115         nd = NDArray.expand_dims(nd, 0, null)[0];
116         nd = nd.asType(DType.Float32());
117         List<NDArray> ndList = new ArrayList<>();
118         ndList.add(nd);
119         List<NDArray> ndResult = predictor.predictWithNDArray(ndList);
120         try {
121             System.out.println("Predict with NDArray");
122             System.out.println(printMaximumClass(ndResult.get(0).toArray(), inst.modelPathPrefix));
123         } catch (IOException e) {
124             System.err.println(e);
125         }
126     }
127 
128 }
129