Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
ray / dag / vis_utils.py
Size: Mime:
import os
import tempfile

from ray.dag import DAGNode
from ray.dag.utils import _DAGNodeNameGenerator
from ray.util.annotations import DeveloperAPI


@DeveloperAPI
def plot(dag: DAGNode, to_file=None):
    if to_file is None:
        tmp_file = tempfile.NamedTemporaryFile(suffix=".png")
        to_file = tmp_file.name
        extension = "png"
    else:
        _, extension = os.path.splitext(to_file)
        if not extension:
            extension = "png"
        else:
            extension = extension[1:]

    graph = _dag_to_dot(dag)
    graph.write(to_file, format=extension)

    # Render the image directly if running inside a Jupyter notebook
    try:
        from IPython import display

        return display.Image(filename=to_file)
    except ImportError:
        pass

    # close temp file if needed
    try:
        tmp_file.close()
    except NameError:
        pass


def _check_pydot_and_graphviz():
    """Check if pydot and graphviz are installed.

    pydot and graphviz are required for plotting. We check this
    during runtime rather than adding them to Ray dependencies.

    """
    try:
        import pydot
    except ImportError:
        raise ImportError(
            "pydot is required to plot DAG, install it with `pip install pydot`."
        )
    try:
        pydot.Dot.create(pydot.Dot())
    except (OSError, pydot.InvocationException):
        raise ImportError(
            "graphviz is required to plot DAG, "
            "download it from https://graphviz.gitlab.io/download/"
        )


def _get_nodes_and_edges(dag: DAGNode):
    """Get all unique nodes and edges in the DAG.

    A basic dfs with memorization to get all unique nodes
    and edges in the DAG.
    Unique nodes will be used to generate unique names,
    while edges will be used to construct the graph.
    """

    edges = []
    nodes = []

    def _dfs(node):
        nodes.append(node)
        for child_node in node._get_all_child_nodes():
            edges.append((child_node, node))
        return node

    dag.apply_recursive(_dfs)
    return nodes, edges


def _dag_to_dot(dag: DAGNode):
    """Create a Dot graph from dag.

    TODO(lchu):
    1. add more Dot configs in kwargs,
    e.g. rankdir, alignment, etc.
    2. add more contents to graph,
    e.g. args, kwargs and options of each node

    """
    # Step 0: check dependencies and init graph
    _check_pydot_and_graphviz()
    import pydot

    graph = pydot.Dot(rankdir="LR")

    # Step 1: generate unique name for each node in dag
    nodes, edges = _get_nodes_and_edges(dag)
    name_generator = _DAGNodeNameGenerator()
    node_names = {}
    for node in nodes:
        node_names[node] = name_generator.get_node_name(node)

    # Step 2: create graph with all the edges
    for edge in edges:
        graph.add_edge(pydot.Edge(node_names[edge[0]], node_names[edge[1]]))
    # if there is only one node
    if len(nodes) == 1 and len(edges) == 0:
        graph.add_node(pydot.Node(node_names[nodes[0]]))

    return graph