Skip to content

Commit 7fb75af

Browse files
committed
update plot style
1 parent b200de8 commit 7fb75af

3 files changed

Lines changed: 211 additions & 110 deletions

File tree

examples/session_analysis_example.ipynb

Lines changed: 169 additions & 78 deletions
Large diffs are not rendered by default.

src/ethopy_analysis/config/styles.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -39,26 +39,26 @@ def __init__(self):
3939
# Core settings
4040
self.font_size = 12
4141
self.title_size = 16
42-
self.label_size = 14
42+
self.label_size = 10
4343
self.tick_size = 10
44-
self.legend_size = 11
45-
self.font_family = "DejaVu Sans"
44+
self.legend_size = 12
45+
self.font_family = "Arial"
4646

4747
# Figure settings
4848
self.figure_size = (10, 6)
4949
self.dpi = 300
5050
self.background_color = "white"
5151

5252
# Colors
53-
self.primary_color = "#4169E1" # Sea green
54-
self.secondary_color = "#2E8B57" # Royal blue
55-
self.accent_color = "#DC143C" # Crimson
53+
self.primary_color = "#1f77b4" # Sea green
54+
self.secondary_color = "#2ca02c" # Royal blue
55+
self.accent_color = "#d62728" # Crimson
5656

5757
# Color palette for multiple categories
5858
self.color_palette = [
59-
"#4169E1",
60-
"#2E8B57",
61-
"#DC143C",
59+
"#1f77b4",
60+
"#2ca02c",
61+
"#d62728",
6262
"#FF8C00",
6363
"#8A2BE2",
6464
"#00CED1",
@@ -78,7 +78,7 @@ def __init__(self):
7878

7979
# Line and marker settings
8080
self.line_width = 2.0
81-
self.marker_size = 6.0
81+
self.marker_size = 10.0
8282

8383
def apply(self):
8484
"""Apply this style to matplotlib."""

src/ethopy_analysis/plots/session.py

Lines changed: 32 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -109,7 +109,7 @@ def plot_trials(trial_df: pd.DataFrame, params: Dict[str, Any], **kwargs) -> Non
109109
)
110110

111111

