Source code for pyflowreg.core.warping

"""
Utility functions for image warping and valid mask computation.

Provides helpers for computing backward warping validity masks and
binary mask warping with nearest-neighbor interpolation.
"""

import cv2
import numpy as np
from multiprocessing import Pool, cpu_count
from typing import Optional


def backward_valid_mask(u, v):
    """
    Compute the valid region mask after backward warping.

    Given displacement fields u and v, computes which pixels remain
    in-bounds after applying the backward warp. This is essential for
    tracking valid data regions in motion-corrected sequences.

    Parameters
    ----------
    u : ndarray, shape (H, W)
        Horizontal displacement field (x-direction)
    v : ndarray, shape (H, W)
        Vertical displacement field (y-direction)

    Returns
    -------
    valid_mask : ndarray, shape (H, W), dtype=bool
        Boolean mask where True indicates pixel remains in-bounds

    Notes
    -----
    Mirrors MATLAB idx_warp computation from imregister_wrapper_w.
    The backward warp maps from output coordinates (y, x) to input
    coordinates (y+v, x+u). A pixel is valid if the mapped location
    falls within [0, H) x [0, W).

    Examples
    --------
    >>> u = np.ones((10, 10)) * 2.5  # shift right by 2.5 pixels
    >>> v = np.zeros((10, 10))
    >>> mask = backward_valid_mask(u, v)
    >>> # Right edge pixels will be False (out of bounds)
    """
    H, W = u.shape

    # Generate coordinate grid
    gy, gx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")

    # Compute mapped coordinates
    mx = gx + u
    my = gy + v

    # Check bounds
    valid = (mx >= 0) & (mx < W) & (my >= 0) & (my < H)

    return valid.astype(bool)


def imregister_binary(mask, u, v):
    """
    Warp a binary mask using nearest-neighbor interpolation.

    Applies backward warping to a binary mask and composes with the
    in-bounds mask to prevent artifacts from extrapolation.

    Parameters
    ----------
    mask : ndarray, shape (H, W)
        Binary mask to warp (will be converted to bool)
    u : ndarray, shape (H, W)
        Horizontal displacement field (x-direction)
    v : ndarray, shape (H, W)
        Vertical displacement field (y-direction)

    Returns
    -------
    warped_mask : ndarray, shape (H, W), dtype=bool
        Warped mask with out-of-bounds pixels set to False

    Notes
    -----
    Mirrors MATLAB mask warping in get_session_valid_index_v3:
        [reg_m, idx_warp] = imregister_wrapper_w(
            double(m), w, zeros(size(m)), 'nearest')
        aligned_valid_masks{i} = (reg_m > 0.5) & idx_warp

    Uses cv2.INTER_NEAREST for binary interpolation to avoid
    intermediate values.

    Examples
    --------
    >>> mask = np.ones((10, 10), dtype=bool)
    >>> u = np.ones((10, 10)) * 2.0  # shift right
    >>> v = np.zeros((10, 10))
    >>> warped = imregister_binary(mask, u, v)
    >>> # Right 2 columns will be False
    """
    H, W = mask.shape

    # Generate coordinate grid
    gy, gx = np.meshgrid(np.arange(H), np.arange(W), indexing="ij")

    # Compute map coordinates (clipped for cv2.remap safety)
    mx = np.clip((gx + u).astype(np.float32), 0, W - 1)
    my = np.clip((gy + v).astype(np.float32), 0, H - 1)

    # Warp with nearest-neighbor
    warped = cv2.remap(
        mask.astype(np.float32),
        mx,
        my,
        cv2.INTER_NEAREST,
        borderMode=cv2.BORDER_CONSTANT,
        borderValue=0,
    )

    # Compute in-bounds mask
    in_bounds = (gx + u >= 0) & (gx + u < W) & (gy + v >= 0) & (gy + v < H)

    # Combine: warped mask AND in-bounds
    return (warped > 0.5) & in_bounds


