Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / data / _internal / plan.py
Size: Mime:
import copy
import itertools
import uuid
from typing import (
    TYPE_CHECKING,
    Any,
    Callable,
    Dict,
    Iterable,
    Iterator,
    List,
    Optional,
    Tuple,
    Union,
)

import ray
from ray.data._internal.block_list import BlockList
from ray.data._internal.compute import (
    UDF,
    ActorPoolStrategy,
    BlockTransform,
    CallableClass,
    ComputeStrategy,
    get_compute,
    is_task_compute,
)
from ray.data._internal.lazy_block_list import LazyBlockList
from ray.data._internal.stats import DatasetStats
from ray.data.block import Block
from ray.data.context import DatasetContext

if TYPE_CHECKING:
    import pyarrow


# Scheduling strategy can be inherited from prev stage if not specified.
INHERITABLE_REMOTE_ARGS = ["scheduling_strategy"]


class Stage:
    """Represents a Dataset transform stage (e.g., map or shuffle)."""

    def __init__(self, name: str, num_blocks: Optional[int]):
        self.name = name
        self.num_blocks = num_blocks

    def __call__(
        self, blocks: BlockList, clear_input_blocks: bool
    ) -> Tuple[BlockList, dict]:
        """Execute this stage against the given blocks."""
        raise NotImplementedError

    def can_fuse(self, other: "Stage") -> bool:
        """Return whether this can be fused with another stage."""
        raise NotImplementedError

    def fuse(self, other: "Stage") -> "Stage":
        """Fuse this stage with a compatible stage."""
        raise NotImplementedError

    def __repr__(self):
        return f'{type(self).__name__}("{self.name}")'

    def __str__(self):
        return repr(self)


