1 /*******************************************************************************
2  * thrill/api/bernoulli_sample.hpp
3  *
4  * Part of Project Thrill - http://project-thrill.org
5  *
6  * Copyright (C) 2016 Lorenz Hübschle-Schneider <lorenz@4z2.de>
7  *
8  * All rights reserved. Published under the BSD-2 license in the LICENSE file.
9  ******************************************************************************/
10 
11 #pragma once
12 #ifndef THRILL_API_BERNOULLI_SAMPLE_HEADER
13 #define THRILL_API_BERNOULLI_SAMPLE_HEADER
14 
15 #include <thrill/api/dia.hpp>
16 #include <thrill/common/functional.hpp>
17 
18 #include <random>
19 
20 namespace thrill {
21 namespace api {
22 
23 /*!
24  * \ingroup api_layer
25  */
26 template <typename ValueType>
27 class BernoulliSampleNode
28 {
29     static const bool debug = false;
30 
31     using SkipDistValueType = int;
32 
33 public:
BernoulliSampleNode(double p)34     explicit BernoulliSampleNode(double p)
35         : p_(p), use_skip_(p < 0.1) { // use skip values if p < 0.1
36         assert(p >= 0.0 && p <= 1.0);
37 
38         if (use_skip_) {
39             skip_dist_ = std::geometric_distribution<SkipDistValueType>(p);
40             skip_remaining_ = skip_dist_(rng_);
41 
42             LOG << "Skip value initialised with " << skip_remaining_;
43         }
44         else {
45             simple_dist_ = std::bernoulli_distribution(p);
46         }
47     }
48 
49     template <typename Emitter>
operator ()(const ValueType & item,Emitter && emit)50     inline void operator () (const ValueType& item, Emitter&& emit) {
51         if (use_skip_) {
52             // use geometric distribution and skip values
53             if (skip_remaining_ == 0) {
54                 // sample element
55                 LOG << "sampled item " << item;
56                 emit(item);
57                 skip_remaining_ = skip_dist_(rng_);
58             }
59             else {
60                 --skip_remaining_;
61             }
62         }
63         else {
64             // use bernoulli distribution
65             if (simple_dist_(rng_)) {
66                 LOG << "sampled item " << item;
67                 emit(item);
68             }
69         }
70     }
71 
use_skip() const72     bool use_skip() const {
73         return use_skip_;
74     }
75 
76 private:
77     // Sampling rate
78     const double p_;
79     // Whether to generate skip values with a geometric distribution or to use
80     // the naive method
81     const bool use_skip_;
82     // Random generator
83     std::default_random_engine rng_ { std::random_device { } () };
84     std::bernoulli_distribution simple_dist_;
85     std::geometric_distribution<SkipDistValueType> skip_dist_;
86     SkipDistValueType skip_remaining_ = -1;
87 };
88 
89 template <typename ValueType, typename Stack>
BernoulliSample(const double p) const90 auto DIA<ValueType, Stack>::BernoulliSample(const double p) const {
91     assert(IsValid());
92 
93     size_t new_id = context().next_dia_id();
94 
95     node_->context().logger_
96         << "dia_id" << new_id
97         << "label" << "BernoulliSample"
98         << "class" << "DIA"
99         << "event" << "create"
100         << "type" << "LOp"
101         << "parents" << (common::Array<size_t>{ dia_id_ });
102 
103     auto new_stack = stack_.push(BernoulliSampleNode<ValueType>(p));
104     return DIA<ValueType, decltype(new_stack)>(
105         node_, new_stack, new_id, "BernoulliSample");
106 }
107 
108 } // namespace api
109 } // namespace thrill
110 
111 #endif // !THRILL_API_BERNOULLI_SAMPLE_HEADER
112 
113 /******************************************************************************/
114