1/*
2 * Licensed to the Apache Software Foundation (ASF) under one or more
3 * contributor license agreements.  See the NOTICE file distributed with
4 * this work for additional information regarding copyright ownership.
5 * The ASF licenses this file to You under the Apache License, Version 2.0
6 * (the "License"); you may not use this file except in compliance with
7 * the License.  You may obtain a copy of the License at
8 *
9 *    http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18package org.apache.spark.memory
19
20import javax.annotation.concurrent.GuardedBy
21
22import scala.collection.mutable
23
24import org.apache.spark.internal.Logging
25
26/**
27 * Implements policies and bookkeeping for sharing an adjustable-sized pool of memory between tasks.
28 *
29 * Tries to ensure that each task gets a reasonable share of memory, instead of some task ramping up
30 * to a large amount first and then causing others to spill to disk repeatedly.
31 *
32 * If there are N tasks, it ensures that each task can acquire at least 1 / 2N of the memory
33 * before it has to spill, and at most 1 / N. Because N varies dynamically, we keep track of the
34 * set of active tasks and redo the calculations of 1 / 2N and 1 / N in waiting tasks whenever this
35 * set changes. This is all done by synchronizing access to mutable state and using wait() and
36 * notifyAll() to signal changes to callers. Prior to Spark 1.6, this arbitration of memory across
37 * tasks was performed by the ShuffleMemoryManager.
38 *
39 * @param lock a [[MemoryManager]] instance to synchronize on
40 * @param memoryMode the type of memory tracked by this pool (on- or off-heap)
41 */
42private[memory] class ExecutionMemoryPool(
43    lock: Object,
44    memoryMode: MemoryMode
45  ) extends MemoryPool(lock) with Logging {
46
47  private[this] val poolName: String = memoryMode match {
48    case MemoryMode.ON_HEAP => "on-heap execution"
49    case MemoryMode.OFF_HEAP => "off-heap execution"
50  }
51
52  /**
53   * Map from taskAttemptId -> memory consumption in bytes
54   */
55  @GuardedBy("lock")
56  private val memoryForTask = new mutable.HashMap[Long, Long]()
57
58  override def memoryUsed: Long = lock.synchronized {
59    memoryForTask.values.sum
60  }
61
62  /**
63   * Returns the memory consumption, in bytes, for the given task.
64   */
65  def getMemoryUsageForTask(taskAttemptId: Long): Long = lock.synchronized {
66    memoryForTask.getOrElse(taskAttemptId, 0L)
67  }
68
69  /**
70   * Try to acquire up to `numBytes` of memory for the given task and return the number of bytes
71   * obtained, or 0 if none can be allocated.
72   *
73   * This call may block until there is enough free memory in some situations, to make sure each
74   * task has a chance to ramp up to at least 1 / 2N of the total memory pool (where N is the # of
75   * active tasks) before it is forced to spill. This can happen if the number of tasks increase
76   * but an older task had a lot of memory already.
77   *
78   * @param numBytes number of bytes to acquire
79   * @param taskAttemptId the task attempt acquiring memory
80   * @param maybeGrowPool a callback that potentially grows the size of this pool. It takes in
81   *                      one parameter (Long) that represents the desired amount of memory by
82   *                      which this pool should be expanded.
83   * @param computeMaxPoolSize a callback that returns the maximum allowable size of this pool
84   *                           at this given moment. This is not a field because the max pool
85   *                           size is variable in certain cases. For instance, in unified
86   *                           memory management, the execution pool can be expanded by evicting
87   *                           cached blocks, thereby shrinking the storage pool.
88   *
89   * @return the number of bytes granted to the task.
90   */
91  private[memory] def acquireMemory(
92      numBytes: Long,
93      taskAttemptId: Long,
94      maybeGrowPool: Long => Unit = (additionalSpaceNeeded: Long) => Unit,
95      computeMaxPoolSize: () => Long = () => poolSize): Long = lock.synchronized {
96    assert(numBytes > 0, s"invalid number of bytes requested: $numBytes")
97
98    // TODO: clean up this clunky method signature
99
100    // Add this task to the taskMemory map just so we can keep an accurate count of the number
101    // of active tasks, to let other tasks ramp down their memory in calls to `acquireMemory`
102    if (!memoryForTask.contains(taskAttemptId)) {
103      memoryForTask(taskAttemptId) = 0L
104      // This will later cause waiting tasks to wake up and check numTasks again
105      lock.notifyAll()
106    }
107
108    // Keep looping until we're either sure that we don't want to grant this request (because this
109    // task would have more than 1 / numActiveTasks of the memory) or we have enough free
110    // memory to give it (we always let each task get at least 1 / (2 * numActiveTasks)).
111    // TODO: simplify this to limit each task to its own slot
112    while (true) {
113      val numActiveTasks = memoryForTask.keys.size
114      val curMem = memoryForTask(taskAttemptId)
115
116      // In every iteration of this loop, we should first try to reclaim any borrowed execution
117      // space from storage. This is necessary because of the potential race condition where new
118      // storage blocks may steal the free execution memory that this task was waiting for.
119      maybeGrowPool(numBytes - memoryFree)
120
121      // Maximum size the pool would have after potentially growing the pool.
122      // This is used to compute the upper bound of how much memory each task can occupy. This
123      // must take into account potential free memory as well as the amount this pool currently
124      // occupies. Otherwise, we may run into SPARK-12155 where, in unified memory management,
125      // we did not take into account space that could have been freed by evicting cached blocks.
126      val maxPoolSize = computeMaxPoolSize()
127      val maxMemoryPerTask = maxPoolSize / numActiveTasks
128      val minMemoryPerTask = poolSize / (2 * numActiveTasks)
129
130      // How much we can grant this task; keep its share within 0 <= X <= 1 / numActiveTasks
131      val maxToGrant = math.min(numBytes, math.max(0, maxMemoryPerTask - curMem))
132      // Only give it as much memory as is free, which might be none if it reached 1 / numTasks
133      val toGrant = math.min(maxToGrant, memoryFree)
134
135      // We want to let each task get at least 1 / (2 * numActiveTasks) before blocking;
136      // if we can't give it this much now, wait for other tasks to free up memory
137      // (this happens if older tasks allocated lots of memory before N grew)
138      if (toGrant < numBytes && curMem + toGrant < minMemoryPerTask) {
139        logInfo(s"TID $taskAttemptId waiting for at least 1/2N of $poolName pool to be free")
140        lock.wait()
141      } else {
142        memoryForTask(taskAttemptId) += toGrant
143        return toGrant
144      }
145    }
146    0L  // Never reached
147  }
148
149  /**
150   * Release `numBytes` of memory acquired by the given task.
151   */
152  def releaseMemory(numBytes: Long, taskAttemptId: Long): Unit = lock.synchronized {
153    val curMem = memoryForTask.getOrElse(taskAttemptId, 0L)
154    var memoryToFree = if (curMem < numBytes) {
155      logWarning(
156        s"Internal error: release called on $numBytes bytes but task only has $curMem bytes " +
157          s"of memory from the $poolName pool")
158      curMem
159    } else {
160      numBytes
161    }
162    if (memoryForTask.contains(taskAttemptId)) {
163      memoryForTask(taskAttemptId) -= memoryToFree
164      if (memoryForTask(taskAttemptId) <= 0) {
165        memoryForTask.remove(taskAttemptId)
166      }
167    }
168    lock.notifyAll() // Notify waiters in acquireMemory() that memory has been freed
169  }
170
171  /**
172   * Release all memory for the given task and mark it as inactive (e.g. when a task ends).
173   * @return the number of bytes freed.
174   */
175  def releaseAllMemoryForTask(taskAttemptId: Long): Long = lock.synchronized {
176    val numBytesToFree = getMemoryUsageForTask(taskAttemptId)
177    releaseMemory(numBytesToFree, taskAttemptId)
178    numBytesToFree
179  }
180
181}
182