1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 package org.apache.hadoop.mapreduce.task.reduce;
19 
20 import java.io.IOException;
21 import java.net.InetAddress;
22 import java.net.URI;
23 import java.net.UnknownHostException;
24 import java.text.DecimalFormat;
25 import java.util.ArrayList;
26 import java.util.Collections;
27 import java.util.HashMap;
28 import java.util.HashSet;
29 import java.util.Iterator;
30 import java.util.List;
31 import java.util.Map;
32 import java.util.Random;
33 import java.util.Set;
34 import java.util.concurrent.DelayQueue;
35 import java.util.concurrent.Delayed;
36 import java.util.concurrent.TimeUnit;
37 
38 import org.apache.commons.logging.Log;
39 import org.apache.commons.logging.LogFactory;
40 import org.apache.hadoop.classification.InterfaceAudience;
41 import org.apache.hadoop.classification.InterfaceStability;
42 import org.apache.hadoop.io.IntWritable;
43 import org.apache.hadoop.mapred.Counters;
44 import org.apache.hadoop.mapred.JobConf;
45 import org.apache.hadoop.mapred.TaskCompletionEvent;
46 import org.apache.hadoop.mapred.TaskStatus;
47 import org.apache.hadoop.mapreduce.MRJobConfig;
48 import org.apache.hadoop.mapreduce.TaskAttemptID;
49 import org.apache.hadoop.mapreduce.TaskID;
50 import org.apache.hadoop.mapreduce.task.reduce.MapHost.State;
51 import org.apache.hadoop.util.Progress;
52 import org.apache.hadoop.util.Time;
53 
54 @InterfaceAudience.Private
55 @InterfaceStability.Unstable
56 public class ShuffleSchedulerImpl<K,V> implements ShuffleScheduler<K,V> {
57   static ThreadLocal<Long> shuffleStart = new ThreadLocal<Long>() {
58     protected Long initialValue() {
59       return 0L;
60     }
61   };
62 
63   private static final Log LOG = LogFactory.getLog(ShuffleSchedulerImpl.class);
64   private static final int MAX_MAPS_AT_ONCE = 20;
65   private static final long INITIAL_PENALTY = 10000;
66   private static final float PENALTY_GROWTH_RATE = 1.3f;
67   private final static int REPORT_FAILURE_LIMIT = 10;
68   private static final float BYTES_PER_MILLIS_TO_MBS = 1000f / 1024 / 1024;
69 
70   private final boolean[] finishedMaps;
71 
72   private final int totalMaps;
73   private int remainingMaps;
74   private Map<String, MapHost> mapLocations = new HashMap<String, MapHost>();
75   private Set<MapHost> pendingHosts = new HashSet<MapHost>();
76   private Set<TaskAttemptID> obsoleteMaps = new HashSet<TaskAttemptID>();
77 
78   private final TaskAttemptID reduceId;
79   private final Random random = new Random();
80   private final DelayQueue<Penalty> penalties = new DelayQueue<Penalty>();
81   private final Referee referee = new Referee();
82   private final Map<TaskAttemptID,IntWritable> failureCounts =
83     new HashMap<TaskAttemptID,IntWritable>();
84   private final Map<String,IntWritable> hostFailures =
85     new HashMap<String,IntWritable>();
86   private final TaskStatus status;
87   private final ExceptionReporter reporter;
88   private final int abortFailureLimit;
89   private final Progress progress;
90   private final Counters.Counter shuffledMapsCounter;
91   private final Counters.Counter reduceShuffleBytes;
92   private final Counters.Counter failedShuffleCounter;
93 
94   private final long startTime;
95   private long lastProgressTime;
96 
97   private final CopyTimeTracker copyTimeTracker;
98 
99   private volatile int maxMapRuntime = 0;
100   private final int maxFailedUniqueFetches;
101   private final int maxFetchFailuresBeforeReporting;
102 
103   private long totalBytesShuffledTillNow = 0;
104   private final DecimalFormat mbpsFormat = new DecimalFormat("0.00");
105 
106   private final boolean reportReadErrorImmediately;
107   private long maxDelay = MRJobConfig.DEFAULT_MAX_SHUFFLE_FETCH_RETRY_DELAY;
108   private int maxHostFailures;
109 
ShuffleSchedulerImpl(JobConf job, TaskStatus status, TaskAttemptID reduceId, ExceptionReporter reporter, Progress progress, Counters.Counter shuffledMapsCounter, Counters.Counter reduceShuffleBytes, Counters.Counter failedShuffleCounter)110   public ShuffleSchedulerImpl(JobConf job, TaskStatus status,
111                           TaskAttemptID reduceId,
112                           ExceptionReporter reporter,
113                           Progress progress,
114                           Counters.Counter shuffledMapsCounter,
115                           Counters.Counter reduceShuffleBytes,
116                           Counters.Counter failedShuffleCounter) {
117     totalMaps = job.getNumMapTasks();
118     abortFailureLimit = Math.max(30, totalMaps / 10);
119     copyTimeTracker = new CopyTimeTracker();
120     remainingMaps = totalMaps;
121     finishedMaps = new boolean[remainingMaps];
122     this.reporter = reporter;
123     this.status = status;
124     this.reduceId = reduceId;
125     this.progress = progress;
126     this.shuffledMapsCounter = shuffledMapsCounter;
127     this.reduceShuffleBytes = reduceShuffleBytes;
128     this.failedShuffleCounter = failedShuffleCounter;
129     this.startTime = Time.monotonicNow();
130     lastProgressTime = startTime;
131     referee.start();
132     this.maxFailedUniqueFetches = Math.min(totalMaps, 5);
133     this.maxFetchFailuresBeforeReporting = job.getInt(
134         MRJobConfig.SHUFFLE_FETCH_FAILURES, REPORT_FAILURE_LIMIT);
135     this.reportReadErrorImmediately = job.getBoolean(
136         MRJobConfig.SHUFFLE_NOTIFY_READERROR, true);
137 
138     this.maxDelay = job.getLong(MRJobConfig.MAX_SHUFFLE_FETCH_RETRY_DELAY,
139         MRJobConfig.DEFAULT_MAX_SHUFFLE_FETCH_RETRY_DELAY);
140     this.maxHostFailures = job.getInt(
141         MRJobConfig.MAX_SHUFFLE_FETCH_HOST_FAILURES,
142         MRJobConfig.DEFAULT_MAX_SHUFFLE_FETCH_HOST_FAILURES);
143   }
144 
145   @Override
resolve(TaskCompletionEvent event)146   public void resolve(TaskCompletionEvent event) {
147     switch (event.getTaskStatus()) {
148     case SUCCEEDED:
149       URI u = getBaseURI(reduceId, event.getTaskTrackerHttp());
150       addKnownMapOutput(u.getHost() + ":" + u.getPort(),
151           u.toString(),
152           event.getTaskAttemptId());
153       maxMapRuntime = Math.max(maxMapRuntime, event.getTaskRunTime());
154       break;
155     case FAILED:
156     case KILLED:
157     case OBSOLETE:
158       obsoleteMapOutput(event.getTaskAttemptId());
159       LOG.info("Ignoring obsolete output of " + event.getTaskStatus() +
160           " map-task: '" + event.getTaskAttemptId() + "'");
161       break;
162     case TIPFAILED:
163       tipFailed(event.getTaskAttemptId().getTaskID());
164       LOG.info("Ignoring output of failed map TIP: '" +
165           event.getTaskAttemptId() + "'");
166       break;
167     }
168   }
169 
getBaseURI(TaskAttemptID reduceId, String url)170   static URI getBaseURI(TaskAttemptID reduceId, String url) {
171     StringBuffer baseUrl = new StringBuffer(url);
172     if (!url.endsWith("/")) {
173       baseUrl.append("/");
174     }
175     baseUrl.append("mapOutput?job=");
176     baseUrl.append(reduceId.getJobID());
177     baseUrl.append("&reduce=");
178     baseUrl.append(reduceId.getTaskID().getId());
179     baseUrl.append("&map=");
180     URI u = URI.create(baseUrl.toString());
181     return u;
182   }
183 
copySucceeded(TaskAttemptID mapId, MapHost host, long bytes, long startMillis, long endMillis, MapOutput<K,V> output )184   public synchronized void copySucceeded(TaskAttemptID mapId,
185                                          MapHost host,
186                                          long bytes,
187                                          long startMillis,
188                                          long endMillis,
189                                          MapOutput<K,V> output
190                                          ) throws IOException {
191     failureCounts.remove(mapId);
192     hostFailures.remove(host.getHostName());
193     int mapIndex = mapId.getTaskID().getId();
194 
195     if (!finishedMaps[mapIndex]) {
196       output.commit();
197       finishedMaps[mapIndex] = true;
198       shuffledMapsCounter.increment(1);
199       if (--remainingMaps == 0) {
200         notifyAll();
201       }
202 
203       // update single copy task status
204       long copyMillis = (endMillis - startMillis);
205       if (copyMillis == 0) copyMillis = 1;
206       float bytesPerMillis = (float) bytes / copyMillis;
207       float transferRate = bytesPerMillis * BYTES_PER_MILLIS_TO_MBS;
208       String individualProgress = "copy task(" + mapId + " succeeded"
209           + " at " + mbpsFormat.format(transferRate) + " MB/s)";
210       // update the aggregated status
211       copyTimeTracker.add(startMillis, endMillis);
212 
213       totalBytesShuffledTillNow += bytes;
214       updateStatus(individualProgress);
215       reduceShuffleBytes.increment(bytes);
216       lastProgressTime = Time.monotonicNow();
217       LOG.debug("map " + mapId + " done " + status.getStateString());
218     }
219   }
220 
updateStatus(String individualProgress)221   private synchronized void updateStatus(String individualProgress) {
222     int mapsDone = totalMaps - remainingMaps;
223     long totalCopyMillis = copyTimeTracker.getCopyMillis();
224     if (totalCopyMillis == 0) totalCopyMillis = 1;
225     float bytesPerMillis = (float) totalBytesShuffledTillNow / totalCopyMillis;
226     float transferRate = bytesPerMillis * BYTES_PER_MILLIS_TO_MBS;
227     progress.set((float) mapsDone / totalMaps);
228     String statusString = mapsDone + " / " + totalMaps + " copied.";
229     status.setStateString(statusString);
230 
231     if (individualProgress != null) {
232       progress.setStatus(individualProgress + " Aggregated copy rate(" +
233           mapsDone + " of " + totalMaps + " at " +
234       mbpsFormat.format(transferRate) + " MB/s)");
235     } else {
236       progress.setStatus("copy(" + mapsDone + " of " + totalMaps + " at "
237           + mbpsFormat.format(transferRate) + " MB/s)");
238     }
239   }
240 
updateStatus()241   private void updateStatus() {
242     updateStatus(null);
243   }
244 
hostFailed(String hostname)245   public synchronized void hostFailed(String hostname) {
246     if (hostFailures.containsKey(hostname)) {
247       IntWritable x = hostFailures.get(hostname);
248       x.set(x.get() + 1);
249     } else {
250       hostFailures.put(hostname, new IntWritable(1));
251     }
252   }
253 
copyFailed(TaskAttemptID mapId, MapHost host, boolean readError, boolean connectExcpt)254   public synchronized void copyFailed(TaskAttemptID mapId, MapHost host,
255       boolean readError, boolean connectExcpt) {
256     host.penalize();
257     int failures = 1;
258     if (failureCounts.containsKey(mapId)) {
259       IntWritable x = failureCounts.get(mapId);
260       x.set(x.get() + 1);
261       failures = x.get();
262     } else {
263       failureCounts.put(mapId, new IntWritable(1));
264     }
265     String hostname = host.getHostName();
266     IntWritable hostFailedNum = hostFailures.get(hostname);
267     // MAPREDUCE-6361: hostname could get cleanup from hostFailures in another
268     // thread with copySucceeded.
269     // In this case, add back hostname to hostFailures to get rid of NPE issue.
270     if (hostFailedNum == null) {
271       hostFailures.put(hostname, new IntWritable(1));
272     }
273     //report failure if already retried maxHostFailures times
274     boolean hostFail = hostFailures.get(hostname).get() >
275         getMaxHostFailures() ? true : false;
276 
277     if (failures >= abortFailureLimit) {
278       try {
279         throw new IOException(failures + " failures downloading " + mapId);
280       } catch (IOException ie) {
281         reporter.reportException(ie);
282       }
283     }
284 
285     checkAndInformMRAppMaster(failures, mapId, readError, connectExcpt,
286         hostFail);
287 
288     checkReducerHealth();
289 
290     long delay = (long) (INITIAL_PENALTY *
291         Math.pow(PENALTY_GROWTH_RATE, failures));
292     if (delay > maxDelay) {
293       delay = maxDelay;
294     }
295 
296     penalties.add(new Penalty(host, delay));
297 
298     failedShuffleCounter.increment(1);
299   }
300 
reportLocalError(IOException ioe)301   public void reportLocalError(IOException ioe) {
302     try {
303       LOG.error("Shuffle failed : local error on this node: "
304           + InetAddress.getLocalHost());
305     } catch (UnknownHostException e) {
306       LOG.error("Shuffle failed : local error on this node");
307     }
308     reporter.reportException(ioe);
309   }
310 
311   // Notify the MRAppMaster
312   // after every read error, if 'reportReadErrorImmediately' is true or
313   // after every 'maxFetchFailuresBeforeReporting' failures
checkAndInformMRAppMaster( int failures, TaskAttemptID mapId, boolean readError, boolean connectExcpt, boolean hostFailed)314   private void checkAndInformMRAppMaster(
315       int failures, TaskAttemptID mapId, boolean readError,
316       boolean connectExcpt, boolean hostFailed) {
317     if (connectExcpt || (reportReadErrorImmediately && readError)
318         || ((failures % maxFetchFailuresBeforeReporting) == 0) || hostFailed) {
319       LOG.info("Reporting fetch failure for " + mapId + " to MRAppMaster.");
320       status.addFetchFailedMap((org.apache.hadoop.mapred.TaskAttemptID) mapId);
321     }
322   }
323 
checkReducerHealth()324   private void checkReducerHealth() {
325     final float MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT = 0.5f;
326     final float MIN_REQUIRED_PROGRESS_PERCENT = 0.5f;
327     final float MAX_ALLOWED_STALL_TIME_PERCENT = 0.5f;
328 
329     long totalFailures = failedShuffleCounter.getValue();
330     int doneMaps = totalMaps - remainingMaps;
331 
332     boolean reducerHealthy =
333       (((float)totalFailures / (totalFailures + doneMaps))
334           < MAX_ALLOWED_FAILED_FETCH_ATTEMPT_PERCENT);
335 
336     // check if the reducer has progressed enough
337     boolean reducerProgressedEnough =
338       (((float)doneMaps / totalMaps)
339           >= MIN_REQUIRED_PROGRESS_PERCENT);
340 
341     // check if the reducer is stalled for a long time
342     // duration for which the reducer is stalled
343     int stallDuration =
344       (int)(Time.monotonicNow() - lastProgressTime);
345 
346     // duration for which the reducer ran with progress
347     int shuffleProgressDuration =
348       (int)(lastProgressTime - startTime);
349 
350     // min time the reducer should run without getting killed
351     int minShuffleRunDuration =
352       Math.max(shuffleProgressDuration, maxMapRuntime);
353 
354     boolean reducerStalled =
355       (((float)stallDuration / minShuffleRunDuration)
356           >= MAX_ALLOWED_STALL_TIME_PERCENT);
357 
358     // kill if not healthy and has insufficient progress
359     if ((failureCounts.size() >= maxFailedUniqueFetches ||
360         failureCounts.size() == (totalMaps - doneMaps))
361         && !reducerHealthy
362         && (!reducerProgressedEnough || reducerStalled)) {
363       LOG.fatal("Shuffle failed with too many fetch failures " +
364       "and insufficient progress!");
365       String errorMsg = "Exceeded MAX_FAILED_UNIQUE_FETCHES; bailing-out.";
366       reporter.reportException(new IOException(errorMsg));
367     }
368 
369   }
370 
tipFailed(TaskID taskId)371   public synchronized void tipFailed(TaskID taskId) {
372     if (!finishedMaps[taskId.getId()]) {
373       finishedMaps[taskId.getId()] = true;
374       if (--remainingMaps == 0) {
375         notifyAll();
376       }
377       updateStatus();
378     }
379   }
380 
addKnownMapOutput(String hostName, String hostUrl, TaskAttemptID mapId)381   public synchronized void addKnownMapOutput(String hostName,
382                                              String hostUrl,
383                                              TaskAttemptID mapId) {
384     MapHost host = mapLocations.get(hostName);
385     if (host == null) {
386       host = new MapHost(hostName, hostUrl);
387       mapLocations.put(hostName, host);
388     }
389     host.addKnownMap(mapId);
390 
391     // Mark the host as pending
392     if (host.getState() == State.PENDING) {
393       pendingHosts.add(host);
394       notifyAll();
395     }
396   }
397 
398 
obsoleteMapOutput(TaskAttemptID mapId)399   public synchronized void obsoleteMapOutput(TaskAttemptID mapId) {
400     obsoleteMaps.add(mapId);
401   }
402 
putBackKnownMapOutput(MapHost host, TaskAttemptID mapId)403   public synchronized void putBackKnownMapOutput(MapHost host,
404                                                  TaskAttemptID mapId) {
405     host.addKnownMap(mapId);
406   }
407 
408 
getHost()409   public synchronized MapHost getHost() throws InterruptedException {
410       while(pendingHosts.isEmpty()) {
411         wait();
412       }
413 
414       MapHost host = null;
415       Iterator<MapHost> iter = pendingHosts.iterator();
416       int numToPick = random.nextInt(pendingHosts.size());
417       for (int i=0; i <= numToPick; ++i) {
418         host = iter.next();
419       }
420 
421       pendingHosts.remove(host);
422       host.markBusy();
423 
424       LOG.info("Assigning " + host + " with " + host.getNumKnownMapOutputs() +
425                " to " + Thread.currentThread().getName());
426       shuffleStart.set(Time.monotonicNow());
427 
428       return host;
429   }
430 
getMapsForHost(MapHost host)431   public synchronized List<TaskAttemptID> getMapsForHost(MapHost host) {
432     List<TaskAttemptID> list = host.getAndClearKnownMaps();
433     Iterator<TaskAttemptID> itr = list.iterator();
434     List<TaskAttemptID> result = new ArrayList<TaskAttemptID>();
435     int includedMaps = 0;
436     int totalSize = list.size();
437     // find the maps that we still need, up to the limit
438     while (itr.hasNext()) {
439       TaskAttemptID id = itr.next();
440       if (!obsoleteMaps.contains(id) && !finishedMaps[id.getTaskID().getId()]) {
441         result.add(id);
442         if (++includedMaps >= MAX_MAPS_AT_ONCE) {
443           break;
444         }
445       }
446     }
447     // put back the maps left after the limit
448     while (itr.hasNext()) {
449       TaskAttemptID id = itr.next();
450       if (!obsoleteMaps.contains(id) && !finishedMaps[id.getTaskID().getId()]) {
451         host.addKnownMap(id);
452       }
453     }
454     LOG.info("assigned " + includedMaps + " of " + totalSize + " to " +
455              host + " to " + Thread.currentThread().getName());
456     return result;
457   }
458 
freeHost(MapHost host)459   public synchronized void freeHost(MapHost host) {
460     if (host.getState() != State.PENALIZED) {
461       if (host.markAvailable() == State.PENDING) {
462         pendingHosts.add(host);
463         notifyAll();
464       }
465     }
466     LOG.info(host + " freed by " + Thread.currentThread().getName() + " in " +
467              (Time.monotonicNow()-shuffleStart.get()) + "ms");
468   }
469 
resetKnownMaps()470   public synchronized void resetKnownMaps() {
471     mapLocations.clear();
472     obsoleteMaps.clear();
473     pendingHosts.clear();
474   }
475 
476   /**
477    * Wait until the shuffle finishes or until the timeout.
478    * @param millis maximum wait time
479    * @return true if the shuffle is done
480    * @throws InterruptedException
481    */
482   @Override
waitUntilDone(int millis )483   public synchronized boolean waitUntilDone(int millis
484                                             ) throws InterruptedException {
485     if (remainingMaps > 0) {
486       wait(millis);
487       return remainingMaps == 0;
488     }
489     return true;
490   }
491 
492   /**
493    * A structure that records the penalty for a host.
494    */
495   private static class Penalty implements Delayed {
496     MapHost host;
497     private long endTime;
498 
Penalty(MapHost host, long delay)499     Penalty(MapHost host, long delay) {
500       this.host = host;
501       this.endTime = Time.monotonicNow() + delay;
502     }
503 
504     @Override
getDelay(TimeUnit unit)505     public long getDelay(TimeUnit unit) {
506       long remainingTime = endTime - Time.monotonicNow();
507       return unit.convert(remainingTime, TimeUnit.MILLISECONDS);
508     }
509 
510     @Override
compareTo(Delayed o)511     public int compareTo(Delayed o) {
512       long other = ((Penalty) o).endTime;
513       return endTime == other ? 0 : (endTime < other ? -1 : 1);
514     }
515 
516   }
517 
518   /**
519    * A thread that takes hosts off of the penalty list when the timer expires.
520    */
521   private class Referee extends Thread {
Referee()522     public Referee() {
523       setName("ShufflePenaltyReferee");
524       setDaemon(true);
525     }
526 
run()527     public void run() {
528       try {
529         while (true) {
530           // take the first host that has an expired penalty
531           MapHost host = penalties.take().host;
532           synchronized (ShuffleSchedulerImpl.this) {
533             if (host.markAvailable() == MapHost.State.PENDING) {
534               pendingHosts.add(host);
535               ShuffleSchedulerImpl.this.notifyAll();
536             }
537           }
538         }
539       } catch (InterruptedException ie) {
540         return;
541       } catch (Throwable t) {
542         reporter.reportException(t);
543       }
544     }
545   }
546 
547   @Override
close()548   public void close() throws InterruptedException {
549     referee.interrupt();
550     referee.join();
551   }
552 
getMaxHostFailures()553   public int getMaxHostFailures() {
554     return maxHostFailures;
555   }
556 
557   private static class CopyTimeTracker {
558     List<Interval> intervals;
559     long copyMillis;
CopyTimeTracker()560     public CopyTimeTracker() {
561       intervals = Collections.emptyList();
562       copyMillis = 0;
563     }
add(long s, long e)564     public void add(long s, long e) {
565       Interval interval = new Interval(s, e);
566       copyMillis = getTotalCopyMillis(interval);
567     }
568 
getCopyMillis()569     public long getCopyMillis() {
570       return copyMillis;
571     }
572     // This method captures the time during which any copy was in progress
573     // each copy time period is record in the Interval list
getTotalCopyMillis(Interval newInterval)574     private long getTotalCopyMillis(Interval newInterval) {
575       if (newInterval == null) {
576         return copyMillis;
577       }
578       List<Interval> result = new ArrayList<Interval>(intervals.size() + 1);
579       for (Interval interval: intervals) {
580         if (interval.end < newInterval.start) {
581           result.add(interval);
582         } else if (interval.start > newInterval.end) {
583           result.add(newInterval);
584           newInterval = interval;
585         } else {
586           newInterval = new Interval(
587               Math.min(interval.start, newInterval.start),
588               Math.max(newInterval.end, interval.end));
589         }
590       }
591       result.add(newInterval);
592       intervals = result;
593 
594       //compute total millis
595       long length = 0;
596       for (Interval interval : intervals) {
597         length += interval.getIntervalLength();
598       }
599       return length;
600     }
601 
602     private static class Interval {
603       final long start;
604       final long end;
Interval(long s, long e)605       public Interval(long s, long e) {
606         start = s;
607         end = e;
608       }
609 
getIntervalLength()610       public long getIntervalLength() {
611         return end - start;
612       }
613     }
614   }
615 }
616