Source code for pyflowreg.util.io.multifile_wrappers

"""
Wrapper classes for video file I/O operations.
Provides multi-file, multi-channel, and subset reading/writing capabilities.
"""

from typing import Union, List
from pathlib import Path

import numpy as np

from pyflowreg.util.io._base import VideoReader, VideoWriter
from pyflowreg.util.io.factory import get_video_file_reader, get_video_file_writer


[docs] class MULTIFILEFileWriter(VideoWriter): """ File writer that writes one file per channel. Each channel is saved to a separate file with _ch{N} suffix. """
[docs] def __init__(self, filename: str, file_type: str = "TIFF", **kwargs): """ Initialize multi-file writer. Args: filename: Base output filename or directory file_type: Output format for each channel file **kwargs: Additional parameters passed to individual writers """ super().__init__() # Parse filename path = Path(filename) if path.suffix: self.folder = path.parent self.file_name = path.stem else: self.folder = path self.file_name = "compensated" # Create output directory if needed self.folder.mkdir(parents=True, exist_ok=True) self.file_type = file_type self.writer_parameters = kwargs self.file_writers = []
[docs] def write_frames(self, frames: np.ndarray): """ Write frames to multiple files (one per channel). Args: frames: Array with shape (T, H, W, C) or compatible """ # Normalize input to 4D if frames.ndim == 2: # Single frame, single channel frames = frames[np.newaxis, :, :, np.newaxis] elif frames.ndim == 3: if len(self.file_writers) > 0: # Already initialized if frames.shape[0] == self.height and frames.shape[1] == self.width: frames = frames[np.newaxis, :, :, :] else: frames = frames[:, :, :, np.newaxis] else: # First write - guess format # Assume (T, H, W) for single channel frames = frames[:, :, :, np.newaxis] # Initialize on first write if not self.initialized: T, H, W, C = frames.shape self.height = H self.width = W self.n_channels = C self.dtype = frames.dtype self.bit_depth = frames.dtype.itemsize * 8 self.initialized = True # Create a writer for each channel for ch_idx in range(self.n_channels): ch_filename = ( self.folder / f"{self.file_name}_ch{ch_idx + 1}.{self.file_type}" ) writer = get_video_file_writer( str(ch_filename), self.file_type, **self.writer_parameters ) self.file_writers.append(writer) # Write each channel to its file for ch_idx in range(self.n_channels): channel_frames = frames[:, :, :, ch_idx : ch_idx + 1] # Keep 4D shape self.file_writers[ch_idx].write_frames(channel_frames)
[docs] def close(self): """Close all channel writers.""" for writer in self.file_writers: writer.close() self.file_writers = []
[docs] class MULTICHANNELFileReader(VideoReader): """ Generic multichannel reader that reads from multiple video files and combines them into a single multichannel output. """
[docs] def __init__( self, input_files: List[str], buffer_size: int = 500, bin_size: int = 1, **kwargs, ): """ Initialize multichannel reader. Args: input_files: List of input file paths buffer_size: Buffer size for batch reading bin_size: Temporal binning factor **kwargs: Additional parameters passed to individual readers """ super().__init__() self.buffer_size = buffer_size self.bin_size = bin_size self.filereaders = [] self.reader_kwargs = kwargs # Store file list for initialization self.input_files = input_files
def _initialize(self): """Initialize all file readers and set properties.""" # Create readers for all input files different_bits = False max_dtype = None for i, file_path in enumerate(self.input_files): reader = get_video_file_reader( file_path, self.buffer_size, self.bin_size, **self.reader_kwargs ) # Ensure the reader is initialized if hasattr(reader, "_ensure_initialized"): reader._ensure_initialized() self.filereaders.append(reader) if i == 0: # Set properties from first reader self.height = reader.height self.width = reader.width self.frame_count = reader.frame_count self.dtype = reader.dtype self.n_channels = reader.n_channels max_dtype = self.dtype else: # Validate consistency if self.height != reader.height or self.width != reader.width: raise ValueError(f"Resolution mismatch in file {file_path}") if self.frame_count != reader.frame_count: raise ValueError(f"Frame count mismatch in file {file_path}") # Accumulate channels self.n_channels += reader.n_channels # Handle different data types if reader.dtype != self.dtype: # Use highest precision dtype if np.can_cast(self.dtype, reader.dtype): max_dtype = reader.dtype elif not np.can_cast(reader.dtype, self.dtype): max_dtype = np.float64 different_bits = True if different_bits: print(f"Warning: Different data types in channels, using {max_dtype}") self.dtype = max_dtype # Create combined name self.input_file_name = "_".join([Path(f).stem for f in self.input_files]) def _read_raw_frames(self, frame_indices: Union[slice, List[int]]) -> np.ndarray: """ Read frames from all files and combine channels. Returns: Array with shape (T, H, W, C_total) """ # Convert indices to list for consistent handling if isinstance(frame_indices, slice): start, stop, step = frame_indices.indices(self.frame_count) indices = list(range(start, stop, step)) else: indices = list(frame_indices) if len(indices) == 0: return np.empty( (0, self.height, self.width, self.n_channels), dtype=self.dtype ) # Allocate output array n_frames = len(indices) output = np.zeros( (n_frames, self.height, self.width, self.n_channels), dtype=self.dtype ) # Read from each file and combine ch_offset = 0 for reader in self.filereaders: # Use reader's indexing directly frames = ( reader[indices] if len(indices) > 1 else reader[indices[0] : indices[0] + 1] ) # Ensure 4D if frames.ndim == 3: frames = frames[np.newaxis, ...] # Copy to output n_ch = reader.n_channels output[:, :, :, ch_offset : ch_offset + n_ch] = frames.astype(self.dtype) ch_offset += n_ch return output
[docs] def close(self): """Close all file readers.""" for reader in self.filereaders: reader.close() self.filereaders = []
[docs] class SUBSETFileReader(VideoReader): """ Reader that provides a subset of frames from another video reader. Useful for reading non-contiguous frame indices or reordering frames. """
[docs] def __init__( self, video_file_reader: VideoReader, indices: Union[List[int], np.ndarray] ): """ Initialize subset reader. Args: video_file_reader: Source video reader indices: Frame indices to include in subset (0-based) """ super().__init__() self.video_file_reader = video_file_reader self.indices = np.array(indices, dtype=np.int64) # Inherit buffer settings self.buffer_size = video_file_reader.buffer_size self.bin_size = 1 # Disable binning for subset reading initially # Will be set in _initialize self._original_bin_size = video_file_reader.bin_size
def _initialize(self): """Initialize properties from source reader.""" # Ensure source is initialized if hasattr(self.video_file_reader, "_ensure_initialized"): self.video_file_reader._ensure_initialized() # Validate indices max_idx = np.max(self.indices) if max_idx >= self.video_file_reader.frame_count: raise ValueError( f"Index {max_idx} exceeds source frame count {self.video_file_reader.frame_count}" ) # Copy properties from source self.height = self.video_file_reader.height self.width = self.video_file_reader.width self.n_channels = self.video_file_reader.n_channels self.dtype = self.video_file_reader.dtype # Set our frame count to the subset size self.frame_count = len(self.indices) # Store original bin size and temporarily disable binning in source self._original_bin_size = self.video_file_reader.bin_size def _read_raw_frames(self, frame_indices: Union[slice, List[int]]) -> np.ndarray: """ Read frames from subset. Args: frame_indices: Indices into the subset (not the original video) Returns: Array with shape (T, H, W, C) """ # Convert subset indices to original indices if isinstance(frame_indices, slice): start, stop, step = frame_indices.indices(self.frame_count) subset_indices = list(range(start, stop, step)) else: subset_indices = list(frame_indices) if len(subset_indices) == 0: return np.empty( (0, self.height, self.width, self.n_channels), dtype=self.dtype ) # Map subset indices to original video indices original_indices = self.indices[subset_indices] # Temporarily disable binning in source reader old_bin = self.video_file_reader.bin_size self.video_file_reader.bin_size = 1 try: # Read frames from source using mapped indices frames = [] for idx in original_indices: frame = self.video_file_reader[int(idx)] if frame.ndim == 3: # Single frame frames.append(frame[np.newaxis, ...]) else: frames.append(frame) result = ( np.concatenate(frames, axis=0) if frames else np.empty( (0, self.height, self.width, self.n_channels), dtype=self.dtype ) ) finally: # Restore original binning self.video_file_reader.bin_size = old_bin return result
[docs] def close(self): """No-op as we don't own the source reader.""" pass # Don't close the source reader as we don't own it
[docs] def main(): """Test wrapper implementations.""" import tempfile # Create test data test_frames = np.random.randint(0, 255, (20, 64, 64, 2), dtype=np.uint8) # Test MULTIFILE writer print("Testing MULTIFILE writer...") with tempfile.TemporaryDirectory() as tmpdir: multifile_path = Path(tmpdir) / "test_multi" # Use HDF5 format since we have that implemented with MULTIFILEFileWriter(str(multifile_path), "HDF5") as writer: writer.write_frames(test_frames[:10]) writer.write_frames(test_frames[10:]) # Check files were created # When path has no extension, it's treated as folder with 'compensated' as default name ch1_file = multifile_path / "compensated_ch1.HDF5" ch2_file = multifile_path / "compensated_ch2.HDF5" assert ch1_file.exists(), "Channel 1 file not created" assert ch2_file.exists(), "Channel 2 file not created" print("✓ MULTIFILE writer test passed") # Test MULTICHANNEL reader print("\nTesting MULTICHANNEL reader...") reader = MULTICHANNELFileReader([str(ch1_file), str(ch2_file)]) print(f"Shape: {reader.shape}") print(f"Channels: {reader.n_channels}") # Read all frames all_frames = reader[:] assert all_frames.shape == ( 20, 64, 64, 2, ), f"Shape mismatch: {all_frames.shape}" print("✓ MULTICHANNEL reader test passed") # Test SUBSET reader print("\nTesting SUBSET reader...") subset_indices = [0, 5, 10, 15, 19] subset_reader = SUBSETFileReader(reader, subset_indices) print(f"Subset shape: {subset_reader.shape}") assert subset_reader.frame_count == 5, "Subset frame count incorrect" subset_frames = subset_reader[:] assert subset_frames.shape == ( 5, 64, 64, 2, ), f"Subset shape mismatch: {subset_frames.shape}" # Verify correct frames were selected for i, orig_idx in enumerate(subset_indices): np.testing.assert_array_equal( subset_frames[i], all_frames[orig_idx], err_msg=f"Frame {i} (original {orig_idx}) mismatch", ) print("✓ SUBSET reader test passed") reader.close() # Test that factory functions are properly imported print("\nTesting factory function imports...") try: from pyflowreg.util.io.factory import get_video_file_reader as factory_reader from pyflowreg.util.io.factory import get_video_file_writer as factory_writer assert factory_reader == get_video_file_reader assert factory_writer == get_video_file_writer print("✓ Factory functions properly imported") except ImportError as e: print(f"✗ Factory import failed: {e}") print("\n✓ All wrapper tests passed!")
if __name__ == "__main__": main()