Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
# pylint: disable=too-many-arguments
"""Provide `visualize` function which permits to build a graphiz
`Digraph` of a complete or bits of a parsed query.
"""
from typing import Any, Dict, Union
from graphviz import Digraph # type: ignore # pylint: disable=import-error
from snsql._ast.ast import Sql, Token
def _label_node(expr: Sql, n_trunc: Union[int, None]) -> str:
"""Label an expression"""
str_expr = str(expr)
if n_trunc and len(str_expr) > n_trunc:
str_expr = str_expr[:n_trunc] + "..."
return f"{type(expr).__name__}: {str_expr}"
def _color_node(expr: Sql, color_types: Dict[Any, str]) -> str:
"""Color the node containing 'expr'.
Args:
expr: ast expression contained in the node
color_types: dict of types and corresponding color
Example: {'str': 'red'}
If `expr` type in `str` then the corresponding node will be red.
If no color is provided for the `expr` type, it will be printed
in black.
Returns:
the color of the node
"""
expr_type = type(expr).__name__
if expr_type in color_types:
return color_types[expr_type]
return "black"
def _visit_nodes(
graph: Digraph,
node: Sql,
path_to_node: str,
color_types: dict,
n_trunc: Union[int, None],
) -> None:
"""Visit the nodes of the AST
Args:
graph: The ongoing directed graph.
node: The current node to visit.
path_to_node: The path to this node.
color_types: A dictionnary which contains the type and
the corresponding color of the nodes.
Example: {str: 'red', CaseExpression: 'green'}
n_trunc: for visibility, truncate the expressions if n_trunc is
not None.
"""
if type(node) in [str, float, int, bool]:
return
mem_childs = []
ind = 1
for child_node in node.children():
if child_node is not None and not isinstance(child_node, Token):
path_to_child_node = f"{path_to_node}.{str(child_node)}"
if path_to_child_node in mem_childs:
ind += 1
path_to_child_node += str(ind)
mem_childs.append(path_to_child_node)
child_label = _label_node(child_node, n_trunc)
graph.node(
path_to_child_node,
child_label,
color=_color_node(child_node, color_types),
)
graph.edge(path_to_node, path_to_child_node)
_visit_nodes(
graph, child_node, path_to_child_node, color_types, n_trunc
)
def visualize_ast(
parsed_query: Sql,
name: Union[str, None] = None,
format_file: str = "png",
view: bool = True,
color_types: Union[Dict[Any, str], None] = None,
n_trunc: Union[int, None] = None,
) -> Digraph:
"""Construct and display the AST graph
Args:
parsed_query: The AST to visualize. This can be a complete or
bits of a parsed query.
name: Optional, graph name for the source code.
If no name is provided, the graph will not be saved.
format_file: If name is not None, rendering output
format ('pdf', 'png', ...). By default, 'png'.
view: If name is not None, open the rendered result
with the default application. By default, True.
color_types: Optional, A dictionnary which contains the type and
the color of the nodes.
Example: {'str': 'red', 'CaseExpression': 'green'}
If no color is provided, the displayed node will be black.
By default, None.
n_trunc: for visibility, truncate the expressions. By default,
no truncation is done.
Returns:
grapviz Digraph representation of 'parsed_query'
"""
if color_types is None:
color_types = {}
graph = Digraph(format=format_file)
graph.node(
str(parsed_query),
_label_node(parsed_query, n_trunc),
color=_color_node(parsed_query, color_types),
)
_visit_nodes(graph, parsed_query, str(parsed_query), color_types, n_trunc)
if name:
graph.render(name, view=view, cleanup=True)
print(f"graph saved in {name}.{format_file}")
return graph