Skip to content

Commit 787e312

Browse files
committed
Rewrite Analysis_tools with publication-quality figures and bootstrap correlation
Analysis_tools.py: - Add publication-quality matplotlib figures with precise cm-based dimensions, pt-based line widths, and 5-7pt Arial fonts for SVG output - Add bootstrap correlation with confidence intervals (Pearson/Spearman) - Add MMP inhibitor analysis per cell line (AICS-0036, AICS-0000) - Add immunolabeling mean intensity analysis with statistical testing - Add immunolabeling heatmap visualization - Add Bland-Altman analysis for migration time agreement - Fix inside-outside migration timing with bias-corrected scatter plots - Use consistent INTENSITY_Y_CONFIG for y-axis ranges across all gene plots plot_tools.py: - Add publication-quality figure dimensions (cm-based) to plot_examples - Add nearest-timepoint matching to handle floating-point precision - Add configurable y-axis ranges via INTENSITY_Y_CONFIG - Add formatted x/y tick labels at regular intervals for ZO1 heatmaps - Set SVG text to editable (svg.fonttype = 'none') const.py: - Add INTENSITY_Y_CONFIG dictionary for per-gene y-axis settings io.py: - Add load_from_aws and local_path parameters to load_imaging_and_segmentation_dataset - Fix load_image_analysis_extracted_features to properly respect load_from_aws flag Dependencies: - Add joblib>=1.3.0 to pyproject.toml - Update pdm.lock and requirements.txt
1 parent accd29c commit 787e312

7 files changed

Lines changed: 2669 additions & 2711 deletions

File tree

EMT_data_analysis/analysis_scripts/Analysis_tools.py

Lines changed: 997 additions & 160 deletions
Large diffs are not rendered by default.

EMT_data_analysis/analysis_scripts/plot_tools.py

Lines changed: 109 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,23 +3,50 @@
33
import seaborn as sns
44
from scipy import stats
55
import scikit_posthocs as sp
6+
import matplotlib as mpl
67
import matplotlib.pyplot as plt
78

9+
mpl.rcParams['svg.fonttype'] = 'none' # Keep text editable in SVG
10+
811
from EMT_data_analysis.tools import const
912

13+
14+
def _find_nearest_timepoint(df, time_column, target_time):
15+
"""
16+
Find the nearest timepoint value in a dataframe column to a target time.
17+
18+
Parameters
19+
----------
20+
df : DataFrame
21+
Dataframe containing the time column
22+
time_column : str
23+
Name of the column containing timepoint values
24+
target_time : float
25+
Target time value to find the nearest match for
26+
27+
Returns
28+
-------
29+
float
30+
The nearest timepoint value from the dataframe
31+
"""
32+
timepoints = df[time_column].unique()
33+
idx = np.abs(timepoints - target_time).argmin()
34+
return timepoints[idx]
35+
36+
1037
def plot_examples(df_int, id_plf, id_2d, id_3d, gene, figs_dir, metric,variable='Mean Intensity', out_type='pdf'):
1138
'''
12-
This function plots one example for individual trajectories of mean intensity over time for each condition to represent how the gene metrics
39+
This function plots one example for individual trajectories of mean intensity over time for each condition to represent how the gene metrics
1340
(time at max EOMES expression, Time at inflection of E-Cad loss and Time at half maximal loss of SOX2 expression) were estimated.
1441
It is also used to plot migration time estimation example for area at glass over time.
1542
Parameters
1643
----------
1744
df_int: DataFrame
1845
Dataframe with mean intensity over time information for each movie in the dataset along with the respective gene metrics.
19-
46+
2047
id_plf: String
2148
Movie ID to plot the mean intensity trajectory for a movie with 2D PLF colony EMT condition
22-
49+
2350
id_2d: String
2451
Movie ID to plot the mean intensity trajectory for a movie with 2D colony EMT condition
2552
@@ -39,35 +66,84 @@ def plot_examples(df_int, id_plf, id_2d, id_3d, gene, figs_dir, metric,variable=
3966
-------
4067
saves plots in the figs_dir'''
4168

