1 package com.taobao.android.mnndemo; 2 3 import android.Manifest; 4 import android.content.Intent; 5 import android.content.pm.PackageManager; 6 import android.graphics.Color; 7 import android.graphics.Matrix; 8 import android.hardware.SensorManager; 9 import android.os.Build; 10 import android.os.Handler; 11 import android.os.HandlerThread; 12 import android.support.v7.app.AppCompatActivity; 13 import android.os.Bundle; 14 import android.util.DisplayMetrics; 15 import android.util.Log; 16 import android.view.OrientationEventListener; 17 import android.view.View; 18 import android.view.ViewStub; 19 import android.view.WindowManager; 20 import android.widget.AdapterView; 21 import android.widget.FrameLayout; 22 import android.widget.RelativeLayout; 23 import android.widget.Spinner; 24 import android.widget.TextView; 25 import android.widget.Toast; 26 27 import com.taobao.android.mnn.MNNForwardType; 28 import com.taobao.android.mnn.MNNImageProcess; 29 import com.taobao.android.mnn.MNNNetInstance; 30 import com.taobao.android.utils.Common; 31 import com.taobao.android.utils.TxtFileReader; 32 33 import java.text.DecimalFormat; 34 import java.util.AbstractMap; 35 import java.util.ArrayList; 36 import java.util.Collections; 37 import java.util.Comparator; 38 import java.util.List; 39 import java.util.Map; 40 import java.util.concurrent.atomic.AtomicBoolean; 41 42 public class VideoActivity extends AppCompatActivity implements AdapterView.OnItemSelectedListener { 43 44 private final String TAG = "VideoActivity"; 45 private final int MAX_CLZ_SIZE = 1000; 46 47 private final String MobileModelFileName = "MobileNet/v2/mobilenet_v2.caffe.mnn"; 48 private final String MobileWordsFileName = "MobileNet/synset_words.txt"; 49 50 private final String SqueezeModelFileName = "SqueezeNet/v1.1/squeezenet_v1.1.caffe.mnn"; 51 private final String SqueezeWordsFileName = "SqueezeNet/squeezenet.txt"; 52 53 private String mMobileModelPath; 54 private List<String> mMobileTaiWords; 55 private String mSqueezeModelPath; 56 private List<String> mSqueezeTaiWords; 57 58 private int mSelectedModelIndex;// current using modle 59 private final MNNNetInstance.Config mConfig = new MNNNetInstance.Config();// session config 60 61 private CameraView mCameraView; 62 private Spinner mForwardTypeSpinner; 63 private Spinner mThreadNumSpinner; 64 private Spinner mModelSpinner; 65 private Spinner mMoreDemoSpinner; 66 67 private TextView mFirstResult; 68 private TextView mSecondResult; 69 private TextView mThirdResult; 70 private TextView mTimeTextView; 71 72 private final int MobileInputWidth = 224; 73 private final int MobileInputHeight = 224; 74 75 private final int SqueezeInputWidth = 227; 76 private final int SqueezeInputHeight = 227; 77 78 HandlerThread mThread; 79 Handler mHandle; 80 81 private AtomicBoolean mLockUIRender = new AtomicBoolean(false); 82 private AtomicBoolean mDrop = new AtomicBoolean(false); 83 84 private MNNNetInstance mNetInstance; 85 private MNNNetInstance.Session mSession; 86 private MNNNetInstance.Session.Tensor mInputTensor; 87 88 private int mRotateDegree;// 0/90/180/360 89 90 /** 91 * 监听屏幕旋转 92 */ detectScreenRotate()93 void detectScreenRotate() { 94 OrientationEventListener orientationListener = new OrientationEventListener(this, 95 SensorManager.SENSOR_DELAY_NORMAL) { 96 @Override 97 public void onOrientationChanged(int orientation) { 98 99 if (orientation == OrientationEventListener.ORIENTATION_UNKNOWN) { 100 return; //手机平放时,检测不到有效的角度 101 } 102 103 //可以根据不同角度检测处理,这里只检测四个角度的改变 104 orientation = (orientation + 45) / 90 * 90; 105 mRotateDegree = orientation % 360; 106 } 107 }; 108 109 110 if (orientationListener.canDetectOrientation()) { 111 orientationListener.enable(); 112 } else { 113 orientationListener.disable(); 114 } 115 } 116 prepareModels()117 private void prepareModels() { 118 119 mMobileModelPath = getCacheDir() + "mobilenet_v1.caffe.mnn"; 120 try { 121 Common.copyAssetResource2File(getBaseContext(), MobileModelFileName, mMobileModelPath); 122 mMobileTaiWords = TxtFileReader.getUniqueUrls(getBaseContext(), MobileWordsFileName, Integer.MAX_VALUE); 123 } catch (Throwable e) { 124 throw new RuntimeException(e); 125 } 126 127 mSqueezeModelPath = getCacheDir() + "squeezenet_v1.1.caffe.mnn"; 128 try { 129 Common.copyAssetResource2File(getBaseContext(), SqueezeModelFileName, mSqueezeModelPath); 130 mSqueezeTaiWords = TxtFileReader.getUniqueUrls(getBaseContext(), SqueezeWordsFileName, Integer.MAX_VALUE); 131 } catch (Throwable e) { 132 throw new RuntimeException(e); 133 } 134 } 135 136 prepareNet()137 private void prepareNet() { 138 if (null != mSession) { 139 mSession.release(); 140 mSession = null; 141 } 142 if (mNetInstance != null) { 143 mNetInstance.release(); 144 mNetInstance = null; 145 } 146 147 String modelPath = mMobileModelPath; 148 if (mSelectedModelIndex == 0) { 149 modelPath = mMobileModelPath; 150 } else if (mSelectedModelIndex == 1) { 151 modelPath = mSqueezeModelPath; 152 } 153 154 // create net instance 155 mNetInstance = MNNNetInstance.createFromFile(modelPath); 156 157 // mConfig.saveTensors; 158 mSession = mNetInstance.createSession(mConfig); 159 160 // get input tensor 161 mInputTensor = mSession.getInput(null); 162 163 int[] dimensions = mInputTensor.getDimensions(); 164 dimensions[0] = 1; // force batch = 1 NCHW [batch, channels, height, width] 165 mInputTensor.reshape(dimensions); 166 mSession.reshape(); 167 168 mLockUIRender.set(false); 169 } 170 171 @Override onCreate(Bundle savedInstanceState)172 protected void onCreate(Bundle savedInstanceState) { 173 super.onCreate(savedInstanceState); 174 getWindow().setFlags(WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON, 175 WindowManager.LayoutParams.FLAG_KEEP_SCREEN_ON); 176 getWindow().addFlags(WindowManager.LayoutParams.FLAG_FULLSCREEN); 177 setContentView(R.layout.activity_main); 178 179 detectScreenRotate(); 180 181 mSelectedModelIndex = 0; 182 mConfig.numThread = 4; 183 mConfig.forwardType = MNNForwardType.FORWARD_CPU.type; 184 185 // prepare mnn net models 186 prepareModels(); 187 188 mForwardTypeSpinner = (Spinner) findViewById(R.id.forwardTypeSpinner); 189 mThreadNumSpinner = (Spinner) findViewById(R.id.threadsSpinner); 190 mThreadNumSpinner.setSelection(2); 191 mModelSpinner = (Spinner) findViewById(R.id.modelTypeSpinner); 192 mMoreDemoSpinner = (Spinner) findViewById(R.id.MoreDemo); 193 194 mFirstResult = findViewById(R.id.firstResult); 195 mSecondResult = findViewById(R.id.secondResult); 196 mThirdResult = findViewById(R.id.thirdResult); 197 mTimeTextView = findViewById(R.id.timeTextView); 198 199 mForwardTypeSpinner.setOnItemSelectedListener(VideoActivity.this); 200 mThreadNumSpinner.setOnItemSelectedListener(VideoActivity.this); 201 mModelSpinner.setOnItemSelectedListener(VideoActivity.this); 202 mMoreDemoSpinner.setOnItemSelectedListener(VideoActivity.this); 203 204 // init sub thread handle 205 mLockUIRender.set(true); 206 clearUIForPrepareNet(); 207 208 if (Build.VERSION.SDK_INT >= Build.VERSION_CODES.M) { 209 if (checkSelfPermission(Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) { 210 requestPermissions(new String[]{Manifest.permission.CAMERA}, 10); 211 } else { 212 handlePreViewCallBack(); 213 } 214 } else { 215 handlePreViewCallBack(); 216 } 217 218 mThread = new HandlerThread("MNNNet"); 219 mThread.start(); 220 mHandle = new Handler(mThread.getLooper()); 221 222 mHandle.post(new Runnable() { 223 @Override 224 public void run() { 225 prepareNet(); 226 } 227 }); 228 229 } 230 231 232 @Override onRequestPermissionsResult(int requestCode, String[] permissions, int[] grantResults)233 public void onRequestPermissionsResult(int requestCode, String[] permissions, int[] grantResults) { 234 super.onRequestPermissionsResult(requestCode, permissions, grantResults); 235 236 if (10 == requestCode) { 237 if (grantResults.length > 0 && grantResults[0] == PackageManager.PERMISSION_GRANTED) { 238 handlePreViewCallBack(); 239 } else { 240 Toast.makeText(this, "没有获得必要的权限", Toast.LENGTH_SHORT).show(); 241 } 242 } 243 244 } 245 handlePreViewCallBack()246 private void handlePreViewCallBack() { 247 248 ViewStub stub = (ViewStub) findViewById(R.id.stub); 249 stub.inflate(); 250 251 mCameraView = (CameraView) findViewById(R.id.camera_view); 252 253 mCameraView.setPreviewCallback(new CameraView.PreviewCallback() { 254 @Override 255 public void onGetPreviewOptimalSize(int optimalWidth, int optimalHeight) { 256 257 // adjust video preview size according to screen size 258 DisplayMetrics metric = new DisplayMetrics(); 259 getWindowManager().getDefaultDisplay().getMetrics(metric); 260 int fixedVideoHeight = metric.widthPixels * optimalWidth / optimalHeight; 261 262 FrameLayout layoutVideo = findViewById(R.id.videoLayout); 263 RelativeLayout.LayoutParams params = (RelativeLayout.LayoutParams) layoutVideo.getLayoutParams(); 264 params.height = fixedVideoHeight; 265 layoutVideo.setLayoutParams(params); 266 } 267 268 @Override 269 public void onPreviewFrame(final byte[] data, final int imageWidth, final int imageHeight, final int angle) { 270 271 if (mLockUIRender.get()) { 272 return; 273 } 274 275 276 if (mDrop.get()) { 277 Log.w(TAG, "drop frame , net running too slow !!"); 278 } else { 279 mDrop.set(true); 280 mHandle.post(new Runnable() { 281 @Override 282 public void run() { 283 mDrop.set(false); 284 if (mLockUIRender.get()) { 285 return; 286 } 287 288 // calculate corrected angle based on camera orientation and mobile rotate degree. (back camrea) 289 int needRotateAngle = (angle + mRotateDegree) % 360; 290 291 /* 292 * convert data to input tensor 293 */ 294 final MNNImageProcess.Config config = new MNNImageProcess.Config(); 295 if (mSelectedModelIndex == 0) { 296 // normalization params 297 config.mean = new float[]{103.94f, 116.78f, 123.68f}; 298 config.normal = new float[]{0.017f, 0.017f, 0.017f}; 299 config.source = MNNImageProcess.Format.YUV_NV21;// input source format 300 config.dest = MNNImageProcess.Format.BGR;// input data format 301 302 // matrix transform: dst to src 303 Matrix matrix = new Matrix(); 304 matrix.postScale(MobileInputWidth / (float) imageWidth, MobileInputHeight / (float) imageHeight); 305 matrix.postRotate(needRotateAngle, MobileInputWidth / 2, MobileInputHeight / 2); 306 matrix.invert(matrix); 307 308 MNNImageProcess.convertBuffer(data, imageWidth, imageHeight, mInputTensor, config, matrix); 309 310 } else if (mSelectedModelIndex == 1) { 311 // input data format 312 config.source = MNNImageProcess.Format.YUV_NV21;// input source format 313 config.dest = MNNImageProcess.Format.BGR;// input data format 314 315 // matrix transform: dst to src 316 final Matrix matrix = new Matrix(); 317 matrix.postScale(SqueezeInputWidth / (float) (float) imageWidth, SqueezeInputHeight / (float) imageHeight); 318 matrix.postRotate(needRotateAngle, SqueezeInputWidth / 2, SqueezeInputWidth / 2); 319 matrix.invert(matrix); 320 321 MNNImageProcess.convertBuffer(data, imageWidth, imageHeight, mInputTensor, config, matrix); 322 } 323 324 final long startTimestamp = System.nanoTime(); 325 /** 326 * inference 327 */ 328 mSession.run(); 329 330 /** 331 * get output tensor 332 */ 333 MNNNetInstance.Session.Tensor output = mSession.getOutput(null); 334 335 float[] result = output.getFloatData();// get float results 336 final long endTimestamp = System.nanoTime(); 337 final float inferenceTimeCost = (endTimestamp - startTimestamp) / 1000000.0f; 338 339 if (result.length > MAX_CLZ_SIZE) { 340 Log.w(TAG, "session result too big (" + result.length + "), model incorrect ?"); 341 } 342 343 final List<Map.Entry<Integer, Float>> maybes = new ArrayList<>(); 344 for (int i = 0; i < result.length; i++) { 345 float confidence = result[i]; 346 if (confidence > 0.01) { 347 maybes.add(new AbstractMap.SimpleEntry<Integer, Float>(i, confidence)); 348 } 349 } 350 351 Collections.sort(maybes, new Comparator<Map.Entry<Integer, Float>>() { 352 @Override 353 public int compare(Map.Entry<Integer, Float> o1, Map.Entry<Integer, Float> o2) { 354 if (Math.abs(o1.getValue() - o2.getValue()) <= Float.MIN_NORMAL) { 355 return 0; 356 } 357 return o1.getValue() > o2.getValue() ? -1 : 1; 358 } 359 }); 360 361 // show results on ui 362 runOnUiThread(new Runnable() { 363 @Override 364 public void run() { 365 366 if (maybes.size() == 0) { 367 mFirstResult.setText("no data"); 368 mSecondResult.setText(""); 369 mThirdResult.setText(""); 370 } 371 if (maybes.size() > 0) { 372 mFirstResult.setTextColor(maybes.get(0).getValue() > 0.2 ? Color.BLACK : Color.parseColor("#a4a4a4")); 373 final Integer iKey = maybes.get(0).getKey(); 374 final Float fValue = maybes.get(0).getValue(); 375 String strWord = "unknown"; 376 if (0 == mSelectedModelIndex) { 377 if (iKey < mMobileTaiWords.size()) { 378 strWord = mMobileTaiWords.get(iKey); 379 } 380 } else { 381 if (iKey < mSqueezeTaiWords.size()) { 382 strWord = mSqueezeTaiWords.get(iKey); 383 } 384 } 385 final String resKey = mSelectedModelIndex == 1 ? strWord.length() >= 10 ? strWord.substring(10) : strWord : strWord; 386 mFirstResult.setText(resKey + ":" + new DecimalFormat("0.00").format(fValue)); 387 388 } 389 if (maybes.size() > 1) { 390 final Integer iKey = maybes.get(1).getKey(); 391 final Float fValue = maybes.get(1).getValue(); 392 String strWord = "unknown"; 393 if (0 == mSelectedModelIndex) { 394 if (iKey < mMobileTaiWords.size()) { 395 strWord = mMobileTaiWords.get(iKey); 396 } 397 } else { 398 if (iKey < mSqueezeTaiWords.size()) { 399 strWord = mSqueezeTaiWords.get(iKey); 400 } 401 } 402 final String resKey = mSelectedModelIndex == 1 ? strWord.length() >= 10 ? strWord.substring(10) : strWord : strWord; 403 mSecondResult.setText(resKey + ":" + new DecimalFormat("0.00").format(fValue)); 404 405 } 406 if (maybes.size() > 2) { 407 final Integer iKey = maybes.get(2).getKey(); 408 final Float fValue = maybes.get(2).getValue(); 409 String strWord = "unknown"; 410 if (0 == mSelectedModelIndex) { 411 if (iKey < mMobileTaiWords.size()) { 412 strWord = mMobileTaiWords.get(iKey); 413 } 414 } else { 415 if (iKey < mSqueezeTaiWords.size()) { 416 strWord = mSqueezeTaiWords.get(iKey); 417 } 418 } 419 final String resKey = mSelectedModelIndex == 1 ? strWord.length() >= 10 ? strWord.substring(10) : strWord : strWord; 420 mThirdResult.setText(resKey + ":" + new DecimalFormat("0.00").format(fValue)); 421 } 422 423 mTimeTextView.setText("cost time:" + inferenceTimeCost + "ms"); 424 } 425 }); 426 427 } 428 }); 429 } 430 } 431 }); 432 } 433 434 435 @Override onItemSelected(AdapterView<?> adapterView, View view, int i, long l)436 public void onItemSelected(AdapterView<?> adapterView, View view, int i, long l) { 437 438 // forward type 439 if (mForwardTypeSpinner.getId() == adapterView.getId()) { 440 441 if (i == 0) { 442 mConfig.forwardType = MNNForwardType.FORWARD_CPU.type; 443 } else if (i == 1) { 444 mConfig.forwardType = MNNForwardType.FORWARD_OPENCL.type; 445 } else if (i == 2) { 446 mConfig.forwardType = MNNForwardType.FORWARD_OPENGL.type; 447 } else if (i == 3) { 448 mConfig.forwardType = MNNForwardType.FORWARD_VULKAN.type; 449 } 450 } 451 // threads num 452 else if (mThreadNumSpinner.getId() == adapterView.getId()) { 453 454 String[] threadList = getResources().getStringArray(R.array.thread_list); 455 mConfig.numThread = Integer.parseInt(threadList[i].split(" ")[1]); 456 } 457 // model index 458 else if (mModelSpinner.getId() == adapterView.getId()) { 459 460 mSelectedModelIndex = i; 461 } else if (mMoreDemoSpinner.getId() == adapterView.getId()) { 462 463 if (i == 1) { 464 Intent intent = new Intent(VideoActivity.this, ImageActivity.class); 465 startActivity(intent); 466 } else if (i == 2) { 467 Intent intent = new Intent(VideoActivity.this, PortraitActivity.class); 468 startActivity(intent); 469 } else if (i == 3) { 470 Intent intent = new Intent(VideoActivity.this, OpenGLTestActivity.class); 471 startActivity(intent); 472 } 473 } 474 475 476 mLockUIRender.set(true); 477 clearUIForPrepareNet(); 478 479 mHandle.post(new Runnable() { 480 @Override 481 public void run() { 482 prepareNet(); 483 } 484 }); 485 486 } 487 clearUIForPrepareNet()488 private void clearUIForPrepareNet() { 489 mFirstResult.setText("prepare net ..."); 490 mSecondResult.setText(""); 491 mThirdResult.setText(""); 492 mTimeTextView.setText(""); 493 } 494 495 496 @Override onNothingSelected(AdapterView<?> adapterView)497 public void onNothingSelected(AdapterView<?> adapterView) { 498 499 } 500 501 @Override onPause()502 protected void onPause() { 503 mCameraView.onPause(); 504 super.onPause(); 505 } 506 507 @Override onResume()508 protected void onResume() { 509 super.onResume(); 510 mCameraView.onResume(); 511 } 512 513 514 @Override onDestroy()515 protected void onDestroy() { 516 mThread.interrupt(); 517 518 /** 519 * instance release 520 */ 521 mHandle.post(new Runnable() { 522 @Override 523 public void run() { 524 if (mNetInstance != null) { 525 mNetInstance.release(); 526 } 527 } 528 }); 529 530 super.onDestroy(); 531 } 532 }