1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnetexamples.neuralstyle.end2end 19 20import java.io.File 21 22import com.sksamuel.scrimage.{Image, Pixel} 23import com.sksamuel.scrimage.filter.GaussianBlurFilter 24import com.sksamuel.scrimage.nio.JpegWriter 25import org.apache.mxnet.{Context, NDArray, Shape} 26 27 28object DataProcessing { 29 30 def preprocessContentImage(path: String, 31 dShape: Shape = null, ctx: Context): NDArray = { 32 val img = Image.fromFile(new File(path)) 33 val resizedImg = img.scaleTo(dShape(3), dShape(2)) 34 val sample = NDArray.empty(Shape(1, 3, resizedImg.height, resizedImg.width), ctx) 35 val datas = { 36 val rgbs = resizedImg.iterator.toArray.map { p => 37 (p.red, p.green, p.blue) 38 } 39 val r = rgbs.map(_._1 - 123.68f) 40 val g = rgbs.map(_._2 - 116.779f) 41 val b = rgbs.map(_._3 - 103.939f) 42 r ++ g ++ b 43 } 44 sample.set(datas) 45 sample 46 } 47 48 def preprocessStyleImage(path: String, shape: Shape, ctx: Context): NDArray = { 49 val img = Image.fromFile(new File(path)) 50 val resizedImg = img.scaleTo(shape(3), shape(2)) 51 val sample = NDArray.empty(Shape(1, 3, shape(2), shape(3)), ctx) 52 val datas = { 53 val rgbs = resizedImg.iterator.toArray.map { p => 54 (p.red, p.green, p.blue) 55 } 56 val r = rgbs.map(_._1 - 123.68f) 57 val g = rgbs.map(_._2 - 116.779f) 58 val b = rgbs.map(_._3 - 103.939f) 59 r ++ g ++ b 60 } 61 sample.set(datas) 62 sample 63 } 64 65 def clip(array: Array[Float]): Array[Float] = array.map { a => 66 if (a < 0) 0f 67 else if (a > 255) 255f 68 else a 69 } 70 71 def postprocessImage(img: NDArray): Image = { 72 val datas = img.toArray 73 val spatialSize = img.shape(2) * img.shape(3) 74 val r = clip(datas.take(spatialSize).map(_ + 123.68f)) 75 val g = clip(datas.drop(spatialSize).take(spatialSize).map(_ + 116.779f)) 76 val b = clip(datas.takeRight(spatialSize).map(_ + 103.939f)) 77 val pixels = for (i <- 0 until spatialSize) 78 yield Pixel(r(i).toInt, g(i).toInt, b(i).toInt, 255) 79 Image(img.shape(3), img.shape(2), pixels.toArray) 80 } 81 82 def saveImage(img: NDArray, filename: String, radius: Int): Unit = { 83 val out = postprocessImage(img) 84 val gauss = GaussianBlurFilter(radius).op 85 val result = Image(out.width, out.height) 86 gauss.filter(out.awt, result.awt) 87 result.output(filename)(JpegWriter()) 88 } 89} 90