1 // -*-mode:c++; c-style:k&r; c-basic-offset:4;-*-
2 //
3 // Copyright 2010-2012, Julian Catchen <jcatchen@uoregon.edu>
4 //
5 // This file is part of Stacks.
6 //
7 // Stacks is free software: you can redistribute it and/or modify
8 // it under the terms of the GNU General Public License as published by
9 // the Free Software Foundation, either version 3 of the License, or
10 // (at your option) any later version.
11 //
12 // Stacks is distributed in the hope that it will be useful,
13 // but WITHOUT ANY WARRANTY; without even the implied warranty of
14 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
15 // GNU General Public License for more details.
16 //
17 // You should have received a copy of the GNU General Public License
18 // along with Stacks.  If not, see <http://www.gnu.org/licenses/>.
19 //
20 
21 #ifndef __MODELS_H__
22 #define __MODELS_H__
23 
24 #include "constants.h"
25 #include "utils.h"
26 #include "DNASeq4.h"
27 #include "stacks.h"
28 #include "locus.h"
29 #include "mstack.h"
30 
31 //
32 // Possible models for calling nucleotide positions as fixed or variable
33 //
34 enum modelt {fixed, snp, bounded, marukihigh, marukilow};
35 
36 //
37 // For use with the multinomial model to call fixed nucleotides.
38 //
39 extern int barcode_size;
40 extern double barcode_err_freq;
41 
42 extern double heterozygote_limit;
43 extern double homozygote_limit;
44 extern double bound_low;  // For the bounded-snp model.
45 extern double bound_high; // For the bounded-snp model.
46 extern double p_freq;     // For the fixed model.
47 
48 extern const std::array<double,40> chisq1ddf_gq;
49 
50 
51 bool lrtest(double lnl_althyp, double lnl_nullhyp, double threshold); // Where threshold is a value in the Chi2 distribution.
52 long vcf_lnl2gq(double lnl_gt, double lnl_gt_alt); // Returns VCF's GQ field `floor(-10 * log10(p-value))`
53 
54 double qchisq(double alpha, size_t df);
55 modelt parse_model_type(const string& arg);
56 void set_model_thresholds(double alpha);
57 void report_model(ostream& os, modelt model_type);
58 void report_alpha(ostream& os, double alpha);
59 string to_string(modelt model_type);
60 
61 snp_type call_snp (double l_ratio);
62 snp_type call_snp (double lnl_hom, double lnl_het);
63 
64 double   lnl_multinomial_model_hom (double total, double n1);
65 double   lnl_multinomial_model_het (double total, double n1n2);
66 double   lr_multinomial_model         (double nuc_1, double nuc_2, double nuc_3, double nuc_4);
67 double   lr_bounded_multinomial_model (double nuc_1, double nuc_2, double nuc_3, double nuc_4);
68 
69 void call_bounded_multinomial_snp(MergedStack *, int, map<char, int> &, bool);
70 void call_bounded_multinomial_snp(Locus *, int, map<char, int> &);
71 void call_multinomial_snp(MergedStack *, int, map<char, int> &, bool);
72 void call_multinomial_snp(Locus *, int, map<char, int> &);
73 void call_multinomial_fixed(MergedStack *, int, map<char, int> &);
74 
75 double   heterozygous_likelihood(int, map<char, int> &);
76 double   homozygous_likelihood(int, map<char, int> &);
77 
78 class SampleCall {
79     GtLiks lnls_;
80     // The genotype call and the corresponding nucleotides.
81     // hom {nt, Nt2()} | het {min_nt, max_nt} | unk {Nt2(), Nt2()}
82     // For hets, the two nucleotides are sorted lexically (A<C<G<T).
83     snp_type call_;
84     array<Nt2,2> nts_;
85     long gq_;
86 public:
SampleCall()87     SampleCall() : lnls_(), call_(snp_type_unk), nts_{{Nt2(),Nt2()}}, gq_(-1) {}
88 
lnls()89     const GtLiks& lnls() const {return lnls_;}
lnls()90           GtLiks& lnls()       {return lnls_;}
91 
call()92     snp_type call() const {return call_;}
nt0()93     Nt2 nt0() const {assert(call_==snp_type_hom || call_==snp_type_het); return nts_[0];}
nt1()94     Nt2 nt1() const {assert(call_==snp_type_het); return nts_[1];}
nts()95     array<Nt2,2> nts() const {assert(call_==snp_type_het); return nts_;}
gq()96     long gq() const {assert(call_==snp_type_hom || call_==snp_type_het); return gq_;}
97 
98     void set_call(snp_type c, Nt2 rank0_nt, Nt2 rank1_nt, long gq);
discard()99     void discard() {call_ = snp_type_discarded; nts_= {{Nt2(), Nt2()}}; gq_ = -1;}
100 
101     // For debugging.
102     friend ostream& operator<<(ostream& os, const SampleCall& sc);
103 };
104 
105 class SiteCall {
106     map<Nt2,double> alleles_;
107     long snp_qual_;
108     vector<SampleCall> sample_calls_; // Empty if alleles_.size() < 2.
109 public:
SiteCall(map<Nt2,double> && alleles,long snp_qual,vector<SampleCall> && sample_calls)110     SiteCall(
111             map<Nt2,double>&& alleles,
112             long snp_qual,
113             vector<SampleCall>&& sample_calls
114         )
115             : alleles_(move(alleles))
116             , snp_qual_(snp_qual)
117             , sample_calls_(move(sample_calls))
118         {}
SiteCall()119     SiteCall()
120         : SiteCall({}, -1, {}) {}
SiteCall(Nt2 fixed_nt)121     SiteCall(Nt2 fixed_nt)
122         : SiteCall({{fixed_nt, 1.0}}, -1, {}) {}
SiteCall(map<Nt2,double> && alleles,vector<SampleCall> && sample_calls)123     SiteCall(map<Nt2,double>&& alleles, vector<SampleCall>&& sample_calls)
124         : SiteCall(move(alleles), -1, move(sample_calls)) {}
125 
alleles()126     const map<Nt2,double>& alleles() const {return alleles_;}
snp_qual()127     long snp_qual() const {return snp_qual_;}
sample_calls()128     const vector<SampleCall>& sample_calls() const {return sample_calls_;}
129 
130     Nt2 most_frequent_allele() const;
131 
132     void filter_mac(size_t min_mac);
discard_sample(size_t sample_i)133     void discard_sample(size_t sample_i)
134         {assert(!sample_calls_.empty()); sample_calls_[sample_i].discard();}
135 
136     static Counts<Nt2> tally_allele_counts(const vector<SampleCall>& spldata);
137     static map<Nt2,double> tally_allele_freqs(const vector<SampleCall>& spldata);
138 
139     // For debugging.
140     void print(ostream& os, const SiteCounts& depths);
141 };
142 
143 class Model {
144 public:
~Model()145     virtual ~Model() {}
146     virtual SiteCall call(const SiteCounts& depths) const = 0;
147     virtual void print(ostream& os) const = 0;
148     friend ostream& operator<< (ostream& os, const Model& m) {m.print(os); return os;}
149 };
150 
151 //
152 // MultinomialModel: the standard Stacks v.1 model described in Hohenloe2010.
153 //
154 class MultinomialModel : public Model {
155     double alpha_;
156 public:
MultinomialModel(double gt_alpha)157     MultinomialModel(double gt_alpha) : alpha_(gt_alpha) {set_model_thresholds(alpha_);}
158     SiteCall call(const SiteCounts& depths) const;
print(ostream & os)159     void print(ostream& os) const
160         {os << to_string(modelt::snp) << " (alpha: "  << alpha_ << ")";}
161 };
162 
163 //
164 // MarukiHighModel: the model of Maruki & Lynch (2017) for high-coverage data.
165 //
166 class MarukiHighModel : public Model {
167     double gt_alpha_;
168     double gt_threshold_;
169     double var_alpha_;
170     double var_threshold_;
171     double calc_hom_lnl(double n, double n1) const;
172     double calc_het_lnl(double n, double n1n2) const;
173 public:
MarukiHighModel(double gt_alpha,double var_alpha)174     MarukiHighModel(double gt_alpha, double var_alpha)
175         : gt_alpha_(gt_alpha), gt_threshold_(qchisq(gt_alpha_,1)),
176           var_alpha_(var_alpha), var_threshold_(qchisq(var_alpha_,1))
177         {}
178     SiteCall call(const SiteCounts& depths) const;
print(ostream & os)179     void print(ostream& os) const
180         {os << to_string(modelt::marukihigh) << " (var_alpha: "  << var_alpha_ << ", gt_alpha: " << gt_alpha_ << ")";}
181 };
182 
183 //
184 // MarukiLowModel: the model of Maruki & Lynch (2015,2017) for low-coverage data.
185 //
186 class MarukiLowModel : public Model {
187     struct LikData {
188         bool has_data;
189         double lnl_MM;
190         double lnl_Mm;
191         double lnl_mm;
192         double l_MM;
193         double l_Mm;
194         double l_mm;
LikDataLikData195         LikData() : has_data(false), lnl_MM(0.0), lnl_Mm(0.0), lnl_mm(0.0), l_MM(1.0), l_Mm(1.0), l_mm(1.0) {}
LikDataLikData196         LikData(double lnl_MM_, double lnl_Mm_, double lnl_mm_)
197             : has_data(true),
198               lnl_MM(lnl_MM_), lnl_Mm(lnl_Mm_), lnl_mm(lnl_mm_),
199               l_MM(exp(lnl_MM)), l_Mm(exp(lnl_Mm)), l_mm(exp(lnl_mm))
200             {assert(std::isfinite(lnl_MM) && std::isfinite(lnl_Mm) && std::isfinite(lnl_mm));}
201     };
202 
203     double gt_alpha_;
204     double gt_threshold_;
205     double var_alpha_;
206     double var_threshold_;
207     double calc_fixed_lnl(double n_tot, double n_M_tot) const;
208     double calc_dimorph_lnl(double freq_MM, double freq_Mm, double freq_mm, const vector<LikData>& liks) const;
209     double calc_ln_weighted_sum(double freq_MM, double freq_Mm, double freq_mm, const LikData& s_liks) const;
210     double calc_ln_weighted_sum_safe(double freq_MM, double freq_Mm, double freq_mm, const LikData& s_liks) const;
211 
212     mutable size_t n_wsum_tot_;
213     mutable size_t n_wsum_underflows_;
214     mutable size_t n_called_sites_;
215     mutable double sum_site_err_rates_;
216 
217 public:
MarukiLowModel(double gt_alpha,double var_alpha)218     MarukiLowModel(double gt_alpha, double var_alpha)
219         : gt_alpha_(gt_alpha), gt_threshold_(qchisq(gt_alpha_,1)),
220           var_alpha_(var_alpha), var_threshold_(qchisq(var_alpha_,2)), // df=2
221           n_wsum_tot_(0), n_wsum_underflows_(0), n_called_sites_(0), sum_site_err_rates_(0.0)
222         {}
223 
224     SiteCall call(const SiteCounts& depths) const;
print(ostream & os)225     void print(ostream& os) const
226         {os << to_string(modelt::marukilow) << " (var_alpha: "  << var_alpha_ << ", gt_alpha: " << gt_alpha_ << ")";}
n_wsum_tot()227     size_t n_wsum_tot() const {return n_wsum_tot_;}
n_wsum_underflows()228     size_t n_wsum_underflows() const {return n_wsum_underflows_;}
mean_err_rate()229     double mean_err_rate() const {return sum_site_err_rates_ / n_called_sites_;}
230 };
231 
232 //
233 // ==================
234 // Inline definitions
235 // ==================
236 //
237 
238 inline
lrtest(double lnl_althyp,double lnl_nullhyp,double threshold)239 bool lrtest(double lnl_althyp, double lnl_nullhyp, double threshold) {
240     return 2.0 * (lnl_althyp - lnl_nullhyp) > threshold;
241 }
242 
243 inline
vcf_lnl2gq(double lnl_gt,double lnl_gt_alt)244 long vcf_lnl2gq(double lnl_gt, double lnl_gt_alt) {
245     double lr = 2.0 * (lnl_gt - lnl_gt_alt);
246     auto itr = std::upper_bound(chisq1ddf_gq.begin(), chisq1ddf_gq.end(), lr);
247     return itr - chisq1ddf_gq.begin();
248 }
249 
250 inline
call_snp(double l_ratio)251 snp_type call_snp (double l_ratio) {
252     if (l_ratio <= heterozygote_limit)
253         return snp_type_het;
254     else if (l_ratio >= homozygote_limit)
255         return snp_type_hom;
256     else
257         return snp_type_unk;
258 }
259 
260 inline
call_snp(double lnl_hom,double lnl_het)261 snp_type call_snp (double lnl_hom, double lnl_het) {
262     return call_snp(2.0 * (lnl_hom - lnl_het));
263 }
264 
265 inline
lnl_multinomial_model_hom(double total,double n1)266 double lnl_multinomial_model_hom (double total, double n1) {
267     if (n1 == total)
268         return 0.0;
269     else if (n1 < 0.25 * total)
270         return total * log(0.25); // With epsilon estimate bounded at 1.0
271     else
272         return n1 * log(n1/total) + (total-n1) * log( (total-n1)/(3.0*total) );
273 }
274 
275 inline
lnl_multinomial_model_het(double total,double n1n2)276 double lnl_multinomial_model_het (double total, double n1n2) {
277     if (n1n2 == total)
278         return total * log(0.5);
279     else if (n1n2 < 0.5 * total)
280         return total * log(0.25); // With epsilon estimate bounded at 1.0
281     else
282         return n1n2 * log( n1n2/(2.0*total) ) + (total-n1n2) * log( (total-n1n2)/(2.0*total) );
283 }
284 
285 inline
lr_multinomial_model_legacy(double nuc_1,double nuc_2,double nuc_3,double nuc_4)286 double lr_multinomial_model_legacy (double nuc_1, double nuc_2, double nuc_3, double nuc_4) {
287     //
288     // This function is to check that the refactored function gives the same
289     // results as the original code (i.e. this code).
290     //
291 
292     double total = nuc_1 + nuc_2 + nuc_3 + nuc_4;
293     assert(total > 0.0);
294 
295     double l_ratio = (nuc_1 * log(nuc_1 / total));
296 
297     if (total - nuc_1 > 0.0)
298         l_ratio += ((total - nuc_1) * log((total - nuc_1) / (3.0 * total)));
299 
300     if (nuc_1 + nuc_2 > 0.0)
301         l_ratio -= ((nuc_1 + nuc_2) * log((nuc_1 + nuc_2) / (2.0 * total)));
302 
303     if (nuc_3 + nuc_4 > 0.0)
304         l_ratio -= ((nuc_3 + nuc_4) * log((nuc_3 + nuc_4) / (2.0 * total)));
305 
306     l_ratio *= 2.0;
307 
308     return l_ratio;
309 }
310 
311 inline
lr_multinomial_model(double nuc_1,double nuc_2,double nuc_3,double nuc_4)312 double lr_multinomial_model (double nuc_1, double nuc_2, double nuc_3, double nuc_4) {
313     //
314     // Method of Paul Hohenlohe <hohenlohe@uidaho.edu>, personal communication.
315     //
316     // For a diploid individual, there are ten possible genotypes
317     // (four homozygous and six heterozygous genotypes).  We calculate
318     // the likelihood of each possible genotype by using a multinomial
319     // sampling distribution, which gives the probability of observing
320     // a set of read counts (n1,n2,n3,n4) given a particular genotype.
321     //
322 
323     double total = nuc_1 + nuc_2 + nuc_3 + nuc_4;
324     assert(total > 0.0);
325     assert(nuc_1 >= nuc_2 && nuc_2 >= nuc_3 && nuc_3 >= nuc_4);
326 
327     double l_ratio = 2.0 * (lnl_multinomial_model_hom(total, nuc_1) - lnl_multinomial_model_het(total, nuc_1+nuc_2));
328 
329     #ifdef DEBUG
330     double l_ratio_legacy = lr_multinomial_model_legacy(nuc_1,nuc_2,nuc_3,nuc_4);
331     assert( (l_ratio == 0.0 && l_ratio_legacy == 0.0) || almost_equal(l_ratio, l_ratio_legacy));
332     #endif
333     return l_ratio;
334 }
335 
336 inline
lr_bounded_multinomial_model(double nuc_1,double nuc_2,double nuc_3,double nuc_4)337 double lr_bounded_multinomial_model (double nuc_1, double nuc_2, double nuc_3, double nuc_4) {
338 
339     //
340     // Method of Paul Hohenlohe <hohenlohe@uidaho.edu>, personal communication.
341     //
342 
343     double total = nuc_1 + nuc_2 + nuc_3 + nuc_4;
344     assert(total > 0.0);
345 
346     //
347     // Calculate the site specific error rate for homozygous and heterozygous genotypes.
348     //
349     double epsilon_hom  = (4.0 / 3.0) * ((total - nuc_1) / total);
350     double epsilon_het  = 2.0 * ((nuc_3 + nuc_4) / total);
351 
352     //
353     // Check if the error rate is above or below the specified bound.
354     //
355     if (epsilon_hom < bound_low)
356         epsilon_hom = bound_low;
357     else if (epsilon_hom > bound_high)
358         epsilon_hom = bound_high;
359 
360     if (epsilon_het < bound_low)
361         epsilon_het = bound_low;
362     else if (epsilon_het > bound_high)
363         epsilon_het = bound_high;
364 
365     //
366     // Calculate the log likelihood for the homozygous and heterozygous genotypes.
367     //
368     double ln_L_hom = nuc_1 * log(1 - ((3.0/4.0) * epsilon_hom));
369     ln_L_hom += epsilon_hom > 0.0 ? ((nuc_2 + nuc_3 + nuc_4) * log(epsilon_hom / 4.0)) : 0.0;
370 
371     double ln_L_het = (nuc_1 + nuc_2) * log(0.5 - (epsilon_het / 4.0));
372     ln_L_het += epsilon_het > 0.0 ? ((nuc_3 + nuc_4) * log(epsilon_het / 4.0)) : 0.0;
373 
374     //
375     // Calculate the likelihood ratio.
376     //
377     double l_ratio  = 2.0 * (ln_L_hom - ln_L_het);
378 
379     // cerr << "  Nuc_1: " << nuc_1 << " Nuc_2: " << nuc_2 << " Nuc_3: " << nuc_3 << " Nuc_4: " << nuc_4
380     //   << " epsilon homozygote: " << epsilon_hom
381     //   << " epsilon heterozygote: " << epsilon_het
382     //   << " Log likelihood hom: " << ln_L_hom
383     //   << " Log likelihood het: " << ln_L_het
384     //   << " Likelihood ratio: " << l_ratio << "\n";
385 
386     return l_ratio;
387 }
388 
389 inline
set_call(snp_type c,Nt2 rank0_nt,Nt2 rank1_nt,long gq)390 void SampleCall::set_call(snp_type c, Nt2 rank0_nt, Nt2 rank1_nt, long gq) {
391     assert(gq >= 0 && gq <= 40);
392     call_ = c;
393     if (call_ == snp_type_hom) {
394         nts_[0] = rank0_nt;
395         gq_ = gq;
396     } else if (call_ == snp_type_het) {
397         if (rank0_nt < rank1_nt)
398             nts_ = {{rank0_nt, rank1_nt}};
399         else
400             nts_ = {{rank1_nt, rank0_nt}};
401         gq_ = gq;
402     }
403 }
404 
405 #endif // __MODELS_H__
406