1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.mxnet.spark.transformer
19
20import java.util.UUID
21
22import org.apache.mxnet.spark.{MXNetModel, MXNetParams}
23import org.apache.mxnet.{Context, Shape, Symbol}
24import org.apache.spark.SparkContext
25import org.apache.spark.ml.param.ParamMap
26import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter}
27import org.apache.spark.ml.{PredictionModel, Predictor}
28import org.apache.spark.mllib.linalg.Vector
29import org.apache.spark.mllib.regression.LabeledPoint
30import org.apache.spark.sql.DataFrame
31import org.slf4j.{Logger, LoggerFactory}
32
33
34class MXNet extends Predictor[Vector, MXNet, MXNetModelWrap] {
35
36  private val logger: Logger = LoggerFactory.getLogger(classOf[MXNet])
37  private val p: MXNetParams = new MXNetParams
38  private var _featuresCol: String = _
39  private var _labelCol: String = _
40
41  override val uid = UUID.randomUUID().toString
42
43  override def train(dataset: DataFrame) : MXNetModelWrap = {
44    val lps = dataset.select(getFeaturesCol, getLabelCol).rdd
45      .map(row => new LabeledPoint(row.getAs[Double](getLabelCol),
46        row.getAs[Vector](getFeaturesCol)))
47    val mxNet = new org.apache.mxnet.spark.MXNet()
48      .setBatchSize(p.batchSize)
49      .setLabelName(p.labelName)
50      .setContext(p.context)
51      .setDimension(p.dimension)
52      .setNetwork(p.getNetwork)
53      .setNumEpoch(p.numEpoch)
54      .setNumServer(p.numServer)
55      .setNumWorker(p.numWorker)
56      .setExecutorJars(p.jars.mkString(","))
57    val fitted = mxNet.fit(lps)
58    new MXNetModelWrap(lps.sparkContext, fitted, uid)
59  }
60
61  override def copy(extra: ParamMap) : MXNet = defaultCopy(extra)
62
63  def setBatchSize(batchSize: Int): this.type = {
64    p.batchSize = batchSize
65    this
66  }
67
68  def setNumEpoch(numEpoch: Int): this.type = {
69    p.numEpoch = numEpoch
70    this
71  }
72
73  def setDimension(dimension: Shape): this.type = {
74    p.dimension = dimension
75    this
76  }
77
78  def setNetwork(network: Symbol): this.type = {
79    p.setNetwork(network)
80    this
81  }
82
83  def setContext(ctx: Array[Context]): this.type = {
84    p.context = ctx
85    this
86  }
87
88  def setNumWorker(numWorker: Int): this.type = {
89    p.numWorker = numWorker
90    this
91  }
92
93  def setNumServer(numServer: Int): this.type = {
94    p.numServer = numServer
95    this
96  }
97
98  def setDataName(name: String): this.type = {
99    p.dataName = name
100    this
101  }
102
103  def setLabelName(name: String): this.type = {
104    p.labelName = name
105    this
106  }
107
108  /**
109    * The application (including parameter scheduler & servers)
110    * will exist if it hasn't received heart beat for over timeout seconds
111    * @param timeout timeout in seconds (default 300)
112    */
113  def setTimeout(timeout: Int): this.type = {
114    p.timeout = timeout
115    this
116  }
117
118  /**
119    * These jars are required by the KVStores at runtime.
120    * They will be uploaded and distributed to each node automatically
121    * @param jars jars required by the KVStore at runtime.
122    */
123  def setExecutorJars(jars: String): this.type = {
124    p.jars = jars.split(",|:")
125    this
126  }
127
128  def setJava(java: String): this.type = {
129    p.javabin = java
130    this
131  }
132
133}
134
135class MXNetModelWrap(sc: SparkContext, mxNet: MXNetModel, uuid: String)
136  extends PredictionModel[Vector, MXNetModelWrap] with Serializable with MLWritable {
137
138  override def copy(extra: ParamMap): MXNetModelWrap = {
139    copyValues(new MXNetModelWrap(sc, mxNet, uuid)).setParent(parent)
140  }
141
142  override val uid: String = uuid
143
144  override def predict(features: Vector) : Double = {
145    val probArrays = mxNet.predict(features)
146    val prob = probArrays(0)
147    val arr = prob.get.toArray
148    if (arr.length == 1) {
149      arr(0)
150    } else {
151      arr.indexOf(arr.max)
152    }
153
154  }
155
156  protected[MXNetModelWrap] class MXNetModelWriter(instance: MXNetModelWrap) extends MLWriter {
157    override protected def saveImpl(path: String): Unit = {
158      mxNet.save(sc, path)
159    }
160  }
161
162  override def write: MLWriter = new MXNetModelWriter(this)
163
164  object MXNetModelWrap extends MLReadable[MXNetModel] {
165    override def read: MLReader[MXNetModel] = new MXNetModelReader
166    override def load(path: String): MXNetModel = super.load(path)
167    private class MXNetModelReader extends MLReader[MXNetModel] {
168      override def load(path: String): MXNetModel = MXNetModel.load(sc, path)
169    }
170  }
171
172}
173