1 //===- TrainingLogger.cpp - mlgo feature/reward logging -------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // This file implements logging infrastructure for extracting features and
10 // rewards for mlgo policy training.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/Analysis/TensorSpec.h"
14 #include "llvm/Config/config.h"
15 
16 #include "llvm/ADT/Twine.h"
17 #include "llvm/Analysis/Utils/TrainingLogger.h"
18 #include "llvm/Support/CommandLine.h"
19 #include "llvm/Support/Debug.h"
20 #include "llvm/Support/JSON.h"
21 #include "llvm/Support/MemoryBuffer.h"
22 #include "llvm/Support/Path.h"
23 #include "llvm/Support/raw_ostream.h"
24 
25 #include <cassert>
26 #include <numeric>
27 
28 using namespace llvm;
29 
30 // FIXME(mtrofin): remove the flag altogether
31 static cl::opt<bool>
32     UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden,
33                     cl::desc("Output simple (non-protobuf) log."));
34 
35 void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
36   json::OStream JOS(*OS);
37   JOS.object([&]() {
38     JOS.attributeArray("features", [&]() {
39       for (const auto &TS : FeatureSpecs)
40         TS.toJSON(JOS);
41     });
42     if (IncludeReward) {
43       JOS.attributeBegin("score");
44       RewardSpec.toJSON(JOS);
45       JOS.attributeEnd();
46     }
47     if (AdviceSpec.has_value()) {
48       JOS.attributeBegin("advice");
49       AdviceSpec->toJSON(JOS);
50       JOS.attributeEnd();
51     }
52   });
53   *OS << "\n";
54 }
55 
56 void Logger::switchContext(StringRef Name) {
57   CurrentContext = Name.str();
58   json::OStream JOS(*OS);
59   JOS.object([&]() { JOS.attribute("context", Name); });
60   *OS << "\n";
61 }
62 
63 void Logger::startObservation() {
64   auto I = ObservationIDs.insert({CurrentContext, 0});
65   size_t NewObservationID = I.second ? 0 : ++I.first->second;
66   json::OStream JOS(*OS);
67   JOS.object([&]() {
68     JOS.attribute("observation", static_cast<int64_t>(NewObservationID));
69   });
70   *OS << "\n";
71 }
72 
73 void Logger::endObservation() { *OS << "\n"; }
74 
75 void Logger::logRewardImpl(const char *RawData) {
76   assert(IncludeReward);
77   json::OStream JOS(*OS);
78   JOS.object([&]() {
79     JOS.attribute("outcome", static_cast<int64_t>(
80                                  ObservationIDs.find(CurrentContext)->second));
81   });
82   *OS << "\n";
83   writeTensor(RewardSpec, RawData);
84   *OS << "\n";
85 }
86 
87 Logger::Logger(std::unique_ptr<raw_ostream> OS,
88                const std::vector<TensorSpec> &FeatureSpecs,
89                const TensorSpec &RewardSpec, bool IncludeReward,
90                std::optional<TensorSpec> AdviceSpec)
91     : OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
92       IncludeReward(IncludeReward) {
93   writeHeader(AdviceSpec);
94 }
95