class ExecutionPlan:
    """A lazy execution plan for a Dataset."""

    # Implementation Notes:
    #
    # This lazy execution plan takes in an input block list and builds up a chain of
    # BlockList --> BlockList stages. When execution is triggered, it tries to fuse
    # together stages in order to reduce Ray task overhead and data copies.
    #
    # Internally, the execution plan holds two block lists:
    #   * _in_blocks: The (possibly lazy) input block list.
    #   * _snapshot_blocks: A snapshot of a computed block list, where this snapshot
    #     is the cached output of executing some prefix in the stage chain.
    #
    # The stages in this execution plan are partitioned into two subchains: before the
    # snapshot and after the snapshot. When the snapshot exists from a previous
    # execution, any future executions will only have to execute the "after the
    # snapshot" subchain, using the snapshot as the input to that subchain.

    def __init__(
        self,
        in_blocks: BlockList,
        stats: DatasetStats,
        dataset_uuid=None,
        *,
        run_by_consumer: bool,
    ):
        """Create a plan with no transformation stages.

        Args:
            in_blocks: Base list of blocks.
            stats: Stats for the base blocks.
            dataset_uuid: Dataset's UUID.
            run_by_consumer: Whether this plan is invoked to run by the consumption
            APIs (e.g. .iter_batches()).
        """
        self._in_blocks = in_blocks
        self._in_stats = stats
        # A computed snapshot of some prefix of stages.
        self._snapshot_blocks = None
        self._snapshot_stats = None
        # Chains of stages.
        self._stages_before_snapshot = []
        self._stages_after_snapshot = []
        # Cache of optimized stages.
        self._last_optimized_stages = None

        self._dataset_uuid = dataset_uuid or uuid.uuid4().hex
        if not stats.dataset_uuid:
            stats.dataset_uuid = self._dataset_uuid

        self._run_by_consumer = run_by_consumer

    def __repr__(self) -> str:
        return (
            f"ExecutionPlan("
            f"dataset_uuid={self._dataset_uuid}, "
            f"run_by_consumer={self._run_by_consumer}, "
            f"in_blocks={self._in_blocks}, "
            f"stages_before_snapshot={self._stages_before_snapshot}, "
            f"stages_after_snapshot={self._stages_after_snapshot}, "
            f"snapshot_blocks={self._snapshot_blocks})"
        )

    def with_stage(self, stage: "Stage") -> "ExecutionPlan":
        """Return a copy of this plan with the given stage appended.

        Args:
            stage: The stage to append.

        Returns:
            A new ExecutionPlan with this stage appended.
        """
        copy = self.copy()
        copy._stages_after_snapshot.append(stage)
        return copy

    def copy(self) -> "ExecutionPlan":
        """Create a shallow copy of this execution plan.

        This copy can be executed without mutating the original, but clearing the copy
        will also clear the original.

        Returns:
            A shallow copy of this execution plan.
        """
        plan_copy = ExecutionPlan(
            self._in_blocks, self._in_stats, run_by_consumer=self._run_by_consumer
        )
        if self._snapshot_blocks is not None:
            # Copy over the existing snapshot.
            plan_copy._snapshot_blocks = self._snapshot_blocks
            plan_copy._snapshot_stats = self._snapshot_stats
        plan_copy._stages_before_snapshot = self._stages_before_snapshot.copy()
        plan_copy._stages_after_snapshot = self._stages_after_snapshot.copy()
        return plan_copy

    def deep_copy(self, preserve_uuid: bool = False) -> "ExecutionPlan":
        """Create a deep copy of this execution plan.

        This copy can be executed AND cleared without mutating the original.

        Args:
            preserve_uuid: Whether to preserve the original UUID in the copy.

        Returns:
            A deep copy of this execution plan.
        """
        dataset_uuid = None
        if preserve_uuid:
            dataset_uuid = self._dataset_uuid
        in_blocks = self._in_blocks
        if isinstance(in_blocks, BlockList):
            in_blocks = in_blocks.copy()
        plan_copy = ExecutionPlan(
            in_blocks,
            copy.copy(self._in_stats),
            dataset_uuid=dataset_uuid,
            run_by_consumer=self._run_by_consumer,
        )
        if self._snapshot_blocks:
            # Copy over the existing snapshot.
            plan_copy._snapshot_blocks = self._snapshot_blocks.copy()
            plan_copy._snapshot_stats = copy.copy(self._snapshot_stats)
        plan_copy._stages_before_snapshot = self._stages_before_snapshot.copy()
        plan_copy._stages_after_snapshot = self._stages_after_snapshot.copy()
        return plan_copy

    def initial_num_blocks(self) -> int:
        """Get the estimated number of blocks after applying all plan stages."""
        if self.has_computed_output():
            return self._snapshot_blocks.initial_num_blocks()
        for stage in self._stages_after_snapshot[::-1]:
            if stage.num_blocks is not None:
                return stage.num_blocks
        if self._snapshot_blocks is not None:
            return self._snapshot_blocks.initial_num_blocks()
        for stage in self._stages_before_snapshot[::-1]:
            if stage.num_blocks is not None:
                return stage.num_blocks
        if self._in_blocks is not None:
            return self._in_blocks.initial_num_blocks()
        return None

    def schema(
        self, fetch_if_missing: bool = False
    ) -> Union[type, "pyarrow.lib.Schema"]:
        """Get the schema after applying all plan stages.

        Args:
            fetch_if_missing: Whether to execute the plan to fetch the schema.

        Returns:
            The schema of the output dataset.
        """
        from ray.data._internal.stage_impl import RandomizeBlocksStage

        if self._stages_after_snapshot:
            if fetch_if_missing:
                if isinstance(self._stages_after_snapshot[-1], RandomizeBlocksStage):
                    # TODO(ekl): this is a hack to optimize the case where we have a
                    # trailing randomize block stages. That stage has no effect and
                    # so we don't need to execute all blocks to get the schema.
                    a = self._stages_after_snapshot.pop()
                    try:
                        self.execute()
                    finally:
                        self._stages_after_snapshot.append(a)
                else:
                    self.execute()
            else:
                return None
        # Snapshot is now guaranteed to be the output of the final stage or None.
        blocks = self._snapshot_blocks
        if not blocks:
            return None
        # Don't force fetching in case it's a lazy block list, in which case we
        # don't want to trigger full execution for a schema read. If we want to
        # trigger execution to get schema, we'll trigger read tasks progressively
        # until a viable schema is available, below.
        metadata = blocks.get_metadata(fetch_if_missing=False)
        # Some blocks could be empty, in which case we cannot get their schema.
        # TODO(ekl) validate schema is the same across different blocks.
        for m in metadata:
            if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
                return m.schema
        if not fetch_if_missing:
            return None
        # Synchronously fetch the schema.
        # For lazy block lists, this launches read tasks and fetches block metadata
        # until we find valid block schema.
        for _, m in blocks.iter_blocks_with_metadata():
            if m.schema is not None and (m.num_rows is None or m.num_rows > 0):
                return m.schema
        return None

    def meta_count(self) -> Optional[int]:
        """Get the number of rows after applying all plan stages if possible.

        This method will never trigger any computation.

        Returns:
            The number of records of the result Dataset, or None.
        """
        if self._stages_after_snapshot:
            return None
        # Snapshot is now guaranteed to be the output of the final stage or None.
        blocks = self._snapshot_blocks
        metadata = blocks.get_metadata() if blocks else None
        if metadata and all(m.num_rows is not None for m in metadata):
            return sum(m.num_rows for m in metadata)
        else:
            return None

    def execute(
        self,
        allow_clear_input_blocks: bool = True,
        force_read: bool = False,
    ) -> BlockList:
        """Execute this plan.

        Args:
            allow_clear_input_blocks: Whether we should try to clear the input blocks
                for each stage.
            force_read: Whether to force the read stage to fully execute.

        Returns:
            The blocks of the output dataset.
        """
        if not self.has_computed_output():
            blocks, stats, stages = self._optimize()
            for stage_idx, stage in enumerate(stages):
                if allow_clear_input_blocks:
                    clear_input_blocks = self._should_clear_input_blocks(
                        blocks, stage_idx
                    )
                else:
                    clear_input_blocks = False
                stats_builder = stats.child_builder(stage.name)
                blocks, stage_info = stage(
                    blocks, clear_input_blocks, self._run_by_consumer
                )
                if stage_info:
                    stats = stats_builder.build_multistage(stage_info)
                else:
                    stats = stats_builder.build(blocks)
                stats.dataset_uuid = uuid.uuid4().hex
            # Set the snapshot to the output of the final stage.
            self._snapshot_blocks = blocks
            self._snapshot_stats = stats
            self._snapshot_stats.dataset_uuid = self._dataset_uuid
            self._stages_before_snapshot += self._stages_after_snapshot
            self._stages_after_snapshot = []
        if _is_lazy(self._snapshot_blocks) and force_read:
            self._snapshot_blocks = self._snapshot_blocks.compute_to_blocklist()
        return self._snapshot_blocks

    def clear_block_refs(self) -> None:
        """Clear all cached block references of this plan, including input blocks.

        This will render the plan un-executable unless the root is a LazyBlockList."""
        self._in_blocks.clear()
        self._snapshot_blocks = None
        self._snapshot_stats = None
        # We're erasing the snapshot, so put all stages into the "after snapshot"
        # bucket.
        self._stages_after_snapshot = (
            self._stages_before_snapshot + self._stages_after_snapshot
        )
        self._stages_before_snapshot = []

    def stats(self) -> DatasetStats:
        """Return stats for this plan, forcing execution if needed."""
        self.execute()
        return self._snapshot_stats

    def _should_clear_input_blocks(
        self,
        blocks: BlockList,
        stage_idx: int,
    ):
        """Whether the provided blocks should be cleared when passed into the stage.

        Args:
            blocks: The blocks that we may want to clear.
            stage_idx: The position of the stage in the optimized after-snapshot chain.
        """
        if stage_idx != 0 or self._stages_before_snapshot:
            # Not the first stage, always clear stage input blocks.
            return True
        elif isinstance(blocks, LazyBlockList):
            # Always clear lazy input blocks since they can be recomputed.
            return True
        else:
            # Otherwise, we have non-lazy input blocks that's the source of this
            # execution plan, so we don't clear these.
            return False

    def _optimize(self) -> Tuple[BlockList, DatasetStats, List[Stage]]:
        """Apply stage fusion optimizations, returning an updated source block list and
        associated stats, and a set of optimized stages.
        """
        context = DatasetContext.get_current()
        blocks, stats, stages = self._get_source_blocks_and_stages()
        if context.optimize_reorder_stages:
            stages = _reorder_stages(stages)
        if context.optimize_fuse_stages:
            if context.optimize_fuse_read_stages:
                # If using a lazy datasource, rewrite read stage into one-to-one stage
                # so it can be fused into downstream stages.
                blocks, stats, stages = _rewrite_read_stages(
                    blocks, stats, stages, self._dataset_uuid
                )
            stages = _fuse_one_to_one_stages(stages)
            self._last_optimized_stages = stages
        return blocks, stats, stages

    def _get_source_blocks_and_stages(
        self,
    ) -> Tuple[BlockList, DatasetStats, List[Stage]]:
        """Get the source blocks, corresponding stats, and the stages for plan
        execution.

        If a computed snapshot exists and has not been cleared, return the snapshot
        blocks and stats; otherwise, return the input blocks and stats that the plan was
        created with.
        """
        stages = self._stages_after_snapshot.copy()
        if self._snapshot_blocks is not None:
            if not self._snapshot_blocks.is_cleared():
                # If snapshot exists, we only have to execute the plan from the
                # snapshot.
                blocks = self._snapshot_blocks
                stats = self._snapshot_stats
                # Unlink the snapshot blocks from the plan so we can eagerly reclaim the
                # snapshot block memory after the first stage is done executing.
                self._snapshot_blocks = None
            else:
                # Snapshot exists but has been cleared, so we need to recompute from the
                # source (input blocks).
                blocks = self._in_blocks
                stats = self._in_stats
                stages = self._stages_before_snapshot + self._stages_after_snapshot
        else:
            # If no snapshot exists, we have to execute the full plan from the
            # beginning.
            blocks = self._in_blocks
            stats = self._in_stats
            if not self.has_lazy_input():
                # If not a lazy datasource, unlink the input blocks from the plan so we
                # can eagerly reclaim the input block memory after the first stage is
                # done executing.
                self._in_blocks = None
        return blocks, stats, stages

    def has_lazy_input(self) -> bool:
        """Return whether this plan has lazy input blocks."""
        return _is_lazy(self._in_blocks)

    def is_read_stage_equivalent(self) -> bool:
        """Return whether this plan can be executed as only a read stage."""
        from ray.data._internal.stage_impl import RandomizeBlocksStage

        context = DatasetContext.get_current()
        remaining_stages = self._stages_after_snapshot
        if (
            context.optimize_fuse_stages
            and remaining_stages
            and isinstance(remaining_stages[0], RandomizeBlocksStage)
        ):
            remaining_stages = remaining_stages[1:]
        return (
            self.has_lazy_input()
            and not self._stages_before_snapshot
            and not remaining_stages
            and (
                not self._snapshot_blocks
                or isinstance(self._snapshot_blocks, LazyBlockList)
            )
        )

    def has_computed_output(self) -> bool:
        """Whether this plan has a computed snapshot for the final stage, i.e. for the
        output of this plan.
        """
        return (
            self._snapshot_blocks is not None
            and not self._stages_after_snapshot
            and not self._snapshot_blocks.is_cleared()
        )


