1 /*
2  * Copyright (c) Facebook, Inc. and its affiliates.
3  * All rights reserved.
4  *
5  * This source code is licensed under the BSD-style license found in the
6  * LICENSE file in the root directory of this source tree.
7  */
8 
9 #pragma once
10 
11 #include <fizz/client/PskCache.h>
12 #include <fizz/client/PskSerializationUtils.h>
13 #include <fizz/protocol/Factory.h>
14 #include <fizz/protocol/OpenSSLFactory.h>
15 #include <wangle/client/persistence/FilePersistentCache.h>
16 
17 namespace proxygen {
18 
19 struct PersistentCachedPsk {
20   std::string serialized;
21   size_t uses{0};
22 };
23 
24 class PersistentFizzPskCache : public fizz::client::PskCache {
25  public:
26   ~PersistentFizzPskCache() override = default;
27 
28   PersistentFizzPskCache(const std::string& filename,
29                          wangle::PersistentCacheConfig config,
30                          std::unique_ptr<fizz::Factory> factory =
31                              std::make_unique<fizz::OpenSSLFactory>())
cache_(filename,std::move (config))32       : cache_(filename, std::move(config)), factory_(std::move(factory)) {
33   }
34 
setMaxPskUses(size_t maxUses)35   void setMaxPskUses(size_t maxUses) {
36     maxPskUses_ = maxUses;
37   }
38 
39   /**
40    * Returns number of times the psk has been used.
41    */
getPskUses(const std::string & identity)42   folly::Optional<size_t> getPskUses(const std::string& identity) {
43     auto serialized = cache_.get(identity);
44     if (serialized) {
45       return serialized->uses;
46     }
47     return folly::none;
48   }
49 
getPsk(const std::string & identity)50   folly::Optional<fizz::client::CachedPsk> getPsk(
51       const std::string& identity) override {
52     auto serialized = cache_.get(identity);
53     if (serialized) {
54       try {
55         auto deserialized =
56             fizz::client::deserializePsk(serialized->serialized, *factory_);
57         serialized->uses++;
58         if (maxPskUses_ != 0 && serialized->uses >= maxPskUses_) {
59           cache_.remove(identity);
60         } else {
61           cache_.put(identity, *serialized);
62         }
63         return deserialized;
64       } catch (const std::exception& ex) {
65         LOG(ERROR) << "Error deserializing PSK: " << ex.what();
66         cache_.remove(identity);
67       }
68     }
69     return folly::none;
70   }
71 
putPsk(const std::string & identity,fizz::client::CachedPsk psk)72   void putPsk(const std::string& identity,
73               fizz::client::CachedPsk psk) override {
74     PersistentCachedPsk serialized;
75     serialized.serialized = fizz::client::serializePsk(psk);
76     serialized.uses = 0;
77     cache_.put(identity, std::move(serialized));
78   }
79 
removePsk(const std::string & identity)80   void removePsk(const std::string& identity) override {
81     cache_.remove(identity);
82   }
83 
84  private:
85   wangle::FilePersistentCache<std::string, PersistentCachedPsk> cache_;
86 
87   size_t maxPskUses_{5};
88 
89   std::unique_ptr<fizz::Factory> factory_;
90 };
91 } // namespace proxygen
92 
93 namespace folly {
94 
95 template <>
96 dynamic toDynamic(const proxygen::PersistentCachedPsk& cached);
97 template <>
98 proxygen::PersistentCachedPsk convertTo(const dynamic& d);
99 } // namespace folly
100