1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet
19
20import java.io.File
21import java.net.URL
22
23import javax.imageio.ImageIO
24import org.apache.commons.io.FileUtils
25import org.scalatest.{BeforeAndAfterAll, FunSuite}
26import org.slf4j.LoggerFactory
27
28class ImageSuite extends FunSuite with BeforeAndAfterAll {
29  private var imLocation = ""
30  private val logger = LoggerFactory.getLogger(classOf[ImageSuite])
31
32  private def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = {
33    val tmpFile = new File(filePath)
34    var retry = maxRetry.getOrElse(3)
35    var success = false
36    if (!tmpFile.exists()) {
37      while (retry > 0 && !success) {
38        try {
39          FileUtils.copyURLToFile(new URL(url), tmpFile)
40          success = true
41        } catch {
42          case e: Exception => retry -= 1
43        }
44      }
45    } else {
46      success = true
47    }
48    if (!success) throw new Exception(s"$url Download failed!")
49  }
50
51  override def beforeAll(): Unit = {
52    val tempDirPath = System.getProperty("java.io.tmpdir")
53    imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg"
54    downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg",
55      imLocation)
56  }
57
58  test("Test load image") {
59    val nd = Image.imRead(imLocation)
60    logger.debug(s"OpenCV load image with shape: ${nd.shape}")
61    require(nd.shape == Shape(576, 1024, 3), "image shape not Match!")
62  }
63
64  test("Test load image from Socket") {
65    val url = new URL("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg")
66    val inputStream = url.openStream
67    val nd = Image.imDecode(inputStream)
68    logger.debug(s"OpenCV load image with shape: ${nd.shape}")
69    require(nd.shape == Shape(576, 1024, 3), "image shape not Match!")
70  }
71
72  test("Test resize image") {
73    val nd = Image.imRead(imLocation)
74    val resizeIm = Image.imResize(nd, 224, 224)
75    logger.debug(s"OpenCV resize image with shape: ${resizeIm.shape}")
76    require(resizeIm.shape == Shape(224, 224, 3), "image shape not Match!")
77  }
78
79  test("Test crop image") {
80    val nd = Image.imRead(imLocation)
81    val nd2 = Image.fixedCrop(nd, 0, 0, 224, 224)
82    require(nd2.shape == Shape(224, 224, 3), "image shape not Match!")
83  }
84
85  test("Test apply border") {
86    val nd = Image.imRead(imLocation)
87    val nd2 = Image.copyMakeBorder(nd, 1, 1, 1, 1)
88    require(nd2.shape == Shape(578, 1026, 3), s"image shape not Match!")
89  }
90
91  test("Test convert to Image") {
92    val nd = Image.imRead(imLocation)
93    val resizeIm = Image.imResize(nd, 224, 224)
94    val tempDirPath = System.getProperty("java.io.tmpdir")
95    val img = Image.toImage(resizeIm)
96    ImageIO.write(img, "png", new File(tempDirPath + "/inputImages/out.png"))
97    logger.debug(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}")
98  }
99
100  test("Test draw Bounding box") {
101    val buf = ImageIO.read(new File(imLocation))
102    val box = Array(
103      Map("xmin" -> 190, "xmax" -> 850, "ymin" -> 50, "ymax" -> 450),
104      Map("xmin" -> 200, "xmax" -> 350, "ymin" -> 440, "ymax" -> 530)
105    )
106    val names = Array("pug", "cookie")
107    Image.drawBoundingBox(buf, box, Some(names), fontSizeMult = Some(1.4f))
108    val tempDirPath = System.getProperty("java.io.tmpdir")
109    ImageIO.write(buf, "png", new File(tempDirPath + "/inputImages/out2.png"))
110    logger.debug(s"converted image stored in ${tempDirPath + "/inputImages/out2.png"}")
111    for (coord <- box) {
112      val topLeft = buf.getRGB(coord("xmin"), coord("ymin"))
113      val downLeft = buf.getRGB(coord("xmin"), coord("ymax"))
114      val topRight = buf.getRGB(coord("xmax"), coord("ymin"))
115      val downRight = buf.getRGB(coord("xmax"), coord("ymax"))
116      require(downLeft == downRight)
117      require(topRight == downRight)
118    }
119  }
120
121}
122