15ffd83dbSDimitry Andric //===- MLModelRunner.h ---- ML model runner interface -----------*- C++ -*-===// 25ffd83dbSDimitry Andric // 35ffd83dbSDimitry Andric // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. 45ffd83dbSDimitry Andric // See https://llvm.org/LICENSE.txt for license information. 55ffd83dbSDimitry Andric // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception 65ffd83dbSDimitry Andric // 75ffd83dbSDimitry Andric //===----------------------------------------------------------------------===// 85ffd83dbSDimitry Andric // 95ffd83dbSDimitry Andric 105ffd83dbSDimitry Andric #ifndef LLVM_ANALYSIS_MLMODELRUNNER_H 115ffd83dbSDimitry Andric #define LLVM_ANALYSIS_MLMODELRUNNER_H 125ffd83dbSDimitry Andric 1381ad6265SDimitry Andric #include "llvm/Analysis/TensorSpec.h" 145ffd83dbSDimitry Andric #include "llvm/IR/PassManager.h" 155ffd83dbSDimitry Andric 165ffd83dbSDimitry Andric namespace llvm { 1781ad6265SDimitry Andric class LLVMContext; 185ffd83dbSDimitry Andric 195ffd83dbSDimitry Andric /// MLModelRunner interface: abstraction of a mechanism for evaluating a 20*5f757f3fSDimitry Andric /// ML model. More abstractly, evaluating a function that has as tensors as 21*5f757f3fSDimitry Andric /// arguments, described via TensorSpecs, and returns a tensor. Currently, the 22*5f757f3fSDimitry Andric /// latter is assumed to be a scalar, in absence of more elaborate scenarios. 230eae32dcSDimitry Andric /// NOTE: feature indices are expected to be consistent all accross 240eae32dcSDimitry Andric /// MLModelRunners (pertaining to the same model), and also Loggers (see 250eae32dcSDimitry Andric /// TFUtils.h) 265ffd83dbSDimitry Andric class MLModelRunner { 275ffd83dbSDimitry Andric public: 285ffd83dbSDimitry Andric // Disallows copy and assign. 295ffd83dbSDimitry Andric MLModelRunner(const MLModelRunner &) = delete; 305ffd83dbSDimitry Andric MLModelRunner &operator=(const MLModelRunner &) = delete; 315ffd83dbSDimitry Andric virtual ~MLModelRunner() = default; 325ffd83dbSDimitry Andric evaluate()330eae32dcSDimitry Andric template <typename T> T evaluate() { 340eae32dcSDimitry Andric return *reinterpret_cast<T *>(evaluateUntyped()); 350eae32dcSDimitry Andric } 360eae32dcSDimitry Andric getTensor(I FeatureID)370eae32dcSDimitry Andric template <typename T, typename I> T *getTensor(I FeatureID) { 380eae32dcSDimitry Andric return reinterpret_cast<T *>( 390eae32dcSDimitry Andric getTensorUntyped(static_cast<size_t>(FeatureID))); 400eae32dcSDimitry Andric } 410eae32dcSDimitry Andric getTensor(I FeatureID)420eae32dcSDimitry Andric template <typename T, typename I> const T *getTensor(I FeatureID) const { 430eae32dcSDimitry Andric return reinterpret_cast<const T *>( 440eae32dcSDimitry Andric getTensorUntyped(static_cast<size_t>(FeatureID))); 450eae32dcSDimitry Andric } 465ffd83dbSDimitry Andric getTensorUntyped(size_t Index)4781ad6265SDimitry Andric void *getTensorUntyped(size_t Index) { return InputBuffers[Index]; } getTensorUntyped(size_t Index)480eae32dcSDimitry Andric const void *getTensorUntyped(size_t Index) const { 490eae32dcSDimitry Andric return (const_cast<MLModelRunner *>(this))->getTensorUntyped(Index); 500eae32dcSDimitry Andric } 515ffd83dbSDimitry Andric 5206c3fb27SDimitry Andric enum class Kind : int { Unknown, Release, Development, NoOp, Interactive }; getKind()5304eeddc0SDimitry Andric Kind getKind() const { return Type; } switchContext(StringRef Name)5406c3fb27SDimitry Andric virtual void switchContext(StringRef Name) {} 5504eeddc0SDimitry Andric 5604eeddc0SDimitry Andric protected: MLModelRunner(LLVMContext & Ctx,Kind Type,size_t NrInputs)5781ad6265SDimitry Andric MLModelRunner(LLVMContext &Ctx, Kind Type, size_t NrInputs) 5881ad6265SDimitry Andric : Ctx(Ctx), Type(Type), InputBuffers(NrInputs) { 5904eeddc0SDimitry Andric assert(Type != Kind::Unknown); 6004eeddc0SDimitry Andric } 6104eeddc0SDimitry Andric virtual void *evaluateUntyped() = 0; 6204eeddc0SDimitry Andric setUpBufferForTensor(size_t Index,const TensorSpec & Spec,void * Buffer)6381ad6265SDimitry Andric void setUpBufferForTensor(size_t Index, const TensorSpec &Spec, 6481ad6265SDimitry Andric void *Buffer) { 6581ad6265SDimitry Andric if (!Buffer) { 6681ad6265SDimitry Andric OwnedBuffers.emplace_back(Spec.getTotalTensorBufferSize()); 6781ad6265SDimitry Andric Buffer = OwnedBuffers.back().data(); 6881ad6265SDimitry Andric } 6981ad6265SDimitry Andric InputBuffers[Index] = Buffer; 7081ad6265SDimitry Andric } 7181ad6265SDimitry Andric 725ffd83dbSDimitry Andric LLVMContext &Ctx; 7304eeddc0SDimitry Andric const Kind Type; 7481ad6265SDimitry Andric 7581ad6265SDimitry Andric private: 7681ad6265SDimitry Andric std::vector<void *> InputBuffers; 7781ad6265SDimitry Andric std::vector<std::vector<char *>> OwnedBuffers; 785ffd83dbSDimitry Andric }; 795ffd83dbSDimitry Andric } // namespace llvm 805ffd83dbSDimitry Andric 815ffd83dbSDimitry Andric #endif // LLVM_ANALYSIS_MLMODELRUNNER_H 82