1#!/usr/bin/env python
2
3"""
4From a set of regions and two sets of intervals inside those regions
5compute (for each region separately) the overlap between the two sets
6of intervals and the overlap in `nsamples` random coverings of the
7regions with intervals having the same lengths. Prints the z-score relative
8to the mean and sample stdev of the random coverings.
9
10Currently intervals must be in bed 3+ format.
11
12TODO: There are a few versions of this floating around, including a
13      better/faster one using gap lists instead of bitsets. Need to track
14      that down and merge as necessary.
15
16usage: %prog bounding_region_file intervals1 intervals2 nsamples
17"""
18
19import sys
20
21from numpy import zeros
22
23from bx.bitset import BitSet
24from bx.intervals.random_intervals import throw_random_bits
25from bx_extras import stats
26
27maxtries = 10
28
29
30class MaxtriesException(Exception):
31    pass
32
33
34def bit_clone(bits):
35    """
36    Clone a bitset
37    """
38    new = BitSet(bits.size)
39    new.ior(bits)
40    return new
41
42
43def throw_random(lengths, mask):
44    """
45    Try multiple times to run 'throw_random'
46    """
47    saved = None
48    for i in range(maxtries):
49        try:
50            return throw_random_bits(lengths, mask)
51        except MaxtriesException as e:
52            saved = e
53            continue
54    raise saved
55
56
57def as_bits(region_start, region_length, intervals):
58    """
59    Convert a set of intervals overlapping a region of a chromosome into
60    a bitset for just that region with the bits covered by the intervals
61    set.
62    """
63    bits = BitSet(region_length)
64    for chr, start, stop in intervals:
65        bits.set_range(start - region_start, stop - start)
66    return bits
67
68
69def interval_lengths(bits):
70    """
71    Get the length distribution of all contiguous runs of set bits from
72    """
73    end = 0
74    while True:
75        start = bits.next_set(end)
76        if start == bits.size:
77            break
78        end = bits.next_clear(start)
79        yield end - start
80
81
82def count_overlap(bits1, bits2):
83    """
84    Count the number of bits that overlap between two sets
85    """
86    b = BitSet(bits1.size)
87    b |= bits1
88    b &= bits2
89    return b.count_range(0, b.size)
90
91
92def overlapping_in_bed(fname, r_chr, r_start, r_stop):
93    """
94    Get from a bed all intervals that overlap the region defined by
95    r_chr, r_start, r_stop.
96    """
97    rval = []
98    for line in open(fname):
99        if line.startswith("#") or line.startswith("track"):
100            continue
101        fields = line.split()
102        chr, start, stop = fields[0], int(fields[1]), int(fields[2])
103        if chr == r_chr and start < r_stop and stop >= r_start:
104            rval.append((chr, max(start, r_start), min(stop, r_stop)))
105    return rval
106
107
108def main():
109    region_fname = sys.argv[1]
110    mask_fname = sys.argv[2]
111    nsamples = int(sys.argv[3])
112    intervals1_fname = sys.argv[4]
113    intervals2_fnames = sys.argv[5:]
114    nfeatures = len(intervals2_fnames)
115    total_actual = zeros(nfeatures)
116    # total_lengths1 = 0
117    total_lengths2 = zeros(nfeatures)
118    total_samples = zeros((nsamples, nfeatures))
119    for line in open(region_fname):
120        # Load lengths for all intervals overlapping region
121        fields = line.split()
122        print("Processing region:", fields[3], file=sys.stderr)
123        r_chr, r_start, r_stop = fields[0], int(fields[1]), int(fields[2])
124        r_length = r_stop - r_start
125        # Load the mask
126        mask = overlapping_in_bed(mask_fname, r_chr, r_start, r_stop)
127        bits_mask = as_bits(r_start, r_length, mask)
128        bits_not_masked = bit_clone(bits_mask)
129        bits_not_masked.invert()
130        # Load the first set
131        intervals1 = overlapping_in_bed(intervals1_fname, r_chr, r_start, r_stop)
132        bits1 = as_bits(r_start, r_length, intervals1)
133        # Intersect it with the mask
134        bits1.iand(bits_not_masked)
135        # Sanity checks
136        assert count_overlap(bits1, bits_mask) == 0
137        # For each data set
138        for featnum, intervals2_fname in enumerate(intervals2_fnames):
139            print(intervals2_fname, file=sys.stderr)
140            intervals2 = overlapping_in_bed(intervals2_fname, r_chr, r_start, r_stop)
141            bits2 = as_bits(r_start, r_length, intervals2)
142            bits2.iand(bits_not_masked)
143            assert count_overlap(bits2, bits_mask) == 0
144            # Observed values
145            actual_overlap = count_overlap(bits1, bits2)
146            total_actual[featnum] += actual_overlap
147            # Sample
148            lengths2 = list(interval_lengths(bits2))
149            total_lengths2[featnum] += sum(lengths2)
150            for i in range(nsamples):
151                # Build randomly covered bitmask for second set
152                random2 = throw_random(lengths2, bits_mask)
153                # Find intersection
154                random2 &= bits1
155                # Print amount intersecting
156                total_samples[i, featnum] += random2.count_range(0, random2.size)
157                print(total_samples[i, featnum], file=sys.stderr)
158    fraction_overlap = total_samples / total_lengths2
159    print("\t".join(intervals2_fnames))
160    print("\t".join(map(str, total_actual/total_lengths2)))
161    for row in fraction_overlap:
162        print("\t".join(map(str, row)))
163    print("observed overlap: %d, sample mean: %d, sample stdev: %d" % (total_actual, stats.amean(total_samples), stats.asamplestdev(total_samples)))
164    print("z-score:", (total_actual - stats.amean(total_samples)) / stats.asamplestdev(total_samples))
165    print("percentile:", sum(total_actual > total_samples) / nsamples)
166
167
168if __name__ == "__main__":
169    main()
170