@@ -109,7 +109,7 @@ def plot_trials(trial_df: pd.DataFrame, params: Dict[str, Any], **kwargs) -> Non
109109 )
110110
111111
112- def difficultyPlot (animal_id : int , session : int , save_path = None ) -> None :
112+ def difficultyPlot (animal_id : int , session : int , save_path = None , params = None ) -> None :
113113 """Create a comprehensive difficulty plot for an animal session.
114114
115115 Generates a visualization showing trial outcomes (reward, punish, abort) across
@@ -145,20 +145,27 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
145145 trials_beh = get_trial_behavior (animal_id , session ).drop (["time" ], axis = 1 )
146146 ports_selection_corr_df = pd .merge (trials_beh , correct_trials_df , how = "inner" )
147147 perf_difficulty (animal_id , session )
148- params = {
148+
149+ default_params = {
149150 "probe_colors" : {
150151 1 : [1 , 0 , 0 ],
151- 2 : [0 , 0.5 , 1 ],
152+ 2 : [0.12156863 , 0.46666667 , 0.70588235 ],
152153 - 1 : [1 , 0 , 0 ],
153154 }, # colors for correct
154155 "trial_bins" : 10 , # how many trials on y axis
155156 "range" : 0.9 , # define offset range(diff is int so offset range(0,1))
156157 "xlim" : (- 2 ,), # plot lims
157158 "ylim" : (min_difficulty - 0.6 ,),
158- "figsize" : (16 , 6 ),
159- # **kwargs ,
159+ "figsize" : (12 , 10 ),
160+ "marker_size" : 10 ,
160161 }
161162
163+ if params is None :
164+ params = default_params
165+ else :
166+ # Merge user params with defaults, user params take priority
167+ params = {** default_params , ** params }
168+
162169 # create an array with colors for every correct trial based on the selected port
163170 clr_index_corr = np .array (
164171 [
@@ -170,18 +177,21 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
170177 )
171178
172179 plt .figure (figsize = params ["figsize" ], tight_layout = True )
173- plot_trials (correct_trials_df , params , s = 10 , c = clr_index_corr , label = "reward" )
174- plot_trials (incorrect_trials_df , params , s = 10 , c = "black" , label = "punish" )
175- plot_trials (missed_trials_df , params , s = 1 , c = "black" , label = "abort" )
180+ plot_trials (correct_trials_df , params , s = params ['marker_size' ], c = clr_index_corr , label = "reward" )
181+ plot_trials (incorrect_trials_df , params , s = params ['marker_size' ], label = "punish" ,
182+ facecolor = "none" , edgecolor = "black" , marker = "o" , linewidth = 0.5 )
183+ plot_trials (missed_trials_df , params , s = params ['marker_size' ]* 0.2 , c = "black" , label = "abort" )
176184
177- plt .ylabel ("Difficulty" )
185+ plt .ylabel ("difficulty levels" , fontsize = 12 )
186+ plt .xlabel ("trials" , fontsize = 12 )
178187 plt .title (
179188 f"Animal:{ animal_id } , Session:{ session } \n \
180189 Reward: { len (correct_trials_df )} , Punish: { len (incorrect_trials_df )} , Abort: { len (missed_trials_df )} "
181190 )
182191 plt .ylim (params ["ylim" ][0 ])
183192 plt .xlim (params ["xlim" ][0 ])
184- plt .yticks (np .unique (difficulties ))
193+ plt .yticks (np .unique (difficulties ), fontsize = 10 )
194+ plt .xticks (fontsize = 10 )
185195 plt .box (False )
186196 legend_elements = [
187197 Line2D (
@@ -190,26 +200,27 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
190200 marker = "o" ,
191201 color = "w" ,
192202 label = "punish" ,
193- markerfacecolor = "black" ,
194- markersize = 8 ,
203+ markerfacecolor = "none" ,
204+ markeredgecolor = "black" ,
205+ markersize = params ['marker_size' ],
195206 ),
196207 Line2D (
197208 [0 ],
198209 [0 ],
199210 marker = "o" ,
200211 color = "w" ,
201212 label = "reward (port 1)" ,
202- markerfacecolor = "red" ,
203- markersize = 8 ,
213+ markerfacecolor = "tab: red" ,
214+ markersize = params [ 'marker_size' ] ,
204215 ),
205216 Line2D (
206217 [0 ],
207218 [0 ],
208219 marker = "o" ,
209220 color = "w" ,
210221 label = "reward (port 2)" ,
211- markerfacecolor = "dodgerblue " ,
212- markersize = 8 ,
222+ markerfacecolor = "tab:blue " ,
223+ markersize = params [ 'marker_size' ] ,
213224 ),
214225 Line2D (
215226 [0 ],
@@ -218,11 +229,10 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
218229 color = "w" ,
219230 label = "abort" ,
220231 markerfacecolor = "black" ,
221- markersize = 4 ,
232+ markersize = params [ 'marker_size' ] * 0.5 ,
222233 ),
223234 ]
224235 plt .legend (handles = legend_elements , bbox_to_anchor = (1.04 , 1 ), loc = "upper left" )
225- # plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
226236
227237 if save_path :
228238 save_plot (plt .gcf (), save_path )
@@ -333,7 +343,7 @@ def LickPlot(
333343
334344 key_animal_session = {"animal_id" : animal_id , "session" : session }
335345 params = {
336- "port_colors" : ["red" , "blue" ], # set function parameters with defaults
346+ "port_colors" : ["tab: red" , "tab: blue" ], # set function parameters with defaults
337347 "xlim" : [- 500 , 10000 ],
338348 "figsize" : (15 , 15 ),
339349 "dotsize" : 3 ,
@@ -437,7 +447,7 @@ def LickPlot(
437447 marker = "o" ,
438448 color = "w" ,
439449 label = "Reward" ,
440- markerfacecolor = "green" ,
450+ markerfacecolor = "tab: green" ,
441451 markersize = 8 ,
442452 ),
443453 Line2D (
@@ -446,7 +456,7 @@ def LickPlot(
446456 marker = "o" ,
447457 color = "w" ,
448458 label = "Punish" ,
449- markerfacecolor = "red" ,
459+ markerfacecolor = "tab: red" ,
450460 markersize = 8 ,
451461 ),
452462 ]
@@ -458,7 +468,7 @@ def LickPlot(
458468 marker = "o" ,
459469 color = "w" ,
460470 label = "lick port 1" ,
461- markerfacecolor = "red" ,
471+ markerfacecolor = "tab: red" ,
462472 markersize = 8 ,
463473 ),
464474 Line2D (
0 commit comments