Skip to content
4 changes: 4 additions & 0 deletions docs/references/localization.rst
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,10 @@ Convert from FSAverage to MNI152

.. autofunction:: fsaverage_to_mni152

Convert from any source space to any target space
------------------------------------------------

.. autofunction:: src_to_dst

Localization on a Freesurfer Brain
----------------------------------
Expand Down
4 changes: 2 additions & 2 deletions naplib/localization/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from .freesurfer import Brain, find_closest_vertices
from .coordinate_conversions import mni152_to_fsaverage, fsaverage_to_mni152
from .coordinate_conversions import mni152_to_fsaverage, fsaverage_to_mni152, src_to_dst

__all__ = ['Brain', 'find_closest_vertices', 'mni152_to_fsaverage', 'fsaverage_to_mni152']
__all__ = ['Brain', 'find_closest_vertices', 'mni152_to_fsaverage', 'fsaverage_to_mni152', 'src_to_dst']
84 changes: 84 additions & 0 deletions naplib/localization/coordinate_conversions.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import numpy as np
import nibabel.freesurfer.io as fsio
from scipy.spatial import cKDTree

def mni152_to_fsaverage(coords):
"""
Expand Down Expand Up @@ -38,4 +40,86 @@ def fsaverage_to_mni152(coords):
new_coords = (xform @ old_coords).T
return new_coords

def src_to_dst(coords, src_pial, src_sphere, dst_pial, dst_sphere, verbose=False):
"""
Convert 3D coordinates from any space to another space.
Each subject comes with a bunch of MRI files; In this function these files are used:
1. lh.pial file of the source space ==> SRC_PATH/surf/lh.pial
2. lh.sphere.reg file of the source ==> SRC_PATH/surf/lh.sphere.reg
3. lh.pial file of the destination ==> DST_PATH/surf/lh.pial
4. lh.sphere.reg file of the destination ==> DST_PATH/surf/lh.sphere.reg

Provide LH files, the function assumes the RH ones are in the same directory.

NOTE: In case of converting to an atlas space, the files we need are accessible
by installing freesurfer: https://surfer.nmr.mgh.harvard.edu/fswiki/DownloadAndInstall
Path of different atlas spaces: PATH_freesurfer/8.0.0/subjects/

Parameters
----------
coords : np.ndarray (elecs, 3)
Coordinates in source space. Can be in both hemispheres.
src_pial : str/dict{'vert_lh', 'vert_rh'}
Path to the source pial surface file (e.g., 'lh.pial'). In case of a mat file for pial surfaces,
provide a dictionary with keys 'vert_lh' and 'vert_rh' containing the vertices for each hemisphere.
src_sphere : str
Path to the source sphere file (e.g., 'lh.sphere.reg')
dst_pial : str
Path to the destination pial surface file (e.g., 'lh.pial')
dst_sphere : str
Path to the destination sphere file (e.g., 'lh.sphere.reg')
verbose : bool, optional
If True, prints additional information about the conversion process. Default is False.

Returns
-------
new_coords : np.ndarray (elecs, 3)
Coordinates in target space
"""
src_sphere_lh, _ = fsio.read_geometry(src_sphere)
src_sphere_rh, _ = fsio.read_geometry(src_sphere.replace('lh', 'rh'))
src_sphere = np.vstack((src_sphere_lh, src_sphere_rh))

tgt_sphere_lh, _ = fsio.read_geometry(dst_sphere)
tgt_sphere_rh, _ = fsio.read_geometry(dst_sphere.replace('lh', 'rh'))

tree_lh = cKDTree(tgt_sphere_lh)
tree_rh = cKDTree(tgt_sphere_rh)

if np.isnan(coords).any():
print("WARNING: NaN values found in coordinates. Replacing with zeros.")
coords = np.nan_to_num(coords)

if isinstance(src_pial, str):
lh_verts_sub, _ = fsio.read_geometry(src_pial)
rh_verts_sub = fsio.read_geometry(src_pial.replace('lh', 'rh'))[0]
lh_threshold = lh_verts_sub.shape[0]
lh_verts_sub = np.vstack((lh_verts_sub, rh_verts_sub))
else:
lh_verts_sub = src_pial['vert_lh']
rh_verts_sub = src_pial['vert_rh']
lh_threshold = lh_verts_sub.shape[0]

lh_verts_sub_fs, _ = fsio.read_geometry(dst_pial)
rh_verts_sub_fs, _ = fsio.read_geometry(dst_pial.replace('lh', 'rh'))

tree_elecs = cKDTree(lh_verts_sub)
_, mapping_indices_elecs = tree_elecs.query(coords, k=1)

if verbose:
print(f"#Electrodes in LH: {np.sum(mapping_indices_elecs < lh_threshold)}, RH: {np.sum(mapping_indices_elecs >= lh_threshold)}")

mapping_indices_elecs_lh = mapping_indices_elecs[mapping_indices_elecs < lh_threshold]
_, mapping_indices_elecs_warped_lh = tree_lh.query(src_sphere[mapping_indices_elecs_lh], k=1)

mapping_indices_elecs_rh = mapping_indices_elecs[mapping_indices_elecs >= lh_threshold]
_, mapping_indices_elecs_warped_rh = tree_rh.query(src_sphere[mapping_indices_elecs_rh - lh_threshold], k=1)

new_coords_lh = lh_verts_sub_fs[mapping_indices_elecs_warped_lh]
new_coords_rh = rh_verts_sub_fs[mapping_indices_elecs_warped_rh]

new_coords = np.zeros((coords.shape[0], 3))
new_coords[mapping_indices_elecs < lh_threshold] = new_coords_lh
new_coords[mapping_indices_elecs >= lh_threshold] = new_coords_rh

return new_coords
19 changes: 18 additions & 1 deletion tests/test_coordinate_conversions.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import numpy as np
from naplib.localization import mni152_to_fsaverage, fsaverage_to_mni152
import os
import mne
from naplib.localization import mni152_to_fsaverage, fsaverage_to_mni152, src_to_dst

def test_mni152_fsaverage_conversions():
coords_tmp = np.array([[13.987, 36.5, 10.067], [-10.54, 24.5, 15.555]])
Expand All @@ -12,3 +14,18 @@ def test_mni152_fsaverage_conversions():

coords_tmp3 = fsaverage_to_mni152(coords_tmp2)
assert np.allclose(coords_tmp3, coords_tmp, rtol=1e-3)

def test_src_to_dst():
coords = np.random.rand(2, 3) * 5

os.makedirs('./.fsaverage_tmp', exist_ok=True)
mne.datasets.fetch_fsaverage('./.fsaverage_tmp/')

src_pial = './.fsaverage_tmp/fsaverage/surf/lh.pial'
src_sphere = './.fsaverage_tmp/fsaverage/surf/lh.sphere.reg'
dst_pial = './.fsaverage_tmp/fsaverage/surf/lh.inflated'
dst_sphere = './.fsaverage_tmp/fsaverage/surf/lh.sphere.reg'
Comment thread
gavinmischler marked this conversation as resolved.
Outdated

inflated_coords = src_to_dst(coords, src_pial, src_sphere, dst_pial, dst_sphere)

assert inflated_coords.shape[0] == coords.shape[0]