1#!/usr/local/bin/python3.8
2__version__="v2.5 beta6"
3welcome_block="""
4# Multi-Echo ICA, Version %s
5# See http://dx.doi.org/10.1016/j.neuroimage.2011.12.028
6# Kundu, P., Inati, S.J., Evans, J.W., Luh, W.M. & Bandettini, P.A. Differentiating
7#	BOLD and non-BOLD signals in fMRI time series using multi-echo EPI. NeuroImage (2011).
8#
9# Kundu, P., Inati, S.J., Evans, J.W., Luh, W.M. & Bandettini, P.A. Differentiating
10#   BOLD and non-BOLD signals in fMRI time series using multi-echo EPI. NeuroImage (2011).
11# http://dx.doi.org/10.1016/j.neuroimage.2011.12.028
12#
13# t2smap.py version 2.5 	(c) 2014 Prantik Kundu, Noah Brenowitz, Souheil Inati
14#
15#Computes T2* map
16"""
17
18import os
19from optparse import OptionParser
20import numpy as np
21import nibabel as nib
22from sys import stdout
23
24def scoreatpercentile(a, per, limit=(), interpolation_method='lower'):
25    """
26    This function is grabbed from scipy
27
28    """
29    values = np.sort(a, axis=0)
30    if limit:
31        values = values[(limit[0] <= values) & (values <= limit[1])]
32
33    idx = per /100. * (values.shape[0] - 1)
34    if (idx % 1 == 0):
35        score = values[int(idx)]
36    else:
37        if interpolation_method == 'fraction':
38            score = _interpolate(values[int(idx)], values[int(idx) + 1],
39                                 idx % 1)
40        elif interpolation_method == 'lower':
41            score = values[int(np.floor(idx))]
42        elif interpolation_method == 'higher':
43            score = values[int(np.ceil(idx))]
44        else:
45            raise ValueError("interpolation_method can only be 'fraction', " \
46                             "'lower' or 'higher'")
47    return score
48
49def niwrite(data,affine, name , header=None):
50	data[np.isnan(data)]=0
51	stdout.write(" + Writing file: %s ...." % name)
52
53	thishead = header
54	if thishead == None:
55		thishead = head.copy()
56		thishead.set_data_shape(list(data.shape))
57
58	outni = nib.Nifti1Image(data,affine,header=thishead)
59	outni.set_data_dtype('float64')
60	outni.to_filename(name)
61
62
63	print 'done.'
64
65	return outni
66
67def cat2echos(data,Ne):
68	"""
69	cat2echos(data,Ne)
70
71	Input:
72	data shape is (nx,ny,Ne*nz,nt)
73	"""
74	nx,ny = data.shape[0:2]
75	nz = data.shape[2]/Ne
76	if len(data.shape) >3:
77		nt = data.shape[3]
78	else:
79		nt = 1
80	return np.reshape(data,(nx,ny,nz,Ne,nt),order='F')
81
82def uncat2echos(data,Ne):
83	"""
84	uncat2echos(data,Ne)
85
86	Input:
87	data shape is (nx,ny,Ne,nz,nt)
88	"""
89    	nx,ny = data.shape[0:2]
90	nz = data.shape[2]*Ne
91	if len(data.shape) >4:
92		nt = data.shape[4]
93	else:
94		nt = 1
95	return np.reshape(data,(nx,ny,nz,nt),order='F')
96
97def makemask(cdat):
98
99	nx,ny,nz,Ne,nt = cdat.shape
100
101	mask = np.ones((nx,ny,nz),dtype=np.bool)
102
103	for i in range(Ne):
104		tmpmask = (cdat[:,:,:,i,:] != 0).prod(axis=-1,dtype=np.bool)
105		mask = mask & tmpmask
106
107	return mask
108
109def fmask(data,mask):
110	"""
111	fmask(data,mask)
112
113	Input:
114	data shape is (nx,ny,nz,...)
115	mask shape is (nx,ny,nz)
116
117	Output:
118	out shape is (Nm,...)
119	"""
120
121	s = data.shape
122	sm = mask.shape
123
124	N = s[0]*s[1]*s[2]
125	news = []
126	news.append(N)
127
128	if len(s) >3:
129		news.extend(s[3:])
130
131	tmp1 = np.reshape(data,news)
132	fdata = tmp1.compress((mask > 0 ).ravel(),axis=0)
133
134	return fdata.squeeze()
135
136def unmask (data,mask):
137	"""
138	unmask (data,mask)
139
140	Input:
141
142	data has shape (Nm,nt)
143	mask has shape (nx,ny,nz)
144
145	"""
146	M = (mask != 0).ravel()
147	Nm = M.sum()
148
149	nx,ny,nz = mask.shape
150
151	if len(data.shape) > 1:
152		nt = data.shape[1]
153	else:
154		nt = 1
155
156	out = np.zeros((nx*ny*nz,nt),dtype=data.dtype)
157	out[M,:] = np.reshape(data,(Nm,nt))
158
159	return np.reshape(out,(nx,ny,nz,nt))
160
161def t2smap(catd,mask,tes):
162	"""
163	t2smap(catd,mask,tes)
164
165	Input:
166
167	catd  has shape (nx,ny,nz,Ne,nt)
168	mask  has shape (nx,ny,nz)
169	tes   is a 1d numpy array
170	"""
171	nx,ny,nz,Ne,nt = catd.shape
172	N = nx*ny*nz
173
174	echodata = fmask(catd,mask)
175	Nm = echodata.shape[0]
176
177	#Do Log Linear fit
178	B = np.reshape(np.abs(echodata), (Nm,Ne*nt)).transpose()
179	B = np.log(B)
180	x = np.array([np.ones(Ne),-tes])
181	X = np.tile(x,(1,nt))
182	X = np.sort(X)[:,::-1].transpose()
183
184	beta,res,rank,sing = np.linalg.lstsq(X,B)
185	t2s = 1/beta[1,:].transpose()
186	s0  = np.exp(beta[0,:]).transpose()
187
188	out = unmask(t2s,mask),unmask(s0,mask)
189	out[0][np.isnan(out[0])]=0.
190
191	return out
192
193def optcom(data,t2s,tes,mask):
194	"""
195	out = optcom(data,t2s)
196
197
198	Input:
199
200	data.shape = (nx,ny,nz,Ne,Nt)
201	t2s.shape  = (nx,ny,nz)
202	tes.shape  = (Ne,)
203
204	Output:
205
206	out.shape = (nx,ny,nz,Nt)
207	"""
208	nx,ny,nz,Ne,Nt = data.shape
209
210	fdat = fmask(data,mask)
211	ft2s = fmask(t2s,mask)
212
213	tes = tes[np.newaxis,:]
214	ft2s = ft2s[:,np.newaxis]
215
216	alpha = tes * np.exp(-tes /ft2s)
217	alpha = np.tile(alpha[:,:,np.newaxis],(1,1,Nt))
218
219	fout  = np.average(fdat,axis = 1,weights=alpha)
220	out = unmask(fout,mask)
221	print 'Out shape is ', out.shape
222	return out
223
224###################################################################################################
225# 						Begin Main
226###################################################################################################
227
228if __name__=='__main__':
229
230	parser=OptionParser()
231	parser.add_option('-d',"--orig_data",dest='data',help="Spatially Concatenated Multi-Echo Dataset",default=None)
232	parser.add_option('-l',"--label",dest='label',help="Optional label to tag output files with",default=None)
233	parser.add_option('-e',"--TEs",dest='tes',help="Echo times (in ms) ex: 15,39,63",default=None)
234
235	(options,args) = parser.parse_args()
236
237	print "-- T2* Map Component for ME-ICA v2.0 --"
238
239	if options.tes==None or options.data==None:
240		print "*+ Need at least data and TEs, use -h for help."
241		sys.exit()
242
243	print "++ Loading Data"
244	tes = np.fromstring(options.tes,sep=',',dtype=np.float32)
245	ne = tes.shape[0]
246	catim  = nib.load(options.data)
247	head   = catim.get_header()
248	head.extensions = []
249	head.set_sform(head.get_sform(),code=1)
250	aff = catim.get_affine()
251	catd = cat2echos(catim.get_data(),ne)
252	nx,ny,nz,Ne,nt = catd.shape
253	mu  = catd.mean(axis=-1)
254	sig  = catd.std(axis=-1)
255
256	print "++ Computing Mask"
257	mask  = makemask(catd)
258
259	print "++ Computing T2* map"
260	t2s,s0   = np.array(t2smap(catd,mask,tes),dtype=np.float)
261	t2s[t2s>500] = 500
262	t2sm = t2s.copy()
263
264	s0_maskmin = scoreatpercentile(np.unique(s0),98)/10
265	t2sm[s0<s0_maskmin] = 0
266
267	print "++ Computing optimal combination"
268	#import ipdb
269	#ipdb.set_trace()
270	tsoc = np.array(optcom(catd,t2s,tes,mask),dtype=float)
271
272	if options.label!=None:
273		suf='_%s' % str(options.label)
274	else:
275		suf=''
276
277	#Clean up numerical errors
278	tsoc[np.isnan(tsoc)]=0
279	s0[np.isnan(s0)]=0
280	s0[s0<0]=0
281	t2s[np.isnan(t2s)]=0
282	t2s[t2s<0]=0
283	t2sm[np.isnan(t2sm)]=0
284	t2sm[t2sm<0]=0
285
286	niwrite(tsoc,aff,'ocv%s.nii' % suf)
287	niwrite(s0,aff,'s0v%s.nii' % suf)
288	niwrite(t2s,aff,'t2sv%s.nii' % suf )
289	niwrite(t2sm,aff,'t2svm%s.nii' % suf )
290
291
292
293
294
295
296
297
298
299
300
301