Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
pytorch-lightning / strategies / launchers / multiprocessing.py
Size: Mime:
# Copyright The Lightning AI team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import os
import queue
import tempfile
from contextlib import suppress
from dataclasses import dataclass
from typing import Any, Callable, Dict, List, Literal, NamedTuple, Optional, Union

import torch
import torch.backends.cudnn
import torch.multiprocessing as mp
from lightning_utilities.core.apply_func import apply_to_collection
from torch import Tensor
from typing_extensions import override

import pytorch_lightning as pl
from lightning_fabric.strategies.launchers.multiprocessing import (
    _check_bad_cuda_fork,
    _check_missing_main_guard,
    _disable_module_memory_sharing,
)
from lightning_fabric.utilities import move_data_to_device
from lightning_fabric.utilities.distributed import _set_num_threads_if_needed
from lightning_fabric.utilities.seed import _collect_rng_states, _set_rng_states
from lightning_fabric.utilities.types import _PATH
from pytorch_lightning.accelerators import CPUAccelerator
from pytorch_lightning.strategies.launchers.launcher import _Launcher
from pytorch_lightning.trainer.connectors.signal_connector import _SIGNUM
from pytorch_lightning.trainer.states import TrainerFn, TrainerState
from pytorch_lightning.utilities.rank_zero import rank_zero_debug

log = logging.getLogger(__name__)


