Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / purelib / ray / dag / input_node.py
Size: Mime:
from typing import Any, Dict, List, Union

from ray.dag import DAGNode
from ray.dag.format_utils import get_dag_node_str
from ray.dag.constants import DAGNODE_TYPE_KEY
from ray.util.annotations import DeveloperAPI

IN_CONTEXT_MANAGER = "__in_context_manager__"


@DeveloperAPI
class InputNode(DAGNode):
    """Ray dag node used in DAG building API to mark entrypoints of a DAG.

    Should only be function or class method. A DAG can have multiple
    entrypoints, but only one instance of InputNode exists per DAG, shared
    among all DAGNodes.

    Example:
                   m1.forward
                /            \
        dag_input              ensemble -> dag_output
                \            /
                   m2.forward

    In this pipeline, each user input is broadcasted to both m1.forward and
    m2.forward as first stop of the DAG, and authored like
    >>> @ray.remote
    >>> class Model:
    ...     def __init__(self, val):
    ...         self.val = val
    ...     def forward(self, input):
    ...         return self.val * input

    >>> @ray.remote
    >>> def combine(a, b):
    ...     return a + b

    >>> with InputNode() as dag_input:
    >>>     m1 = Model.bind(1)
    >>>     m2 = Model.bind(2)
    >>>     m1_output = m1.forward.bind(dag_input[0])
    >>>     m2_output = m2.forward.bind(dag_input.x)
    >>>     ray_dag = combine.bind(m1_output, m2_output)

    >>> # Pass mix of args and kwargs as input.
    >>> ray_dag.execute(1, x=2) # 1 sent to m1, 2 sent to m2

    >>> # Alternatively user can also pass single data object, list or dict
    >>> # and access them via list index, object attribute or dict key str.
    >>> ray_dag.execute(UserDataObject(m1=1, m2=2))
    ...     # dag_input.m1, dag_input.m2
    >>> ray_dag.execute([1, 2]))
    ...     # dag_input[0], dag_input[1]
    >>> ray_dag.execute({"m1": 1, "m2": 2})
    ...     # dag_input["m1"], dag_input["m2"]
    """

    def __init__(self, *args, _other_args_to_resolve=None, **kwargs):
        """InputNode should only take attributes of validating and converting
        input data rather than the input data itself. User input should be
        provided via `ray_dag.execute(user_input)`.

        Args:
            _other_args_to_resolve: Internal only to keep InputNode's execution
                context throughput pickling, replacement and serialization.
                User should not use or pass this field.
        """
        if len(args) != 0 or len(kwargs) != 0:
            raise ValueError("InputNode should not take any args or kwargs.")

        super().__init__([], {}, {}, other_args_to_resolve=_other_args_to_resolve)

    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ):
        return InputNode(_other_args_to_resolve=new_other_args_to_resolve)

    def _execute_impl(self, *args, **kwargs):
        """Executor of InputNode."""
        # Catch and assert singleton context at dag execution time.
        assert self._in_context_manager(), (
            "InputNode is a singleton instance that should be only used in "
            "context manager for dag building and execution. See the docstring "
            "of class InputNode for examples."
        )
        # If user only passed in one value, for simplicity we just return it.
        if len(args) == 1 and len(kwargs) == 0:
            return args[0]

        return DAGInputData(*args, **kwargs)

    def _in_context_manager(self) -> bool:
        """Return if InputNode is created in context manager."""
        if (
            not self._bound_other_args_to_resolve
            or IN_CONTEXT_MANAGER not in self._bound_other_args_to_resolve
        ):
            return False
        else:
            return self._bound_other_args_to_resolve[IN_CONTEXT_MANAGER]

    def set_context(self, key: str, val: Any):
        """Set field in parent DAGNode attribute that can be resolved in both
        pickle and JSON serialization
        """
        self._bound_other_args_to_resolve[key] = val

    def __str__(self) -> str:
        return get_dag_node_str(self, "__InputNode__")

    def __getattr__(self, key: str):
        assert isinstance(
            key, str
        ), "Please only access dag input attributes with str key."
        return InputAttributeNode(self, key, "__getattr__")

    def __getitem__(self, key: Union[int, str]) -> Any:
        assert isinstance(key, (str, int)), (
            "Please only use int index or str as first-level key to "
            "access fields of dag input."
        )
        return InputAttributeNode(self, key, "__getitem__")

    def __enter__(self):
        self.set_context(IN_CONTEXT_MANAGER, True)
        return self

    def __exit__(self, *args):
        pass

    def to_json(self) -> Dict[str, Any]:
        return {
            DAGNODE_TYPE_KEY: InputNode.__name__,
            "other_args_to_resolve": self.get_other_args_to_resolve(),
            "uuid": self.get_stable_uuid(),
        }

    @classmethod
    def from_json(cls, input_json):
        assert input_json[DAGNODE_TYPE_KEY] == InputNode.__name__
        node = cls(_other_args_to_resolve=input_json["other_args_to_resolve"])
        node._stable_uuid = input_json["uuid"]
        return node


