55
66import gdist
77import matplotlib .pyplot as plt
8+ from matplotlib .colors import LightSource
89import numpy as np
910from nibabel .freesurfer .io import read_annot
1011
@@ -124,17 +125,19 @@ def surfdist_viz(
124125 alpha = "auto" ,
125126 bg_map = None ,
126127 bg_on_stat = False ,
128+ bg_alpha = 1.0 ,
127129 figsize = None ,
128130 ax = None ,
129131 vmin = None ,
130132 vmax = None ,
133+ light_source = None ,
131134):
132135 """Visualize results on cortical surface using matplotlib.
133136
134137 Parameters
135138 ----------
136139 coords : numpy array of shape (n_nodes,3), each row specifying the x,y,z
137- coordinates of one node of surface mesh
140+ coordinates of one node of surface mesh
138141 faces : numpy array of shape (n_faces, 3), each row specifying the indices
139142 of the three nodes building one node of the surface mesh
140143 stat_map : numpy array of shape (n_nodes,) containing the values to be
@@ -158,9 +161,16 @@ def surfdist_viz(
158161 multiplied with the background map for shadowing. Otherwise,
159162 only areas that are not covered by the statsitical map after
160163 thresholding will show shadows.
164+ bg_alpha : float, determines the opacity of the background map.
165+ bg_alpha defaults to 1.0 and is only relevant if bg_on_stat
161166 figsize : tuple of intergers, dimensions of the figure that is produced.
162167 ax : Axis
163168 Axis to plot on, with 3d projection.
169+ light_source: None, bool, or tuple of int, optional
170+ Whether to apply a light source for shading. If True, the light
171+ source position is inferred from `elev` and `azim`. If a tuple of
172+ (alt, az), these values will be used to specify the light source
173+ position. If None or False, no shading is applied. Default is None.
164174
165175 Returns
166176 -------
@@ -226,7 +236,7 @@ def surfdist_viz(
226236 bg_faces = np .mean (bg_data [faces ], axis = 1 )
227237 bg_faces = bg_faces - bg_faces .min ()
228238 bg_faces = bg_faces / bg_faces .max ()
229- face_colors = plt .cm .gray_r (bg_faces )
239+ face_colors = plt .cm .gray_r (bg_faces * bg_alpha )
230240
231241 # modify alpha values of background
232242 face_colors [:, 3 ] = alpha * face_colors [:, 3 ]
@@ -260,6 +270,41 @@ def surfdist_viz(
260270 else :
261271 face_colors = cmap (stat_map_faces )
262272
273+ if light_source :
274+ if hasattr (light_source , '__len__' ):
275+ if len (light_source ) == 2 :
276+ ls = LightSource (azdeg = light_source [1 ], altdeg = light_source [0 ])
277+ else :
278+ # Apply lighting to the face colors for shading
279+ ls = LightSource (azdeg = azim , altdeg = elev )
280+
281+ # Manually calculate the light vector since the 'light_vector'
282+ # attribute is not accessible in some matplotlib versions.
283+ az = np .radians (ls .azdeg )
284+ alt = np .radians (ls .altdeg )
285+ light_vec = np .array ([
286+ np .cos (az ) * np .cos (alt ),
287+ np .sin (az ) * np .cos (alt ),
288+ np .sin (alt )
289+ ])
290+
291+ # Calculate face normals
292+ v0 = coords [faces [:, 0 ]]
293+ v1 = coords [faces [:, 1 ]]
294+ v2 = coords [faces [:, 2 ]]
295+ face_normals = np .cross (v1 - v0 , v2 - v0 )
296+ face_normals /= np .linalg .norm (face_normals , axis = 1 )[:, np .newaxis ]
297+
298+ # The shade is the dot product of the light vector and face normals
299+ shade = np .dot (face_normals , light_vec )
300+
301+ # Modulate the RGB colors by the shade, keeping the alpha channel
302+ # Use np.clip to keep shade values between 0 and 1
303+ illuminated_rgb = face_colors [:, :3 ] * np .clip (shade , 0 , 1 )[:, np .newaxis ]
304+
305+ # Combine illuminated RGB with the original alpha channel
306+ face_colors = np .hstack ((illuminated_rgb , face_colors [:, 3 :]))
307+
263308 p3dcollec .set_facecolors (face_colors )
264309
265310 if not premade_ax :
0 commit comments