Skip to content
This repository was archived by the owner on Feb 8, 2025. It is now read-only.
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions AFQ/_fixes.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,9 @@ def __init__(self, *args, min_length=10, max_length=1000, **kwargs):
self.max_length = max_length
_generate_streamlines = _verbose_generate_streamlines

def __reduce__(self):
return (self.__init__, ())


class VerboseParticleFilteringTracking(ParticleFilteringTracking):
def __init__(self, *args, min_length=10, max_length=1000, **kwargs):
Expand All @@ -90,6 +93,9 @@ def __init__(self, *args, min_length=10, max_length=1000, **kwargs):
self.max_length = max_length
_generate_streamlines = _verbose_generate_streamlines

def __reduce__(self):
return (self.__init__, ())


def in_place_norm(vec, axis=-1, keepdims=False, delvec=True):
""" Return Vectors with Euclidean (L2) norm
Expand Down
4 changes: 2 additions & 2 deletions AFQ/models/csd.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
from dipy.reconst import csdeconv as csd
from dipy.reconst import mcsd
from dipy.reconst import shm
import dipy.data as dpd
import AFQ.utils.models as ut

# Monkey patch fixed spherical harmonics for conda and fixed solve_qp from
Expand Down Expand Up @@ -112,7 +111,8 @@ def fit_csd(data_files, bval_files, bvec_files, mask=None, response=None,
fname : the full path to the file containing the SH coefficients.
"""
img, data, gtab, mask = ut.prepare_data(data_files, bval_files, bvec_files,
b0_threshold=b0_threshold)
b0_threshold=b0_threshold,
mask=mask)

csdfit = _fit(gtab, data, mask, response=response, sh_order=sh_order,
lambda_=lambda_, tau=tau, msmt=msmt)
Expand Down
2 changes: 1 addition & 1 deletion AFQ/models/dti.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,7 @@ def fit_dti(data_files, bval_files, bvec_files, mask=None,
b0_threshold=b0_threshold)

# In this case, we dump the fit object
dtf = _fit(gtab, data, mask=None)
dtf = _fit(gtab, data, mask=mask)
FA, MD, AD, RD, params = dtf.fa, dtf.md, dtf.ad, dtf.rd, dtf.model_params

maps = [FA, MD, AD, RD, params]
Expand Down
2 changes: 1 addition & 1 deletion AFQ/tests/test_tractography.py
Original file line number Diff line number Diff line change
Expand Up @@ -123,4 +123,4 @@ def test_pft_tracking():
n_seeds=1,
step_size=step_size,
min_length=min_length,
tracker="pft")
tracker="pft")
89 changes: 64 additions & 25 deletions AFQ/tractography.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from collections.abc import Iterable
from itertools import chain
import numpy as np
import nibabel as nib
import dipy.reconst.shm as shm
import logging

import dipy.data as dpd
Expand All @@ -15,15 +15,35 @@
ActStoppingCriterion)


class PklDeterministicDG(DeterministicMaximumDirectionGetter):
def __reduce__(self):
return (self.__init__, ())


class PklProbabilisticDG(ProbabilisticDirectionGetter):
def __reduce__(self):
return (self.__init__, ())


class PklThresholdStoppingCriterion(ThresholdStoppingCriterion):
def __reduce__(self):
return (self.__init__, ())


from AFQ._fixes import (VerboseLocalTracking, VerboseParticleFilteringTracking,
tensor_odf)

from AFQ.utils.parallel import parfor


