1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.ml.tree.impl
19
20import java.io.IOException
21
22import scala.collection.mutable
23
24import org.apache.hadoop.fs.Path
25
26import org.apache.spark.internal.Logging
27import org.apache.spark.ml.tree.{LearningNode, Split}
28import org.apache.spark.rdd.RDD
29import org.apache.spark.storage.StorageLevel
30
31
32/**
33 * This is used by the node id cache to find the child id that a data point would belong to.
34 * @param split Split information.
35 * @param nodeIndex The current node index of a data point that this will update.
36 */
37private[tree] case class NodeIndexUpdater(split: Split, nodeIndex: Int) {
38
39  /**
40   * Determine a child node index based on the feature value and the split.
41   * @param binnedFeature Binned feature value.
42   * @param splits Split information to convert the bin indices to approximate feature values.
43   * @return Child node index to update to.
44   */
45  def updateNodeIndex(binnedFeature: Int, splits: Array[Split]): Int = {
46    if (split.shouldGoLeft(binnedFeature, splits)) {
47      LearningNode.leftChildIndex(nodeIndex)
48    } else {
49      LearningNode.rightChildIndex(nodeIndex)
50    }
51  }
52}
53
54/**
55 * Each TreePoint belongs to a particular node per tree.
56 * Each row in the nodeIdsForInstances RDD is an array over trees of the node index
57 * in each tree. Initially, values should all be 1 for root node.
58 * The nodeIdsForInstances RDD needs to be updated at each iteration.
59 * @param nodeIdsForInstances The initial values in the cache
60 *                           (should be an Array of all 1's (meaning the root nodes)).
61 * @param checkpointInterval The checkpointing interval
62 *                           (how often should the cache be checkpointed.).
63 */
64private[spark] class NodeIdCache(
65  var nodeIdsForInstances: RDD[Array[Int]],
66  val checkpointInterval: Int) extends Logging {
67
68  // Keep a reference to a previous node Ids for instances.
69  // Because we will keep on re-persisting updated node Ids,
70  // we want to unpersist the previous RDD.
71  private var prevNodeIdsForInstances: RDD[Array[Int]] = null
72
73  // To keep track of the past checkpointed RDDs.
74  private val checkpointQueue = mutable.Queue[RDD[Array[Int]]]()
75  private var rddUpdateCount = 0
76
77  // Indicates whether we can checkpoint
78  private val canCheckpoint = nodeIdsForInstances.sparkContext.getCheckpointDir.nonEmpty
79
80  // Hadoop Configuration for deleting checkpoints as needed
81  private val hadoopConf = nodeIdsForInstances.sparkContext.hadoopConfiguration
82
83  /**
84   * Update the node index values in the cache.
85   * This updates the RDD and its lineage.
86   * TODO: Passing bin information to executors seems unnecessary and costly.
87   * @param data The RDD of training rows.
88   * @param nodeIdUpdaters A map of node index updaters.
89   *                       The key is the indices of nodes that we want to update.
90   * @param splits  Split information needed to find child node indices.
91   */
92  def updateNodeIndices(
93      data: RDD[BaggedPoint[TreePoint]],
94      nodeIdUpdaters: Array[mutable.Map[Int, NodeIndexUpdater]],
95      splits: Array[Array[Split]]): Unit = {
96    if (prevNodeIdsForInstances != null) {
97      // Unpersist the previous one if one exists.
98      prevNodeIdsForInstances.unpersist()
99    }
100
101    prevNodeIdsForInstances = nodeIdsForInstances
102    nodeIdsForInstances = data.zip(nodeIdsForInstances).map { case (point, ids) =>
103      var treeId = 0
104      while (treeId < nodeIdUpdaters.length) {
105        val nodeIdUpdater = nodeIdUpdaters(treeId).getOrElse(ids(treeId), null)
106        if (nodeIdUpdater != null) {
107          val featureIndex = nodeIdUpdater.split.featureIndex
108          val newNodeIndex = nodeIdUpdater.updateNodeIndex(
109            binnedFeature = point.datum.binnedFeatures(featureIndex),
110            splits = splits(featureIndex))
111          ids(treeId) = newNodeIndex
112        }
113        treeId += 1
114      }
115      ids
116    }
117
118    // Keep on persisting new ones.
119    nodeIdsForInstances.persist(StorageLevel.MEMORY_AND_DISK)
120    rddUpdateCount += 1
121
122    // Handle checkpointing if the directory is not None.
123    if (canCheckpoint && checkpointInterval != -1 && (rddUpdateCount % checkpointInterval) == 0) {
124      // Let's see if we can delete previous checkpoints.
125      var canDelete = true
126      while (checkpointQueue.size > 1 && canDelete) {
127        // We can delete the oldest checkpoint iff
128        // the next checkpoint actually exists in the file system.
129        if (checkpointQueue(1).getCheckpointFile.isDefined) {
130          val old = checkpointQueue.dequeue()
131          // Since the old checkpoint is not deleted by Spark, we'll manually delete it here.
132          try {
133            val path = new Path(old.getCheckpointFile.get)
134            val fs = path.getFileSystem(hadoopConf)
135            fs.delete(path, true)
136          } catch {
137            case e: IOException =>
138              logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
139                s" file: ${old.getCheckpointFile.get}")
140          }
141        } else {
142          canDelete = false
143        }
144      }
145
146      nodeIdsForInstances.checkpoint()
147      checkpointQueue.enqueue(nodeIdsForInstances)
148    }
149  }
150
151  /**
152   * Call this after training is finished to delete any remaining checkpoints.
153   */
154  def deleteAllCheckpoints(): Unit = {
155    while (checkpointQueue.nonEmpty) {
156      val old = checkpointQueue.dequeue()
157      if (old.getCheckpointFile.isDefined) {
158        try {
159          val path = new Path(old.getCheckpointFile.get)
160          val fs = path.getFileSystem(hadoopConf)
161          fs.delete(path, true)
162        } catch {
163          case e: IOException =>
164            logError("Decision Tree learning using cacheNodeIds failed to remove checkpoint" +
165              s" file: ${old.getCheckpointFile.get}")
166        }
167      }
168    }
169    if (prevNodeIdsForInstances != null) {
170      // Unpersist the previous one if one exists.
171      prevNodeIdsForInstances.unpersist()
172    }
173  }
174}
175
176private[spark] object NodeIdCache {
177  /**
178   * Initialize the node Id cache with initial node Id values.
179   * @param data The RDD of training rows.
180   * @param numTrees The number of trees that we want to create cache for.
181   * @param checkpointInterval The checkpointing interval
182   *                           (how often should the cache be checkpointed.).
183   * @param initVal The initial values in the cache.
184   * @return A node Id cache containing an RDD of initial root node Indices.
185   */
186  def init(
187      data: RDD[BaggedPoint[TreePoint]],
188      numTrees: Int,
189      checkpointInterval: Int,
190      initVal: Int = 1): NodeIdCache = {
191    new NodeIdCache(
192      data.map(_ => Array.fill[Int](numTrees)(initVal)),
193      checkpointInterval)
194  }
195}
196