1 //===- TFUtils.cpp - tensorflow evaluation utilities ----------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements utilities for interfacing with tensorflow C APIs.
10 //
11 //===----------------------------------------------------------------------===//
12 #include "llvm/Config/config.h"
13 #if defined(LLVM_HAVE_TF_API)
14 
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/Utils/TFUtils.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/JSON.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/MemoryBuffer.h"
22 #include "llvm/Support/Path.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #include "google/protobuf/text_format.h"
26 #include "tensorflow/c/c_api.h"
27 #include "tensorflow/c/c_api_experimental.h"
28 #include "tensorflow/core/example/example.pb.h"
29 #include <cassert>
30 #include <numeric>
31 
32 using namespace llvm;
33 
34 using google::protobuf::Message;
35 using google::protobuf::TextFormat;
36 
37 static cl::opt<bool>
38     ProtobufTextMode("tfutils-text-log", cl::init(false), cl::Hidden,
39                      cl::desc("Output textual (human-readable) protobuf."));
40 
41 namespace {
42 
43 using TFGraphPtr = std::unique_ptr<TF_Graph, decltype(&TF_DeleteGraph)>;
44 using TFSessionOptionsPtr =
45     std::unique_ptr<TF_SessionOptions, decltype(&TF_DeleteSessionOptions)>;
46 using TFStatusPtr = std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)>;
47 
48 struct TFInitializer {
TFInitializer__anonba1951670111::TFInitializer49   TFInitializer() {
50     assert(!IsInitialized && "TFInitialized should be called only once");
51     int Argc = 1;
52     const char *Name = "";
53     const char **NamePtr = &Name;
54     TF_InitMain(Name, &Argc, const_cast<char ***>(&NamePtr));
55     IsInitialized = true;
56   }
57   bool IsInitialized = false;
58 };
59 
60 llvm::ManagedStatic<TFInitializer> TFLibInitializer;
61 
ensureInitTF()62 bool ensureInitTF() { return TFLibInitializer->IsInitialized; }
63 
createTFGraph()64 TFGraphPtr createTFGraph() {
65   return TFGraphPtr(TF_NewGraph(), &TF_DeleteGraph);
66 }
67 
createTFStatus()68 TFStatusPtr createTFStatus() {
69   return TFStatusPtr(TF_NewStatus(), &TF_DeleteStatus);
70 }
71 
createTFSessionOptions()72 TFSessionOptionsPtr createTFSessionOptions() {
73   return TFSessionOptionsPtr(TF_NewSessionOptions(), &TF_DeleteSessionOptions);
74 }
75 } // namespace
76 
77 namespace llvm {
78 class EvaluationResultImpl {
79 public:
EvaluationResultImpl(size_t OutputSize)80   EvaluationResultImpl(size_t OutputSize)
81       : OutputSize(OutputSize), Output(OutputSize){};
82 
~EvaluationResultImpl()83   ~EvaluationResultImpl() {
84     for (auto *P : Output)
85       if (P)
86         TF_DeleteTensor(P);
87   }
88 
89   EvaluationResultImpl(const EvaluationResultImpl &) = delete;
90   EvaluationResultImpl(EvaluationResultImpl &&Other) = delete;
getOutput()91   std::vector<TF_Tensor *> &getOutput() { return Output; }
92 
93 private:
94   const size_t OutputSize;
95   std::vector<TF_Tensor *> Output;
96 };
97 
getElementByteSize() const98 size_t TensorSpec::getElementByteSize() const {
99   return TF_DataTypeSize(static_cast<TF_DataType>(TypeIndex));
100 }
101 
TensorSpec(const std::string & Name,int Port,int TypeIndex,const std::vector<int64_t> & Shape)102 TensorSpec::TensorSpec(const std::string &Name, int Port, int TypeIndex,
103                        const std::vector<int64_t> &Shape)
104     : Name(Name), Port(Port), TypeIndex(TypeIndex), Shape(Shape),
105       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
106                                    std::multiplies<int64_t>())) {}
107 
getTensorSpecFromJSON(LLVMContext & Ctx,const json::Value & Value)108 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
109                                            const json::Value &Value) {
110   auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
111     std::string S;
112     llvm::raw_string_ostream OS(S);
113     OS << Value;
114     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
115     return None;
116   };
117   // FIXME: accept a Path as a parameter, and use it for error reporting.
118   json::Path::Root Root("tensor_spec");
119   json::ObjectMapper Mapper(Value, Root);
120   if (!Mapper)
121     return EmitError("Value is not a dict");
122 
123   std::string TensorName;
124   int TensorPort = -1;
125   std::string TensorType;
126   std::vector<int64_t> TensorShape;
127 
128   if (!Mapper.map<std::string>("name", TensorName))
129     return EmitError("'name' property not present or not a string");
130   if (!Mapper.map<std::string>("type", TensorType))
131     return EmitError("'type' property not present or not a string");
132   if (!Mapper.map<int>("port", TensorPort))
133     return EmitError("'port' property not present or not an int");
134   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
135     return EmitError("'shape' property not present or not an int array");
136 
137 #define PARSE_TYPE(T, E)                                                       \
138   if (TensorType == #T)                                                        \
139     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
140   TFUTILS_SUPPORTED_TYPES(PARSE_TYPE)
141 #undef PARSE_TYPE
142   return None;
143 }
144 
145 Optional<std::vector<LoggedFeatureSpec>>
loadOutputSpecs(LLVMContext & Ctx,StringRef ExpectedDecisionName,StringRef ModelPath,StringRef SpecFileOverride)146 loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
147                 StringRef ModelPath, StringRef SpecFileOverride) {
148   SmallVector<char, 128> OutputSpecsPath;
149   StringRef FileName = SpecFileOverride;
150   if (FileName.empty()) {
151     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
152     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
153   }
154 
155   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
156   if (!BufferOrError) {
157     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
158                   BufferOrError.getError().message());
159     return None;
160   }
161   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
162   if (!ParsedJSONValues) {
163     Ctx.emitError("Could not parse specs file: " + FileName);
164     return None;
165   }
166   auto ValuesArray = ParsedJSONValues->getAsArray();
167   if (!ValuesArray) {
168     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
169                   "logging_name:<name>} dictionaries");
170     return None;
171   }
172   std::vector<LoggedFeatureSpec> Ret;
173   for (const auto &Value : *ValuesArray)
174     if (const auto *Obj = Value.getAsObject())
175       if (const auto *SpecPart = Obj->get("tensor_spec"))
176         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
177           if (auto LoggingName = Obj->getString("logging_name")) {
178             if (!TensorSpec->isElementType<int64_t>() &&
179                 !TensorSpec->isElementType<int32_t>() &&
180                 !TensorSpec->isElementType<float>()) {
181               Ctx.emitError(
182                   "Only int64, int32, and float tensors are supported. "
183                   "Found unsupported type for tensor named " +
184                   TensorSpec->name());
185               return None;
186             }
187             Ret.push_back({*TensorSpec, LoggingName->str()});
188           }
189 
190   if (ValuesArray->size() != Ret.size()) {
191     Ctx.emitError(
192         "Unable to parse output spec. It should be a json file containing an "
193         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
194         "with a json object describing a TensorSpec; and a 'logging_name' key, "
195         "which is a string to use as name when logging this tensor in the "
196         "training log.");
197     return None;
198   }
199   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
200     Ctx.emitError("The first output spec must describe the decision tensor, "
201                   "and must have the logging_name " +
202                   StringRef(ExpectedDecisionName));
203     return None;
204   }
205   return Ret;
206 }
207 
208 class TFModelEvaluatorImpl {
209 public:
210   TFModelEvaluatorImpl(StringRef SavedModelPath,
211                        const std::vector<TensorSpec> &InputSpecs,
212                        function_ref<TensorSpec(size_t)> GetOutputSpecs,
213                        size_t OutputSpecsSize, const char *Tags);
214 
isValid() const215   bool isValid() const { return IsValid; }
OutputSize() const216   size_t OutputSize() const { return OutputFeed.size(); }
217 
evaluate(TF_Tensor ** Output,TF_Status * Status)218   void evaluate(TF_Tensor **Output, TF_Status *Status) {
219     TF_SessionRun(Session, nullptr, InputFeed.data(), Input.data(),
220                   Input.size(), OutputFeed.data(), Output, OutputFeed.size(),
221                   nullptr, 0, nullptr, Status);
222   }
223 
224   void initInput(size_t Index, TF_DataType Type,
225                  const std::vector<int64_t> &Dimensions);
getInput() const226   const std::vector<TF_Tensor *> &getInput() const { return Input; }
227 
228   ~TFModelEvaluatorImpl();
229 
230 private:
231   /// The objects necessary for carrying out an evaluation of the SavedModel.
232   /// They are expensive to set up, and we maintain them accross all the
233   /// evaluations of the model.
234   TF_Session *Session = nullptr;
235   TFGraphPtr Graph;
236   TFSessionOptionsPtr Options;
237 
238   /// The specification of the input nodes.
239   std::vector<TF_Output> InputFeed;
240 
241   /// The input tensors. They must match by index of the corresponding InputFeed
242   /// value. We set up the tensors once and just mutate theirs scalars before
243   /// each evaluation. The input tensors keep their value after an evaluation.
244   std::vector<TF_Tensor *> Input;
245 
246   /// The specification of the output nodes. When evaluating, the tensors in the
247   /// output tensor vector must match by index the corresponding element in the
248   /// OutputFeed.
249   std::vector<TF_Output> OutputFeed;
250 
invalidate()251   void invalidate() { IsValid = false; }
252 
253   bool IsValid = true;
254 
255   /// Reusable utility for ensuring we can bind the requested Name to a node in
256   /// the SavedModel Graph.
257   bool checkReportAndInvalidate(const TF_Output &Output,
258                                 const TensorSpec &OutputSpec);
259 };
260 
261 class LoggerDataImpl {
262   const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
263   const TensorSpec RewardSpec;
264   const bool IncludeReward;
265 
266   std::vector<tensorflow::FeatureList> FeatureLists;
267   tensorflow::FeatureList Reward;
268 
isSelfConsistent(const tensorflow::SequenceExample & SE,size_t NrRecords) const269   bool isSelfConsistent(const tensorflow::SequenceExample &SE,
270                         size_t NrRecords) const {
271     bool Ret = true;
272     for (const auto &TSpecs : LoggedFeatureSpecs) {
273       const auto &Name = TSpecs.getLoggingName();
274       const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
275       if (NrRecords != static_cast<size_t>(FL.size())) {
276         dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
277                << NrRecords << " got " << FL.size() << "\n";
278         Ret = false;
279       }
280     }
281     if (IncludeReward && static_cast<size_t>(SE.feature_lists()
282                                                  .feature_list()
283                                                  .at(RewardSpec.name())
284                                                  .feature()
285                                                  .size()) != NrRecords) {
286       dbgs() << "[TF-UTILS]: reward is missing records.\n";
287       Ret = false;
288     }
289     return Ret;
290   }
291 
transferLog(tensorflow::SequenceExample & SE)292   void transferLog(tensorflow::SequenceExample &SE) {
293     auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
294     if (IncludeReward)
295       (*FL)[RewardSpec.name()] = std::move(Reward);
296     assert(FeatureLists.size() == LoggedFeatureSpecs.size());
297     for (size_t I = 0; I < FeatureLists.size(); ++I) {
298       const auto &LFS = LoggedFeatureSpecs[I];
299       (*FL)[LFS.getLoggingName()] = std::move(FeatureLists[I]);
300     }
301   }
302 
303 public:
LoggerDataImpl(const std::vector<LoggedFeatureSpec> & LoggedSpecs,const TensorSpec & RewardSpec,bool IncludeReward)304   LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
305                  const TensorSpec &RewardSpec, bool IncludeReward)
306       : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
307         IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
308 
309   // flush the logged info to a stream and clear the log contents.
flush(raw_ostream & OS)310   void flush(raw_ostream &OS) {
311     size_t NrRecords = getNrRecords();
312     (void)NrRecords;
313     tensorflow::SequenceExample SE;
314     transferLog(SE);
315     assert(isSelfConsistent(SE, NrRecords));
316     std::string OutStr;
317     if (ProtobufTextMode)
318       google::protobuf::TextFormat::PrintToString(SE, &OutStr);
319     else
320       OutStr = SE.SerializeAsString();
321 
322     OS << OutStr;
323   }
324 
addNewTensor(size_t FeatureID)325   char *addNewTensor(size_t FeatureID) {
326     const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
327     if (Spec.isElementType<float>()) {
328       auto *RF = FeatureLists[FeatureID]
329                      .add_feature()
330                      ->mutable_float_list()
331                      ->mutable_value();
332       RF->Resize(Spec.getElementCount(), 0.0);
333       return reinterpret_cast<char *>(RF->mutable_data());
334     } else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
335       auto *RF = FeatureLists[FeatureID]
336                      .add_feature()
337                      ->mutable_int64_list()
338                      ->mutable_value();
339       RF->Resize(Spec.getElementCount(), 0);
340       return reinterpret_cast<char *>(RF->mutable_data());
341     }
342     llvm_unreachable("Unsupported tensor type.");
343   }
344 
logReward(T Value)345   template <typename T> void logReward(T Value) {
346     assert(IncludeReward);
347     if (RewardSpec.isElementType<float>())
348       Reward.add_feature()->mutable_float_list()->add_value(Value);
349     else if (RewardSpec.isElementType<int32_t>() ||
350              RewardSpec.isElementType<int64_t>())
351       Reward.add_feature()->mutable_int64_list()->add_value(Value);
352     else
353       llvm_unreachable("Unsupported tensor type.");
354   }
355 
getNrRecords() const356   size_t getNrRecords() const {
357     return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
358   }
359 };
360 } // namespace llvm
361 
TFModelEvaluatorImpl(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,function_ref<TensorSpec (size_t)> GetOutputSpecs,size_t OutputSpecsSize,const char * Tags="serve")362 TFModelEvaluatorImpl::TFModelEvaluatorImpl(
363     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
364     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
365     const char *Tags = "serve")
366     : Graph(createTFGraph()), Options(createTFSessionOptions()),
367       InputFeed(InputSpecs.size()), Input(InputSpecs.size()),
368       OutputFeed(OutputSpecsSize) {
369   if (!ensureInitTF()) {
370     errs() << "Tensorflow should have been initialized";
371     return;
372   }
373   auto Status = createTFStatus();
374 
375   Session = TF_LoadSessionFromSavedModel(Options.get(), nullptr,
376                                          SavedModelPath.str().c_str(), &Tags, 1,
377                                          Graph.get(), nullptr, Status.get());
378   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
379     errs() << TF_Message(Status.get());
380     invalidate();
381   }
382   for (size_t I = 0; I < InputSpecs.size(); ++I) {
383     auto &InputSpec = InputSpecs[I];
384     InputFeed[I] = {
385         TF_GraphOperationByName(Graph.get(), (InputSpec.name()).c_str()),
386         InputSpec.port()};
387     if (!checkReportAndInvalidate(InputFeed[I], InputSpec))
388       return;
389     initInput(I, static_cast<TF_DataType>(InputSpec.typeIndex()),
390               InputSpec.shape());
391   }
392   for (size_t I = 0; I < OutputSpecsSize; ++I) {
393     auto OutputSpec = GetOutputSpecs(I);
394     OutputFeed[I] = {
395         TF_GraphOperationByName(Graph.get(), (OutputSpec.name()).c_str()),
396         OutputSpec.port()};
397     if (!checkReportAndInvalidate(OutputFeed[I], OutputSpec))
398       return;
399   }
400 }
401 
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,function_ref<TensorSpec (size_t)> GetOutputSpecs,size_t OutputSpecsSize,const char * Tags)402 TFModelEvaluator::TFModelEvaluator(
403     StringRef SavedModelPath, const std::vector<TensorSpec> &InputSpecs,
404     function_ref<TensorSpec(size_t)> GetOutputSpecs, size_t OutputSpecsSize,
405     const char *Tags)
406     : Impl(new TFModelEvaluatorImpl(SavedModelPath, InputSpecs, GetOutputSpecs,
407                                     OutputSpecsSize, Tags)) {
408   if (!Impl->isValid())
409     Impl.reset();
410 }
411 
TFModelEvaluator(StringRef SavedModelPath,const std::vector<TensorSpec> & InputSpecs,const std::vector<TensorSpec> & OutputSpecs,const char * Tags)412 TFModelEvaluator::TFModelEvaluator(StringRef SavedModelPath,
413                                    const std::vector<TensorSpec> &InputSpecs,
414                                    const std::vector<TensorSpec> &OutputSpecs,
415                                    const char *Tags)
416     : TFModelEvaluator(
417           SavedModelPath, InputSpecs, [&](size_t I) { return OutputSpecs[I]; },
418           OutputSpecs.size(), Tags) {}
419 
~TFModelEvaluatorImpl()420 TFModelEvaluatorImpl::~TFModelEvaluatorImpl() {
421   for (auto *T : Input) {
422     TF_DeleteTensor(T);
423   }
424   if (Session == nullptr)
425     return;
426   auto Status = createTFStatus();
427   TF_DeleteSession(Session, Status.get());
428   Session = nullptr;
429   if (TF_GetCode(Status.get()) != TF_Code::TF_OK)
430     errs() << "Could not delete TF session";
431 }
432 
checkReportAndInvalidate(const TF_Output & Output,const TensorSpec & OutputSpec)433 bool TFModelEvaluatorImpl::checkReportAndInvalidate(
434     const TF_Output &Output, const TensorSpec &OutputSpec) {
435   if (Output.oper)
436     return true;
437   errs() << "Could not find TF_Output named: " + OutputSpec.name();
438   IsValid = false;
439   return IsValid;
440 }
441 
evaluate()442 Optional<TFModelEvaluator::EvaluationResult> TFModelEvaluator::evaluate() {
443   if (!isValid())
444     return None;
445   std::unique_ptr<EvaluationResultImpl> Ret =
446       std::make_unique<EvaluationResultImpl>(Impl->OutputSize());
447   auto Status = createTFStatus();
448   Impl->evaluate(Ret->getOutput().data(), Status.get());
449   if (TF_GetCode(Status.get()) != TF_Code::TF_OK) {
450     errs() << TF_Message(Status.get());
451     Impl.reset();
452     return None;
453   }
454   return EvaluationResult(std::move(Ret));
455 }
456 
initInput(size_t Index,TF_DataType Type,const std::vector<int64_t> & Dimensions)457 void TFModelEvaluatorImpl::initInput(size_t Index, TF_DataType Type,
458                                      const std::vector<int64_t> &Dimensions) {
459   int64_t TotalSize = TF_DataTypeSize(Type);
460   for (auto &D : Dimensions)
461     TotalSize *= D;
462 
463   Input[Index] =
464       TF_AllocateTensor(Type, Dimensions.data(), Dimensions.size(), TotalSize);
465   std::memset(TF_TensorData(Input[Index]), 0, TotalSize);
466 }
467 
getUntypedInput(size_t Index)468 void *TFModelEvaluator::getUntypedInput(size_t Index) {
469   return TF_TensorData(Impl->getInput()[Index]);
470 }
471 
EvaluationResult(std::unique_ptr<EvaluationResultImpl> Impl)472 TFModelEvaluator::EvaluationResult::EvaluationResult(
473     std::unique_ptr<EvaluationResultImpl> Impl)
474     : Impl(std::move(Impl)) {}
475 
EvaluationResult(EvaluationResult && Other)476 TFModelEvaluator::EvaluationResult::EvaluationResult(EvaluationResult &&Other)
477     : Impl(std::move(Other.Impl)) {}
478 
479 TFModelEvaluator::EvaluationResult &
operator =(EvaluationResult && Other)480 TFModelEvaluator::EvaluationResult::operator=(EvaluationResult &&Other) {
481   Impl = std::move(Other.Impl);
482   return *this;
483 }
484 
getUntypedTensorValue(size_t Index)485 void *TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) {
486   return TF_TensorData(Impl->getOutput()[Index]);
487 }
488 
489 const void *
getUntypedTensorValue(size_t Index) const490 TFModelEvaluator::EvaluationResult::getUntypedTensorValue(size_t Index) const {
491   return TF_TensorData(Impl->getOutput()[Index]);
492 }
493 
494 #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
495   template <> int TensorSpec::getDataType<T>() { return E; }
496 
TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)497 TFUTILS_SUPPORTED_TYPES(TFUTILS_GETDATATYPE_IMPL)
498 
499 #undef TFUTILS_GETDATATYPE_IMPL
500 
501 TFModelEvaluator::EvaluationResult::~EvaluationResult() {}
~TFModelEvaluator()502 TFModelEvaluator::~TFModelEvaluator() {}
503 
Logger(const std::vector<LoggedFeatureSpec> & FeatureSpecs,const TensorSpec & RewardSpec,bool IncludeReward)504 Logger::Logger(const std::vector<LoggedFeatureSpec> &FeatureSpecs,
505                const TensorSpec &RewardSpec, bool IncludeReward)
506     : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
507       IncludeReward(IncludeReward),
508       LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
509                                                   IncludeReward)) {}
510 
~Logger()511 Logger::~Logger() {}
512 
513 #define LOG_REWARD(NAME, TYPE)                                                 \
514   void Logger::log##NAME##Reward(TYPE Value) {                                 \
515     assert(IncludeReward);                                                     \
516     LoggerData->logReward(Value);                                              \
517   }
518 
LOG_REWARD(Float,float)519 LOG_REWARD(Float, float)
520 LOG_REWARD(Int32, int32_t)
521 LOG_REWARD(Int64, int64_t)
522 #undef LOG_REWARD
523 
524 #define LOG_FINAL_REWARD(NAME, TYPE)                                           \
525   void Logger::log##NAME##FinalReward(TYPE Value) {                            \
526     assert(RewardSpec.isElementType<TYPE>());                                  \
527     for (size_t I = 1; I < LoggerData->getNrRecords(); ++I)                    \
528       log##NAME##Reward(0);                                                    \
529     log##NAME##Reward(Value);                                                  \
530   }
531 
532 LOG_FINAL_REWARD(Float, float)
533 LOG_FINAL_REWARD(Int32, int32_t)
534 LOG_FINAL_REWARD(Int64, int64_t)
535 #undef LOG_FINAL_REWARD
536 
537 void Logger::logFloatValue(size_t FeatureID, const float *Value) {
538   assert(FeatureSpecs[FeatureID].Spec.isElementType<float>());
539   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
540 }
541 
logInt64Value(size_t FeatureID,const int64_t * Value)542 void Logger::logInt64Value(size_t FeatureID, const int64_t *Value) {
543   assert(FeatureSpecs[FeatureID].Spec.isElementType<int64_t>());
544   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
545 }
546 
logInt32Value(size_t FeatureID,const int32_t * Value)547 void Logger::logInt32Value(size_t FeatureID, const int32_t *Value) {
548   assert(FeatureSpecs[FeatureID].Spec.isElementType<int32_t>());
549   logSpecifiedTensorValue(FeatureID, reinterpret_cast<const char *>(Value));
550 }
551 
logSpecifiedTensorValue(size_t FeatureID,const char * RawData)552 void Logger::logSpecifiedTensorValue(size_t FeatureID, const char *RawData) {
553   const auto &Spec = FeatureSpecs[FeatureID].Spec;
554   char *Buff = addEntryAndGetFloatOrInt64Buffer(FeatureID);
555   if (Spec.isElementType<int32_t>())
556     for (size_t I = 0; I < Spec.getElementCount(); ++I)
557       (reinterpret_cast<int64_t *>(Buff))[I] =
558           static_cast<int64_t>((reinterpret_cast<const int32_t *>(RawData))[I]);
559   else if (Spec.isElementType<int64_t>() || Spec.isElementType<float>())
560     std::memcpy(Buff, RawData,
561                 Spec.getElementCount() * Spec.getElementByteSize());
562   else
563     llvm_unreachable("Unsupported tensor type");
564 }
565 
addEntryAndGetFloatOrInt64Buffer(size_t FeatureID)566 char *Logger::addEntryAndGetFloatOrInt64Buffer(size_t FeatureID) {
567   return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
568 }
569 
flush(raw_ostream & OS)570 void Logger::flush(raw_ostream &OS) { LoggerData->flush(OS); }
571 #endif // defined(LLVM_HAVE_TF_API)
572