Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
40 changes: 34 additions & 6 deletions Python/shapeworks/shapeworks/stats.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,14 +92,18 @@ def lda_loadings(group1_data, group2_data):
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)


def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
def _project_and_pdf(diffVect, group1_data, group2_data, combined_data, normalize_projections=True):
"""Shared logic for projecting groups onto a discriminant direction and fitting PDFs.

Args:
diffVect: Discriminant direction vector (features,)
group1_data: PCA loadings for group 1 (features x samples)
group2_data: PCA loadings for group 2 (features x samples)
combined_data: Concatenation of group1_data and group2_data (features x all_samples)
normalize_projections: If True, normalize so group means map to -1 and +1.
This works well when diffVect is aligned with the mean difference (e.g. LDA).
Set to False for directions that may not be aligned with the mean difference
(e.g. DWD), which would cause the normalization to produce extreme values.

Returns: 6-tuple (group1_x, group2_x, group1_pdf, group2_pdf, group1_map, group2_map)
"""
Expand All @@ -122,12 +126,14 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
for ii in range(group1_num):
subjDiff = group1_data[:, ii] - overall_mean
group1_map[ii] = np.dot(diffVect, subjDiff)
group1_map[ii] = normalize(group1_map[ii], group1_mean_map, group2_mean_map)
if normalize_projections:
group1_map[ii] = normalize(group1_map[ii], group1_mean_map, group2_mean_map)

for ii in range(group2_num):
subjDiff = group2_data[:, ii] - overall_mean
group2_map[ii] = np.dot(diffVect, subjDiff)
group2_map[ii] = normalize(group2_map[ii], group1_mean_map, group2_mean_map)
if normalize_projections:
group2_map[ii] = normalize(group2_map[ii], group1_mean_map, group2_mean_map)

group1_map_mean = group1_map.mean()
group2_map_mean = group2_map.mean()
Expand All @@ -142,8 +148,20 @@ def _project_and_pdf(diffVect, group1_data, group2_data, combined_data):
if group2_map_std < min_std:
group2_map_std = min_std

group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)
if normalize_projections:
group1_x = np.linspace(group1_map_mean - 6, group1_map_mean + 6, num=300)
group2_x = np.linspace(group2_map_mean - 6, group2_map_mean + 6, num=300)
else:
# Common x-range covering both groups and all shape mappings so PDF
# tails extend smoothly across the full plot
all_maps = np.concatenate([group1_map, group2_map])
max_std = max(group1_map_std, group2_map_std)
x_min = min(all_maps.min(), group1_map_mean - 6 * group1_map_std,
group2_map_mean - 6 * group2_map_std) - max_std
x_max = max(all_maps.max(), group1_map_mean + 6 * group1_map_std,
group2_map_mean + 6 * group2_map_std) + max_std
group1_x = np.linspace(x_min, x_max, num=300)
group2_x = np.linspace(x_min, x_max, num=300)

group1_pdf = stats.norm.pdf(group1_x, group1_map_mean, group1_map_std)
group2_pdf = stats.norm.pdf(group2_x, group2_map_mean, group2_map_std)
Expand Down Expand Up @@ -172,7 +190,17 @@ def dwd_loadings(group1_data, group2_data):

diffVect = model.coef_.flatten()

return _project_and_pdf(diffVect, group1_data, group2_data, combined_data)
# Normalize to unit length so projections reflect data geometry, not solver scale
norm = np.linalg.norm(diffVect)
if norm > 1e-12:
diffVect = diffVect / norm

# DWD's direction optimizes for margin, not mean separation, so it may be
# nearly orthogonal to the mean difference. The mean-based normalization in
# _project_and_pdf divides by the projection of the mean difference onto
# diffVect, which can be near-zero, producing extreme values.
# Use raw projections with adaptive PDF ranges instead.
return _project_and_pdf(diffVect, group1_data, group2_data, combined_data, normalize_projections=False)


def lda(data):
Expand Down
6 changes: 5 additions & 1 deletion Studio/Job/StatsGroupDWDJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,15 @@ void StatsGroupDWDJob::plot(JKQTPlotter* plot, QString group_1_name, QString gro
draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54));
draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue);

// Place shape mapping dots near the bottom of the plot, scaled to peak PDF height
double peak_pdf = std::max(group1_pdf_.maxCoeff(), group2_pdf_.maxCoeff());
double scatter_y = peak_pdf * 0.03;

auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) {
QVector<double> x, y;
for (int i = 0; i < map.size(); i++) {
x << map(i);
y << 0.01;
y << scatter_y;
}

int column_x = ds->addCopiedColumn(x, name + "scatter x");
Expand Down
6 changes: 5 additions & 1 deletion Studio/Job/StatsGroupLDAJob.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,11 +111,15 @@ void StatsGroupLDAJob::plot(JKQTPlotter* plot, QString group_1_name, QString gro
draw_line_plot(group1_x_, group1_pdf_, group_1_name, QColor(239, 133, 54));
draw_line_plot(group2_x_, group2_pdf_, group_2_name, Qt::blue);

// Place shape mapping dots near the bottom of the plot, scaled to peak PDF height
double peak_pdf = std::max(group1_pdf_.maxCoeff(), group2_pdf_.maxCoeff());
double scatter_y = peak_pdf * 0.03;

auto draw_scatter_plot = [&](Eigen::MatrixXd map, QString name, QColor color) {
QVector<double> x, y;
for (int i = 0; i < map.size(); i++) {
x << map(i);
y << 0.01;
y << scatter_y;
}

int column_x = ds->addCopiedColumn(x, name + "scatter x");
Expand Down
Loading