1 //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 11 #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 12 13 #include "llvm/ADT/STLExtras.h" 14 #include "llvm/ADT/iterator_range.h" 15 #include "llvm/Analysis/TensorSpec.h" 16 #include "llvm/Config/llvm-config.h" 17 18 #ifdef LLVM_HAVE_TFLITE 19 #include "llvm/Analysis/MLModelRunner.h" 20 #include "llvm/Analysis/Utils/TFUtils.h" 21 #include "llvm/IR/LLVMContext.h" 22 #include "llvm/IR/PassManager.h" 23 24 namespace llvm { 25 26 /// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs 27 /// to dynamically load and evaluate a TF SavedModel 28 /// (https://www.tensorflow.org/guide/saved_model). Runtime performance is 29 /// sacrificed for ease of use while training. 30 class ModelUnderTrainingRunner final : public MLModelRunner { 31 public: 32 // Disallows copy and assign. 33 ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete; 34 ModelUnderTrainingRunner & 35 operator=(const ModelUnderTrainingRunner &) = delete; 36 37 const std::vector<TensorSpec> &extraOutputsForLoggingSpecs() const { 38 return ExtraOutputsForLogging; 39 } 40 41 const void *getUntypedExtraOutputValue(size_t ExtraOutputIndex) const { 42 return lastEvaluationResult()->getUntypedTensorValue(ExtraOutputIndex + 1); 43 } 44 45 const std::optional<TFModelEvaluator::EvaluationResult> & 46 lastEvaluationResult() const { 47 return LastEvaluationResult; 48 } 49 static bool classof(const MLModelRunner *R) { 50 return R->getKind() == MLModelRunner::Kind::Development; 51 } 52 53 static std::unique_ptr<ModelUnderTrainingRunner> 54 createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath, 55 StringRef DecisionName, 56 const std::vector<TensorSpec> &InputSpecs, 57 StringRef OutputSpecsPathOverride = ""); 58 59 ModelUnderTrainingRunner( 60 LLVMContext &Ctx, const std::string &ModelPath, 61 const std::vector<TensorSpec> &InputSpecs, 62 const std::vector<TensorSpec> &OutputSpecs, 63 const std::vector<TensorSpec> &ExtraOutputsForLogging = {}); 64 65 bool isValid() const { return !!Evaluator; } 66 67 private: 68 std::unique_ptr<TFModelEvaluator> Evaluator; 69 const std::vector<TensorSpec> OutputSpecs; 70 const std::vector<TensorSpec> ExtraOutputsForLogging; 71 std::optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult; 72 void *evaluateUntyped() override; 73 }; 74 75 } // namespace llvm 76 #endif // define(LLVM_HAVE_TFLITE) 77 #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 78