1 /******************************************************************************* 2 * thrill/api/bernoulli_sample.hpp 3 * 4 * Part of Project Thrill - http://project-thrill.org 5 * 6 * Copyright (C) 2016 Lorenz Hübschle-Schneider <lorenz@4z2.de> 7 * 8 * All rights reserved. Published under the BSD-2 license in the LICENSE file. 9 ******************************************************************************/ 10 11 #pragma once 12 #ifndef THRILL_API_BERNOULLI_SAMPLE_HEADER 13 #define THRILL_API_BERNOULLI_SAMPLE_HEADER 14 15 #include <thrill/api/dia.hpp> 16 #include <thrill/common/functional.hpp> 17 18 #include <random> 19 20 namespace thrill { 21 namespace api { 22 23 /*! 24 * \ingroup api_layer 25 */ 26 template <typename ValueType> 27 class BernoulliSampleNode 28 { 29 static const bool debug = false; 30 31 using SkipDistValueType = int; 32 33 public: BernoulliSampleNode(double p)34 explicit BernoulliSampleNode(double p) 35 : p_(p), use_skip_(p < 0.1) { // use skip values if p < 0.1 36 assert(p >= 0.0 && p <= 1.0); 37 38 if (use_skip_) { 39 skip_dist_ = std::geometric_distribution<SkipDistValueType>(p); 40 skip_remaining_ = skip_dist_(rng_); 41 42 LOG << "Skip value initialised with " << skip_remaining_; 43 } 44 else { 45 simple_dist_ = std::bernoulli_distribution(p); 46 } 47 } 48 49 template <typename Emitter> operator ()(const ValueType & item,Emitter && emit)50 inline void operator () (const ValueType& item, Emitter&& emit) { 51 if (use_skip_) { 52 // use geometric distribution and skip values 53 if (skip_remaining_ == 0) { 54 // sample element 55 LOG << "sampled item " << item; 56 emit(item); 57 skip_remaining_ = skip_dist_(rng_); 58 } 59 else { 60 --skip_remaining_; 61 } 62 } 63 else { 64 // use bernoulli distribution 65 if (simple_dist_(rng_)) { 66 LOG << "sampled item " << item; 67 emit(item); 68 } 69 } 70 } 71 use_skip() const72 bool use_skip() const { 73 return use_skip_; 74 } 75 76 private: 77 // Sampling rate 78 const double p_; 79 // Whether to generate skip values with a geometric distribution or to use 80 // the naive method 81 const bool use_skip_; 82 // Random generator 83 std::default_random_engine rng_ { std::random_device { } () }; 84 std::bernoulli_distribution simple_dist_; 85 std::geometric_distribution<SkipDistValueType> skip_dist_; 86 SkipDistValueType skip_remaining_ = -1; 87 }; 88 89 template <typename ValueType, typename Stack> BernoulliSample(const double p) const90auto DIA<ValueType, Stack>::BernoulliSample(const double p) const { 91 assert(IsValid()); 92 93 size_t new_id = context().next_dia_id(); 94 95 node_->context().logger_ 96 << "dia_id" << new_id 97 << "label" << "BernoulliSample" 98 << "class" << "DIA" 99 << "event" << "create" 100 << "type" << "LOp" 101 << "parents" << (common::Array<size_t>{ dia_id_ }); 102 103 auto new_stack = stack_.push(BernoulliSampleNode<ValueType>(p)); 104 return DIA<ValueType, decltype(new_stack)>( 105 node_, new_stack, new_id, "BernoulliSample"); 106 } 107 108 } // namespace api 109 } // namespace thrill 110 111 #endif // !THRILL_API_BERNOULLI_SAMPLE_HEADER 112 113 /******************************************************************************/ 114