1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet.javaapi
19// scalastyle:off
20import java.awt.image.BufferedImage
21// scalastyle:on
22import java.io.InputStream
23import scala.collection.JavaConverters._
24
25object Image {
26  /**
27    * Decode image with OpenCV.
28    * Note: return image in RGB by default, instead of OpenCV's default BGR.
29    * @param buf   Buffer containing binary encoded image
30    * @param flag  Convert decoded image to grayscale (0) or color (1).
31    * @param toRGB Whether to convert decoded image
32    *              to mxnet's default RGB format (instead of opencv's default BGR).
33    * @return NDArray in HWC format with DType [[DType.UInt8]]
34    */
35  def imDecode(buf: Array[Byte], flag: Int, toRGB: Boolean): NDArray = {
36    org.apache.mxnet.Image.imDecode(buf, flag, toRGB, None)
37  }
38
39  /**
40    * Decode image with OpenCV.
41    * Note: return image in RGB by default, instead of OpenCV's default BGR.
42    * @param buf   Buffer containing binary encoded image
43    * @return NDArray in HWC format with DType [[DType.UInt8]]
44    */
45  def imDecode(buf: Array[Byte]): NDArray = {
46    imDecode(buf, 1, true)
47  }
48
49  /**
50    * Same imageDecode with InputStream
51    *
52    * @param inputStream the inputStream of the image
53    * @param flag        Convert decoded image to grayscale (0) or color (1).
54    * @param toRGB       Whether to convert decoded image
55    * @return NDArray in HWC format with DType [[DType.UInt8]]
56    */
57  def imDecode(inputStream: InputStream, flag: Int, toRGB: Boolean): NDArray = {
58    org.apache.mxnet.Image.imDecode(inputStream, flag, toRGB, None)
59  }
60
61  /**
62    * Same imageDecode with InputStream
63    *
64    * @param inputStream the inputStream of the image
65    * @return NDArray in HWC format with DType [[DType.UInt8]]
66    */
67  def imDecode(inputStream: InputStream): NDArray = {
68    imDecode(inputStream, 1, true)
69  }
70
71  /**
72    * Read and decode image with OpenCV.
73    * Note: return image in RGB by default, instead of OpenCV's default BGR.
74    * @param filename Name of the image file to be loaded.
75    * @param flag     Convert decoded image to grayscale (0) or color (1).
76    * @param toRGB    Whether to convert decoded image to mxnet's default RGB format
77    *                 (instead of opencv's default BGR).
78    * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
79    */
80  def imRead(filename: String, flag: Int, toRGB: Boolean): NDArray = {
81    org.apache.mxnet.Image.imRead(filename, Some(flag), Some(toRGB), None)
82  }
83
84  /**
85    * Read and decode image with OpenCV.
86    * Note: return image in RGB by default, instead of OpenCV's default BGR.
87    * @param filename Name of the image file to be loaded.
88    * @return org.apache.mxnet.NDArray in HWC format with DType [[DType.UInt8]]
89    */
90  def imRead(filename: String): NDArray = {
91    imRead(filename, 1, true)
92  }
93
94  /**
95    * Resize image with OpenCV.
96    * @param src    source image in NDArray
97    * @param w      Width of resized image.
98    * @param h      Height of resized image.
99    * @param interp Interpolation method (default=cv2.INTER_LINEAR).
100    * @return org.apache.mxnet.NDArray
101    */
102  def imResize(src: NDArray, w: Int, h: Int, interp: Integer): NDArray = {
103    val interpVal = if (interp == null) None else Some(interp.intValue())
104    org.apache.mxnet.Image.imResize(src, w, h, interpVal, None)
105  }
106
107  /**
108    * Resize image with OpenCV.
109    * @param src    source image in NDArray
110    * @param w      Width of resized image.
111    * @param h      Height of resized image.
112    * @return org.apache.mxnet.NDArray
113    */
114  def imResize(src: NDArray, w: Int, h: Int): NDArray = {
115    imResize(src, w, h, null)
116  }
117
118  /**
119    * Do a fixed crop on the image
120    * @param src Src image in NDArray
121    * @param x0  starting x point
122    * @param y0  starting y point
123    * @param w   width of the image
124    * @param h   height of the image
125    * @return cropped NDArray
126    */
127  def fixedCrop(src: NDArray, x0: Int, y0: Int, w: Int, h: Int): NDArray = {
128    org.apache.mxnet.Image.fixedCrop(src, x0, y0, w, h)
129  }
130
131  /**
132    * Convert a NDArray image to a real image
133    * The time cost will increase if the image resolution is big
134    * @param src Source image file in RGB
135    * @return Buffered Image
136    */
137  def toImage(src: NDArray): BufferedImage = {
138    org.apache.mxnet.Image.toImage(src)
139  }
140
141  /**
142    * Draw bounding boxes on the image
143    * @param src        buffered image to draw on
144    * @param coordinate Contains Map of xmin, xmax, ymin, ymax
145    *                   corresponding to top-left and down-right points
146    * @param names      The name set of the bounding box
147    */
148  def drawBoundingBox(src: BufferedImage,
149                      coordinate: java.util.List[
150                        java.util.Map[java.lang.String, java.lang.Integer]],
151                      names: java.util.List[java.lang.String]): Unit = {
152    val coord = coordinate.asScala.map(
153      _.asScala.map{case (name, value) => (name, Integer2int(value))}.toMap).toArray
154    org.apache.mxnet.Image.drawBoundingBox(src, coord, Option(names.asScala.toArray))
155  }
156
157}
158