1from __future__ import absolute_import 2from __future__ import print_function 3 4import re 5import os 6import sys 7import time 8from collections import defaultdict 9 10from .utils import log 11from . import db 12from .errors import ConfigError, DataError 13 14 15def iter_fasta_seqs(source): 16 """Iter records in a FASTA file""" 17 18 if os.path.isfile(source): 19 if source.endswith('.gz'): 20 import gzip 21 _source = gzip.open(source) 22 else: 23 _source = open(source, "r") 24 else: 25 _source = iter(source.split("\n")) 26 27 seq_chunks = [] 28 seq_name = None 29 for line in _source: 30 line = line.strip() 31 if line.startswith('#') or not line: 32 continue 33 elif line.startswith('>'): 34 # yield seq if finished 35 if seq_name and not seq_chunks: 36 raise ValueError("Error parsing fasta file. %s has no sequence" %seq_name) 37 elif seq_name: 38 yield seq_name, ''.join(seq_chunks) 39 40 seq_name = line[1:].split('\t')[0].strip() 41 seq_chunks = [] 42 else: 43 if seq_name is None: 44 raise Exception("Error reading sequences: Wrong format.") 45 seq_chunks.append(line.replace(" ","")) 46 47 # return last sequence 48 if seq_name and not seq_chunks: 49 raise ValueError("Error parsing fasta file. %s has no sequence" %seq_name) 50 elif seq_name: 51 yield seq_name, ''.join(seq_chunks) 52 53 54def load_sequences(args, seqtype, target_seqs, target_species, cached_seqs): 55 seqfile = getattr(args, "%s_seed_file" %seqtype) 56 skipped_seqs = 0 57 loaded_seqs = {} 58 59 log.log(28, "Reading %s sequences from %s...", seqtype, seqfile) 60 fix_dups = True if args.rename_dup_seqnames else False 61 if args.seq_name_parser: 62 NAME_PARSER = re.compile(args.seq_name_parser) 63 64 seq_repl = {} 65 # Clear problematic symbols 66 if not args.no_seq_correct: 67 seq_repl["."] = "-" 68 seq_repl["*"] = "X" 69 if seqtype == "aa": 70 seq_repl["J"] = "X" # phyml fails with J 71 seq_repl["O"] = "X" # mafft fails with O 72 seq_repl["U"] = "X" # selenocysteines 73 if args.dealign: 74 seq_repl["-"] = "" 75 seq_repl["."] = "" 76 if seq_repl: 77 SEQ_PARSER = re.compile("(%s)" %('|'.join(map(re.escape,seq_repl.keys())))) 78 79 start_time = time.time() 80 dupnames = defaultdict(int) 81 for c1, (raw_seqname, seq) in enumerate(iter_fasta_seqs(seqfile)): 82 if c1 % 10000 == 0 and c1: 83 if loaded_seqs and target_seqs: # only works when workflow is supermatrix 84 estimated_time = ((len(target_seqs) - len(loaded_seqs)) * (time.time() - start_time)) / float(c1) 85 percent = (len(loaded_seqs) / float(len(target_seqs))) * 100.0 86 else: 87 percent = 0 88 estimated_time = -1 89 print("loaded:%07d skipped:%07d scanned:%07d %0.1f%%" %\ 90 (len(loaded_seqs), skipped_seqs, c1, percent), end='\n', file=sys.stderr) 91 92 if args.seq_name_parser: 93 name_match = re.search(NAME_PARSER, raw_seqname) 94 if name_match: 95 seqname = name_match.groups()[0] 96 else: 97 raise ConfigError("Could not parse sequence name: %s" %raw_seqname) 98 else: 99 seq_name = raw_seqname 100 101 if target_seqs and loaded_seqs == len(target_seqs): 102 break 103 elif target_seqs and seqname not in target_seqs: 104 skipped_seqs += 1 105 continue 106 elif target_species and seqname.split(args.spname_delimiter, 1)[0] not in target_species: 107 skipped_seqs += 1 108 continue 109 110 if seq_repl: 111 seq = SEQ_PARSER.sub(lambda m: seq_repl[m.groups()[0]], seq) 112 113 if cached_seqs: 114 try: 115 seqid = cached_seqs[seqname] 116 except: 117 raise DataError("%s sequence not found in %s sequence file" %(seqname, seqtype)) 118 else: 119 seqid = "S%09d" %(len(loaded_seqs)+1) 120 121 if seqname in loaded_seqs: 122 if fix_dups: 123 dupnames[seqname] += 1 124 seqname = seqname + "_%d"%dupnames[seqname] 125 else: 126 raise DataError("Duplicated sequence name [%s] found. Fix manually or use --rename-dup-seqnames to continue" %(seqname)) 127 128 129 loaded_seqs[seqname] = seqid 130 db.add_seq(seqid, seq, seqtype) 131 if not cached_seqs: 132 db.add_seq_name(seqid, seqname) 133 print('\n', file=sys.stderr) 134 db.seqconn.commit() 135 return loaded_seqs 136 137 138 139 140 # if not args.no_seq_checks: 141 # # Load unknown symbol inconsistencies 142 # if seqtype == "nt" and set(seq) - NT: 143 # seq2unknown[seqtype][seqname] = set(seq) - NT 144 # elif seqtype == "aa" and set(seq) - AA: 145 # seq2unknown[seqtype][seqname] = set(seq) - AA 146 147 # seq2seq[seqtype][seqname] = seq 148 # seq2length[seqtype][seqname] = len(seq) 149 150 151 # Initialize target sets using aa as source 152 # if not target_seqs: # and seqtype == "aa": 153 # target_seqs = set(visited_seqs[source_seqtype]) 154 155 # if skipped_seqs: 156 # log.warning("%d sequences will not be used since they are" 157 # " not present in the aa seed file." %skipped_seqs) 158 159 # return target_seqs, visited_seqs, seq2length, seq2unknown, seq2seq 160 161 162 163 164 165 166 167 168def check_seq_integrity(args, target_seqs, visited_seqs, seq2length, seq2unknown, seq2seq): 169 log.log(28, "Checking data consistency ...") 170 source_seqtype = "aa" if "aa" in GLOBALS["seqtypes"] else "nt" 171 error = "" 172 173 # Check for duplicate ids 174 if not args.ignore_dup_seqnames: 175 seq_number = len(set(visited_seqs[source_seqtype])) 176 if len(visited_seqs[source_seqtype]) != seq_number: 177 counter = defaultdict(int) 178 for seqname in visited_seqs[source_seqtype]: 179 counter[seqname] += 1 180 duplicates = ["%s\thas %d copies" %(key, value) for key, value in six.iteritems(counter) if value > 1] 181 error += "\nDuplicate sequence names.\n" 182 error += '\n'.join(duplicates) 183 184 # check that the seq of all targets is available 185 if target_seqs: 186 for seqtype in GLOBALS["seqtypes"]: 187 missing_seq = target_seqs - set(seq2seq[seqtype].keys()) 188 if missing_seq: 189 error += "\nThe following %s sequences are missing:\n" %seqtype 190 error += '\n'.join(missing_seq) 191 192 # check for unknown characters 193 for seqtype in GLOBALS["seqtypes"]: 194 if seq2unknown[seqtype]: 195 error += "\nThe following %s sequences contain unknown symbols:\n" %seqtype 196 error += '\n'.join(["%s\tcontains:\t%s" %(k,' '.join(v)) for k,v in six.iteritems(seq2unknown[seqtype])] ) 197 198 # check for aa/cds consistency 199 REAL_NT = set('ACTG') 200 if GLOBALS["seqtypes"] == set(["aa", "nt"]): 201 inconsistent_cds = set() 202 for seqname, ntlen in six.iteritems(seq2length["nt"]): 203 if seqname in seq2length["aa"]: 204 aa_len = seq2length["aa"][seqname] 205 if ntlen / 3.0 != aa_len: 206 inconsistent_cds.add("%s\tExpected:%d\tFound:%d" %\ 207 (seqname, 208 aa_len*3, 209 ntlen)) 210 else: 211 if not args.no_seq_checks: 212 for i, aa in enumerate(seq2seq["aa"][seqname]): 213 codon = seq2seq["nt"][seqname][i*3:(i*3)+3] 214 if not (set(codon) - REAL_NT): 215 if GENCODE[codon] != aa: 216 log.warning('@@2:Unmatching codon in seq:%s, aa pos:%s (%s != %s)@@1: Use --no-seq-checks to skip' %(seqname, i, codon, aa)) 217 inconsistent_cds.add('Unmatching codon in seq:%s, aa pos:%s (%s != %s)' %(seqname, i, codon, aa)) 218 219 if inconsistent_cds: 220 error += "\nUnexpected coding sequence length for the following ids:\n" 221 error += '\n'.join(inconsistent_cds) 222 223 # Show some stats 224 all_len = list(seq2length[source_seqtype].values()) 225 max_len = _max(all_len) 226 min_len = _min(all_len) 227 mean_len = _mean(all_len) 228 std_len = _std(all_len) 229 outliers = [] 230 for v in all_len: 231 if abs(mean_len - v) > (3 * std_len): 232 outliers.append(v) 233 log.log(28, "Total sequences: %d" %len(all_len)) 234 log.log(28, "Average sequence length: %d +- %0.1f " %(mean_len, std_len)) 235 log.log(28, "Max sequence length: %d" %max_len) 236 log.log(28, "Min sequence length: %d" %min_len) 237 238 if outliers: 239 log.warning("%d sequence lengths look like outliers" %len(outliers)) 240 241 return error 242 243 244def hash_names(target_names): 245 """Given a set of strings of variable lengths, it returns their 246 conversion to fixed and safe hash-strings. 247 """ 248 # An example of hash name collision 249 #test= ['4558_15418', '9600_21104', '7222_13002', '3847_37647', '412133_16266'] 250 #hash_names(test) 251 252 log.log(28, "Generating safe sequence names...") 253 hash2name = defaultdict(list) 254 for c1, name in enumerate(target_names): 255 print(c1, "\r", end=' ', file=sys.stderr) 256 sys.stderr.flush() 257 hash_name = encode_seqname(name) 258 hash2name[hash_name].append(name) 259 260 collisions = [(k,v) for k,v in six.iteritems(hash2name) if len(v)>1] 261 #GLOBALS["name_collisions"] = {} 262 if collisions: 263 visited = set(hash2name.keys()) 264 for old_hash, coliding_names in collisions: 265 logindent(2) 266 log.log(20, "Collision found when hash-encoding the following gene names: %s", coliding_names) 267 niter = 1 268 valid = False 269 while not valid or len(new_hashes) < len(coliding_names): 270 niter += 1 271 new_hashes = defaultdict(list) 272 for name in coliding_names: 273 hash_name = encode_seqname(name*niter) 274 new_hashes[hash_name].append(name) 275 valid = set(new_hashes.keys()).isdisjoint(visited) 276 277 log.log(20, "Fixed with %d concatenations! %s", niter, ', '.join(['%s=%s' %(e[1][0], e[0]) for e in six.iteritems(new_hashes)])) 278 del hash2name[old_hash] 279 hash2name.update(new_hashes) 280 #GLOBALS["name_collisions"].update([(_name, _code) for _code, _name in new_hashes.iteritems()]) 281 logindent(-2) 282 #collisions = [(k,v) for k,v in hash2name.iteritems() if len(v)>1] 283 #log.log(28, "Final collisions %s", collisions ) 284 hash2name = {k: v[0] for k,v in six.iteritems(hash2name)} 285 name2hash = {v: k for k,v in six.iteritems(hash2name)} 286 return name2hash, hash2name 287 288