@DeveloperAPI
class InputAttributeNode(DAGNode):
    """Represents partial access of user input based on an index (int),
     object attribute or dict key (str).

    Examples:
        >>> with InputNode() as dag_input:
        >>>     a = dag_input[0]
        >>>     b = dag_input.x
        >>>     ray_dag = add.bind(a, b)

        >>> # This makes a = 1 and b = 2
        >>> ray_dag.execute(1, x=2)

        >>> with InputNode() as dag_input:
        >>>     a = dag_input[0]
        >>>     b = dag_input[1]
        >>>     ray_dag = add.bind(a, b)

        >>> # This makes a = 2 and b = 3
        >>> ray_dag.execute(2, 3)

        >>> # Alternatively, you can input a single object
        >>> # and the inputs are automatically indexed from the object:
        >>> # This makes a = 2 and b = 3
        >>> ray_dag.execute([2, 3])
    """

    def __init__(self, dag_input_node: InputNode, key: str, accessor_method: str):
        self._dag_input_node = dag_input_node
        self._key = key
        self._accessor_method = accessor_method
        super().__init__(
            [],
            {},
            {},
            {
                "dag_input_node": dag_input_node,
                "key": key,
                "accessor_method": accessor_method,
            },
        )

    def _copy_impl(
        self,
        new_args: List[Any],
        new_kwargs: Dict[str, Any],
        new_options: Dict[str, Any],
        new_other_args_to_resolve: Dict[str, Any],
    ):
        return InputAttributeNode(
            new_other_args_to_resolve["dag_input_node"],
            new_other_args_to_resolve["key"],
            new_other_args_to_resolve["accessor_method"],
        )

    def _execute_impl(self, *args, **kwargs):
        """Executor of InputAttributeNode.

        Args and kwargs are to match base class signature, but not in the
        implementation. All args and kwargs should be resolved and replaced
        with value in bound_args and bound_kwargs via bottom-up recursion when
        current node is executed.
        """

        if isinstance(self._dag_input_node, DAGInputData):
            return self._dag_input_node[self._key]
        else:
            # dag.execute() is called with only one arg, thus when an
            # InputAttributeNode is executed, its dependent InputNode is
            # resolved with original user input python object.
            user_input_python_object = self._dag_input_node
            if isinstance(self._key, str):
                if self._accessor_method == "__getitem__":
                    return user_input_python_object[self._key]
                elif self._accessor_method == "__getattr__":
                    return getattr(user_input_python_object, self._key)
            elif isinstance(self._key, int):
                return user_input_python_object[self._key]
            else:
                raise ValueError(
                    "Please only use int index or str as first-level key to "
                    "access fields of dag input."
                )

    def __str__(self) -> str:
        return get_dag_node_str(self, f'["{self._key}"]')

    def to_json(self) -> Dict[str, Any]:
        return {
            DAGNODE_TYPE_KEY: InputAttributeNode.__name__,
            "other_args_to_resolve": self.get_other_args_to_resolve(),
            "uuid": self.get_stable_uuid(),
        }

    @classmethod
    def from_json(cls, input_json):
        assert input_json[DAGNODE_TYPE_KEY] == InputAttributeNode.__name__
        node = cls(
            input_json["other_args_to_resolve"]["dag_input_node"],
            input_json["other_args_to_resolve"]["key"],
            input_json["other_args_to_resolve"]["accessor_method"],
        )
        node._stable_uuid = input_json["uuid"]
        return node


@DeveloperAPI
class DAGInputData:
    """If user passed multiple args and kwargs directly to dag.execute(), we
    generate this wrapper for all user inputs as one object, accessible via
    list index or object attribute key.
    """

    def __init__(self, *args, **kwargs):
        self._args = list(args)
        self._kwargs = kwargs

    def __getitem__(self, key: Union[int, str]) -> Any:
        if isinstance(key, int):
            # Access list args by index.
            return self._args[key]
        elif isinstance(key, str):
            # Access kwarg by key.
            return self._kwargs[key]
        else:
            raise ValueError(
                "Please only use int index or str as first-level key to "
                "access fields of dag input."
            )