1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnetexamples.neuralstyle.end2end
19
20import java.io.File
21
22import com.sksamuel.scrimage.{Image, Pixel}
23import com.sksamuel.scrimage.filter.GaussianBlurFilter
24import com.sksamuel.scrimage.nio.JpegWriter
25import org.apache.mxnet.{Context, NDArray, Shape}
26
27
28object DataProcessing {
29
30  def preprocessContentImage(path: String,
31      dShape: Shape = null, ctx: Context): NDArray = {
32    val img = Image.fromFile(new File(path))
33    val resizedImg = img.scaleTo(dShape(3), dShape(2))
34    val sample = NDArray.empty(Shape(1, 3, resizedImg.height, resizedImg.width), ctx)
35    val datas = {
36      val rgbs = resizedImg.iterator.toArray.map { p =>
37        (p.red, p.green, p.blue)
38      }
39      val r = rgbs.map(_._1 - 123.68f)
40      val g = rgbs.map(_._2 - 116.779f)
41      val b = rgbs.map(_._3 - 103.939f)
42      r ++ g ++ b
43    }
44    sample.set(datas)
45    sample
46  }
47
48  def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = {
49    val img = Image.fromFile(new File(path))
50    val resizedImg = img.scaleTo(shape(3), shape(2))
51    val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx)
52    val datas = {
53      val rgbs = resizedImg.iterator.toArray.map { p =>
54        (p.red, p.green, p.blue)
55      }
56      val r = rgbs.map(_._1 - 123.68f)
57      val g = rgbs.map(_._2 - 116.779f)
58      val b = rgbs.map(_._3 - 103.939f)
59      r ++ g ++ b
60    }
61    sample.set(datas)
62    sample
63  }
64
65  def clip(array: Array[Float]): Array[Float] = array.map { a =>
66    if (a < 0) 0f
67    else if (a > 255) 255f
68    else a
69  }
70
71  def postprocessImage(img: NDArray): Image = {
72    val datas = img.toArray
73    val spatialSize = img.shape(2) * img.shape(3)
74    val r = clip(datas.take(spatialSize).map(_ + 123.68f))
75    val g = clip(datas.drop(spatialSize).take(spatialSize).map(_ + 116.779f))
76    val b = clip(datas.takeRight(spatialSize).map(_ + 103.939f))
77    val pixels = for (i <- 0 until spatialSize)
78      yield Pixel(r(i).toInt, g(i).toInt, b(i).toInt, 255)
79    Image(img.shape(3), img.shape(2), pixels.toArray)
80  }
81
82  def saveImage(img: NDArray, filename: String, radius: Int): Unit = {
83    val out = postprocessImage(img)
84    val gauss = GaussianBlurFilter(radius).op
85    val result = Image(out.width, out.height)
86    gauss.filter(out.awt, result.awt)
87    result.output(filename)(JpegWriter())
88  }
89}
90