Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_sql / sarus_sql / ast_visualization.py
Size: Mime:
# pylint: disable=too-many-arguments
"""Provide `visualize` function which permits to build a graphiz
`Digraph` of a complete or bits of a parsed query.
"""

from typing import Any, Dict, Union

from graphviz import Digraph  # type: ignore # pylint: disable=import-error
from snsql._ast.ast import Sql, Token


def _label_node(expr: Sql, n_trunc: Union[int, None]) -> str:
    """Label an expression"""
    str_expr = str(expr)
    if n_trunc and len(str_expr) > n_trunc:
        str_expr = str_expr[:n_trunc] + "..."
    return f"{type(expr).__name__}: {str_expr}"


def _color_node(expr: Sql, color_types: Dict[Any, str]) -> str:
    """Color the node containing 'expr'.
    Args:
        expr: ast expression contained in the node
        color_types: dict of types and corresponding color
            Example: {'str': 'red'}
            If `expr` type in `str` then the corresponding node will be red.
            If no color is provided for the `expr` type, it will be printed
            in black.
    Returns:
        the color of the node
    """
    expr_type = type(expr).__name__
    if expr_type in color_types:
        return color_types[expr_type]
    return "black"


def _visit_nodes(
    graph: Digraph,
    node: Sql,
    path_to_node: str,
    color_types: dict,
    n_trunc: Union[int, None],
) -> None:
    """Visit the nodes of the AST
    Args:
        graph: The ongoing directed graph.
        node: The current node to visit.
        path_to_node: The path to this node.
        color_types: A dictionnary which contains the type and
            the corresponding color of the nodes.
            Example: {str: 'red', CaseExpression: 'green'}
        n_trunc: for visibility, truncate the expressions if n_trunc is
        not None.
    """
    if type(node) in [str, float, int, bool]:
        return
    mem_childs = []
    ind = 1
    for child_node in node.children():
        if child_node is not None and not isinstance(child_node, Token):
            path_to_child_node = f"{path_to_node}.{str(child_node)}"
            if path_to_child_node in mem_childs:
                ind += 1
                path_to_child_node += str(ind)
            mem_childs.append(path_to_child_node)
            child_label = _label_node(child_node, n_trunc)
            graph.node(
                path_to_child_node,
                child_label,
                color=_color_node(child_node, color_types),
            )
            graph.edge(path_to_node, path_to_child_node)
            _visit_nodes(
                graph, child_node, path_to_child_node, color_types, n_trunc
            )


def visualize_ast(
    parsed_query: Sql,
    name: Union[str, None] = None,
    format_file: str = "png",
    view: bool = True,
    color_types: Union[Dict[Any, str], None] = None,
    n_trunc: Union[int, None] = None,
) -> Digraph:
    """Construct and display the AST graph
    Args:
        parsed_query: The AST to visualize. This can be a complete or
            bits of a parsed query.
        name: Optional, graph name for the source code.
            If no name is provided, the graph will not be saved.
        format_file: If name is not None, rendering output
            format ('pdf', 'png', ...). By default, 'png'.
        view: If name is not None, open the rendered result
            with the default application. By default, True.
        color_types: Optional, A dictionnary which contains the type and
            the color of the nodes.
            Example: {'str': 'red', 'CaseExpression': 'green'}
            If no color is provided, the displayed node will be black.
            By default, None.
        n_trunc: for visibility, truncate the expressions. By default,
            no truncation is done.

    Returns:
        grapviz Digraph representation of 'parsed_query'
    """
    if color_types is None:
        color_types = {}
    graph = Digraph(format=format_file)
    graph.node(
        str(parsed_query),
        _label_node(parsed_query, n_trunc),
        color=_color_node(parsed_query, color_types),
    )
    _visit_nodes(graph, parsed_query, str(parsed_query), color_types, n_trunc)
    if name:
        graph.render(name, view=view, cleanup=True)
        print(f"graph saved in {name}.{format_file}")
    return graph