112-
def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
112+
def difficultyPlot(animal_id: int, session: int, save_path=None, params=None) -> None:
113113
"""Create a comprehensive difficulty plot for an animal session.
114114
115115
Generates a visualization showing trial outcomes (reward, punish, abort) across
@@ -145,20 +145,27 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
145145
trials_beh = get_trial_behavior(animal_id, session).drop(["time"], axis=1)
146146
ports_selection_corr_df = pd.merge(trials_beh, correct_trials_df, how="inner")
147147
perf_difficulty(animal_id, session)
148-
params = {
148+
149+
default_params = {
149150
"probe_colors": {
150151
1: [1, 0, 0],
151-
2: [0, 0.5, 1],
152+
2: [0.12156863, 0.46666667, 0.70588235],
152153
-1: [1, 0, 0],
153154
}, # colors for correct
154155
"trial_bins": 10, # how many trials on y axis
155156
"range": 0.9, # define offset range(diff is int so offset range(0,1))
156157
"xlim": (-2,), # plot lims
157158
"ylim": (min_difficulty - 0.6,),
158-
"figsize": (16, 6),
159-
# **kwargs,
159+
"figsize": (12, 10),
160+
"marker_size": 10,
160161
}
161162

163+
if params is None:
164+
params = default_params
165+
else:
166+
# Merge user params with defaults, user params take priority
167+
params = {**default_params, **params}
168+
162169
# create an array with colors for every correct trial based on the selected port
163170
clr_index_corr = np.array(
164171
[
@@ -170,18 +177,21 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
170177
)
171178

172179
plt.figure(figsize=params["figsize"], tight_layout=True)
173-
plot_trials(correct_trials_df, params, s=10, c=clr_index_corr, label="reward")
174-
plot_trials(incorrect_trials_df, params, s=10, c="black", label="punish")
175-
plot_trials(missed_trials_df, params, s=1, c="black", label="abort")
180+
plot_trials(correct_trials_df, params, s=params['marker_size'], c=clr_index_corr, label="reward")
181+
plot_trials(incorrect_trials_df, params, s=params['marker_size'], label="punish",
182+
facecolor="none", edgecolor="black", marker="o", linewidth=0.5)
183+
plot_trials(missed_trials_df, params, s=params['marker_size']*0.2, c="black", label="abort")
176184

177-
plt.ylabel("Difficulty")
185+
plt.ylabel("difficulty levels", fontsize=12)
186+
plt.xlabel("trials", fontsize=12)
178187
plt.title(
179188
f"Animal:{animal_id}, Session:{session} \n\
180189
Reward: {len(correct_trials_df)}, Punish: {len(incorrect_trials_df)}, Abort: {len(missed_trials_df)}"
181190
)
182191
plt.ylim(params["ylim"][0])
183192
plt.xlim(params["xlim"][0])
184-
plt.yticks(np.unique(difficulties))
193+
plt.yticks(np.unique(difficulties), fontsize=10)
194+
plt.xticks(fontsize=10)
185195
plt.box(False)
186196
legend_elements = [
187197
Line2D(
@@ -190,26 +200,27 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
190200
marker="o",
191201
color="w",
192202
label="punish",
193-
markerfacecolor="black",
194-
markersize=8,
203+
markerfacecolor="none",
204+
markeredgecolor="black",
205+
markersize=params['marker_size'],
195206
),
196207
Line2D(
197208
[0],
198209
[0],
199210
marker="o",
200211
color="w",
201212
label="reward (port 1)",
202-
markerfacecolor="red",
203-
markersize=8,
213+
markerfacecolor="tab:red",
214+
markersize=params['marker_size'],
204215
),
205216
Line2D(
206217
[0],
207218
[0],
208219
marker="o",
209220
color="w",
210221
label="reward (port 2)",
211-
markerfacecolor="dodgerblue",
212-
markersize=8,
222+
markerfacecolor="tab:blue",
223+
markersize=params['marker_size'],
213224
),
214225
Line2D(
215226
[0],
@@ -218,11 +229,10 @@ def difficultyPlot(animal_id: int, session: int, save_path=None) -> None:
218229
color="w",
219230
label="abort",
220231
markerfacecolor="black",
221-
markersize=4,
232+
markersize=params['marker_size']*0.5,
222233
),
223234
]
224235
plt.legend(handles=legend_elements, bbox_to_anchor=(1.04, 1), loc="upper left")
225-
# plt.legend(bbox_to_anchor=(1.04, 1), loc="upper left")
226236

227237
if save_path:
228238
save_plot(plt.gcf(), save_path)
@@ -333,7 +343,7 @@ def LickPlot(
333343

334344
key_animal_session = {"animal_id": animal_id, "session": session}
335345
params = {
336-
"port_colors": ["red", "blue"], # set function parameters with defaults
346+
"port_colors": ["tab:red", "tab:blue"], # set function parameters with defaults
337347
"xlim": [-500, 10000],
338348
"figsize": (15, 15),
339349
"dotsize": 3,
@@ -437,7 +447,7 @@ def LickPlot(
437447
marker="o",
438448
color="w",
439449
label="Reward",
440-
markerfacecolor="green",
450+
markerfacecolor="tab:green",
441451
markersize=8,
442452
),
443453
Line2D(
@@ -446,7 +456,7 @@ def LickPlot(
446456
marker="o",
447457
color="w",
448458
label="Punish",
449-
markerfacecolor="red",
459+
markerfacecolor="tab:red",
450460
markersize=8,
451461
),
452462
]
@@ -458,7 +468,7 @@ def LickPlot(
458468
marker="o",
459469
color="w",
460470
label="lick port 1",
461-
markerfacecolor="red",
471+
markerfacecolor="tab:red",
462472
markersize=8,
463473
),
464474
Line2D(

0 commit comments

Comments
 (0)