Source code for pyflowreg.util.io.mat

import os
from typing import Union, List
import warnings

import numpy as np
import scipy.io as sio
import h5py
import hdf5storage as h5s

from pyflowreg.util.io._base import VideoReader, VideoWriter
from pyflowreg.util.io._ds_io import DSFileReader, DSFileWriter


[docs] class MATFileReader(DSFileReader, VideoReader): """ MAT video file reader with dataset discovery. Supports both traditional MAT files (v5, v7) and v7.3 (HDF5-based). """ def __init__( self, file_path: str, buffer_size: int = 500, bin_size: int = 1, **kwargs ): # Initialize parent classes DSFileReader.__init__(self) VideoReader.__init__(self) self.file_path = file_path self.buffer_size = buffer_size self.bin_size = bin_size # MAT-specific self.mat_data = None self.is_v73 = False self.h5file = None # For v7.3 files # Dataset options from kwargs self.dataset_names = kwargs.get("dataset_names") self.dimension_ordering = kwargs.get( "dimension_ordering", [0, 1, 2] ) # MATLAB default # Known dataset patterns from MATLAB version self.known_patterns = ["ch*_reg", "ch*", "buffer*", "mov", "data"] @staticmethod def _is_v73(path: str) -> bool: try: with open(path, "rb") as f: head = f.read(128) if b"MATLAB 7.3 MAT-file" in head: return True except Exception: pass return h5py.is_hdf5(path) def _initialize(self): """Open MAT file and set up properties.""" if self._is_v73(self.file_path): self.is_v73 = True try: self.h5file = h5py.File(self.file_path, "r") except Exception as e: raise IOError(f"Cannot open MAT v7.3 (HDF5) file: {e}") else: try: self.mat_data = sio.loadmat( self.file_path, verify_compressed_data_integrity=False ) self.is_v73 = False except NotImplementedError: # Unexpected v7.3 despite header test self.is_v73 = True self.h5file = h5py.File(self.file_path, "r") except ValueError as e: raise IOError(f"Cannot open MAT v7.3 file: {e}") # Find datasets if not self.dataset_names: if self.is_v73: datasets_info = self._find_datasets_v73() else: datasets_info = self._find_datasets_regular() self.dataset_names = self._find_datasets(datasets_info) if not self.dataset_names: raise ValueError("No suitable datasets found in MAT file") # Verify and setup properties from first dataset self._setup_properties() def _find_datasets_regular(self): """Find datasets in regular MAT files.""" datasets_info = [] for key in self.mat_data.keys(): # Skip metadata keys if key.startswith("__"): continue data = self.mat_data[key] if isinstance(data, np.ndarray) and data.ndim == 3: datasets_info.append((key, data.shape)) return datasets_info def _find_datasets_v73(self): """Find datasets in v7.3 MAT files.""" datasets_info = [] def visitor(name, obj): if isinstance(obj, h5py.Dataset) and len(obj.shape) == 3: # Skip MATLAB metadata if not name.startswith("#"): datasets_info.append((name, obj.shape)) self.h5file.visititems(visitor) return datasets_info def _setup_properties(self): """Setup reader properties from discovered datasets.""" if not self.dataset_names: raise ValueError("No datasets to setup properties from") # Get first dataset to determine properties if self.is_v73: first_ds = self.h5file[self.dataset_names[0]] shape = first_ds.shape self.dtype = first_ds.dtype else: first_ds = self.mat_data[self.dataset_names[0]] shape = first_ds.shape self.dtype = first_ds.dtype # Map from MATLAB dimension ordering to properties # MATLAB: [height, width, time] by default # Python: expecting (T, H, W, C) self.height = shape[self.dimension_ordering[0]] self.width = shape[self.dimension_ordering[1]] self.frame_count = shape[self.dimension_ordering[2]] self.n_channels = len(self.dataset_names) # Legacy compatibility self.m = self.height self.n = self.width self.mat_data_type = str(self.dtype) # Verify all datasets have same shape for ds_name in self.dataset_names[1:]: if self.is_v73: ds_shape = self.h5file[ds_name].shape else: ds_shape = self.mat_data[ds_name].shape if ds_shape != shape: raise ValueError( f"Dataset {ds_name} has different shape: {ds_shape} vs {shape}" ) def _read_raw_frames(self, frame_indices: Union[slice, List[int]]) -> np.ndarray: """ Read raw frames from MAT file. Returns: Array with shape (T, H, W, C) """ # Convert list to array for indexing if isinstance(frame_indices, list): if len(frame_indices) == 0: return np.empty( (0, self.height, self.width, self.n_channels), dtype=self.dtype ) indices = np.array(frame_indices) else: # Convert slice to indices start, stop, step = frame_indices.indices(self.frame_count) indices = np.arange(start, stop, step) n_frames = len(indices) output = np.zeros( (n_frames, self.height, self.width, self.n_channels), dtype=self.dtype ) # Read from each dataset/channel for ch_idx, ds_name in enumerate(self.dataset_names): if self.is_v73: data = self._read_v73_dataset(ds_name, indices) else: data = self._read_regular_dataset(ds_name, indices) # Store in output array output[:, :, :, ch_idx] = data return output def _read_regular_dataset(self, ds_name: str, indices: np.ndarray) -> np.ndarray: """Read from regular MAT file dataset.""" dataset = self.mat_data[ds_name] # Create index arrays for each dimension idx = [slice(None), slice(None), slice(None)] idx[self.dimension_ordering[2]] = indices # Read data with proper ordering data = dataset[tuple(idx)] # Permute to (T, H, W) format if self.dimension_ordering != [2, 0, 1]: # Create inverse permutation - find where each output dim comes from perm = [None, None, None] perm[0] = self.dimension_ordering[2] # T comes from stored time position perm[1] = self.dimension_ordering[0] # H comes from stored height position perm[2] = self.dimension_ordering[1] # W comes from stored width position data = np.transpose(data, perm) return data def _read_v73_dataset(self, ds_name: str, indices: np.ndarray) -> np.ndarray: """Read from v7.3 MAT file dataset.""" dataset = self.h5file[ds_name] # Check if indices are contiguous for efficient reading if len(indices) > 1 and np.all(np.diff(indices) == 1): # Contiguous - use slicing idx = [slice(None), slice(None), slice(None)] idx[self.dimension_ordering[2]] = slice(indices[0], indices[-1] + 1) data = dataset[tuple(idx)] else: # Non-contiguous - read individually n_frames = len(indices) shape = [dataset.shape[0], dataset.shape[1], dataset.shape[2]] shape[self.dimension_ordering[2]] = n_frames data = np.zeros(shape, dtype=self.dtype) for i, frame_idx in enumerate(indices): idx_src = [slice(None), slice(None), slice(None)] idx_src[self.dimension_ordering[2]] = frame_idx idx_dst = [slice(None), slice(None), slice(None)] idx_dst[self.dimension_ordering[2]] = i data[tuple(idx_dst)] = dataset[tuple(idx_src)] # Permute to (T, H, W) format if self.dimension_ordering != [2, 0, 1]: # Create inverse permutation - find where each output dim comes from perm = [None, None, None] perm[0] = self.dimension_ordering[2] # T comes from stored time position perm[1] = self.dimension_ordering[0] # H comes from stored height position perm[2] = self.dimension_ordering[1] # W comes from stored width position data = np.transpose(data, perm) return data
[docs] def close(self): """Close MAT file.""" if self.h5file: self.h5file.close() self.h5file = None self.mat_data = None
[docs] class MATFileWriter(DSFileWriter, VideoWriter): """ MAT video file writer with MATLAB compatibility. Creates MAT files with separate 3D datasets per channel, stored in MATLAB-compatible dimension ordering. """
[docs] def __init__(self, file_path: str, **kwargs): """ Initialize MAT writer. Args: file_path: Output file path dataset_names: Optional dataset naming pattern or list Default: 'ch*' (produces ch1, ch2, etc.) dimension_ordering: Storage order for MATLAB compatibility Default: [0, 1, 2] for (H, W, T) in MATLAB use_v73: Force v7.3 format (HDF5-based) for large files """ # Initialize parent classes DSFileWriter.__init__(self, **kwargs) VideoWriter.__init__(self) self.file_path = file_path self.use_v73 = kwargs.get("use_v73", False) self._data_dict = {} self._frame_counter = 0 # MATLAB compatibility options self.dimension_ordering = kwargs.get("dimension_ordering", [0, 1, 2]) # Dataset naming if not self.dataset_names: self.dataset_names = "ch*"
[docs] def write_frames(self, frames: np.ndarray): """ Write frames to MAT file buffers. Args: frames: Array with shape (T, H, W, C) or (T, H, W) or (H, W) """ # Normalize input to 4D (T, H, W, C) if frames.ndim == 2: # Single frame, single channel frames = frames[np.newaxis, :, :, np.newaxis] elif frames.ndim == 3: if frames.shape[0] == self.height and frames.shape[1] == self.width: # Single frame, multiple channels (H, W, C) frames = frames[np.newaxis, :, :, :] else: # Multiple frames, single channel (T, H, W) frames = frames[:, :, :, np.newaxis] elif frames.ndim != 4: raise ValueError(f"Expected 2D, 3D or 4D input, got {frames.ndim}D") # Initialize on first write if not self.initialized: T, H, W, C = frames.shape self.height = H self.width = W self.n_channels = C self.dtype = frames.dtype self.initialized = True # Initialize data buffers for each channel for ch_idx in range(self.n_channels): ds_name = self.get_ds_name(ch_idx + 1, self.n_channels) self._data_dict[ds_name] = [] # Validate shape T, H, W, C = frames.shape if H != self.height or W != self.width: raise ValueError( f"Frame size mismatch. Expected ({self.height}, {self.width}), " f"got ({H}, {W})" ) if C != self.n_channels: raise ValueError( f"Channel count mismatch. Expected {self.n_channels}, got {C}" ) # Accumulate frames for each channel for ch_idx in range(self.n_channels): ds_name = self.get_ds_name(ch_idx + 1, self.n_channels) channel_data = frames[:, :, :, ch_idx] # (T, H, W) # Convert to MATLAB dimension ordering if self.dimension_ordering != [2, 0, 1]: perm = [None, None, None] perm[self.dimension_ordering[0]] = 1 # H position perm[self.dimension_ordering[1]] = 2 # W position perm[self.dimension_ordering[2]] = 0 # T position channel_data = np.transpose(channel_data, perm) self._data_dict[ds_name].append(channel_data) self._frame_counter += T
[docs] def close(self): """Close and write the MAT file.""" if not self._data_dict: return # Concatenate accumulated frames for each channel final_dict = {} for ds_name, frame_list in self._data_dict.items(): if frame_list: # Concatenate along time dimension concat_axis = self.dimension_ordering[2] final_dict[ds_name] = np.concatenate(frame_list, axis=concat_axis) # Add metadata final_dict["__pyflowreg_metadata__"] = { "n_channels": self.n_channels, "frame_count": self._frame_counter, "height": self.height, "width": self.width, "dimension_ordering": self.dimension_ordering, "format": "pyflowreg_mat_v1", } # Write MAT file # if self.use_v73: # # Use v7.3 format for large files # sio.savemat(self.file_path, final_dict, do_compression=True, format='7.3') # else: # Use default format # try: # sio.savemat(self.file_path, final_dict, do_compression=True) # except ValueError: # File too large for v5/v7, switch to v7.3 # warnings.warn("File too large for MAT v5/v7, switching to v7.3 format") # sio.savemat(self.file_path, final_dict, do_compression=True, format='7.3') try: if self.use_v73: h5s.savemat(self.file_path, final_dict) else: sio.savemat(self.file_path, final_dict, do_compression=True, format="5") except ValueError: warnings.warn("Switching to v7.3 (file too large for v5).") h5s.savemat(self.file_path, final_dict) print(f"MAT file written: {self.file_path}") self._data_dict = {}
def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.close()
[docs] def main(): """Test MAT file I/O.""" import tempfile # Create test data test_frames = np.random.randint(0, 255, (100, 128, 128, 2), dtype=np.uint8) # Test writing with tempfile.NamedTemporaryFile(suffix=".mat", delete=False) as f: mat_path = f.name print(f"Writing test MAT file: {mat_path}") with MATFileWriter(mat_path, use_v73=True) as writer: writer.write_frames(test_frames[:50]) writer.write_frames(test_frames[50:]) # Test reading print(f"Reading test MAT file: {mat_path}") reader = MATFileReader(mat_path, buffer_size=10, bin_size=1) print(f"Shape: {reader.shape}") print(f"Channels: {reader.n_channels}") print(f"Frame count: {reader.frame_count}") # Test different access patterns single_frame = reader[0] print(f"Single frame shape: {single_frame.shape}") frame_slice = reader[10:20] print(f"Slice shape: {frame_slice.shape}") # Test batch reading reader.reset() batch = reader.read_batch() print(f"Batch shape: {batch.shape}") # Verify data integrity all_frames = reader[:] if np.array_equal(all_frames, test_frames): print("✓ Data integrity verified") else: print("✗ Data mismatch!") reader.close() # Cleanup os.unlink(mat_path) print("Test complete")
if __name__ == "__main__": main()