class _MultiProcessingLauncher(_Launcher):
    r"""Launches processes that run a given function in parallel, and joins them all at the end.

    The main process in which this launcher is invoked creates N so-called worker processes (using
    :func:`torch.multiprocessing.start_processes`) that run the given function.
    Worker processes have a rank that ranges from 0 to N - 1.

    Note:
        - This launcher requires all objects to be pickleable.
        - It is important that the entry point to the program/script is guarded by ``if __name__ == "__main__"``.
        - With start method 'fork' the user must ensure that no CUDA context gets created in the main process before
          the launcher is invoked. E.g., one should avoid creating cuda tensors or calling ``torch.cuda.*`` functions
          before calling ``Trainer.fit``.

    Args:
        strategy: A reference to the strategy that is used together with this launcher.
        start_method: The method how to start the processes.
            - 'spawn': The default start method. Requires all objects to be pickleable.
            - 'fork': Preferable for IPython/Jupyter environments where 'spawn' is not available. Not available on
              the Windows platform for example.
            - 'forkserver': Alternative implementation to 'fork'.

    """

    def __init__(
        self, strategy: "pl.strategies.ParallelStrategy", start_method: Literal["spawn", "fork", "forkserver"] = "spawn"
    ) -> None:
        self._strategy = strategy
        self._start_method = start_method
        if start_method not in mp.get_all_start_methods():
            raise ValueError(
                f"The start method '{self._start_method}' is not available on this platform. Available methods are:"
                f" {', '.join(mp.get_all_start_methods())}"
            )
        self.procs: List[mp.Process] = []
        self._already_fit = False

    @property
    @override
    def is_interactive_compatible(self) -> bool:
        # The start method 'spawn' is not supported in interactive environments
        # The start method 'fork' is the only one supported in Jupyter environments, with constraints around CUDA
        # initialization. For more context, see https://github.com/Lightning-AI/lightning/issues/7550
        return self._start_method == "fork"

    @override
    def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] = None, **kwargs: Any) -> Any:
        """Launches processes that run the given function in parallel.

        The function is allowed to have a return value. However, when all processes join, only the return value
        of worker process 0 gets returned from this `launch` method in the main process.

        Arguments:
            function: The entry point for all launched processes.
            *args: Optional positional arguments to be passed to the given function.
            trainer: Optional reference to the :class:`~pytorch_lightning.trainer.trainer.Trainer` for which
                a selected set of attributes get restored in the main process after processes join.
            **kwargs: Optional keyword arguments to be passed to the given function.

        """
        if self._start_method in ("fork", "forkserver"):
            _check_bad_cuda_fork()
        if self._start_method == "spawn":
            _check_missing_main_guard()
        if self._already_fit and trainer is not None and trainer.state.fn == TrainerFn.FITTING:
            # resolving https://github.com/Lightning-AI/lightning/issues/18775 will lift this restriction
            raise NotImplementedError(
                "Calling `trainer.fit()` twice on the same Trainer instance using a spawn-based strategy is not"
                " supported. You can work around this limitation by creating a new Trainer instance and passing the"
                " `fit(ckpt_path=...)` argument."
            )

        # The default cluster environment in Lightning chooses a random free port number
        # This needs to be done in the main process here before starting processes to ensure each rank will connect
        # through the same port
        assert self._strategy.cluster_environment is not None
        os.environ["MASTER_PORT"] = str(self._strategy.cluster_environment.main_port)

        context = mp.get_context(self._start_method)
        return_queue = context.SimpleQueue()

        if self._start_method == "spawn":
            global_states = _GlobalStateSnapshot.capture()
            process_args = [trainer, function, args, kwargs, return_queue, global_states]
        else:
            process_args = [trainer, function, args, kwargs, return_queue]

        process_context = mp.start_processes(
            self._wrapping_function,
            args=process_args,
            nprocs=self._strategy.num_processes,
            start_method=self._start_method,
            join=False,  # we will join ourselves to get the process references
        )
        self.procs = process_context.processes
        while not process_context.join():
            pass

        worker_output = return_queue.get()
        if trainer is None:
            return worker_output

        self._already_fit |= trainer.state.fn == TrainerFn.FITTING
        self._recover_results_in_main_process(worker_output, trainer)
        return worker_output.trainer_results

    def _wrapping_function(
        self,
        process_idx: int,
        trainer: Optional["pl.Trainer"],
        function: Callable,
        args: Any,
        kwargs: Any,
        return_queue: Union[mp.SimpleQueue, queue.Queue],
        global_states: Optional["_GlobalStateSnapshot"] = None,
    ) -> None:
        if global_states:
            global_states.restore()
        if self._start_method == "spawn" and isinstance(self._strategy.accelerator, CPUAccelerator):
            args, kwargs = _disable_module_memory_sharing((args, kwargs))

        _set_num_threads_if_needed(num_processes=self._strategy.num_processes)

        os.environ["LOCAL_RANK"] = str(process_idx)
        results = function(*args, **kwargs)

        if trainer is not None:
            results = self._collect_rank_zero_results(trainer, results)

        if process_idx == 0:
            return_queue.put(move_data_to_device(results, "cpu"))

    def _recover_results_in_main_process(self, worker_output: "_WorkerOutput", trainer: "pl.Trainer") -> None:
        # transfer back the best path to the trainer
        if trainer.checkpoint_callback and hasattr(trainer.checkpoint_callback, "best_model_path"):
            trainer.checkpoint_callback.best_model_path = str(worker_output.best_model_path)

        # TODO: pass also best score
        # load last weights
        if worker_output.weights_path is not None:
            ckpt = self._strategy.checkpoint_io.load_checkpoint(worker_output.weights_path)
            # choose non-strict loading of parameters on the main process, because the model's composition
            # could have changed in the worker process (layers added or removed)
            trainer.lightning_module.load_state_dict(ckpt, strict=False)
            self._strategy.checkpoint_io.remove_checkpoint(worker_output.weights_path)

        trainer.state = worker_output.trainer_state

        # get the `callback_metrics` and set it to the trainer
        self.update_main_process_results(trainer, worker_output.extra)

    def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Optional["_WorkerOutput"]:
        rank_zero_debug("Collecting results from rank 0 process.")
        checkpoint_callback = trainer.checkpoint_callback
        best_model_path = (
            checkpoint_callback.best_model_path
            if checkpoint_callback and hasattr(checkpoint_callback, "best_model_path")
            else None
        )

        # requires to compute the state_dict on all processes in case Metrics are present
        state_dict = trainer.lightning_module.state_dict()

        if self._strategy.local_rank != 0:
            return None

        # save the last weights
        weights_path = None
        if trainer.state.fn == TrainerFn.FITTING:
            # use tempdir here to avoid race conditions because the filesystem may be shared between nodes
            weights_path = os.path.join(tempfile.mkdtemp(), ".temp.ckpt")
            self._strategy.checkpoint_io.save_checkpoint(state_dict, weights_path)

        # add extra result data from trainer to send to main process
        extra = self.get_extra_results(trainer)

        return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)

    def get_extra_results(self, trainer: "pl.Trainer") -> Dict[str, Any]:
        """Gather extra state from the Trainer and return it as a dictionary for sending back to the main process. To
        avoid issues with memory sharing, we convert tensors to bytes.

        Args:
            trainer: reference to the Trainer.

        Returns:
            A dictionary with items to send back to the main process where :meth:`update_main_process_results` will
            process this output.

        """
        callback_metrics = apply_to_collection(trainer.callback_metrics, Tensor, lambda t: t.cpu())
        buffer = io.BytesIO()
        torch.save(callback_metrics, buffer)
        # send tensors as bytes to avoid issues with memory sharing
        return {"callback_metrics_bytes": buffer.getvalue()}

    def update_main_process_results(self, trainer: "pl.Trainer", extra: Dict[str, Any]) -> None:
        """Retrieve the :attr:`trainer.callback_metrics` dictionary from the given queue. To preserve consistency, we
        convert bytes back to ``torch.Tensor``.

        Args:
            trainer: reference to the Trainer.
            extra: A dictionary with trainer state that was sent from the worker process and needs to be restored
                on the current trainer.

        """
        # NOTE: `get_extra_results` needs to be called before
        callback_metrics_bytes = extra["callback_metrics_bytes"]
        callback_metrics = torch.load(io.BytesIO(callback_metrics_bytes), weights_only=True)
        trainer.callback_metrics.update(callback_metrics)

    @override
    def kill(self, signum: _SIGNUM) -> None:
        for proc in self.procs:
            if proc.is_alive() and proc.pid is not None:
                log.debug(f"Process {os.getpid()} is terminating {proc.pid} with {signum}")
                with suppress(ProcessLookupError):
                    os.kill(proc.pid, signum)

    def __getstate__(self) -> Dict:
        state = self.__dict__.copy()
        state["procs"] = []  # SpawnProcess can't be pickled
        return state