def compute_batch_valid_masks(w):
    """
    Compute valid masks for a batch of displacement fields.

    Parameters
    ----------
    w : ndarray, shape (T, H, W, 2)
        Batch of displacement fields where w[..., 0] = u, w[..., 1] = v

    Returns
    -------
    valid_masks : ndarray, shape (T, H, W), dtype=uint8
        Valid masks for each frame (0 or 255 for compatibility)

    Notes
    -----
    Used in BatchMotionCorrector to persist per-frame validity.
    Returns uint8 for efficient HDF5 storage.
    """
    T = w.shape[0]
    H, W = w.shape[1:3]

    valid_batch = np.empty((T, H, W), dtype=np.uint8)

    for t in range(T):
        u = w[t, ..., 0]
        v = w[t, ..., 1]
        valid_batch[t] = backward_valid_mask(u, v).astype(np.uint8) * 255

    return valid_batch


[docs] def imregister_wrapper(f2_level, u, v, f1_level, interpolation_method="cubic"): """ Backward warp of moving image using displacement field. Performs backward registration by warping f2_level toward f1_level using displacement field (u, v) with bicubic or bilinear interpolation via cv2.remap. Out-of-bounds pixels are replaced with corresponding pixels from f1_level. Parameters ---------- f2_level : np.ndarray Moving image to warp, shape (H, W) or (H, W, C) u : np.ndarray Horizontal displacement field, shape (H, W) v : np.ndarray Vertical displacement field, shape (H, W) f1_level : np.ndarray Fixed (reference) image, shape (H, W) or (H, W, C) interpolation_method : str, default='cubic' Interpolation method: 'cubic' (bicubic) or 'linear' (bilinear). Defaults to bicubic following Sun et al. best practices. Returns ------- warped : np.ndarray Backward-warped image, shape (H, W) or (H, W, C) Notes ----- The displacement convention is: warped_pos = original_pos + (u, v) Out-of-bounds regions use values from f1_level to maintain continuity. Bicubic interpolation is more accurate than bilinear for optical flow estimation and is the recommended default. References ---------- .. [1] Sun, D., Roth, S., and Black, M. J. "Secrets of Optical Flow Estimation and Their Principles", CVPR 2010. """ if f2_level.ndim == 2: f2_level = f2_level[:, :, None] f1_level = f1_level[:, :, None] # f2_level = f2_level[1:-1, 1:-1] # f1_level = f1_level[1:-1, 1:-1] # u = u[1:-1, 1:-1] # v = v[1:-1, 1:-1] H, W, C = f2_level.shape grid_y, grid_x = np.meshgrid(np.arange(H), np.arange(W), indexing="ij") map_x = (grid_x + u).astype(np.float32) map_y = (grid_y + v).astype(np.float32) out_of_bounds = (map_x < 0) | (map_x >= W) | (map_y < 0) | (map_y >= H) map_x_clipped = np.clip(map_x, 0, W - 1).astype(np.float32) map_y_clipped = np.clip(map_y, 0, H - 1).astype(np.float32) if interpolation_method.lower() == "cubic": interp = cv2.INTER_CUBIC elif interpolation_method.lower() == "linear": interp = cv2.INTER_LINEAR else: raise ValueError("Unsupported interpolation method. Use 'linear' or 'cubic'.") warped = np.empty_like(f2_level, dtype=np.float32) for c in range(C): warped[:, :, c] = cv2.remap( f2_level[:, :, c], map_x_clipped, map_y_clipped, interpolation=interp, borderMode=cv2.BORDER_REPLICATE, ) for c in range(C): warped[:, :, c][out_of_bounds] = f1_level[:, :, c][out_of_bounds] if warped.shape[2] == 1: warped = warped[:, :, 0] return warped
[docs] def warpingDepth(eta, levels, m, n): """ Calculate maximum pyramid depth for given dimension and warping factor. Determines how many pyramid levels can be computed given the downsampling factor eta before the dimension becomes too small (< 10 pixels) for reliable optical flow estimation. At pyramid level i, the dimension size is dim * eta^i, where dim = min(m, n). Parameters ---------- eta : float Pyramid downsampling factor per level (0 < eta <= 1) levels : int Maximum number of levels to attempt m : int First dimension n : int Second dimension Returns ------- warpingdepth : int Maximum pyramid depth satisfying: round(dim * eta^i) >= 10, where dim = min(m, n). Approximately floor(log(10/dim) / log(eta)). Notes ----- When called from get_displacement with (m, m) for height and (n, n) for width, this enables independent pyramid depth computation per dimension, allowing narrow ROIs to achieve large displacements along their longer dimension without being limited by the shorter dimension. """ min_dim = min(m, n) warpingdepth = 0 for _ in range(levels): warpingdepth += 1 min_dim *= eta if round(min_dim) < 10: break return warpingdepth
def align_sequence( batch: np.ndarray, displacement: np.ndarray, reference: np.ndarray, interpolation_method: str = "cubic", n_workers: Optional[int] = None, ) -> np.ndarray: """ Apply displacement field to align a batch of frames to a reference. This function warps each frame in a batch using the provided displacement field to align them to a common reference frame. Uses multiprocessing for efficient batch processing. Parameters ---------- batch : ndarray, shape (T, H, W, C) or (T, H, W) Batch of frames to align displacement : ndarray, shape (H, W, 2) Displacement field where displacement[..., 0] = u (horizontal), displacement[..., 1] = v (vertical) reference : ndarray, shape (H, W, C) or (H, W) Reference image used to fill out-of-bounds regions interpolation_method : str, default='cubic' Interpolation method: 'cubic' or 'linear' n_workers : int, optional Number of parallel workers. If None, uses cpu_count() Returns ------- aligned_batch : ndarray Aligned frames with same shape and dtype as input batch Examples -------- >>> batch = np.random.rand(100, 512, 512, 1) >>> displacement = np.zeros((512, 512, 2)) # No displacement >>> reference = np.mean(batch, axis=0) >>> aligned = align_sequence(batch, displacement, reference) """ # Track if we need to add/remove channel dimension added_channel_dim = False if batch.ndim == 3: batch = batch[..., np.newaxis] added_channel_dim = True if reference.ndim == 2: reference = reference[..., np.newaxis] T, H, W, C = batch.shape u = displacement[..., 0] v = displacement[..., 1] # Prepare reference as float64 reference_f64 = reference.astype(np.float64, copy=False) # Set up multiprocessing if n_workers is None: n_workers = min(cpu_count(), T) # Process frames if n_workers > 1: # Multiprocessing requires a picklable function with single argument # So we use functools.partial to bind the fixed parameters from functools import partial # Create partial function with fixed displacement and reference warp_func = partial( imregister_wrapper, u=u, v=v, f1_level=reference_f64, interpolation_method=interpolation_method, ) # Convert frames to float64 for processing frames_f64 = [batch[t].astype(np.float64, copy=False) for t in range(T)] with Pool(processes=n_workers) as pool: warped_frames = pool.map(warp_func, frames_f64) else: # Sequential processing for small batches warped_frames = [] for t in range(T): warped = imregister_wrapper( batch[t].astype(np.float64, copy=False), u, v, reference_f64, interpolation_method=interpolation_method, ) warped_frames.append(warped) # Stack results and preserve dtype aligned_batch = np.empty_like(batch) for t, warped in enumerate(warped_frames): if warped.ndim == 2: warped = warped[..., np.newaxis] aligned_batch[t] = warped.astype(batch.dtype, copy=False) # Remove singleton channel dimension only if we added it if added_channel_dim and aligned_batch.shape[-1] == 1: aligned_batch = aligned_batch[..., 0] return aligned_batch