1#!/usr/bin/env python3
2
3# calculates data likelihoods for sets of alleles
4from __future__ import print_function, division
5import multiset
6import sys
7import cjson
8import phred
9import json
10import math
11import operator
12from logsumexp import logsumexp
13from dirichlet import dirichlet_maximum_likelihood_ratio, dirichlet, multinomial, multinomialln
14from factorialln import factorialln
15
16"""
17This module attempts to find the best method to approximate the integration of
18data likelihoods for the bayesian variant caller we're currently working on.
19
20stdin should be a stream of newline-delimited json records each encoding a list
21of alleles which have been parsed out of alignment records.  alleles.cpp in
22this distribution provides such a stream.
23
24Erik Garrison <erik.garrison@bc.edu> 2010-07-15
25"""
26
27#potential_alleles = [
28#        {'type':'reference'},
29#        {'type':'snp', 'alt':'A'},
30#        {'type':'snp', 'alt':'T'},
31#        {'type':'snp', 'alt':'G'},
32#        {'type':'snp', 'alt':'C'}
33#        ]
34
35def list_genotypes_to_count_genotypes(genotypes):
36    count_genotypes = []
37    for genotype in genotypes:
38        counts = {}
39        for allele in genotype:
40            if counts.has_key(allele):
41                counts[allele] += 1
42            else:
43                counts[allele] = 1
44        count_genotypes.append(counts.items())
45    return count_genotypes
46
47"""
48ploidy = 2
49potential_alleles = ['A','T','G','C']
50
51# genotypes are expressed as sets of allele frequencies
52genotypes = list_genotypes_to_count_genotypes(list(multiset.multichoose(ploidy, potential_alleles)))
53"""
54
55
56# TODO
57# update this so that we aren't just using the 'alternate' field from the alleles
58# and are also incorporating the type of allele (ins, deletion, ref, snp)
59
60
61def group_alleles(alleles):
62    groups = {}
63    for allele in alleles:
64        alt = allele['alt']
65        if groups.has_key(alt):
66            groups[alt].append(allele)
67        else:
68            groups[alt] = [allele]
69    return groups
70
71def alleles_quality_to_lnprob(alleles):
72    for allele in alleles:
73        allele['quality'] = phred.phred2ln(allele['quality'])
74    return alleles
75
76def fold(func, iterable, initial=None, reverse=False):
77    x=initial
78    if reverse:
79        iterable=reversed(iterable)
80    for e in iterable:
81        x=func(x,e) if x is not None else e
82    return x
83
84def product(listy):
85    return fold(operator.mul, listy)
86
87def observed_alleles_in_genotype(genotype, allele_groups):
88    in_genotype = {}
89    not_in_genotype = {}
90    for key in allele_groups.keys():
91        found = False
92        for allele, count in genotype:
93            if allele == key:
94                in_genotype[key] = allele_groups[key]
95                found = True
96                break
97        if not found:
98            not_in_genotype[key] = allele_groups[key]
99    return in_genotype, not_in_genotype
100
101#def scaled_sampling_prob(genotype, alleles):
102#    """The probability of drawing the observations in the allele_groups out of
103#    the given genotype, scaled by the number of possible multiset permutations
104#    of the genotype (we scale because we don't phase our genotypes under
105#    evaluation)."""
106#    allele_groups = group_alleles(alleles)
107#    if len(allele_groups.items()) == 0:
108#        return 0
109#    genotype_allele_frequencies = [x[1] for x in genotype]
110#    multiplicity = sum(genotype_allele_frequencies)
111#    genotype_allele_probabilities = [float(x)/multiplicity for x in genotype_allele_frequencies]
112#    observed_allele_frequencies = [len(x) for x in allele_groups.items()]
113#    observation_product = 1
114#    for allele, count in genotype:
115#        if allele_groups.has_key(allele):
116#            observation_product *= math.pow(float(count) / multiplicity, len(allele_groups[allele]))
117#    return float(math.pow(math.factorial(multiplicity), 2)) \
118#        / (product([math.factorial(x) for x in genotype_allele_frequencies]) *
119#                sum([math.factorial(x) for x in observed_allele_frequencies])) \
120#        * observation_product
121#
122
123
124# TODO XXX
125# Yes, this is the sampling probability.  It is the multinomial sampling
126# probability, which is the specific probability of a specific set of
127# categorical outcomes.  Unfortunately, this is not what we really want here.
128# What we want is the prior probability that a given set of draws come out of a
129# given multiset (genotype, in our case).  I believe that this is given by the
130# Dirichlet distribution.  Investigate.
131def sampling_prob(genotype, alleles):
132    """The specific probability of drawing the observations in alleles out of the given
133    genotype, follows the multinomial probability distribution."""
134    allele_groups = group_alleles(alleles)
135    multiplicity = sum([x[1] for x in genotype])
136    print(genotype, multiplicity, alleles)
137    for allele, count in genotype:
138        if allele_groups.has_key(allele):
139            print allele, count, math.pow(float(count) / multiplicity, len(allele_groups[allele]))
140    print(product([math.factorial(len(obs)) for obs in allele_groups.values()]))
141    print(allele_groups.values())
142    return float(math.factorial(len(alleles))) \
143        / product([math.factorial(len(obs)) for obs in allele_groups.values()]) \
144        * product([math.pow(float(count) / multiplicity, len(allele_groups[allele])) \
145                    for allele, count in genotype if allele_groups.has_key(allele)])
146
147def likelihood_given_true_alleles(observed_alleles, true_alleles):
148    prob = 0
149    for o, t in zip(observed_alleles, true_alleles):
150        if o['alt'] == t['alt']:
151            prob += math.log(1 - math.exp(o['quality']))
152        else:
153            prob += o['quality']
154    return prob
155
156def data_likelihood_exact(genotype, observed_alleles):
157    """'Exact' data likelihood, sum of sampling probability * join Q score for
158    the observed alleles over all possible underlying 'true allele'
159    combinations."""
160    #print "probability that observations", [o['alt'] for o in observed_alleles], "arise from genotype", genotype
161    observation_count = len(observed_alleles)
162    ploidy = sum([count for allele, count in genotype])
163    allele_probs = [count / float(ploidy) for allele, count in genotype]
164    probs = []
165    # for all true allele combinations X permutations
166    for true_allele_combination in multiset.multichoose(observation_count, [x[0] for x in genotype]):
167        for true_allele_permutation in multiset.permutations(true_allele_combination):
168            # this mapping allows us to use sampling_prob the same way as we do when we use JSON allele observation records
169            true_alleles = [{'alt':allele} for allele in true_allele_permutation]
170            allele_groups = group_alleles(true_alleles)
171            observations = []
172            for allele, count in genotype:
173                if allele_groups.has_key(allele):
174                    observations.append(len(allele_groups[allele]))
175                else:
176                    observations.append(0)
177            #sprob = dirichlet_maximum_likelihood_ratio(allele_probs, observations) # distribution parameter here
178            lnsampling_prob = multinomialln(allele_probs, observations)
179            prob = lnsampling_prob + likelihood_given_true_alleles(observed_alleles, true_alleles)
180            #print math.exp(prob), sprob, genotype, true_allele_permutation
181            #print genotype, math.exp(prob), sprob, true_allele_permutation, [o['alt'] for o in observed_alleles]
182            probs.append(prob)
183    # sum the individual probability of all combinations
184    p = logsumexp(probs)
185    #print math.exp(p)
186    return p
187
188def data_likelihood_estimate(genotype, alleles):
189    """Estimates the data likelihood, which is a sum over all possible error
190    profiles, or underlying 'true alleles', motivating the observations."""
191    # for up to error_depth errors
192    pass
193
194def genotype_combination_sampling_probability(genotype_combination, observed_alleles):
195    multiplicity = math.log(ploidy * len(genotype_combination))
196    result = 1 - multiplicity
197    allele_groups = group_alleles(observed_alleles)
198    for allele, observations in allele_groups.iteritems():
199        result += math.log(math.factorial(len(observations)))
200    # scale by product of multiset permutations of all genotypes in combo
201    for combo in genotype_combination:
202        for genotype in combo:
203            m_i = sum([a[1] for a in genotype])
204            result += math.log(math.factorial(m_i))
205            result -= sum([math.log(math.factorial(allele[1])) for allele in genotype])
206    return result
207
208def count_frequencies(genotype_combo):
209    counts = {}
210    alleles = {}
211    for genotype in genotype_combo:
212        for allele, count in genotype:
213            if alleles.has_key(allele):
214                alleles[allele] += count
215            else:
216                alleles[allele] = count
217    for allele, count in alleles.iteritems():
218        if counts.has_key(count):
219            counts[count] += 1
220        else:
221            counts[count] = 1
222    return counts
223
224def allele_frequency_probability(allele_frequency_counts, theta=0.001):
225    """Implements Ewens' Sampling Formula.  allele_frequency_counts is a
226    dictionary mapping count -> number of alleles with this count in the
227    population."""
228    M = sum([frequency * count for frequency, count in allele_frequency_counts.iteritems()])
229    return math.factorial(M) \
230        / (theta * product([theta + h for h in range(1, M)])) \
231        * product([math.pow(theta, count) / math.pow(frequency, count) * math.factorial(count) \
232            for frequency, count in allele_frequency_counts.iteritems()])
233
234def powln(n, m):
235    """Power of number in log space"""
236    return sum([n] * m)
237
238def allele_frequency_probabilityln(allele_frequency_counts, theta=0.001):
239    """Log space version to avoid inevitable overflows with coverage >100.
240    Implements Ewens' Sampling Formula.  allele_frequency_counts is a
241    dictionary mapping count -> number of alleles with this count in the
242    population."""
243    thetaln = math.log(theta)
244    M = sum([frequency * count for frequency, count in allele_frequency_counts.iteritems()])
245    return factorialln(M) \
246        - (thetaln + sum([math.log(theta + h) for h in range(1, M)])) \
247        + sum([powln(thetaln, count) - powln(math.log(frequency), count) + factorialln(count) \
248            for frequency, count in allele_frequency_counts.iteritems()])
249
250def genotype_probabilities(genotypes, alleles):
251    return [[str(genotype), data_likelihood_exact(genotype, alleles)] for genotype in genotypes]
252
253def genotype_probabilities_heuristic(genotypes, alleles):
254    groups = group_alleles(alleles)
255    # group genotypes relative to the groups of observed alleles
256    # take the first member of each group and apply our data likelihood calculation
257    # then apply it to the rest
258    if len(groups.keys()) is 1:
259        # we can cleanly do all-right, part-right, all-wrong
260        pass
261    if len(groups.keys()) is 2:
262        # we can do all-right, two types of 'part-right', and all-wrong
263        pass
264
265def multiset_banded_genotype_combinations(sample_genotypes, bandwidth):
266    for index_combo in multiset.multichoose(len(samples), range(bandwidth)):
267        for index_permutation in multiset.permutations(index_combo):
268            yield [genotypes[index] for index, genotypes in zip(index_permutation, sample_genotypes)]
269
270# TODO you should implement gabor's banding solution; the above multiset method
271# is comically large and produces incorrect results despite the computational load
272def banded_genotype_combinations(sample_genotypes, bandwidth, band_depth):
273    # always provide the 'best' case
274    yield [(sample, genotypes[0]) for sample, genotypes in sample_genotypes]
275    for i in range(1, bandwidth):
276        for j in range(1, band_depth):  # band_depth is the depth to which we explore the bandwith... TODO explain better
277            indexes = j * [i] + (len(sample_genotypes) - j) * [0]
278            for index_permutation in multiset.permutations(indexes):
279                yield [(sample, genotypes[index]) for index, (sample, genotypes) in zip(index_permutation, sample_genotypes)]
280
281def genotype_str(genotype):
282    return fold(operator.add, [allele * count for allele, count in genotype])
283
284if __name__ == '__main__':
285
286    ploidy = 2 # assume ploidy 2 for all individuals and all positions
287
288    potential_alleles = ['A','T','G','C']
289
290    # genotypes are expressed as sets of allele frequencies
291    genotypes = list_genotypes_to_count_genotypes(list(multiset.multichoose(ploidy, potential_alleles)))
292
293    for line in sys.stdin:
294        position = cjson.decode(line)
295        #print position['position']
296        samples = position['samples']
297
298        position['coverage'] = sum([len(sample['alleles']) for samplename, sample in samples.iteritems()])
299
300        #potential_alleles = ['A','T','G','C']
301        potential_alleles = set()
302        for samplename, sample in samples.items():
303            # only process snps and reference alleles
304            alleles = [allele for allele in sample['alleles'] if allele['type'] in ['reference', 'snp']]
305            alleles = alleles_quality_to_lnprob(alleles)
306            sample['alleles'] = alleles
307            potential_alleles = potential_alleles.union(set([allele['alt'] for allele in alleles]))
308
309        position['filtered coverage'] = sum([len(sample['alleles']) for samplename, sample in samples.iteritems()])
310
311        # genotypes are expressed as sets of allele frequencies
312        #genotypes = list_genotypes_to_count_genotypes(list(multiset.multichoose(ploidy, list(potential_alleles))))
313
314        for samplename, sample in samples.items():
315            alleles = sample['alleles']
316            groups = group_alleles(alleles)
317            sample['genotypes'] = [[genotype, data_likelihood_exact(genotype, alleles)] for genotype in genotypes]
318            #sample['genotypes_estimate'] = [[str(genotype), data_likelihood_estimate(genotype, alleles)] for genotype in genotypes]
319        # estimate the posterior over all genotype combinations within some indexed bandwidth of optimal
320        # TODO preserve sample names in the genotype comos
321        sample_genotypes = [(name, sorted(sample['genotypes'], key=lambda genotype: genotype[1], reverse=True)) for name, sample in samples.iteritems()]
322        genotype_combo_probs = []
323        #for combo in multiset_banded_genotype_combinations(sample_genotypes, 2):
324        #for combo in banded_genotype_combinations(sample_genotypes, min(len(genotypes), 2), len(samples)):
325        # now marginals time...
326        marginals = {}
327        for name, sample in samples.iteritems():
328            marginals[name] = {}
329
330        combos_tested = 0
331        for combo in banded_genotype_combinations(sample_genotypes, min(len(genotypes), 2), 2):
332            combos_tested += 1
333            probability_observations_given_genotypes = sum([prob for name, (genotype, prob) in combo])
334            frequency_counts = count_frequencies([genotype for name, (genotype, prob) in combo])
335            prior_probability_of_genotype = allele_frequency_probabilityln(frequency_counts)
336            combo_prob = prior_probability_of_genotype + probability_observations_given_genotypes
337            for name, (genotype, prob) in combo:
338                gstr = genotype_str(genotype)
339                if marginals[name].has_key(gstr):
340                    marginals[name][gstr].append(combo_prob)
341                else:
342                    marginals[name][gstr] = [combo_prob]
343            genotype_combo_probs.append([combo, combo_prob])
344
345        genotype_combo_probs = sorted(genotype_combo_probs, key=lambda c: c[1], reverse=True)
346        #for line in [json.dumps({'prob':prior_probability_of_genotype, 'combo':combo}) for combo, prior_probability_of_genotype in genotype_combo_probs]:
347        #    print line
348
349        # sum, use to normalize
350        # apply bayes rule
351
352        #print genotype_combo_probs
353        #print [prob for combo, prob in genotype_combo_probs]
354        #for combo, prob in genotype_combo_probs:
355        #    print prob
356        posterior_normalizer = logsumexp([prob for combo, prob in genotype_combo_probs])
357
358        # handle marginals
359        for sample, genotype_probs in marginals.iteritems():
360            for genotype, probs in genotype_probs.iteritems():
361                marginals[sample][genotype] = logsumexp(probs) - posterior_normalizer
362
363        best_genotype_combo = genotype_combo_probs[0][0]
364        best_genotype_combo_prob = genotype_combo_probs[0][1]
365
366        #best_genotype_probability = math.exp(sum([prob for name, (genotype, prob) in best_genotype_combo]) \
367        #        + allele_frequency_probabilityln(count_frequencies([genotype for name, (genotype, prob) in best_genotype_combo])) \
368        #        - posterior_normalizer)
369        best_genotype_probability = math.exp(best_genotype_combo_prob - posterior_normalizer)
370        position['best_genotype_combo'] = [[name, genotype_str(genotype), math.exp(marginals[name][genotype_str(genotype)])]
371                                                  for name, (genotype, prob) in best_genotype_combo]
372        position['best_genotype_combo_prob'] = best_genotype_probability
373        position['posterior_normalizer'] = math.exp(posterior_normalizer)
374        position['combos_tested'] = combos_tested
375        #position['genotype_combo_probs'] = genotype_combo_probs
376        # TODO estimate marginal probabilities of genotypings
377        # here we cast everything into float-space
378        for samplename, sample in samples.items():
379            sample['genotypes'] = sorted([[genotype_str(genotype), math.exp(prob)] for genotype, prob in sample['genotypes']],
380                                            key=lambda c: c[1], reverse=True)
381
382        print(cjson.encode(position))
383        #print position['position']
384