class _WorkerOutput(NamedTuple):
    best_model_path: Optional[_PATH]
    weights_path: Optional[_PATH]
    trainer_state: TrainerState
    trainer_results: Any
    extra: Dict[str, Any]


@dataclass
class _GlobalStateSnapshot:
    """Captures a hand-selected set of (global) variables in modules and provides a way to restore them.

    It facilitates and encapsulates the transfer of globals like PyTorch's deterministic flags or random generator state
    across process boundaries when launching processes with :func:`torch.multiprocessing.spawn`.

    Example:

        .. code-block:: python

            # in main process
            snapshot = _GlobalStateSnapshot.capture()

            # in worker process
            snapshot.restore()

    """

    use_deterministic_algorithms: bool
    use_deterministic_algorithms_warn_only: bool
    cudnn_benchmark: bool
    rng_states: Dict[str, Any]

    @classmethod
    def capture(cls) -> "_GlobalStateSnapshot":
        """Capture a few global states from torch, numpy, etc., that we want to restore in a spawned worker process."""
        return cls(
            use_deterministic_algorithms=torch.are_deterministic_algorithms_enabled(),
            use_deterministic_algorithms_warn_only=torch.is_deterministic_algorithms_warn_only_enabled(),
            cudnn_benchmark=torch.backends.cudnn.benchmark,
            rng_states=_collect_rng_states(),
        )

    def restore(self) -> None:
        """Restores all globals to the values captured in the :meth:`capture` method."""
        torch.use_deterministic_algorithms(
            self.use_deterministic_algorithms, warn_only=self.use_deterministic_algorithms_warn_only
        )
        torch.backends.cudnn.benchmark = self.cudnn_benchmark
        _set_rng_states(self.rng_states)