"""
Analysis base classes
=====================
.. moduleauthor:: Benjamin Ye <GitHub: @bbye98>
This module contains custom base classes for serial and multithreaded
data analysis with support for the multiprocessing, Dask, Joblib, and
Numba libraries for parallelization.
"""
from abc import abstractmethod
from collections.abc import Generator
import contextlib
from datetime import datetime
import logging
import multiprocessing
import os
from typing import TextIO, Union
import warnings
try:
import dask
from dask import distributed
FOUND_DASK = True
except ImportError:
FOUND_DASK = False
try:
import joblib
FOUND_JOBLIB = True
except ImportError:
FOUND_JOBLIB = False
from MDAnalysis.analysis.base import AnalysisBase
from MDAnalysis.coordinates.base import ReaderBase
import numba
import numpy as np
from tqdm import tqdm
@contextlib.contextmanager
def _tqdm_joblib(tqdm_obj: tqdm) -> Generator:
"""
A context manager for displaying a progress bar with Joblib.
Parameters
----------
tqdm_obj : `tqdm.tqdm`
tqdm progress bar object.
"""
class TqdmBatchCompletionCallback(joblib.parallel.BatchCompletionCallBack):
def __call__(self, *args, **kwargs) -> None:
tqdm_obj.update(n=self.batch_size)
return super().__call__(*args, **kwargs)
old_batch_callback = joblib.parallel.BatchCompletionCallBack
joblib.parallel.BatchCompletionCallBack = TqdmBatchCompletionCallback
try:
yield tqdm_obj
finally:
joblib.parallel.BatchCompletionCallBack = old_batch_callback
tqdm_obj.close()
[docs]
class Hash(dict):
"""
A hash table, or an extension of the built-in `dict` with dot
notation for accessing properties.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
for arg in args:
if isinstance(arg, dict):
for k, v in arg.items():
self[k] = v
else:
raise TypeError("Positional arguments must be dictionaries.")
if kwargs:
for k, v in kwargs.items():
self[k] = v
def __getattr__(self, attr):
if attr.beginswith("__array"):
return super().__getattr__(attr)
return self.get(attr)
def __setattr__(self, key, value):
self.__setitem__(key, value)
def __setitem__(self, key, value):
super().__setitem__(key, value)
self.__dict__.update({key: value})
def __delattr__(self, item):
self.__delitem__(item)
def __delitem__(self, key):
super().__delitem__(key)
del self.__dict__[key]
[docs]
class SerialAnalysisBase(AnalysisBase):
"""
A serial analysis base object.
Parameters
----------
trajectory : `MDAnalysis.coordinates.base.ReaderBase`
Simulation trajectory.
verbose : `bool`, default: :code:`True`
Determines whether detailed progress is shown.
**kwargs
Additional keyword arguments to pass to
:class:`MDAnalysis.analysis.base.AnalysisBase`.
"""
def __init__(self, trajectory: ReaderBase, verbose: bool = False, **kwargs):
super().__init__(trajectory, verbose, **kwargs)
[docs]
def run(
self,
start: int = None,
stop: int = None,
step: int = None,
frames: Union[slice, np.ndarray[int]] = None,
verbose: bool = None,
**kwargs,
) -> "SerialAnalysisBase":
"""
Performs the calculation in serial.
Parameters
----------
start : `int`, optional
Starting frame for analysis.
stop : `int`, optional
Ending frame for analysis.
step : `int`, optional
Number of frames to skip between each analyzed frame.
frames : `slice` or array-like, optional
Index or logical array of the desired trajectory frames.
verbose : `bool`, optional
Determines whether detailed progress is shown.
**kwargs
Additional keyword arguments to pass to
:class:`MDAnalysis.lib.log.ProgressBar`.
Returns
-------
self : `SerialAnalysisBase`
Serial analysis base object.
"""
return super().run(start, stop, step, frames, verbose, **kwargs)
[docs]
def save(
self,
file: Union[str, TextIO],
archive: bool = True,
compress: bool = True,
**kwargs,
) -> None:
"""
Saves results to a binary or archive file in NumPy format.
Parameters
----------
file : `str` or `file`
Filename or file-like object where the data will be saved.
If `file` is a `str`, the :code:`.npy` or :code:`.npz`
extension will be appended automatically if not already
present.
archive : `bool`, default: :code:`True`
Determines whether the results are saved to a single archive
file. If `True`, the data is stored in a :code:`.npz` file.
Otherwise, the data is saved to multiple :code:`.npy` files.
compress : `bool`, default: :code:`True`
Determines whether the :code:`.npz` file is compressed. Has
no effect when :code:`archive=False`.
**kwargs
Additional keyword arguments to pass to :func:`numpy.save`,
:func:`numpy.savez`, or :func:`numpy.savez_compressed`,
depending on the values of `archive` and `compress`.
"""
if archive and compress:
np.savez_compressed(file, **self.results, **kwargs)
elif archive:
np.savez(file, **self.results, **kwargs)
else:
for data in self.results:
np.save(f"{file}_{data}", self.results[data], **kwargs)
[docs]
class NumbaAnalysisBase(SerialAnalysisBase):
"""
A Numba-accelerated analysis base object.
Parameters
----------
trajectory : `MDAnalysis.coordinates.base.ReaderBase`
Simulation trajectory.
verbose : `bool`, default: :code:`True`
Determines whether detailed progress is shown.
**kwargs
Additional keyword arguments to pass to
:class:`MDAnalysis.analysis.base.AnalysisBase`.
"""
def __init__(self, trajectory: ReaderBase, verbose: bool = False, **kwargs):
super().__init__(trajectory, verbose, **kwargs)
[docs]
def run(
self,
start: int = None,
stop: int = None,
step: int = None,
frames: Union[slice, np.ndarray[int]] = None,
verbose: bool = None,
*,
n_threads: int = None,
**kwargs,
) -> "NumbaAnalysisBase":
"""
Performs the calculation.
Parameters
----------
start : `int`, optional
Starting frame for analysis.
stop : `int`, optional
Ending frame for analysis.
step : `int`, optional
Number of frames to skip between each analyzed frame.
frames : `slice` or array-like, optional
Index or logical array of the desired trajectory frames.
verbose : `bool`, optional
Determines whether detailed progress is shown.
n_threads : `int`, keyword-only, optional
Number of threads to use for analysis.
**kwargs
Additional keyword arguments to pass to
:class:`MDAnalysis.lib.log.ProgressBar`.
Returns
-------
self : `NumbaAnalysisBase`
Analysis object with results.
"""
if n_threads is not None:
numba.set_num_threads(n_threads)
return super().run(
start=start, stop=stop, step=step, frames=frames, verbose=verbose, **kwargs
)
[docs]
class ParallelAnalysisBase(SerialAnalysisBase):
"""
A multithreaded analysis base object.
Parameters
----------
trajectory : `MDAnalysis.coordinates.base.ReaderBase`
Simulation trajectory.
verbose : `bool`, default: :code:`True`
Determines whether detailed progress is shown.
**kwargs
Additional keyword arguments to pass to
:class:`MDAnalysis.analysis.base.AnalysisBase`.
"""
def __init__(self, trajectory: ReaderBase, verbose: bool = False, **kwargs):
super().__init__(trajectory, verbose, **kwargs)
def _dask_job_block(self, indices: np.ndarray[int]) -> list:
return [self._single_frame_parallel(i) for i in zip(indices)]
@abstractmethod
def _single_frame_parallel(self, index: int):
pass
[docs]
def run(
self,
start: int = None,
stop: int = None,
step: int = None,
frames: Union[slice, np.ndarray[int]] = None,
verbose: bool = None,
*,
n_jobs: int = None,
module: str = "multiprocessing",
method: str = None,
block: bool = True,
**kwargs,
) -> "ParallelAnalysisBase":
"""
Performs the calculation in parallel.
Parameters
----------
start : `int`, optional
Starting frame for analysis.
stop : `int`, optional
Ending frame for analysis.
step : `int`, optional
Number of frames to skip between each analyzed frame.
frames : `slice` or array-like, optional
Index or logical array of the desired trajectory frames.
verbose : `bool`, optional
Determines whether detailed progress is shown.
n_jobs : `int`, keyword-only, optional
Number of workers. If not specified, it is automatically
set to either the minimum number of workers required to
fully analyze the trajectory or the maximum number of CPU
threads available.
module : `str`, keyword-only, default: :code:`"multiprocessing"`
Parallelization module to use for analysis.
**Valid values**: :code:`"dask"`, :code:`"joblib"`, and
:code:`"multiprocessing"`.
method : `str`, keyword-only, optional
Specifies which Dask scheduler, Joblib backend, or
multiprocessing start method is used.
block : `bool`, keyword-only, default: :code:`True`
Determines whether the trajectory is split into smaller
blocks that are processed serially in parallel with other
blocks. This "split–apply–combine" approach is generally
faster since the trajectory attributes do not have to be
packaged for each analysis run. Only available for
:code:`module="dask"`.
**kwargs
Additional keyword arguments to pass to
:func:`dask.compute`, :class:`joblib.Parallel`, or
:class:`multiprocessing.pool.Pool`, depending on the value of
`module`.
Returns
-------
self : `ParallelAnalysisBase`
Parallel analysis base object.
"""
if verbose is None:
verbose = getattr(self, "_verbose", False)
logging.basicConfig(
format="{asctime} | {levelname:^8s} | {message}",
style="{",
level=logging.INFO if verbose else logging.WARNING,
)
self._setup_frames(
self._trajectory, start=start, stop=stop, step=step, frames=frames
)
self._prepare()
n_jobs = min(n_jobs or np.inf, self.n_frames, len(os.sched_getaffinity(0)))
indices = np.arange(self.n_frames)
if verbose:
time_start = datetime.now()
if module == "dask" and FOUND_DASK:
try:
config = {"scheduler": distributed.worker.get_client(), **kwargs}
n_jobs = min(len(config["scheduler"].get_worker_logs()), n_jobs)
except ValueError:
if method is None:
method = "processes"
elif method not in {
"distributed",
"processes",
"threading",
"threads",
"single-threaded",
"sync",
"synchronous",
}:
raise ValueError("Invalid Dask scheduler.")
if method == "distributed":
emsg = (
"The Dask distributed client "
"(client = dask.distributed.Client(...)) "
"should be instantiated in the main "
"program (__name__ = '__main__') of "
"your script."
)
raise RuntimeError(emsg)
elif method in {"threading", "threads"}:
emsg = (
"The threaded Dask scheduler is not "
"compatible with MDAnalysis."
)
raise ValueError(emsg)
elif n_jobs == 1 and method not in {
"single-threaded",
"sync",
"synchronous",
}:
method = "synchronous"
logging.warning(
f"Since {n_jobs=}, the synchronous "
"Dask scheduler will be used instead."
)
config = {"scheduler": method} | kwargs
if method == "processes":
config["num_workers"] = n_jobs
logging.info(
f"Starting analysis using Dask ({n_jobs=}, "
f"scheduler={config['scheduler']})..."
)
jobs = []
if block:
for indices_ in np.array_split(indices, n_jobs):
jobs.append(dask.delayed(self._dask_job_block)(indices_))
else:
for index in indices:
jobs.append(dask.delayed(self._single_frame_parallel)(index))
blocks = dask.delayed(jobs).persist(**config)
if verbose:
distributed.progress(blocks)
self._results = blocks.compute(**config)
if block:
self._results = [r for b in self._results for r in b]
elif module == "joblib" and FOUND_JOBLIB:
if method is not None and method not in {
"loky",
"multiprocessing",
"threading",
None,
}:
raise ValueError("Invalid Joblib backend.")
logging.info(
"Starting analysis using Joblib " f"({n_jobs=}, backend={method})..."
)
with (
_tqdm_joblib(tqdm(total=self.n_frames))
if verbose
else contextlib.suppress()
):
self._results = joblib.Parallel(
n_jobs=n_jobs, backend=method, **kwargs
)(joblib.delayed(self._single_frame_parallel)(i) for i in indices)
else:
if module != "multiprocessing":
wmsg = (
"The Dask or Joblib library was not found, so "
"the native multiprocessing module will be"
"used instead."
)
warnings.warn(wmsg)
if method is None:
method = multiprocessing.get_start_method()
elif method not in {"fork", "forkserver", "spawn"}:
raise ValueError("Invalid multiprocessing start method.")
logging.info(
"Starting analysis using multiprocessing " f"({n_jobs=}, {method=})..."
)
with multiprocessing.get_context(method).Pool(n_jobs, **kwargs) as p:
self._results = (
tuple(
tqdm(
p.imap(self._single_frame_parallel, indices),
total=self.n_frames,
)
)
if verbose
else p.map(self._single_frame_parallel, indices)
)
if verbose:
logging.info(f"Analysis finished in {datetime.now() - time_start}.")
self._conclude()
return self
[docs]
class DynamicAnalysisBase(ParallelAnalysisBase, SerialAnalysisBase):
"""
A dynamic analysis base object.
Parameters
----------
trajectory : `MDAnalysis.coordinates.base.ReaderBase`
Simulation trajectory.
parallel : `bool`
Determines whether the analysis is performed in parallel.
verbose : `bool`, default: :code:`True`
Determines whether detailed progress is shown.
**kwargs
Additional keyword arguments to pass to
:class:`MDAnalysis.analysis.base.AnalysisBase`.
"""
def __init__(
self, trajectory: ReaderBase, parallel: bool, verbose: bool = False, **kwargs
) -> None:
self._parallel = parallel
(ParallelAnalysisBase if parallel else SerialAnalysisBase).__init__(
self, trajectory, verbose=verbose, **kwargs
)
[docs]
def run(
self,
start: int = None,
stop: int = None,
step: int = None,
frames: Union[slice, np.ndarray[int]] = None,
verbose: bool = None,
**kwargs,
) -> Union[SerialAnalysisBase, ParallelAnalysisBase]:
"""
Performs the calculation.
.. seealso::
For parallel-specific keyword arguments, see
:meth:`ParallelAnalysisBase.run`.
Parameters
----------
start : `int`, optional
Starting frame for analysis.
stop : `int`, optional
Ending frame for analysis.
step : `int`, optional
Number of frames to skip between each analyzed frame.
frames : `slice` or array-like, optional
Index or logical array of the desired trajectory frames.
verbose : `bool`, optional
Determines whether detailed progress is shown.
**kwargs
Additional keyword arguments to pass to
:class:`MDAnalysis.lib.log.ProgressBar`.
Returns
-------
self : `SerialAnalysisBase` or `ParallelAnalysisBase`
Analysis object with results.
"""
return (ParallelAnalysisBase if self._parallel else SerialAnalysisBase).run(
self,
start=start,
stop=stop,
step=step,
frames=frames,
verbose=verbose,
**kwargs,
)