1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet.infer
19
20import org.apache.mxnet.io.NDArrayIter
21import org.apache.mxnet.module.{BaseModule, Module}
22import org.apache.mxnet._
23import org.mockito.Matchers._
24import org.mockito.Mockito
25import org.scalatest.{BeforeAndAfterAll, FunSuite}
26
27class PredictorSuite extends FunSuite with BeforeAndAfterAll {
28
29  class MyPredictor(val modelPathPrefix: String,
30                    override val inputDescriptors: IndexedSeq[DataDesc])
31    extends Predictor(modelPathPrefix, inputDescriptors, epoch = Some(0)) {
32
33    override def loadModule(): Module = mockModule
34
35    val getIDescriptor: IndexedSeq[DataDesc] = iDescriptors
36    val getBatchSize: Int = batchSize
37    val getBatchIndex: Int = batchIndex
38
39    lazy val mockModule: Module = Mockito.mock(classOf[Module])
40  }
41
42  test("PredictorSuite-testPredictorConstruction") {
43    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2),
44      layout = Layout.NCHW))
45
46    val mockPredictor = new MyPredictor("xyz", inputDescriptor)
47
48    assert(mockPredictor.getBatchSize == 1)
49    assert(mockPredictor.getBatchIndex == inputDescriptor(0).layout.indexOf('N'))
50
51    val inputDescriptor2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(1, 3, 2, 2),
52      layout = Layout.NCHW),
53      new DataDesc("data", Shape(2, 3, 2, 2), layout = Layout.NCHW))
54
55    assertThrows[IllegalArgumentException] {
56      val mockPredictor = new MyPredictor("xyz", inputDescriptor2)
57    }
58
59    // batchsize is defaulted to 1
60    val iDesc2 = IndexedSeq[DataDesc](new DataDesc("data", Shape(3, 2, 2), layout = "CHW"))
61    val p2 = new MyPredictor("xyz", inputDescriptor)
62    assert(p2.getBatchSize == 1, "should use a default batch size of 1")
63
64  }
65
66  test("PredictorSuite-testWithFlatArrays") {
67
68    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
69      layout = Layout.NCHW))
70    val inputData = Array.fill[Float](12)(1)
71
72    // this will disposed at the end of the predict call on Predictor.
73    val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2)))
74
75    val testPredictor = new MyPredictor("xyz", inputDescriptor)
76
77    Mockito.doReturn(predictResult).when(testPredictor.mockModule)
78      .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
79
80    val testFun = testPredictor.predict(IndexedSeq(inputData))
81
82    assert(testFun.size == 1, "output size should be 1 ")
83
84    assert(Array.fill[Float](12)(1).mkString == testFun(0).mkString)
85
86    // Verify that the module was bound with batch size 1 and rebound back to the original
87    // input descriptor. the number of times is twice here because loadModule overrides the
88    // initial bind.
89    Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
90      any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
91      , any[Option[BaseModule]], any[String])
92  }
93
94  test("PredictorSuite-testWithFlatFloat64Arrays") {
95
96    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
97      layout = Layout.NCHW, dtype = DType.Float64))
98    val inputData = Array.fill[Double](12)(1d)
99
100    // this will disposed at the end of the predict call on Predictor.
101    val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2), dtype = DType.Float64))
102
103    val testPredictor = new MyPredictor("xyz", inputDescriptor)
104
105    Mockito.doReturn(predictResult).when(testPredictor.mockModule)
106      .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
107
108    val testFun = testPredictor.predict(IndexedSeq(inputData))
109
110    assert(testFun.size == 1, "output size should be 1 ")
111
112    assert(testFun(0)(0).getClass == 1d.getClass)
113
114    assert(Array.fill[Double](12)(1d).mkString == testFun(0).mkString)
115
116    // Verify that the module was bound with batch size 1 and rebound back to the original
117    // input descriptor. the number of times is twice here because loadModule overrides the
118    // initial bind.
119    Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
120      any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
121      , any[Option[BaseModule]], any[String])
122  }
123
124  test("PredictorSuite-testWithNDArray") {
125    val inputDescriptor = IndexedSeq[DataDesc](new DataDesc("data", Shape(2, 3, 2, 2),
126      layout = Layout.NCHW))
127    val inputData = NDArray.ones(Shape(1, 3, 2, 2))
128
129    // this will disposed at the end of the predict call on Predictor.
130    val predictResult = IndexedSeq(NDArray.ones(Shape(1, 3, 2, 2)))
131
132    val testPredictor = new MyPredictor("xyz", inputDescriptor)
133
134    Mockito.doReturn(predictResult).when(testPredictor.mockModule)
135      .predict(any(classOf[NDArrayIter]), any[Int], any[Boolean])
136
137    val testFun = testPredictor.predictWithNDArray(IndexedSeq(inputData))
138
139    assert(testFun.size == 1, "output size should be 1")
140
141    assert(Array.fill[Float](12)(1).mkString == testFun(0).toArray.mkString)
142
143    Mockito.verify(testPredictor.mockModule, Mockito.times(2)).bind(any[IndexedSeq[DataDesc]],
144      any[Option[IndexedSeq[DataDesc]]], any[Boolean], any[Boolean], any[Boolean]
145      , any[Option[BaseModule]], any[String])
146  }
147}
148