Source code for pyflowreg.util.io._base

from abc import ABC, abstractmethod
from typing import Optional, Union, List, Tuple

import numpy as np


[docs] class VideoReader(ABC): """ Abstract base class for all video file readers. Data is returned in (T, H, W, C) format: - T: Time/frames - H: Height - W: Width - C: Channels This format is optimal for OpenCV operations and can be easily converted to PyTorch format (T, C, H, W) when needed. """ def __init__(self): # Core properties - set by _initialize() self.height: int = 0 self.width: int = 0 self.frame_count: int = 0 self.n_channels: int = 0 self.dtype: Optional[np.dtype] = None # Reader configuration self.buffer_size: int = 500 self.bin_size: int = 1 # State tracking self.current_frame: int = 0 self._initialized: bool = False @abstractmethod def _initialize(self): """ Initialize file-specific properties. Must set: height, width, frame_count, n_channels, dtype """ pass @abstractmethod def _read_raw_frames(self, frame_indices: Union[slice, List[int]]) -> np.ndarray: """ Read raw frames from the underlying file. Args: frame_indices: Either a slice object or list of 0-based indices Returns: Array with shape (T, H, W, C) containing raw frames """ pass
[docs] @abstractmethod def close(self): """Close file handles and clean up resources.""" pass
def _ensure_initialized(self): """Ensure the reader is initialized before operations.""" if not self._initialized: self._initialize() self._initialized = True
[docs] def bin_frames(self, frames: np.ndarray) -> np.ndarray: """ Apply temporal binning to reduce frame count. Args: frames: Input array with shape (T, H, W, C) Returns: Binned array with shape (T//bin_size, H, W, C) """ if self.bin_size == 1: return frames input_dtype = frames.dtype if frames.ndim != 4: raise ValueError(f"Expected 4D array (T, H, W, C), got {frames.ndim}D") T, H, W, C = frames.shape # Pad to make divisible by bin_size pad = (-T) % self.bin_size if pad: frames = np.pad(frames, [(0, pad), (0, 0), (0, 0), (0, 0)], mode="edge") T = frames.shape[0] # Reshape and average frames = frames.reshape(T // self.bin_size, self.bin_size, H, W, C) frames = frames.mean(axis=1).astype(input_dtype) return frames
def __getitem__(self, key: Union[int, slice, Tuple]) -> np.ndarray: """ Array-like indexing with automatic binning. With bin_size > 1, indices refer to binned frames: - reader[0] returns average of first bin_size frames - reader[1] returns average of next bin_size frames Returns: Single frame: (H, W, C) Multiple frames: (T, H, W, C) """ self._ensure_initialized() # Calculate binned frame count binned_count = (self.frame_count + self.bin_size - 1) // self.bin_size # Handle single integer if isinstance(key, int): if key < 0: key = binned_count + key if key < 0 or key >= binned_count: raise IndexError( f"Index {key} out of range for {binned_count} binned frames" ) # Get raw frame range for this bin start = key * self.bin_size end = min((key + 1) * self.bin_size, self.frame_count) # Read and average frames raw_frames = self._read_raw_frames(slice(start, end)) binned_frame = raw_frames.mean(axis=0).astype(raw_frames.dtype) return binned_frame # Handle slice elif isinstance(key, slice): start, stop, step = key.indices(binned_count) if start >= stop: return np.empty( (0, self.height, self.width, self.n_channels), dtype=self.dtype ) # Collect all requested bins binned_frames = [] for bin_idx in range(start, stop, step): frame_start = bin_idx * self.bin_size frame_end = min((bin_idx + 1) * self.bin_size, self.frame_count) raw_frames = self._read_raw_frames(slice(frame_start, frame_end)) binned_frame = raw_frames.mean(axis=0, keepdims=True).astype( raw_frames.dtype ) binned_frames.append(binned_frame) return np.concatenate(binned_frames, axis=0) # Handle list or numpy array (fancy indexing) elif isinstance(key, (list, np.ndarray)): # Convert to numpy array if it's a list indices = np.asarray(key, dtype=np.int64) # Handle negative indices indices = np.where(indices < 0, binned_count + indices, indices) # Check bounds if np.any(indices < 0) or np.any(indices >= binned_count): raise IndexError(f"Index out of range for {binned_count} binned frames") # Collect frames at specified indices frames_list = [] for idx in indices: idx = int(idx) # Ensure it's a Python int frame_start = idx * self.bin_size frame_end = min((idx + 1) * self.bin_size, self.frame_count) raw_frames = self._read_raw_frames(slice(frame_start, frame_end)) binned_frame = raw_frames.mean(axis=0, keepdims=True).astype( raw_frames.dtype ) frames_list.append(binned_frame) # Return (T, H, W, C) for consistency with slice indexing return np.concatenate(frames_list, axis=0) # Handle tuple for advanced indexing elif isinstance(key, tuple): frame_key, *rest = key # Get frames first if isinstance(frame_key, int): frames = self[frame_key] # Returns (H, W, C) frames = frames[np.newaxis, ...] # Add T dimension back else: frames = self[frame_key] # Returns (T, H, W, C) # Apply additional slicing if rest: # Convert to handle both (T, ...) and direct (...) slicing if frames.ndim == 4: # Has time dimension full_key = (slice(None),) + tuple(rest) else: # Single frame, no time dimension full_key = tuple(rest) frames = frames[full_key] return frames else: raise TypeError(f"Invalid index type: {type(key)}")
[docs] def read_batch(self) -> Optional[np.ndarray]: """ Read next batch of frames with binning. Returns: Array with shape (T, H, W, C) or None if no more frames """ self._ensure_initialized() if not self.has_batch(): return None # Calculate frames to read frames_to_read = self.buffer_size * self.bin_size end_frame = min(self.current_frame + frames_to_read, self.frame_count) # Read raw frames raw_frames = self._read_raw_frames(slice(self.current_frame, end_frame)) self.current_frame = end_frame # Apply binning return self.bin_frames(raw_frames)
[docs] def has_batch(self) -> bool: """Check if more frames are available.""" return self.current_frame < self.frame_count
[docs] def reset(self): """Reset to beginning of file.""" self.current_frame = 0
def __len__(self) -> int: """Number of frames after binning.""" self._ensure_initialized() return (self.frame_count + self.bin_size - 1) // self.bin_size def __iter__(self): """Make reader iterable.""" self.reset() return self def __next__(self) -> np.ndarray: """Iterator protocol.""" if not self.has_batch(): raise StopIteration return self.read_batch() @property def shape(self) -> Tuple[int, int, int, int]: """ Shape after binning. Returns ------- tuple of int Shape as (T_binned, H, W, C) """ self._ensure_initialized() return (len(self), self.height, self.width, self.n_channels) @property def unbinned_shape(self) -> Tuple[int, int, int, int]: """ Original shape before binning. Returns ------- tuple of int Shape as (T_original, H, W, C) """ self._ensure_initialized() return (self.frame_count, self.height, self.width, self.n_channels)
[docs] def to_pytorch(self, frames: np.ndarray) -> np.ndarray: """ Convert from OpenCV (T, H, W, C) to PyTorch (T, C, H, W) format. """ if frames.ndim == 3: # Single frame (H, W, C) return np.transpose(frames, (2, 0, 1)) elif frames.ndim == 4: # Multiple frames (T, H, W, C) return np.transpose(frames, (0, 3, 1, 2)) else: raise ValueError(f"Expected 3D or 4D array, got {frames.ndim}D")
def __repr__(self): self._ensure_initialized() return ( f"{self.__class__.__name__}(shape={self.shape}, " f"dtype={self.dtype}, bin_size={self.bin_size})" ) def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()
[docs] class VideoWriter(ABC): """ Abstract base class for all video file writers. Defines a common interface for writing frames. """ def __init__(self): self.initialized = False self.height = 0 self.width = 0 self.n_channels = 0 self.bit_depth = 0 self.dtype = None
[docs] def init(self, first_frame_batch: np.ndarray): """Initializes writer properties based on the first batch of frames.""" shape = first_frame_batch.shape self.height = shape[0] self.width = shape[1] self.n_channels = shape[2] if len(shape) > 2 else 1 self.dtype = first_frame_batch.dtype self.bit_depth = self.dtype.itemsize * 8 self.initialized = True
[docs] @abstractmethod def write_frames(self, frames: np.ndarray): """Writes a batch of frames to the file.""" pass
[docs] @abstractmethod def close(self): """Closes the writer and finalizes the file.""" pass
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()