Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
117 changes: 56 additions & 61 deletions matchmaker/apply_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,22 +60,21 @@ def apply_transform(moving_img, moving_resolution, parameter_object, interpolati


@click.command()
@click.option("-mp", "--moving_paths", required=True, multiple=True, help="Paths to moving inputs")
@click.option("-mk", "--moving_keys", required=True, multiple=True, help="Keys of moving inputs")
@click.option("-mr", "--moving_resolutions", required=True, multiple=True, help="Resolutions of moving inputs")
@click.option("-op", "--output_paths", required=True, multiple=True, help="Paths to save warped images",)
@click.option("-ok", "--output_keys", required=True, multiple=True, help="Keys of moving outputs",)
@click.option("-io", "--interpolation_orders", required=True, multiple=True, help="Orders of interpolation",)
@click.option("-mp", "--moving_path", required=True, help="Path to moving input")
@click.option("-mk", "--moving_key", required=True, help="Key of moving input")
@click.option("-mr", "--moving_resolution", required=True, help="Resolution of moving input")
@click.option("-op", "--output_path", required=True, help="Path to save warped image")
@click.option("-ok", "--output_key", required=True, help="Key of moving output")
@click.option("-io", "--interpolation_order", required=True, help="Order of interpolation")
@click.option("-ld", "--log_dir", required=True, help="Log directory")
@click.option("-pm", "--parameter_map_path", required=True, help="Path to the parameter map",)
@click.option("-pt", "--prealignment_transform_path", default=None, help="Prealignment transform path",)
@click.option("-fi", "--fixed_path", default=None, help="Fixed input .n5 file")
@click.option("-fk", "--fixed_key", default=None, help="Fixed input key")
@click.option("-vb", "--verbose", is_flag=True, default=False, help="Show verbose logs")
def apply_transforms(moving_paths, moving_keys, moving_resolutions, output_paths, output_keys,
interpolation_orders, log_dir, parameter_map_path, prealignment_transform_path,
fixed_path, fixed_key, verbose):
assert len(moving_paths) == len(moving_keys) == len(moving_resolutions) == len(output_paths) == len(output_keys) == len(interpolation_orders)
def apply_transforms(moving_path, moving_key, moving_resolution, output_path, output_key,
interpolation_order, log_dir, parameter_map_path,
prealignment_transform_path, fixed_path, fixed_key, verbose):
log_dir = Path(log_dir)
log_dir.mkdir(exist_ok=True)

Expand All @@ -102,61 +101,57 @@ def apply_transforms(moving_paths, moving_keys, moving_resolutions, output_paths
logging.info("Rotate fixed image using prealignment transform")
fixed_prealigned = rotate_img(fixed_img, T_fixed, output_shape=output_shape)

for i, (moving_path, moving_key) in enumerate(zip(moving_paths, moving_keys)):
print("\n")
logging.info(f"[{i+1}/{len(moving_paths)}] Start processing moving image: {moving_path}")
moving_resolution = json.loads(moving_resolutions[i])
output_path, output_key = output_paths[i], output_keys[i]
interpolation_order = interpolation_orders[i]

moving_name = Path(moving_path).stem
logging.info("Read moving image")
moving_img = load_data(moving_path, moving_key)
if moving_path.endswith(".n5"):
if list(moving_resolution) != get_attrs(moving_path, moving_key)["resolution"]:
raise ValueError("Moving resolution from config is different from n5 file")

if moving_img.ndim == 3:
moving_img = moving_img[None, ...]
chunks = (128, 512, 512)
else:
chunks = (1, 128, 512, 512)

logging.info("Start transformation")
warped, warp_prealigned = apply_transform(moving_img,
moving_resolution,
parameter_object,
interpolation_order,
T_fixed=T_fixed,
output_shape=output_shape)

resolution = [float(res) for res in parameter_object.GetParameter(0, "Spacing")]

save_attrs = {}
if moving_path.endswith(".n5"):
attributes = dict(get_attrs(moving_path, moving_key))
attributes["resolution"] = resolution
save_attrs["chunks"] = chunks
save_attrs["attrs"] = attributes

logging.info("Plot warped image")
plot_three_slices(warped, save_path=log_dir / f"{moving_name}_warped.png")
if fixed_path:
logging.info("Plot overlay image")
plot_overlay(fixed_img, warped, log_dir / f"{moving_name}_warped_overlay.png",)
logging.info(f"Start processing moving image: {moving_path}")
moving_resolution = json.loads(moving_resolution)

if T_fixed is not None:
logging.info("Plot warped moving image after pre-alignment")
plot_three_slices(warp_prealigned, save_path=log_dir / f"{moving_name}_warp_prealigned.png")
moving_name = Path(moving_path).stem
logging.info("Read moving image")
moving_img = load_data(moving_path, moving_key)
if moving_path.endswith(".n5"):
if list(moving_resolution) != get_attrs(moving_path, moving_key)["resolution"]:
raise ValueError("Moving resolution from config is different from n5 file")

if fixed_path:
logging.info("Plot overlay image after pre-alignment")
plot_overlay(fixed_prealigned, warp_prealigned, log_dir / f"{moving_name}_warp_prealigned_overlay.png",)
if moving_img.ndim == 3:
moving_img = moving_img[None, ...]
chunks = (128, 512, 512)
else:
chunks = (1, 128, 512, 512)

logging.info("Start transformation")
warped, warp_prealigned = apply_transform(moving_img,
moving_resolution,
parameter_object,
interpolation_order,
T_fixed=T_fixed,
output_shape=output_shape)

resolution = [float(res) for res in parameter_object.GetParameter(0, "Spacing")]

save_attrs = {}
if moving_path.endswith(".n5"):
attributes = dict(get_attrs(moving_path, moving_key))
attributes["resolution"] = resolution
save_attrs["chunks"] = chunks
save_attrs["attrs"] = attributes

logging.info("Plot warped image")
plot_three_slices(warped, save_path=log_dir / f"{moving_name}_warped.png")
if fixed_path:
logging.info("Plot overlay image")
plot_overlay(fixed_img, warped, log_dir / f"{moving_name}_warped_overlay.png",)

save_data(warp_prealigned, output_path, output_key=output_key, **save_attrs)
if T_fixed is not None:
logging.info("Plot warped moving image after pre-alignment")
plot_three_slices(warp_prealigned, save_path=log_dir / f"{moving_name}_warp_prealigned.png")

else:
save_data(warped, output_path, output_key=output_key, **save_attrs)
if fixed_path:
logging.info("Plot overlay image after pre-alignment")
plot_overlay(fixed_prealigned, warp_prealigned, log_dir / f"{moving_name}_warp_prealigned_overlay.png",)

save_data(warp_prealigned, output_path, output_key=output_key, **save_attrs)

else:
save_data(warped, output_path, output_key=output_key, **save_attrs)


if __name__ == "__main__":
Expand Down
78 changes: 57 additions & 21 deletions workflows/apply_transform.smk
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import re

configfile: "examples/register_config_test_rigid_apply_transform.yaml"

Expand All @@ -18,9 +19,9 @@ prealignment_transform_path = config["prealignment_transform_path"]
fixed_path = config["fixed_image"]["input_path"]
fixed_key = config["fixed_image"]["input_key"]


def expand_flag(flag, values):
return " ".join(f"{flag} {v}" for v in values)
TARGET_OUTPUTS = [f"{p}/{k}" if p.endswith(".n5") else p for p, k in zip(output_paths, output_keys)]
FILE_OUTPUTS = [p for p in TARGET_OUTPUTS if not ".n5/" in p]
N5_OUTPUTS = [p for p in TARGET_OUTPUTS if ".n5/" in p]


def get_all_opts(d):
Expand All @@ -33,41 +34,76 @@ def get_all_opts(d):

rule all:
input:
output_paths,
TARGET_OUTPUTS,


rule apply_transform:
rule apply_transform_file:
input:
parameter_map_path = parameter_map_path,
moving_img = lambda w: moving_paths[TARGET_OUTPUTS.index(w.out_file)]
output:
[directory(f"{path}/{key}") if path.endswith(".n5") else path for path, key in zip(output_paths, output_keys)],
out_file = "{out_file}"
wildcard_constraints:
out_file = "^(" + "|".join(re.escape(p) for p in FILE_OUTPUTS) + ")$" if FILE_OUTPUTS else "$^"
params:
opts = lambda w: get_all_opts({
"prealignment_transform_path": prealignment_transform_path,
"fixed_path": fixed_path,
"fixed_key": fixed_key,
}),
moving_path = lambda w: moving_paths[TARGET_OUTPUTS.index(w.out_file)],
moving_key = lambda w: moving_keys[TARGET_OUTPUTS.index(w.out_file)],
moving_resolution = lambda w: moving_resolutions[TARGET_OUTPUTS.index(w.out_file)],
output_path = lambda w: output_paths[TARGET_OUTPUTS.index(w.out_file)],
output_key = lambda w: output_keys[TARGET_OUTPUTS.index(w.out_file)],
interpolation_order = lambda w: interpolation_orders[TARGET_OUTPUTS.index(w.out_file)],
conda: "matchmaker_env"
shell:
"""
python matchmaker/apply_transform.py \
{params.opts} \
--moving_path {params.moving_path} \
--moving_key {params.moving_key} \
--moving_resolution '{params.moving_resolution}' \
--output_path {params.output_path} \
--output_key {params.output_key} \
--interpolation_order {params.interpolation_order} \
--log_dir {log_dir} \
--parameter_map_path {input.parameter_map_path}
"""

moving_paths = expand_flag("--moving_paths", moving_paths),
moving_keys = expand_flag("--moving_keys", moving_keys),
moving_resolutions = expand_flag("--moving_resolutions", moving_resolutions),
output_paths = expand_flag("--output_paths", output_paths),
output_keys = expand_flag("--output_keys", output_keys),
interpolation_orders = expand_flag("--interpolation_orders", interpolation_orders),
log_dir = log_dir,

log: f"{log_dir}/matchmaker.log"
rule apply_transform_n5:
input:
parameter_map_path = parameter_map_path,
moving_img = lambda w: moving_paths[TARGET_OUTPUTS.index(w.out_dir)]
output:
out_dir = directory("{out_dir}")
wildcard_constraints:
out_dir = "^(" + "|".join(re.escape(p) for p in N5_OUTPUTS) + ")$" if N5_OUTPUTS else "$^"
params:
opts = lambda w: get_all_opts({
"prealignment_transform_path": prealignment_transform_path,
"fixed_path": fixed_path,
"fixed_key": fixed_key,
}),
moving_path = lambda w: moving_paths[TARGET_OUTPUTS.index(w.out_dir)],
moving_key = lambda w: moving_keys[TARGET_OUTPUTS.index(w.out_dir)],
moving_resolution = lambda w: moving_resolutions[TARGET_OUTPUTS.index(w.out_dir)],
output_path = lambda w: output_paths[TARGET_OUTPUTS.index(w.out_dir)],
output_key = lambda w: output_keys[TARGET_OUTPUTS.index(w.out_dir)],
interpolation_order = lambda w: interpolation_orders[TARGET_OUTPUTS.index(w.out_dir)],
conda: "matchmaker_env"
shell:
"""
python matchmaker/apply_transform.py \
{params.opts} \
{params.moving_paths} \
{params.moving_keys} \
{params.moving_resolutions} \
{params.output_paths} \
{params.output_keys} \
{params.interpolation_orders} \
--log_dir {params.log_dir} \
--moving_path {params.moving_path} \
--moving_key {params.moving_key} \
--moving_resolution '{params.moving_resolution}' \
--output_path {params.output_path} \
--output_key {params.output_key} \
--interpolation_order {params.interpolation_order} \
--log_dir {log_dir} \
--parameter_map_path {input.parameter_map_path}
"""