1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet 19 20import java.io.File 21import java.net.URL 22 23import javax.imageio.ImageIO 24import org.apache.commons.io.FileUtils 25import org.scalatest.{BeforeAndAfterAll, FunSuite} 26import org.slf4j.LoggerFactory 27 28class ImageSuite extends FunSuite with BeforeAndAfterAll { 29 private var imLocation = "" 30 private val logger = LoggerFactory.getLogger(classOf[ImageSuite]) 31 32 private def downloadUrl(url: String, filePath: String, maxRetry: Option[Int] = None) : Unit = { 33 val tmpFile = new File(filePath) 34 var retry = maxRetry.getOrElse(3) 35 var success = false 36 if (!tmpFile.exists()) { 37 while (retry > 0 && !success) { 38 try { 39 FileUtils.copyURLToFile(new URL(url), tmpFile) 40 success = true 41 } catch { 42 case e: Exception => retry -= 1 43 } 44 } 45 } else { 46 success = true 47 } 48 if (!success) throw new Exception(s"$url Download failed!") 49 } 50 51 override def beforeAll(): Unit = { 52 val tempDirPath = System.getProperty("java.io.tmpdir") 53 imLocation = tempDirPath + "/inputImages/Pug-Cookie.jpg" 54 downloadUrl("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg", 55 imLocation) 56 } 57 58 test("Test load image") { 59 val nd = Image.imRead(imLocation) 60 logger.debug(s"OpenCV load image with shape: ${nd.shape}") 61 require(nd.shape == Shape(576, 1024, 3), "image shape not Match!") 62 } 63 64 test("Test load image from Socket") { 65 val url = new URL("https://s3.amazonaws.com/model-server/inputs/Pug-Cookie.jpg") 66 val inputStream = url.openStream 67 val nd = Image.imDecode(inputStream) 68 logger.debug(s"OpenCV load image with shape: ${nd.shape}") 69 require(nd.shape == Shape(576, 1024, 3), "image shape not Match!") 70 } 71 72 test("Test resize image") { 73 val nd = Image.imRead(imLocation) 74 val resizeIm = Image.imResize(nd, 224, 224) 75 logger.debug(s"OpenCV resize image with shape: ${resizeIm.shape}") 76 require(resizeIm.shape == Shape(224, 224, 3), "image shape not Match!") 77 } 78 79 test("Test crop image") { 80 val nd = Image.imRead(imLocation) 81 val nd2 = Image.fixedCrop(nd, 0, 0, 224, 224) 82 require(nd2.shape == Shape(224, 224, 3), "image shape not Match!") 83 } 84 85 test("Test apply border") { 86 val nd = Image.imRead(imLocation) 87 val nd2 = Image.copyMakeBorder(nd, 1, 1, 1, 1) 88 require(nd2.shape == Shape(578, 1026, 3), s"image shape not Match!") 89 } 90 91 test("Test convert to Image") { 92 val nd = Image.imRead(imLocation) 93 val resizeIm = Image.imResize(nd, 224, 224) 94 val tempDirPath = System.getProperty("java.io.tmpdir") 95 val img = Image.toImage(resizeIm) 96 ImageIO.write(img, "png", new File(tempDirPath + "/inputImages/out.png")) 97 logger.debug(s"converted image stored in ${tempDirPath + "/inputImages/out.png"}") 98 } 99 100 test("Test draw Bounding box") { 101 val buf = ImageIO.read(new File(imLocation)) 102 val box = Array( 103 Map("xmin" -> 190, "xmax" -> 850, "ymin" -> 50, "ymax" -> 450), 104 Map("xmin" -> 200, "xmax" -> 350, "ymin" -> 440, "ymax" -> 530) 105 ) 106 val names = Array("pug", "cookie") 107 Image.drawBoundingBox(buf, box, Some(names), fontSizeMult = Some(1.4f)) 108 val tempDirPath = System.getProperty("java.io.tmpdir") 109 ImageIO.write(buf, "png", new File(tempDirPath + "/inputImages/out2.png")) 110 logger.debug(s"converted image stored in ${tempDirPath + "/inputImages/out2.png"}") 111 for (coord <- box) { 112 val topLeft = buf.getRGB(coord("xmin"), coord("ymin")) 113 val downLeft = buf.getRGB(coord("xmin"), coord("ymax")) 114 val topRight = buf.getRGB(coord("xmax"), coord("ymin")) 115 val downRight = buf.getRGB(coord("xmax"), coord("ymax")) 116 require(downLeft == downRight) 117 require(topRight == downRight) 118 } 119 } 120 121} 122