diff --git a/src/eaa/image_proc.py b/src/eaa/image_proc.py index a18be28..6962dfd 100644 --- a/src/eaa/image_proc.py +++ b/src/eaa/image_proc.py @@ -5,6 +5,7 @@ import matplotlib.pyplot as plt import scipy.ndimage as ndi from scipy import optimize +from scipy.special import erf from skimage.metrics import normalized_mutual_information from skimage.registration import phase_cross_correlation as skimage_phase_cross_correlation from sciagent.message_proc import generate_openai_message @@ -117,10 +118,30 @@ def add_frame_to_pil(image: Image.Image) -> Image.Image: return buffer +def _gaussian_rect_window(shape: tuple[int, int], decay_fraction: float = 0.2) -> np.ndarray: + """2D Gaussian-softened rectangle window. + + The mask is 1 in the center and decays smoothly to 0 at each edge. + The decay is the convolution of a step function with a Gaussian, implemented + via the error function. The transition from 1 to 0 spans ``decay_fraction`` + of the image size on each side. + """ + def _win_1d(size: int) -> np.ndarray: + decay_len = decay_fraction * size + sigma = decay_len / 4.0 + s2 = sigma * np.sqrt(2.0) + x = np.arange(size, dtype=float) + left = 0.5 * (1.0 + erf((x - decay_len / 2.0) / s2)) + right = 0.5 * (1.0 - erf((x - (size - 1.0 - decay_len / 2.0)) / s2)) + return np.minimum(left, right) + + return np.outer(_win_1d(shape[0]), _win_1d(shape[1])) + + def phase_cross_correlation( - moving: np.ndarray, - ref: np.ndarray, - use_hanning_window: bool = True, + moving: np.ndarray, + ref: np.ndarray, + filtering_method: Optional[Literal["hanning", "gaussian"]] = "hanning", upsample_factor: int = 1, ) -> np.ndarray | Tuple[np.ndarray, float]: """Phase correlation with windowing. The result gives @@ -133,9 +154,12 @@ def phase_cross_correlation( A 2D image. ref : np.ndarray A 2D image. - use_hanning_window : bool, optional - If True, a Hanning window is used to smooth the images before the - correlation is computed. + filtering_method : {"hanning", "gaussian"} or None, optional + Window function applied to both images before phase correlation to + reduce spectral leakage. ``"hanning"`` uses a standard Hanning window. + ``"gaussian"`` uses a Gaussian-softened rectangle that is 1 in the + centre and decays to 0 at each edge over 20 % of the image size. + Pass ``None`` to disable windowing. upsample_factor : int, optional Upsampling factor for subpixel accuracy in phase correlation. A value of 1 yields pixel-level precision. @@ -150,15 +174,24 @@ def phase_cross_correlation( ) moving = moving - moving.mean() ref = ref - ref.mean() - if use_hanning_window: + if filtering_method == "hanning": win_y = np.hanning(moving.shape[0]) win_x = np.hanning(moving.shape[1]) win = np.outer(win_y, win_x) moving_for_registration = moving * win ref_for_registration = ref * win - else: + elif filtering_method == "gaussian": + win = _gaussian_rect_window(moving.shape, decay_fraction=0.2) + moving_for_registration = moving * win + ref_for_registration = ref * win + elif filtering_method is None: moving_for_registration = moving ref_for_registration = ref + else: + raise ValueError( + f"Unknown filtering_method {filtering_method!r}. " + "Use 'hanning', 'gaussian', or None." + ) shift, _, _ = skimage_phase_cross_correlation( ref_for_registration, @@ -168,6 +201,124 @@ def phase_cross_correlation( return shift +def error_minimization_registration( + moving: np.ndarray, + ref: np.ndarray, + y_valid_fraction: float = 0.8, + x_valid_fraction: float = 0.8, + subpixel: bool = True, +) -> np.ndarray: + """Image registration by exhaustive integer-shift MSE search with quadratic + subpixel refinement. + + A central window of size ``(y_valid_fraction * h, x_valid_fraction * w)`` + is fixed in the reference image. The moving image is sampled at the same + window position for every integer shift (dy, dx) within the margins + ``[-max_dy, max_dy] × [-max_dx, max_dx]``, where the margins are the pixel + gaps between the valid window and the image boundary. No wrap-around pixels + are ever included: the valid window is identical for all shifts. + + The resulting 2-D MSE map is fitted with a 2-D quadratic polynomial. The + analytic minimum of that polynomial is returned as the sub-pixel shift. + + Parameters + ---------- + moving : np.ndarray + 2-D image to register. + ref : np.ndarray + 2-D reference image with the same shape as *moving*. + y_valid_fraction : float + Fraction of the image height occupied by the comparison window. + Values close to 1 leave little margin and therefore a small search range. + x_valid_fraction : float + Same as *y_valid_fraction* along the x (column) axis. + subpixel : bool + If True, perform subpixel refinement using a 2D quadratic fit. + + Returns + ------- + np.ndarray + Estimated (dy, dx) shift to apply to *moving* so that it aligns with + *ref*. + """ + assert moving.shape == ref.shape, ( + "The shapes of the moving and reference images must be the same." + ) + h, w = ref.shape + + vh = int(round(y_valid_fraction * h)) + vw = int(round(x_valid_fraction * w)) + + # Centre the valid window; margin on each side = max search range + r0 = (h - vh) // 2 + c0 = (w - vw) // 2 + r1, c1 = r0 + vh, c0 + vw + max_dy, max_dx = r0, c0 + + if max_dy == 0 and max_dx == 0: + return np.zeros(2) + + dy_vals = np.arange(-max_dy, max_dy + 1) + dx_vals = np.arange(-max_dx, max_dx + 1) + + ref_crop = ref[r0:r1, c0:c1].astype(float) + moving_f = moving.astype(float) + + # Exhaustive integer-shift MSE map + error_map = np.empty((len(dy_vals), len(dx_vals))) + for i, dy in enumerate(dy_vals): + for j, dx in enumerate(dx_vals): + diff = moving_f[r0 + dy : r1 + dy, c0 + dx : c1 + dx] - ref_crop + error_map[i, j] = np.mean(diff * diff) + + # Fit quadratic in a local neighbourhood around the integer minimum. + # Neighbourhood half-width: 10% of image size / 2, at least 1 (→ 3×3 minimum). + min_i, min_j = np.unravel_index(np.argmin(error_map), error_map.shape) + if not subpixel: + return -np.array([float(dy_vals[min_i]), float(dx_vals[min_j])]) + + half_y = max(1, int(round(0.05 * h))) + half_x = max(1, int(round(0.05 * w))) + i_lo = max(0, min_i - half_y) + i_hi = min(len(dy_vals) - 1, min_i + half_y) + j_lo = max(0, min_j - half_x) + j_hi = min(len(dx_vals) - 1, min_j + half_x) + local_dy = dy_vals[i_lo : i_hi + 1] + local_dx = dx_vals[j_lo : j_hi + 1] + local_err = error_map[i_lo : i_hi + 1, j_lo : j_hi + 1] + + # The 2-D quadratic has 6 parameters; require ≥3 points in each dimension so + # the design matrix is well-determined and the Hessian is not rank-deficient. + if len(local_dy) >= 3 and len(local_dx) >= 3: + # Fit: f(dy, dx) = a*dy² + b*dx² + c*dy*dx + d*dy + e*dx + g + dy_mesh, dx_mesh = np.meshgrid(local_dy, local_dx, indexing="ij") + dy_f = dy_mesh.ravel() + dx_f = dx_mesh.ravel() + design = np.column_stack( + [dy_f**2, dx_f**2, dy_f * dx_f, dy_f, dx_f, np.ones(len(dy_f))] + ) + coeffs, _, _, _ = np.linalg.lstsq(design, local_err.ravel(), rcond=None) + a, b, c, d, e, _ = coeffs + + # Analytic minimum: solve Hessian @ [dy_min, dx_min]ᵀ = -gradient + # Hessian = [[2a, c], [c, 2b]]; gradient at origin = [d, e] + hess = np.array([[2.0 * a, c], [c, 2.0 * b]]) + try: + if np.all(np.linalg.eigvalsh(hess) > 0): + shift = np.linalg.solve(hess, np.array([-d, -e])) + else: + raise np.linalg.LinAlgError("Hessian not positive definite") + except np.linalg.LinAlgError: + shift = np.array([float(dy_vals[min_i]), float(dx_vals[min_j])]) + else: + shift = np.array([float(dy_vals[min_i]), float(dx_vals[min_j])]) + + # Negate: the MSE is minimised at the offset where moving[r0+dy:] matches + # ref[r0:], but the caller wants the shift to apply to moving so that + # roll(moving, shift) ≈ ref, which is the opposite direction. + return -shift + + def normalize_image_01(image: np.ndarray) -> np.ndarray: """Normalize image intensities to [0, 1].""" image = np.nan_to_num(image, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32) diff --git a/src/eaa/maths.py b/src/eaa/maths.py index c798aba..922dd0f 100644 --- a/src/eaa/maths.py +++ b/src/eaa/maths.py @@ -1,6 +1,7 @@ import logging import numpy as np +import scipy.ndimage import scipy.optimize logger = logging.getLogger(__name__) @@ -35,7 +36,7 @@ def fit_gaussian_1d( y: np.ndarray, y_threshold: float = 0, ) -> tuple[float, float, float, float, float, float, float]: - """Fit a 1D Gaussian to the data after subtracting a linear background. + """Fit a 1D Gaussian to 1D data. Parameters ---------- @@ -58,31 +59,111 @@ def fit_gaussian_1d( """ x_data = np.array(x, dtype=float) y_data = np.array(y, dtype=float) + finite_mask = np.isfinite(x_data) & np.isfinite(y_data) + x_data = x_data[finite_mask] + y_data = y_data[finite_mask] + if x_data.size < 5: + logger.error("Too few finite data points for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, np.nan, np.nan + + order = np.argsort(x_data) + x_data = x_data[order] + y_data = y_data[order] x_min = float(np.min(x_data)) x_max_input = float(np.max(x_data)) + x_span = x_max_input - x_min + if not np.isfinite(x_span) or x_span <= 0: + logger.error("Invalid x range for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input + y_max, y_min = np.max(y_data), np.min(y_data) - x_max = x_data[np.argmax(y_data)] - offset = x_max + y_range = float(y_max - y_min) + if not np.isfinite(y_range) or np.isclose(y_range, 0.0): + logger.error("Input data are too flat for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input + + smooth_sigma = max(1.0, x_data.size * 0.02) + y_smooth = scipy.ndimage.gaussian_filter1d(y_data, sigma=smooth_sigma, mode="nearest") + y_smooth_max = float(np.max(y_smooth)) + y_smooth_min = float(np.min(y_smooth)) + y_smooth_range = y_smooth_max - y_smooth_min + if not np.isfinite(y_smooth_range) or np.isclose(y_smooth_range, 0.0): + logger.error("Smoothed data are too flat for Gaussian fitting. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input + + x_peak = float(x_data[np.argmax(y_smooth)]) + offset = x_peak x = x_data - offset - x_max = 0 - mask = y_data >= y_min + y_threshold * (y_max - y_min) - a_guess = y_max - y_min - mu_guess = x_max - x_above_thresh = x[y_data > y_min + a_guess * 0.2] - if len(x_above_thresh) >= 3: - sigma_guess = (x_above_thresh.max() - x_above_thresh.min()) / 2 + fit_threshold = y_smooth_min + y_threshold * y_smooth_range + mask = y_smooth >= fit_threshold + if int(np.count_nonzero(mask)) < 5: + mask = np.ones_like(y_data, dtype=bool) + + positive_weight = np.clip(y_smooth - y_smooth_min, a_min=0.0, a_max=None) + if np.sum(positive_weight) > 0: + mu_guess = float(np.sum(x * positive_weight) / np.sum(positive_weight)) else: - sigma_guess = (x.max() - x.min()) / 2 - c_guess = y_min + mu_guess = 0.0 + + width_mask = y_smooth >= (y_smooth_min + 0.5 * y_smooth_range) + x_above_half = x[width_mask] + if x_above_half.size >= 2: + sigma_guess = float((x_above_half.max() - x_above_half.min()) / 2.355) + else: + sigma_guess = x_span / 6.0 + sigma_guess = float(np.clip(sigma_guess, x_span / 100.0, x_span)) + + a_guess = max(y_smooth_range, np.finfo(float).eps) + c_guess = float(np.median(y_data[~mask])) if np.any(~mask) else float(y_smooth_min) p0 = [a_guess, mu_guess, sigma_guess, c_guess] - try: - popt, _ = scipy.optimize.curve_fit(gaussian_1d, x[mask], y_data[mask], p0=p0) - except RuntimeError: + lower_bounds = [0.0, float(np.min(x)), max(x_span / 1000.0, 1e-12), float(y_min - y_range)] + upper_bounds = [float(2 * y_range + abs(y_max)), float(np.max(x)), float(2 * x_span), float(y_max + y_range)] + + def run_fit(current_mask: np.ndarray, current_p0: list[float]) -> np.ndarray | None: + if int(np.count_nonzero(current_mask)) < 5: + return None + try: + popt, _ = scipy.optimize.curve_fit( + gaussian_1d, + x[current_mask], + y_data[current_mask], + p0=current_p0, + bounds=(lower_bounds, upper_bounds), + maxfev=20000, + ) + return popt + except (RuntimeError, ValueError): + return None + + popt = run_fit(mask, p0) + if popt is None: + y_smooth_retry = scipy.ndimage.gaussian_filter1d( + y_data, + sigma=max(2.0, x_data.size * 0.05), + mode="nearest", + ) + retry_weight = np.clip(y_smooth_retry - np.min(y_smooth_retry), a_min=0.0, a_max=None) + if np.sum(retry_weight) > 0: + mu_retry = float(np.sum(x * retry_weight) / np.sum(retry_weight)) + else: + mu_retry = 0.0 + retry_mask = y_smooth_retry >= ( + np.min(y_smooth_retry) + max(0.1, y_threshold) * (np.max(y_smooth_retry) - np.min(y_smooth_retry)) + ) + retry_p0 = [a_guess, mu_retry, sigma_guess, c_guess] + popt = run_fit(retry_mask, retry_p0) + + if popt is None: logger.error("Failed to fit Gaussian to data. Returning NaN values.") return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input y_fit = gaussian_1d(x, *popt) amplitude = float(popt[0]) + sigma = float(popt[2]) + mu = float(popt[1]) + if sigma <= 0 or not (float(np.min(x)) <= mu <= float(np.max(x))): + logger.error("Gaussian fit parameters are invalid. Returning NaN values.") + return np.nan, np.nan, np.nan, np.nan, np.nan, x_min, x_max_input if np.isclose(amplitude, 0.0): normalized_residual = np.nan else: diff --git a/src/eaa/task_manager/imaging/analytical_feature_tracking.py b/src/eaa/task_manager/imaging/analytical_feature_tracking.py index 50b7947..ad3da6e 100644 --- a/src/eaa/task_manager/imaging/analytical_feature_tracking.py +++ b/src/eaa/task_manager/imaging/analytical_feature_tracking.py @@ -188,7 +188,7 @@ def run( reference_image, psize_t=self.image_acquisition_tool.psize_k, psize_r=self.image_registration_tool.reference_pixel_size, - registration_algorithm_kwargs={"use_hanning_window": True}, + registration_algorithm_kwargs={"filtering_method": "hanning"}, ) if check_feature_presence_llm( task_manager=self, diff --git a/src/eaa/task_manager/tuning/analytical_focusing.py b/src/eaa/task_manager/tuning/analytical_focusing.py index 2b58ed6..e298829 100644 --- a/src/eaa/task_manager/tuning/analytical_focusing.py +++ b/src/eaa/task_manager/tuning/analytical_focusing.py @@ -18,7 +18,10 @@ from sciagent.message_proc import print_message from eaa.tool.imaging.acquisition import AcquireImage -from eaa.tool.imaging.line_scan_predictor import LineScanPredictor +from eaa.tool.imaging.aps_mic.test_target_landmark_fitting import ( + TestPatternLandmarkFitting, +) +from eaa.tool.imaging.nn_registration import NNRegistration from eaa.tool.imaging.param_tuning import SetParameters from eaa.task_manager.tuning.base import BaseParameterTuningTaskManager from eaa.tool.imaging.registration import ImageRegistration @@ -32,6 +35,9 @@ logger = logging.getLogger(__name__) +RegistrationToolType = ImageRegistration | TestPatternLandmarkFitting | NNRegistration + + class LineScanValidationFailed(RuntimeError): pass @@ -57,13 +63,12 @@ def __init__( line_scan_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), image_acquisition_tool_x_coordinate_args: Tuple[str, ...] = ("x_center",), image_acquisition_tool_y_coordinate_args: Tuple[str, ...] = ("y_center",), - registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm"] = "phase_correlation", - registration_algorithm_kwargs: Optional[dict[str, Any]] = None, + registration_target: Literal["previous", "initial"] = "previous", run_line_scan_checker: bool = True, run_offset_calibration: bool = True, - use_linear_drift_prediction: bool = False, - n_parameter_drift_points_before_prediction: int = 3, - line_scan_predictor_tool: Optional[LineScanPredictor] = None, + registration_tools: Optional[list[RegistrationToolType]] = None, + registration_selection_priming_iterations: int = 3, + primary_registration_tool_index: int = 0, *args, **kwargs ): """Analytical scanning microscope focusing task manager driven @@ -72,9 +77,12 @@ def __init__( The workflow is as follows: 1. Acquire a 2D image in the user-specified region of interest. 2. Run a line scan at user-specified coordinates and record the FWHM of the Gaussian fit. - 3. Change parameter and acquire a new 2D image. - 4. Run image registration to get the offset and adjust 1D/2D scan coordinates. - 5. Repeat 1 - 3 a few times to collect initial data for Bayesian optimization. + 3. Change parameter and acquire a new 2D image. The change of parameter causes the sample + to drift relative to the beam. + 4. Register the acquired image with the reference image (previous or initial) to estimate + the drift correction that should be applied to the line scan and image acquisition tools. + Update the positions for line scan and image acquisition. + 5. Repeat 1 - 4 a few times to collect initial data for Bayesian optimization. 6. Use Bayesian optimization to suggest new parameters. 7. Change parameter. 8. Run image registration or feature tracking as in 4. @@ -122,9 +130,14 @@ def __init__( See `line_scan_tool_x_coordinate_args`. image_acquisition_tool_y_coordinate_args: Tuple[str, ...] See `line_scan_tool_y_coordinate_args`. - registration_algorithm_kwargs : Optional[dict[str, Any]] - Optional keyword arguments forwarded to the selected image - registration algorithm when aligning consecutive 2D scans. + registration_target : Literal["previous", "initial"], optional + The reference image used by the registration branch of drift + correction. "previous" (default) registers each new 2D scan + against the immediately preceding one; small registration errors + therefore accumulate over many iterations. "initial" registers + every new 2D scan against the very first scan, which prevents + error accumulation at the cost of requiring sufficient overlap + between the current and the initial image. run_line_scan_checker : bool, optional If True, run the LLM-based line-scan quality checker and allow it to request scan-argument adjustments before accepting a line scan. @@ -132,22 +145,22 @@ def __init__( If True, run 2D image acquisition and image-registration-based offset calibration. If False, the loop only performs parameter setting, line scan, and optimization updates/suggestions. - use_linear_drift_prediction : bool, optional - If True, fit linear models to predict image-acquisition drift - (y and x separately) as a function of optics parameters and use - the predicted positions for subsequent 2D scans once enough - parameter-drift samples have been collected. - n_parameter_drift_points_before_prediction : int, optional - Number of parameter-drift samples to collect before using linear - drift prediction for image acquisitions. - line_scan_predictor_tool : LineScanPredictor, optional - If provided, this tool is used instead of image registration to - update scan positions after each 2D acquisition. It predicts the - optimal line scan center from the reference image, reference line - scan position, and current image, then shifts both line scan and - image acquisition kwargs by the predicted drift so they stay in - sync. Requires ``run_offset_calibration=True`` so that a 2D image - is acquired before the prediction is made. + registration_tools : list[RegistrationToolType], optional + Registration tools to use for drift estimation. Each tool must + provide ``get_offset(target=...)`` returning the shift to apply to + the current image so it aligns with the selected reference image. + Supported tool classes are :class:`ImageRegistration`, + :class:`TestPatternLandmarkFitting`, and :class:`NNRegistration`. + The caller is responsible for instantiating these tools with the + desired configuration. + registration_selection_priming_iterations : int, optional + Number of parameter-drift samples to collect before the linear + model is used to arbitrate among multiple registration tools. + During this warm-up phase the tool at + ``primary_registration_tool_index`` is used. + primary_registration_tool_index : int, optional + Index of the registration tool to trust during the warm-up phase + before the linear drift model has enough samples to arbitrate. """ if acquisition_tool is None: raise ValueError("`acquisition_tool` must be provided.") @@ -157,14 +170,11 @@ def __init__( self.optimization_tool = self.create_bo_tool(parameter_ranges) else: self.optimization_tool = optimization_tool - self.image_registration_tool = self.create_image_registration_tool( - acquisition_tool, - llm_config=llm_config, - registration_method=registration_method, - ) - self.registration_algorithm_kwargs = copy.deepcopy( - registration_algorithm_kwargs or {} - ) + if registration_target not in ("previous", "initial"): + raise ValueError( + f"`registration_target` must be 'previous' or 'initial', got {registration_target!r}." + ) + self.registration_target = registration_target if hasattr(acquisition_tool, "line_scan_return_gaussian_fit"): acquisition_tool.line_scan_return_gaussian_fit = True @@ -186,24 +196,47 @@ def __init__( self.run_line_scan_checker = run_line_scan_checker self.run_offset_calibration = run_offset_calibration - self.use_linear_drift_prediction = use_linear_drift_prediction - if line_scan_predictor_tool is not None and not run_offset_calibration: + self.registration_tools = list(registration_tools or []) + if self.run_offset_calibration and len(self.registration_tools) == 0: raise ValueError( - "`line_scan_predictor_tool` requires `run_offset_calibration=True` " - "because a 2D image must be acquired before the predictor can run." + "`registration_tools` must be provided when `run_offset_calibration=True`." ) - - self.line_scan_predictor_tool = line_scan_predictor_tool - self.n_parameter_drift_points_before_prediction = ( - n_parameter_drift_points_before_prediction + if len(self.registration_tools) > 1 and not run_offset_calibration: + raise ValueError( + "Multiple `registration_tools` require `run_offset_calibration=True` " + "because a 2D image must be acquired before the tool can run." + ) + for registration_tool in self.registration_tools: + if not hasattr(registration_tool, "get_offset"): + raise ValueError( + "Each registration tool must provide a `get_offset(target=...)` method, " + f"got {type(registration_tool).__name__}." + ) + if hasattr(registration_tool, "image_acquisition_tool") and getattr( + registration_tool, + "image_acquisition_tool", + ) is None: + registration_tool.image_acquisition_tool = acquisition_tool + if isinstance(registration_tool, ImageRegistration): + registration_tool.llm_config = llm_config + registration_tool.memory_config = memory_config + + self.registration_selection_priming_iterations = ( + registration_selection_priming_iterations ) - if self.n_parameter_drift_points_before_prediction < 1: + self.primary_registration_tool_index = primary_registration_tool_index + if self.registration_selection_priming_iterations < 1: raise ValueError( - "`n_parameter_drift_points_before_prediction` must be >= 1." + "`registration_selection_priming_iterations` must be >= 1." + ) + if not (0 <= self.primary_registration_tool_index < len(self.registration_tools)): + raise ValueError( + "`primary_registration_tool_index` is out of range for `registration_tools`." ) self.drift_model_y = MultivariateLinearRegression() self.drift_model_x = MultivariateLinearRegression() self.initial_image_acquisition_position: np.ndarray | None = None + self.initial_line_scan_position: np.ndarray | None = None super().__init__( llm_config=llm_config, @@ -226,23 +259,6 @@ def create_bo_tool(self, parameter_ranges: list[tuple[float, ...], tuple[float, ) return bo_tool - def create_image_registration_tool( - self, - acquisition_tool: AcquireImage, - llm_config: Optional[LLMConfig] = None, - registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm"] = "llm", - ): - image_registration_tool = ImageRegistration( - image_acquisition_tool=acquisition_tool, - llm_config=llm_config, - reference_image=None, - reference_pixel_size=1.0, - image_coordinates_origin="top_left", - registration_method=registration_method, - log_scale=True - ) - return image_registration_tool - def prerun_check( self, initial_sampling_range: Optional[Tuple[float, float]], @@ -316,7 +332,7 @@ def run( self.run_line_scan() # Initialize optimization tool. - self.collect_initial_data_optimization_tool( + self.collect_initial_data_for_optimization_tool( current_x=np.array(list(self.initial_parameters.values())), sampling_range=initial_sampling_window_size, n=n_initial_points, @@ -376,6 +392,8 @@ def initialize_kwargs_buffers( ): self.line_scan_kwargs = copy.deepcopy(initial_line_scan_kwargs) self.image_acquisition_kwargs = copy.deepcopy(initial_2d_scan_kwargs) + if self.line_scan_kwargs is not None: + self.initial_line_scan_position = self.extract_line_scan_position(self.line_scan_kwargs) def run_line_scan(self) -> float: """Run a line scan and return the FWHM of the Gaussian fit. @@ -441,7 +459,7 @@ def parse_json_from_response(self, response_text: str) -> dict[str, Any]: raise ValueError(f"Unable to parse JSON from response: {response_text}") return json.loads(match.group(0)) - def extract_scan_position(self, kwargs: dict[str, float]) -> np.ndarray: + def extract_line_scan_position(self, kwargs: dict[str, float]) -> np.ndarray: if len(self.line_scan_tool_x_coordinate_args) == 0 or len(self.line_scan_tool_y_coordinate_args) == 0: raise ValueError("Line scan coordinate args must not be empty.") x_arg = self.line_scan_tool_x_coordinate_args[0] @@ -467,34 +485,22 @@ def extract_image_acquisition_position(self, kwargs: dict[str, float]) -> np.nda ) return np.array([float(kwargs[y_arg]), float(kwargs[x_arg])], dtype=float) - def should_apply_linear_drift_prediction(self) -> bool: - if not self.run_offset_calibration: - return False - if not self.use_linear_drift_prediction: - return False - if self.initial_image_acquisition_position is None: - return False - return ( - self.drift_model_y.get_n_parameter_drift_points_collected() - >= self.n_parameter_drift_points_before_prediction - ) - def update_linear_drift_models( self, parameters: np.ndarray, current_position_yx: np.ndarray | None = None, ): - if not self.use_linear_drift_prediction: + if len(self.registration_tools) <= 1: return - if self.initial_image_acquisition_position is None: + if self.initial_line_scan_position is None: return if current_position_yx is None: - current_position_yx = self.extract_image_acquisition_position( - self.image_acquisition_kwargs + current_position_yx = self.extract_line_scan_position( + self.line_scan_kwargs ) - delta_yx = current_position_yx - self.initial_image_acquisition_position + delta_yx = current_position_yx - self.initial_line_scan_position x_train = np.array(parameters, dtype=float).reshape(1, -1).tolist() self.drift_model_y.update(x=x_train, y=[[float(delta_yx[0])]]) self.drift_model_x.update(x=x_train, y=[[float(delta_yx[1])]]) @@ -505,6 +511,43 @@ def update_linear_drift_models( f"n_samples={self.drift_model_y.get_n_parameter_drift_points_collected()}.```" ) + def snapshot_acquisition_state(self) -> dict[str, Any]: + """Capture mutable acquisition-tool state for retry rollback.""" + state = { + "image_0": copy.deepcopy(self.acquisition_tool.image_0), + "image_km1": copy.deepcopy(self.acquisition_tool.image_km1), + "image_k": copy.deepcopy(self.acquisition_tool.image_k), + "psize_0": copy.deepcopy(self.acquisition_tool.psize_0), + "psize_km1": copy.deepcopy(self.acquisition_tool.psize_km1), + "psize_k": copy.deepcopy(self.acquisition_tool.psize_k), + "image_acquisition_call_history": copy.deepcopy( + self.acquisition_tool.image_acquisition_call_history + ), + "line_scan_call_history": copy.deepcopy( + self.acquisition_tool.line_scan_call_history + ), + } + for attr in ["blur", "offset", "line_scan_candidates"]: + if hasattr(self.acquisition_tool, attr): + state[attr] = copy.deepcopy(getattr(self.acquisition_tool, attr)) + return state + + def restore_acquisition_state(self, state: dict[str, Any]) -> None: + """Restore acquisition-tool state captured by snapshot_acquisition_state.""" + self.acquisition_tool.image_0 = state["image_0"] + self.acquisition_tool.image_km1 = state["image_km1"] + self.acquisition_tool.image_k = state["image_k"] + self.acquisition_tool.psize_0 = state["psize_0"] + self.acquisition_tool.psize_km1 = state["psize_km1"] + self.acquisition_tool.psize_k = state["psize_k"] + self.acquisition_tool.image_acquisition_call_history = state[ + "image_acquisition_call_history" + ] + self.acquisition_tool.line_scan_call_history = state["line_scan_call_history"] + for attr in ["blur", "offset", "line_scan_candidates"]: + if attr in state: + setattr(self.acquisition_tool, attr, state[attr]) + def record_linear_drift_model_visualizations(self) -> None: image_paths = [] for axis_name, model in [("y", self.drift_model_y), ("x", self.drift_model_x)]: @@ -536,24 +579,6 @@ def record_linear_drift_model_visualizations(self) -> None: image_path=image_paths, ) - def apply_predicted_image_acquisition_position(self, parameters: np.ndarray): - current_pos = self.extract_image_acquisition_position(self.image_acquisition_kwargs) - x_in = np.array(parameters, dtype=float).reshape(1, -1).tolist() - delta_y = float(self.drift_model_y.predict(x_in)[0][0]) - delta_x = float(self.drift_model_x.predict(x_in)[0][0]) - predicted_pos = self.initial_image_acquisition_position + np.array( - [delta_y, delta_x], - dtype=float, - ) - delta_pos = predicted_pos - current_pos - self.apply_offset_to_image_acquisition_kwargs(-delta_pos) - self.apply_offset_to_line_scan_kwargs(-delta_pos) - self.record_system_message( - "Using linear drift prediction for acquisition positions: " - f"```Predicted 2D scan yx = {predicted_pos.tolist()}" - f"Offset from last = {delta_pos.tolist()}```" - ) - def build_line_scan_precheck_message( self, line_scan_result: dict[str, Any], @@ -723,8 +748,8 @@ def check_line_scan( old_line_scan_kwargs = copy.deepcopy(self.line_scan_kwargs) merged_line_scan_kwargs = copy.deepcopy(self.line_scan_kwargs) merged_line_scan_kwargs.update(new_line_scan_kwargs) - old_scan_position = self.extract_scan_position(old_line_scan_kwargs) - new_scan_position = self.extract_scan_position(merged_line_scan_kwargs) + old_scan_position = self.extract_line_scan_position(old_line_scan_kwargs) + new_scan_position = self.extract_line_scan_position(merged_line_scan_kwargs) offset = new_scan_position - old_scan_position self.line_scan_kwargs = merged_line_scan_kwargs if self.run_offset_calibration: @@ -776,53 +801,115 @@ def get_suggested_next_parameters(self, step_size_limit: Optional[float | Tuple[ p_suggested = p_current + signs * step_sizes return p_suggested - def find_offset(self) -> np.ndarray: - """Find the offset between the latest image and the previous image. + def get_registration_tool_name(self, registration_tool: RegistrationToolType) -> str: + return getattr(registration_tool, "name", type(registration_tool).__name__) + + def record_registration_tool_result( + self, + registration_tool: RegistrationToolType, + target: Literal["previous", "initial"], + alignment_offset: np.ndarray, + ) -> None: + if not hasattr(registration_tool, "plot_last_fit"): + return + + registration_fig = None + try: + registration_fig = registration_tool.plot_last_fit() + registration_fig_path = BaseTool.save_image_to_temp_dir( + fig=registration_fig, + filename=f"{self.get_registration_tool_name(registration_tool)}_overlay.png", + add_timestamp=True, + ) + self.record_system_message( + content=( + f"Registration result from {self.get_registration_tool_name(registration_tool)}:\n" + f"```target={target}\n" + f"alignment_offset={alignment_offset.tolist()}```" + ), + image_path=registration_fig_path, + ) + except Exception as exc: + logger.warning( + "Failed to render registration overlay for %s: %s", + self.get_registration_tool_name(registration_tool), + exc, + ) + finally: + if registration_fig is not None: + plt.close(registration_fig) + + def find_position_correction( + self, + registration_tool: RegistrationToolType, + target: Literal["previous", "initial"] = "previous", + ) -> tuple[np.ndarray, np.ndarray]: + """Find the correction implied by a registration tool result. + + Parameters + ---------- + registration_tool : RegistrationToolType + Registration tool that provides ``get_offset(target=...)``, where + the returned offset is the shift to apply to the current/test image + so it aligns with the selected reference image. + target : Literal["previous", "initial"], optional + The reference image to register against. + "previous": register the current image against the immediately + preceding image. The returned corrections are relative to the + previous step. + "initial": register the current image against the very first + acquired image. The returned corrections are cumulative from the + initial position, which prevents per-step error accumulation. Returns ------- - np.ndarray - The offset between the latest image and the previous image. - Offset is in physical units, i.e., pixel size is already accounted for. + tuple[np.ndarray, np.ndarray] + Two arrays in physical units: + ``(line_scan_correction, registration_correction)``. + ``line_scan_correction`` includes both the pure registration + correction and the intentional scan-position difference between the + current and reference 2D scans. + ``registration_correction`` is the pure stage-position correction + to apply to image acquisition coordinates. """ - alignment_offset = np.array( - self.image_registration_tool.register_images( - image_t=self.image_registration_tool.process_image(self.acquisition_tool.image_k), - image_r=self.image_registration_tool.process_image(self.acquisition_tool.image_km1), - psize_t=self.acquisition_tool.psize_k, - psize_r=self.acquisition_tool.psize_km1, - registration_algorithm_kwargs=self.registration_algorithm_kwargs, - ), - dtype=float, - ) - - # Count in the difference of scan positions. + if target not in ("previous", "initial"): + raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") + + # Scan-position difference depends only on the target image. + ref_history_idx = -2 if target == "previous" else 0 scan_pos_diff = np.array([ float(self.acquisition_tool.image_acquisition_call_history[-1][f"{dir}_center"]) - - float(self.acquisition_tool.image_acquisition_call_history[-2][f"{dir}_center"]) + - float(self.acquisition_tool.image_acquisition_call_history[ref_history_idx][f"{dir}_center"]) for dir in ["y", "x"] ]).astype(float) - offset_to_subtract = alignment_offset - scan_pos_diff - self.record_system_message( - f"Pure image registration offset (to apply to current image for alignment) is " - f"{alignment_offset}. Counting in scan-position difference {scan_pos_diff}, " - f"the offset to subtract from the next line scan positions is {offset_to_subtract}." + + alignment_offset = np.array( + registration_tool.get_offset(target=target), + dtype=float, ) - return offset_to_subtract, alignment_offset + self.record_registration_tool_result(registration_tool, target, alignment_offset) + + # Image registration offset is the offset by which the moving image should be rolled to match + # the reference. We want acquisition position correction here, which is the negation of it. + registration_correction = -alignment_offset + + # Count in the difference of scan positions. + line_scan_correction = registration_correction + scan_pos_diff + return line_scan_correction, registration_correction def apply_offset_to_line_scan_kwargs(self, offset: np.ndarray): for arg in self.line_scan_tool_x_coordinate_args: - self.line_scan_kwargs[arg] -= offset[1] + self.line_scan_kwargs[arg] += offset[1] for arg in self.line_scan_tool_y_coordinate_args: - self.line_scan_kwargs[arg] -= offset[0] + self.line_scan_kwargs[arg] += offset[0] def apply_offset_to_image_acquisition_kwargs(self, offset: np.ndarray): for arg in self.image_acquisition_tool_x_coordinate_args: - self.image_acquisition_kwargs[arg] -= offset[1] + self.image_acquisition_kwargs[arg] += offset[1] for arg in self.image_acquisition_tool_y_coordinate_args: - self.image_acquisition_kwargs[arg] -= offset[0] + self.image_acquisition_kwargs[arg] += offset[0] - def collect_initial_data_optimization_tool( + def collect_initial_data_for_optimization_tool( self, current_x: np.ndarray, sampling_range: np.ndarray, @@ -861,6 +948,7 @@ def run_tuning_iteration(self, x: np.ndarray): x_current = np.array(x, dtype=float) line_scan_kwargs_before = copy.deepcopy(self.line_scan_kwargs) image_acquisition_kwargs_before = copy.deepcopy(self.image_acquisition_kwargs) + acquisition_state_before = self.snapshot_acquisition_state() def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: for parameter_name in self.param_setting_tool.parameter_names: @@ -868,6 +956,7 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: self.param_setting_tool.parameter_history[parameter_name].pop() self.line_scan_kwargs = copy.deepcopy(line_scan_kwargs_before) self.image_acquisition_kwargs = copy.deepcopy(image_acquisition_kwargs_before) + self.restore_acquisition_state(acquisition_state_before) delta = x_current - x_original x_next = x_original + delta / 2 self.record_system_message( @@ -881,21 +970,10 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: while True: self.record_system_message(f"Setting parameters to new value:```{x_current}```") self.param_setting_tool.set_parameters(x_current) + chosen_line_scan_correction_wrt_initial = None if self.run_offset_calibration: - if self.should_apply_linear_drift_prediction(): - self.apply_predicted_image_acquisition_position(x_current) self.run_2d_scan() - if self.line_scan_predictor_tool is not None: - self.apply_line_scan_predictor_offset() - else: - line_scan_pos_offset, alignment_offset = self.find_offset() - if np.any(np.isnan(line_scan_pos_offset)): - x_current = rollback_and_shrink_delta("Image registration failed (NaN offset).") - continue - self.apply_offset_to_line_scan_kwargs(line_scan_pos_offset) - self.apply_offset_to_image_acquisition_kwargs(alignment_offset) - self.update_linear_drift_models(x_current) - self.record_linear_drift_model_visualizations() + chosen_line_scan_correction_wrt_initial = self.apply_drift_correction(x_current) try: fwhm = self.run_line_scan() if np.isnan(fwhm): @@ -903,30 +981,180 @@ def rollback_and_shrink_delta(message_prefix: str) -> np.ndarray: except LineScanValidationFailed: x_current = rollback_and_shrink_delta("Line scan validation failed.") continue + if ( + self.run_offset_calibration + and chosen_line_scan_correction_wrt_initial is not None + ): + self.update_linear_drift_models( + x_current, + current_position_yx=chosen_line_scan_correction_wrt_initial, + ) + self.record_linear_drift_model_visualizations() self.update_optimization_model(fwhm) return - def apply_line_scan_predictor_offset(self) -> None: - """Update scan positions using the line scan predictor. + def _select_drift( + self, + candidate_drifts: dict[str, np.ndarray], + x_current: np.ndarray, + ) -> tuple[np.ndarray, str]: + """Select the best cumulative drift estimate from the available tools.""" + n_collected = self.drift_model_y.get_n_parameter_drift_points_collected() + n_needed = self.registration_selection_priming_iterations + + if n_collected < n_needed: + preferred_tool = self.registration_tools[self.primary_registration_tool_index] + preferred_name = self.get_registration_tool_name(preferred_tool) + if preferred_name in candidate_drifts: + chosen_source = preferred_name + else: + chosen_source = next(iter(candidate_drifts)) + self.record_system_message( + "Primary registration tool did not yield a valid result; " + f"falling back to {chosen_source}." + ) + chosen_drift = candidate_drifts[chosen_source] + candidate_drift_lines = "\n".join( + f"{name}: {drift.tolist()}" + for name, drift in candidate_drifts.items() + ) + self.record_system_message( + f"Registration tool selection (primary phase, n={n_collected}/{n_needed}): " + f"```using {chosen_source}\n" + f"candidate_drifts:\n{candidate_drift_lines}```" + ) + else: + x_in = np.array(x_current, dtype=float).reshape(1, -1).tolist() + model_drift = np.array( + [ + float(self.drift_model_y.predict(x_in)[0][0]), + float(self.drift_model_x.predict(x_in)[0][0]), + ], + dtype=float, + ) + candidate_distances = { + name: float(np.linalg.norm(drift - model_drift)) + for name, drift in candidate_drifts.items() + } + chosen_source = min(candidate_distances, key=candidate_distances.get) + chosen_drift = candidate_drifts[chosen_source] + candidate_result_lines = "\n".join( + f"{name}: drift={candidate_drifts[name].tolist()}, dist={candidate_distances[name]:.4f}" + for name in candidate_drifts + ) + self.record_system_message( + f"Registration tool selection (arbitration phase):\n" + f"```model_drift={model_drift.tolist()}\n" + f"candidates:\n{candidate_result_lines}\n" + f"Chosen: {chosen_source}```" + ) + + return chosen_drift, chosen_source + + def apply_drift_correction(self, x_current: np.ndarray) -> np.ndarray: + """Run the configured registration tools, select a drift estimate, and apply it. - Calls :meth:`LineScanPredictor.predict_line_scan_position` to obtain - the predicted line scan center in the current image, computes the - drift relative to the current line scan center, then shifts both - ``line_scan_kwargs`` and ``image_acquisition_kwargs`` by that drift so - that the two sets of coordinates stay in sync. + Parameters + ---------- + x_current : np.ndarray + Current optics parameter vector. """ - current_center = self.extract_scan_position(self.line_scan_kwargs) - result = json.loads(self.line_scan_predictor_tool.predict_line_scan_position()) - predicted_center = np.array([result["center_y"], result["center_x"]], dtype=float) - drift = predicted_center - current_center + if self.initial_line_scan_position is None or self.initial_image_acquisition_position is None: + raise RuntimeError( + "apply_drift_correction requires initial scan positions to be set. " + "Ensure initialize_kwargs_buffers and run_2d_scan have been called first." + ) + + current_line_scan_position = self.extract_line_scan_position(self.line_scan_kwargs) + current_image_acq_position = self.extract_image_acquisition_position( + self.image_acquisition_kwargs + ) + + candidate_results: dict[str, dict[str, np.ndarray]] = {} + for registration_tool in self.registration_tools: + tool_name = self.get_registration_tool_name(registration_tool) + try: + line_scan_correction, image_acq_correction = self.find_position_correction( + registration_tool=registration_tool, + target=self.registration_target, + ) + except Exception as exc: + if len(self.registration_tools) == 1: + raise + logger.warning("Registration tool %s failed: %s", tool_name, exc) + self.record_system_message( + f"Registration tool {tool_name} failed and will be skipped: {exc}" + ) + continue + + if self.registration_target == "previous": + line_scan_correction_wrt_prev = line_scan_correction + line_scan_correction_wrt_initial = ( + current_line_scan_position + + line_scan_correction + - self.initial_line_scan_position + ) + image_acq_correction_wrt_prev = image_acq_correction + image_acq_correction_wrt_initial = ( + image_acq_correction + + current_image_acq_position + - self.initial_image_acquisition_position + ) + else: + line_scan_correction_wrt_initial = line_scan_correction + line_scan_correction_wrt_prev = ( + self.initial_line_scan_position + + line_scan_correction + - current_line_scan_position + ) + image_acq_correction_wrt_initial = image_acq_correction + image_acq_correction_wrt_prev = ( + self.initial_image_acquisition_position + + image_acq_correction + - current_image_acq_position + ) + + candidate_results[tool_name] = { + "line_scan_correction_wrt_prev": line_scan_correction_wrt_prev, + "line_scan_correction_wrt_initial": line_scan_correction_wrt_initial, + "image_acq_correction_wrt_prev": image_acq_correction_wrt_prev, + "image_acq_correction_wrt_initial": image_acq_correction_wrt_initial, + } + + if len(candidate_results) == 0: + raise RuntimeError("No registration tool produced a valid drift estimate.") + + if len(candidate_results) == 1: + chosen_source = next(iter(candidate_results)) + else: + chosen_line_scan_correction_wrt_initial, chosen_source = self._select_drift( + { + name: result["line_scan_correction_wrt_initial"] + for name, result in candidate_results.items() + }, + x_current, + ) + _ = chosen_line_scan_correction_wrt_initial + + chosen_result = candidate_results[chosen_source] + chosen_line_scan_correction_wrt_prev = chosen_result["line_scan_correction_wrt_prev"] + chosen_line_scan_correction_wrt_initial = chosen_result["line_scan_correction_wrt_initial"] + chosen_image_acq_correction_wrt_prev = chosen_result["image_acq_correction_wrt_prev"] + chosen_image_acq_correction_wrt_initial = chosen_result["image_acq_correction_wrt_initial"] + + self.apply_offset_to_line_scan_kwargs(chosen_line_scan_correction_wrt_prev) + self.apply_offset_to_image_acquisition_kwargs(chosen_image_acq_correction_wrt_prev) + self.record_system_message( - f"Line scan predictor: current center={current_center.tolist()}, " - f"predicted center={predicted_center.tolist()}, " - f"drift={drift.tolist()}" + f"Applied drift correction:\n" + f"```source = {chosen_source}\n" + f"chosen_line_scan_correction_wrt_prev = {chosen_line_scan_correction_wrt_prev.tolist()}\n" + f"chosen_line_scan_correction_wrt_initial = {chosen_line_scan_correction_wrt_initial.tolist()}\n" + f"chosen_image_acq_correction_wrt_prev = {chosen_image_acq_correction_wrt_prev.tolist()}\n" + f"chosen_image_acq_correction_wrt_initial = {chosen_image_acq_correction_wrt_initial.tolist()}```" ) - # apply_offset_to_*_kwargs(o) does position -= o; passing -drift gives position += drift. - self.apply_offset_to_line_scan_kwargs(-drift) - self.apply_offset_to_image_acquisition_kwargs(-drift) + + return chosen_line_scan_correction_wrt_initial def apply_user_correction_offset(self) -> bool: message = ( @@ -945,8 +1173,8 @@ def apply_user_correction_offset(self) -> bool: except ValueError: logger.info("Invalid offset values. Use numeric values like 'y,x'.") continue - self.apply_offset_to_line_scan_kwargs(offset) - self.apply_offset_to_image_acquisition_kwargs(offset) + self.apply_offset_to_line_scan_kwargs(-offset) + self.apply_offset_to_image_acquisition_kwargs(-offset) correction_message = f"Applied user correction offset: {offset.tolist()}" logger.info(correction_message) self.record_system_message(correction_message) diff --git a/src/eaa/tool/imaging/acquisition.py b/src/eaa/tool/imaging/acquisition.py index e0335d6..cc36dff 100644 --- a/src/eaa/tool/imaging/acquisition.py +++ b/src/eaa/tool/imaging/acquisition.py @@ -519,7 +519,7 @@ def acquire_line_scan( fwhm = np.nan else: val_gauss = eaa.maths.gaussian_1d(ds, a, mu, sigma, c) - fwhm = 2.35 * sigma + fwhm = 2.35 * np.abs(sigma) show_scan_line = self.image_k is not None and len(self.image_acquisition_call_history) > 0 show_first_scan_line = ( diff --git a/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py b/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py new file mode 100644 index 0000000..a478e9d --- /dev/null +++ b/src/eaa/tool/imaging/aps_mic/test_target_landmark_fitting.py @@ -0,0 +1,430 @@ +from typing import Annotated, Literal, Optional +import logging + +import matplotlib.pyplot as plt +import numpy as np +import scipy.ndimage as ndi +from skimage import filters, measure +from sciagent.tool.base import BaseTool, ToolReturnType, check, tool + +from eaa.tool.imaging.acquisition import AcquireImage + +logger = logging.getLogger(__name__) + + +class TestPatternLandmarkFitting(BaseTool): + """Fit the circular landmark on the right side of an APS-MIC test image. + + The expected landmark is a bright or dark disk-shaped feature that may be + only partially visible near the right image boundary. The fitting workflow + is: + + 1. Apply Gaussian smoothing and normalize the image to ``[0, 1]``. + 2. Segment the landmark candidate with binary thresholding. + 3. Apply binary erosion followed by binary dilation with a 3x3 structure + element for 3 iterations to suppress small islands and weak bridges. + 4. Run connected-component analysis and keep the component whose right-most + pixel has the largest x coordinate. + 5. Extract the boundary pixels of that component, excluding pixels on the + outer image border, and fit a circle to those arc pixels with RANSAC. + + The returned center is expressed in the coordinate system of the original, + uncropped image. + """ + + name: str = "test_pattern_landmark_fitting" + + @check + def __init__( + self, + image_acquisition_tool: Optional[AcquireImage] = None, + zoom: float = 4.0, + gaussian_sigma_fraction: float = 0.03, + ransac_residual_threshold_fraction: float = 0.015, + ransac_max_trials: int = 1000, + require_approval: bool = False, + *args, + **kwargs, + ): + """Initialize the landmark fitting tool. + + Parameters + ---------- + image_acquisition_tool : Optional[AcquireImage], optional + Acquisition tool that provides ``image_k`` when no image is passed + directly to :meth:`fit_landmark_center`. + zoom : float, optional + Zoom factor applied to the image before segmentation and fitting. + Returned coordinates are converted back to the original image scale. + gaussian_sigma_fraction : float, optional + Gaussian blur sigma expressed as a fraction of the original image + width in pixels. + ransac_residual_threshold_fraction : float, optional + RANSAC inlier threshold expressed as a fraction of the cropped + image size. + ransac_max_trials : int, optional + Maximum number of RANSAC iterations used for the circle fit. + require_approval : bool, optional + Whether tool execution requires explicit approval in the agent + framework. + """ + super().__init__(*args, require_approval=require_approval, **kwargs) + + self.image_acquisition_tool = image_acquisition_tool + self.zoom = zoom + self.gaussian_sigma_fraction = gaussian_sigma_fraction + self.ransac_residual_threshold_fraction = ransac_residual_threshold_fraction + self.ransac_max_trials = ransac_max_trials + self.latest_image: Optional[np.ndarray] = None + self.latest_circle_model: Optional[measure.CircleModel] = None + self.latest_circle_inliers: Optional[np.ndarray] = None + self.latest_processed_image: Optional[np.ndarray] = None + + def preprocess_image(self, image: np.ndarray) -> np.ndarray: + """Convert the input to a finite 2D float image. + + Parameters + ---------- + image : np.ndarray + Input image. If the input is 3D, the last axis is averaged. + + Returns + ------- + np.ndarray + A 2D floating-point image with non-finite values replaced by the + mean of the finite pixels. + """ + arr = np.asarray(image, dtype=float) + if arr.ndim == 3: + arr = np.mean(arr, axis=-1) + if arr.ndim != 2: + raise ValueError(f"Expected a 2D image, got shape {arr.shape}.") + + if not np.isfinite(arr).all(): + finite_mask = np.isfinite(arr) + if not finite_mask.any(): + raise ValueError("Input image does not contain any finite pixels.") + arr = arr.copy() + arr[~finite_mask] = float(np.mean(arr[finite_mask])) + return arr + + def get_input_image(self, image: Optional[np.ndarray]) -> np.ndarray: + """Return the image to process. + + Parameters + ---------- + image : Optional[np.ndarray] + Explicit image to fit. If omitted, ``image_acquisition_tool.image_k`` + is used. + + Returns + ------- + np.ndarray + Preprocessed 2D image. + """ + if image is not None: + return self.preprocess_image(image) + if self.image_acquisition_tool is None or self.image_acquisition_tool.image_k is None: + raise ValueError( + "No image was provided and image_acquisition_tool.image_k is not available." + ) + return self.preprocess_image(self.image_acquisition_tool.image_k) + + def zoom_image(self, image: np.ndarray) -> np.ndarray: + """Zoom the image before processing. + + Parameters + ---------- + image : np.ndarray + Preprocessed 2D image in original coordinates. + + Returns + ------- + np.ndarray + Image in the processing scale. + """ + if self.zoom <= 0: + raise ValueError("zoom must be positive.") + if self.zoom == 1.0: + return image + return ndi.zoom(image, zoom=self.zoom, order=1, mode="nearest") + + def resolve_pixel_size( + self, + pixel_size: Optional[float], + image_role: Literal["current", "previous", "initial"], + ) -> float: + """Resolve the pixel size for converting pixel coordinates to physical units.""" + if pixel_size is not None: + return float(pixel_size) + if self.image_acquisition_tool is None: + raise ValueError( + "pixel_size must be provided when image_acquisition_tool is unavailable." + ) + + if image_role == "current": + resolved = self.image_acquisition_tool.psize_k + elif image_role == "previous": + resolved = self.image_acquisition_tool.psize_km1 + elif image_role == "initial": + resolved = self.image_acquisition_tool.psize_0 + else: + raise ValueError(f"Unsupported image_role: {image_role!r}.") + + if resolved is None: + raise ValueError( + f"Pixel size for image_role={image_role!r} is unavailable." + ) + return float(resolved) + + @staticmethod + def normalize_image(image: np.ndarray) -> np.ndarray: + """Robustly normalize an image to ``[0, 1]`` using percentiles.""" + lo = float(np.percentile(image, 1)) + hi = float(np.percentile(image, 99)) + if hi <= lo: + lo = float(np.min(image)) + hi = float(np.max(image)) + if hi <= lo: + return np.zeros_like(image, dtype=float) + return np.clip((image - lo) / (hi - lo), 0.0, 1.0) + + def detect_edge_points(self, image: np.ndarray) -> np.ndarray: + """Detect candidate circle-arc points from the segmented landmark. + + Parameters + ---------- + image : np.ndarray + Preprocessed 2D image in original coordinates. + + Returns + ------- + np.ndarray + Boundary points as ``(x, y)`` pairs in original image coordinates. + """ + sigma = max(0.5, self.gaussian_sigma_fraction * image.shape[1]) + blurred = ndi.gaussian_filter(image, sigma=sigma, mode="nearest") + normalized = self.normalize_image(blurred) + threshold = filters.threshold_otsu(normalized) + binary = normalized >= threshold + + structure = np.ones((3, 3), dtype=bool) + # binary = ndi.binary_erosion(binary, structure=structure, iterations=1) + # binary = ndi.binary_dilation(binary, structure=structure, iterations=1) + # binary = ndi.binary_closing(binary, structure=structure, iterations=1) + + labels, num_labels = ndi.label(binary, structure=structure) + if num_labels == 0: + raise ValueError("No connected component was found in the segmented image.") + + selected_label = None + rightmost_x = -1 + for label_id in range(1, num_labels + 1): + label_y, label_x = np.nonzero(labels == label_id) + if label_x.size == 0: + continue + label_rightmost_x = int(np.max(label_x)) + if label_rightmost_x > rightmost_x: + rightmost_x = label_rightmost_x + selected_label = label_id + + if selected_label is None: + raise ValueError("Failed to select a segmented landmark component.") + + component = labels == selected_label + boundary = component & ~ndi.binary_erosion(component, structure=structure, iterations=1) + valid_support = np.zeros_like(component, dtype=bool) + valid_support[1:-1, 1:-1] = True + boundary &= valid_support + + edge_y, edge_x = np.nonzero(boundary) + if edge_x.size < 3: + raise ValueError("Segmented landmark boundary has fewer than 3 candidate points.") + + points = np.column_stack((edge_x.astype(float), edge_y.astype(float))) + return points + + def fit_circle( + self, + points: np.ndarray, + cropped_shape: tuple[int, int], + ) -> tuple[measure.CircleModel, np.ndarray]: + """Fit a circle with RANSAC and return its parameters. + + Parameters + ---------- + points : np.ndarray + Candidate edge points as ``(x, y)`` pairs in the original image + frame. + cropped_shape : tuple[int, int] + Shape of the image as ``(height, width)``. + + Returns + ------- + tuple[measure.CircleModel, np.ndarray] + The fitted circle model and its boolean inlier mask. + """ + residual_threshold = max( + 1.0, + self.ransac_residual_threshold_fraction * max(cropped_shape), + ) + model, inliers = measure.ransac( + points, + measure.CircleModel, + min_samples=3, + residual_threshold=residual_threshold, + max_trials=self.ransac_max_trials, + ) + if model is None or inliers is None or int(np.count_nonzero(inliers)) < 3: + raise ValueError("RANSAC could not find a valid circle from the detected edges.") + + center_x, center_y, radius = model.params + logger.debug( + "Circle fit: center=(%.3f, %.3f), radius=%.3f, inliers=%d/%d", + center_x, + center_y, + radius, + int(np.count_nonzero(inliers)), + len(points), + ) + return model, inliers + + def plot_last_fit(self) -> plt.Figure: + """Plot the most recent image with the fitted circle overlaid. + + Returns + ------- + matplotlib.figure.Figure + Figure showing the stored image and the last fitted circle. + """ + if self.latest_image is None or self.latest_circle_model is None: + raise ValueError("No fitted landmark is available. Call fit_landmark_center first.") + + center_x, center_y, radius = self.latest_circle_model.params + theta = np.linspace(0.0, 2.0 * np.pi, 512) + circle_x = center_x + radius * np.cos(theta) + circle_y = center_y + radius * np.sin(theta) + + fig, ax = plt.subplots(1, 1, squeeze=True) + ax.imshow(self.latest_image, cmap="viridis", origin="upper") + ax.plot(circle_x, circle_y, color="cyan", linewidth=1.5) + ax.scatter([center_x], [center_y], color="red", s=30) + ax.set_title("Landmark Circle Fit") + ax.set_xlim(-0.5, self.latest_image.shape[1] - 0.5) + ax.set_ylim(self.latest_image.shape[0] - 0.5, -0.5) + return fig + + @tool(name="fit_landmark_center", return_type=ToolReturnType.LIST) + def fit_landmark_center( + self, + image: Annotated[ + Optional[np.ndarray], + "Optional 2D image array. When omitted, the tool uses image_acquisition_tool.image_k.", + ] = None, + pixel_size: Annotated[ + Optional[float], + "Pixel size used to convert the fitted center from pixels to physical units.", + ] = None, + image_role: Annotated[ + Literal["current", "previous", "initial"], + "Which acquisition buffer the image corresponds to when pixel_size is omitted.", + ] = "current", + ) -> Annotated[ + list[float], + "The fitted landmark center as [center_y, center_x] in physical units.", + ]: + """Detect the right-side disk feature and return its center. + + Parameters + ---------- + image : Optional[np.ndarray], optional + Explicit image array to process. If omitted, the latest image from + ``image_acquisition_tool`` is used. + pixel_size : Optional[float], optional + Pixel size in physical units per pixel. If omitted, the value is + taken from ``image_acquisition_tool`` according to ``image_role``. + image_role : {"current", "previous", "initial"}, optional + Acquisition-buffer role used to resolve the pixel size when + ``pixel_size`` is omitted. + + Returns + ------- + list[float] + Landmark center as ``[center_y, center_x]`` in the coordinates of + the original image, expressed in physical units. + """ + image_arr = self.get_input_image(image) + resolved_pixel_size = self.resolve_pixel_size(pixel_size, image_role) + processed_image = self.zoom_image(image_arr) + points = self.detect_edge_points(processed_image) + model, inliers = self.fit_circle( + points, + cropped_shape=processed_image.shape, + ) + center_x_px, center_y_px, radius_px = model.params + circle_model_px = measure.CircleModel() + circle_model_px.params = ( + float(center_x_px / self.zoom), + float(center_y_px / self.zoom), + float(radius_px / self.zoom), + ) + self.latest_image = image_arr + self.latest_processed_image = processed_image + self.latest_circle_model = circle_model_px + self.latest_circle_inliers = inliers + return [ + float(circle_model_px.params[1] * resolved_pixel_size), + float(circle_model_px.params[0] * resolved_pixel_size), + ] + + @tool(name="get_offset", return_type=ToolReturnType.LIST) + def get_offset( + self, + target: Annotated[ + Literal["previous", "initial"], + "Reference image buffer against which the current image is compared.", + ] = "initial", + ) -> Annotated[ + list[float], + "The landmark-based registration offset [dy, dx] in physical units.", + ]: + """Return the shift to apply to the test image to match the reference. + + The returned offset follows the same convention as the other image + registration tools in this codebase: it is the physical-space + translation ``[dy, dx]`` that should be applied to the current (test) + image so it aligns with the selected reference image. + """ + if self.image_acquisition_tool is None: + raise RuntimeError( + "image_acquisition_tool is required to compare current and reference images." + ) + if self.image_acquisition_tool.image_k is None: + raise RuntimeError("Current image buffer (image_k) is not populated.") + + if target == "previous": + reference_image = self.image_acquisition_tool.image_km1 + reference_role: Literal["previous", "initial"] = "previous" + elif target == "initial": + reference_image = self.image_acquisition_tool.image_0 + reference_role = "initial" + else: + raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") + + if reference_image is None: + raise RuntimeError( + f"Reference image buffer for target={target!r} is not populated." + ) + + center_ref = np.array( + self.fit_landmark_center(image=reference_image, image_role=reference_role), + dtype=float, + ) + center_test = np.array( + self.fit_landmark_center( + image=self.image_acquisition_tool.image_k, + image_role="current", + ), + dtype=float, + ) + return (center_ref - center_test).tolist() diff --git a/src/eaa/tool/imaging/nn_registration.py b/src/eaa/tool/imaging/nn_registration.py new file mode 100644 index 0000000..15b2d5a --- /dev/null +++ b/src/eaa/tool/imaging/nn_registration.py @@ -0,0 +1,141 @@ +import io +import logging +from typing import Literal + +import numpy as np +import requests +import tifffile +from sciagent.tool.base import BaseTool, check, ToolReturnType, tool + +from eaa.tool.imaging.acquisition import AcquireImage + +logger = logging.getLogger(__name__) + + +class NNRegistration(BaseTool): + """Tool that queries a NN server to obtain a registration offset between + a reference image and the current image. + + The server must serve a model with ``prediction_type="offset"``, which + accepts ``ref_image`` and ``test_image`` and returns + ``{"offset_y": float, "offset_x": float}`` as fractions of the reference + image size. + The offset convention matches :class:`~eaa.tool.imaging.registration.ImageRegistration`: + the returned values are the translation to apply to the test image so it + aligns with the reference. + """ + + name: str = "nn_registration" + + @check + def __init__( + self, + server_url: str, + image_acquisition_tool: AcquireImage, + require_approval: bool = False, + *args, + **kwargs, + ): + """Initialize the NN registration tool. + + Parameters + ---------- + server_url : str + Base URL of the inference server, e.g. ``"http://localhost:8090"``. + image_acquisition_tool : AcquireImage + The image acquisition tool instance. Must be the same object used + by the task manager so that its image buffers and call history + reflect the current state. + """ + super().__init__(*args, require_approval=require_approval, **kwargs) + self.server_url = server_url.rstrip("/") + self.image_acquisition_tool = image_acquisition_tool + + @staticmethod + def _encode_as_tiff(image: np.ndarray) -> bytes: + buf = io.BytesIO() + tifffile.imwrite(buf, image.astype(np.float32)) + return buf.getvalue() + + @tool(name="get_offset", return_type=ToolReturnType.LIST) + def get_offset(self, target: Literal["previous", "initial"] = "initial") -> np.ndarray: + """Query the server and return the registration offset in physical units. + + Parameters + ---------- + target : Literal["previous", "initial"] + The reference image to register against. + "previous": register the current image (``image_k``) against the + immediately preceding image (``image_km1``). + "initial": register the current image (``image_k``) against the + first acquired image (``image_0``), giving the cumulative drift from + the initial acquisition. + + Returns + ------- + np.ndarray + ``[offset_y, offset_x]`` in physical coordinate units (same units + as the image acquisition positions). The convention matches + :meth:`~eaa.tool.imaging.registration.ImageRegistration.register_images`: + this is the translation to apply to the current image so it aligns + with the reference image. + + Raises + ------ + RuntimeError + If the required image buffers or acquisition history are not populated. + requests.HTTPError + If the server returns a non-2xx response. + """ + acq = self.image_acquisition_tool + + if acq.image_k is None: + raise RuntimeError( + "Current image buffer (image_k) is not populated. " + "Acquire at least one image before calling get_offset." + ) + if not acq.image_acquisition_call_history: + raise RuntimeError("No image acquisition history found.") + + if target == "previous": + if acq.image_km1 is None: + raise RuntimeError( + "Previous image buffer (image_km1) is not populated. " + "Acquire at least two images before calling get_offset with target='previous'." + ) + ref_image = acq.image_km1 + ref_img_info = acq.image_acquisition_call_history[-2] + elif target == "initial": + if acq.image_0 is None: + raise RuntimeError( + "Initial image buffer (image_0) is not populated. " + "Acquire at least one image before calling get_offset with target='initial'." + ) + ref_image = acq.image_0 + ref_img_info = acq.image_acquisition_call_history[0] + else: + raise ValueError(f"`target` must be 'previous' or 'initial', got {target!r}.") + + ref_tiff = self._encode_as_tiff(ref_image) + test_tiff = self._encode_as_tiff(acq.image_k) + + response = requests.post( + f"{self.server_url}/predict", + files={ + "ref_image": ("ref_image.tif", ref_tiff, "image/tiff"), + "test_image": ("test_image.tif", test_tiff, "image/tiff"), + }, + ) + response.raise_for_status() + result = response.json() + + # Convert fractions of the reference image size to physical units. + offset_y_phys = float(result["offset_y"]) * float(ref_img_info["size_y"]) + offset_x_phys = float(result["offset_x"]) * float(ref_img_info["size_x"]) + + logger.debug( + "NNRegistration offset (target=%s): frac=(%.4f, %.4f), phys=(%.4f, %.4f)", + target, result["offset_y"], result["offset_x"], + offset_y_phys, offset_x_phys, + ) + return np.array([offset_y_phys, offset_x_phys], dtype=float) diff --git a/src/eaa/tool/imaging/registration.py b/src/eaa/tool/imaging/registration.py index 80f52ae..c6212bb 100644 --- a/src/eaa/tool/imaging/registration.py +++ b/src/eaa/tool/imaging/registration.py @@ -1,4 +1,4 @@ -from typing import Annotated, Any, List, Literal, Optional, Tuple +from typing import Annotated, Any, List, Literal, Optional, Tuple, Dict import logging import re from pathlib import Path @@ -12,10 +12,13 @@ from sciagent.message_proc import generate_openai_message from sciagent.skill import SkillMetadata from sciagent.api.llm_config import LLMConfig -from sciagent.api.memory import MemoryManagerConfig from eaa.tool.imaging.acquisition import AcquireImage -from eaa.image_proc import phase_cross_correlation, translation_nmi_registration +from eaa.image_proc import ( + error_minimization_registration, + phase_cross_correlation, + translation_nmi_registration, +) logger = logging.getLogger(__name__) @@ -33,11 +36,14 @@ def __init__( self, image_acquisition_tool: AcquireImage, llm_config: Optional[LLMConfig] = None, - memory_config: Optional[MemoryManagerConfig] = None, reference_image: np.ndarray = None, reference_pixel_size: float = 1.0, image_coordinates_origin: Literal["top_left", "center"] = "top_left", - registration_method: Literal["phase_correlation", "sift", "mutual_information", "llm"] = "phase_correlation", + registration_method: Literal[ + "phase_correlation", "sift", "mutual_information", "llm", "error_minimization" + ] = "phase_correlation", + registration_algorithm_kwargs: Optional[Dict[str, Any]] = None, + zoom: float = 1.0, log_scale: bool = False, require_approval: bool = False, *args, @@ -66,10 +72,17 @@ def __init__( When this argument is set to "center", the test image is padded/cropped centrally. When it is set to "top_left", the test image is on the bottom and right sides. - registration_method : Literal["phase_correlation", "sift", "mutual_information"], optional + registration_method : Literal["phase_correlation", "sift", "mutual_information", "llm", "error_minimization"], optional The method used to estimate translational offsets. "phase_correlation" - uses phase correlation, "sift" uses feature matching, and - "mutual_information" uses pyramid-based normalized mutual information. + uses phase correlation, "sift" uses feature matching, + "mutual_information" uses pyramid-based normalized mutual information, + and "error_minimization" uses exhaustive integer-shift MSE search with + local quadratic subpixel refinement. + registration_algorithm_kwargs : Optional[Dict[str, Any]], optional + Keyword arguments to pass to the registration algorithm. + zoom : float, optional + Zoom factor applied to both images before registration. Returned + offsets are scaled back to the original image coordinates. log_scale : bool, optional If True, images are transformed as `log10(x + 1)` before registration. """ @@ -77,11 +90,12 @@ def __init__( self.image_acquisition_tool = image_acquisition_tool self.llm_config = llm_config - self.memory_config = memory_config self.reference_image = reference_image self.reference_pixel_size = reference_pixel_size self.image_coordinates_origin = image_coordinates_origin self.registration_method = registration_method + self.registration_algorithm_kwargs = registration_algorithm_kwargs or {} + self.zoom = zoom self.log_scale = log_scale def set_reference_image( @@ -117,32 +131,35 @@ def process_image(self, image: np.ndarray) -> np.ndarray: image = np.log10(image + 1) return image - @tool(name="get_offset_of_latest_image", return_type=ToolReturnType.LIST) - def get_offset_of_latest_image( + def zoom_image(self, image: np.ndarray) -> np.ndarray: + """Apply the configured registration zoom factor to an image.""" + if self.zoom <= 0: + raise ValueError("zoom must be positive.") + if self.zoom == 1.0: + return image + return ndi.zoom(image, zoom=self.zoom, order=1, mode="nearest") + + @tool(name="get_offset", return_type=ToolReturnType.LIST) + def get_offset( self, - register_with: Annotated[ - Literal["previous", "first", "reference"], - "The image to register the latest image with. " - "Can be 'previous', 'first', or 'reference'. " - "'previous': register with the image collected by the acquisition tool before the latest. " - "'first': register with the first image collected by the acquisition tool. " - "'reference': register with the reference image provided to the tool. ", - ], + target: Annotated[ + Literal["previous", "initial", "reference"], + "Reference image buffer against which the current image is compared.", + ] = "initial", ) -> Annotated[ List[float], - "The translational offset [dy (vertical), dx (horizontal)] to apply to the latest " - "acquired image so it aligns with the reference image. Positive y means shifting the " - "latest image downward; positive x means shifting it rightward. The returned values are " - "in physical units, i.e., pixel size is already accounted for.", + "The translational offset [dy, dx] to apply to the latest image " + "so it aligns with the selected reference image.", ]: - """ - Register the latest image collected by the image acquisition tool + """Register the latest image collected by the image acquisition tool and the reference image. """ + register_with = "previous" if target == "previous" else "first" image_t, image_r, psize_t, psize_r = self.get_registration_inputs(register_with) - - offset = self.register_images(image_t, image_r, psize_t, psize_r) - return offset + return self.register_images( + image_t, image_r, psize_t, psize_r, + registration_algorithm_kwargs=self.registration_algorithm_kwargs + ) def get_registration_inputs( self, @@ -329,12 +346,12 @@ def register_images_llm(self, image_t: np.ndarray, image_r: np.ndarray) -> np.nd ], dtype=float) def register_images( - self, - image_t: np.ndarray, - image_r: np.ndarray, - psize_t: float, + self, + image_t: np.ndarray, + image_r: np.ndarray, + psize_t: float, psize_r: float, - registration_method: Optional[Literal["phase_correlation", "sift", "mutual_information", "llm"]] = None, + registration_method: Optional[Literal["phase_correlation", "sift", "mutual_information", "llm", "error_minimization"]] = None, registration_algorithm_kwargs: Optional[dict[str, Any]] = None, ) -> np.ndarray | Tuple[np.ndarray, float] | str: """ @@ -369,6 +386,11 @@ def register_images( - `max_iter` (int, default: `60`) - `tol` (float, default: `1e-4`) + - `registration_method="error_minimization"`: + - `y_valid_fraction` (float, default: `0.8`) + - `x_valid_fraction` (float, default: `0.8`) + - `subpixel` (bool, default: `True`) + - `registration_method="sift"` or `"llm"`: - No algorithm kwargs are currently supported; pass `None` or `{}`. @@ -387,12 +409,15 @@ def register_images( if psize_t != psize_r: # Resize the target image to have the same pixel size as the reference image image_t = ndi.zoom(image_t, psize_t / psize_r) - - if method in {"phase_correlation", "mutual_information"}: + + image_t = self.zoom_image(image_t) + image_r = self.zoom_image(image_r) + + if method in {"phase_correlation", "mutual_information", "error_minimization"}: image_t = self.reconcile_image_shape(image_t, image_r.shape) if method == "phase_correlation": - phase_kwargs = {"use_hanning_window": True} + phase_kwargs = {"filtering_method": "hanning"} phase_kwargs.update(algorithm_kwargs) offset = phase_cross_correlation( image_t, @@ -414,6 +439,10 @@ def register_images( ref=image_r, **mi_kwargs, ) + elif method == "error_minimization": + em_kwargs = {"y_valid_fraction": 0.8, "x_valid_fraction": 0.8, "subpixel": True} + em_kwargs.update(algorithm_kwargs) + offset = error_minimization_registration(image_t, image_r, **em_kwargs) elif method == "sift": if len(algorithm_kwargs) > 0: raise ValueError( @@ -430,7 +459,7 @@ def register_images( offset = self.register_images_llm(image_t=image_t, image_r=image_r) else: raise ValueError(f"Invalid registration method: {method}") - return offset + return np.array(offset, dtype=float) / self.zoom def reconcile_image_shape( self, diff --git a/tests/test_analytical_feature_tracking.py b/tests/test_analytical_feature_tracking.py index 32de788..c48aa85 100644 --- a/tests/test_analytical_feature_tracking.py +++ b/tests/test_analytical_feature_tracking.py @@ -1,6 +1,5 @@ import os import argparse -import pytest import numpy as np import tifffile diff --git a/tests/test_analytical_focusing.py b/tests/test_analytical_focusing.py index eae5c89..92c67fe 100644 --- a/tests/test_analytical_focusing.py +++ b/tests/test_analytical_focusing.py @@ -3,20 +3,28 @@ import numpy as np import tifffile -from matplotlib.figure import Figure from eaa.task_manager.tuning.analytical_focusing import ( AnalyticalScanningMicroscopeFocusingTaskManager, ) from eaa.tool.imaging.acquisition import SimulatedAcquireImage from eaa.tool.imaging.param_tuning import SimulatedSetParameters -from sciagent.tool.base import BaseTool +from eaa.tool.imaging.registration import ImageRegistration import test_utils as tutils +class DummyRegistrationTool: + def __init__(self, name: str, offset=(0.0, 0.0)): + self.name = name + self.offset = np.array(offset, dtype=float) + + def get_offset(self, target="previous"): + return self.offset.tolist() + + class TestAnalyticalFocusing(tutils.BaseTester): - def _build_task_manager(self): + def _build_task_manager(self, registration_tools=None): image_path = os.path.join( self.get_ci_input_data_dir(), "simulated_images", @@ -42,12 +50,17 @@ def _build_task_manager(self): parameter_ranges=[(0.0,), (10.0,)], drift_factor=10, ) + if registration_tools is None: + registration_tools = [ + ImageRegistration(image_acquisition_tool=acquisition_tool) + ] task_manager = AnalyticalScanningMicroscopeFocusingTaskManager( param_setting_tool=param_setting_tool, acquisition_tool=acquisition_tool, initial_parameters={"z": 10.0}, parameter_ranges=[(0.0,), (10.0,)], + registration_tools=registration_tools, line_scan_tool_x_coordinate_args=("x_center",), line_scan_tool_y_coordinate_args=("y_center",), image_acquisition_tool_x_coordinate_args=("x_center",), @@ -104,46 +117,38 @@ def test_task_manager_runs_without_offset_calibration(self, monkeypatch): ) assert acquisition_tool.counter_acquire_image == 0 - def test_linear_drift_prediction_fit_and_apply(self): - task_manager, _ = self._build_task_manager() - task_manager.use_linear_drift_prediction = True - task_manager.n_parameter_drift_points_before_prediction = 3 - task_manager.initial_image_acquisition_position = np.array([10.0, 20.0], dtype=float) - task_manager.image_acquisition_kwargs = { - "y_center": 0.0, - "x_center": 0.0, - "size_y": 128, - "size_x": 128, - } - task_manager.line_scan_kwargs = { - "x_center": 1.0, - "y_center": 2.0, - "length": 10.0, - "scan_step": 1.0, - } - - # Drift model: delta_y = 2 * z + 1, delta_x = -z + 3 + def test_select_drift_uses_linear_model_after_priming(self): + registration_tools = [ + DummyRegistrationTool("primary"), + DummyRegistrationTool("secondary"), + ] + task_manager, _ = self._build_task_manager( + registration_tools=registration_tools + ) + task_manager.registration_selection_priming_iterations = 3 + task_manager.initial_line_scan_position = np.array([10.0, 20.0], dtype=float) + for z in [0.0, 1.0, 2.0]: drift = np.array([2.0 * z + 1.0, -z + 3.0], dtype=float) - current_position = task_manager.initial_image_acquisition_position + drift + current_position = task_manager.initial_line_scan_position + drift task_manager.update_linear_drift_models( parameters=np.array([z], dtype=float), current_position_yx=current_position, ) - assert task_manager.should_apply_linear_drift_prediction() - task_manager.apply_predicted_image_acquisition_position(np.array([4.0], dtype=float)) + chosen_drift, chosen_source = task_manager._select_drift( + candidate_drifts={ + "primary": np.array([0.0, 0.0], dtype=float), + "secondary": np.array([9.0, -1.0], dtype=float), + }, + x_current=np.array([4.0], dtype=float), + ) - assert np.isclose(task_manager.image_acquisition_kwargs["y_center"], 19.0) - assert np.isclose(task_manager.image_acquisition_kwargs["x_center"], 19.0) - assert np.isclose(task_manager.line_scan_kwargs["y_center"], 21.0) - assert np.isclose(task_manager.line_scan_kwargs["x_center"], 20.0) + assert chosen_source == "secondary" + assert np.allclose(chosen_drift, np.array([9.0, -1.0], dtype=float)) def test_run_iteration_applies_registration_offset_and_updates_model(self, monkeypatch): task_manager, acquisition_tool = self._build_task_manager() - task_manager.use_linear_drift_prediction = True - task_manager.n_parameter_drift_points_before_prediction = 1 - task_manager.initial_image_acquisition_position = np.array([0.0, 0.0], dtype=float) task_manager.initialize_kwargs_buffers( initial_line_scan_kwargs={ "x_center": 160.0, @@ -153,10 +158,7 @@ def test_run_iteration_applies_registration_offset_and_updates_model(self, monke }, initial_2d_scan_kwargs={"y_center": 0.0, "x_center": 0.0, "size_y": 200, "size_x": 200}, ) - task_manager.update_linear_drift_models( - parameters=np.array([0.0], dtype=float), - current_position_yx=np.array([0.0, 0.0], dtype=float), - ) + task_manager.initial_image_acquisition_position = np.array([0.0, 0.0], dtype=float) def fake_run_2d_scan(): kwargs = task_manager.image_acquisition_kwargs @@ -172,8 +174,11 @@ def fake_run_2d_scan(): monkeypatch.setattr(task_manager, "run_2d_scan", fake_run_2d_scan) monkeypatch.setattr( task_manager, - "find_offset", - lambda: (np.array([0.0, 0.0]), np.array([100.0, -50.0])), + "find_position_correction", + lambda registration_tool, target: ( + np.array([0.0, 0.0]), + np.array([100.0, -50.0]), + ), ) monkeypatch.setattr(task_manager, "run_line_scan", lambda: 1.0) monkeypatch.setattr(task_manager, "update_optimization_model", lambda fwhm: None) @@ -200,12 +205,13 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke }, initial_2d_scan_kwargs={"y_center": 0.0, "x_center": 0.0, "size_y": 200, "size_x": 200}, ) + task_manager.initial_image_acquisition_position = np.array([0.0, 0.0], dtype=float) monkeypatch.setattr(task_manager, "run_2d_scan", lambda: None) monkeypatch.setattr( task_manager, - "find_offset", - lambda: (np.array([0.0, 0.0]), np.array([0.0, 0.0])), + "apply_drift_correction", + lambda x_current: np.array([0.0, 0.0], dtype=float), ) monkeypatch.setattr(task_manager, "run_line_scan", lambda: 1.0) monkeypatch.setattr(task_manager, "update_optimization_model", lambda fwhm: None) @@ -214,7 +220,7 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke monkeypatch.setattr( task_manager, "update_linear_drift_models", - lambda parameters: call_order.append("update"), + lambda parameters, current_position_yx=None: call_order.append("update"), ) monkeypatch.setattr( task_manager, @@ -226,6 +232,71 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke assert call_order == ["update", "visualize"] + def test_registration_and_scan_position_corrections_across_two_iterations(self): + registration_tool = DummyRegistrationTool("dummy", offset=(2.5, -4.0)) + task_manager, acquisition_tool = self._build_task_manager( + registration_tools=[registration_tool] + ) + task_manager.initialize_kwargs_buffers( + initial_line_scan_kwargs={ + "x_center": 160.0, + "y_center": 170.0, + "length": 60.0, + "scan_step": 1.0, + }, + initial_2d_scan_kwargs={ + "y_center": 175.0, + "x_center": 175.0, + "size_y": 350, + "size_x": 350, + }, + ) + + task_manager.run_2d_scan() + task_manager.run_line_scan() + + task_manager.param_setting_tool.set_parameters(np.array([9.0], dtype=float)) + task_manager.image_acquisition_kwargs["y_center"] += 7.0 + task_manager.image_acquisition_kwargs["x_center"] -= 6.0 + task_manager.run_2d_scan() + + line_scan_correction, image_acquisition_correction = ( + task_manager.find_position_correction(registration_tool, target="previous") + ) + assert np.allclose(line_scan_correction, np.array([4.5, -2.0], dtype=float)) + assert np.allclose( + image_acquisition_correction, + np.array([-2.5, 4.0], dtype=float), + ) + + chosen_line_scan_correction = task_manager.apply_drift_correction( + np.array([9.0], dtype=float) + ) + assert np.allclose( + chosen_line_scan_correction, + np.array([4.5, -2.0], dtype=float), + ) + assert np.allclose( + task_manager.extract_line_scan_position(task_manager.line_scan_kwargs), + np.array([174.5, 158.0], dtype=float), + ) + assert np.allclose( + task_manager.extract_image_acquisition_position( + task_manager.image_acquisition_kwargs + ), + np.array([179.5, 173.0], dtype=float), + ) + + task_manager.run_line_scan() + assert len(acquisition_tool.line_scan_call_history) == 2 + assert np.allclose( + [ + acquisition_tool.line_scan_call_history[-1]["y_center"], + acquisition_tool.line_scan_call_history[-1]["x_center"], + ], + np.array([174.5, 158.0], dtype=float), + ) + if __name__ == "__main__": parser = argparse.ArgumentParser() @@ -241,3 +312,4 @@ def test_run_tuning_iteration_calls_drift_visualization_after_update(self, monke ) tester.test_task_manager_runs() tester.test_task_manager_runs_without_offset_calibration() + tester.test_registration_and_scan_position_corrections_across_two_iterations() diff --git a/tests/test_image_registration_tool.py b/tests/test_image_registration_tool.py index 81509f1..c7f315d 100644 --- a/tests/test_image_registration_tool.py +++ b/tests/test_image_registration_tool.py @@ -33,7 +33,7 @@ def test_image_registration(self): y_center=164, x_center=164, size_y=128, size_x=128 ) - offset = registration_tool.get_offset_of_latest_image(register_with="previous") + offset = registration_tool.get_offset(target="previous") if self.debug: print("Offset: ", offset) @@ -69,7 +69,7 @@ def test_image_registration_diff_size(self): y_center=175, x_center=175, size_y=150, size_x=150 ) - offset = registration_tool.get_offset_of_latest_image(register_with="previous") + offset = registration_tool.get_offset(target="previous") if self.debug: print("Offset: ", offset) @@ -108,7 +108,7 @@ def test_image_registration_mutual_information(self): y_center=164, x_center=164, size_y=128, size_x=128 ) - offset = registration_tool.get_offset_of_latest_image(register_with="previous") + offset = registration_tool.get_offset(target="previous") if self.debug: print("Offset (mutual information): ", offset)