def _pack_args(
    self_fn_args: Iterable[Any],
    self_fn_kwargs: Dict[str, Any],
    prev_fn_args: Iterable[Any],
    prev_fn_kwargs: Dict[str, Any],
) -> Tuple[
    Tuple[Any],
    Callable[
        [Tuple[Any]],
        Tuple[
            Tuple[Any],
            Dict[str, Any],
            Tuple[Any],
            Dict[str, Any],
        ],
    ],
]:
    """Pack the (kw)args from two stages into a single, flat positional args tuple that
    can be given to a Ray task, ensuring resoultion of each argument.
    This function returns this args tuple along with a function that will unpack this
    flat args tuple back into it's original args and kwargs structure.
    """
    if not self_fn_args:
        self_fn_args = tuple()
    if not self_fn_kwargs:
        self_fn_kwargs = {}
    if not prev_fn_args:
        prev_fn_args = tuple()
    if not prev_fn_kwargs:
        prev_fn_kwargs = {}
    # Offsets into flat args tuple.
    offsets = list(
        itertools.accumulate(
            [
                len(self_fn_args),
                len(prev_fn_args),
                len(self_fn_kwargs),
                len(prev_fn_kwargs),
            ]
        )
    )
    # Keys for the kwargs.
    keys = list(self_fn_kwargs.keys()) + list(prev_fn_kwargs.keys())

    fn_args = (
        self_fn_args
        + prev_fn_args
        + tuple(self_fn_kwargs.values())
        + tuple(prev_fn_kwargs.values())
    )

    def unpack(
        fn_args: List[Any],
    ) -> Tuple[List[Any], Dict[str, Any], List[Any], Dict[str, Any]]:
        self_fn_args = fn_args[: offsets[0]]
        prev_fn_args = fn_args[offsets[0] : offsets[1]]
        self_fn_kwargs = fn_args[offsets[1] : offsets[2]]
        prev_fn_kwargs = fn_args[offsets[2] :]
        prev_key_offset = offsets[2] - offsets[1]
        self_fn_kwargs = {k: v for k, v in zip(keys[:prev_key_offset], self_fn_kwargs)}
        prev_fn_kwargs = {k: v for k, v in zip(keys[prev_key_offset:], prev_fn_kwargs)}
        return self_fn_args, self_fn_kwargs, prev_fn_args, prev_fn_kwargs

    return fn_args, unpack