def track(params_file, directions="det", max_angle=30., sphere=None,
seed_mask=None, seed_threshold=0, n_seeds=1, random_seeds=False,
rng_seed=None, stop_mask=None, stop_threshold=0, step_size=0.5,
min_length=10, max_length=1000, odf_model="DTI",
tracker="local"):
tracker="local",
parallel_kwargs={"n_jobs": -1,
"engine": "joblib",
"backend": "loky"}):
"""
Tractography

Expand All @@ -32,7 +52,7 @@ def track(params_file, directions="det", max_angle=30., sphere=None,
params_file : str, nibabel img.
Full path to a nifti file containing CSD spherical harmonic
coefficients, or nibabel img with model params.
directions : str
directions : str or initialized direction-getter object
How tracking directions are determined.
One of: {"det" | "prob"}
max_angle : float, optional.
Expand Down Expand Up @@ -110,7 +130,6 @@ def track(params_file, directions="det", max_angle=30., sphere=None,
model_params = params_img.get_fdata()
affine = params_img.affine
odf_model = odf_model.upper()
directions = directions.lower()

logger.info("Generating Seeds...")
if isinstance(n_seeds, int):
Expand All @@ -134,28 +153,38 @@ def track(params_file, directions="det", max_angle=30., sphere=None,
sphere = dpd.default_sphere

logger.info("Getting Directions...")
if directions == "det":
dg = DeterministicMaximumDirectionGetter
elif directions == "prob":
dg = ProbabilisticDirectionGetter

if odf_model == "DTI" or odf_model == "DKI":
evals = model_params[..., :3]
evecs = model_params[..., 3:12].reshape(params_img.shape[:3] + (3, 3))
odf = tensor_odf(evals, evecs, sphere)
dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
elif odf_model == "CSD" or "MSMT":
dg = dg.from_shcoeff(model_params, max_angle=max_angle, sphere=sphere)
if isinstance(directions, str):
directions = directions.lower()
if directions == "det":
dg = PklDeterministicDG
elif directions == "prob":
dg = PklProbabilisticDG

if odf_model == "DTI" or odf_model == "DKI":
evals = model_params[..., :3]
evecs = model_params[..., 3:12].reshape(
params_img.shape[:3] + (3, 3))
odf = tensor_odf(evals, evecs, sphere)
dg = dg.from_pmf(odf, max_angle=max_angle, sphere=sphere)
elif odf_model == "CSD" or "MSMT":
dg = dg.from_shcoeff(model_params,
max_angle=max_angle,
sphere=sphere)
else:
# Assume it's an already-initialized dg
dg = directions

if tracker == "local":
if stop_mask is None:
stop_mask = np.ones(params_img.shape[:3])

if stop_mask.dtype == 'bool':
stopping_criterion = ThresholdStoppingCriterion(stop_mask,
0.5)
stopping_criterion = PklThresholdStoppingCriterion(
stop_mask,
0.5)
else:
stopping_criterion = ThresholdStoppingCriterion(stop_mask,
stopping_criterion = PklThresholdStoppingCriterion(
stop_mask,
stop_threshold)

my_tracker = VerboseLocalTracking
Expand Down Expand Up @@ -213,14 +242,24 @@ def track(params_file, directions="det", max_angle=30., sphere=None,
pve_gm_data,
pve_csf_data)

logger.info("Tracking...")
logger.info("Tracking!")
seeds_list = [seeds[ss * 100:(ss + 1) * 100] for ss in
range((seeds.shape[0] // 100) + 1)]

results = parfor(
_tracking, seeds_list,
func_args=[
my_tracker, dg, stopping_criterion, params_img],
func_kwargs=dict(
step_size=step_size, min_length=min_length,
max_length=max_length, random_seed=rng_seed),
**parallel_kwargs)

return _tracking(my_tracker, seeds, dg, stopping_criterion, params_img,
step_size=step_size, min_length=min_length,
max_length=max_length, random_seed=rng_seed)
return StatefulTractogram(chain.from_iterable(results),
params_img, Space.RASMM)


def _tracking(tracker, seeds, dg, stopping_criterion, params_img,
def _tracking(seeds, tracker, dg, stopping_criterion, params_img,
step_size=0.5, min_length=10, max_length=1000,
random_seed=None):
"""
Expand All @@ -239,4 +278,4 @@ def _tracking(tracker, seeds, dg, stopping_criterion, params_img,
max_length=max_length,
random_seed=random_seed)

return StatefulTractogram(tracker, params_img, Space.RASMM)
return tracker
4 changes: 3 additions & 1 deletion AFQ/utils/parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ def parfor(func, in_list, out_shape=None, n_jobs=-1, engine="joblib",

"""
if engine == "joblib":
from joblib.externals.loky import set_loky_pickler

p = joblib.Parallel(n_jobs=n_jobs, backend=backend)
d = joblib.delayed(func)
d_l = []
for in_element in in_list:
d_l.append(d(in_element, *func_args, **func_kwargs))
results = p(tqdm(d_l))
results = p(d_l)

elif engine == "dask":
if n_jobs == -1:
Expand Down