69+
# Publication figure dimensions
70+
cm_to_inch = 1 / 2.54
71+
fig_width_cm = 2.8846 # x-axis width
72+
fig_height_cm = 1.889 # y-axis height
73+
pad_left = 0.55
74+
pad_bottom = 0.45
75+
pad_right = 0.05
76+
pad_top = 0.05
77+
total_w = fig_width_cm * cm_to_inch + pad_left + pad_right
78+
total_h = fig_height_cm * cm_to_inch + pad_bottom + pad_top
79+
pt_to_inch = 1 / 72.0
80+
81+
y_cfg = const.INTENSITY_Y_CONFIG.get(gene, None)
82+
83+
# Colors
84+
color_orange = (255/255, 165/255, 0/255)
85+
color_blue = (0/255, 191/255, 255/255)
86+
color_purple = (139/255, 0/255, 139/255)
87+
trace_lw = 0.75 * pt_to_inch * 72 # 0.75 pt
88+
4289
df_plf=df_int[df_int['Data ID']==id_plf]
4390
df_2d=df_int[df_int['Data ID']==id_2d]
4491
df_3d=df_int[df_int['Data ID']==id_3d]
4592

46-
fig,ax=plt.subplots(1,1,figsize=(8,6))
93+
fig, ax = plt.subplots(1, 1)
4794

95+
# Use nearest timepoint matching to handle floating point precision differences
4896
x_metric_2d=df_2d[metric].values[0]
49-
y_metric_2d=df_2d[variable][df_2d['Timepoint (h)']==x_metric_2d].values[0]
50-
ax.plot(df_2d['Timepoint (h)'],df_2d[variable], c='deepskyblue', linewidth=3)
51-
ax.scatter(x_metric_2d,y_metric_2d,c='black', marker='D', s=100)
52-
97+
nearest_tp_2d = _find_nearest_timepoint(df_2d, 'Timepoint (h)', x_metric_2d)
98+
y_metric_2d=df_2d[variable][df_2d['Timepoint (h)']==nearest_tp_2d].values[0]
99+
ax.plot(df_2d['Timepoint (h)'],df_2d[variable], c=color_blue, linewidth=trace_lw)
100+
ax.scatter(x_metric_2d,y_metric_2d,c='black', marker='D', s=8, zorder=5)
53101

54102
x_metric_plf=df_plf[metric].values[0]
55-
y_metric_plf=df_plf[variable][df_plf['Timepoint (h)']==x_metric_plf].values[0]
56-
ax.plot(df_plf['Timepoint (h)'],df_plf[variable], c='darkmagenta', linewidth=3)
57-
ax.scatter(x_metric_plf,y_metric_plf,c='black', marker='D', s=100)
103+
nearest_tp_plf = _find_nearest_timepoint(df_plf, 'Timepoint (h)', x_metric_plf)
104+
y_metric_plf=df_plf[variable][df_plf['Timepoint (h)']==nearest_tp_plf].values[0]
105+
ax.plot(df_plf['Timepoint (h)'],df_plf[variable], c=color_purple, linewidth=trace_lw)
106+
ax.scatter(x_metric_plf,y_metric_plf,c='black', marker='D', s=8, zorder=5)
58107

