"""
Multiprocessing executor - processes frames in parallel using shared memory.
"""
from multiprocessing import shared_memory
from concurrent.futures import ProcessPoolExecutor, as_completed
from typing import Callable, Tuple, Optional, Dict
import numpy as np
from .base import BaseExecutor
# Helper functions for dimension conversion (needed in worker processes)
def _as2d(x):
"""Convert single-channel 3D array to 2D."""
return x[..., 0] if x.ndim == 3 and x.shape[2] == 1 else x
def _as3d(x):
"""Ensure array has channel dimension (H,W,C)."""
return x[..., None] if x.ndim == 2 else x
# Global dictionary to store shared memory references in worker processes
_SHM: Dict[str, Tuple[shared_memory.SharedMemory, np.ndarray]] = {}
def _init_shared(shm_specs: Dict[str, Tuple[str, tuple, str]]):
"""
Initialize shared memory in worker process.
Args:
shm_specs: Dictionary mapping names to (shm_name, shape, dtype_str) tuples
"""
# Limit thread usage in numerical libraries to avoid oversubscription
# Each worker process should use only 1 thread to avoid N processes × M threads saturation
import os
os.environ.update(
{
"OMP_NUM_THREADS": "1",
"MKL_NUM_THREADS": "1",
"OPENBLAS_NUM_THREADS": "1",
"NUMEXPR_NUM_THREADS": "1",
}
)
global _SHM
_SHM = {}
for key, (name, shape, dtype_str) in shm_specs.items():
shm = shared_memory.SharedMemory(name=name)
arr = np.ndarray(shape, dtype=np.dtype(dtype_str), buffer=shm.buf)
_SHM[key] = (shm, arr)
def _process_frame_worker(
t: int, interpolation_method: str, flow_param_scalars: dict
) -> int:
"""
Worker function to process a single frame using shared memory.
Args:
t: Frame index
interpolation_method: Interpolation method for registration
flow_param_scalars: Dictionary of scalar flow parameters (non-array)
Returns:
Frame index (for tracking completion)
"""
# Import functions inside worker to avoid pickling issues with Numba
from pyflowreg.core.optical_flow import get_displacement
from ...core.warping import imregister_wrapper
# Get arrays from shared memory
batch = _SHM["batch"][1]
batch_proc = _SHM["batch_proc"][1]
registered = _SHM["registered"][1]
w_out = _SHM["flow_fields"][1]
reference_proc = _SHM["reference_proc"][1]
reference_raw = _SHM["reference_raw"][1]
w_init = _SHM["w_init"][1]
# Extract CC parameters and remove them from flow_params
use_cc = bool(flow_param_scalars.get("cc_initialization", False))
cc_hw = flow_param_scalars.get("cc_hw", 256)
cc_up = int(flow_param_scalars.get("cc_up", 1))
# Import prealignment function only if needed
if use_cc:
from pyflowreg.util.xcorr_prealignment import estimate_rigid_xcorr_2d
# Create flow_params without CC parameters
flow_params = {
k: v
for k, v in flow_param_scalars.items()
if k not in ["cc_initialization", "cc_hw", "cc_up"]
}
if "weight" in _SHM:
flow_params["weight"] = _SHM["weight"][1]
# Check if cross-correlation initialization is enabled
if use_cc:
target_hw = cc_hw
if isinstance(target_hw, int):
target_hw = (target_hw, target_hw)
weight = _SHM["weight"][1] if "weight" in _SHM else None
# Step 1: Backward warp mov by w_init to get partially aligned
mov_partial = imregister_wrapper(
batch_proc[t],
w_init[..., 0], # dx
w_init[..., 1], # dy
reference_proc,
interpolation_method="linear",
)
# Use 2D views for CC
ref_for_cc = _as2d(reference_proc)
mov_for_cc = _as2d(mov_partial)
# Step 2: Estimate rigid residual between ref and partially aligned mov
w_cross = estimate_rigid_xcorr_2d(
ref_for_cc, mov_for_cc, target_hw=target_hw, up=cc_up, weight=weight
)
# Step 3: Combine w_init + w_cross
w_combined = w_init.copy()
w_combined[..., 0] += w_cross[0]
w_combined[..., 1] += w_cross[1]
# Step 4: Backward warp original mov by combined field
mov_aligned = imregister_wrapper(
batch_proc[t],
w_combined[..., 0],
w_combined[..., 1],
reference_proc,
interpolation_method="linear",
)
# Ensure mov_aligned has channel dimension (imregister_wrapper strips it for single channel)
mov_aligned = _as3d(mov_aligned)
# Step 5: Get residual non-rigid displacement
w_residual = get_displacement(
reference_proc, mov_aligned, uv=np.zeros_like(w_init), **flow_params
)
# Step 6: Total flow is w_combined + w_residual
flow = (w_combined + w_residual).astype(np.float32, copy=False)
else:
# Compute optical flow without prealignment
flow = get_displacement(
reference_proc, batch_proc[t], uv=w_init.copy(), **flow_params
).astype(np.float32, copy=False)
# Apply flow field to register the frame
reg_frame = imregister_wrapper(
batch[t],
flow[..., 0],
flow[..., 1],
reference_raw,
interpolation_method=interpolation_method,
)
# Store results directly in shared memory
w_out[t] = flow
# Handle case where registered frame might have fewer channels
if reg_frame.ndim < registered.ndim - 1:
registered[t, ..., 0] = reg_frame
else:
registered[t] = reg_frame
return t
[docs]
class MultiprocessingExecutor(BaseExecutor):
"""
Multiprocessing executor using shared memory for zero-copy data sharing.
This is the most efficient executor for CPU-bound operations as it:
1. Uses multiple CPU cores in parallel
2. Avoids data serialization overhead with shared memory
3. Bypasses the GIL completely
"""
[docs]
def __init__(self, n_workers: Optional[int] = None):
"""
Initialize multiprocessing executor.
Args:
n_workers: Number of worker processes. If None, uses RuntimeContext default.
"""
super().__init__(n_workers)
self.shm_handles = {}
self.executor = None
[docs]
def setup(self):
"""Create the process pool executor."""
if self.executor is None:
# We'll create the executor with initializer when processing
pass
[docs]
def cleanup(self):
"""Cleanup shared memory and shutdown executor."""
# Cleanup shared memory
for shm in self.shm_handles.values():
shm.close()
shm.unlink()
self.shm_handles = {}
# Shutdown executor
if self.executor is not None:
self.executor.shutdown(wait=True)
self.executor = None
def _create_shared_input(self, name: str, arr: np.ndarray, shm_specs: dict):
"""
Create shared memory for input array.
Args:
name: Name for the shared memory
arr: Array to share
shm_specs: Dictionary to store shared memory specifications
"""
shm = shared_memory.SharedMemory(create=True, size=arr.nbytes)
shared_arr = np.ndarray(arr.shape, dtype=arr.dtype, buffer=shm.buf)
shared_arr[:] = arr
shm_specs[name] = (shm.name, arr.shape, str(arr.dtype))
self.shm_handles[name] = shm
def _create_shared_output(
self, name: str, shape: tuple, dtype: np.dtype, shm_specs: dict
) -> np.ndarray:
"""
Create shared memory for output array.
Args:
name: Name for the shared memory
shape: Shape of the array
dtype: Data type of the array
shm_specs: Dictionary to store shared memory specifications
Returns:
Numpy array view of the shared memory
"""
nbytes = int(np.prod(shape) * np.dtype(dtype).itemsize)
shm = shared_memory.SharedMemory(create=True, size=nbytes)
arr = np.ndarray(shape, dtype=dtype, buffer=shm.buf)
shm_specs[name] = (shm.name, shape, str(np.dtype(dtype)))
self.shm_handles[name] = shm
return arr
[docs]
def process_batch(
self,
batch: np.ndarray,
batch_proc: np.ndarray,
reference_raw: np.ndarray,
reference_proc: np.ndarray,
w_init: np.ndarray,
get_displacement_func: Callable,
imregister_func: Callable,
interpolation_method: str = "cubic",
progress_callback: Optional[Callable[[int], None]] = None,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray]:
"""
Process frames in parallel using multiprocessing with shared memory.
Args:
batch: Raw frames to register, shape (T, H, W, C)
batch_proc: Preprocessed frames for flow computation, shape (T, H, W, C)
reference_raw: Raw reference frame, shape (H, W, C)
reference_proc: Preprocessed reference frame, shape (H, W, C)
w_init: Initial flow field, shape (H, W, 2)
get_displacement_func: Ignored (functions imported in worker)
imregister_func: Ignored (functions imported in worker)
interpolation_method: Interpolation method for registration
**kwargs: Additional parameters including 'flow_params' dict
Returns:
Tuple of (registered_frames, flow_fields)
"""
# Normalize inputs to ensure consistent dimensions
batch, batch_proc, reference_raw, reference_proc, w_init = (
self._normalize_inputs(
batch, batch_proc, reference_raw, reference_proc, w_init
)
)
T, H, W, C = batch.shape
# Get flow parameters from kwargs
flow_params = kwargs.get("flow_params", {})
# Create shared memory for all arrays
shm_specs = {}
# Input arrays (read-only in workers)
self._create_shared_input("batch", batch, shm_specs)
self._create_shared_input("batch_proc", batch_proc, shm_specs)
self._create_shared_input("reference_raw", reference_raw, shm_specs)
self._create_shared_input("reference_proc", reference_proc, shm_specs)
self._create_shared_input("w_init", w_init.astype(np.float32), shm_specs)
# Handle weight array separately if present in flow_params
if isinstance(flow_params.get("weight", None), np.ndarray):
self._create_shared_input("weight", flow_params["weight"], shm_specs)
# Create scalar-only params dict (without weight array)
flow_param_scalars = {k: v for k, v in flow_params.items() if k != "weight"}
else:
flow_param_scalars = dict(flow_params)
# Output arrays (written by workers)
reg_arr = self._create_shared_output(
"registered", batch.shape, batch.dtype, shm_specs
)
flow_arr = self._create_shared_output(
"flow_fields", (T, H, W, 2), np.float32, shm_specs
)
# Create process pool with shared memory initialization
with ProcessPoolExecutor(
max_workers=self.n_workers, initializer=_init_shared, initargs=(shm_specs,)
) as executor:
# Submit all frames for processing
futures = [
executor.submit(
_process_frame_worker, t, interpolation_method, flow_param_scalars
)
for t in range(T)
]
# Wait for all frames to complete
for future in as_completed(futures):
future.result() # This will raise any exceptions that occurred
# Copy results from shared memory (important to copy before cleanup!)
registered = np.array(reg_arr, copy=True)
flow_fields = np.array(flow_arr, copy=True)
# Call progress callback for entire batch (multiprocessing processes batch in parallel)
if progress_callback is not None:
progress_callback(T)
# Cleanup shared memory
for shm in self.shm_handles.values():
shm.close()
shm.unlink()
self.shm_handles = {}
return registered, flow_fields
[docs]
def get_info(self) -> dict:
"""Get information about this executor."""
info = super().get_info()
info.update(
{
"parallel": True,
"description": f"Multiprocessing with shared memory, {self.n_workers} workers",
"features": ["zero-copy", "shared-memory", "true-parallelism"],
}
)
return info
# Register this executor with RuntimeContext on import
MultiprocessingExecutor.register()