Source code for pyflowreg.core.diso_optical_flow

"""
Dense Inverse Search Optical Flow (DIS) backend using OpenCV.
Provides an alternative optical flow computation method to the variational approach.
"""

import cv2
import numpy as np
from typing import Optional, Dict, Any


[docs] class DisoOF: """ Dense Inverse Search Optical Flow (DIS) implementation using OpenCV. This class provides a callable interface compatible with get_displacement for computing optical flow between fixed and moving images. Uses lazy initialization to ensure pickle compatibility for multiprocessing. Parameters ---------- preset : int DIS preset mode: - cv2.DISOPTICAL_FLOW_PRESET_ULTRAFAST: Ultrafast preset - cv2.DISOPTICAL_FLOW_PRESET_FAST: Fast preset - cv2.DISOPTICAL_FLOW_PRESET_MEDIUM: Medium preset (default) finest_scale : int Finest scale for the image pyramid (0 = original scale) gradient_descent_iterations : int Number of gradient descent iterations at each pyramid level patch_size : int Size of the patch for matching (default: 8) patch_stride : int Stride between neighbor patches (default: 4) use_mean_normalization : bool Whether to use mean normalization (default: True) use_spatial_propagation : bool Whether to use spatial propagation (default: True) """ def __init__( self, preset: int = cv2.DISOPTICAL_FLOW_PRESET_MEDIUM, finest_scale: int = 2, gradient_descent_iterations: int = 12, patch_size: int = 8, patch_stride: int = 4, use_mean_normalization: bool = True, use_spatial_propagation: bool = True, ): self._cfg = dict( preset=preset, finest_scale=finest_scale, gradient_descent_iterations=gradient_descent_iterations, patch_size=patch_size, patch_stride=patch_stride, use_mean_normalization=use_mean_normalization, use_spatial_propagation=use_spatial_propagation, ) self._dis = None def __getstate__(self): """Support pickling for multiprocessing.""" return {"_cfg": self._cfg, "_dis": None} def __setstate__(self, state): """Support unpickling for multiprocessing.""" self._cfg = state["_cfg"] self._dis = None def _ensure(self): """Lazy initialization of OpenCV DIS object.""" if self._dis is not None: return d = cv2.DISOpticalFlow_create(self._cfg["preset"]) d.setFinestScale(self._cfg["finest_scale"]) d.setGradientDescentIterations(self._cfg["gradient_descent_iterations"]) d.setPatchSize(self._cfg["patch_size"]) d.setPatchStride(self._cfg["patch_stride"]) d.setUseMeanNormalization(bool(self._cfg["use_mean_normalization"])) d.setUseSpatialPropagation(bool(self._cfg["use_spatial_propagation"])) self._dis = d def _to_gray( self, img: np.ndarray, weight: Optional[np.ndarray] = None ) -> np.ndarray: """ Convert image to grayscale using weights if provided. Parameters ---------- img : np.ndarray Input image of shape (H, W) or (H, W, C) weight : np.ndarray, optional Channel weights - can be: - 1D array of channel weights - 2D array (H, W) for spatial weights - 3D array (H, W, C) for full spatial-channel weights Returns ------- np.ndarray Grayscale image of shape (H, W) """ if img.ndim == 2: return img if img.ndim == 3: if img.shape[2] == 1: return img[:, :, 0] # Handle different weight formats if weight is not None: if weight.ndim == 1: # 1D channel weights - normalize and broadcast if len(weight) != img.shape[2]: # Use equal weights if mismatch weight = np.ones(img.shape[2]) / img.shape[2] else: weight = weight / weight.sum() # Broadcast to spatial dimensions weight = np.ones( (img.shape[0], img.shape[1], img.shape[2]) ) * weight.reshape(1, 1, -1) elif weight.ndim == 2: # 2D spatial weights - broadcast to all channels weight = weight[:, :, np.newaxis] # Apply spatial weights equally to all channels, then average return np.mean(img * weight, axis=2) elif weight.ndim == 3: # Full 3D weights - use as is if weight.shape != img.shape: # Fallback to equal weights if shape mismatch weight = np.ones(img.shape[2]) / img.shape[2] weight = np.ones( (img.shape[0], img.shape[1], img.shape[2]) ) * weight.reshape(1, 1, -1) else: # Fallback to equal weights weight = np.ones(img.shape[2]) / img.shape[2] weight = np.ones( (img.shape[0], img.shape[1], img.shape[2]) ) * weight.reshape(1, 1, -1) else: # Equal weights for all channels weight = np.ones(img.shape[2]) / img.shape[2] weight = np.ones( (img.shape[0], img.shape[1], img.shape[2]) ) * weight.reshape(1, 1, -1) # Weighted average return np.sum(img * weight, axis=2) raise ValueError(f"Unexpected image shape: {img.shape}") def _normalize(self, a: np.ndarray, b: np.ndarray) -> tuple: """ Convert images to uint8 [0,255] range for OpenCV. Handles both [0,1] float and [0,255] uint8 inputs. Parameters ---------- a : np.ndarray First image b : np.ndarray Second image Returns ------- tuple Both images converted to uint8 [0,255] range """ # Check if already uint8 if a.dtype == np.uint8 and b.dtype == np.uint8: return a, b # Handle uint8 input (convert to float for consistent processing) if a.dtype == np.uint8: a = a.astype(np.float32) / 255.0 if b.dtype == np.uint8: b = b.astype(np.float32) / 255.0 # Now assume [0,1] range and convert to [0,255] # Clip to [0,1] range (in case of slight overflow from preprocessing) a_clipped = np.clip(a, 0, 1) b_clipped = np.clip(b, 0, 1) # Convert to [0,255] uint8 A = (a_clipped * 255).astype(np.uint8) B = (b_clipped * 255).astype(np.uint8) return A, B def __call__( self, fixed: np.ndarray, moving: np.ndarray, w: Optional[np.ndarray] = None, weight: Optional[np.ndarray] = None, **kwargs, ) -> np.ndarray: """ Compute optical flow between fixed and moving images. Parameters ---------- fixed : np.ndarray Reference/fixed image of shape (H, W) or (H, W, C) Expected to be normalized to [0,1] range or uint8 [0,255] moving : np.ndarray Moving image of shape (H, W) or (H, W, C) Expected to be normalized to [0,1] range or uint8 [0,255] w : np.ndarray, optional Initial flow field of shape (H, W, 2) for warm start weight : np.ndarray, optional Channel weights for multi-channel images **kwargs : dict Additional parameters (for compatibility, not used) Returns ------- np.ndarray Displacement field of shape (H, W, 2) as float32 """ self._ensure() # Convert to grayscale using weights a = self._to_gray(fixed, weight) b = self._to_gray(moving, weight) # Normalize to [0,255] uint8 A, B = self._normalize(a, b) # Prepare initial flow if provided init = None if ( w is not None and isinstance(w, np.ndarray) and w.ndim == 3 and w.shape[2] == 2 ): init = w.astype(np.float32, copy=False) # Compute optical flow flow = self._dis.calc(A, B, init) # Return as float32 return flow.astype(np.float32, copy=False)
[docs] def set_preset(self, preset: int): """ Update the DIS preset configuration. Parameters ---------- preset : int One of cv2.DISOPTICAL_FLOW_PRESET_* """ cfg = dict(self._cfg) cfg["preset"] = preset self._cfg = cfg self._dis = None
[docs] def get_params(self) -> Dict[str, Any]: """ Get current DIS parameters. Returns ------- dict Dictionary containing current parameter values """ return dict(self._cfg)
[docs] def set_params(self, **params): """ Update DIS parameters. Parameters ---------- **params : dict Parameters to update """ self._cfg.update(params) self._dis = None
def _diso_factory(**kwargs): """ Factory function for creating DisoOF instances suitable for multiprocessing. Returns a DisoOF instance that is pickle-safe due to lazy initialization. Parameters ---------- **kwargs : dict Parameters to pass to DisoOF constructor Returns ------- DisoOF DisoOF instance with specified parameters """ return DisoOF(**kwargs)