1 //===- ModelUnderTrainingRunner.cpp - 'development' mode runner -----------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation of a MLModelRunner for 'development' mode, i.e. evaluation
10 // happens off a model that's provided from the command line and is interpreted.
11 //
12 //===----------------------------------------------------------------------===//
13 
14 #include "llvm/ADT/STLExtras.h"
15 #include "llvm/Config/config.h"
16 #if defined(LLVM_HAVE_TFLITE)
17 #include "llvm/Analysis/ModelUnderTrainingRunner.h"
18 #include "llvm/Support/MemoryBuffer.h"
19 #include "llvm/Support/Path.h"
20 #include <optional>
21 
22 using namespace llvm;
23 namespace {
24 struct LoggedFeatureSpec {
25   TensorSpec Spec;
26   std::optional<std::string> LoggingName;
27 };
28 
29 std::optional<std::vector<LoggedFeatureSpec>>
30 loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
31                 StringRef ModelPath, StringRef SpecFileOverride) {
32   SmallVector<char, 128> OutputSpecsPath;
33   StringRef FileName = SpecFileOverride;
34   if (FileName.empty()) {
35     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
36     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
37   }
38 
39   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
40   if (!BufferOrError) {
41     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
42                   BufferOrError.getError().message());
43     return std::nullopt;
44   }
45   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
46   if (!ParsedJSONValues) {
47     Ctx.emitError("Could not parse specs file: " + FileName);
48     return std::nullopt;
49   }
50   auto ValuesArray = ParsedJSONValues->getAsArray();
51   if (!ValuesArray) {
52     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
53                   "logging_name:<name>} dictionaries");
54     return std::nullopt;
55   }
56   std::vector<LoggedFeatureSpec> Ret;
57   for (const auto &Value : *ValuesArray)
58     if (const auto *Obj = Value.getAsObject())
59       if (const auto *SpecPart = Obj->get("tensor_spec"))
60         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
61           if (auto LoggingName = Obj->getString("logging_name")) {
62             if (!TensorSpec->isElementType<int64_t>() &&
63                 !TensorSpec->isElementType<int32_t>() &&
64                 !TensorSpec->isElementType<float>()) {
65               Ctx.emitError(
66                   "Only int64, int32, and float tensors are supported. "
67                   "Found unsupported type for tensor named " +
68                   TensorSpec->name());
69               return std::nullopt;
70             }
71             Ret.push_back({*TensorSpec, LoggingName->str()});
72           }
73 
74   if (ValuesArray->size() != Ret.size()) {
75     Ctx.emitError(
76         "Unable to parse output spec. It should be a json file containing an "
77         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
78         "with a json object describing a TensorSpec; and a 'logging_name' key, "
79         "which is a string to use as name when logging this tensor in the "
80         "training log.");
81     return std::nullopt;
82   }
83   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
84     Ctx.emitError("The first output spec must describe the decision tensor, "
85                   "and must have the logging_name " +
86                   StringRef(ExpectedDecisionName));
87     return std::nullopt;
88   }
89   return Ret;
90 }
91 } // namespace
92 
93 ModelUnderTrainingRunner::ModelUnderTrainingRunner(
94     LLVMContext &Ctx, const std::string &ModelPath,
95     const std::vector<TensorSpec> &InputSpecs,
96     const std::vector<TensorSpec> &OutputSpecs,
97     const std::vector<TensorSpec> &ExtraOutputsForLogging)
98     : MLModelRunner(Ctx, MLModelRunner::Kind::Development, InputSpecs.size()),
99       OutputSpecs(OutputSpecs), ExtraOutputsForLogging(ExtraOutputsForLogging) {
100   Evaluator =
101       std::make_unique<TFModelEvaluator>(ModelPath, InputSpecs, OutputSpecs);
102   if (!Evaluator || !Evaluator->isValid()) {
103     Ctx.emitError("Failed to create saved model evaluator");
104     Evaluator.reset();
105     return;
106   }
107 
108   for (size_t I = 0, E = InputSpecs.size(); I < E; ++I) {
109     setUpBufferForTensor(I, InputSpecs[I], Evaluator->getUntypedInput(I));
110   }
111 }
112 
113 void *ModelUnderTrainingRunner::evaluateUntyped() {
114   LastEvaluationResult = Evaluator->evaluate();
115   if (!LastEvaluationResult.has_value()) {
116     Ctx.emitError("Error evaluating model.");
117     return nullptr;
118   }
119   return LastEvaluationResult->getUntypedTensorValue(0);
120 }
121 
122 std::unique_ptr<ModelUnderTrainingRunner>
123 ModelUnderTrainingRunner::createAndEnsureValid(
124     LLVMContext &Ctx, const std::string &ModelPath, StringRef DecisionName,
125     const std::vector<TensorSpec> &InputSpecs,
126     StringRef OutputSpecsPathOverride) {
127   if (auto MaybeOutputSpecs = loadOutputSpecs(Ctx, DecisionName, ModelPath,
128                                               OutputSpecsPathOverride)) {
129     std::unique_ptr<ModelUnderTrainingRunner> MUTR;
130     std::vector<TensorSpec> OutputSpecs;
131     std::vector<TensorSpec> ExtraOutputsForLogging;
132     append_range(OutputSpecs,
133                  map_range(*MaybeOutputSpecs, [](const LoggedFeatureSpec &LFS) {
134                    return LFS.Spec;
135                  }));
136     append_range(ExtraOutputsForLogging,
137                  map_range(drop_begin(*MaybeOutputSpecs),
138                            [](const LoggedFeatureSpec &LFS) {
139                              return TensorSpec(LFS.LoggingName
140                                                    ? *LFS.LoggingName
141                                                    : LFS.Spec.name(),
142                                                LFS.Spec);
143                            }));
144 
145     MUTR.reset(new ModelUnderTrainingRunner(
146         Ctx, ModelPath, InputSpecs, OutputSpecs, ExtraOutputsForLogging));
147     if (MUTR && MUTR->isValid())
148       return MUTR;
149 
150     Ctx.emitError("Could not load or create model evaluator.");
151     return nullptr;
152   }
153   Ctx.emitError("Could not load the policy model from the provided path");
154   return nullptr;
155 }
156 
157 #endif // defined(LLVM_HAVE_TFLITE)
158