1 //===- TensorSpec.cpp - tensor type abstraction ---------------------------===//
2 //
3 // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 // See https://llvm.org/LICENSE.txt for license information.
5 // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 //
7 //===----------------------------------------------------------------------===//
8 //
9 // Implementation file for the abstraction of a tensor type, and JSON loading
10 // utils.
11 //
12 //===----------------------------------------------------------------------===//
13 #include "llvm/Config/config.h"
14 
15 #include "llvm/ADT/Twine.h"
16 #include "llvm/Analysis/TensorSpec.h"
17 #include "llvm/Support/CommandLine.h"
18 #include "llvm/Support/Debug.h"
19 #include "llvm/Support/JSON.h"
20 #include "llvm/Support/ManagedStatic.h"
21 #include "llvm/Support/MemoryBuffer.h"
22 #include "llvm/Support/Path.h"
23 #include "llvm/Support/raw_ostream.h"
24 #include <cassert>
25 #include <numeric>
26 
27 using namespace llvm;
28 
29 namespace llvm {
30 
31 #define TFUTILS_GETDATATYPE_IMPL(T, E)                                         \
32   template <> TensorType TensorSpec::getDataType<T>() { return TensorType::E; }
33 
34 SUPPORTED_TENSOR_TYPES(TFUTILS_GETDATATYPE_IMPL)
35 
36 #undef TFUTILS_GETDATATYPE_IMPL
37 
38 TensorSpec::TensorSpec(const std::string &Name, int Port, TensorType Type,
39                        size_t ElementSize, const std::vector<int64_t> &Shape)
40     : Name(Name), Port(Port), Type(Type), Shape(Shape),
41       ElementCount(std::accumulate(Shape.begin(), Shape.end(), 1,
42                                    std::multiplies<int64_t>())),
43       ElementSize(ElementSize) {}
44 
45 Optional<TensorSpec> getTensorSpecFromJSON(LLVMContext &Ctx,
46                                            const json::Value &Value) {
47   auto EmitError = [&](const llvm::Twine &Message) -> Optional<TensorSpec> {
48     std::string S;
49     llvm::raw_string_ostream OS(S);
50     OS << Value;
51     Ctx.emitError("Unable to parse JSON Value as spec (" + Message + "): " + S);
52     return None;
53   };
54   // FIXME: accept a Path as a parameter, and use it for error reporting.
55   json::Path::Root Root("tensor_spec");
56   json::ObjectMapper Mapper(Value, Root);
57   if (!Mapper)
58     return EmitError("Value is not a dict");
59 
60   std::string TensorName;
61   int TensorPort = -1;
62   std::string TensorType;
63   std::vector<int64_t> TensorShape;
64 
65   if (!Mapper.map<std::string>("name", TensorName))
66     return EmitError("'name' property not present or not a string");
67   if (!Mapper.map<std::string>("type", TensorType))
68     return EmitError("'type' property not present or not a string");
69   if (!Mapper.map<int>("port", TensorPort))
70     return EmitError("'port' property not present or not an int");
71   if (!Mapper.map<std::vector<int64_t>>("shape", TensorShape))
72     return EmitError("'shape' property not present or not an int array");
73 
74 #define PARSE_TYPE(T, E)                                                       \
75   if (TensorType == #T)                                                        \
76     return TensorSpec::createSpec<T>(TensorName, TensorShape, TensorPort);
77   SUPPORTED_TENSOR_TYPES(PARSE_TYPE)
78 #undef PARSE_TYPE
79   return None;
80 }
81 
82 Optional<std::vector<LoggedFeatureSpec>>
83 loadOutputSpecs(LLVMContext &Ctx, StringRef ExpectedDecisionName,
84                 StringRef ModelPath, StringRef SpecFileOverride) {
85   SmallVector<char, 128> OutputSpecsPath;
86   StringRef FileName = SpecFileOverride;
87   if (FileName.empty()) {
88     llvm::sys::path::append(OutputSpecsPath, ModelPath, "output_spec.json");
89     FileName = {OutputSpecsPath.data(), OutputSpecsPath.size()};
90   }
91 
92   auto BufferOrError = MemoryBuffer::getFileOrSTDIN(FileName);
93   if (!BufferOrError) {
94     Ctx.emitError("Error opening output specs file: " + FileName + " : " +
95                   BufferOrError.getError().message());
96     return None;
97   }
98   auto ParsedJSONValues = json::parse(BufferOrError.get()->getBuffer());
99   if (!ParsedJSONValues) {
100     Ctx.emitError("Could not parse specs file: " + FileName);
101     return None;
102   }
103   auto ValuesArray = ParsedJSONValues->getAsArray();
104   if (!ValuesArray) {
105     Ctx.emitError("Expected an array of {tensor_spec:<TensorSpec>, "
106                   "logging_name:<name>} dictionaries");
107     return None;
108   }
109   std::vector<LoggedFeatureSpec> Ret;
110   for (const auto &Value : *ValuesArray)
111     if (const auto *Obj = Value.getAsObject())
112       if (const auto *SpecPart = Obj->get("tensor_spec"))
113         if (auto TensorSpec = getTensorSpecFromJSON(Ctx, *SpecPart))
114           if (auto LoggingName = Obj->getString("logging_name")) {
115             if (!TensorSpec->isElementType<int64_t>() &&
116                 !TensorSpec->isElementType<int32_t>() &&
117                 !TensorSpec->isElementType<float>()) {
118               Ctx.emitError(
119                   "Only int64, int32, and float tensors are supported. "
120                   "Found unsupported type for tensor named " +
121                   TensorSpec->name());
122               return None;
123             }
124             Ret.push_back({*TensorSpec, LoggingName->str()});
125           }
126 
127   if (ValuesArray->size() != Ret.size()) {
128     Ctx.emitError(
129         "Unable to parse output spec. It should be a json file containing an "
130         "array of dictionaries. Each dictionary must have a 'tensor_spec' key, "
131         "with a json object describing a TensorSpec; and a 'logging_name' key, "
132         "which is a string to use as name when logging this tensor in the "
133         "training log.");
134     return None;
135   }
136   if (Ret.empty() || *Ret[0].LoggingName != ExpectedDecisionName) {
137     Ctx.emitError("The first output spec must describe the decision tensor, "
138                   "and must have the logging_name " +
139                   StringRef(ExpectedDecisionName));
140     return None;
141   }
142   return Ret;
143 }
144 } // namespace llvm
145