1 //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// 2 // 3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 4 // See https://llvm.org/LICENSE.txt for license information. 5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 6 // 7 //===----------------------------------------------------------------------===// 8 // 9 10 #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H 11 #define LLVM_ANALYSIS_MLMODELRUNNER_H 12 13 #include "llvm/Analysis/TensorSpec.h" 14 #include "llvm/IR/PassManager.h" 15 16 namespace llvm { 17 class LLVMContext; 18 19 /// MLModelRunner interface: abstraction of a mechanism for evaluating a 20 /// tensorflow "saved model". 21 /// NOTE: feature indices are expected to be consistent all accross 22 /// MLModelRunners (pertaining to the same model), and also Loggers (see 23 /// TFUtils.h) 24 class MLModelRunner { 25 public: 26 // Disallows copy and assign. 27 MLModelRunner(const MLModelRunner &) = delete; 28 MLModelRunner &operator=(const MLModelRunner &) = delete; 29 virtual ~MLModelRunner() = default; 30 31 template <typename T> T evaluate() { 32 return *reinterpret_cast<T *>(evaluateUntyped()); 33 } 34 35 template <typename T, typename I> T *getTensor(I FeatureID) { 36 return reinterpret_cast<T *>( 37 getTensorUntyped(static_cast<size_t>(FeatureID))); 38 } 39 40 template <typename T, typename I> const T *getTensor(I FeatureID) const { 41 return reinterpret_cast<const T *>( 42 getTensorUntyped(static_cast<size_t>(FeatureID))); 43 } 44 45 void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } 46 const void *getTensorUntyped(size_t Index) const { 47 return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index); 48 } 49 50 enum class Kind : int { Unknown, Release, Development, NoOp }; 51 Kind getKind() const { return Type; } 52 53 protected: 54 MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs) 55 : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) { 56 assert(Type != Kind::Unknown); 57 } 58 virtual void *evaluateUntyped() = 0; 59 60 void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, 61 void *Buffer) { 62 if (!Buffer) { 63 OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize()); 64 Buffer = OwnedBuffers.back().data(); 65 } 66 InputBuffers[Index] = Buffer; 67 } 68 69 LLVMContext &Ctx; 70 const Kind Type; 71 72 private: 73 std::vector<void *> InputBuffers; 74 std::vector<std::vector<char *>> OwnedBuffers; 75 }; 76 } // namespace llvm 77 78 #endif // LLVM_ANALYSIS_MLMODELRUNNER_H 79