1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.mxnet.spark.transformer 19 20import java.util.UUID 21 22import org.apache.mxnet.spark.{MXNetModel, MXNetParams} 23import org.apache.mxnet.{Context, Shape, Symbol} 24import org.apache.spark.SparkContext 25import org.apache.spark.ml.param.ParamMap 26import org.apache.spark.ml.util.{MLReadable, MLReader, MLWritable, MLWriter} 27import org.apache.spark.ml.{PredictionModel, Predictor} 28import org.apache.spark.mllib.linalg.Vector 29import org.apache.spark.mllib.regression.LabeledPoint 30import org.apache.spark.sql.DataFrame 31import org.slf4j.{Logger, LoggerFactory} 32 33 34class MXNet extends Predictor[Vector, MXNet, MXNetModelWrap] { 35 36 private val logger: Logger = LoggerFactory.getLogger(classOf[MXNet]) 37 private val p: MXNetParams = new MXNetParams 38 private var _featuresCol: String = _ 39 private var _labelCol: String = _ 40 41 override val uid = UUID.randomUUID().toString 42 43 override def train(dataset: DataFrame) : MXNetModelWrap = { 44 val lps = dataset.select(getFeaturesCol, getLabelCol).rdd 45 .map(row => new LabeledPoint(row.getAs[Double](getLabelCol), 46 row.getAs[Vector](getFeaturesCol))) 47 val mxNet = new org.apache.mxnet.spark.MXNet() 48 .setBatchSize(p.batchSize) 49 .setLabelName(p.labelName) 50 .setContext(p.context) 51 .setDimension(p.dimension) 52 .setNetwork(p.getNetwork) 53 .setNumEpoch(p.numEpoch) 54 .setNumServer(p.numServer) 55 .setNumWorker(p.numWorker) 56 .setExecutorJars(p.jars.mkString(",")) 57 val fitted = mxNet.fit(lps) 58 new MXNetModelWrap(lps.sparkContext, fitted, uid) 59 } 60 61 override def copy(extra: ParamMap) : MXNet = defaultCopy(extra) 62 63 def setBatchSize(batchSize: Int): this.type = { 64 p.batchSize = batchSize 65 this 66 } 67 68 def setNumEpoch(numEpoch: Int): this.type = { 69 p.numEpoch = numEpoch 70 this 71 } 72 73 def setDimension(dimension: Shape): this.type = { 74 p.dimension = dimension 75 this 76 } 77 78 def setNetwork(network: Symbol): this.type = { 79 p.setNetwork(network) 80 this 81 } 82 83 def setContext(ctx: Array[Context]): this.type = { 84 p.context = ctx 85 this 86 } 87 88 def setNumWorker(numWorker: Int): this.type = { 89 p.numWorker = numWorker 90 this 91 } 92 93 def setNumServer(numServer: Int): this.type = { 94 p.numServer = numServer 95 this 96 } 97 98 def setDataName(name: String): this.type = { 99 p.dataName = name 100 this 101 } 102 103 def setLabelName(name: String): this.type = { 104 p.labelName = name 105 this 106 } 107 108 /** 109 * The application (including parameter scheduler & servers) 110 * will exist if it hasn't received heart beat for over timeout seconds 111 * @param timeout timeout in seconds (default 300) 112 */ 113 def setTimeout(timeout: Int): this.type = { 114 p.timeout = timeout 115 this 116 } 117 118 /** 119 * These jars are required by the KVStores at runtime. 120 * They will be uploaded and distributed to each node automatically 121 * @param jars jars required by the KVStore at runtime. 122 */ 123 def setExecutorJars(jars: String): this.type = { 124 p.jars = jars.split(",|:") 125 this 126 } 127 128 def setJava(java: String): this.type = { 129 p.javabin = java 130 this 131 } 132 133} 134 135class MXNetModelWrap(sc: SparkContext, mxNet: MXNetModel, uuid: String) 136 extends PredictionModel[Vector, MXNetModelWrap] with Serializable with MLWritable { 137 138 override def copy(extra: ParamMap): MXNetModelWrap = { 139 copyValues(new MXNetModelWrap(sc, mxNet, uuid)).setParent(parent) 140 } 141 142 override val uid: String = uuid 143 144 override def predict(features: Vector) : Double = { 145 val probArrays = mxNet.predict(features) 146 val prob = probArrays(0) 147 val arr = prob.get.toArray 148 if (arr.length == 1) { 149 arr(0) 150 } else { 151 arr.indexOf(arr.max) 152 } 153 154 } 155 156 protected[MXNetModelWrap] class MXNetModelWriter(instance: MXNetModelWrap) extends MLWriter { 157 override protected def saveImpl(path: String): Unit = { 158 mxNet.save(sc, path) 159 } 160 } 161 162 override def write: MLWriter = new MXNetModelWriter(this) 163 164 object MXNetModelWrap extends MLReadable[MXNetModel] { 165 override def read: MLReader[MXNetModel] = new MXNetModelReader 166 override def load(path: String): MXNetModel = super.load(path) 167 private class MXNetModelReader extends MLReader[MXNetModel] { 168 override def load(path: String): MXNetModel = MXNetModel.load(sc, path) 169 } 170 } 171 172} 173