Source code for mlresearch.utils._parallelize

import os
import contextlib
from joblib import Parallel, delayed
from ._utils import _optional_import


def _get_n_jobs(n_jobs):
    """Assign number of jobs to be assigned in parallel."""
    max_jobs = os.cpu_count()
    n_jobs = 1 if n_jobs is None else int(n_jobs)
    if n_jobs > max_jobs:
        raise RuntimeError("Cannot assign more jobs than the number of CPUs.")
    elif n_jobs == -1:
        return max_jobs
    else:
        return n_jobs


@contextlib.contextmanager
def _tqdm_joblib(tqdm_object):
    """
    Context manager to patch joblib to report into tqdm progress bar given as argument.
    """

    def tqdm_print_progress(self):
        if self.n_completed_tasks > tqdm_object.n:
            n_completed = self.n_completed_tasks - tqdm_object.n
            tqdm_object.update(n=n_completed)

    original_print_progress = Parallel.print_progress
    Parallel.print_progress = tqdm_print_progress

    try:
        yield tqdm_object
    finally:
        Parallel.print_progress = original_print_progress
        tqdm_object.close()


[docs] def parallel_loop( function, iterable, n_jobs=None, progress_bar=False, description=None ): """ Parallelize a loop and optionally add a progress bar. .. warning:: The progress bar tracks job starts, not completions. Parameters ---------- function : function The function to which the elements in the iterable will passed to. Must have a single parameter. iterable : iterable Object to be looped over. n_jobs : int, default=None Number of jobs to run in parallel. None means 1 unless in a joblib.parallel_backend context. -1 means using all processors. Returns ------- output : list The list with the results produced using ``function`` across ``iterable``. """ n_jobs = _get_n_jobs(n_jobs) if progress_bar: tqdm = _optional_import("tqdm.auto").tqdm with _tqdm_joblib(tqdm(desc=description, total=len(iterable))) as progress_bar: return Parallel(n_jobs=n_jobs)(delayed(function)(i) for i in iterable) else: return Parallel(n_jobs=n_jobs)(delayed(function)(i) for i in iterable)