Skip to content
Merged
102 changes: 83 additions & 19 deletions naplib/preprocessing/rereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def _rereference(data_arr, method='avg', return_ref=False):
return data_rereferenced


def make_contact_rereference_arr(channelnames, extent=None):
def make_contact_rereference_arr(channelnames, extent=None, grid_sizes={}):
"""
Create grid which defines re-referencing scheme based on electrodes being on the same contact as
each other.
Expand All @@ -128,13 +128,16 @@ def make_contact_rereference_arr(channelnames, extent=None):
be alphanumeric, with any numbers only being on the right. 2) The numeric portion specifies a
different electrode number, while the character portion in the left of the channelname specifies the
contact name. E.g. ['RT1','RT2','RT3','Ls1','Ls2'] indicates two contacts, the first with 3 electrodes
and the second with 2 electrodes. 3) Electrodes from the same contact must be contiguous.
and the second with 2 electrodes.
extent : int, optional, default=None
If provided, then only contacts from the same group which are within ``extent`` electrodes away
from each other (inclusive) are still grouped together. Only used if ``method='contact'``. For
example, if ``extent=1``, only the nearest electrode on either side of a given electrode on the
same contact is still grouped with it. For example, extent=1 produces the traditional local
average reference scheme.
from each other (inclusive) are still grouped together. For example, if ``extent=1``, only the
nearest electrode on either side of a given electrode on the same contact is still grouped with it.
This ``extent=1`` produces the traditional local average reference scheme.
grid_sizes : dict, optional, default={}
If provided, contains {'contact_name': (nrow, ncol)} values for any known ECoG grid sizes.
E.g. {'GridA': (8, 16)} indicates that electrodes on contact 'GridA' are arranged in an 8 x 16 grid,
which is needed to determine adjacent electrodes for local average referencing with ``extent >= 1``.

Returns
-------
Expand All @@ -145,18 +148,79 @@ def make_contact_rereference_arr(channelnames, extent=None):
--------
rereference
"""
contact_arrays = pd.Series([x.rstrip('0123456789') for x in channelnames])
connections = np.zeros((len(contact_arrays),) * 2, dtype=float)
for _, inds in contact_arrays.groupby(contact_arrays):
for i in inds.index:
connections[i, inds.index] = 1.0
def _find_adjacent_numbers(a, b, number, extent):
'''
Used to determine electrodes for local averaging ECoG grid"
'''
# Validate if the number is within the valid range
if number < 1 or number > a * b:
raise ValueError("The number is outside the range of the grid.")

# remove longer than extent if desired
if extent is not None:
if extent < 1:
raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}')
connections *= np.tri(*connections.shape, k=extent)
connections *= np.fliplr(np.flipud(np.tri(*connections.shape, k=extent)))
connections = connections

# Calculate the row and column of the given number
row = (number - 1) // b
col = (number - 1) % b

# Find all adjacent numbers within the extent
adjacent_numbers = []
for dr in range(-extent, extent + 1): # Rows within the extent
for dc in range(-extent, extent + 1): # Columns within the extent
if dr == 0 and dc == 0:
continue # Skip the number itself
new_row, new_col = row + dr, col + dc
if 0 <= new_row < a and 0 <= new_col < b:
adjacent_num = new_row * b + new_col + 1
adjacent_numbers.append(adjacent_num)

return np.array(adjacent_numbers, dtype=int)
connections = np.zeros((len(channelnames),) * 2, dtype=float)
channelnames = np.array(channelnames)
contact_arrays = np.array([x.rstrip('0123456789') for x in channelnames])
contacts, ch_per_contact = np.unique([x.rstrip('0123456789') for x in channelnames], return_counts=True)
if extent is None:
# Common average referencing per electrode array (ECoG grid or sEEG shank)
# CAR will end up subtracting parts of channel ch from itself
for contact, num_ch in zip(contacts, ch_per_contact):
for ch in range(1,num_ch+1):
curr = np.where(channelnames==f'{contact}{ch}')[0]
inds = np.where(contact_arrays==contact)[0]
connections[curr,inds] = 1
elif extent < 1:
raise ValueError(f'Invalid extent. Must be no less than 1 but got extent={extent}')
else:
# Local average referencing within each electrode array
# LAR will NOT subtract parts of channel ch from itself
for contact, num_ch in zip(contacts, ch_per_contact):
for ch in range(1,num_ch+1):
# Local referencing for ECoG grids
if 'grid' in contact.lower():
side = np.sqrt(num_ch)
half_side = np.sqrt(num_ch/2)
# Check grid_sizes dict
if contact in grid_sizes:
nrows, ncols = grid_sizes[contact]
# Assume a square
elif np.isclose(side, int(side)):
nrows, ncols = side, side
# Assume a 1 x 2 rectangle
elif np.isclose(half_side, int(half_side)):
nrows, ncols = half_side, half_side*2
else:
raise Exception(f'Cannot determine {contact} layout. Please include layout in `grid_sizes`')
adjacent = _find_adjacent_numbers(nrows, ncols, ch, extent)
curr = np.where(channelnames==f'{contact}{ch}')[0]
inds = []
for adj in adjacent:
inds.append(np.where(channelnames==f'{contact}{adj}')[0])

# Local referencing for sEEG shanks and strips
else:
curr = np.where(channelnames==f'{contact}{ch}')[0]
inds = []
for cc in range(ch-extent, ch+extent+1):
if cc != ch:
inds.append(np.where(channelnames==f'{contact}{cc}')[0])

inds = np.concatenate(inds)
connections[curr,inds] = 1

return connections
25 changes: 22 additions & 3 deletions tests/preprocessing/test_rereference.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,29 @@


def test_create_contact_rereference_arr():
Comment thread
vinaysraghavan marked this conversation as resolved.
expected = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]])
g = ['LT1','LT2','RT1','RT2']
arr = make_contact_rereference_arr(g)
expected = np.array([[0,1,0,0,0,0,0,0],
[1,0,0,0,0,0,0,0],
[0,0,0,1,0,0,0,0],
[0,0,1,0,0,0,0,0],
[0,0,0,0,0,1,1,1],
[0,0,0,0,1,0,1,1],
[0,0,0,0,1,1,0,1],
[0,0,0,0,1,1,1,0],
])
expected1 = np.array([[1,1,0,0,0,0,0,0],
[1,1,0,0,0,0,0,0],
[0,0,1,1,0,0,0,0],
[0,0,1,1,0,0,0,0],
[0,0,0,0,1,1,1,1],
[0,0,0,0,1,1,1,1],
[0,0,0,0,1,1,1,1],
[0,0,0,0,1,1,1,1],
])
g = ['LT1','LT2','GridA1','GridA2'] + [f'GridB{n}' for n in range(1,5)]
arr = make_contact_rereference_arr(g, extent=1, grid_sizes={'GridA':(1,2)})
arr1 = make_contact_rereference_arr(g)
assert np.allclose(expected, arr)
assert np.allclose(expected1, arr1)

def test_rereference_avg():
arr = np.array([[1,1,0,0],[1,1,0,0],[0,0,1,1],[0,0,1,1]])
Expand Down