1 package com.taobao.android.mnndemo;
2 
3 import android.Manifest;
4 import android.content.Intent;
5 import android.content.pm.PackageManager;
6 import android.graphics.Color;
7 import android.graphics.Matrix;
8 import android.hardware.SensorManager;
9 import android.os.Build;
10 import android.os.Handler;
11 import android.os.HandlerThread;
12 import android.support.v7.app.AppCompatActivity;
13 import android.os.Bundle;
14 import android.util.DisplayMetrics;
15 import android.util.Log;
16 import android.view.OrientationEventListener;
17 import android.view.View;
18 import android.view.ViewStub;
19 import android.view.WindowManager;
20 import android.widget.AdapterView;
21 import android.widget.FrameLayout;
22 import android.widget.RelativeLayout;
23 import android.widget.Spinner;
24 import android.widget.TextView;
25 import android.widget.Toast;
26 
27 import com.taobao.android.mnn.MNNForwardType;
28 import com.taobao.android.mnn.MNNImageProcess;
29 import com.taobao.android.mnn.MNNNetInstance;
30 import com.taobao.android.utils.Common;
31 import com.taobao.android.utils.TxtFileReader;
32 
33 import java.text.DecimalFormat;
34 import java.util.AbstractMap;
35 import java.util.ArrayList;
36 import java.util.Collections;
37 import java.util.Comparator;
38 import java.util.List;
39 import java.util.Map;
40 import java.util.concurrent.atomic.AtomicBoolean;
41 
42 public class VideoActivity extends AppCompatActivity implements AdapterView.OnItemSelectedListener {
43 
44     private final String TAG = "VideoActivity";
45     private final int MAX_CLZ_SIZE = 1000;
46 
47     private final String MobileModelFileName = "MobileNet/v2/mobilenet_v2.caffe.mnn";
48     private final String MobileWordsFileName = "MobileNet/synset_words.txt";
49 
50     private final String SqueezeModelFileName = "SqueezeNet/v1.1/squeezenet_v1.1.caffe.mnn";
51     private final String SqueezeWordsFileName = "SqueezeNet/squeezenet.txt";
52 
53     private String mMobileModelPath;
54     private List<String> mMobileTaiWords;
55     private String mSqueezeModelPath;
56     private List<String> mSqueezeTaiWords;
57 
58     private int mSelectedModelIndex;// current using modle
59     private final MNNNetInstance.Config mConfig = new MNNNetInstance.Config();// session config
60 
61     private CameraView mCameraView;
62     private Spinner mForwardTypeSpinner;
63     private Spinner mThreadNumSpinner;
64     private Spinner mModelSpinner;
65     private Spinner mMoreDemoSpinner;
66 
67     private TextView mFirstResult;
68     private TextView mSecondResult;
69     private TextView mThirdResult;
70     private TextView mTimeTextView;
71 
72     private final int MobileInputWidth = 224;
73     private final int MobileInputHeight = 224;
74 
75     private final int SqueezeInputWidth = 227;
76     private final int SqueezeInputHeight = 227;
77 
78     HandlerThread mThread;
79     Handler mHandle;
80 
81     private AtomicBoolean mLockUIRender = new AtomicBoolean(false);
82     private AtomicBoolean mDrop = new AtomicBoolean(false);
83 
84     private MNNNetInstance mNetInstance;
85     private MNNNetInstance.Session mSession;
86     private MNNNetInstance.Session.Tensor mInputTensor;
87 
88     private int mRotateDegree;// 0/90/180/360
89 
90     /**
91      * 监听屏幕旋转
92      */
detectScreenRotate()93     void detectScreenRotate() {
94         OrientationEventListener orientationListener = new OrientationEventListener(this,
95                 SensorManager.SENSOR_DELAY_NORMAL) {
96             @Override
97             public void onOrientationChanged(int orientation) {
98 
99                 if (orientation == OrientationEventListener.ORIENTATION_UNKNOWN) {
100                     return;  //手机平放时,检测不到有效的角度
101                 }
102 
103                 //可以根据不同角度检测处理,这里只检测四个角度的改变
104                 orientation = (orientation + 45) / 90 * 90;
105                 mRotateDegree = orientation % 360;
106             }
107         };
108 
109 
110         if (orientationListener.canDetectOrientation()) {
111             orientationListener.enable();
112         } else {
113             orientationListener.disable();
114         }
115     }
116 
prepareModels()117     private void prepareModels() {
118 
119         mMobileModelPath = getCacheDir() + "mobilenet_v1.caffe.mnn";
120         try {
121             Common.copyAssetResource2File(getBaseContext(), MobileModelFileName, mMobileModelPath);
122             mMobileTaiWords = TxtFileReader.getUniqueUrls(getBaseContext(), MobileWordsFileName, Integer.MAX_VALUE);
123         } catch (Throwable e) {
124             throw new RuntimeException(e);
125         }
126 
127         mSqueezeModelPath = getCacheDir() + "squeezenet_v1.1.caffe.mnn";
128         try {
129             Common.copyAssetResource2File(getBaseContext(), SqueezeModelFileName, mSqueezeModelPath);
130             mSqueezeTaiWords = TxtFileReader.getUniqueUrls(getBaseContext(), SqueezeWordsFileName, Integer.MAX_VALUE);
131         } catch (Throwable e) {
132             throw new RuntimeException(e);
133         }
134     }
135 
136 
prepareNet()137     private void prepareNet() {
138         if (null != mSession) {
139             mSession.release();
140             mSession = null;
141         }
142         if (mNetInstance != null) {
143             mNetInstance.release();
144             mNetInstance = null;
145         }
146 
147         String modelPath = mMobileModelPath;
148         if (mSelectedModelIndex == 0) {
149             modelPath = mMobileModelPath;
150         } else if (mSelectedModelIndex == 1) {
151             modelPath = mSqueezeModelPath;
152         }
153 
154         // create net instance
155         mNetInstance = MNNNetInstance.createFromFile(modelPath);
156 
157         // mConfig.saveTensors;
158         mSession = mNetInstance.createSession(mConfig);
159 
160         // get input tensor
161         mInputTensor = mSession.getInput(null);
162 
163         int[] dimensions = mInputTensor.getDimensions();
164         dimensions[0] = 1; // force batch = 1  NCHW  [batch, channels, height, width]
165         mInputTensor.reshape(dimensions);
166         mSession.reshape();
167 
168         mLockUIRender.set(false);
169     }
170 
171     @Override
onCreate(Bundle savedInstanceState)172     protected void onCreate(Bundle savedInstanceState) {
173         super.onCreate(savedInstanceState);
174         getWindow().setFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON,
175                 WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON);
176         getWindow().addFlags(WindowManager.LayoutParams.FLAG_FULLSCREEN);
177         setContentView(R.layout.activity_main);
178 
179         detectScreenRotate();
180 
181         mSelectedModelIndex = 0;
182         mConfig.numThread = 4;
183         mConfig.forwardType = MNNForwardType.FORWARD_CPU.type;
184 
185         // prepare mnn net models
186         prepareModels();
187 
188         mForwardTypeSpinner = (Spinner) findViewById(R.id.forwardTypeSpinner);
189         mThreadNumSpinner = (Spinner) findViewById(R.id.threadsSpinner);
190         mThreadNumSpinner.setSelection(2);
191         mModelSpinner = (Spinner) findViewById(R.id.modelTypeSpinner);
192         mMoreDemoSpinner = (Spinner) findViewById(R.id.MoreDemo);
193 
194         mFirstResult = findViewById(R.id.firstResult);
195         mSecondResult = findViewById(R.id.secondResult);
196         mThirdResult = findViewById(R.id.thirdResult);
197         mTimeTextView = findViewById(R.id.timeTextView);
198 
199         mForwardTypeSpinner.setOnItemSelectedListener(VideoActivity.this);
200         mThreadNumSpinner.setOnItemSelectedListener(VideoActivity.this);
201         mModelSpinner.setOnItemSelectedListener(VideoActivity.this);
202         mMoreDemoSpinner.setOnItemSelectedListener(VideoActivity.this);
203 
204         // init sub thread handle
205         mLockUIRender.set(true);
206         clearUIForPrepareNet();
207 
208         if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) {
209             if (checkSelfPermission(Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
210                 requestPermissions(new String[]{Manifest.permission.CAMERA}, 10);
211             } else {
212                 handlePreViewCallBack();
213             }
214         } else {
215             handlePreViewCallBack();
216         }
217 
218         mThread = new HandlerThread("MNNNet");
219         mThread.start();
220         mHandle = new Handler(mThread.getLooper());
221 
222         mHandle.post(new Runnable() {
223             @Override
224             public void run() {
225                 prepareNet();
226             }
227         });
228 
229     }
230 
231 
232     @Override
onRequestPermissionsResult(int requestCode, String[] permissions, int[] grantResults)233     public void onRequestPermissionsResult(int requestCode, String[] permissions, int[] grantResults) {
234         super.onRequestPermissionsResult(requestCode, permissions, grantResults);
235 
236         if (10 == requestCode) {
237             if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) {
238                 handlePreViewCallBack();
239             } else {
240                 Toast.makeText(this, "没有获得必要的权限", Toast.LENGTH_SHORT).show();
241             }
242         }
243 
244     }
245 
handlePreViewCallBack()246     private void handlePreViewCallBack() {
247 
248         ViewStub stub = (ViewStub) findViewById(R.id.stub);
249         stub.inflate();
250 
251         mCameraView = (CameraView) findViewById(R.id.camera_view);
252 
253         mCameraView.setPreviewCallback(new CameraView.PreviewCallback() {
254             @Override
255             public void onGetPreviewOptimalSize(int optimalWidth, int optimalHeight) {
256 
257                 // adjust video preview size according to screen size
258                 DisplayMetrics metric = new DisplayMetrics();
259                 getWindowManager().getDefaultDisplay().getMetrics(metric);
260                 int fixedVideoHeight = metric.widthPixels * optimalWidth / optimalHeight;
261 
262                 FrameLayout layoutVideo = findViewById(R.id.videoLayout);
263                 RelativeLayout.LayoutParams params = (RelativeLayout.LayoutParams) layoutVideo.getLayoutParams();
264                 params.height = fixedVideoHeight;
265                 layoutVideo.setLayoutParams(params);
266             }
267 
268             @Override
269             public void onPreviewFrame(final byte[] data, final int imageWidth, final int imageHeight, final int angle) {
270 
271                 if (mLockUIRender.get()) {
272                     return;
273                 }
274 
275 
276                 if (mDrop.get()) {
277                     Log.w(TAG, "drop frame , net running too slow !!");
278                 } else {
279                     mDrop.set(true);
280                     mHandle.post(new Runnable() {
281                         @Override
282                         public void run() {
283                             mDrop.set(false);
284                             if (mLockUIRender.get()) {
285                                 return;
286                             }
287 
288                             // calculate corrected angle based on camera orientation and mobile rotate degree. (back camrea)
289                             int needRotateAngle = (angle + mRotateDegree) % 360;
290 
291                             /*
292                              *  convert data to input tensor
293                              */
294                             final MNNImageProcess.Config config = new MNNImageProcess.Config();
295                             if (mSelectedModelIndex == 0) {
296                                 // normalization params
297                                 config.mean = new float[]{103.94f, 116.78f, 123.68f};
298                                 config.normal = new float[]{0.017f, 0.017f, 0.017f};
299                                 config.source = MNNImageProcess.Format.YUV_NV21;// input source format
300                                 config.dest = MNNImageProcess.Format.BGR;// input data format
301 
302                                 // matrix transform: dst to src
303                                 Matrix matrix = new Matrix();
304                                 matrix.postScale(MobileInputWidth / (float) imageWidth, MobileInputHeight / (float) imageHeight);
305                                 matrix.postRotate(needRotateAngle, MobileInputWidth / 2, MobileInputHeight / 2);
306                                 matrix.invert(matrix);
307 
308                                 MNNImageProcess.convertBuffer(data, imageWidth, imageHeight, mInputTensor, config, matrix);
309 
310                             } else if (mSelectedModelIndex == 1) {
311                                 // input data format
312                                 config.source = MNNImageProcess.Format.YUV_NV21;// input source format
313                                 config.dest = MNNImageProcess.Format.BGR;// input data format
314 
315                                 // matrix transform: dst to src
316                                 final Matrix matrix = new Matrix();
317                                 matrix.postScale(SqueezeInputWidth / (float) (float) imageWidth, SqueezeInputHeight / (float) imageHeight);
318                                 matrix.postRotate(needRotateAngle, SqueezeInputWidth / 2, SqueezeInputWidth / 2);
319                                 matrix.invert(matrix);
320 
321                                 MNNImageProcess.convertBuffer(data, imageWidth, imageHeight, mInputTensor, config, matrix);
322                             }
323 
324                             final long startTimestamp = System.nanoTime();
325                             /**
326                              * inference
327                              */
328                             mSession.run();
329 
330                             /**
331                              * get output tensor
332                              */
333                             MNNNetInstance.Session.Tensor output = mSession.getOutput(null);
334 
335                             float[] result = output.getFloatData();// get float results
336                             final long endTimestamp = System.nanoTime();
337                             final float inferenceTimeCost = (endTimestamp - startTimestamp) / 1000000.0f;
338 
339                             if (result.length > MAX_CLZ_SIZE) {
340                                 Log.w(TAG, "session result too big (" + result.length + "), model incorrect ?");
341                             }
342 
343                             final List<Map.Entry<Integer, Float>> maybes = new ArrayList<>();
344                             for (int i = 0; i < result.length; i++) {
345                                 float confidence = result[i];
346                                 if (confidence > 0.01) {
347                                     maybes.add(new AbstractMap.SimpleEntry<Integer, Float>(i, confidence));
348                                 }
349                             }
350 
351                             Collections.sort(maybes, new Comparator<Map.Entry<Integer, Float>>() {
352                                 @Override
353                                 public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) {
354                                     if (Math.abs(o1.getValue() - o2.getValue()) <= Float.MIN_NORMAL) {
355                                         return 0;
356                                     }
357                                     return o1.getValue() > o2.getValue() ? -1 : 1;
358                                 }
359                             });
360 
361                             // show results on ui
362                             runOnUiThread(new Runnable() {
363                                 @Override
364                                 public void run() {
365 
366                                     if (maybes.size() == 0) {
367                                         mFirstResult.setText("no data");
368                                         mSecondResult.setText("");
369                                         mThirdResult.setText("");
370                                     }
371                                     if (maybes.size() > 0) {
372                                         mFirstResult.setTextColor(maybes.get(0).getValue() > 0.2 ? Color.BLACK : Color.parseColor("#a4a4a4"));
373                                         final Integer iKey = maybes.get(0).getKey();
374                                         final Float fValue = maybes.get(0).getValue();
375                                         String strWord = "unknown";
376                                         if (0 == mSelectedModelIndex) {
377                                             if (iKey < mMobileTaiWords.size()) {
378                                                 strWord = mMobileTaiWords.get(iKey);
379                                             }
380                                         } else {
381                                             if (iKey < mSqueezeTaiWords.size()) {
382                                                 strWord = mSqueezeTaiWords.get(iKey);
383                                             }
384                                         }
385                                         final String resKey = mSelectedModelIndex == 1 ? strWord.length() >= 10 ? strWord.substring(10) : strWord : strWord;
386                                         mFirstResult.setText(resKey + ":" + new DecimalFormat("0.00").format(fValue));
387 
388                                     }
389                                     if (maybes.size() > 1) {
390                                         final Integer iKey = maybes.get(1).getKey();
391                                         final Float fValue = maybes.get(1).getValue();
392                                         String strWord = "unknown";
393                                         if (0 == mSelectedModelIndex) {
394                                             if (iKey < mMobileTaiWords.size()) {
395                                                 strWord = mMobileTaiWords.get(iKey);
396                                             }
397                                         } else {
398                                             if (iKey < mSqueezeTaiWords.size()) {
399                                                 strWord = mSqueezeTaiWords.get(iKey);
400                                             }
401                                         }
402                                         final String resKey = mSelectedModelIndex == 1 ? strWord.length() >= 10 ? strWord.substring(10) : strWord : strWord;
403                                         mSecondResult.setText(resKey + ":" + new DecimalFormat("0.00").format(fValue));
404 
405                                     }
406                                     if (maybes.size() > 2) {
407                                         final Integer iKey = maybes.get(2).getKey();
408                                         final Float fValue = maybes.get(2).getValue();
409                                         String strWord = "unknown";
410                                         if (0 == mSelectedModelIndex) {
411                                             if (iKey < mMobileTaiWords.size()) {
412                                                 strWord = mMobileTaiWords.get(iKey);
413                                             }
414                                         } else {
415                                             if (iKey < mSqueezeTaiWords.size()) {
416                                                 strWord = mSqueezeTaiWords.get(iKey);
417                                             }
418                                         }
419                                         final String resKey = mSelectedModelIndex == 1 ? strWord.length() >= 10 ? strWord.substring(10) : strWord : strWord;
420                                         mThirdResult.setText(resKey + ":" + new DecimalFormat("0.00").format(fValue));
421                                     }
422 
423                                     mTimeTextView.setText("cost time:" + inferenceTimeCost + "ms");
424                                 }
425                             });
426 
427                         }
428                     });
429                 }
430             }
431         });
432     }
433 
434 
435     @Override
onItemSelected(AdapterView<?> adapterView, View view, int i, long l)436     public void onItemSelected(AdapterView<?> adapterView, View view, int i, long l) {
437 
438         // forward type
439         if (mForwardTypeSpinner.getId() == adapterView.getId()) {
440 
441             if (i == 0) {
442                 mConfig.forwardType = MNNForwardType.FORWARD_CPU.type;
443             } else if (i == 1) {
444                 mConfig.forwardType = MNNForwardType.FORWARD_OPENCL.type;
445             } else if (i == 2) {
446                 mConfig.forwardType = MNNForwardType.FORWARD_OPENGL.type;
447             } else if (i == 3) {
448                 mConfig.forwardType = MNNForwardType.FORWARD_VULKAN.type;
449             }
450         }
451         // threads num
452         else if (mThreadNumSpinner.getId() == adapterView.getId()) {
453 
454             String[] threadList = getResources().getStringArray(R.array.thread_list);
455             mConfig.numThread = Integer.parseInt(threadList[i].split(" ")[1]);
456         }
457         // model index
458         else if (mModelSpinner.getId() == adapterView.getId()) {
459 
460             mSelectedModelIndex = i;
461         } else if (mMoreDemoSpinner.getId() == adapterView.getId()) {
462 
463             if (i == 1) {
464                 Intent intent = new Intent(VideoActivity.this, ImageActivity.class);
465                 startActivity(intent);
466             } else if (i == 2) {
467                 Intent intent = new Intent(VideoActivity.this, PortraitActivity.class);
468                 startActivity(intent);
469             } else if (i == 3) {
470                 Intent intent = new Intent(VideoActivity.this, OpenGLTestActivity.class);
471                 startActivity(intent);
472             }
473         }
474 
475 
476         mLockUIRender.set(true);
477         clearUIForPrepareNet();
478 
479         mHandle.post(new Runnable() {
480             @Override
481             public void run() {
482                 prepareNet();
483             }
484         });
485 
486     }
487 
clearUIForPrepareNet()488     private void clearUIForPrepareNet() {
489         mFirstResult.setText("prepare net ...");
490         mSecondResult.setText("");
491         mThirdResult.setText("");
492         mTimeTextView.setText("");
493     }
494 
495 
496     @Override
onNothingSelected(AdapterView<?> adapterView)497     public void onNothingSelected(AdapterView<?> adapterView) {
498 
499     }
500 
501     @Override
onPause()502     protected void onPause() {
503         mCameraView.onPause();
504         super.onPause();
505     }
506 
507     @Override
onResume()508     protected void onResume() {
509         super.onResume();
510         mCameraView.onResume();
511     }
512 
513 
514     @Override
onDestroy()515     protected void onDestroy() {
516         mThread.interrupt();
517 
518         /**
519          * instance release
520          */
521         mHandle.post(new Runnable() {
522             @Override
523             public void run() {
524                 if (mNetInstance != null) {
525                     mNetInstance.release();
526                 }
527             }
528         });
529 
530         super.onDestroy();
531     }
532 }