1#!/usr/local/bin/python3.8
2#
3#  The MIT License
4#
5#  Copyright (c) 2017-2019 Genome Research Ltd.
6#
7#  Author: Petr Danecek <pd3@sanger.ac.uk>
8#
9#  Permission is hereby granted, free of charge, to any person obtaining a copy
10#  of this software and associated documentation files (the "Software"), to deal
11#  in the Software without restriction, including without limitation the rights
12#  to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
13#  copies of the Software, and to permit persons to whom the Software is
14#  furnished to do so, subject to the following conditions:
15#
16#  The above copyright notice and this permission notice shall be included in
17#  all copies or substantial portions of the Software.
18#
19#  THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
20#  IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
21#  FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
22#  AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
23#  LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
24#  OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
25#  THE SOFTWARE.
26
27import glob, gzip, csv, sys, os, copy, re
28csv.register_dialect('tab', delimiter='\t', quoting=csv.QUOTE_NONE)
29
30def usage(msg=None):
31    if msg==None:
32        print("""\
33Usage: plot-roh.py [OPTIONS] <dir>
34Options:
35   -H, --highlight +group1,-group2       Highlight calls shared within group1 but not present in group2
36   -i, --interactive                     Run interactively
37   -l, --min-length <num>                Filter input regions shorter than this [0]
38   -n, --min-markers <num>               Filter input regions with fewer marker than this [0]
39   -o, --outfile <file>                  Output file name [plot.png]
40   -q, --min-qual <num>                  Filter input regions with quality smaller than this [0]
41   -r, --region [^]<chr|chr:beg-end>     Plot this chromosome/region only
42   -s, --samples <file>                  List of samples to show, rename or group: "name[\\tnew_name[\\tgroup]]"
43   -h, --help                            This usage text
44Matplotlib options:
45   +adj, --adjust <str>          Set plot adjust [bottom=0.18,left=0.07,right=0.98]
46   +dpi, --dpi <num>             Set bitmap DPI [150]
47   +sxt, --show-xticks           Show x-ticks (genomic coordinate)
48   +twh, --track-wh <num,num>    Set track width and height [20,1]
49   +xlb, --xlabel <str>          Set x-label
50   +xli, --xlimit <num>          Extend x-range by this fraction [0.05]""")
51    else:
52        print(msg)
53    sys.exit(1)
54
55dir          = None
56plot_regions = None
57min_length   = 0
58min_markers  = 0
59min_qual     = 0
60interactive  = False
61sample_file  = None
62highlight    = None
63outfile      = None
64adjust       = 'bottom=0.18,left=0.07,right=0.98'
65dpi          = 150
66xlim         = 0.05
67show_xticks  = False
68xlabel       = None
69track_width  = 20
70track_height = 1
71
72if len(sys.argv) < 2: usage()
73args = sys.argv[1:]
74while len(args):
75    if args[0]=='-r' or args[0]=='--region':
76        args = args[1:]
77        plot_regions = args[0]
78    elif args[0]=='-i' or args[0]=='--interactive':
79        interactive = True
80    elif args[0]=='-l' or args[0]=='--min-length':
81        args = args[1:]
82        min_length = float(args[0])
83    elif args[0]=='-n' or args[0]=='--min-markers':
84        args = args[1:]
85        min_markers = float(args[0])
86    elif args[0]=='-o' or args[0]=='--outfile':
87        args = args[1:]
88        outfile = args[0]
89    elif args[0]=='-q' or args[0]=='--min-qual':
90        args = args[1:]
91        min_qual = float(args[0])
92    elif args[0]=='-H' or args[0]=='--highlight':
93        args = args[1:]
94        highlight = args[0]
95    elif args[0]=='-s' or args[0]=='--samples':
96        args = args[1:]
97        sample_file = args[0]
98    elif args[0]=='-?' or args[0]=='-h' or args[0]=='--help':
99        usage()
100    elif args[0]=='+adj' or args[0]=='--adjust':
101        args = args[1:]
102        adjust = args[0]
103    elif args[0]=='+dpi' or args[0]=='--dpi':
104        args = args[1:]
105        dpi = float(args[0])
106    elif args[0]=='+sxt' or args[0]=='--show-xticks':
107        show_xticks = True
108    elif args[0]=='+twh' or args[0]=='--track-wh':
109        args = args[1:]
110        (track_width,track_height) = args[0].split(',')
111        track_height = -float(track_height)    # will be used as if negative, no auto-magic
112        track_width  = float(track_width)
113    elif args[0]=='+xlb' or args[0]=='--xlabel':
114        args = args[1:]
115        xlabel = args[0]
116    elif args[0]=='+xli' or args[0]=='--xlimit':
117        args = args[1:]
118        xlim = float(args[0])
119    else:
120        dir = args[0]
121    args = args[1:]
122
123if interactive and outfile!=None: usage("Use -i, --interactive or -o, --outfile, but not both")
124if not interactive and outfile==None: outfile = 'plot.png'
125
126def wrap_hash(**args): return args
127adjust = eval("wrap_hash("+adjust+")")
128
129
130import matplotlib as mpl
131if interactive==False:
132    mpl.use('Agg')
133    import matplotlib.pyplot as plt
134    import matplotlib.patches as patches
135else:
136    for gui in ['TKAgg','GTKAgg','Qt4Agg','WXAgg','MacOSX']:
137        try:
138            mpl.use(gui,warn=False, force=True)
139            import matplotlib.pyplot as plt
140            import matplotlib.patches as patches
141            break
142        except:
143            continue
144
145cols = [ '#337ab7', '#5cb85c', '#5bc0de', '#f0ad4e', '#d9534f', 'grey', 'black' ]
146
147globstr = os.path.join(dir, '*.txt.gz')
148fnames = glob.glob(globstr)
149if len(fnames)==0: usage("No data files found in \""+dir+"\"")
150
151def parse_regions(str):
152    if str==None: return None
153    regs = { 'inc':[], 'exc':[] }
154    list = str.split(',')
155    key = 'inc'
156    if list[0][0]=='^':
157        key = 'exc'
158        list[0] = list[0][1:]
159    for reg in list:
160        x = reg.split(':')
161        chr = x[0]
162        beg = 0
163        end = (1<<32)-1
164        if len(x)>1:
165            (beg,end) = x[1].split('-')
166            beg = float(beg)
167            end = float(end)
168        regs[key].append({'chr':chr,'beg':beg,'end':end})
169    return regs
170
171def region_overlap(regs,chr,beg,end):
172    if regs==None: return (beg,end)
173    if len(regs['exc'])>0:
174        for reg in regs['exc']:
175            if chr==reg['chr']: return None
176        return (beg,end)
177    if len(regs['inc'])==0: return (beg,end)
178    for reg in regs['inc']:
179        if chr!=reg['chr']: continue
180        if beg>reg['end']: continue
181        if end<reg['beg']: continue
182        if beg<reg['beg']: beg = reg['beg']
183        if end>reg['end']: end = reg['end']
184        return (beg,end)
185    return None
186
187def parse_outfile(fname):
188    files = re.split(r',',fname)
189    bname = re.search(r'^(.+)\.[^.]+$', files[0]).group(1)
190    for i in range(len(files)-1):
191        files[i+1] = bname+"."+files[i+1]
192    return files
193
194def next_region(rgs):
195    min = None
196    for smpl in rgs:
197        if len(rgs[smpl])==0: continue
198        reg = rgs[smpl][0]
199        if min==None:
200            min = [0,0]
201            min[0] = reg[0]
202            min[1] = reg[1]
203        if min[0] > reg[0]: min[0] = reg[0]
204    if min==None: return None
205    for smpl in rgs:
206        if len(rgs[smpl])==0: continue
207        reg = rgs[smpl][0]
208        if min[1] > reg[1]: min[1] = reg[1]
209        if min[1] > reg[0] - 1 and min[0] != reg[0]: min[1] = reg[0] - 1
210    return min;
211
212def merge_regions(rg):
213    rgs = copy.deepcopy(rg)
214    out = {}
215    while True:
216        min = next_region(rgs)
217        if min==None: break
218        beg = min[0]
219        end = min[1]
220        smpls = []
221        for smpl in rgs:
222            if len(rgs[smpl])==0: continue
223            reg = rgs[smpl][0]
224            if reg[0] > end: continue
225            if reg[1] > end:
226                rgs[smpl][0][0] = end + 1
227            else:
228                rgs[smpl] = rgs[smpl][1:]
229            if smpl not in out: out[smpl] = []
230            smpls.append(smpl)
231        if len(smpls)>1:
232            for smpl in smpls: out[smpl].append([beg,end])
233    return out
234
235def prune_regions(groups,regions):
236    regs = {'+':{},'-':{}}
237    for smpl in regions:
238        grp = groups[smpl]
239        for reg in regions[smpl]:
240            key = str(reg[0])+"-"+str(reg[1])   # reg=[beg,end] -> "beg-end"
241            if key not in regs[grp]: regs[grp][key] = 0
242            regs[grp][key] += 1
243    nexp = 0
244    for smpl in groups:
245        if groups[smpl]=='+': nexp += 1
246    for smpl in regions:
247        rm = []
248        for reg in regions[smpl]:
249            key = str(reg[0])+"-"+str(reg[1])
250            if key in regs['-']: rm.append(reg)
251            elif key not in regs['+'] or regs['+'][key]!=nexp: rm.append(reg)
252        for reg in rm:
253            if reg in regions[smpl]:
254                regions[smpl].remove(reg)
255    return regions
256
257def parse_samples(fname,highlight):
258    if fname==None: return (None,None,{})
259    samples = {}
260    groups  = {}
261    grp2sgn = {}
262    smpl2y  = {}
263    # parse "+name" to create a map "name":"+"
264    if highlight!=None:
265        for grp in re.split(r',', highlight):
266            if grp[0]!='+' and grp[0]!='-': usage("Expected + or - before the group name: "+grp)
267            grp2sgn[grp[1:]] = grp[0]
268    # read samples, renaming them
269    with open(fname) as f:
270        for line in f:
271            row  = re.split(r'\s+', line.rstrip('\n'))
272            smpl = row[0]
273            if len(row)==1: samples[smpl] = smpl
274            else:
275                samples[smpl] = row[1]
276            if len(row)==3:
277                grp = row[2]
278                if grp in grp2sgn:
279                    grp = grp2sgn[grp]
280                else:
281                    grp = '+'
282                groups[smpl] = grp
283            y = len(smpl2y)
284            smpl2y[smpl] = y
285    if highlight==None: groups = None
286    return (samples,groups,smpl2y)
287
288plot_regions = parse_regions(plot_regions)
289(samples,groups,smpl2y) = parse_samples(sample_file,highlight)
290
291dat_gt = {}
292dat_rg = {}
293chrs   = []
294for fname in fnames:
295    f = gzip.open(fname, 'rt')
296    reader = csv.reader(f, 'tab')
297    for row in reader:
298        if row[0]=='GT':
299            chr  = row[1]
300            pos  = int(row[2])
301            reg  = region_overlap(plot_regions,chr,pos,pos)
302            if reg==None: continue
303            for i in range(3,len(row),2):
304                smpl = row[i]
305                if samples!=None and smpl not in samples: continue
306                gt   = row[i+1]
307                x = gt.split('/')
308                if x[0]=='.': continue          # missing genotype ./.
309                dsg = 2
310                if x[0]!=x[1]: dsg = 1
311                elif x[0]=='0': continue        # skip HomRef 0/0 genotypes
312                if chr not in dat_gt:
313                    dat_gt[chr] = {}
314                    chrs.append(chr)
315                if smpl not in dat_gt[chr]:
316                    dat_gt[chr][smpl] = []
317                if smpl not in smpl2y:
318                    y = len(smpl2y)
319                    smpl2y[smpl] = y
320                dat_gt[chr][smpl].append([pos,dsg])
321        elif row[0]=='RG':
322            smpl  = row[1]
323            if samples!=None and smpl not in samples: continue
324            chr   = row[2]
325            beg   = int(row[3])
326            end   = int(row[4])
327            length= int(row[5])
328            nmark = int(row[6])
329            qual  = float(row[7])
330            if length < min_length: continue
331            if nmark < min_markers : continue
332            if qual < min_qual : continue
333            reg = region_overlap(plot_regions,chr,beg,end)
334            if reg==None: continue
335            if chr not in dat_rg: dat_rg[chr] = {}
336            if smpl not in dat_rg[chr]: dat_rg[chr][smpl] = []
337            if reg!=None:
338                if beg<reg[0]: beg = reg[0]
339                if end>reg[1]: end = reg[1]
340            dat_rg[chr][smpl].append([beg,end])
341
342if samples==None:
343    samples = {}
344    for smpl in smpl2y: samples[smpl] = smpl
345
346# list the samples in the same order as encountered in the file, from top to bottom
347for smpl in smpl2y:
348    smpl2y[smpl] = len(smpl2y) - smpl2y[smpl] - 1
349
350off_list = []
351off_hash = {}
352off = 0
353off_sep = 0
354dat_rg1 = {}
355for chr in chrs:
356    if chr in dat_rg:
357        rg1 = merge_regions(dat_rg[chr])
358        if groups!=None:
359            rg1 = prune_regions(groups,rg1)
360        if len(rg1)!=0: dat_rg1[chr] = rg1
361    off_hash[chr] = off
362    max_pos = 0
363    for smpl in dat_gt[chr]:
364        if max_pos < dat_gt[chr][smpl][-1][0]: max_pos = dat_gt[chr][smpl][-1][0]
365    if off_sep==0: off_sep = max_pos*0.1
366    off += max_pos + off_sep
367    off_list.append(off)
368
369wh = track_width,len(smpl2y)
370if track_height < 0:
371    wh = track_width,-track_height*len(smpl2y)
372elif len(smpl2y)>5:
373    wh = track_width,5
374
375def bignum(num):
376    s = str(num); out = ''; slen = len(s)
377    for i in range(slen):
378        out += s[i]
379        if i+1<slen and (slen-i-1)%3==0: out += ','
380    return out
381
382def format_coord(x, y):
383    chr = None
384    off = 0
385    for i in range(len(off_list)):
386        chr = chrs[i]
387        if off_list[i] > x: break
388        off = off_list[i]
389    return 'chr%s:%s'%(chr,bignum(int(x - off)))
390
391
392fig, ax1 = plt.subplots(1, 1, figsize=wh, num=dir)
393ax1.yaxis.set_ticks_position('none')
394ax1.format_coord = format_coord
395xtick_lbl = []
396xtick_pos = []
397max_x = 0
398min_x = -1
399for chr in dat_gt:
400    off  = off_hash[chr]
401    icol = 0
402    max  = 0
403    for smpl in dat_gt[chr]:
404        y = smpl2y[smpl]
405        if chr in dat_rg and smpl in dat_rg[chr]:
406            for rg in dat_rg[chr][smpl]:
407                rect = patches.Rectangle((rg[0]+off,3*y+0.5), rg[1]-rg[0]+1, 2, color='#dddddd')
408                ax1.add_patch(rect)
409        if chr in dat_rg1 and smpl in dat_rg1[chr]:
410            for rg in dat_rg1[chr][smpl]:
411                rect = patches.Rectangle((rg[0]+off,3*y+0.5), rg[1]-rg[0]+1, 2, color='#d9534f')
412                ax1.add_patch(rect)
413        ax1.plot([x[0]+off for x in dat_gt[chr][smpl]],[x[1]+3*y for x in dat_gt[chr][smpl]],'.',color=cols[icol%len(cols)])
414        if min_x==-1 or min_x > dat_gt[chr][smpl][0][0]+off: min_x = dat_gt[chr][smpl][0][0]+off
415        if max_x < dat_gt[chr][smpl][-1][0]+off: max_x = dat_gt[chr][smpl][-1][0]+off
416        if max < dat_gt[chr][smpl][-1][0]: max = dat_gt[chr][smpl][-1][0]
417        icol += 1
418        if icol >= len(cols): 0
419    xtick_lbl.append(chr)
420    xtick_pos.append(off)
421ytick_lbl = []
422ytick_pos = []
423for chr in dat_gt:
424    for smpl in dat_gt[chr]:
425        ytick_lbl.append(samples[smpl])
426        ytick_pos.append(3*smpl2y[smpl]+1)
427    break
428if xlim!=0:
429    if min_x==-1: min_x = 0
430    ax1.set_xlim(min_x,max_x+xlim*max_x)
431lbl_pos = 3*(len(smpl2y)-1)
432ax1.annotate('   HomAlt ',xy=(max_x,lbl_pos-1),xycoords='data',va='center')
433ax1.annotate('   Het',xy=(max_x,lbl_pos-2),xycoords='data',va='center')
434if not show_xticks:
435    ax1.set_xticks(xtick_pos)
436    ax1.set_xticklabels(xtick_lbl)
437if xlabel!=None:
438    ax1.set_xlabel(xlabel)
439ax1.set_yticks(ytick_pos)
440ax1.set_yticklabels(ytick_lbl)
441ax1.set_ylim(0,3*len(smpl2y)+0.5)
442plt.subplots_adjust(**adjust)
443if interactive:
444    plt.show()
445else:
446    files = parse_outfile(outfile)
447    for file in (parse_outfile(outfile)):
448        plt.savefig(file,dpi=dpi)
449    plt.close()
450
451
452