1 //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H 11 #define LLVM_ANALYSIS_MLMODELRUNNER_H 12 13 #include "llvm/Analysis/TensorSpec.h" 14 #include "llvm/IR/PassManager.h" 15 16 namespace llvm { 17 class LLVMContext; 18 19 /// MLModelRunner interface: abstraction of a mechanism for evaluating a 20 /// tensorflow "saved model". 21 /// NOTE: feature indices are expected to be consistent all accross 22 /// MLModelRunners (pertaining to the same model), and also Loggers (see 23 /// TFUtils.h) 24 class MLModelRunner { 25 public: 26 // Disallows copy and assign. 27 MLModelRunner(const MLModelRunner &) = delete; 28 MLModelRunner &operator=(const MLModelRunner &) = delete; 29 virtual ~MLModelRunner() = default; 30 31 template <typename T> T evaluate() { 32 return *reinterpret_cast<T *>(evaluateUntyped()); 33 } 34 35 template <typename T, typename I> T *getTensor(I FeatureID) { 36 return reinterpret_cast<T *>( 37 getTensorUntyped(static_cast<size_t>(FeatureID))); 38 } 39 40 template <typename T, typename I> const T *getTensor(I FeatureID) const { 41 return reinterpret_cast<const T *>( 42 getTensorUntyped(static_cast<size_t>(FeatureID))); 43 } 44 45 void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } 46 const void *getTensorUntyped(size_t Index) const { 47 return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index); 48 } 49 50 enum class Kind : int { Unknown, Release, Development, NoOp, Interactive }; 51 Kind getKind() const { return Type; } 52 virtual void switchContext(StringRef Name) {} 53 54 protected: 55 MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs) 56 : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) { 57 assert(Type != Kind::Unknown); 58 } 59 virtual void *evaluateUntyped() = 0; 60 61 void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, 62 void *Buffer) { 63 if (!Buffer) { 64 OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize()); 65 Buffer = OwnedBuffers.back().data(); 66 } 67 InputBuffers[Index] = Buffer; 68 } 69 70 LLVMContext &Ctx; 71 const Kind Type; 72 73 private: 74 std::vector<void *> InputBuffers; 75 std::vector<std::vector<char *>> OwnedBuffers; 76 }; 77 } // namespace llvm 78 79 #endif // LLVM_ANALYSIS_MLMODELRUNNER_H 80