1/* 2 * Licensed to the Apache Software Foundation (ASF) under one or more 3 * contributor license agreements. See the NOTICE file distributed with 4 * this work for additional information regarding copyright ownership. 5 * The ASF licenses this file to You under the Apache License, Version 2.0 6 * (the "License"); you may not use this file except in compliance with 7 * the License. You may obtain a copy of the License at 8 * 9 * http://www.apache.org/licenses/LICENSE-2.0 10 * 11 * Unless required by applicable law or agreed to in writing, software 12 * distributed under the License is distributed on an "AS IS" BASIS, 13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 14 * See the License for the specific language governing permissions and 15 * limitations under the License. 16 */ 17 18package org.apache.spark.ml.tree.impl 19 20import java.io.IOException 21 22import scala.collection.mutable 23 24import org.apache.hadoop.fs.Path 25 26import org.apache.spark.internal.Logging 27import org.apache.spark.ml.tree.{LearningNode, Split} 28import org.apache.spark.rdd.RDD 29import org.apache.spark.storage.StorageLevel 30 31 32/** 33 * This is used by the node id cache to find the child id that a data point would belong to. 34 * @param split Split information. 35 * @param nodeIndex The current node index of a data point that this will update. 36 */ 37private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) { 38 39 /** 40 * Determine a child node index based on the feature value and the split. 41 * @param binnedFeature Binned feature value. 42 * @param splits Split information to convert the bin indices to approximate feature values. 43 * @return Child node index to update to. 44 */ 45 def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = { 46 if (split.shouldGoLeft(binnedFeature, splits)) { 47 LearningNode.leftChildIndex(nodeIndex) 48 } else { 49 LearningNode.rightChildIndex(nodeIndex) 50 } 51 } 52} 53 54/** 55 * Each TreePoint belongs to a particular node per tree. 56 * Each row in the nodeIdsForInstances RDD is an array over trees of the node index 57 * in each tree. Initially, values should all be 1 for root node. 58 * The nodeIdsForInstances RDD needs to be updated at each iteration. 59 * @param nodeIdsForInstances The initial values in the cache 60 * (should be an Array of all 1's (meaning the root nodes)). 61 * @param checkpointInterval The checkpointing interval 62 * (how often should the cache be checkpointed.). 63 */ 64private[spark] class NodeIdCache( 65 var nodeIdsForInstances: RDD[Array[Int]], 66 val checkpointInterval: Int) extends Logging { 67 68 // Keep a reference to a previous node Ids for instances. 69 // Because we will keep on re-persisting updated node Ids, 70 // we want to unpersist the previous RDD. 71 private var prevNodeIdsForInstances: RDD[Array[Int]] = null 72 73 // To keep track of the past checkpointed RDDs. 74 private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]() 75 private var rddUpdateCount = 0 76 77 // Indicates whether we can checkpoint 78 private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty 79 80 // Hadoop Configuration for deleting checkpoints as needed 81 private val hadoopConf = nodeIdsForInstances.sparkContext.hadoopConfiguration 82 83 /** 84 * Update the node index values in the cache. 85 * This updates the RDD and its lineage. 86 * TODO: Passing bin information to executors seems unnecessary and costly. 87 * @param data The RDD of training rows. 88 * @param nodeIdUpdaters A map of node index updaters. 89 * The key is the indices of nodes that we want to update. 90 * @param splits Split information needed to find child node indices. 91 */ 92 def updateNodeIndices( 93 data: RDD[BaggedPoint[TreePoint]], 94 nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]], 95 splits: Array[Array[Split]]): Unit = { 96 if (prevNodeIdsForInstances != null) { 97 // Unpersist the previous one if one exists. 98 prevNodeIdsForInstances.unpersist() 99 } 100 101 prevNodeIdsForInstances = nodeIdsForInstances 102 nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) => 103 var treeId = 0 104 while (treeId < nodeIdUpdaters.length) { 105 val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null) 106 if (nodeIdUpdater != null) { 107 val featureIndex = nodeIdUpdater.split.featureIndex 108 val newNodeIndex = nodeIdUpdater.updateNodeIndex( 109 binnedFeature = point.datum.binnedFeatures(featureIndex), 110 splits = splits(featureIndex)) 111 ids(treeId) = newNodeIndex 112 } 113 treeId += 1 114 } 115 ids 116 } 117 118 // Keep on persisting new ones. 119 nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK) 120 rddUpdateCount += 1 121 122 // Handle checkpointing if the directory is not None. 123 if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) { 124 // Let's see if we can delete previous checkpoints. 125 var canDelete = true 126 while (checkpointQueue.size > 1 && canDelete) { 127 // We can delete the oldest checkpoint iff 128 // the next checkpoint actually exists in the file system. 129 if (checkpointQueue(1).getCheckpointFile.isDefined) { 130 val old = checkpointQueue.dequeue() 131 // Since the old checkpoint is not deleted by Spark, we'll manually delete it here. 132 try { 133 val path = new Path(old.getCheckpointFile.get) 134 val fs = path.getFileSystem(hadoopConf) 135 fs.delete(path, true) 136 } catch { 137 case e: IOException => 138 logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + 139 s" file: ${old.getCheckpointFile.get}") 140 } 141 } else { 142 canDelete = false 143 } 144 } 145 146 nodeIdsForInstances.checkpoint() 147 checkpointQueue.enqueue(nodeIdsForInstances) 148 } 149 } 150 151 /** 152 * Call this after training is finished to delete any remaining checkpoints. 153 */ 154 def deleteAllCheckpoints(): Unit = { 155 while (checkpointQueue.nonEmpty) { 156 val old = checkpointQueue.dequeue() 157 if (old.getCheckpointFile.isDefined) { 158 try { 159 val path = new Path(old.getCheckpointFile.get) 160 val fs = path.getFileSystem(hadoopConf) 161 fs.delete(path, true) 162 } catch { 163 case e: IOException => 164 logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" + 165 s" file: ${old.getCheckpointFile.get}") 166 } 167 } 168 } 169 if (prevNodeIdsForInstances != null) { 170 // Unpersist the previous one if one exists. 171 prevNodeIdsForInstances.unpersist() 172 } 173 } 174} 175 176private[spark] object NodeIdCache { 177 /** 178 * Initialize the node Id cache with initial node Id values. 179 * @param data The RDD of training rows. 180 * @param numTrees The number of trees that we want to create cache for. 181 * @param checkpointInterval The checkpointing interval 182 * (how often should the cache be checkpointed.). 183 * @param initVal The initial values in the cache. 184 * @return A node Id cache containing an RDD of initial root node Indices. 185 */ 186 def init( 187 data: RDD[BaggedPoint[TreePoint]], 188 numTrees: Int, 189 checkpointInterval: Int, 190 initVal: Int = 1): NodeIdCache = { 191 new NodeIdCache( 192 data.map(_ => Array.fill[Int](numTrees)(initVal)), 193 checkpointInterval) 194 } 195} 196