1 /**
2  * Licensed to the Apache Software Foundation (ASF) under one
3  * or more contributor license agreements.  See the NOTICE file
4  * distributed with this work for additional information
5  * regarding copyright ownership.  The ASF licenses this file
6  * to you under the Apache License, Version 2.0 (the
7  * "License"); you may not use this file except in compliance
8  * with the License.  You may obtain a copy of the License at
9  *
10  *     http://www.apache.org/licenses/LICENSE-2.0
11  *
12  * Unless required by applicable law or agreed to in writing, software
13  * distributed under the License is distributed on an "AS IS" BASIS,
14  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15  * See the License for the specific language governing permissions and
16  * limitations under the License.
17  */
18 
19 #include "hadoop/Pipes.hh"
20 #include "hadoop/SerialUtils.hh"
21 #include "hadoop/StringUtils.hh"
22 
23 #include <map>
24 #include <vector>
25 
26 #include <errno.h>
27 #include <netinet/in.h>
28 #include <stdint.h>
29 #include <stdio.h>
30 #include <stdlib.h>
31 #include <string.h>
32 #include <strings.h>
33 #include <unistd.h>
34 #include <sys/socket.h>
35 #include <pthread.h>
36 #include <iostream>
37 #include <fstream>
38 
39 #include <openssl/hmac.h>
40 #include <openssl/buffer.h>
41 
42 using std::map;
43 using std::string;
44 using std::vector;
45 
46 using namespace HadoopUtils;
47 
48 namespace HadoopPipes {
49 
50   class JobConfImpl: public JobConf {
51   private:
52     map<string, string> values;
53   public:
set(const string & key,const string & value)54     void set(const string& key, const string& value) {
55       values[key] = value;
56     }
57 
hasKey(const string & key) const58     virtual bool hasKey(const string& key) const {
59       return values.find(key) != values.end();
60     }
61 
get(const string & key) const62     virtual const string& get(const string& key) const {
63       map<string,string>::const_iterator itr = values.find(key);
64       if (itr == values.end()) {
65         throw Error("Key " + key + " not found in JobConf");
66       }
67       return itr->second;
68     }
69 
getInt(const string & key) const70     virtual int getInt(const string& key) const {
71       const string& val = get(key);
72       return toInt(val);
73     }
74 
getFloat(const string & key) const75     virtual float getFloat(const string& key) const {
76       const string& val = get(key);
77       return toFloat(val);
78     }
79 
getBoolean(const string & key) const80     virtual bool getBoolean(const string&key) const {
81       const string& val = get(key);
82       return toBool(val);
83     }
84   };
85 
86   class DownwardProtocol {
87   public:
88     virtual void start(int protocol) = 0;
89     virtual void setJobConf(vector<string> values) = 0;
90     virtual void setInputTypes(string keyType, string valueType) = 0;
91     virtual void runMap(string inputSplit, int numReduces, bool pipedInput)= 0;
92     virtual void mapItem(const string& key, const string& value) = 0;
93     virtual void runReduce(int reduce, bool pipedOutput) = 0;
94     virtual void reduceKey(const string& key) = 0;
95     virtual void reduceValue(const string& value) = 0;
96     virtual void close() = 0;
97     virtual void abort() = 0;
~DownwardProtocol()98     virtual ~DownwardProtocol() {}
99   };
100 
101   class UpwardProtocol {
102   public:
103     virtual void output(const string& key, const string& value) = 0;
104     virtual void partitionedOutput(int reduce, const string& key,
105                                    const string& value) = 0;
106     virtual void status(const string& message) = 0;
107     virtual void progress(float progress) = 0;
108     virtual void done() = 0;
109     virtual void registerCounter(int id, const string& group,
110                                  const string& name) = 0;
111     virtual void
112       incrementCounter(const TaskContext::Counter* counter, uint64_t amount) = 0;
~UpwardProtocol()113     virtual ~UpwardProtocol() {}
114   };
115 
116   class Protocol {
117   public:
118     virtual void nextEvent() = 0;
119     virtual UpwardProtocol* getUplink() = 0;
~Protocol()120     virtual ~Protocol() {}
121   };
122 
123   class TextUpwardProtocol: public UpwardProtocol {
124   private:
125     FILE* stream;
126     static const char fieldSeparator = '\t';
127     static const char lineSeparator = '\n';
128 
writeBuffer(const string & buffer)129     void writeBuffer(const string& buffer) {
130       fputs(quoteString(buffer, "\t\n").c_str(), stream);
131     }
132 
133   public:
TextUpwardProtocol(FILE * _stream)134     TextUpwardProtocol(FILE* _stream): stream(_stream) {}
135 
output(const string & key,const string & value)136     virtual void output(const string& key, const string& value) {
137       fprintf(stream, "output%c", fieldSeparator);
138       writeBuffer(key);
139       fprintf(stream, "%c", fieldSeparator);
140       writeBuffer(value);
141       fprintf(stream, "%c", lineSeparator);
142     }
143 
partitionedOutput(int reduce,const string & key,const string & value)144     virtual void partitionedOutput(int reduce, const string& key,
145                                    const string& value) {
146       fprintf(stream, "parititionedOutput%c%d%c", fieldSeparator, reduce,
147               fieldSeparator);
148       writeBuffer(key);
149       fprintf(stream, "%c", fieldSeparator);
150       writeBuffer(value);
151       fprintf(stream, "%c", lineSeparator);
152     }
153 
status(const string & message)154     virtual void status(const string& message) {
155       fprintf(stream, "status%c%s%c", fieldSeparator, message.c_str(),
156               lineSeparator);
157     }
158 
progress(float progress)159     virtual void progress(float progress) {
160       fprintf(stream, "progress%c%f%c", fieldSeparator, progress,
161               lineSeparator);
162     }
163 
registerCounter(int id,const string & group,const string & name)164     virtual void registerCounter(int id, const string& group,
165                                  const string& name) {
166       fprintf(stream, "registerCounter%c%d%c%s%c%s%c", fieldSeparator, id,
167               fieldSeparator, group.c_str(), fieldSeparator, name.c_str(),
168               lineSeparator);
169     }
170 
incrementCounter(const TaskContext::Counter * counter,uint64_t amount)171     virtual void incrementCounter(const TaskContext::Counter* counter,
172                                   uint64_t amount) {
173       fprintf(stream, "incrCounter%c%d%c%ld%c", fieldSeparator, counter->getId(),
174               fieldSeparator, (long)amount, lineSeparator);
175     }
176 
done()177     virtual void done() {
178       fprintf(stream, "done%c", lineSeparator);
179     }
180   };
181 
182   class TextProtocol: public Protocol {
183   private:
184     FILE* downStream;
185     DownwardProtocol* handler;
186     UpwardProtocol* uplink;
187     string key;
188     string value;
189 
readUpto(string & buffer,const char * limit)190     int readUpto(string& buffer, const char* limit) {
191       int ch;
192       buffer.clear();
193       while ((ch = getc(downStream)) != -1) {
194         if (strchr(limit, ch) != NULL) {
195           return ch;
196         }
197         buffer += ch;
198       }
199       return -1;
200     }
201 
202     static const char* delim;
203   public:
204 
TextProtocol(FILE * down,DownwardProtocol * _handler,FILE * up)205     TextProtocol(FILE* down, DownwardProtocol* _handler, FILE* up) {
206       downStream = down;
207       uplink = new TextUpwardProtocol(up);
208       handler = _handler;
209     }
210 
getUplink()211     UpwardProtocol* getUplink() {
212       return uplink;
213     }
214 
nextEvent()215     virtual void nextEvent() {
216       string command;
217       string arg;
218       int sep;
219       sep = readUpto(command, delim);
220       if (command == "mapItem") {
221         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
222         sep = readUpto(key, delim);
223         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
224         sep = readUpto(value, delim);
225         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
226         handler->mapItem(key, value);
227       } else if (command == "reduceValue") {
228         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
229         sep = readUpto(value, delim);
230         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
231         handler->reduceValue(value);
232       } else if (command == "reduceKey") {
233         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
234         sep = readUpto(key, delim);
235         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
236         handler->reduceKey(key);
237       } else if (command == "start") {
238         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
239         sep = readUpto(arg, delim);
240         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
241         handler->start(toInt(arg));
242       } else if (command == "setJobConf") {
243         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
244         sep = readUpto(arg, delim);
245         int len = toInt(arg);
246         vector<string> values(len);
247         for(int i=0; i < len; ++i) {
248           HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
249           sep = readUpto(arg, delim);
250           values.push_back(arg);
251         }
252         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
253         handler->setJobConf(values);
254       } else if (command == "setInputTypes") {
255         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
256         sep = readUpto(key, delim);
257         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
258         sep = readUpto(value, delim);
259         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
260         handler->setInputTypes(key, value);
261       } else if (command == "runMap") {
262         string split;
263         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
264         sep = readUpto(split, delim);
265         string reduces;
266         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
267         sep = readUpto(reduces, delim);
268         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
269         sep = readUpto(arg, delim);
270         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
271         handler->runMap(split, toInt(reduces), toBool(arg));
272       } else if (command == "runReduce") {
273         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
274         sep = readUpto(arg, delim);
275         HADOOP_ASSERT(sep == '\t', "Short text protocol command " + command);
276         string piped;
277         sep = readUpto(piped, delim);
278         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
279         handler->runReduce(toInt(arg), toBool(piped));
280       } else if (command == "abort") {
281         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
282         handler->abort();
283       } else if (command == "close") {
284         HADOOP_ASSERT(sep == '\n', "Long text protocol command " + command);
285         handler->close();
286       } else {
287         throw Error("Illegal text protocol command " + command);
288       }
289     }
290 
~TextProtocol()291     ~TextProtocol() {
292       delete uplink;
293     }
294   };
295   const char* TextProtocol::delim = "\t\n";
296 
297   enum MESSAGE_TYPE {START_MESSAGE, SET_JOB_CONF, SET_INPUT_TYPES, RUN_MAP,
298                      MAP_ITEM, RUN_REDUCE, REDUCE_KEY, REDUCE_VALUE,
299                      CLOSE, ABORT, AUTHENTICATION_REQ,
300                      OUTPUT=50, PARTITIONED_OUTPUT, STATUS, PROGRESS, DONE,
301                      REGISTER_COUNTER, INCREMENT_COUNTER, AUTHENTICATION_RESP};
302 
303   class BinaryUpwardProtocol: public UpwardProtocol {
304   private:
305     FileOutStream* stream;
306   public:
BinaryUpwardProtocol(FILE * _stream)307     BinaryUpwardProtocol(FILE* _stream) {
308       stream = new FileOutStream();
309       HADOOP_ASSERT(stream->open(_stream), "problem opening stream");
310     }
311 
authenticate(const string & responseDigest)312     virtual void authenticate(const string &responseDigest) {
313       serializeInt(AUTHENTICATION_RESP, *stream);
314       serializeString(responseDigest, *stream);
315       stream->flush();
316     }
317 
output(const string & key,const string & value)318     virtual void output(const string& key, const string& value) {
319       serializeInt(OUTPUT, *stream);
320       serializeString(key, *stream);
321       serializeString(value, *stream);
322     }
323 
partitionedOutput(int reduce,const string & key,const string & value)324     virtual void partitionedOutput(int reduce, const string& key,
325                                    const string& value) {
326       serializeInt(PARTITIONED_OUTPUT, *stream);
327       serializeInt(reduce, *stream);
328       serializeString(key, *stream);
329       serializeString(value, *stream);
330     }
331 
status(const string & message)332     virtual void status(const string& message) {
333       serializeInt(STATUS, *stream);
334       serializeString(message, *stream);
335     }
336 
progress(float progress)337     virtual void progress(float progress) {
338       serializeInt(PROGRESS, *stream);
339       serializeFloat(progress, *stream);
340       stream->flush();
341     }
342 
done()343     virtual void done() {
344       serializeInt(DONE, *stream);
345     }
346 
registerCounter(int id,const string & group,const string & name)347     virtual void registerCounter(int id, const string& group,
348                                  const string& name) {
349       serializeInt(REGISTER_COUNTER, *stream);
350       serializeInt(id, *stream);
351       serializeString(group, *stream);
352       serializeString(name, *stream);
353     }
354 
incrementCounter(const TaskContext::Counter * counter,uint64_t amount)355     virtual void incrementCounter(const TaskContext::Counter* counter,
356                                   uint64_t amount) {
357       serializeInt(INCREMENT_COUNTER, *stream);
358       serializeInt(counter->getId(), *stream);
359       serializeLong(amount, *stream);
360     }
361 
~BinaryUpwardProtocol()362     ~BinaryUpwardProtocol() {
363       delete stream;
364     }
365   };
366 
367   class BinaryProtocol: public Protocol {
368   private:
369     FileInStream* downStream;
370     DownwardProtocol* handler;
371     BinaryUpwardProtocol * uplink;
372     string key;
373     string value;
374     string password;
375     bool authDone;
getPassword(string & password)376     void getPassword(string &password) {
377       const char *passwordFile = getenv("hadoop.pipes.shared.secret.location");
378       if (passwordFile == NULL) {
379         return;
380       }
381       std::ifstream fstr(passwordFile, std::fstream::binary);
382       if (fstr.fail()) {
383         std::cerr << "Could not open the password file" << std::endl;
384         return;
385       }
386       unsigned char * passBuff = new unsigned char [512];
387       fstr.read((char *)passBuff, 512);
388       int passwordLength = fstr.gcount();
389       fstr.close();
390       passBuff[passwordLength] = 0;
391       password.replace(0, passwordLength, (const char *) passBuff, passwordLength);
392       delete [] passBuff;
393       return;
394     }
395 
verifyDigestAndRespond(string & digest,string & challenge)396     void verifyDigestAndRespond(string& digest, string& challenge) {
397       if (password.empty()) {
398         //password can be empty if process is running in debug mode from
399         //command file.
400         authDone = true;
401         return;
402       }
403 
404       if (!verifyDigest(password, digest, challenge)) {
405         std::cerr << "Server failed to authenticate. Exiting" << std::endl;
406         exit(-1);
407       }
408       authDone = true;
409       string responseDigest = createDigest(password, digest);
410       uplink->authenticate(responseDigest);
411     }
412 
verifyDigest(string & password,string & digest,string & challenge)413     bool verifyDigest(string &password, string& digest, string& challenge) {
414       string expectedDigest = createDigest(password, challenge);
415       if (digest == expectedDigest) {
416         return true;
417       } else {
418         return false;
419       }
420     }
421 
createDigest(string & password,string & msg)422     string createDigest(string &password, string& msg) {
423 #if OPENSSL_VERSION_NUMBER < 0x10100000L
424       HMAC_CTX ctx;
425       unsigned char digest[EVP_MAX_MD_SIZE];
426       HMAC_Init(&ctx, (const unsigned char *)password.c_str(),
427           password.length(), EVP_sha1());
428       HMAC_Update(&ctx, (const unsigned char *)msg.c_str(), msg.length());
429       unsigned int digestLen;
430       HMAC_Final(&ctx, digest, &digestLen);
431       HMAC_cleanup(&ctx);
432 #else
433       HMAC_CTX *ctx = HMAC_CTX_new();
434       unsigned char digest[EVP_MAX_MD_SIZE];
435       HMAC_Init_ex(ctx, (const unsigned char *)password.c_str(),
436           password.length(), EVP_sha1(), NULL);
437       HMAC_Update(ctx, (const unsigned char *)msg.c_str(), msg.length());
438       unsigned int digestLen;
439       HMAC_Final(ctx, digest, &digestLen);
440       HMAC_CTX_free(ctx);
441 #endif
442       //now apply base64 encoding
443       BIO *bmem, *b64;
444       BUF_MEM *bptr;
445 
446       b64 = BIO_new(BIO_f_base64());
447       bmem = BIO_new(BIO_s_mem());
448       b64 = BIO_push(b64, bmem);
449       BIO_write(b64, digest, digestLen);
450       BIO_flush(b64);
451       BIO_get_mem_ptr(b64, &bptr);
452 
453       char digestBuffer[bptr->length];
454       memcpy(digestBuffer, bptr->data, bptr->length-1);
455       digestBuffer[bptr->length-1] = 0;
456       BIO_free_all(b64);
457 
458       return string(digestBuffer);
459     }
460 
461   public:
BinaryProtocol(FILE * down,DownwardProtocol * _handler,FILE * up)462     BinaryProtocol(FILE* down, DownwardProtocol* _handler, FILE* up) {
463       downStream = new FileInStream();
464       downStream->open(down);
465       uplink = new BinaryUpwardProtocol(up);
466       handler = _handler;
467       authDone = false;
468       getPassword(password);
469     }
470 
getUplink()471     UpwardProtocol* getUplink() {
472       return uplink;
473     }
474 
nextEvent()475     virtual void nextEvent() {
476       int32_t cmd;
477       cmd = deserializeInt(*downStream);
478       if (!authDone && cmd != AUTHENTICATION_REQ) {
479         //Authentication request must be the first message if
480         //authentication is not complete
481         std::cerr << "Command:" << cmd << "received before authentication. "
482             << "Exiting.." << std::endl;
483         exit(-1);
484       }
485       switch (cmd) {
486       case AUTHENTICATION_REQ: {
487         string digest;
488         string challenge;
489         deserializeString(digest, *downStream);
490         deserializeString(challenge, *downStream);
491         verifyDigestAndRespond(digest, challenge);
492         break;
493       }
494       case START_MESSAGE: {
495         int32_t prot;
496         prot = deserializeInt(*downStream);
497         handler->start(prot);
498         break;
499       }
500       case SET_JOB_CONF: {
501         int32_t entries;
502         entries = deserializeInt(*downStream);
503         vector<string> result(entries);
504         for(int i=0; i < entries; ++i) {
505           string item;
506           deserializeString(item, *downStream);
507           result.push_back(item);
508         }
509         handler->setJobConf(result);
510         break;
511       }
512       case SET_INPUT_TYPES: {
513         string keyType;
514         string valueType;
515         deserializeString(keyType, *downStream);
516         deserializeString(valueType, *downStream);
517         handler->setInputTypes(keyType, valueType);
518         break;
519       }
520       case RUN_MAP: {
521         string split;
522         int32_t numReduces;
523         int32_t piped;
524         deserializeString(split, *downStream);
525         numReduces = deserializeInt(*downStream);
526         piped = deserializeInt(*downStream);
527         handler->runMap(split, numReduces, piped);
528         break;
529       }
530       case MAP_ITEM: {
531         deserializeString(key, *downStream);
532         deserializeString(value, *downStream);
533         handler->mapItem(key, value);
534         break;
535       }
536       case RUN_REDUCE: {
537         int32_t reduce;
538         int32_t piped;
539         reduce = deserializeInt(*downStream);
540         piped = deserializeInt(*downStream);
541         handler->runReduce(reduce, piped);
542         break;
543       }
544       case REDUCE_KEY: {
545         deserializeString(key, *downStream);
546         handler->reduceKey(key);
547         break;
548       }
549       case REDUCE_VALUE: {
550         deserializeString(value, *downStream);
551         handler->reduceValue(value);
552         break;
553       }
554       case CLOSE:
555         handler->close();
556         break;
557       case ABORT:
558         handler->abort();
559         break;
560       default:
561         HADOOP_ASSERT(false, "Unknown binary command " + toString(cmd));
562       }
563     }
564 
~BinaryProtocol()565     virtual ~BinaryProtocol() {
566       delete downStream;
567       delete uplink;
568     }
569   };
570 
571   /**
572    * Define a context object to give to combiners that will let them
573    * go through the values and emit their results correctly.
574    */
575   class CombineContext: public ReduceContext {
576   private:
577     ReduceContext* baseContext;
578     Partitioner* partitioner;
579     int numReduces;
580     UpwardProtocol* uplink;
581     bool firstKey;
582     bool firstValue;
583     map<string, vector<string> >::iterator keyItr;
584     map<string, vector<string> >::iterator endKeyItr;
585     vector<string>::iterator valueItr;
586     vector<string>::iterator endValueItr;
587 
588   public:
CombineContext(ReduceContext * _baseContext,Partitioner * _partitioner,int _numReduces,UpwardProtocol * _uplink,map<string,vector<string>> & data)589     CombineContext(ReduceContext* _baseContext,
590                    Partitioner* _partitioner,
591                    int _numReduces,
592                    UpwardProtocol* _uplink,
593                    map<string, vector<string> >& data) {
594       baseContext = _baseContext;
595       partitioner = _partitioner;
596       numReduces = _numReduces;
597       uplink = _uplink;
598       keyItr = data.begin();
599       endKeyItr = data.end();
600       firstKey = true;
601       firstValue = true;
602     }
603 
getJobConf()604     virtual const JobConf* getJobConf() {
605       return baseContext->getJobConf();
606     }
607 
getInputKey()608     virtual const std::string& getInputKey() {
609       return keyItr->first;
610     }
611 
getInputValue()612     virtual const std::string& getInputValue() {
613       return *valueItr;
614     }
615 
emit(const std::string & key,const std::string & value)616     virtual void emit(const std::string& key, const std::string& value) {
617       if (partitioner != NULL) {
618         uplink->partitionedOutput(partitioner->partition(key, numReduces),
619                                   key, value);
620       } else {
621         uplink->output(key, value);
622       }
623     }
624 
progress()625     virtual void progress() {
626       baseContext->progress();
627     }
628 
setStatus(const std::string & status)629     virtual void setStatus(const std::string& status) {
630       baseContext->setStatus(status);
631     }
632 
nextKey()633     bool nextKey() {
634       if (firstKey) {
635         firstKey = false;
636       } else {
637         ++keyItr;
638       }
639       if (keyItr != endKeyItr) {
640         valueItr = keyItr->second.begin();
641         endValueItr = keyItr->second.end();
642         firstValue = true;
643         return true;
644       }
645       return false;
646     }
647 
nextValue()648     virtual bool nextValue() {
649       if (firstValue) {
650         firstValue = false;
651       } else {
652         ++valueItr;
653       }
654       return valueItr != endValueItr;
655     }
656 
getCounter(const std::string & group,const std::string & name)657     virtual Counter* getCounter(const std::string& group,
658                                const std::string& name) {
659       return baseContext->getCounter(group, name);
660     }
661 
incrementCounter(const Counter * counter,uint64_t amount)662     virtual void incrementCounter(const Counter* counter, uint64_t amount) {
663       baseContext->incrementCounter(counter, amount);
664     }
665   };
666 
667   /**
668    * A RecordWriter that will take the map outputs, buffer them up and then
669    * combine then when the buffer is full.
670    */
671   class CombineRunner: public RecordWriter {
672   private:
673     map<string, vector<string> > data;
674     int64_t spillSize;
675     int64_t numBytes;
676     ReduceContext* baseContext;
677     Partitioner* partitioner;
678     int numReduces;
679     UpwardProtocol* uplink;
680     Reducer* combiner;
681   public:
CombineRunner(int64_t _spillSize,ReduceContext * _baseContext,Reducer * _combiner,UpwardProtocol * _uplink,Partitioner * _partitioner,int _numReduces)682     CombineRunner(int64_t _spillSize, ReduceContext* _baseContext,
683                   Reducer* _combiner, UpwardProtocol* _uplink,
684                   Partitioner* _partitioner, int _numReduces) {
685       numBytes = 0;
686       spillSize = _spillSize;
687       baseContext = _baseContext;
688       partitioner = _partitioner;
689       numReduces = _numReduces;
690       uplink = _uplink;
691       combiner = _combiner;
692     }
693 
emit(const std::string & key,const std::string & value)694     virtual void emit(const std::string& key,
695                       const std::string& value) {
696       numBytes += key.length() + value.length();
697       data[key].push_back(value);
698       if (numBytes >= spillSize) {
699         spillAll();
700       }
701     }
702 
close()703     virtual void close() {
704       spillAll();
705     }
706 
707   private:
spillAll()708     void spillAll() {
709       CombineContext context(baseContext, partitioner, numReduces,
710                              uplink, data);
711       while (context.nextKey()) {
712         combiner->reduce(context);
713       }
714       data.clear();
715       numBytes = 0;
716     }
717   };
718 
719   class TaskContextImpl: public MapContext, public ReduceContext,
720                          public DownwardProtocol {
721   private:
722     bool done;
723     JobConf* jobConf;
724     string key;
725     const string* newKey;
726     const string* value;
727     bool hasTask;
728     bool isNewKey;
729     bool isNewValue;
730     string* inputKeyClass;
731     string* inputValueClass;
732     string status;
733     float progressFloat;
734     uint64_t lastProgress;
735     bool statusSet;
736     Protocol* protocol;
737     UpwardProtocol *uplink;
738     string* inputSplit;
739     RecordReader* reader;
740     Mapper* mapper;
741     Reducer* reducer;
742     RecordWriter* writer;
743     Partitioner* partitioner;
744     int numReduces;
745     const Factory* factory;
746     pthread_mutex_t mutexDone;
747     std::vector<int> registeredCounterIds;
748 
749   public:
750 
TaskContextImpl(const Factory & _factory)751     TaskContextImpl(const Factory& _factory) {
752       statusSet = false;
753       done = false;
754       newKey = NULL;
755       factory = &_factory;
756       jobConf = NULL;
757       inputKeyClass = NULL;
758       inputValueClass = NULL;
759       inputSplit = NULL;
760       mapper = NULL;
761       reducer = NULL;
762       reader = NULL;
763       writer = NULL;
764       partitioner = NULL;
765       protocol = NULL;
766       isNewKey = false;
767       isNewValue = false;
768       lastProgress = 0;
769       progressFloat = 0.0f;
770       hasTask = false;
771       pthread_mutex_init(&mutexDone, NULL);
772     }
773 
setProtocol(Protocol * _protocol,UpwardProtocol * _uplink)774     void setProtocol(Protocol* _protocol, UpwardProtocol* _uplink) {
775 
776       protocol = _protocol;
777       uplink = _uplink;
778     }
779 
start(int protocol)780     virtual void start(int protocol) {
781       if (protocol != 0) {
782         throw Error("Protocol version " + toString(protocol) +
783                     " not supported");
784       }
785     }
786 
setJobConf(vector<string> values)787     virtual void setJobConf(vector<string> values) {
788       int len = values.size();
789       JobConfImpl* result = new JobConfImpl();
790       HADOOP_ASSERT(len % 2 == 0, "Odd length of job conf values");
791       for(int i=0; i < len; i += 2) {
792         result->set(values[i], values[i+1]);
793       }
794       jobConf = result;
795     }
796 
setInputTypes(string keyType,string valueType)797     virtual void setInputTypes(string keyType, string valueType) {
798       inputKeyClass = new string(keyType);
799       inputValueClass = new string(valueType);
800     }
801 
runMap(string _inputSplit,int _numReduces,bool pipedInput)802     virtual void runMap(string _inputSplit, int _numReduces, bool pipedInput) {
803       inputSplit = new string(_inputSplit);
804       reader = factory->createRecordReader(*this);
805       HADOOP_ASSERT((reader == NULL) == pipedInput,
806                     pipedInput ? "RecordReader defined when not needed.":
807                     "RecordReader not defined");
808       if (reader != NULL) {
809         value = new string();
810       }
811       mapper = factory->createMapper(*this);
812       numReduces = _numReduces;
813       if (numReduces != 0) {
814         reducer = factory->createCombiner(*this);
815         partitioner = factory->createPartitioner(*this);
816       }
817       if (reducer != NULL) {
818         int64_t spillSize = 100;
819         if (jobConf->hasKey("mapreduce.task.io.sort.mb")) {
820           spillSize = jobConf->getInt("mapreduce.task.io.sort.mb");
821         }
822         writer = new CombineRunner(spillSize * 1024 * 1024, this, reducer,
823                                    uplink, partitioner, numReduces);
824       }
825       hasTask = true;
826     }
827 
mapItem(const string & _key,const string & _value)828     virtual void mapItem(const string& _key, const string& _value) {
829       newKey = &_key;
830       value = &_value;
831       isNewKey = true;
832     }
833 
runReduce(int reduce,bool pipedOutput)834     virtual void runReduce(int reduce, bool pipedOutput) {
835       reducer = factory->createReducer(*this);
836       writer = factory->createRecordWriter(*this);
837       HADOOP_ASSERT((writer == NULL) == pipedOutput,
838                     pipedOutput ? "RecordWriter defined when not needed.":
839                     "RecordWriter not defined");
840       hasTask = true;
841     }
842 
reduceKey(const string & _key)843     virtual void reduceKey(const string& _key) {
844       isNewKey = true;
845       newKey = &_key;
846     }
847 
reduceValue(const string & _value)848     virtual void reduceValue(const string& _value) {
849       isNewValue = true;
850       value = &_value;
851     }
852 
isDone()853     virtual bool isDone() {
854       pthread_mutex_lock(&mutexDone);
855       bool doneCopy = done;
856       pthread_mutex_unlock(&mutexDone);
857       return doneCopy;
858     }
859 
close()860     virtual void close() {
861       pthread_mutex_lock(&mutexDone);
862       done = true;
863       pthread_mutex_unlock(&mutexDone);
864     }
865 
abort()866     virtual void abort() {
867       throw Error("Aborted by driver");
868     }
869 
waitForTask()870     void waitForTask() {
871       while (!done && !hasTask) {
872         protocol->nextEvent();
873       }
874     }
875 
nextKey()876     bool nextKey() {
877       if (reader == NULL) {
878         while (!isNewKey) {
879           nextValue();
880           if (done) {
881             return false;
882           }
883         }
884         key = *newKey;
885       } else {
886         if (!reader->next(key, const_cast<string&>(*value))) {
887           pthread_mutex_lock(&mutexDone);
888           done = true;
889           pthread_mutex_unlock(&mutexDone);
890           return false;
891         }
892         progressFloat = reader->getProgress();
893       }
894       isNewKey = false;
895       if (mapper != NULL) {
896         mapper->map(*this);
897       } else {
898         reducer->reduce(*this);
899       }
900       return true;
901     }
902 
903     /**
904      * Advance to the next value.
905      */
nextValue()906     virtual bool nextValue() {
907       if (isNewKey || done) {
908         return false;
909       }
910       isNewValue = false;
911       progress();
912       protocol->nextEvent();
913       return isNewValue;
914     }
915 
916     /**
917      * Get the JobConf for the current task.
918      */
getJobConf()919     virtual JobConf* getJobConf() {
920       return jobConf;
921     }
922 
923     /**
924      * Get the current key.
925      * @return the current key or NULL if called before the first map or reduce
926      */
getInputKey()927     virtual const string& getInputKey() {
928       return key;
929     }
930 
931     /**
932      * Get the current value.
933      * @return the current value or NULL if called before the first map or
934      *    reduce
935      */
getInputValue()936     virtual const string& getInputValue() {
937       return *value;
938     }
939 
940     /**
941      * Mark your task as having made progress without changing the status
942      * message.
943      */
progress()944     virtual void progress() {
945       if (uplink != 0) {
946         uint64_t now = getCurrentMillis();
947         if (now - lastProgress > 1000) {
948           lastProgress = now;
949           if (statusSet) {
950             uplink->status(status);
951             statusSet = false;
952           }
953           uplink->progress(progressFloat);
954         }
955       }
956     }
957 
958     /**
959      * Set the status message and call progress.
960      */
setStatus(const string & status)961     virtual void setStatus(const string& status) {
962       this->status = status;
963       statusSet = true;
964       progress();
965     }
966 
967     /**
968      * Get the name of the key class of the input to this task.
969      */
getInputKeyClass()970     virtual const string& getInputKeyClass() {
971       return *inputKeyClass;
972     }
973 
974     /**
975      * Get the name of the value class of the input to this task.
976      */
getInputValueClass()977     virtual const string& getInputValueClass() {
978       return *inputValueClass;
979     }
980 
981     /**
982      * Access the InputSplit of the mapper.
983      */
getInputSplit()984     virtual const std::string& getInputSplit() {
985       return *inputSplit;
986     }
987 
emit(const string & key,const string & value)988     virtual void emit(const string& key, const string& value) {
989       progress();
990       if (writer != NULL) {
991         writer->emit(key, value);
992       } else if (partitioner != NULL) {
993         int part = partitioner->partition(key, numReduces);
994         uplink->partitionedOutput(part, key, value);
995       } else {
996         uplink->output(key, value);
997       }
998     }
999 
1000     /**
1001      * Register a counter with the given group and name.
1002      */
getCounter(const std::string & group,const std::string & name)1003     virtual Counter* getCounter(const std::string& group,
1004                                const std::string& name) {
1005       int id = registeredCounterIds.size();
1006       registeredCounterIds.push_back(id);
1007       uplink->registerCounter(id, group, name);
1008       return new Counter(id);
1009     }
1010 
1011     /**
1012      * Increment the value of the counter with the given amount.
1013      */
incrementCounter(const Counter * counter,uint64_t amount)1014     virtual void incrementCounter(const Counter* counter, uint64_t amount) {
1015       uplink->incrementCounter(counter, amount);
1016     }
1017 
closeAll()1018     void closeAll() {
1019       if (reader) {
1020         reader->close();
1021       }
1022       if (mapper) {
1023         mapper->close();
1024       }
1025       if (reducer) {
1026         reducer->close();
1027       }
1028       if (writer) {
1029         writer->close();
1030       }
1031     }
1032 
~TaskContextImpl()1033     virtual ~TaskContextImpl() {
1034       delete jobConf;
1035       delete inputKeyClass;
1036       delete inputValueClass;
1037       delete inputSplit;
1038       if (reader) {
1039         delete value;
1040       }
1041       delete reader;
1042       delete mapper;
1043       delete reducer;
1044       delete writer;
1045       delete partitioner;
1046       pthread_mutex_destroy(&mutexDone);
1047     }
1048   };
1049 
1050   /**
1051    * Ping the parent every 5 seconds to know if it is alive
1052    */
ping(void * ptr)1053   void* ping(void* ptr) {
1054     TaskContextImpl* context = (TaskContextImpl*) ptr;
1055     char* portStr = getenv("mapreduce.pipes.command.port");
1056     int MAX_RETRIES = 3;
1057     int remaining_retries = MAX_RETRIES;
1058     while (!context->isDone()) {
1059       try{
1060         sleep(5);
1061         int sock = -1;
1062         if (portStr) {
1063           sock = socket(PF_INET, SOCK_STREAM, 0);
1064           HADOOP_ASSERT(sock != - 1,
1065                         string("problem creating socket: ") + strerror(errno));
1066           sockaddr_in addr;
1067           addr.sin_family = AF_INET;
1068           addr.sin_port = htons(toInt(portStr));
1069           addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
1070           HADOOP_ASSERT(connect(sock, (sockaddr*) &addr, sizeof(addr)) == 0,
1071                         string("problem connecting command socket: ") +
1072                         strerror(errno));
1073 
1074         }
1075         if (sock != -1) {
1076           int result = shutdown(sock, SHUT_RDWR);
1077           HADOOP_ASSERT(result == 0, "problem shutting socket");
1078           result = close(sock);
1079           HADOOP_ASSERT(result == 0, "problem closing socket");
1080         }
1081         remaining_retries = MAX_RETRIES;
1082       } catch (Error& err) {
1083         if (!context->isDone()) {
1084           fprintf(stderr, "Hadoop Pipes Exception: in ping %s\n",
1085                 err.getMessage().c_str());
1086           remaining_retries -= 1;
1087           if (remaining_retries == 0) {
1088             exit(1);
1089           }
1090         } else {
1091           return NULL;
1092         }
1093       }
1094     }
1095     return NULL;
1096   }
1097 
1098   /**
1099    * Run the assigned task in the framework.
1100    * The user's main function should set the various functions using the
1101    * set* functions above and then call this.
1102    * @return true, if the task succeeded.
1103    */
runTask(const Factory & factory)1104   bool runTask(const Factory& factory) {
1105     try {
1106       TaskContextImpl* context = new TaskContextImpl(factory);
1107       Protocol* connection;
1108       char* portStr = getenv("mapreduce.pipes.command.port");
1109       int sock = -1;
1110       FILE* stream = NULL;
1111       FILE* outStream = NULL;
1112       char *bufin = NULL;
1113       char *bufout = NULL;
1114       if (portStr) {
1115         sock = socket(PF_INET, SOCK_STREAM, 0);
1116         HADOOP_ASSERT(sock != - 1,
1117                       string("problem creating socket: ") + strerror(errno));
1118         sockaddr_in addr;
1119         addr.sin_family = AF_INET;
1120         addr.sin_port = htons(toInt(portStr));
1121         addr.sin_addr.s_addr = htonl(INADDR_LOOPBACK);
1122         HADOOP_ASSERT(connect(sock, (sockaddr*) &addr, sizeof(addr)) == 0,
1123                       string("problem connecting command socket: ") +
1124                       strerror(errno));
1125 
1126         stream = fdopen(sock, "r");
1127         outStream = fdopen(sock, "w");
1128 
1129         // increase buffer size
1130         int bufsize = 128*1024;
1131         int setbuf;
1132         bufin = new char[bufsize];
1133         bufout = new char[bufsize];
1134         setbuf = setvbuf(stream, bufin, _IOFBF, bufsize);
1135         HADOOP_ASSERT(setbuf == 0, string("problem with setvbuf for inStream: ")
1136                                      + strerror(errno));
1137         setbuf = setvbuf(outStream, bufout, _IOFBF, bufsize);
1138         HADOOP_ASSERT(setbuf == 0, string("problem with setvbuf for outStream: ")
1139                                      + strerror(errno));
1140         connection = new BinaryProtocol(stream, context, outStream);
1141       } else if (getenv("mapreduce.pipes.commandfile")) {
1142         char* filename = getenv("mapreduce.pipes.commandfile");
1143         string outFilename = filename;
1144         outFilename += ".out";
1145         stream = fopen(filename, "r");
1146         outStream = fopen(outFilename.c_str(), "w");
1147         connection = new BinaryProtocol(stream, context, outStream);
1148       } else {
1149         connection = new TextProtocol(stdin, context, stdout);
1150       }
1151       context->setProtocol(connection, connection->getUplink());
1152       pthread_t pingThread;
1153       pthread_create(&pingThread, NULL, ping, (void*)(context));
1154       context->waitForTask();
1155       while (!context->isDone()) {
1156         context->nextKey();
1157       }
1158       context->closeAll();
1159       connection->getUplink()->done();
1160       pthread_join(pingThread,NULL);
1161       delete context;
1162       delete connection;
1163       if (stream != NULL) {
1164         fflush(stream);
1165       }
1166       if (outStream != NULL) {
1167         fflush(outStream);
1168       }
1169       fflush(stdout);
1170       if (sock != -1) {
1171         int result = shutdown(sock, SHUT_RDWR);
1172         HADOOP_ASSERT(result == 0, "problem shutting socket");
1173         result = close(sock);
1174         HADOOP_ASSERT(result == 0, "problem closing socket");
1175       }
1176       if (stream != NULL) {
1177         //fclose(stream);
1178       }
1179       if (outStream != NULL) {
1180         //fclose(outStream);
1181       }
1182       delete bufin;
1183       delete bufout;
1184       return true;
1185     } catch (Error& err) {
1186       fprintf(stderr, "Hadoop Pipes Exception: %s\n",
1187               err.getMessage().c_str());
1188       return false;
1189     }
1190   }
1191 }
1192 
1193