1 //===- ModelUnderTrainingRunner.h -- 'development' mode runner --*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #ifndef LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 11 #define LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 12 13 #include "llvm/Analysis/TensorSpec.h" 14 #include "llvm/Config/llvm-config.h" 15 16 #ifdef LLVM_HAVE_TF_API 17 #include "llvm/Analysis/MLModelRunner.h" 18 #include "llvm/Analysis/Utils/TFUtils.h" 19 #include "llvm/IR/LLVMContext.h" 20 #include "llvm/IR/PassManager.h" 21 22 namespace llvm { 23 24 /// ModelUnderTrainingRunner - training mode implementation. It uses TF C APIs 25 /// to dynamically load and evaluate a TF SavedModel 26 /// (https://www.tensorflow.org/guide/saved_model). Runtime performance is 27 /// sacrificed for ease of use while training. 28 class ModelUnderTrainingRunner final : public MLModelRunner { 29 public: 30 // Disallows copy and assign. 31 ModelUnderTrainingRunner(const ModelUnderTrainingRunner &) = delete; 32 ModelUnderTrainingRunner & 33 operator=(const ModelUnderTrainingRunner &) = delete; 34 35 const std::vector<LoggedFeatureSpec> &outputLoggedFeatureSpecs() const { 36 return OutputSpecs; 37 } 38 39 const Optional<TFModelEvaluator::EvaluationResult> & 40 lastEvaluationResult() const { 41 return LastEvaluationResult; 42 } 43 static bool classof(const MLModelRunner *R) { 44 return R->getKind() == MLModelRunner::Kind::Development; 45 } 46 47 static std::unique_ptr<ModelUnderTrainingRunner> 48 createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath, 49 StringRef DecisionName, 50 const std::vector<TensorSpec> &InputSpecs, 51 StringRef OutputSpecsPathOverride = ""); 52 static std::unique_ptr<ModelUnderTrainingRunner> 53 createAndEnsureValid(LLVMContext &Ctx, const std::string &ModelPath, 54 StringRef DecisionName, 55 const std::vector<TensorSpec> &InputSpecs, 56 const std::vector<LoggedFeatureSpec> &OutputSpecs); 57 58 private: 59 ModelUnderTrainingRunner(LLVMContext &Ctx, const std::string &ModelPath, 60 const std::vector<TensorSpec> &InputSpecs, 61 const std::vector<LoggedFeatureSpec> &OutputSpecs); 62 63 std::unique_ptr<TFModelEvaluator> Evaluator; 64 const std::vector<LoggedFeatureSpec> OutputSpecs; 65 Optional<TFModelEvaluator::EvaluationResult> LastEvaluationResult; 66 void *evaluateUntyped() override; 67 bool isValid() const { return !!Evaluator; } 68 }; 69 70 } // namespace llvm 71 #endif // define(LLVM_HAVE_TF_API) 72 #endif // LLVM_ANALYSIS_MODELUNDERTRAININGRUNNER_H 73