1"""
2
3Registration API: simplified API for registration of MRI data and of
4streamlines.
5
6
7"""
8import collections
9import numbers
10import numpy as np
11import nibabel as nib
12from dipy.align.metrics import CCMetric, EMMetric, SSDMetric
13from dipy.align.imwarp import (SymmetricDiffeomorphicRegistration,
14                               DiffeomorphicMap)
15
16from dipy.align.imaffine import (transform_centers_of_mass,
17                                 AffineMap,
18                                 MutualInformationMetric,
19                                 AffineRegistration)
20
21from dipy.align.transforms import (TranslationTransform3D,
22                                   RigidTransform3D,
23                                   AffineTransform3D)
24
25
26import dipy.core.gradients as dpg
27import dipy.data as dpd
28from dipy.align.streamlinear import StreamlineLinearRegistration
29from dipy.tracking.streamline import set_number_of_points
30from dipy.tracking.utils import transform_tracking_output
31from dipy.io.streamline import load_trk
32from dipy.io.utils import read_img_arr_or_path
33from dipy.io.image import load_nifti, save_nifti
34
35__all__ = ["syn_registration", "register_dwi_to_template",
36           "write_mapping", "read_mapping", "resample",
37           "center_of_mass", "translation", "rigid", "affine",
38           "affine_registration", "register_series",
39           "register_dwi_series", "streamline_registration"]
40
41# Global dicts for choosing metrics for registration:
42syn_metric_dict = {'CC': CCMetric,
43                   'EM': EMMetric,
44                   'SSD': SSDMetric}
45
46affine_metric_dict = {'MI': MutualInformationMetric}
47
48
49def _handle_pipeline_inputs(moving, static, static_affine=None,
50                            moving_affine=None, starting_affine=None):
51    """
52    Helper function to prepare inputs for pipeline functions
53
54    Parameters
55    ----------
56    moving, static: Either as a 3D/4D array or as a nifti image object, or as
57        a string containing the full path to a nifti file.
58
59    static_affine, moving_affine: 2D arrays.
60        The array associated with the static/moving images.
61
62    starting_affine : 2D array, optional.
63        This is the registration matrix that is inherited from previous steps
64        in the pipeline. Default: 4-by-4 identity matrix.
65    """
66    static, static_affine = read_img_arr_or_path(static,
67                                                 affine=static_affine)
68    moving, moving_affine = read_img_arr_or_path(moving,
69                                                 affine=moving_affine)
70    if starting_affine is None:
71        starting_affine = np.eye(4)
72
73    return static, static_affine, moving, moving_affine, starting_affine
74
75
76def syn_registration(moving, static,
77                     moving_affine=None,
78                     static_affine=None,
79                     step_length=0.25,
80                     metric='CC',
81                     dim=3,
82                     level_iters=None,
83                     prealign=None,
84                     **metric_kwargs):
85    """Register a 2D/3D source image (moving) to a 2D/3D target image (static).
86
87    Parameters
88    ----------
89    moving, static : array or nib.Nifti1Image or str.
90        Either as a 2D/3D array or as a nifti image object, or as
91        a string containing the full path to a nifti file.
92    moving_affine, static_affine : 4x4 array, optional.
93        Must be provided for `data` provided as an array. If provided together
94        with Nifti1Image or str `data`, this input will over-ride the affine
95        that is stored in the `data` input. Default: use the affine stored
96        in `data`.
97    metric : string, optional
98        The metric to be optimized. One of `CC`, `EM`, `SSD`,
99        Default: 'CC' => CCMetric.
100    dim: int (either 2 or 3), optional
101       The dimensions of the image domain. Default: 3
102    level_iters : list of int, optional
103        the number of iterations at each level of the Gaussian Pyramid (the
104        length of the list defines the number of pyramid levels to be
105        used). Default: [10, 10, 5].
106    metric_kwargs : dict, optional
107        Parameters for initialization of the metric object. If not provided,
108        uses the default settings of each metric.
109
110    Returns
111    -------
112    warped_moving : ndarray
113        The data in `moving`, warped towards the `static` data.
114    forward : ndarray (..., 3)
115        The vector field describing the forward warping from the source to the
116        target.
117    backward : ndarray (..., 3)
118        The vector field describing the backward warping from the target to the
119        source.
120    """
121    level_iters = level_iters or [10, 10, 5]
122
123    static, static_affine, moving, moving_affine, _ = \
124        _handle_pipeline_inputs(moving, static,
125                                moving_affine=moving_affine,
126                                static_affine=static_affine,
127                                starting_affine=None)
128
129    use_metric = syn_metric_dict[metric.upper()](dim, **metric_kwargs)
130
131    sdr = SymmetricDiffeomorphicRegistration(use_metric, level_iters,
132                                             step_length=step_length)
133
134    mapping = sdr.optimize(static, moving,
135                           static_grid2world=static_affine,
136                           moving_grid2world=moving_affine,
137                           prealign=prealign)
138
139    warped_moving = mapping.transform(moving)
140    return warped_moving, mapping
141
142
143def register_dwi_to_template(dwi, gtab, dwi_affine=None, template=None,
144                             template_affine=None, reg_method="syn",
145                             **reg_kwargs):
146    """
147    Register DWI data to a template through the B0 volumes.
148
149    Parameters
150    -----------
151    dwi : 4D array, nifti image or str
152        Containing the DWI data, or full path to a nifti file with DWI.
153    gtab : GradientTable or sequence of strings
154        The gradients associated with the DWI data, or a sequence with
155        (fbval, fbvec), full paths to bvals and bvecs files.
156    dwi_affine : 4x4 array, optional
157        An affine transformation associated with the DWI. Required if data
158        is provided as an array. If provided together with nifti/path,
159        will over-ride the affine that is in the nifti.
160    template : 3D array, nifti image or str
161        Containing the data for the template, or full path to a nifti file
162        with the template data.
163    template_affine : 4x4 array, optional
164        An affine transformation associated with the template. Required if data
165        is provided as an array. If provided together with nifti/path,
166        will over-ride the affine that is in the nifti.
167
168    reg_method : str,
169        One of "syn" or "aff", which designates which registration method is
170        used. Either syn, which uses the :func:`syn_registration` function
171        or :func:`affine_registration` function. Default: "syn".
172    reg_kwargs : key-word arguments for :func:`syn_registration` or
173        :func:`affine_registration`
174
175    Returns
176    -------
177    warped_b0, mapping: The fist is an array with the b0 volume warped to the
178    template. If reg_method is "syn", the second is a DiffeomorphicMap class
179    instance that can be used to transform between the two spaces. Otherwise,
180    if reg_method is "aff", this is a 4x4 matrix encoding the affine transform.
181
182    Notes
183    -----
184    This function assumes that the DWI data is already internally registered.
185    See :func:`register_dwi_series`.
186
187    """
188    dwi_data, dwi_affine = read_img_arr_or_path(dwi, affine=dwi_affine)
189
190    if template is None:
191        template = dpd.read_mni_template()
192
193    template_data, template_affine = read_img_arr_or_path(
194                                       template,
195                                       affine=template_affine)
196
197    if not isinstance(gtab, dpg.GradientTable):
198        gtab = dpg.gradient_table(*gtab)
199
200    mean_b0 = np.mean(dwi_data[..., gtab.b0s_mask], -1)
201    if reg_method.lower() == "syn":
202        warped_b0, mapping = syn_registration(mean_b0, template_data,
203                                              moving_affine=dwi_affine,
204                                              static_affine=template_affine,
205                                              **reg_kwargs)
206    elif reg_method.lower() == "aff":
207        warped_b0, mapping = affine_registration(mean_b0, template_data,
208                                                 moving_affine=dwi_affine,
209                                                 static_affine=template_affine,
210                                                 **reg_kwargs)
211    else:
212        raise ValueError("reg_method should be one of 'aff' or 'syn', but you"
213                         " provided %s" % reg_method)
214
215    return warped_b0, mapping
216
217
218def write_mapping(mapping, fname):
219    """
220    Write out a syn registration mapping to a nifti file
221
222    Parameters
223    ----------
224    mapping : a DiffeomorphicMap object derived from :func:`syn_registration`
225    fname : str
226        Full path to the nifti file storing the mapping
227
228    Notes
229    -----
230    The data in the file is organized with shape (X, Y, Z, 2, 3, 3), such
231    that the forward mapping in each voxel is in `data[i, j, k, 0, :, :]` and
232    the backward mapping in each voxel is in `data[i, j, k, 0, :, :]`.
233    """
234    mapping_data = np.array([mapping.forward.T, mapping.backward.T]).T
235    save_nifti(fname, mapping_data, mapping.codomain_world2grid)
236
237
238def read_mapping(disp, domain_img, codomain_img, prealign=None):
239    """
240    Read a syn registration mapping from a nifti file
241
242    Parameters
243    ----------
244    disp : str or Nifti1Image
245        A file of image containing the mapping displacement field in each voxel
246        Shape (x, y, z, 3, 2)
247
248    domain_img : str or Nifti1Image
249
250    codomain_img : str or Nifti1Image
251
252    Returns
253    -------
254    A :class:`DiffeomorphicMap` object.
255
256    Notes
257    -----
258    See :func:`write_mapping` for the data format expected.
259    """
260    if isinstance(disp, str):
261        disp_data, disp_affine = load_nifti(disp)
262
263    if isinstance(domain_img, str):
264        domain_img = nib.load(domain_img)
265
266    if isinstance(codomain_img, str):
267        codomain_img = nib.load(codomain_img)
268
269    mapping = DiffeomorphicMap(3, disp_data.shape[:3],
270                               disp_grid2world=np.linalg.inv(disp_affine),
271                               domain_shape=domain_img.shape[:3],
272                               domain_grid2world=domain_img.affine,
273                               codomain_shape=codomain_img.shape,
274                               codomain_grid2world=codomain_img.affine,
275                               prealign=prealign)
276
277    mapping.forward = disp_data[..., 0]
278    mapping.backward = disp_data[..., 1]
279    mapping.is_inverse = True
280
281    return mapping
282
283
284def resample(moving, static, moving_affine=None, static_affine=None,
285             between_affine=None):
286    """Resample an image (moving) from one space to another (static).
287
288    Parameters
289    ----------
290    moving : array, nifti image or str
291        Containing the data for the moving object, or full path to a nifti file
292        with the moving data.
293
294    moving_affine : 4x4 array, optional
295        An affine transformation associated with the moving object. Required if
296        data is provided as an array. If provided together with nifti/path,
297        will over-ride the affine that is in the nifti.
298
299    static : array, nifti image or str
300        Containing the data for the static object, or full path to a nifti file
301        with the moving data.
302
303    static_affine : 4x4 array, optional
304        An affine transformation associated with the static object. Required if
305        data is provided as an array. If provided together with nifti/path,
306        will over-ride the affine that is in the nifti.
307
308    between_affine: 4x4 array, optional
309        If an additional affine is needed betweeen the two spaces.
310        Default: identity (no additional registration).
311
312    Returns
313    -------
314    A Nifti1Image class instance with the data from the moving object
315    resampled into the space of the static object.
316    """
317
318    static, static_affine, moving, moving_affine, between_affine = \
319        _handle_pipeline_inputs(moving, static,
320                                moving_affine=moving_affine,
321                                static_affine=static_affine,
322                                starting_affine=between_affine)
323    affine_map = AffineMap(between_affine,
324                           static.shape, static_affine,
325                           moving.shape, moving_affine)
326    resampled = affine_map.transform(moving)
327    return nib.Nifti1Image(resampled, static_affine)
328
329
330def center_of_mass(moving, static, static_affine=None, moving_affine=None,
331                   starting_affine=None, reg=None):
332    """
333    Implements a center of mass transform
334
335    Parameters
336    ----------
337    moving : array, nifti image or str
338        Containing the data for the moving object, or full path to a nifti file
339        with the moving data.
340
341    moving_affine : 4x4 array, optional
342        An affine transformation associated with the moving object. Required if
343        data is provided as an array. If provided together with nifti/path,
344        will over-ride the affine that is in the nifti.
345
346    static : array, nifti image or str
347        Containing the data for the static object, or full path to a nifti file
348        with the moving data.
349
350    static_affine : 4x4 array, optional
351        An affine transformation associated with the static object. Required if
352        data is provided as an array. If provided together with nifti/path,
353        will over-ride the affine that is in the nifti.
354
355    starting_affine: 4x4 array, optional
356        Initial guess for the transformation between the spaces.
357
358    reg : not needed here. Use None
359
360    Returns
361    -------
362    transformed, transform.affine : array with moving data resampled to the
363    static space after computing the center of mass transformation and the
364    affine 4x4 associated with the transformation.
365    """
366    static, static_affine, moving, moving_affine, starting_affine = \
367        _handle_pipeline_inputs(moving, static,
368                                moving_affine=moving_affine,
369                                static_affine=static_affine,
370                                starting_affine=starting_affine)
371    transform = transform_centers_of_mass(static, static_affine,
372                                          moving, moving_affine)
373    return transform.affine
374
375
376def translation(moving, static, static_affine=None, moving_affine=None,
377                starting_affine=None, reg=None):
378    """
379    Implements a translation transform
380
381    Parameters
382    ----------
383    moving : array, nifti image or str
384        Containing the data for the moving object, or full path to a nifti file
385        with the moving data.
386
387    moving_affine : 4x4 array, optional
388        An affine transformation associated with the moving object. Required if
389        data is provided as an array. If provided together with nifti/path,
390        will over-ride the affine that is in the nifti.
391
392    static : array, nifti image or str
393        Containing the data for the static object, or full path to a nifti file
394        with the moving data.
395
396    static_affine : 4x4 array, optional
397        An affine transformation associated with the static object. Required if
398        data is provided as an array. If provided together with nifti/path,
399        will over-ride the affine that is in the nifti.
400
401    starting_affine: 4x4 array, optional
402        Initial guess for the transformation between the spaces.
403
404    reg : AffineRegistration class instance.
405
406    Returns
407    -------
408    transformed, transform.affine : array with moving data resampled to the
409    static space after computing the translation transformation and the
410    affine 4x4 associated with the transformation.
411    """
412    static, static_affine, moving, moving_affine, starting_affine = \
413        _handle_pipeline_inputs(moving, static,
414                                moving_affine=moving_affine,
415                                static_affine=static_affine,
416                                starting_affine=starting_affine)
417    transform = TranslationTransform3D()
418    translation = reg.optimize(static, moving, transform, None,
419                               static_affine, moving_affine,
420                               starting_affine=starting_affine)
421
422    return translation.affine
423
424
425def rigid(moving, static, static_affine=None, moving_affine=None,
426          starting_affine=None, reg=None):
427    """
428    Implements a rigid transform
429
430    Parameters
431    ----------
432    moving : array, nifti image or str
433        Containing the data for the moving object, or full path to a nifti file
434        with the moving data.
435
436    moving_affine : 4x4 array, optional
437        An affine transformation associated with the moving object. Required if
438        data is provided as an array. If provided together with nifti/path,
439        will over-ride the affine that is in the nifti.
440
441    static : array, nifti image or str
442        Containing the data for the static object, or full path to a nifti file
443        with the moving data.
444
445    static_affine : 4x4 array, optional
446        An affine transformation associated with the static object. Required if
447        data is provided as an array. If provided together with nifti/path,
448        will over-ride the affine that is in the nifti.
449
450    starting_affine: 4x4 array, optional
451        Initial guess for the transformation between the spaces.
452
453    reg : AffineRegistration class instance.
454
455    Returns
456    -------
457    transformed, transform.affine : array with moving data resampled to the
458    static space after computing the rigid transformation and the affine 4x4
459    associated with the transformation.
460    """
461    static, static_affine, moving, moving_affine, starting_affine = \
462        _handle_pipeline_inputs(moving, static,
463                                moving_affine=moving_affine,
464                                static_affine=static_affine,
465                                starting_affine=starting_affine)
466    transform = RigidTransform3D()
467    rigid = reg.optimize(static, moving, transform, None,
468                         static_affine, moving_affine,
469                         starting_affine=starting_affine)
470    return rigid.affine
471
472
473def affine(moving, static, static_affine=None, moving_affine=None,
474           starting_affine=None, reg=None):
475    """
476    Implements a translation transform
477
478    Parameters
479    ----------
480    moving : array, nifti image or str
481        Containing the data for the moving object, or full path to a nifti file
482        with the moving data.
483
484    moving_affine : 4x4 array, optional
485        An affine transformation associated with the moving object. Required if
486        data is provided as an array. If provided together with nifti/path,
487        will over-ride the affine that is in the nifti.
488
489    static : array, nifti image or str
490        Containing the data for the static object, or full path to a nifti file
491        with the moving data.
492
493    static_affine : 4x4 array, optional
494        An affine transformation associated with the static object. Required if
495        data is provided as an array. If provided together with nifti/path,
496        will over-ride the affine that is in the nifti.
497
498    starting_affine: 4x4 array, optional
499        Initial guess for the transformation between the spaces.
500
501    reg : AffineRegistration class instance.
502
503    Returns
504    -------
505    transformed, transform.affine : array with moving data resampled to the
506    static space after computing the affine transformation and the affine
507    4x4 associated with the transformation.
508    """
509
510    static, static_affine, moving, moving_affine, starting_affine = \
511        _handle_pipeline_inputs(moving, static,
512                                moving_affine=moving_affine,
513                                static_affine=static_affine,
514                                starting_affine=starting_affine)
515    transform = AffineTransform3D()
516    xform = reg.optimize(static, moving, transform, None,
517                         static_affine, moving_affine,
518                         starting_affine=starting_affine)
519
520    return xform.affine
521
522
523def affine_registration(moving, static,
524                        moving_affine=None,
525                        static_affine=None,
526                        pipeline=None,
527                        starting_affine=None,
528                        metric='MI',
529                        level_iters=None,
530                        sigmas=None,
531                        factors=None,
532                        **metric_kwargs):
533
534    """
535    Find the affine transformation between two 3D images.
536
537    Parameters
538    ----------
539    moving : array, nifti image or str
540        Containing the data for the moving object, or full path to a nifti file
541        with the moving data.
542
543    moving_affine : 4x4 array, optional
544        An affine transformation associated with the moving object. Required if
545        data is provided as an array. If provided together with nifti/path,
546        will over-ride the affine that is in the nifti.
547
548    static : array, nifti image or str
549        Containing the data for the static object, or full path to a nifti file
550        with the moving data.
551
552    static_affine : 4x4 array, optional
553        An affine transformation associated with the static object. Required if
554        data is provided as an array. If provided together with nifti/path,
555        will over-ride the affine that is in the nifti.
556
557    pipeline : sequence, optional
558        Sequence of transforms to use in the gradual fitting of the full
559        affine. Default: (executed from left to right):
560        `[center_of_mass, translation, rigid, affine]`
561
562    starting_affine: 4x4 array, optional
563        Initial guess for the transformation between the spaces.
564        Default: identity.
565
566    metric : str, optional.
567        Currently only supports 'MI' for MutualInformationMetric.
568
569    nbins : int, optional
570        MutualInformationMetric key-word argument: the number of bins to be
571        used for computing the intensity histograms. The default is 32.
572
573    sampling_proportion : None or float in interval (0, 1], optional
574        MutualInformationMetric key-word argument: There are two types of
575        sampling: dense and sparse. Dense sampling uses all voxels for
576        estimating the (joint and marginal) intensity histograms, while
577        sparse sampling uses a subset of them. If `sampling_proportion` is
578        None, then dense sampling is used. If `sampling_proportion` is a
579        floating point value in (0,1] then sparse sampling is used,
580        where `sampling_proportion` specifies the proportion of voxels to
581        be used. The default is None (dense sampling).
582
583    level_iters : sequence, optional
584        AffineRegistration key-word argument: the number of iterations at each
585        scale of the scale space. `level_iters[0]` corresponds to the coarsest
586        scale, `level_iters[-1]` the finest, where n is the length of the
587        sequence. By default, a 3-level scale space with iterations
588        sequence equal to [10000, 1000, 100] will be used.
589
590    sigmas : sequence of floats, optional
591        AffineRegistration key-word argument: custom smoothing parameter to
592        build the scale space (one parameter for each scale). By default,
593        the sequence of sigmas will be [3, 1, 0].
594
595    factors : sequence of floats, optional
596        AffineRegistration key-word argument: custom scale factors to build the
597        scale space (one factor for each scale). By default, the sequence of
598        factors will be [4, 2, 1].
599
600    Returns
601    -------
602    transformed, affine : array with moving data resampled to the static space
603    after computing the affine transformation and the affine 4x4
604    associated with the transformation.
605
606
607    Notes
608    -----
609    Performs a gradual registration between the two inputs, using a pipeline
610    that gradually approximates the final registration. If the final default
611    step (`affine`) is ommitted, the resulting affine may not have all 12
612    degrees of freedom adjusted.
613    """
614    pipeline = pipeline or [center_of_mass, translation, rigid, affine]
615    level_iters = level_iters or [10000, 1000, 100]
616    sigmas = sigmas or [3, 1, 0.0]
617    factors = factors or [4, 2, 1]
618
619    static, static_affine, moving, moving_affine, starting_affine = \
620        _handle_pipeline_inputs(moving, static,
621                                moving_affine=moving_affine,
622                                static_affine=static_affine,
623                                starting_affine=starting_affine)
624
625    # Define the Affine registration object we'll use with the chosen metric.
626    # For now, there is only one metric (mutual information)
627    use_metric = affine_metric_dict[metric](**metric_kwargs)
628
629    affreg = AffineRegistration(metric=use_metric,
630                                level_iters=level_iters,
631                                sigmas=sigmas,
632                                factors=factors)
633
634    # Go through the selected transformation:
635    for func in pipeline:
636        starting_affine = func(moving, static,
637                               static_affine=static_affine,
638                               moving_affine=moving_affine,
639                               starting_affine=starting_affine,
640                               reg=affreg)
641
642    # After doing all that, resample once at the end:
643    affine_map = AffineMap(starting_affine,
644                           static.shape, static_affine,
645                           moving.shape, moving_affine)
646
647    resampled = affine_map.transform(moving)
648
649    return resampled, starting_affine
650
651
652def register_series(series, ref, pipeline=None, series_affine=None,
653                    ref_affine=None):
654    """Register a series to a reference image.
655
656    Parameters
657    ----------
658    series : 4D array or nib.Nifti1Image class instance or str
659        The data is 4D with the last dimension separating different 3D volumes
660
661    ref : int or 3D array or nib.Nifti1Image class instance or str
662        If this is an int, this is the index of the reference image within the
663        series. Otherwise it is an array of data to register with (associated
664        with a `ref_affine` required) or a nifti img or full path to a file
665        containing one.
666
667    pipeline : sequence, optional
668        Sequence of transforms to do for each volume in the series.
669        Default: (executed from left to right):
670        `[center_of_mass, translation, rigid, affine]`
671
672    series_affine, ref_affine : 4x4 arrays, optional.
673        The affine. If provided, this input will over-ride the affine provided
674        together with the nifti img or file.
675
676    Returns
677    -------
678    xformed, affines : 4D array with transformed data and a (4,4,n) array
679    with 4x4 matrices associated with each of the volumes of the input moving
680    data that was used to transform it into register with the static data.
681    """
682    pipeline = pipeline or [center_of_mass, translation, rigid, affine]
683
684    series, series_affine = read_img_arr_or_path(series,
685                                                 affine=series_affine)
686    if isinstance(ref, numbers.Number):
687        ref_as_idx = ref
688        idxer = np.zeros(series.shape[-1]).astype(bool)
689        idxer[ref] = True
690        ref = series[..., idxer].squeeze()
691        ref_affine = series_affine
692    else:
693        ref_as_idx = False
694        ref, ref_affine = read_img_arr_or_path(ref, affine=ref_affine)
695        if len(ref.shape) != 3:
696            raise ValueError("The reference image should be a single volume",
697                             " or the index of one or more volumes")
698
699    xformed = np.zeros(series.shape)
700    affines = np.zeros((4, 4, series.shape[-1]))
701    for ii in range(series.shape[-1]):
702        this_moving = series[..., ii]
703        if isinstance(ref_as_idx, numbers.Number) and ii == ref_as_idx:
704            # This is the reference! No need to move and the xform is I(4):
705            xformed[..., ii] = this_moving
706            affines[..., ii] = np.eye(4)
707        else:
708            transformed, reg_affine = affine_registration(
709                this_moving, ref,
710                moving_affine=series_affine,
711                static_affine=ref_affine,
712                pipeline=pipeline)
713            xformed[..., ii] = transformed
714            affines[..., ii] = reg_affine
715
716    return xformed, affines
717
718
719def register_dwi_series(data, gtab, affine=None, b0_ref=0, pipeline=None):
720    """
721    Register a DWI series to the mean of the B0 images in that series (all
722    first registered to the first B0 volume)
723
724    Parameters
725    ----------
726    data : 4D array or nibabel Nifti1Image class instance or str
727        Diffusion data. Either as a 4D array or as a nifti image object, or as
728        a string containing the full path to a nifti file.
729
730    gtab : a GradientTable class instance or tuple of strings
731        If provided as a tuple of strings, these are assumed to be full paths
732        to the bvals and bvecs files (in that order).
733
734    affine : 4x4 array, optional.
735        Must be provided for `data` provided as an array. If provided together
736        with Nifti1Image or str `data`, this input will over-ride the affine
737        that is stored in the `data` input. Default: use the affine stored
738        in `data`.
739
740    b0_ref : int, optional.
741        Which b0 volume to use as reference. Default: 0
742
743    pipeline : list of callables, optional.
744        The transformations to perform in sequence (from left to right):
745        Default: `[center_of_mass, translation, rigid, affine]`
746
747
748    Returns
749    -------
750    xform_img, affine_array: a Nifti1Image containing the registered data and
751    using the affine of the original data and a list containing the affine
752    transforms associated with each of the
753
754    """
755    if pipeline is None:
756        [center_of_mass, translation, rigid, affine]
757
758    data, affine = read_img_arr_or_path(data, affine=affine)
759    if isinstance(gtab, collections.Sequence):
760        gtab = dpg.gradient_table(*gtab)
761
762    if np.sum(gtab.b0s_mask) > 1:
763        # First, register the b0s into one image and average:
764        b0_img = nib.Nifti1Image(data[..., gtab.b0s_mask], affine)
765        trans_b0, b0_affines = register_series(b0_img, ref=b0_ref,
766                                               pipeline=pipeline)
767        ref_data = np.mean(trans_b0, -1)
768    else:
769        # There's only one b0 and we register everything to it
770        trans_b0 = ref_data = data[..., gtab.b0s_mask]
771        b0_affines = np.eye(4)[..., np.newaxis]
772
773    # Construct a series out of the DWI and the registered mean B0:
774    moving_data = data[..., ~gtab.b0s_mask]
775    series_arr = np.concatenate([ref_data, moving_data], -1)
776    series = nib.Nifti1Image(series_arr, affine)
777
778    xformed, affines = register_series(series, ref=0, pipeline=pipeline)
779    # Cut out the part pertaining to that first volume:
780    affines = affines[..., 1:]
781    xformed = xformed[..., 1:]
782    affine_array = np.zeros((4, 4, data.shape[-1]))
783    affine_array[..., gtab.b0s_mask] = b0_affines
784    affine_array[..., ~gtab.b0s_mask] = affines
785
786    data_array = np.zeros(data.shape)
787    data_array[..., gtab.b0s_mask] = trans_b0
788    data_array[..., ~gtab.b0s_mask] = xformed
789
790    return nib.Nifti1Image(data_array, affine), affine_array
791
792
793def streamline_registration(moving, static, n_points=100,
794                            native_resampled=False):
795    """
796    Register two collections of streamlines ('bundles') to each other
797
798    Parameters
799    ----------
800    moving, static : lists of 3 by n, or str
801        The two bundles to be registered. Given either as lists of arrays with
802        3D coordinates, or strings containing full paths to these files.
803
804    n_points : int, optional
805        How many points to resample to. Default: 100.
806
807    native_resampled : bool, optional
808        Whether to return the moving bundle in the original space, but
809        resampled in the static space to n_points.
810
811    Returns
812    -------
813    aligned : list
814        Streamlines from the moving group, moved to be closely matched to
815        the static group.
816
817    matrix : array (4, 4)
818        The affine transformation that takes us from 'moving' to 'static'
819    """
820    # Load the streamlines, if you were given a file-name
821    if isinstance(moving, str):
822        moving = load_trk(moving, 'same', bbox_valid_check=False).streamlines
823    if isinstance(static, str):
824        static = load_trk(static, 'same', bbox_valid_check=False).streamlines
825
826    srr = StreamlineLinearRegistration()
827    srm = srr.optimize(static=set_number_of_points(static, n_points),
828                       moving=set_number_of_points(moving, n_points))
829
830    aligned = srm.transform(moving)
831    if native_resampled:
832        aligned = set_number_of_points(aligned, n_points)
833        aligned = transform_tracking_output(aligned, np.linalg.inv(srm.matrix))
834
835    return aligned, srm.matrix
836