1 //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 
10 #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
11 #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
12 
13 #include "llvm/ADT/STLExtras.h"
14 #include "llvm/ADT/iterator_range.h"
15 #include "llvm/Analysis/TensorSpec.h"
16 #include "llvm/Config/llvm-config.h"
17 
18 #ifdef LLVM_HAVE_TFLITE
19 #include "llvm/Analysis/MLModelRunner.h"
20 #include "llvm/Analysis/Utils/TFUtils.h"
21 #include "llvm/IR/LLVMContext.h"
22 #include "llvm/IR/PassManager.h"
23 
24 namespace llvm {
25 
26 /// ModelUnderTrainingRunner - training mode implementation. It uses TFLite
27 /// to dynamically load and evaluate a TF SavedModel
28 /// (https://www.tensorflow.org/guide/saved_model) converted to TFLite. see
29 /// lib/Analysis/models/saved-model-to-tflite.py. Runtime performance is
30 /// sacrificed for ease of use while training.
31 class ModelUnderTrainingRunner final : public MLModelRunner {
32 public:
33   // Disallows copy and assign.
34   ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete;
35   ModelUnderTrainingRunner &
36   operator=(const ModelUnderTrainingRunner &) = delete;
37 
extraOutputsForLoggingSpecs()38   const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const {
39     return ExtraOutputsForLogging;
40   }
41 
getUntypedExtraOutputValue(size_t ExtraOutputIndex)42   const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const {
43     return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1);
44   }
45 
46   const std::optional<TFModelEvaluator::EvaluationResult> &
lastEvaluationResult()47   lastEvaluationResult() const {
48     return LastEvaluationResult;
49   }
classof(const MLModelRunner * R)50   static bool classof(const MLModelRunner *R) {
51     return R->getKind() == MLModelRunner::Kind::Development;
52   }
53 
54   static std::unique_ptr<ModelUnderTrainingRunner>
55   createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath,
56                        StringRef DecisionName,
57                        const std::vector<TensorSpec> &InputSpecs,
58                        StringRef OutputSpecsPathOverride = "");
59 
60   ModelUnderTrainingRunner(
61       LLVMContext &Ctx, const std::string &ModelPath,
62       const std::vector<TensorSpec> &InputSpecs,
63       const std::vector<TensorSpec> &OutputSpecs,
64       const std::vector<TensorSpec> &ExtraOutputsForLogging = {});
65 
isValid()66   bool isValid() const { return !!Evaluator; }
67 
68 private:
69   std::unique_ptr<TFModelEvaluator> Evaluator;
70   const std::vector<TensorSpec> OutputSpecs;
71   const std::vector<TensorSpec> ExtraOutputsForLogging;
72   std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult;
73   void *evaluateUntyped() override;
74 };
75 
76 } // namespace llvm
77 #endif // define(LLVM_HAVE_TFLITE)
78 #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H
79