1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet.infer 19 20import org.apache.mxnet.io.NDArrayIter 21import org.apache.mxnet.module.{BaseModule, Module} 22import org.apache.mxnet._ 23import org.mockito.Matchers._ 24import org.mockito.Mockito 25import org.scalatest.{BeforeAndAfterAll, FunSuite} 26 27class PredictorSuite extends FunSuite with BeforeAndAfterAll { 28 29 class MyPredictor(val modelPathPrefix: String, 30 override val inputDescriptors: IndexedSeq[DataDesc]) 31 extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) { 32 33 override def loadModule(): Module = mockModule 34 35 val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors 36 val getBatchSize: Int = batchSize 37 val getBatchIndex: Int = batchIndex 38 39 lazy val mockModule: Module = Mockito.mock(classOf[Module]) 40 } 41 42 test("PredictorSuite-testPredictorConstruction") { 43 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2), 44 layout = Layout.NCHW)) 45 46 val mockPredictor = new MyPredictor("xyz", inputDescriptor) 47 48 assert(mockPredictor.getBatchSize == 1) 49 assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N')) 50 51 val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2), 52 layout = Layout.NCHW), 53 new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW)) 54 55 assertThrows[IllegalArgumentException] { 56 val mockPredictor = new MyPredictor("xyz", inputDescriptor2) 57 } 58 59 // batchsize is defaulted to 1 60 val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW")) 61 val p2 = new MyPredictor("xyz", inputDescriptor) 62 assert(p2.getBatchSize == 1, "should use a default batch size of 1") 63 64 } 65 66 test("PredictorSuite-testWithFlatArrays") { 67 68 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), 69 layout = Layout.NCHW)) 70 val inputData = Array.fill[Float](12)(1) 71 72 // this will disposed at the end of the predict call on Predictor. 73 val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2))) 74 75 val testPredictor = new MyPredictor("xyz", inputDescriptor) 76 77 Mockito.doReturn(predictResult).when(testPredictor.mockModule) 78 .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) 79 80 val testFun = testPredictor.predict(IndexedSeq(inputData)) 81 82 assert(testFun.size == 1, "output size should be 1 ") 83 84 assert(Array.fill[Float](12)(1).mkString == testFun(0).mkString) 85 86 // Verify that the module was bound with batch size 1 and rebound back to the original 87 // input descriptor. the number of times is twice here because loadModule overrides the 88 // initial bind. 89 Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], 90 any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] 91 , any[Option[BaseModule]], any[String]) 92 } 93 94 test("PredictorSuite-testWithFlatFloat64Arrays") { 95 96 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), 97 layout = Layout.NCHW, dtype = DType.Float64)) 98 val inputData = Array.fill[Double](12)(1d) 99 100 // this will disposed at the end of the predict call on Predictor. 101 val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2), dtype = DType.Float64)) 102 103 val testPredictor = new MyPredictor("xyz", inputDescriptor) 104 105 Mockito.doReturn(predictResult).when(testPredictor.mockModule) 106 .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) 107 108 val testFun = testPredictor.predict(IndexedSeq(inputData)) 109 110 assert(testFun.size == 1, "output size should be 1 ") 111 112 assert(testFun(0)(0).getClass == 1d.getClass) 113 114 assert(Array.fill[Double](12)(1d).mkString == testFun(0).mkString) 115 116 // Verify that the module was bound with batch size 1 and rebound back to the original 117 // input descriptor. the number of times is twice here because loadModule overrides the 118 // initial bind. 119 Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], 120 any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] 121 , any[Option[BaseModule]], any[String]) 122 } 123 124 test("PredictorSuite-testWithNDArray") { 125 val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2), 126 layout = Layout.NCHW)) 127 val inputData = NDArray.ones(Shape(1, 3, 2, 2)) 128 129 // this will disposed at the end of the predict call on Predictor. 130 val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2))) 131 132 val testPredictor = new MyPredictor("xyz", inputDescriptor) 133 134 Mockito.doReturn(predictResult).when(testPredictor.mockModule) 135 .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean]) 136 137 val testFun = testPredictor.predictWithNDArray(IndexedSeq(inputData)) 138 139 assert(testFun.size == 1, "output size should be 1") 140 141 assert(Array.fill[Float](12)(1).mkString == testFun(0).toArray.mkString) 142 143 Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]], 144 any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean] 145 , any[Option[BaseModule]], any[String]) 146 } 147} 148