Source code for pyflowreg.util.io._ds_io

import numpy as np
import re
from collections import defaultdict


[docs] class DSFileReader: """ A mixin class that provides a generic, multi-pass heuristic for finding the most likely data-containing datasets within a file. """ def _find_datasets(self, datasets_with_info: list[tuple]) -> list: """ Heuristic to find datasets based on a list of names and their shapes. Args: datasets_with_info (list[tuple]): A list where each element is a tuple containing (dataset_name: str, dataset_shape: tuple). Returns: A list of strings with the names of the selected datasets. """ all_names = [info[0] for info in datasets_with_info] # --- Pass 1: Find datasets with channel conventions (e.g., 'ch1', 'channel_2') --- # This regex captures (prefix)(channel_word)(separator)(number) pattern = re.compile( r"^(.*?)((?:ch|channel|chan))([_.\s]*)(\d+)", re.IGNORECASE ) channel_groups = defaultdict(list) for name in all_names: match = pattern.match(name) if match: prefix = match.group(1) channel_num = int(match.group(4)) channel_groups[prefix].append((channel_num, name)) if channel_groups: # Find the group with the most channels that also have consistent shapes # This is a crucial check to ensure we're getting a real channel group valid_groups = {} for prefix, channels in channel_groups.items(): # Get shapes of all datasets in this group shapes = { info[1] for name in channels for info in datasets_with_info if info[0] == name[1] } if ( len(shapes) == 1 ): # All datasets in the group must have the same shape valid_groups[prefix] = channels if valid_groups: best_prefix = max(valid_groups, key=lambda k: len(valid_groups[k])) sorted_channels = sorted( valid_groups[best_prefix], key=lambda item: item[0] ) print( f"Heuristic Pass 1: Found channel group with prefix '{best_prefix}'." ) return [name for num, name in sorted_channels] # --- Pass 2: Find datasets with common generic names --- common_names = ["mov", "data", "dataset"] for name in all_names: sanitized_name = name.lower().lstrip("/") if sanitized_name in common_names: print(f"Heuristic Pass 2: Found common dataset '{name}'.") return [name] # --- Pass 3: Fallback to guessing based on dimensions --- print("Heuristic Pass 1 & 2 failed. Falling back to dimension-based guessing.") candidate_shapes = defaultdict(list) for name, shape in datasets_with_info: if len(shape) in [3, 4]: candidate_shapes[shape].append(name) if candidate_shapes: best_shape = max(candidate_shapes, key=lambda s: np.prod(s)) print( f"Warning: Guessing video data based on dimensions. " f"Selected {len(candidate_shapes[best_shape])} dataset(s) with shape {best_shape}." ) return candidate_shapes[best_shape] return []
[docs] class DSFileWriter: """ A mixin class that provides logic for generating dataset names for writers. This is a direct port of the DS_file_writer.m functionality. """ def __init__(self, **kwargs): # Default dimension ordering for writers: (height, width, time) self.dimension_ordering = kwargs.get("dimension_ordering", (0, 1, 2)) self.dataset_names = kwargs.get("dataset_names", None) # Sanitize dataset names by removing any leading slashes if self.dataset_names: if isinstance(self.dataset_names, list): self.dataset_names = [name.lstrip("/") for name in self.dataset_names] elif isinstance(self.dataset_names, str): self.dataset_names = self.dataset_names.lstrip("/")
[docs] def get_ds_name(self, channel_id: int, n_channels: int) -> str: """ Gets the dataset name for a specific channel. Args: channel_id (int): The 1-based index of the channel. n_channels (int): The total number of channels being written. Returns: The dataset name as a string. """ if self.dataset_names: if isinstance(self.dataset_names, list): if len(self.dataset_names) != n_channels: raise ValueError( "The number of provided dataset names must match the number of channels." ) return self.dataset_names[channel_id - 1] # Handle string patterns like 'ch*_reg' if "*" in self.dataset_names: return self.dataset_names.replace("*", str(channel_id)) # If it's a single name for a single channel, or a prefix for multiple if n_channels == 1: return self.dataset_names else: return f"{self.dataset_names}{channel_id}" else: # Default naming convention if none is provided return f"ch{channel_id}"