1 #include <stdint.h>
2 #include <string.h>
3 #include <stdio.h>
4 #include <assert.h>
5 #include "mmpriv.h"
6 #include "kalloc.h"
7 #include "krmq.h"
8 
mg_chain_bk_end(int32_t max_drop,const mm128_t * z,const int32_t * f,const int64_t * p,int32_t * t,int64_t k)9 static int64_t mg_chain_bk_end(int32_t max_drop, const mm128_t *z, const int32_t *f, const int64_t *p, int32_t *t, int64_t k)
10 {
11 	int64_t i = z[k].y, end_i = -1, max_i = i;
12 	int32_t max_s = 0;
13 	if (i < 0 || t[i] != 0) return i;
14 	do {
15 		int32_t s;
16 		t[i] = 2;
17 		end_i = i = p[i];
18 		s = i < 0? z[k].x : (int32_t)z[k].x - f[i];
19 		if (s > max_s) max_s = s, max_i = i;
20 		else if (max_s - s > max_drop) break;
21 	} while (i >= 0 && t[i] == 0);
22 	for (i = z[k].y; i >= 0 && i != end_i; i = p[i]) // reset modified t[]
23 		t[i] = 0;
24 	return max_i;
25 }
26 
mg_chain_backtrack(void * km,int64_t n,const int32_t * f,const int64_t * p,int32_t * v,int32_t * t,int32_t min_cnt,int32_t min_sc,int32_t max_drop,int32_t * n_u_,int32_t * n_v_)27 uint64_t *mg_chain_backtrack(void *km, int64_t n, const int32_t *f, const int64_t *p, int32_t *v, int32_t *t, int32_t min_cnt, int32_t min_sc, int32_t max_drop, int32_t *n_u_, int32_t *n_v_)
28 {
29 	mm128_t *z;
30 	uint64_t *u;
31 	int64_t i, k, n_z, n_v;
32 	int32_t n_u;
33 
34 	*n_u_ = *n_v_ = 0;
35 	for (i = 0, n_z = 0; i < n; ++i) // precompute n_z
36 		if (f[i] >= min_sc) ++n_z;
37 	if (n_z == 0) return 0;
38 	KMALLOC(km, z, n_z);
39 	for (i = 0, k = 0; i < n; ++i) // populate z[]
40 		if (f[i] >= min_sc) z[k].x = f[i], z[k++].y = i;
41 	radix_sort_128x(z, z + n_z);
42 
43 	memset(t, 0, n * 4);
44 	for (k = n_z - 1, n_v = n_u = 0; k >= 0; --k) { // precompute n_u
45 		if (t[z[k].y] == 0) {
46 			int64_t n_v0 = n_v, end_i;
47 			int32_t sc;
48 			end_i = mg_chain_bk_end(max_drop, z, f, p, t, k);
49 			for (i = z[k].y; i != end_i; i = p[i])
50 				++n_v, t[i] = 1;
51 			sc = i < 0? z[k].x : (int32_t)z[k].x - f[i];
52 			if (sc >= min_sc && n_v > n_v0 && n_v - n_v0 >= min_cnt)
53 				++n_u;
54 			else n_v = n_v0;
55 		}
56 	}
57 	KMALLOC(km, u, n_u);
58 	memset(t, 0, n * 4);
59 	for (k = n_z - 1, n_v = n_u = 0; k >= 0; --k) { // populate u[]
60 		if (t[z[k].y] == 0) {
61 			int64_t n_v0 = n_v, end_i;
62 			int32_t sc;
63 			end_i = mg_chain_bk_end(max_drop, z, f, p, t, k);
64 			for (i = z[k].y; i != end_i; i = p[i])
65 				v[n_v++] = i, t[i] = 1;
66 			sc = i < 0? z[k].x : (int32_t)z[k].x - f[i];
67 			if (sc >= min_sc && n_v > n_v0 && n_v - n_v0 >= min_cnt)
68 				u[n_u++] = (uint64_t)sc << 32 | (n_v - n_v0);
69 			else n_v = n_v0;
70 		}
71 	}
72 	kfree(km, z);
73 	assert(n_v < INT32_MAX);
74 	*n_u_ = n_u, *n_v_ = n_v;
75 	return u;
76 }
77 
compact_a(void * km,int32_t n_u,uint64_t * u,int32_t n_v,int32_t * v,mm128_t * a)78 static mm128_t *compact_a(void *km, int32_t n_u, uint64_t *u, int32_t n_v, int32_t *v, mm128_t *a)
79 {
80 	mm128_t *b, *w;
81 	uint64_t *u2;
82 	int64_t i, j, k;
83 
84 	// write the result to b[]
85 	KMALLOC(km, b, n_v);
86 	for (i = 0, k = 0; i < n_u; ++i) {
87 		int32_t k0 = k, ni = (int32_t)u[i];
88 		for (j = 0; j < ni; ++j)
89 			b[k++] = a[v[k0 + (ni - j - 1)]];
90 	}
91 	kfree(km, v);
92 
93 	// sort u[] and a[] by the target position, such that adjacent chains may be joined
94 	KMALLOC(km, w, n_u);
95 	for (i = k = 0; i < n_u; ++i) {
96 		w[i].x = b[k].x, w[i].y = (uint64_t)k<<32|i;
97 		k += (int32_t)u[i];
98 	}
99 	radix_sort_128x(w, w + n_u);
100 	KMALLOC(km, u2, n_u);
101 	for (i = k = 0; i < n_u; ++i) {
102 		int32_t j = (int32_t)w[i].y, n = (int32_t)u[j];
103 		u2[i] = u[j];
104 		memcpy(&a[k], &b[w[i].y>>32], n * sizeof(mm128_t));
105 		k += n;
106 	}
107 	memcpy(u, u2, n_u * 8);
108 	memcpy(b, a, k * sizeof(mm128_t)); // write _a_ to _b_ and deallocate _a_ because _a_ is oversized, sometimes a lot
109 	kfree(km, a); kfree(km, w); kfree(km, u2);
110 	return b;
111 }
112 
comput_sc(const mm128_t * ai,const mm128_t * aj,int32_t max_dist_x,int32_t max_dist_y,int32_t bw,float chn_pen_gap,float chn_pen_skip,int is_cdna,int n_seg)113 static inline int32_t comput_sc(const mm128_t *ai, const mm128_t *aj, int32_t max_dist_x, int32_t max_dist_y, int32_t bw, float chn_pen_gap, float chn_pen_skip, int is_cdna, int n_seg)
114 {
115 	int32_t dq = (int32_t)ai->y - (int32_t)aj->y, dr, dd, dg, q_span, sc;
116 	int32_t sidi = (ai->y & MM_SEED_SEG_MASK) >> MM_SEED_SEG_SHIFT;
117 	int32_t sidj = (aj->y & MM_SEED_SEG_MASK) >> MM_SEED_SEG_SHIFT;
118 	if (dq <= 0 || dq > max_dist_x) return INT32_MIN;
119 	dr = (int32_t)(ai->x - aj->x);
120 	if (sidi == sidj && (dr == 0 || dq > max_dist_y)) return INT32_MIN;
121 	dd = dr > dq? dr - dq : dq - dr;
122 	if (sidi == sidj && dd > bw) return INT32_MIN;
123 	if (n_seg > 1 && !is_cdna && sidi == sidj && dr > max_dist_y) return INT32_MIN;
124 	dg = dr < dq? dr : dq;
125 	q_span = aj->y>>32&0xff;
126 	sc = q_span < dg? q_span : dg;
127 	if (dd || dg > q_span) {
128 		float lin_pen, log_pen;
129 		lin_pen = chn_pen_gap * (float)dd + chn_pen_skip * (float)dg;
130 		log_pen = dd >= 1? mg_log2(dd + 1) : 0.0f; // mg_log2() only works for dd>=2
131 		if (is_cdna || sidi != sidj) {
132 			if (sidi != sidj && dr == 0) ++sc; // possibly due to overlapping paired ends; give a minor bonus
133 			else if (dr > dq || sidi != sidj) sc -= (int)(lin_pen < log_pen? lin_pen : log_pen); // deletion or jump between paired ends
134 			else sc -= (int)(lin_pen + .5f * log_pen);
135 		} else sc -= (int)(lin_pen + .5f * log_pen);
136 	}
137 	return sc;
138 }
139 
140 /* Input:
141  *   a[].x: tid<<33 | rev<<32 | tpos
142  *   a[].y: flags<<40 | q_span<<32 | q_pos
143  * Output:
144  *   n_u: #chains
145  *   u[]: score<<32 | #anchors (sum of lower 32 bits of u[] is the returned length of a[])
146  * input a[] is deallocated on return
147  */
mg_lchain_dp(int max_dist_x,int max_dist_y,int bw,int max_skip,int max_iter,int min_cnt,int min_sc,float chn_pen_gap,float chn_pen_skip,int is_cdna,int n_seg,int64_t n,mm128_t * a,int * n_u_,uint64_t ** _u,void * km)148 mm128_t *mg_lchain_dp(int max_dist_x, int max_dist_y, int bw, int max_skip, int max_iter, int min_cnt, int min_sc, float chn_pen_gap, float chn_pen_skip,
149 					  int is_cdna, int n_seg, int64_t n, mm128_t *a, int *n_u_, uint64_t **_u, void *km)
150 { // TODO: make sure this works when n has more than 32 bits
151 	int32_t *f, *t, *v, n_u, n_v, mmax_f = 0, max_drop = bw;
152 	int64_t *p, i, j, max_ii, st = 0, n_iter = 0;
153 	uint64_t *u;
154 
155 	if (_u) *_u = 0, *n_u_ = 0;
156 	if (n == 0 || a == 0) {
157 		kfree(km, a);
158 		return 0;
159 	}
160 	if (max_dist_x < bw) max_dist_x = bw;
161 	if (max_dist_y < bw && !is_cdna) max_dist_y = bw;
162 	if (is_cdna) max_drop = INT32_MAX;
163 	KMALLOC(km, p, n);
164 	KMALLOC(km, f, n);
165 	KMALLOC(km, v, n);
166 	KCALLOC(km, t, n);
167 
168 	// fill the score and backtrack arrays
169 	for (i = 0, max_ii = -1; i < n; ++i) {
170 		int64_t max_j = -1, end_j;
171 		int32_t max_f = a[i].y>>32&0xff, n_skip = 0;
172 		while (st < i && (a[i].x>>32 != a[st].x>>32 || a[i].x > a[st].x + max_dist_x)) ++st;
173 		if (i - st > max_iter) st = i - max_iter;
174 		for (j = i - 1; j >= st; --j) {
175 			int32_t sc;
176 			sc = comput_sc(&a[i], &a[j], max_dist_x, max_dist_y, bw, chn_pen_gap, chn_pen_skip, is_cdna, n_seg);
177 			++n_iter;
178 			if (sc == INT32_MIN) continue;
179 			sc += f[j];
180 			if (sc > max_f) {
181 				max_f = sc, max_j = j;
182 				if (n_skip > 0) --n_skip;
183 			} else if (t[j] == (int32_t)i) {
184 				if (++n_skip > max_skip)
185 					break;
186 			}
187 			if (p[j] >= 0) t[p[j]] = i;
188 		}
189 		end_j = j;
190 		if (max_ii < 0 || a[i].x - a[max_ii].x > (int64_t)max_dist_x) {
191 			int32_t max = INT32_MIN;
192 			max_ii = -1;
193 			for (j = i - 1; j >= st; --j)
194 				if (max < f[j]) max = f[j], max_ii = j;
195 		}
196 		if (max_ii >= 0 && max_ii < end_j) {
197 			int32_t tmp;
198 			tmp = comput_sc(&a[i], &a[max_ii], max_dist_x, max_dist_y, bw, chn_pen_gap, chn_pen_skip, is_cdna, n_seg);
199 			if (tmp != INT32_MIN && max_f < tmp + f[max_ii])
200 				max_f = tmp + f[max_ii], max_j = max_ii;
201 		}
202 		f[i] = max_f, p[i] = max_j;
203 		v[i] = max_j >= 0 && v[max_j] > max_f? v[max_j] : max_f; // v[] keeps the peak score up to i; f[] is the score ending at i, not always the peak
204 		if (max_ii < 0 || (a[i].x - a[max_ii].x <= (int64_t)max_dist_x && f[max_ii] < f[i]))
205 			max_ii = i;
206 		if (mmax_f < max_f) mmax_f = max_f;
207 	}
208 
209 	u = mg_chain_backtrack(km, n, f, p, v, t, min_cnt, min_sc, max_drop, &n_u, &n_v);
210 	*n_u_ = n_u, *_u = u; // NB: note that u[] may not be sorted by score here
211 	kfree(km, p); kfree(km, f); kfree(km, t);
212 	if (n_u == 0) {
213 		kfree(km, a); kfree(km, v);
214 		return 0;
215 	}
216 	return compact_a(km, n_u, u, n_v, v, a);
217 }
218 
219 typedef struct lc_elem_s {
220 	int32_t y;
221 	int64_t i;
222 	double pri;
223 	KRMQ_HEAD(struct lc_elem_s) head;
224 } lc_elem_t;
225 
226 #define lc_elem_cmp(a, b) ((a)->y < (b)->y? -1 : (a)->y > (b)->y? 1 : ((a)->i > (b)->i) - ((a)->i < (b)->i))
227 #define lc_elem_lt2(a, b) ((a)->pri < (b)->pri)
KRMQ_INIT(lc_elem,lc_elem_t,head,lc_elem_cmp,lc_elem_lt2)228 KRMQ_INIT(lc_elem, lc_elem_t, head, lc_elem_cmp, lc_elem_lt2)
229 
230 KALLOC_POOL_INIT(rmq, lc_elem_t)
231 
232 static inline int32_t comput_sc_simple(const mm128_t *ai, const mm128_t *aj, float chn_pen_gap, float chn_pen_skip, int32_t *exact, int32_t *width)
233 {
234 	int32_t dq = (int32_t)ai->y - (int32_t)aj->y, dr, dd, dg, q_span, sc;
235 	dr = (int32_t)(ai->x - aj->x);
236 	*width = dd = dr > dq? dr - dq : dq - dr;
237 	dg = dr < dq? dr : dq;
238 	q_span = aj->y>>32&0xff;
239 	sc = q_span < dg? q_span : dg;
240 	if (exact) *exact = (dd == 0 && dg <= q_span);
241 	if (dd || dq > q_span) {
242 		float lin_pen, log_pen;
243 		lin_pen = chn_pen_gap * (float)dd + chn_pen_skip * (float)dg;
244 		log_pen = dd >= 1? mg_log2(dd + 1) : 0.0f; // mg_log2() only works for dd>=2
245 		sc -= (int)(lin_pen + .5f * log_pen);
246 	}
247 	return sc;
248 }
249 
mg_lchain_rmq(int max_dist,int max_dist_inner,int bw,int max_chn_skip,int cap_rmq_size,int min_cnt,int min_sc,float chn_pen_gap,float chn_pen_skip,int64_t n,mm128_t * a,int * n_u_,uint64_t ** _u,void * km)250 mm128_t *mg_lchain_rmq(int max_dist, int max_dist_inner, int bw, int max_chn_skip, int cap_rmq_size, int min_cnt, int min_sc, float chn_pen_gap, float chn_pen_skip,
251 					   int64_t n, mm128_t *a, int *n_u_, uint64_t **_u, void *km)
252 {
253 	int32_t *f,*t, *v, n_u, n_v, mmax_f = 0, max_rmq_size = 0, max_drop = bw;
254 	int64_t *p, i, i0, st = 0, st_inner = 0, n_iter = 0;
255 	uint64_t *u;
256 	lc_elem_t *root = 0, *root_inner = 0;
257 	void *mem_mp = 0;
258 	kmp_rmq_t *mp;
259 
260 	if (_u) *_u = 0, *n_u_ = 0;
261 	if (n == 0 || a == 0) {
262 		kfree(km, a);
263 		return 0;
264 	}
265 	if (max_dist < bw) max_dist = bw;
266 	if (max_dist_inner <= 0 || max_dist_inner >= max_dist) max_dist_inner = 0;
267 	KMALLOC(km, p, n);
268 	KMALLOC(km, f, n);
269 	KCALLOC(km, t, n);
270 	KMALLOC(km, v, n);
271 	mem_mp = km_init2(km, 0x10000);
272 	mp = kmp_init_rmq(mem_mp);
273 
274 	// fill the score and backtrack arrays
275 	for (i = i0 = 0; i < n; ++i) {
276 		int64_t max_j = -1;
277 		int32_t q_span = a[i].y>>32&0xff, max_f = q_span;
278 		lc_elem_t s, *q, *r, lo, hi;
279 		// add in-range anchors
280 		if (i0 < i && a[i0].x != a[i].x) {
281 			int64_t j;
282 			for (j = i0; j < i; ++j) {
283 				q = kmp_alloc_rmq(mp);
284 				q->y = (int32_t)a[j].y, q->i = j, q->pri = -(f[j] + 0.5 * chn_pen_gap * ((int32_t)a[j].x + (int32_t)a[j].y));
285 				krmq_insert(lc_elem, &root, q, 0);
286 				if (max_dist_inner > 0) {
287 					r = kmp_alloc_rmq(mp);
288 					*r = *q;
289 					krmq_insert(lc_elem, &root_inner, r, 0);
290 				}
291 			}
292 			i0 = i;
293 		}
294 		// get rid of active chains out of range
295 		while (st < i && (a[i].x>>32 != a[st].x>>32 || a[i].x > a[st].x + max_dist || krmq_size(head, root) > cap_rmq_size)) {
296 			s.y = (int32_t)a[st].y, s.i = st;
297 			if ((q = krmq_find(lc_elem, root, &s, 0)) != 0) {
298 				q = krmq_erase(lc_elem, &root, q, 0);
299 				kmp_free_rmq(mp, q);
300 			}
301 			++st;
302 		}
303 		if (max_dist_inner > 0)  { // similar to the block above, but applied to the inner tree
304 			while (st_inner < i && (a[i].x>>32 != a[st_inner].x>>32 || a[i].x > a[st_inner].x + max_dist_inner || krmq_size(head, root_inner) > cap_rmq_size)) {
305 				s.y = (int32_t)a[st_inner].y, s.i = st_inner;
306 				if ((q = krmq_find(lc_elem, root_inner, &s, 0)) != 0) {
307 					q = krmq_erase(lc_elem, &root_inner, q, 0);
308 					kmp_free_rmq(mp, q);
309 				}
310 				++st_inner;
311 			}
312 		}
313 		// RMQ
314 		lo.i = INT32_MAX, lo.y = (int32_t)a[i].y - max_dist;
315 		hi.i = 0, hi.y = (int32_t)a[i].y;
316 		if ((q = krmq_rmq(lc_elem, root, &lo, &hi)) != 0) {
317 			int32_t sc, exact, width, n_skip = 0;
318 			int64_t j = q->i;
319 			assert(q->y >= lo.y && q->y <= hi.y);
320 			sc = f[j] + comput_sc_simple(&a[i], &a[j], chn_pen_gap, chn_pen_skip, &exact, &width);
321 			if (width <= bw && sc > max_f) max_f = sc, max_j = j;
322 			if (!exact && root_inner && (int32_t)a[i].y > 0) {
323 				lc_elem_t *lo, *hi;
324 				s.y = (int32_t)a[i].y - 1, s.i = n;
325 				krmq_interval(lc_elem, root_inner, &s, &lo, &hi);
326 				if (lo) {
327 					const lc_elem_t *q;
328 					int32_t width, n_rmq_iter = 0;
329 					krmq_itr_t(lc_elem) itr;
330 					krmq_itr_find(lc_elem, root_inner, lo, &itr);
331 					while ((q = krmq_at(&itr)) != 0) {
332 						if (q->y < (int32_t)a[i].y - max_dist_inner) break;
333 						++n_rmq_iter;
334 						j = q->i;
335 						sc = f[j] + comput_sc_simple(&a[i], &a[j], chn_pen_gap, chn_pen_skip, 0, &width);
336 						if (width <= bw) {
337 							if (sc > max_f) {
338 								max_f = sc, max_j = j;
339 								if (n_skip > 0) --n_skip;
340 							} else if (t[j] == (int32_t)i) {
341 								if (++n_skip > max_chn_skip)
342 									break;
343 							}
344 							if (p[j] >= 0) t[p[j]] = i;
345 						}
346 						if (!krmq_itr_prev(lc_elem, &itr)) break;
347 					}
348 					n_iter += n_rmq_iter;
349 				}
350 			}
351 		}
352 		// set max
353 		assert(max_j < 0 || (a[max_j].x < a[i].x && (int32_t)a[max_j].y < (int32_t)a[i].y));
354 		f[i] = max_f, p[i] = max_j;
355 		v[i] = max_j >= 0 && v[max_j] > max_f? v[max_j] : max_f; // v[] keeps the peak score up to i; f[] is the score ending at i, not always the peak
356 		if (mmax_f < max_f) mmax_f = max_f;
357 		if (max_rmq_size < krmq_size(head, root)) max_rmq_size = krmq_size(head, root);
358 	}
359 	km_destroy(mem_mp);
360 
361 	u = mg_chain_backtrack(km, n, f, p, v, t, min_cnt, min_sc, max_drop, &n_u, &n_v);
362 	*n_u_ = n_u, *_u = u; // NB: note that u[] may not be sorted by score here
363 	kfree(km, p); kfree(km, f); kfree(km, t);
364 	if (n_u == 0) {
365 		kfree(km, a); kfree(km, v);
366 		return 0;
367 	}
368 	return compact_a(km, n_u, u, n_v, v, a);
369 }
370