Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/register_config_test_elastic.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ coherent_point_drift:
w: 0.00001
lmd: 0.1
beta: 100
maxiter: 150 # Low for debugging purpose, should be 100-150
maxiter: 100 # Low for debugging purpose, should be 100-150

ILP:
min_neighbours: 10
Expand Down
2 changes: 1 addition & 1 deletion examples/register_config_test_rigid.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ ILP:
min_neighbours: 10
max_dist: 30

mobie_export: True
mobie_export: False
semantic_seg: True
mobie_dataset_name: "platy1_muscles_stardist"
15 changes: 9 additions & 6 deletions matchmaker/cpd_nonrigid_registration.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
import click
import logging
Expand All @@ -22,17 +23,17 @@ def cpd_from_images(fixed_img, fixed_resolution, moving_img, moving_resolution,
fixed_pcd = create_pcd(fixed_center_coords, fixed_labels)
moving_pcd = create_pcd(moving_center_coord, moving_labels)

overlay_pcds(fixed_pcd, moving_pcd, projection="xz", save_path = output_dir / "pcds_before_registration_xz.png")
overlay_pcds(fixed_pcd, moving_pcd, projection="yz", save_path = output_dir / "pcds_before_registration_yz.png")
overlay_pcds(fixed_pcd, moving_pcd, projection="xz", save_path = output_dir / "plots/pcds_before_registration_xz.png")
overlay_pcds(fixed_pcd, moving_pcd, projection="yz", save_path = output_dir / "plots/pcds_before_registration_yz.png")

logging.info(f"Point cloud registration with parameters w={w}, beta={beta}, lmd={lmd}, maxiter={maxiter}")

registered_pcd = run_cpd(fixed_pcd, moving_pcd, w, beta, lmd, maxiter)

overlay_pcds(fixed_pcd, registered_pcd, projection="xz", save_path = output_dir / "pcds_after_registration_xz.png")
overlay_pcds(fixed_pcd, registered_pcd, projection="yz", save_path = output_dir / "pcds_after_registration_yz.png")
overlay_pcds(fixed_pcd, registered_pcd, projection="xz", save_path = output_dir / "plots/pcds_after_registration_xz.png")
overlay_pcds(fixed_pcd, registered_pcd, projection="yz", save_path = output_dir / "plots/pcds_after_registration_yz.png")

visualize_displacement_field(moving_pcd, registered_pcd, projection="xz", save_path = output_dir / "displacement_field.png")
visualize_displacement_field(moving_pcd, registered_pcd, projection="xz", save_path = output_dir / "plots/displacement_field.pdf")

o3d.t.io.write_point_cloud(str(output_dir / "fixed_pcd.pcd"), fixed_pcd, write_ascii=True)
o3d.t.io.write_point_cloud(str(output_dir / "moving_pcd.pcd"), moving_pcd, write_ascii=True)
Expand All @@ -56,13 +57,15 @@ def main(fixed_path, fixed_key, moving_path, moving_key, output_dir, w, beta, lm
level=logging.INFO,
format="%(asctime)s [%(levelname)s] %(message)s",
handlers=[
logging.FileHandler(f"{output_dir}/rigid_alignment.log", mode="w"),
logging.FileHandler(f"{output_dir}/cpd_nonrigid_registration.log", mode="w"),
logging.StreamHandler(sys.stdout),
],
datefmt="%Y-%m-%d %H:%M:%S",
)

output_dir = Path(output_dir)
os.makedirs(output_dir / "plots", exist_ok=True)

logging.info("Reading fixed image")
fixed_img = read_volume(fixed_path, fixed_key)
fixed_resolution = get_attrs(fixed_path, fixed_key)["resolution"]
Expand Down
14 changes: 7 additions & 7 deletions matchmaker/elastix_deformable_pointset_registration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
read_volume,
write_volume,
get_attrs,
plot_overlay,
itk_scalar_img,
itk_to_np_order,
apply_transform_chanwise,
Expand Down Expand Up @@ -114,7 +113,7 @@ def elastix_deformable_pointset_alignment(
plot_overlay(
itk_to_np_order(itk.GetArrayFromImage(fixed_img)),
itk_to_np_order(itk.GetArrayFromImage(moving_img)),
output_dir / f"deformable_pointset_alignment_before.png",
f"{output_dir}/plots/deformable_pointset_alignment_before.pdf",
)

SCRIPT_DIR = Path(__file__).resolve().parent
Expand All @@ -140,7 +139,7 @@ def elastix_deformable_pointset_alignment(
plot_overlay(
itk_to_np_order(itk.GetArrayFromImage(fixed_img)),
result_img_np,
output_dir / f"deformable_pointset_alignment_semantic.png",
f"{output_dir}/plots/deformable_pointset_alignment_semantic.pdf",
)

logging.info(f"Apply transform to all channels")
Expand All @@ -151,7 +150,7 @@ def elastix_deformable_pointset_alignment(
plot_overlay(
fixed_img_np,
result_img_np,
output_dir / f"deformable_pointset_alignment_final.png",
f"{output_dir}/plots/deformable_pointset_alignment_final.pdf",
)
logging.info(f"Result image shape {result_img_np.shape}")

Expand All @@ -163,8 +162,8 @@ def elastix_deformable_pointset_alignment(
transformed_grid_np = apply_transform_chanwise(
result_transform_parameters, grid_img_np.astype(np.float32), moving_resolution
)
plot_overlay(fixed_img_np, grid_img_np, output_dir / f"grid_before.png")
plot_overlay(fixed_img_np, transformed_grid_np, output_dir / f"grid_after.png")
plot_overlay(fixed_img_np, grid_img_np, f"{output_dir}/plots/grid_before.png")
plot_overlay(fixed_img_np, transformed_grid_np, f"{output_dir}/plots/grid_after.png")

return result_img_np

Expand Down Expand Up @@ -236,6 +235,7 @@ def main(
)

output_dir = Path(output_dir)
os.makedirs(output_dir / "plots", exist_ok=True)

logging.info("Reading fixed image")
fixed_img = read_volume(fixed_path, fixed_key)
Expand Down Expand Up @@ -282,7 +282,7 @@ def main(
plot_overlay(
prealigned_fixed,
prealigned_moving_aligned,
output_dir / f"deformable_pointset_alignment_prealigned.png",
f"{output_dir}/plots/deformable_pointset_alignment_prealigned.pdf",
)


Expand Down
7 changes: 5 additions & 2 deletions matchmaker/match_pointclouds.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import sys
import click
import logging
Expand Down Expand Up @@ -43,12 +44,12 @@ def match_points(fixed_pcd, registered_pcd, output_dir):

if swap_order:
plot_matching_qc(
pos_2, pos_1, output_dir / "point_matching.png", pairs=matched_idx_pairs
pos_2, pos_1, f"{output_dir}/plots/point_matching.pdf", pairs=matched_idx_pairs
)

else:
plot_matching_qc(
pos_1, pos_2, output_dir / "point_matching.png", pairs=matched_idx_pairs
pos_1, pos_2, f"{output_dir}/plots/point_matching.pdf", pairs=matched_idx_pairs
)

return matched_idx_pairs, matched_label_pairs
Expand Down Expand Up @@ -87,6 +88,8 @@ def main(fixed_pcd, moving_pcd, output_dir, min_neighbours, max_dist):
)

output_dir = Path(output_dir)
os.makedirs(output_dir / "plots", exist_ok=True)

logging.info("Reading fixed point cloud")
fixed_pcd = o3d.t.io.read_point_cloud(fixed_pcd)

Expand Down
46 changes: 36 additions & 10 deletions matchmaker/prealignment.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
write_transform_dict, plot_three_slices, plot_overlay, setup_logging,
get_axis_orient_matrix, resample_volume, transform_axes_vis)

from matchmaker.utils.vis import CYAN_HEX, PINK, CYAN, PINK_HEX


def get_SVD_transform(img, spacing, save_path=None):
"""Convert image to point cloud by thresholding, then run SVD on resulting point cloud.
Expand Down Expand Up @@ -76,24 +78,24 @@ def orient_axis(fixed_prealigned, moving_prealigned, output_dir):
int_prof_x_fixed = np.sum(fixed_prealigned > 0, axis=(1, 0)) / np.sum(fixed_prealigned > 0)

plt.figure()
plt.plot(int_prof_z, label="moving")
plt.plot(int_prof_z_fixed, label="fixed")
plt.plot(int_prof_z_fixed, label="fixed", color=PINK_HEX)
plt.plot(int_prof_z, label="moving", color=CYAN_HEX)
plt.xlabel(f"Axis Z Coordinate")
plt.ylabel(f"Sum intensity along axis = Z")
plt.legend()
plt.savefig(f"{output_dir}/plots/axis_int_profile_Z.png", dpi=300)

plt.figure()
plt.plot(int_prof_y, label="moving")
plt.plot(int_prof_y_fixed, label="fixed")
plt.plot(int_prof_y_fixed, label="fixed", color=PINK_HEX)
plt.plot(int_prof_y, label="moving", color=CYAN_HEX)
plt.xlabel(f"Axis Y Coordinate")
plt.ylabel(f"Sum intensity along axis = Y")
plt.legend()
plt.savefig(f"{output_dir}/plots/axis_int_profile_Y.png", dpi=300)

plt.figure()
plt.plot(int_prof_x, label="moving")
plt.plot(int_prof_x_fixed, label="fixed")
plt.plot(int_prof_x_fixed, label="fixed", color=PINK_HEX)
plt.plot(int_prof_x, label="moving", color=CYAN_HEX)
plt.xlabel(f"Axis X Coordinate")
plt.ylabel(f"Sum intensity along axis = X")
plt.legend()
Expand Down Expand Up @@ -313,22 +315,42 @@ def run_prealignment(

plot_three_slices(
fixed_img,
save_path=f"{output_dir}/plots/fixed_input.png",
save_path=f"{output_dir}/plots/fixed_input.pdf",
gc=gc_fixed,
Vt=Vt_fixed,
cmap="gnuplot2_r"
)

plot_three_slices(
moving_img,
save_path=f"{output_dir}/plots/moving_input.png",
save_path=f"{output_dir}/plots/moving_input.pdf",
gc=gc_moving,
Vt=Vt_moving,
cmap="gnuplot2_r"
)

plot_three_slices(
fixed_img,
save_path=f"{output_dir}/plots/fixed_input_semantic.pdf",
gc=gc_fixed,
Vt=Vt_fixed,
cmap=PINK,

)

plot_three_slices(
moving_img,
save_path=f"{output_dir}/plots/moving_input_semantic.pdf",
gc=gc_moving,
Vt=Vt_moving,
cmap=CYAN,

)

plot_overlay(
fixed_img,
moving_img,
save_path=f"{output_dir}/plots/overlay_input.png",
save_path=f"{output_dir}/plots/overlay_input.pdf",
gc1=gc_fixed,
Vt1=Vt_fixed,
gc2=gc_moving,
Expand Down Expand Up @@ -404,13 +426,17 @@ def run_prealignment(
save_path=f"{output_dir}/plots/fixed_prealigned.png",
gc = (np.linalg.inv(T_fixed) @ np.append(gc_fixed, 1))[:3],
Vt = transform_axes_vis(Vt_fixed, T_fixed),
cmap=PINK,

)

plot_three_slices(
moving_prealigned,
save_path=f"{output_dir}/plots/moving_prealigned.png",
save_path=f"{output_dir}/plots/moving_prealigned.pdf",
gc = (np.linalg.inv(T_moving) @ np.append(gc_moving, 1))[:3],
Vt = transform_axes_vis(Vt_moving, T_moving),
cmap=CYAN,

)

plot_overlay(
Expand Down
29 changes: 22 additions & 7 deletions matchmaker/raw_to_n5.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import tifffile as tif
from pathlib import Path

from matchmaker.utils import (read_volume, write_volume, get_attrs, set_attrs, plot_three_slices, convert_to_int)
from matchmaker.utils import (read_volume, write_volume, plot_three_slices, convert_to_int)


def preprocess_tif_input(input_path, output_path, output_key, log_dir, x_res, y_res, z_res):
Expand All @@ -17,7 +17,7 @@ def preprocess_tif_input(input_path, output_path, output_key, log_dir, x_res, y_

assert (image.ndim == 3) or (
image.ndim == 4
), "Currently pipeline only works with ZYX or CZYX images, input has {image.ndim} dimensions"
), f"Currently pipeline only works with ZYX or CZYX images, input has {image.ndim} dimensions"

image = convert_to_int(image)
logging.info(f"Writing output image to {output_path}")
Expand All @@ -28,11 +28,19 @@ def preprocess_tif_input(input_path, output_path, output_key, log_dir, x_res, y_
if image.ndim == 4:
chunks = (1, 128, 512, 512)
for chan in range(image.shape[0]):
plot_three_slices(image[chan], log_dir / f"input_image_{Path(input_path).stem}_{chan}.png")
plot_three_slices(
image[chan],
log_dir / f"input_image_{Path(input_path).stem}_{chan}.png",
cmap="gnuplot2_r",
)

else:
chunks = (128, 512, 512)
plot_three_slices(image, log_dir / f"input_image_{Path(input_path).stem}.png")
plot_three_slices(
image,
log_dir / f"input_image_{Path(input_path).stem}.png",
cmap="gnuplot2_r",
)

write_volume(output_path, image, output_key, chunks=chunks, attrs=attrs)

Expand All @@ -58,16 +66,23 @@ def preprocess_n5_input(input_path, input_key, output_path, output_key, log_dir,
if image.ndim == 4:
chunks = (1, 128, 512, 512)
for chan in range(image.shape[0]):
plot_three_slices(image[chan], log_dir / f"input_image_{Path(input_path).stem}_{chan}.png")
plot_three_slices(
image[chan],
log_dir / f"input_image_{Path(input_path).stem}_{chan}.png",
cmap="gnuplot2_r",
)

else:
chunks = (128, 512, 512)
plot_three_slices(image, log_dir / f"input_image_{Path(input_path).stem}.png")
plot_three_slices(
image,
save_path=log_dir / f"input_image_{Path(input_path).stem}.png",
cmap="gnuplot2_r",
)

write_volume(output_path, image, output_key, chunks=chunks, attrs=attrs)



@click.command()
@click.option("-in", "--input_path", required=True, help="Path of the input image in .tif format")
@click.option("-ink", "--input_key", required=False, default=None, help="Key of the input image in .n5 format")
Expand Down
2 changes: 1 addition & 1 deletion matchmaker/rigid_alignment_elastix.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ def elastix_segm_rigid_alignment(
plot_overlay(
itk_to_np_order(itk.GetArrayFromImage(fixed_img)),
result_img_np,
f"{output_dir}/plots/overlay_after_rigid_alignment.png",
f"{output_dir}/plots/overlay_after_rigid_alignment.pdf",
)

logging.info("Apply transform to all channels")
Expand Down
Loading