1 /*
2 * Copyright (c) Facebook, Inc. and its affiliates.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17 #include <thrift/conformance/cpp2/AnyRegistry.h>
18
19 #include <glog/logging.h>
20
21 #include <folly/CppAttributes.h>
22 #include <folly/Demangle.h>
23 #include <folly/Singleton.h>
24 #include <folly/String.h>
25 #include <folly/io/Cursor.h>
26 #include <thrift/conformance/cpp2/Any.h>
27 #include <thrift/lib/cpp2/type/UniversalName.h>
28
29 namespace apache::thrift::conformance {
30 using type::containsUniversalHash;
31 using type::findByUniversalHash;
32 using type::getUniversalHash;
33 using type::getUniversalHashPrefix;
34 using type::hash_size_t;
35 using type::maybeGetUniversalHashPrefix;
36 namespace detail {
37
getGeneratedAnyRegistry()38 AnyRegistry& getGeneratedAnyRegistry() {
39 struct GeneratedTag {};
40 return folly::detail::createGlobal<AnyRegistry, GeneratedTag>();
41 }
42
43 } // namespace detail
44
45 namespace {
46
maybeGetTypeHash(const ThriftTypeInfo & type,hash_size_t defaultTypeHashBytes=kDefaultTypeHashBytes)47 folly::fbstring maybeGetTypeHash(
48 const ThriftTypeInfo& type,
49 hash_size_t defaultTypeHashBytes = kDefaultTypeHashBytes) {
50 if (type.typeHashBytes_ref().has_value()) {
51 // Use the custom size.
52 defaultTypeHashBytes = type.typeHashBytes_ref().value_unchecked();
53 }
54 return maybeGetUniversalHashPrefix(
55 type::UniversalHashAlgorithm::Sha2_256,
56 type.get_uri(),
57 defaultTypeHashBytes);
58 }
59
60 } // namespace
61
TypeEntry(const std::type_info & typeInfo,ThriftTypeInfo type)62 AnyRegistry::TypeEntry::TypeEntry(
63 const std::type_info& typeInfo, ThriftTypeInfo type)
64 : typeInfo(typeInfo),
65 typeHash(maybeGetTypeHash(type)),
66 type(std::move(type)) {}
67
registerType(const std::type_info & typeInfo,ThriftTypeInfo type)68 bool AnyRegistry::registerType(
69 const std::type_info& typeInfo, ThriftTypeInfo type) {
70 return registerTypeImpl(typeInfo, std::move(type)) != nullptr;
71 }
72
registerSerializer(const std::type_info & type,const AnySerializer * serializer)73 bool AnyRegistry::registerSerializer(
74 const std::type_info& type, const AnySerializer* serializer) {
75 return registerSerializerImpl(
76 serializer, ®istry_.at(std::type_index(type)));
77 }
78
registerSerializer(const std::type_info & type,std::unique_ptr<AnySerializer> serializer)79 bool AnyRegistry::registerSerializer(
80 const std::type_info& type, std::unique_ptr<AnySerializer> serializer) {
81 return registerSerializerImpl(
82 std::move(serializer), ®istry_.at(std::type_index(type)));
83 }
84
getTypeUri(const std::type_info & type) const85 std::string_view AnyRegistry::getTypeUri(
86 const std::type_info& type) const noexcept {
87 const auto* entry = getTypeEntry(type);
88 if (entry == nullptr) {
89 return {};
90 }
91 return entry->type.get_uri();
92 }
93
getTypeUri(const Any & value) const94 std::string_view AnyRegistry::getTypeUri(const Any& value) const noexcept {
95 const auto* entry = getTypeEntryFor(value);
96 if (entry == nullptr) {
97 return {};
98 }
99 return entry->type.get_uri();
100 }
101
getTypeId(const Any & value) const102 const std::type_info& AnyRegistry::getTypeId(const Any& value) const {
103 return getAndCheckTypeEntryFor(value).typeInfo;
104 }
105
106 // Same as above, except returns nullptr if the type has not been registered.
tryGetTypeId(const Any & value) const107 const std::type_info* AnyRegistry::tryGetTypeId(
108 const Any& value) const noexcept {
109 const auto* entry = getTypeEntryFor(value);
110 if (entry == nullptr) {
111 return nullptr;
112 }
113 return &entry->typeInfo;
114 }
115
getSerializer(const std::type_info & type,const Protocol & protocol) const116 const AnySerializer* AnyRegistry::getSerializer(
117 const std::type_info& type, const Protocol& protocol) const noexcept {
118 return getSerializer(getTypeEntry(type), protocol);
119 }
120
getSerializerByUri(const std::string_view uri,const Protocol & protocol) const121 const AnySerializer* AnyRegistry::getSerializerByUri(
122 const std::string_view uri, const Protocol& protocol) const noexcept {
123 return getSerializer(getTypeEntryByUri(uri), protocol);
124 }
125
getSerializerByHash(type::UniversalHashAlgorithm alg,const folly::fbstring & typeHash,const Protocol & protocol) const126 const AnySerializer* AnyRegistry::getSerializerByHash(
127 type::UniversalHashAlgorithm alg,
128 const folly::fbstring& typeHash,
129 const Protocol& protocol) const {
130 if (alg != type::UniversalHashAlgorithm::Sha2_256) {
131 folly::throw_exception<std::runtime_error>(
132 "Unsupported hash algorithm: " + std::to_string(static_cast<int>(alg)));
133 }
134 return getSerializer(getTypeEntryByHash(typeHash), protocol);
135 }
136
store(any_ref value,const Protocol & protocol) const137 Any AnyRegistry::store(any_ref value, const Protocol& protocol) const {
138 if (value.type() == typeid(Any)) {
139 // Use the Any specific overload.
140 return store(any_cast<const Any&>(value), protocol);
141 }
142
143 const auto& entry = getAndCheckTypeEntry(value.type());
144 const auto& serializer = getAndCheckSerializer(entry, protocol);
145
146 folly::IOBufQueue queue(folly::IOBufQueue::cacheChainLength());
147 // Allocate 16KB at a time; leave some room for the IOBuf overhead
148 constexpr size_t kDesiredGrowth = (1 << 14) - 64;
149 serializer.encode(value, folly::io::QueueAppender(&queue, kDesiredGrowth));
150
151 Any result;
152 if (entry.typeHash.empty()) {
153 result.set_type(entry.type.get_uri());
154 } else {
155 result.set_typeHashPrefixSha2_256(entry.typeHash);
156 }
157 setProtocol(protocol, result);
158 result.set_data(queue.moveAsValue());
159 return result;
160 }
161
store(const Any & value,const Protocol & protocol) const162 Any AnyRegistry::store(const Any& value, const Protocol& protocol) const {
163 if (hasProtocol(value, protocol)) {
164 return value;
165 }
166 return store(load(value), protocol);
167 }
168
load(const Any & value,any_ref out) const169 void AnyRegistry::load(const Any& value, any_ref out) const {
170 const auto& entry = getAndCheckTypeEntryFor(value);
171 const auto& serializer = getAndCheckSerializer(entry, getProtocol(value));
172 folly::io::Cursor cursor(&*value.data_ref());
173 serializer.decode(entry.typeInfo, cursor, out);
174 }
175
load(const Any & value) const176 std::any AnyRegistry::load(const Any& value) const {
177 std::any out;
178 load(value, out);
179 return out;
180 }
181
debugString() const182 std::string AnyRegistry::debugString() const {
183 std::string result = "AnyRegistry[\n";
184 // Using the sorted map, hashIndex_, to produce stable results.
185 for (const auto& indx : hashIndex_) {
186 const TypeEntry& entry = *indx.second;
187 result += " ";
188 result += entry.type.get_uri();
189 result += " (";
190 result += folly::hexlify(indx.first);
191 result += ")";
192 if (!entry.serializers.empty()) {
193 result += ":\n";
194 // Convert to a set, so output is deterministic.
195 std::set<Protocol> protocols;
196 for (const auto& ser : entry.serializers) {
197 protocols.emplace(ser.first);
198 }
199 for (const auto& protocol : protocols) {
200 result += " ";
201 result += protocol.name();
202 result += ",\n";
203 }
204 } else {
205 result += ",\n";
206 }
207 }
208 result += "]";
209 return result;
210 }
211
forceRegisterType(const std::type_info & typeInfo,std::string type)212 bool AnyRegistry::forceRegisterType(
213 const std::type_info& typeInfo, std::string type) {
214 if (getTypeEntryByUri(type) != nullptr) {
215 return false;
216 }
217
218 ThriftTypeInfo info;
219 info.set_uri(std::move(type));
220 info.set_typeHashBytes(0);
221
222 auto result = registry_.emplace(
223 std::type_index(typeInfo), TypeEntry(typeInfo, std::move(info)));
224 if (!result.second) {
225 return false;
226 }
227
228 TypeEntry* entry = &result.first->second;
229 indexUri(*entry->type.uri_ref(), entry);
230 return true;
231 }
232
registerTypeImpl(const std::type_info & typeInfo,ThriftTypeInfo type)233 auto AnyRegistry::registerTypeImpl(
234 const std::type_info& typeInfo, ThriftTypeInfo type) -> TypeEntry* {
235 validateThriftTypeInfo(type);
236 std::vector<folly::fbstring> typeHashs;
237 typeHashs.reserve(type.altUris_ref()->size() + 1);
238 if (!genTypeHashsAndCheckForConflicts(type, &typeHashs)) {
239 return nullptr;
240 }
241
242 auto result = registry_.emplace(
243 std::type_index(typeInfo), TypeEntry(typeInfo, std::move(type)));
244 if (!result.second) {
245 return nullptr;
246 }
247
248 TypeEntry* entry = &result.first->second;
249
250 // Add to secondary indexes.
251 indexUri(*entry->type.uri_ref(), entry);
252 for (const auto& alias : *entry->type.altUris_ref()) {
253 indexUri(alias, entry);
254 }
255
256 for (auto& hash : typeHashs) {
257 indexHash(std::move(hash), entry);
258 }
259 return &result.first->second;
260 }
261
registerSerializerImpl(const AnySerializer * serializer,TypeEntry * entry)262 bool AnyRegistry::registerSerializerImpl(
263 const AnySerializer* serializer, TypeEntry* entry) {
264 if (serializer == nullptr) {
265 return false;
266 }
267 validateProtocol(serializer->getProtocol());
268 return entry->serializers.emplace(serializer->getProtocol(), serializer)
269 .second;
270 }
271
registerSerializerImpl(std::unique_ptr<AnySerializer> serializer,TypeEntry * entry)272 bool AnyRegistry::registerSerializerImpl(
273 std::unique_ptr<AnySerializer> serializer, TypeEntry* entry) {
274 if (!registerSerializerImpl(serializer.get(), entry)) {
275 return false;
276 }
277 ownedSerializers_.emplace_front(std::move(serializer));
278 return true;
279 }
280
genTypeHashsAndCheckForConflicts(std::string_view uri,std::vector<folly::fbstring> * typeHashs) const281 bool AnyRegistry::genTypeHashsAndCheckForConflicts(
282 std::string_view uri,
283 std::vector<folly::fbstring>* typeHashs) const noexcept {
284 if (uri.empty() || uriIndex_.contains(uri)) {
285 return false; // Already exists.
286 }
287
288 auto typeHash = getUniversalHash(type::UniversalHashAlgorithm::Sha2_256, uri);
289 // Find shortest valid type hash prefix.
290 folly::fbstring minTypeHash(
291 getUniversalHashPrefix(typeHash, kMinTypeHashBytes));
292 // Check if the minimum type hash would be ambiguous.
293 if (containsUniversalHash(hashIndex_, minTypeHash)) {
294 return false; // Ambigous with another typeHash.
295 }
296 typeHashs->emplace_back(std::move(typeHash));
297 return true;
298 }
299
genTypeHashsAndCheckForConflicts(const ThriftTypeInfo & type,std::vector<folly::fbstring> * typeHashs) const300 bool AnyRegistry::genTypeHashsAndCheckForConflicts(
301 const ThriftTypeInfo& type,
302 std::vector<folly::fbstring>* typeHashs) const noexcept {
303 // Ensure uri and all aliases are availabile.
304 if (!genTypeHashsAndCheckForConflicts(*type.uri_ref(), typeHashs)) {
305 return false;
306 }
307 for (const auto& alias : *type.altUris_ref()) {
308 if (!genTypeHashsAndCheckForConflicts(alias, typeHashs)) {
309 return false;
310 }
311 }
312 return true;
313 }
314
indexUri(std::string_view uri,TypeEntry * entry)315 void AnyRegistry::indexUri(std::string_view uri, TypeEntry* entry) noexcept {
316 auto res = uriIndex_.emplace(uri, entry);
317 DCHECK(res.second);
318 }
319
indexHash(folly::fbstring && typeHash,TypeEntry * entry)320 void AnyRegistry::indexHash(
321 folly::fbstring&& typeHash, TypeEntry* entry) noexcept {
322 auto res = hashIndex_.emplace(std::move(typeHash), entry);
323 DCHECK(res.second);
324 }
325
getTypeEntry(const std::type_index & typeIndex) const326 auto AnyRegistry::getTypeEntry(const std::type_index& typeIndex) const noexcept
327 -> const TypeEntry* {
328 auto itr = registry_.find(typeIndex);
329 if (itr == registry_.end()) {
330 return nullptr;
331 }
332 return &itr->second;
333 }
334
getTypeEntryByHash(const folly::fbstring & typeHash) const335 auto AnyRegistry::getTypeEntryByHash(
336 const folly::fbstring& typeHash) const noexcept -> const TypeEntry* {
337 if (typeHash.size() < kMinTypeHashBytes) {
338 return nullptr;
339 }
340 auto itr = findByUniversalHash(hashIndex_, typeHash);
341 if (itr == hashIndex_.end()) {
342 // No match.
343 return nullptr;
344 }
345 return itr->second;
346 }
347
getTypeEntryByUri(std::string_view uri) const348 auto AnyRegistry::getTypeEntryByUri(std::string_view uri) const noexcept
349 -> const TypeEntry* {
350 auto itr = uriIndex_.find(uri);
351 if (itr == uriIndex_.end()) {
352 return nullptr;
353 }
354 return itr->second;
355 }
356
getTypeEntryFor(const Any & value) const357 auto AnyRegistry::getTypeEntryFor(const Any& value) const noexcept
358 -> const TypeEntry* {
359 if (value.type_ref().has_value() && !value.type_ref()->empty()) {
360 return getTypeEntryByUri(value.type_ref().value_unchecked());
361 }
362 if (value.typeHashPrefixSha2_256_ref().has_value()) {
363 return getTypeEntryByHash(
364 value.typeHashPrefixSha2_256_ref().value_unchecked());
365 }
366 return nullptr;
367 }
368
getAndCheckTypeEntryFor(const Any & value) const369 auto AnyRegistry::getAndCheckTypeEntryFor(const Any& value) const
370 -> const TypeEntry& {
371 if (value.type_ref().has_value() &&
372 !value.type_ref().value_unchecked().empty()) {
373 return getAndCheckTypeEntryByUri(value.type_ref().value_unchecked());
374 }
375 if (value.typeHashPrefixSha2_256_ref().has_value()) {
376 return getAndCheckTypeEntryByHash(
377 value.typeHashPrefixSha2_256_ref().value_unchecked());
378 }
379 throw std::invalid_argument("any must have a type");
380 }
381
getSerializer(const TypeEntry * entry,const Protocol & protocol) const382 const AnySerializer* AnyRegistry::getSerializer(
383 const TypeEntry* entry, const Protocol& protocol) const noexcept {
384 if (entry == nullptr) {
385 return nullptr;
386 }
387
388 auto itr = entry->serializers.find(protocol);
389 if (itr == entry->serializers.end()) {
390 return nullptr;
391 }
392 return itr->second;
393 }
394
getAndCheckTypeEntry(const std::type_info & typeInfo) const395 auto AnyRegistry::getAndCheckTypeEntry(const std::type_info& typeInfo) const
396 -> const TypeEntry& {
397 const TypeEntry* result = getTypeEntry(typeInfo);
398 if (result == nullptr) {
399 throw std::out_of_range(
400 fmt::format("Type not registered: {}", folly::demangle(typeInfo)));
401 }
402 return *result;
403 }
404
getAndCheckTypeEntryByUri(std::string_view uri) const405 auto AnyRegistry::getAndCheckTypeEntryByUri(std::string_view uri) const
406 -> const TypeEntry& {
407 const TypeEntry* result = getTypeEntryByUri(uri);
408 if (result == nullptr) {
409 throw std::out_of_range(fmt::format("Type uri not registered: {}", uri));
410 }
411 return *result;
412 }
413
getAndCheckTypeEntryByHash(const folly::fbstring & typeHash) const414 auto AnyRegistry::getAndCheckTypeEntryByHash(
415 const folly::fbstring& typeHash) const -> const TypeEntry& {
416 const TypeEntry* result = getTypeEntryByHash(typeHash);
417 if (result == nullptr) {
418 throw std::out_of_range(
419 fmt::format("Type hash not registered: {}", folly::hexlify(typeHash)));
420 }
421 return *result;
422 }
423
getAndCheckSerializer(const TypeEntry & entry,const Protocol & protocol) const424 const AnySerializer& AnyRegistry::getAndCheckSerializer(
425 const TypeEntry& entry, const Protocol& protocol) const {
426 auto itr = entry.serializers.find(protocol);
427 if (itr == entry.serializers.end()) {
428 folly::throw_exception<std::out_of_range>(fmt::format(
429 "Serializer not found: {}#{}", entry.type.get_uri(), protocol.name()));
430 }
431 return *itr->second;
432 }
433
434 } // namespace apache::thrift::conformance
435