59108
x_metric_3d=df_3d[metric].values[0]
60-
y_metric_3d=df_3d[variable][df_3d['Timepoint (h)']==x_metric_3d].values[0]
61-
ax.plot(df_3d['Timepoint (h)'],df_3d[variable], c='orange', linewidth=3)
62-
ax.scatter(x_metric_3d,y_metric_3d,c='black', marker='D', s=100)
109+
nearest_tp_3d = _find_nearest_timepoint(df_3d, 'Timepoint (h)', x_metric_3d)
110+
y_metric_3d=df_3d[variable][df_3d['Timepoint (h)']==nearest_tp_3d].values[0]
111+
ax.plot(df_3d['Timepoint (h)'],df_3d[variable], c=color_orange, linewidth=trace_lw)
112+
ax.scatter(x_metric_3d,y_metric_3d,c='black', marker='D', s=8, zorder=5)
113+
114+
# Y-axis
115+
if y_cfg is not None:
116+
ymin, ymax = y_cfg['ylim']
117+
y_pad = y_cfg['ytick_interval'] * 0.3
118+
ax.set_ylim(ymin - y_pad, ymax + y_pad)
119+
ax.set_yticks(np.arange(ymin, ymax + 1, y_cfg['ytick_interval']))
120+
ylabel = y_cfg.get('ylabel', 'Mean intensity (AU)')
121+
ax.set_ylabel(ylabel, fontsize=5, fontfamily='Arial')
122+
else:
123+
ax.set_ylabel(variable, fontsize=5, fontfamily='Arial')
63124

125+
# X-axis: 0-50, interval 10, with padding
126+
ax.set_xlim(-2, 52)
127+
ax.set_xticks(np.arange(0, 51, 10))
128+
ax.set_xlabel('Time (h)', fontsize=5, fontfamily='Arial')
64129

65-
plt.ylabel(f'{variable}', fontsize=16)
66-
plt.xlabel('Time (h)', fontsize=16)
130+
# Tick label styling
131+
ax.tick_params(axis='both', labelsize=5, width=0.5 * pt_to_inch * 72,
132+
length=3, direction='out', pad=2)
133+
for label in ax.get_xticklabels() + ax.get_yticklabels():
134+
label.set_fontfamily('Arial')
135+
label.set_fontsize(5)
136+
137+
# Axis line (spine) width: 0.5 pt
138+
for spine in ax.spines.values():
139+
spine.set_linewidth(0.5 * pt_to_inch * 72)
140+
141+
# Remove top and right spines
142+
ax.spines['top'].set_visible(False)
143+
ax.spines['right'].set_visible(False)
67144

68-
plt.xlim(-1,50)
69-
plt.tight_layout()
70145
plt.savefig(rf'{figs_dir}/Example_{gene}_{metric}.{out_type}', dpi=600, transparent=True)
146+
plt.close(fig)
71147

