Source code for pyflowreg.motion_correction.flow_reg_live

"""
FlowRegLive - Real-time motion compensation with adaptive reference.

Implements fast online registration with robust preregistered reference
and temporal filtering using half kernels for speed.
"""

from collections import deque
from typing import Optional, Tuple
import numpy as np

from pyflowreg.motion_correction.OF_options import OFOptions, QualitySetting
from pyflowreg.motion_correction.compensate_arr import compensate_arr
from pyflowreg.core.warping import imregister_wrapper
from pyflowreg.util.image_processing import (
    normalize,
    apply_gaussian_filter,
    gaussian_filter_1d_half_kernel,
)


[docs] class FlowRegLive: """ Real-time motion compensation with adaptive reference updating. Optimized for speed using: - Fast quality setting for optical flow - Half kernel temporal filtering - Circular buffer for temporal history - Weighted reference updates every 20 frames """
[docs] def __init__( self, options: Optional[OFOptions] = None, reference_buffer_size: int = 50, reference_update_interval: int = 20, reference_update_weight: float = 0.2, truncate: float = 4.0, **kwargs, ): """ Initialize FlowRegLive for real-time motion compensation. Args: options: OFOptions for configuration. If None, uses fast defaults. reference_buffer_size: Size of buffer for reference initialization reference_update_interval: Update reference every N frames reference_update_weight: Weight for mixing new frames into reference truncate: Truncate filter at this many standard deviations **kwargs: Additional parameters to override OFOptions """ # Create options with fast defaults for real-time if options is None: options = OFOptions( quality_setting=QualitySetting.FAST, **kwargs, ) else: # Copy and override for speed options = options.model_copy( update={ "quality_setting": QualitySetting.FAST, **kwargs, } ) self.options = options self.truncate = truncate # Resolve displacement function from options self._get_disp = self.options.resolve_get_displacement() # Reference buffer self.reference_buffer = deque(maxlen=reference_buffer_size) # Get sigma values and determine temporal buffer size self.sigma = np.asarray(self.options.sigma) if self.sigma.ndim == 1: self.sigma_2d = self.sigma[:2] # [sy, sx] self.sigma_t = self.sigma[2] if len(self.sigma) > 2 else 0.0 else: # Per-channel, find max temporal sigma self.sigma_2d = self.sigma[0, :2] # Use first channel for 2D # Get max temporal sigma across all channels self.sigma_t = np.max(self.sigma[:, 2]) if self.sigma.shape[1] > 2 else 0.0 # Calculate temporal buffer size from sigma_t # Buffer size = kernel radius + 1, where radius = truncate * sigma_t temporal_buffer_size = max(1, int(self.truncate * self.sigma_t + 0.5) + 1) self.temporal_buffer = deque(maxlen=temporal_buffer_size) # Reference management self.reference_raw = ( None # Original reference (for warping unfiltered frames if needed) ) self.reference_proc = ( None # Filtered reference (main reference for OF and normalization) ) self.reference_update_interval = reference_update_interval self.reference_update_weight = reference_update_weight self.frame_count = 0 # Flow initialization self.last_flow = None # Normalization parameters from filtered reference self.norm_min = None self.norm_max = None
[docs] def set_reference(self, frames: Optional[np.ndarray] = None): """ Initialize reference from frames or buffer. Args: frames: Optional array of frames (T,H,W,C). If None, uses buffer. """ if frames is not None: # Use provided frames if frames.ndim == 2: # Single 2D frame (H,W) - convert to 3D self.reference_raw = frames[..., np.newaxis].copy() elif frames.ndim == 3: # Single frame (H,W,C) self.reference_raw = frames.copy() else: # Multiple frames - compensate and average print(f"Preregistering {frames.shape[0]} frames for reference...") # Use balanced quality for reference ref_options = self.options.model_copy( update={"quality_setting": QualitySetting.BALANCED} ) # Take mean as initial reference initial_ref = np.mean(frames, axis=0) # Compensate frames registered, _ = compensate_arr(frames, initial_ref, ref_options) # Average compensated frames as reference self.reference_raw = np.mean(registered, axis=0) print("Reference initialized from compensated frames") else: # Use buffer if len(self.reference_buffer) == 0: raise ValueError( "Reference buffer is empty, cannot initialize reference" ) frames = np.array(list(self.reference_buffer)) self.set_reference(frames) return # Ensure 3D reference (H,W,C) if self.reference_raw.ndim == 2: self.reference_raw = self.reference_raw[..., np.newaxis] # Preprocess reference (normalize and filter) self._preprocess_reference() # Store normalization parameters from FILTERED reference if self.options.channel_normalization.value == "separate": self.norm_min = [] self.norm_max = [] for c in range(self.reference_proc.shape[2]): self.norm_min.append(self.reference_proc[..., c].min()) self.norm_max.append(self.reference_proc[..., c].max()) else: self.norm_min = self.reference_proc.min() self.norm_max = self.reference_proc.max() # Reset flow initialization self.last_flow = None # Clear temporal buffer self.temporal_buffer.clear()
def _preprocess_reference(self): """Preprocess reference frame (normalize and filter).""" # First normalize ref_norm = normalize( self.reference_raw, channel_normalization=self.options.channel_normalization.value, ) # Apply 2D Gaussian filter self.reference_proc = apply_gaussian_filter( ref_norm, self.sigma_2d, mode="reflect", truncate=self.truncate ) def _normalize_with_reference(self, frame: np.ndarray) -> np.ndarray: """Normalize frame using filtered reference min/max values.""" if self.norm_min is None or self.norm_max is None: # Fallback to frame's own normalization return normalize( frame, channel_normalization=self.options.channel_normalization.value ) eps = 1e-8 if self.options.channel_normalization.value == "separate": result = np.zeros_like(frame, dtype=np.float64) for c in range(frame.shape[-1]): if c < len(self.norm_min): result[..., c] = (frame[..., c] - self.norm_min[c]) / ( self.norm_max[c] - self.norm_min[c] + eps ) else: # Fallback for extra channels ch = frame[..., c] result[..., c] = (ch - ch.min()) / (ch.max() - ch.min() + eps) return result else: return (frame - self.norm_min) / (self.norm_max - self.norm_min + eps) def __call__( self, frame: np.ndarray, normalize: bool = True ) -> Tuple[np.ndarray, np.ndarray]: """ Process single frame with motion compensation. Args: frame: Input frame (H,W,C) or (H,W) normalize: Whether to normalize (unused, always normalizes) Returns: Tuple of (registered_frame, flow_field) with shapes (H,W,C) and (H,W,2) """ if self.reference_raw is None: # Add to buffer and return original self.reference_buffer.append(frame.copy()) return frame, np.zeros(frame.shape[:2] + (2,), dtype=np.float32) # Ensure 3D frame if frame.ndim == 2: frame = frame[..., np.newaxis] self.frame_count += 1 # Normalize using filtered reference parameters frame_norm = self._normalize_with_reference(frame) # Apply 2D Gaussian filter frame_2d_filtered = apply_gaussian_filter( frame_norm, self.sigma_2d, mode="reflect", truncate=self.truncate ) # Add to temporal buffer self.temporal_buffer.append(frame_2d_filtered) # Apply temporal filtering with half kernel if self.sigma_t > 0 and len(self.temporal_buffer) > 1: frame_proc = gaussian_filter_1d_half_kernel( self.temporal_buffer, self.sigma_t, truncate=self.truncate ) else: frame_proc = frame_2d_filtered # Compute optical flow with initialization flow_params = self.options.to_dict() # Ensure weight is numpy array if "weight" in flow_params and isinstance(flow_params["weight"], list): flow_params["weight"] = np.array(flow_params["weight"]) flow = self._get_disp( self.reference_proc, frame_proc, uv=self.last_flow, # Use previous flow as initialization **flow_params, ) # Store flow for next frame self.last_flow = flow.copy() # Apply warping to original frame registered = imregister_wrapper( frame, flow[..., 0], flow[..., 1], self.reference_raw, interpolation_method=self.options.interpolation_method.value, ) # Ensure registered maintains 3D shape if registered.ndim == 2: registered = registered[..., np.newaxis] # Update reference every N frames if self.frame_count % self.reference_update_interval == 0: # Warp the processed frame and mix into processed reference registered_proc = imregister_wrapper( frame_proc, flow[..., 0], flow[..., 1], self.reference_proc, interpolation_method=self.options.interpolation_method.value, ) # Ensure 3D shape for consistency if registered_proc.ndim == 2: registered_proc = registered_proc[..., np.newaxis] # Mix into processed reference self.reference_proc = ( 1 - self.reference_update_weight ) * self.reference_proc + self.reference_update_weight * registered_proc # Also update raw reference for consistency self.reference_raw = ( 1 - self.reference_update_weight ) * self.reference_raw + self.reference_update_weight * registered # Update normalization parameters from new filtered reference if self.options.channel_normalization.value == "separate": self.norm_min = [] self.norm_max = [] for c in range(self.reference_proc.shape[2]): self.norm_min.append(self.reference_proc[..., c].min()) self.norm_max.append(self.reference_proc[..., c].max()) else: self.norm_min = self.reference_proc.min() self.norm_max = self.reference_proc.max() return registered, flow
[docs] def register_frames(self, frames: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: """ Register multiple frames in sequence. Args: frames: Array of frames (T,H,W,C) Returns: Tuple of (registered_frames, flow_fields) """ registered = np.empty_like(frames) flows = np.empty(frames.shape[:3] + (2,), dtype=np.float32) for i in range(frames.shape[0]): registered[i], flows[i] = self(frames[i], False) return registered, flows
[docs] def reset_reference(self, new_reference: np.ndarray): """ Set new reference and reset flow. Args: new_reference: New reference frame (H,W,C) """ self.set_reference(new_reference)
[docs] def get_current_flow(self) -> Optional[np.ndarray]: """Get current flow field state.""" return self.last_flow.copy() if self.last_flow is not None else None
[docs] def set_flow_init(self, w_init: np.ndarray): """Set flow field initialization for next frame.""" self.last_flow = w_init.copy()