1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  *
4  * Licensed under the Apache License, Version 2.0 (the "License");
5  * you may not use this file except in compliance with the License.
6  * You may obtain a copy of the License at
7  *
8  *     http://www.apache.org/licenses/LICENSE-2.0
9  *
10  * Unless required by applicable law or agreed to in writing, software
11  * distributed under the License is distributed on an "AS IS" BASIS,
12  * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13  * See the License for the specific language governing permissions and
14  * limitations under the License.
15  */
16 
17 #include <thrift/conformance/cpp2/AnyRegistry.h>
18 
19 #include <glog/logging.h>
20 
21 #include <folly/CppAttributes.h>
22 #include <folly/Demangle.h>
23 #include <folly/Singleton.h>
24 #include <folly/String.h>
25 #include <folly/io/Cursor.h>
26 #include <thrift/conformance/cpp2/Any.h>
27 #include <thrift/lib/cpp2/type/UniversalName.h>
28 
29 namespace apache::thrift::conformance {
30 using type::containsUniversalHash;
31 using type::findByUniversalHash;
32 using type::getUniversalHash;
33 using type::getUniversalHashPrefix;
34 using type::hash_size_t;
35 using type::maybeGetUniversalHashPrefix;
36 namespace detail {
37 
getGeneratedAnyRegistry()38 AnyRegistry& getGeneratedAnyRegistry() {
39   struct GeneratedTag {};
40   return folly::detail::createGlobal<AnyRegistry, GeneratedTag>();
41 }
42 
43 } // namespace detail
44 
45 namespace {
46 
maybeGetTypeHash(const ThriftTypeInfo & type,hash_size_t defaultTypeHashBytes=kDefaultTypeHashBytes)47 folly::fbstring maybeGetTypeHash(
48     const ThriftTypeInfo& type,
49     hash_size_t defaultTypeHashBytes = kDefaultTypeHashBytes) {
50   if (type.typeHashBytes_ref().has_value()) {
51     // Use the custom size.
52     defaultTypeHashBytes = type.typeHashBytes_ref().value_unchecked();
53   }
54   return maybeGetUniversalHashPrefix(
55       type::UniversalHashAlgorithm::Sha2_256,
56       type.get_uri(),
57       defaultTypeHashBytes);
58 }
59 
60 } // namespace
61 
TypeEntry(const std::type_info & typeInfo,ThriftTypeInfo type)62 AnyRegistry::TypeEntry::TypeEntry(
63     const std::type_info& typeInfo, ThriftTypeInfo type)
64     : typeInfo(typeInfo),
65       typeHash(maybeGetTypeHash(type)),
66       type(std::move(type)) {}
67 
registerType(const std::type_info & typeInfo,ThriftTypeInfo type)68 bool AnyRegistry::registerType(
69     const std::type_info& typeInfo, ThriftTypeInfo type) {
70   return registerTypeImpl(typeInfo, std::move(type)) != nullptr;
71 }
72 
registerSerializer(const std::type_info & type,const AnySerializer * serializer)73 bool AnyRegistry::registerSerializer(
74     const std::type_info& type, const AnySerializer* serializer) {
75   return registerSerializerImpl(
76       serializer, &registry_.at(std::type_index(type)));
77 }
78 
registerSerializer(const std::type_info & type,std::unique_ptr<AnySerializer> serializer)79 bool AnyRegistry::registerSerializer(
80     const std::type_info& type, std::unique_ptr<AnySerializer> serializer) {
81   return registerSerializerImpl(
82       std::move(serializer), &registry_.at(std::type_index(type)));
83 }
84 
getTypeUri(const std::type_info & type) const85 std::string_view AnyRegistry::getTypeUri(
86     const std::type_info& type) const noexcept {
87   const auto* entry = getTypeEntry(type);
88   if (entry == nullptr) {
89     return {};
90   }
91   return entry->type.get_uri();
92 }
93 
getTypeUri(const Any & value) const94 std::string_view AnyRegistry::getTypeUri(const Any& value) const noexcept {
95   const auto* entry = getTypeEntryFor(value);
96   if (entry == nullptr) {
97     return {};
98   }
99   return entry->type.get_uri();
100 }
101 
getTypeId(const Any & value) const102 const std::type_info& AnyRegistry::getTypeId(const Any& value) const {
103   return getAndCheckTypeEntryFor(value).typeInfo;
104 }
105 
106 // Same as above, except returns nullptr if the type has not been registered.
tryGetTypeId(const Any & value) const107 const std::type_info* AnyRegistry::tryGetTypeId(
108     const Any& value) const noexcept {
109   const auto* entry = getTypeEntryFor(value);
110   if (entry == nullptr) {
111     return nullptr;
112   }
113   return &entry->typeInfo;
114 }
115 
getSerializer(const std::type_info & type,const Protocol & protocol) const116 const AnySerializer* AnyRegistry::getSerializer(
117     const std::type_info& type, const Protocol& protocol) const noexcept {
118   return getSerializer(getTypeEntry(type), protocol);
119 }
120 
getSerializerByUri(const std::string_view uri,const Protocol & protocol) const121 const AnySerializer* AnyRegistry::getSerializerByUri(
122     const std::string_view uri, const Protocol& protocol) const noexcept {
123   return getSerializer(getTypeEntryByUri(uri), protocol);
124 }
125 
getSerializerByHash(type::UniversalHashAlgorithm alg,const folly::fbstring & typeHash,const Protocol & protocol) const126 const AnySerializer* AnyRegistry::getSerializerByHash(
127     type::UniversalHashAlgorithm alg,
128     const folly::fbstring& typeHash,
129     const Protocol& protocol) const {
130   if (alg != type::UniversalHashAlgorithm::Sha2_256) {
131     folly::throw_exception<std::runtime_error>(
132         "Unsupported hash algorithm: " + std::to_string(static_cast<int>(alg)));
133   }
134   return getSerializer(getTypeEntryByHash(typeHash), protocol);
135 }
136 
store(any_ref value,const Protocol & protocol) const137 Any AnyRegistry::store(any_ref value, const Protocol& protocol) const {
138   if (value.type() == typeid(Any)) {
139     // Use the Any specific overload.
140     return store(any_cast<const Any&>(value), protocol);
141   }
142 
143   const auto& entry = getAndCheckTypeEntry(value.type());
144   const auto& serializer = getAndCheckSerializer(entry, protocol);
145 
146   folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength());
147   // Allocate 16KB at a time; leave some room for the IOBuf overhead
148   constexpr size_t kDesiredGrowth = (1 << 14) - 64;
149   serializer.encode(value, folly::io::QueueAppender(&queue, kDesiredGrowth));
150 
151   Any result;
152   if (entry.typeHash.empty()) {
153     result.set_type(entry.type.get_uri());
154   } else {
155     result.set_typeHashPrefixSha2_256(entry.typeHash);
156   }
157   setProtocol(protocol, result);
158   result.set_data(queue.moveAsValue());
159   return result;
160 }
161 
store(const Any & value,const Protocol & protocol) const162 Any AnyRegistry::store(const Any& value, const Protocol& protocol) const {
163   if (hasProtocol(value, protocol)) {
164     return value;
165   }
166   return store(load(value), protocol);
167 }
168 
load(const Any & value,any_ref out) const169 void AnyRegistry::load(const Any& value, any_ref out) const {
170   const auto& entry = getAndCheckTypeEntryFor(value);
171   const auto& serializer = getAndCheckSerializer(entry, getProtocol(value));
172   folly::io::Cursor cursor(&*value.data_ref());
173   serializer.decode(entry.typeInfo, cursor, out);
174 }
175 
load(const Any & value) const176 std::any AnyRegistry::load(const Any& value) const {
177   std::any out;
178   load(value, out);
179   return out;
180 }
181 
debugString() const182 std::string AnyRegistry::debugString() const {
183   std::string result = "AnyRegistry[\n";
184   // Using the sorted map, hashIndex_, to produce stable results.
185   for (const auto& indx : hashIndex_) {
186     const TypeEntry& entry = *indx.second;
187     result += "  ";
188     result += entry.type.get_uri();
189     result += " (";
190     result += folly::hexlify(indx.first);
191     result += ")";
192     if (!entry.serializers.empty()) {
193       result += ":\n";
194       // Convert to a set, so output is deterministic.
195       std::set<Protocol> protocols;
196       for (const auto& ser : entry.serializers) {
197         protocols.emplace(ser.first);
198       }
199       for (const auto& protocol : protocols) {
200         result += "    ";
201         result += protocol.name();
202         result += ",\n";
203       }
204     } else {
205       result += ",\n";
206     }
207   }
208   result += "]";
209   return result;
210 }
211 
forceRegisterType(const std::type_info & typeInfo,std::string type)212 bool AnyRegistry::forceRegisterType(
213     const std::type_info& typeInfo, std::string type) {
214   if (getTypeEntryByUri(type) != nullptr) {
215     return false;
216   }
217 
218   ThriftTypeInfo info;
219   info.set_uri(std::move(type));
220   info.set_typeHashBytes(0);
221 
222   auto result = registry_.emplace(
223       std::type_index(typeInfo), TypeEntry(typeInfo, std::move(info)));
224   if (!result.second) {
225     return false;
226   }
227 
228   TypeEntry* entry = &result.first->second;
229   indexUri(*entry->type.uri_ref(), entry);
230   return true;
231 }
232 
registerTypeImpl(const std::type_info & typeInfo,ThriftTypeInfo type)233 auto AnyRegistry::registerTypeImpl(
234     const std::type_info& typeInfo, ThriftTypeInfo type) -> TypeEntry* {
235   validateThriftTypeInfo(type);
236   std::vector<folly::fbstring> typeHashs;
237   typeHashs.reserve(type.altUris_ref()->size() + 1);
238   if (!genTypeHashsAndCheckForConflicts(type, &typeHashs)) {
239     return nullptr;
240   }
241 
242   auto result = registry_.emplace(
243       std::type_index(typeInfo), TypeEntry(typeInfo, std::move(type)));
244   if (!result.second) {
245     return nullptr;
246   }
247 
248   TypeEntry* entry = &result.first->second;
249 
250   // Add to secondary indexes.
251   indexUri(*entry->type.uri_ref(), entry);
252   for (const auto& alias : *entry->type.altUris_ref()) {
253     indexUri(alias, entry);
254   }
255 
256   for (auto& hash : typeHashs) {
257     indexHash(std::move(hash), entry);
258   }
259   return &result.first->second;
260 }
261 
registerSerializerImpl(const AnySerializer * serializer,TypeEntry * entry)262 bool AnyRegistry::registerSerializerImpl(
263     const AnySerializer* serializer, TypeEntry* entry) {
264   if (serializer == nullptr) {
265     return false;
266   }
267   validateProtocol(serializer->getProtocol());
268   return entry->serializers.emplace(serializer->getProtocol(), serializer)
269       .second;
270 }
271 
registerSerializerImpl(std::unique_ptr<AnySerializer> serializer,TypeEntry * entry)272 bool AnyRegistry::registerSerializerImpl(
273     std::unique_ptr<AnySerializer> serializer, TypeEntry* entry) {
274   if (!registerSerializerImpl(serializer.get(), entry)) {
275     return false;
276   }
277   ownedSerializers_.emplace_front(std::move(serializer));
278   return true;
279 }
280 
genTypeHashsAndCheckForConflicts(std::string_view uri,std::vector<folly::fbstring> * typeHashs) const281 bool AnyRegistry::genTypeHashsAndCheckForConflicts(
282     std::string_view uri,
283     std::vector<folly::fbstring>* typeHashs) const noexcept {
284   if (uri.empty() || uriIndex_.contains(uri)) {
285     return false; // Already exists.
286   }
287 
288   auto typeHash = getUniversalHash(type::UniversalHashAlgorithm::Sha2_256, uri);
289   // Find shortest valid type hash prefix.
290   folly::fbstring minTypeHash(
291       getUniversalHashPrefix(typeHash, kMinTypeHashBytes));
292   // Check if the minimum type hash would be ambiguous.
293   if (containsUniversalHash(hashIndex_, minTypeHash)) {
294     return false; // Ambigous with another typeHash.
295   }
296   typeHashs->emplace_back(std::move(typeHash));
297   return true;
298 }
299 
genTypeHashsAndCheckForConflicts(const ThriftTypeInfo & type,std::vector<folly::fbstring> * typeHashs) const300 bool AnyRegistry::genTypeHashsAndCheckForConflicts(
301     const ThriftTypeInfo& type,
302     std::vector<folly::fbstring>* typeHashs) const noexcept {
303   // Ensure uri and all aliases are availabile.
304   if (!genTypeHashsAndCheckForConflicts(*type.uri_ref(), typeHashs)) {
305     return false;
306   }
307   for (const auto& alias : *type.altUris_ref()) {
308     if (!genTypeHashsAndCheckForConflicts(alias, typeHashs)) {
309       return false;
310     }
311   }
312   return true;
313 }
314 
indexUri(std::string_view uri,TypeEntry * entry)315 void AnyRegistry::indexUri(std::string_view uri, TypeEntry* entry) noexcept {
316   auto res = uriIndex_.emplace(uri, entry);
317   DCHECK(res.second);
318 }
319 
indexHash(folly::fbstring && typeHash,TypeEntry * entry)320 void AnyRegistry::indexHash(
321     folly::fbstring&& typeHash, TypeEntry* entry) noexcept {
322   auto res = hashIndex_.emplace(std::move(typeHash), entry);
323   DCHECK(res.second);
324 }
325 
getTypeEntry(const std::type_index & typeIndex) const326 auto AnyRegistry::getTypeEntry(const std::type_index& typeIndex) const noexcept
327     -> const TypeEntry* {
328   auto itr = registry_.find(typeIndex);
329   if (itr == registry_.end()) {
330     return nullptr;
331   }
332   return &itr->second;
333 }
334 
getTypeEntryByHash(const folly::fbstring & typeHash) const335 auto AnyRegistry::getTypeEntryByHash(
336     const folly::fbstring& typeHash) const noexcept -> const TypeEntry* {
337   if (typeHash.size() < kMinTypeHashBytes) {
338     return nullptr;
339   }
340   auto itr = findByUniversalHash(hashIndex_, typeHash);
341   if (itr == hashIndex_.end()) {
342     // No match.
343     return nullptr;
344   }
345   return itr->second;
346 }
347 
getTypeEntryByUri(std::string_view uri) const348 auto AnyRegistry::getTypeEntryByUri(std::string_view uri) const noexcept
349     -> const TypeEntry* {
350   auto itr = uriIndex_.find(uri);
351   if (itr == uriIndex_.end()) {
352     return nullptr;
353   }
354   return itr->second;
355 }
356 
getTypeEntryFor(const Any & value) const357 auto AnyRegistry::getTypeEntryFor(const Any& value) const noexcept
358     -> const TypeEntry* {
359   if (value.type_ref().has_value() && !value.type_ref()->empty()) {
360     return getTypeEntryByUri(value.type_ref().value_unchecked());
361   }
362   if (value.typeHashPrefixSha2_256_ref().has_value()) {
363     return getTypeEntryByHash(
364         value.typeHashPrefixSha2_256_ref().value_unchecked());
365   }
366   return nullptr;
367 }
368 
getAndCheckTypeEntryFor(const Any & value) const369 auto AnyRegistry::getAndCheckTypeEntryFor(const Any& value) const
370     -> const TypeEntry& {
371   if (value.type_ref().has_value() &&
372       !value.type_ref().value_unchecked().empty()) {
373     return getAndCheckTypeEntryByUri(value.type_ref().value_unchecked());
374   }
375   if (value.typeHashPrefixSha2_256_ref().has_value()) {
376     return getAndCheckTypeEntryByHash(
377         value.typeHashPrefixSha2_256_ref().value_unchecked());
378   }
379   throw std::invalid_argument("any must have a type");
380 }
381 
getSerializer(const TypeEntry * entry,const Protocol & protocol) const382 const AnySerializer* AnyRegistry::getSerializer(
383     const TypeEntry* entry, const Protocol& protocol) const noexcept {
384   if (entry == nullptr) {
385     return nullptr;
386   }
387 
388   auto itr = entry->serializers.find(protocol);
389   if (itr == entry->serializers.end()) {
390     return nullptr;
391   }
392   return itr->second;
393 }
394 
getAndCheckTypeEntry(const std::type_info & typeInfo) const395 auto AnyRegistry::getAndCheckTypeEntry(const std::type_info& typeInfo) const
396     -> const TypeEntry& {
397   const TypeEntry* result = getTypeEntry(typeInfo);
398   if (result == nullptr) {
399     throw std::out_of_range(
400         fmt::format("Type not registered: {}", folly::demangle(typeInfo)));
401   }
402   return *result;
403 }
404 
getAndCheckTypeEntryByUri(std::string_view uri) const405 auto AnyRegistry::getAndCheckTypeEntryByUri(std::string_view uri) const
406     -> const TypeEntry& {
407   const TypeEntry* result = getTypeEntryByUri(uri);
408   if (result == nullptr) {
409     throw std::out_of_range(fmt::format("Type uri not registered: {}", uri));
410   }
411   return *result;
412 }
413 
getAndCheckTypeEntryByHash(const folly::fbstring & typeHash) const414 auto AnyRegistry::getAndCheckTypeEntryByHash(
415     const folly::fbstring& typeHash) const -> const TypeEntry& {
416   const TypeEntry* result = getTypeEntryByHash(typeHash);
417   if (result == nullptr) {
418     throw std::out_of_range(
419         fmt::format("Type hash not registered: {}", folly::hexlify(typeHash)));
420   }
421   return *result;
422 }
423 
getAndCheckSerializer(const TypeEntry & entry,const Protocol & protocol) const424 const AnySerializer& AnyRegistry::getAndCheckSerializer(
425     const TypeEntry& entry, const Protocol& protocol) const {
426   auto itr = entry.serializers.find(protocol);
427   if (itr == entry.serializers.end()) {
428     folly::throw_exception<std::out_of_range>(fmt::format(
429         "Serializer not found: {}#{}", entry.type.get_uri(), protocol.name()));
430   }
431   return *itr->second;
432 }
433 
434 } // namespace apache::thrift::conformance
435