1from __future__ import absolute_import
2from __future__ import print_function
3
4import re
5import os
6import sys
7import time
8from collections import defaultdict
9
10from .utils import log
11from . import db
12from .errors import ConfigError, DataError
13
14
15def iter_fasta_seqs(source):
16    """Iter records in a FASTA file"""
17
18    if os.path.isfile(source):
19        if source.endswith('.gz'):
20            import gzip
21            _source = gzip.open(source)
22        else:
23            _source = open(source, "r")
24    else:
25        _source = iter(source.split("\n"))
26
27    seq_chunks = []
28    seq_name = None
29    for line in _source:
30        line = line.strip()
31        if line.startswith('#') or not line:
32            continue
33        elif line.startswith('>'):
34            # yield seq if finished
35            if seq_name and not seq_chunks:
36                raise ValueError("Error parsing fasta file. %s has no sequence" %seq_name)
37            elif seq_name:
38                yield seq_name, ''.join(seq_chunks)
39
40            seq_name = line[1:].split('\t')[0].strip()
41            seq_chunks = []
42        else:
43            if seq_name is None:
44                raise Exception("Error reading sequences: Wrong format.")
45            seq_chunks.append(line.replace(" ",""))
46
47    # return last sequence
48    if seq_name and not seq_chunks:
49        raise ValueError("Error parsing fasta file. %s has no sequence" %seq_name)
50    elif seq_name:
51        yield seq_name, ''.join(seq_chunks)
52
53
54def load_sequences(args, seqtype, target_seqs, target_species, cached_seqs):
55    seqfile = getattr(args, "%s_seed_file" %seqtype)
56    skipped_seqs = 0
57    loaded_seqs = {}
58
59    log.log(28, "Reading %s sequences from %s...", seqtype, seqfile)
60    fix_dups = True if args.rename_dup_seqnames else False
61    if args.seq_name_parser:
62        NAME_PARSER = re.compile(args.seq_name_parser)
63
64    seq_repl = {}
65    # Clear problematic symbols
66    if not args.no_seq_correct:
67        seq_repl["."] = "-"
68        seq_repl["*"] = "X"
69        if seqtype == "aa":
70            seq_repl["J"] = "X" # phyml fails with J
71            seq_repl["O"] = "X" # mafft fails with O
72            seq_repl["U"] = "X" # selenocysteines
73    if args.dealign:
74        seq_repl["-"] = ""
75        seq_repl["."] = ""
76    if seq_repl:
77        SEQ_PARSER = re.compile("(%s)" %('|'.join(map(re.escape,seq_repl.keys()))))
78
79    start_time = time.time()
80    dupnames = defaultdict(int)
81    for c1, (raw_seqname, seq) in enumerate(iter_fasta_seqs(seqfile)):
82        if c1 % 10000 == 0 and c1:
83            if loaded_seqs and target_seqs:  # only works when workflow is supermatrix
84                estimated_time = ((len(target_seqs) - len(loaded_seqs)) * (time.time() - start_time)) / float(c1)
85                percent = (len(loaded_seqs) / float(len(target_seqs))) * 100.0
86            else:
87                percent = 0
88                estimated_time = -1
89            print("loaded:%07d skipped:%07d scanned:%07d %0.1f%%" %\
90                  (len(loaded_seqs), skipped_seqs, c1, percent), end='\n', file=sys.stderr)
91
92        if args.seq_name_parser:
93            name_match = re.search(NAME_PARSER, raw_seqname)
94            if name_match:
95                seqname = name_match.groups()[0]
96            else:
97                raise ConfigError("Could not parse sequence name: %s" %raw_seqname)
98        else:
99            seq_name = raw_seqname
100
101        if target_seqs and loaded_seqs == len(target_seqs):
102            break
103        elif target_seqs and seqname not in target_seqs:
104            skipped_seqs += 1
105            continue
106        elif target_species and seqname.split(args.spname_delimiter, 1)[0] not in target_species:
107            skipped_seqs += 1
108            continue
109
110        if seq_repl:
111            seq = SEQ_PARSER.sub(lambda m: seq_repl[m.groups()[0]], seq)
112
113        if cached_seqs:
114            try:
115                seqid = cached_seqs[seqname]
116            except:
117                raise DataError("%s sequence not found in %s sequence file" %(seqname, seqtype))
118        else:
119            seqid = "S%09d" %(len(loaded_seqs)+1)
120
121        if seqname in loaded_seqs:
122            if fix_dups:
123                dupnames[seqname] += 1
124                seqname = seqname + "_%d"%dupnames[seqname]
125            else:
126                raise DataError("Duplicated sequence name [%s] found. Fix manually or use --rename-dup-seqnames to continue" %(seqname))
127
128
129        loaded_seqs[seqname] = seqid
130        db.add_seq(seqid, seq, seqtype)
131        if not cached_seqs:
132            db.add_seq_name(seqid, seqname)
133    print('\n', file=sys.stderr)
134    db.seqconn.commit()
135    return loaded_seqs
136
137
138
139
140        # if not args.no_seq_checks:
141        #     # Load unknown symbol inconsistencies
142        #     if seqtype == "nt" and set(seq) - NT:
143        #         seq2unknown[seqtype][seqname] = set(seq) - NT
144        #     elif seqtype == "aa" and set(seq) - AA:
145        #         seq2unknown[seqtype][seqname] = set(seq) - AA
146
147        # seq2seq[seqtype][seqname] = seq
148        # seq2length[seqtype][seqname] = len(seq)
149
150
151    # Initialize target sets using aa as source
152    # if not target_seqs: # and seqtype == "aa":
153    #     target_seqs = set(visited_seqs[source_seqtype])
154
155    # if skipped_seqs:
156    #     log.warning("%d sequences will not be used since they are"
157    #                 "  not present in the aa seed file." %skipped_seqs)
158
159    # return target_seqs, visited_seqs, seq2length, seq2unknown, seq2seq
160
161
162
163
164
165
166
167
168def check_seq_integrity(args, target_seqs, visited_seqs, seq2length, seq2unknown, seq2seq):
169    log.log(28, "Checking data consistency ...")
170    source_seqtype = "aa" if "aa" in GLOBALS["seqtypes"] else "nt"
171    error = ""
172
173    # Check for duplicate ids
174    if not args.ignore_dup_seqnames:
175        seq_number = len(set(visited_seqs[source_seqtype]))
176        if len(visited_seqs[source_seqtype]) != seq_number:
177            counter = defaultdict(int)
178            for seqname in visited_seqs[source_seqtype]:
179                counter[seqname] += 1
180            duplicates = ["%s\thas %d copies" %(key, value) for key, value in six.iteritems(counter) if value > 1]
181            error += "\nDuplicate sequence names.\n"
182            error += '\n'.join(duplicates)
183
184    # check that the seq of all targets is available
185    if target_seqs:
186        for seqtype in GLOBALS["seqtypes"]:
187            missing_seq = target_seqs - set(seq2seq[seqtype].keys())
188            if missing_seq:
189                error += "\nThe following %s sequences are missing:\n" %seqtype
190                error += '\n'.join(missing_seq)
191
192    # check for unknown characters
193    for seqtype in GLOBALS["seqtypes"]:
194        if seq2unknown[seqtype]:
195            error += "\nThe following %s sequences contain unknown symbols:\n" %seqtype
196            error += '\n'.join(["%s\tcontains:\t%s" %(k,' '.join(v)) for k,v in six.iteritems(seq2unknown[seqtype])] )
197
198    # check for aa/cds consistency
199    REAL_NT = set('ACTG')
200    if GLOBALS["seqtypes"] == set(["aa", "nt"]):
201        inconsistent_cds = set()
202        for seqname, ntlen in six.iteritems(seq2length["nt"]):
203            if seqname in seq2length["aa"]:
204                aa_len = seq2length["aa"][seqname]
205                if ntlen / 3.0 != aa_len:
206                    inconsistent_cds.add("%s\tExpected:%d\tFound:%d" %\
207                                         (seqname,
208                                         aa_len*3,
209                                         ntlen))
210                else:
211                    if not args.no_seq_checks:
212                        for i, aa in enumerate(seq2seq["aa"][seqname]):
213                            codon = seq2seq["nt"][seqname][i*3:(i*3)+3]
214                            if not (set(codon) - REAL_NT):
215                                if GENCODE[codon] != aa:
216                                    log.warning('@@2:Unmatching codon in seq:%s, aa pos:%s (%s != %s)@@1: Use --no-seq-checks to skip' %(seqname, i, codon, aa))
217                                    inconsistent_cds.add('Unmatching codon in seq:%s, aa pos:%s (%s != %s)' %(seqname, i, codon, aa))
218
219        if inconsistent_cds:
220            error += "\nUnexpected coding sequence length for the following ids:\n"
221            error += '\n'.join(inconsistent_cds)
222
223    # Show some stats
224    all_len = list(seq2length[source_seqtype].values())
225    max_len = _max(all_len)
226    min_len = _min(all_len)
227    mean_len = _mean(all_len)
228    std_len = _std(all_len)
229    outliers = []
230    for v in all_len:
231        if abs(mean_len - v) >  (3 * std_len):
232            outliers.append(v)
233    log.log(28, "Total sequences:  %d" %len(all_len))
234    log.log(28, "Average sequence length: %d +- %0.1f " %(mean_len, std_len))
235    log.log(28, "Max sequence length:  %d" %max_len)
236    log.log(28, "Min sequence length:  %d" %min_len)
237
238    if outliers:
239        log.warning("%d sequence lengths look like outliers" %len(outliers))
240
241    return error
242
243
244def hash_names(target_names):
245    """Given a set of strings of variable lengths, it returns their
246    conversion to fixed and safe hash-strings.
247    """
248    # An example of hash name collision
249    #test= ['4558_15418', '9600_21104', '7222_13002', '3847_37647', '412133_16266']
250    #hash_names(test)
251
252    log.log(28, "Generating safe sequence names...")
253    hash2name = defaultdict(list)
254    for c1, name in enumerate(target_names):
255        print(c1, "\r", end=' ', file=sys.stderr)
256        sys.stderr.flush()
257        hash_name = encode_seqname(name)
258        hash2name[hash_name].append(name)
259
260    collisions = [(k,v) for k,v in six.iteritems(hash2name) if len(v)>1]
261    #GLOBALS["name_collisions"] = {}
262    if collisions:
263        visited = set(hash2name.keys())
264        for old_hash, coliding_names in collisions:
265            logindent(2)
266            log.log(20, "Collision found when hash-encoding the following gene names: %s", coliding_names)
267            niter = 1
268            valid = False
269            while not valid or len(new_hashes) < len(coliding_names):
270                niter += 1
271                new_hashes = defaultdict(list)
272                for name in coliding_names:
273                    hash_name = encode_seqname(name*niter)
274                    new_hashes[hash_name].append(name)
275                valid = set(new_hashes.keys()).isdisjoint(visited)
276
277            log.log(20, "Fixed with %d concatenations! %s", niter, ', '.join(['%s=%s' %(e[1][0], e[0]) for e in  six.iteritems(new_hashes)]))
278            del hash2name[old_hash]
279            hash2name.update(new_hashes)
280            #GLOBALS["name_collisions"].update([(_name, _code) for _code, _name in new_hashes.iteritems()])
281            logindent(-2)
282    #collisions = [(k,v) for k,v in hash2name.iteritems() if len(v)>1]
283    #log.log(28, "Final collisions %s", collisions )
284    hash2name = {k: v[0] for k,v in six.iteritems(hash2name)}
285    name2hash = {v: k for k,v in six.iteritems(hash2name)}
286    return name2hash, hash2name
287
288