1 //
2 // Copyright 2018 The Abseil Authors.
3 //
4 // Licensed under the Apache License, Version 2.0 (the "License");
5 // you may not use this file except in compliance with the License.
6 // You may obtain a copy of the License at
7 //
8 //      https://www.apache.org/licenses/LICENSE-2.0
9 //
10 // Unless required by applicable law or agreed to in writing, software
11 // distributed under the License is distributed on an "AS IS" BASIS,
12 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 // See the License for the specific language governing permissions and
14 // limitations under the License.
15 //
16 #ifndef ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
17 #define ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
18 
19 #include <string>
20 #include <typeinfo>
21 
22 #include "absl/random/random.h"
23 #include "absl/strings/str_cat.h"
24 
25 namespace absl {
26 ABSL_NAMESPACE_BEGIN
27 namespace random_internal {
28 
29 class MockingBitGenBase {
30   template <typename>
31   friend struct DistributionCaller;
32   using generator_type = absl::BitGen;
33 
34  public:
35   // URBG interface
36   using result_type = generator_type::result_type;
result_type(min)37   static constexpr result_type(min)() { return (generator_type::min)(); }
result_type(max)38   static constexpr result_type(max)() { return (generator_type::max)(); }
operator()39   result_type operator()() { return gen_(); }
40 
41   virtual ~MockingBitGenBase() = default;
42 
43  protected:
44   // CallImpl is the type-erased virtual dispatch.
45   // The type of dist is always distribution<T>,
46   // The type of result is always distribution<T>::result_type.
47   virtual bool CallImpl(const std::type_info& distr_type, void* dist_args,
48                         void* result) = 0;
49 
50   template <typename DistrT, typename ArgTupleT>
GetTypeId()51   static const std::type_info& GetTypeId() {
52     return typeid(std::pair<absl::decay_t<DistrT>, absl::decay_t<ArgTupleT>>);
53   }
54 
55   // Call the generating distribution function.
56   // Invoked by DistributionCaller<>::Call<DistT>.
57   // DistT is the distribution type.
58   template <typename DistrT, typename... Args>
Call(Args &&...args)59   typename DistrT::result_type Call(Args&&... args) {
60     using distr_result_type = typename DistrT::result_type;
61     using ArgTupleT = std::tuple<absl::decay_t<Args>...>;
62 
63     ArgTupleT arg_tuple(std::forward<Args>(args)...);
64     auto dist = absl::make_from_tuple<DistrT>(arg_tuple);
65 
66     distr_result_type result{};
67     bool found_match =
68         CallImpl(GetTypeId<DistrT, ArgTupleT>(), &arg_tuple, &result);
69 
70     if (!found_match) {
71       result = dist(gen_);
72     }
73 
74     return result;
75   }
76 
77  private:
78   generator_type gen_;
79 };  // namespace random_internal
80 
81 }  // namespace random_internal
82 ABSL_NAMESPACE_END
83 }  // namespace absl
84 
85 #endif  // ABSL_RANDOM_INTERNAL_MOCKING_BIT_GEN_BASE_H_
86