72148
def run_statistics (x,y,z):
73149
'''
@@ -201,5 +277,20 @@ def Intensity_over_z(df, figs_dir, color_map='coolwarm', out_type='pdf'):
201277

202278
ax= sns.heatmap(df_nanmerge, cmap=color_map, vmin=color_min, vmax=color_max )
203279
ax.invert_yaxis()
280+
281+
# Set x-axis ticks at increments of 4 hours
282+
x_cols = df_nanmerge.columns.tolist()
283+
xtick_positions = [i for i, v in enumerate(x_cols) if v % 4 == 0]
284+
xtick_labels = [int(x_cols[i]) for i in xtick_positions]
285+
ax.set_xticks([p + 0.5 for p in xtick_positions])
286+
ax.set_xticklabels(xtick_labels)
287+
288+
# Set y-axis ticks at increments of 4 z-planes
289+
y_rows = df_nanmerge.index.tolist()
290+
ytick_positions = [i for i, v in enumerate(y_rows) if v % 4 == 0]
291+
ytick_labels = [int(y_rows[i]) for i in ytick_positions]
292+
ax.set_yticks([p + 0.5 for p in ytick_positions])
293+
ax.set_yticklabels(ytick_labels)
294+
204295
plt.title(f'Condition={c}, Data ID={id}')
205296
fig.savefig(rf'{figs_dir}/Histogram_zo1_{c}_{id}.{out_type}', dpi=600)

EMT_data_analysis/tools/const.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,13 @@
4848
EXAMPLE_ACM_IDS = [
4949
'3500005824_36',
5050
'3500006256_12'
51-
]
51+
]
52+
53+
# Y-axis configuration for mean intensity plots (shared by summary and example plots)
54+
INTENSITY_Y_CONFIG = {
55+
'SOX2': {'ylim': (100, 170), 'ytick_interval': 10},
56+
'TBXT': {'ylim': (100, 400), 'ytick_interval': 50},
57+
'EOMES': {'ylim': (100, 155), 'ytick_interval': 10},
58+
'CDH1': {'ylim': (100, 145), 'ytick_interval': 10},
59+
'HIST1H2BJ': {'ylim': (0, 170000), 'ytick_interval': 20000, 'ylabel': 'Colony area over bottom 2 Z (µm²)'},
60+
}

EMT_data_analysis/tools/io.py

Lines changed: 34 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,45 @@
55
def convert_to_windows_path(linux_path: Path):
66
return PurePosixPath(linux_path)
77

8-
def load_imaging_and_segmentation_dataset():
9-
df = pd.read_csv("https://allencell.s3.amazonaws.com/aics/emt_timelapse_dataset/manifests/imaging_and_segmentation_data.csv")
8+
def load_imaging_and_segmentation_dataset(load_from_aws: bool = True, local_path: str = None):
9+
"""
10+
Load the imaging and segmentation dataset.
11+
12+
Parameters
13+
----------
14+
load_from_aws : bool, default True
15+
If True, load from AWS S3. If False, load from local file.
16+
local_path : str, optional
17+
Path to local CSV file. If not provided and load_from_aws=False,
18+
will look for 'imaging_and_segmentation_data.csv' in the project root.
19+
20+
Returns
21+
-------
22+
df : DataFrame
23+
The imaging and segmentation dataset
24+
"""
25+
if load_from_aws:
26+
path = "https://allencell.s3.amazonaws.com/aics/emt_timelapse_dataset/manifests/imaging_and_segmentation_data.csv"
27+
else:
28+
if local_path is not None:
29+
path = local_path
30+
else:
31+
# Default local path: project root (parent of EMT_data_analysis package)
32+
project_root = Path(__file__).parent.parent.parent
33+
path = project_root / "imaging_and_segmentation_data.csv"
34+
print(f'Loading from local file: {path}')
35+
36+
df = pd.read_csv(path)
1037
n_movies = df['Data ID'].nunique()
1138
print(f'Total number of movies in the dataset: {n_movies}')
1239
return df
1340

1441
def load_image_analysis_extracted_features(load_from_aws: bool = True):
15-
metric_comp_results_dir = get_results_directory_name() / "metric_computation"
16-
path = metric_comp_results_dir / "Image_analysis_extracted_features.csv"
17-
try:
18-
print('Trying to load features from local path.')
19-
df = pd.read_csv(path)
20-
except Exception:
21-
print(f'Features not found at {path}. Loading from AWS instead. This may take a while...')
22-
path = "https://allencell.s3.amazonaws.com/aics/emt_timelapse_dataset/manifests/Image_analysis_extracted_features.csv?versionId=ehxRXxC0FpidcpgXU_z.51T.nkWB0Yuj"
23-
df = pd.read_csv(path)
42+
path = "https://allencell.s3.amazonaws.com/aics/emt_timelapse_dataset/manifests/Image_analysis_extracted_features.csv?versionId=ehxRXxC0FpidcpgXU_z.51T.nkWB0Yuj"
43+
if not load_from_aws:
44+
metric_comp_results_dir = get_results_directory_name() / "metric_computation"
45+
path = metric_comp_results_dir / "Image_analysis_extracted_features.csv"
46+
df = pd.read_csv(path)
2447
return df
2548

2649
def load_inside_outside_classification(load_from_aws: bool = True):

0 commit comments

Comments
 (0)