"""Utilities for running parallel computations with processes and shared memory."""
import logging
from collections.abc import Callable, Generator
from contextlib import contextmanager
from multiprocessing.managers import SharedMemoryManager
from multiprocessing.shared_memory import SharedMemory
from typing import Any, NamedTuple
import numpy as np
from joblib import Parallel, delayed
from numpy.typing import DTypeLike, NDArray
log = logging.getLogger(__name__)
[docs]
class ArrayProps(NamedTuple):
"""Properties of an array stored in shared memory."""
name: str
shape: tuple[int, ...]
dtype: DTypeLike
[docs]
class SHM_Wrapper:
"""Wrapper for function to be call with parallel shared memory.
Parameters
----------
func : Callable
Function to be called with shared memory arrays.
"""
# A decorator would not work here because of the way joblib works.
def __init__(self, func: Callable):
self.func = func
[docs]
def __call__(
self,
input_props: ArrayProps,
output_props: ArrayProps,
*args: Any,
**kwargs: Any,
) -> None:
"""Run in parallel with shared memory."""
with array_from_shm(input_props, output_props) as (input, output):
self.func(input, output, *args, **kwargs)
[docs]
def run_parallel(
func: Callable,
input_array: NDArray,
output_array: NDArray,
n_jobs: int = -1,
parallel_axis: int = 0,
*args: Any,
**kwargs: Any,
) -> NDArray:
"""Run a function in parallel with shared memory."""
with (
SharedMemoryManager() as smm,
Parallel(n_jobs=n_jobs, backend="multiprocessing") as parallel,
):
# input_shm = smm.SharedMemory(size=input_array.nbytes)
# input_array_sm = np.ndarray(
# input_array.shape, dtype=input_array.dtype, buffer=input_shm.buf
# )
# input_array_sm[:] = input_array # move to shared memory
# output_shm = smm.SharedMemory(size=output_array.nbytes)
# output_array_sm = np.ndarray(
# output_array.shape, dtype=output_array.dtype, buffer=output_shm.buf
# )
# input_prop = ArrayProps(input_shm.name, input_array.shape, input_array.dtype)
# output_prop = ArrayProps(
# output_shm.name, output_array.shape, output_array.dtype
# )
input_prop, input_array_sm, input_shm = array_to_shm(input_array, smm)
output_prop, output_array_sm, output_shm = array_to_shm(output_array, smm)
input_array_sm[:] = input_array # move to shared memory
parallel(
delayed(SHM_Wrapper(func))(
input_prop,
output_prop,
i,
*args,
**kwargs,
)
for i in range(input_array.shape[parallel_axis])
)
output_array[:] = output_array_sm # copy back
smm.shutdown()
return output_array
[docs]
@contextmanager
def array_from_shm(
*array_props: ArrayProps,
) -> Generator[list[NDArray], None, None]:
"""Get arrays from shared memory."""
shms = []
arrays: list[NDArray] = []
for prop in array_props:
nbytes = int(np.dtype(prop.dtype).itemsize * np.prod(prop.shape))
shms.append(SharedMemory(name=prop.name, size=nbytes))
arrays.append(
np.ndarray(shape=prop.shape, dtype=prop.dtype, buffer=shms[-1].buf)
)
yield arrays
del arrays
for s in shms:
s.close()
del shms
[docs]
def array_to_shm(
array: NDArray, smm: SharedMemoryManager
) -> tuple[ArrayProps, NDArray, SharedMemory]:
"""Move an array to shared memory."""
shm = smm.SharedMemory(size=array.nbytes)
array_sm: NDArray = np.ndarray(array.shape, dtype=array.dtype, buffer=shm.buf)
array_sm[:] = array # move to shared memory
# Returning the shm object is required to avoid garbage collection (and segfault)
return ArrayProps(shm.name, array.shape, str(array.dtype)), array_sm, shm