@@ -116,7 +116,7 @@ def _rereference(data_arr, method='avg', return_ref=False):
116116 return data_rereferenced
117117
118118
119- def make_contact_rereference_arr (channelnames , extent = None ):
119+ def make_contact_rereference_arr (channelnames , extent = None , grid_sizes = {} ):
120120 """
121121 Create grid which defines re-referencing scheme based on electrodes being on the same contact as
122122 each other.
@@ -128,13 +128,17 @@ def make_contact_rereference_arr(channelnames, extent=None):
128128 be alphanumeric, with any numbers only being on the right. 2) The numeric portion specifies a
129129 different electrode number, while the character portion in the left of the channelname specifies the
130130 contact name. E.g. ['RT1','RT2','RT3','Ls1','Ls2'] indicates two contacts, the first with 3 electrodes
131- and the second with 2 electrodes. 3) Electrodes from the same contact must be contiguous.
131+ and the second with 2 electrodes.
132132 extent : int, optional, default=None
133133 If provided, then only contacts from the same group which are within ``extent`` electrodes away
134- from each other (inclusive) are still grouped together. Only used if ``method='contact'``. For
135- example, if ``extent=1``, only the nearest electrode on either side of a given electrode on the
136- same contact is still grouped with it. For example, extent=1 produces the traditional local
137- average reference scheme.
134+ from each other (inclusive) are still grouped together. For example, if ``extent=1``, only the
135+ nearest electrode on either side of a given electrode on the same contact is still grouped with it.
136+ This ``extent=1`` produces the traditional local average reference scheme.
137+ The default ``extent=None`` produces the traditional common average reference scheme.
138+ grid_sizes : dict, optional, default={}
139+ If provided, contains {'contact_name': (nrow, ncol)} values for any known ECoG grid sizes.
140+ E.g. {'GridA': (8, 16)} indicates that electrodes on contact 'GridA' are arranged in an 8 x 16 grid,
141+ which is needed to determine adjacent electrodes for local average referencing with ``extent >= 1``.
138142
139143 Returns
140144 -------
@@ -145,18 +149,89 @@ def make_contact_rereference_arr(channelnames, extent=None):
145149 --------
146150 rereference
147151 """
148- contact_arrays = pd .Series ([x .rstrip ('0123456789' ) for x in channelnames ])
149- connections = np .zeros ((len (contact_arrays ),) * 2 , dtype = float )
150- for _ , inds in contact_arrays .groupby (contact_arrays ):
151- for i in inds .index :
152- connections [i , inds .index ] = 1.0
152+ def _find_adjacent_numbers (a , b , number , extent ):
153+ '''
154+ Used to determine electrodes for local averaging ECoG grid"
155+ '''
156+ # Validate if the number is within the valid range
157+ if number < 1 or number > a * b :
158+ raise ValueError ("The number is outside the range of the grid." )
153159
154- # remove longer than extent if desired
155- if extent is not None :
156- if extent < 1 :
157- raise ValueError (f'Invalid extent. Must be no less than 1 but got extent={ extent } ' )
158- connections *= np .tri (* connections .shape , k = extent )
159- connections *= np .fliplr (np .flipud (np .tri (* connections .shape , k = extent )))
160- connections = connections
161-
160+ # Calculate the row and column of the given number
161+ row = (number - 1 ) // b
162+ col = (number - 1 ) % b
163+
164+ # Find all adjacent numbers within the extent
165+ adjacent_numbers = []
166+ for dr in range (- extent , extent + 1 ): # Rows within the extent
167+ for dc in range (- extent , extent + 1 ): # Columns within the extent
168+ if dr == 0 and dc == 0 :
169+ continue # Skip the number itself
170+ new_row , new_col = row + dr , col + dc
171+ if 0 <= new_row < a and 0 <= new_col < b :
172+ adjacent_num = new_row * b + new_col + 1
173+ adjacent_numbers .append (adjacent_num )
174+
175+ return np .array (adjacent_numbers , dtype = int )
176+
177+ connections = np .zeros ((len (channelnames ),) * 2 , dtype = float )
178+ channelnames = np .array (channelnames )
179+ contact_arrays = np .array ([x .rstrip ('0123456789' ) for x in channelnames ])
180+ contacts = np .unique (contact_arrays )
181+ # Determine the channel numbers on each contact
182+ ch_per_contact = {contact :[int (x .replace (contact ,'' )) for x in channelnames
183+ if x .rstrip ('0123456789' )== contact ]
184+ for contact in contacts }
185+
186+ if extent is None :
187+ # Common average referencing per electrode array (ECoG grid or sEEG shank)
188+ # CAR will end up subtracting parts of channel ch from itself
189+ for contact in contacts :
190+ for ch in ch_per_contact [contact ]:
191+ curr = np .where (channelnames == f'{ contact } { ch } ' )[0 ]
192+ inds = np .where (contact_arrays == contact )[0 ]
193+ connections [curr ,inds ] = 1
194+ elif extent < 1 :
195+ raise ValueError (f'Invalid extent. Must be no less than 1 but got extent={ extent } ' )
196+ else :
197+ # Local average referencing within each electrode array
198+ # LAR will NOT subtract parts of channel ch from itself
199+ for contact in contacts :
200+ for ch in ch_per_contact [contact ]:
201+ # Local referencing for ECoG grids
202+ if 'grid' in contact .lower ():
203+ num_ch = len (ch_per_contact [contact ])
204+ side = np .sqrt (num_ch )
205+ half_side = np .sqrt (num_ch / 2 )
206+ # Check grid_sizes dict
207+ if contact in grid_sizes :
208+ nrows , ncols = grid_sizes [contact ]
209+ # Assume a square
210+ elif np .isclose (side , int (side )):
211+ nrows , ncols = side , side
212+ # Assume a 1 x 2 rectangle
213+ elif np .isclose (half_side , int (half_side )):
214+ nrows , ncols = half_side , half_side * 2
215+ else :
216+ raise Exception (f'Cannot determine { contact } layout. Please include layout in `grid_sizes`' )
217+ adjacent = _find_adjacent_numbers (nrows , ncols , ch , extent )
218+ curr = np .where (channelnames == f'{ contact } { ch } ' )[0 ]
219+ inds = []
220+ for adj in adjacent :
221+ inds .append (np .where (channelnames == f'{ contact } { adj } ' )[0 ])
222+
223+ # Local referencing for sEEG shanks and strips
224+ else :
225+ curr = np .where (channelnames == f'{ contact } { ch } ' )[0 ]
226+ inds = []
227+ for cc in range (ch - extent , ch + extent + 1 ):
228+ if cc != ch :
229+ inds .append (np .where (channelnames == f'{ contact } { cc } ' )[0 ])
230+
231+ inds = np .concatenate (inds )
232+ if len (inds ) < 1 :
233+ print (f'{ contact } { cc } has no re-references.' )
234+ else :
235+ connections [curr ,inds ] = 1
236+
162237 return connections
0 commit comments