class OneToOneStage(Stage):
    """A stage that transforms blocks independently (e.g., map or filter)."""

    def __init__(
        self,
        name: str,
        block_fn: BlockTransform,
        compute: Union[str, ComputeStrategy],
        ray_remote_args: dict,
        fn: Optional[UDF] = None,
        fn_args: Optional[Iterable[Any]] = None,
        fn_kwargs: Optional[Dict[str, Any]] = None,
        fn_constructor_args: Optional[Iterable[Any]] = None,
        fn_constructor_kwargs: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(name, None)
        self.block_fn = block_fn
        self.compute = compute or "tasks"
        self.ray_remote_args = ray_remote_args or {}
        self.fn = fn
        self.fn_args = fn_args
        self.fn_kwargs = fn_kwargs
        self.fn_constructor_args = fn_constructor_args
        self.fn_constructor_kwargs = fn_constructor_kwargs

    def can_fuse(self, prev: Stage):
        if not isinstance(prev, OneToOneStage):
            return False
        # Allow fusing tasks->actors if the resources are compatible (read->map), but
        # not the other way around. The latter will be used as the compute if fused.
        if is_task_compute(self.compute) and prev.compute != self.compute:
            return False
        if (
            isinstance(self.fn, CallableClass)
            and isinstance(prev.fn, CallableClass)
            and (
                prev.fn != self.fn
                or (
                    prev.fn_constructor_args != self.fn_constructor_args
                    or prev.fn_constructor_kwargs != self.fn_constructor_kwargs
                )
            )
        ):
            # Fusing callable classes is only supported if they are the same function
            # AND their construction arguments are the same.
            # TODO(Clark): Support multiple callable classes instantiating in the same
            # actor worker constructor.
            return False
        if not _are_remote_args_compatible(prev.ray_remote_args, self.ray_remote_args):
            return False
        return True

    def fuse(self, prev: Stage):
        if not self.can_fuse(prev):
            raise ValueError(
                f"Tried to fuse {prev} with {self}, but these are not fusable."
            )
        name = prev.name + "->" + self.name
        prev_fn = prev.fn
        if isinstance(self.fn, CallableClass) and isinstance(prev_fn, CallableClass):
            assert self.fn == prev_fn
            assert (
                prev.fn_constructor_args == self.fn_constructor_args
                and prev.fn_constructor_kwargs == self.fn_constructor_kwargs
            )
            # If both UDFs are callable classes, they must be equal and have the same
            # construction args, so we tell the previous stage to reuse the passed
            # (instantiated) callable class UDF that's provided to the block function.
            use_outer_fn = True
            prev_fn = None
        else:
            # Otherwise, we're either fusing two non-callable class UDFs, or a
            # non-callable class UDF with a callable class UDF. In either case, prev
            # will be a non-callable class UDF, so we use it within the block function.
            use_outer_fn = False

        # Package args into a flat positional args list.
        fn_args, unpack_args = _pack_args(
            self.fn_args,
            self.fn_kwargs,
            prev.fn_args,
            prev.fn_kwargs,
        )

        block_fn1 = prev.block_fn
        block_fn2 = self.block_fn

        def block_fn(
            block: Block,
            fn: UDF,
            *fn_args,
            **fn_kwargs,
        ) -> Iterable[Block]:
            assert not fn_kwargs, fn_kwargs
            # Unpack flat position args list into
            self_fn_args, self_fn_kwargs, prev_fn_args, prev_fn_kwargs = unpack_args(
                fn_args
            )
            self_fn_args = self_fn_args if fn is None else (fn,) + self_fn_args
            if use_outer_fn:
                prev_fn_ = fn
            else:
                prev_fn_ = prev_fn
            prev_fn_args = (
                prev_fn_args if prev_fn_ is None else (prev_fn_,) + prev_fn_args
            )
            for tmp1 in block_fn1(block, *prev_fn_args, **prev_fn_kwargs):
                for tmp2 in block_fn2(tmp1, *self_fn_args, **self_fn_kwargs):
                    yield tmp2

        return OneToOneStage(
            name,
            block_fn,
            self.compute,
            prev.ray_remote_args,
            fn=self.fn,
            fn_args=fn_args,
            fn_kwargs={},
            fn_constructor_args=self.fn_constructor_args,
            fn_constructor_kwargs=self.fn_constructor_kwargs,
        )

    def __call__(
        self, blocks: BlockList, clear_input_blocks: bool, run_by_consumer: bool
    ) -> Tuple[BlockList, dict]:
        compute = get_compute(self.compute)
        assert (
            self.fn_constructor_args is None and self.fn_constructor_kwargs is None
        ) or isinstance(compute, ActorPoolStrategy)

        if blocks._owned_by_consumer:
            assert (
                run_by_consumer
            ), "Blocks owned by consumer can only be consumed by consumer"

        blocks = compute._apply(
            self.block_fn,
            self.ray_remote_args,
            blocks,
            clear_input_blocks,
            name=self.name,
            fn=self.fn,
            fn_args=self.fn_args,
            fn_kwargs=self.fn_kwargs,
            fn_constructor_args=self.fn_constructor_args,
            fn_constructor_kwargs=self.fn_constructor_kwargs,
        )
        assert isinstance(blocks, BlockList), blocks
        blocks._owned_by_consumer = run_by_consumer
        return blocks, {}


class AllToAllStage(Stage):
    """A stage that transforms blocks holistically (e.g., shuffle)."""

    def __init__(
        self,
        name: str,
        num_blocks: Optional[int],
        fn: Callable[[BlockList, bool, Callable], Tuple[BlockList, dict]],
        supports_block_udf: bool = False,
        block_udf: Optional[BlockTransform] = None,
        remote_args: Optional[Dict[str, Any]] = None,
    ):
        super().__init__(name, num_blocks)
        self.fn = fn
        self.supports_block_udf = supports_block_udf
        self.block_udf = block_udf
        self.ray_remote_args = remote_args or {}

    def can_fuse(self, prev: Stage):
        context = DatasetContext.get_current()
        # TODO(ekl) also support fusing shuffle stages to subsequent 1:1 stages.
        if not context.optimize_fuse_shuffle_stages:
            return False
        if not self.supports_block_udf:
            return False
        if not isinstance(prev, OneToOneStage):
            return False
        if not is_task_compute(prev.compute):
            return False
        if any(k not in INHERITABLE_REMOTE_ARGS for k in prev.ray_remote_args):
            return False
        return True

    def fuse(self, prev: Stage):
        if not self.can_fuse(prev):
            raise ValueError(
                f"Tried to fuse {prev} with {self}, but these are not fusable."
            )
        assert self.supports_block_udf
        assert prev.fn_constructor_args is None and prev.fn_constructor_kwargs is None
        name = prev.name + "->" + self.name
        prev_fn_args = prev.fn_args or tuple()
        prev_fn_args = prev_fn_args if prev.fn is None else (prev.fn,) + prev_fn_args
        prev_fn_kwargs = prev.fn_kwargs or {}
        prev_block_fn = prev.block_fn
        if self.block_udf is None:

            def block_udf(block: Block) -> Iterable[Block]:
                yield from prev_block_fn(block, *prev_fn_args, **prev_fn_kwargs)

        else:
            self_block_udf = self.block_udf

            def block_udf(block: Block) -> Iterable[Block]:
                for tmp1 in prev_block_fn(
                    block,
                    *prev_fn_args,
                    **prev_fn_kwargs,
                ):
                    for tmp2 in self_block_udf(tmp1):
                        yield tmp2

        return AllToAllStage(
            name, self.num_blocks, self.fn, True, block_udf, prev.ray_remote_args
        )

    def __call__(
        self, blocks: BlockList, clear_input_blocks: bool, run_by_consumer: bool
    ) -> Tuple[BlockList, dict]:
        from ray.data._internal.stage_impl import RandomizeBlocksStage

        in_blocks_owned_by_consumer = blocks._owned_by_consumer
        if in_blocks_owned_by_consumer:
            assert (
                run_by_consumer
            ), "Blocks owned by consumer can only be consumed by consumer"
        blocks, stage_info = self.fn(
            blocks, clear_input_blocks, self.block_udf, self.ray_remote_args
        )
        assert isinstance(blocks, BlockList), blocks

        # RandomizeBlocksStage is an in-place transformation, so the ownership
        # of blocks doesn't change.
        if isinstance(self, RandomizeBlocksStage):
            blocks._owned_by_consumer = in_blocks_owned_by_consumer
        else:
            blocks._owned_by_consumer = run_by_consumer

        return blocks, stage_info


def _rewrite_read_stages(
    blocks: BlockList,
    stats: DatasetStats,
    stages: List[Stage],
    dataset_uuid: str,
) -> Tuple[BlockList, DatasetStats, List[Stage]]:
    """Rewrites read stages into one-to-one stages, if needed."""
    if _is_lazy(blocks) and stages:
        blocks, stats, stages = _rewrite_read_stage(blocks, stages)
        stats.dataset_uuid = dataset_uuid
    return blocks, stats, stages


def _rewrite_read_stage(
    in_blocks: LazyBlockList, stages: List[Stage]
) -> Tuple[BlockList, DatasetStats, List[Stage]]:
    """Rewrite the read stage to a OneToOne stage over read tasks as input.

    For example, suppose the plan was [Read -> MapBatches(Fn)]. These stages cannot
    be fused, since read stages are handled specially.
    After rewriting to [GetReadTasks -> MapBatches(DoRead) -> MapBatches(Fn)],
    now we can fuse the latter two MapBatches stages into a single OneToOne stage:
    [GetReadTasks -> MapBatches(DoRead -> Fn)].

    Args:
        blocks: Lazy block list representing read stage.
        stages: List of current stages.

    Returns:
        Non-lazy block list containing read tasks for not-yet-read block partitions,
        new stats for the block list, and the new list of stages.
    """
    from ray.data._internal.stage_impl import RandomizeBlocksStage

    # Generate the "GetReadTasks" stage blocks.
    remote_args = in_blocks._remote_args
    blocks, metadata = [], []
    for read_task in in_blocks._tasks:
        blocks.append(ray.put(read_task._read_fn))
        metadata.append(read_task.get_metadata())
    block_list = BlockList(
        blocks, metadata, owned_by_consumer=in_blocks._owned_by_consumer
    )

    def block_fn(read_fn: Callable[[], Iterator[Block]]) -> Iterator[Block]:
        for block in read_fn():
            yield block

    name = "read"

    # Fuse downstream randomize stage with the read stage if possible. This is needed
    # when .window() is called right after read->randomize, since it forces execution.
    has_randomize = stages and isinstance(stages[0], RandomizeBlocksStage)
    if has_randomize:
        if stages and isinstance(stages[0], RandomizeBlocksStage):
            block_list, _ = stages[0].do_randomize(block_list)
            stages = stages[1:]
        name += "->randomize_block_order"

    stage = OneToOneStage(
        name,
        block_fn,
        "tasks",
        remote_args,
    )
    stats = DatasetStats(stages={}, parent=None)
    stages.insert(0, stage)
    return block_list, stats, stages


def _reorder_stages(stages: List[Stage]) -> List[Stage]:
    """Reorder randomize stages to the end to enable better stage fusion.

    This applies to RandomizeBlockOrder stages specifically (issue #26057).

    Args:
        stages: Stages to try to reorder.

    Returns:
        Reordered stages.
    """
    from ray.data._internal.stage_impl import RandomizeBlocksStage

    output: List[Stage] = []
    reorder_buf: List[RandomizeBlocksStage] = []

    for s in stages:
        if isinstance(s, RandomizeBlocksStage):
            # Buffer it for later reordering.
            reorder_buf.append(s)
        else:
            # Barrier: flush the reorder buffer.
            if isinstance(s, AllToAllStage):
                output.extend(reorder_buf)
                reorder_buf = []
            output.append(s)

    output.extend(reorder_buf)
    return output


def _fuse_one_to_one_stages(stages: List[Stage]) -> List[Stage]:
    """Fuses compatible one-to-one stages.

    Args:
        stages: Stages to try to fuse.

    Returns:
        Fused stages.
    """
    fused_stages = []
    prev_stage = None
    for idx, stage in enumerate(stages):
        if prev_stage is None:
            prev_stage = stage
        elif stage.can_fuse(prev_stage):
            prev_stage = stage.fuse(prev_stage)
        else:
            fused_stages.append(prev_stage)
            prev_stage = stage
    if prev_stage:
        fused_stages.append(prev_stage)
        prev_stage = None
    return fused_stages


def _are_remote_args_compatible(prev_args, next_args):
    """Check if Ray remote arguments are compatible for merging."""
    prev_args = _canonicalize(prev_args)
    next_args = _canonicalize(next_args)
    remote_args = next_args.copy()
    for key in INHERITABLE_REMOTE_ARGS:
        if key in prev_args:
            remote_args[key] = prev_args[key]
    if prev_args != remote_args:
        return False
    return True


def _canonicalize(remote_args: dict) -> dict:
    """Returns canonical form of given remote args."""
    remote_args = remote_args.copy()
    if "num_cpus" not in remote_args or remote_args["num_cpus"] is None:
        remote_args["num_cpus"] = 1
    if "num_gpus" not in remote_args or remote_args["num_gpus"] is None:
        remote_args["num_gpus"] = 0
    resources = remote_args.get("resources", {})
    for k, v in list(resources.items()):
        if v is None or v == 0.0:
            del resources[k]
    remote_args["resources"] = resources
    return remote_args


def _is_lazy(blocks: BlockList) -> bool:
    """Whether the provided block list is lazy."""
    return isinstance(blocks, LazyBlockList)