Repository URL to install this package:
|
Version:
3.0.0.dev3 ▾
|
import typing as t
from snsql._ast.ast import Column, Query, Sql, Table, UnifiedQuery
from .ast_walker import ASTWalker
from .utils import is_quoted, split_quote
class TableRenamer(ASTWalker):
"""TableRenamer rewrite a `Query` object by replacing all the tables
specified as keys in `map_dict` by their corresponding values.
Args:
query (Query): a parsed query
map_dict (Dict[Tuple[str], Tuple[str]]): the dictionnary containing
the mapping between the original and the new tablenames
"""
def __init__(
self,
query: t.Union[Query, UnifiedQuery],
map_dict: t.Dict[t.Tuple[str, ...], t.Tuple[str, ...]],
**kwargs: t.Any,
):
super().__init__(query, **kwargs)
self.map_dict = map_dict
def rename_table_column(self, original_name: str) -> str:
"""If `original_name` is in `self.map_dict`,
returns the corresponding new name.
Otherwise, returns `original_name`"""
escaped_name = split_quote(original_name, "")
escaped_name_len = len(escaped_name)
matching_names = []
for old, new in self.map_dict.items():
if len(old) >= escaped_name_len and (
old[len(old) - escaped_name_len :] == escaped_name
or old[:1] == escaped_name
):
matching_names.append(new)
if len(matching_names) == 1:
new_name = ".".join(
[
'"' + part + '"' if not is_quoted(part) else part
for part in matching_names[0]
]
)
elif len(matching_names) > 1:
raise ValueError(
f"Qualified table name {original_name} is ambiguous"
)
else:
new_name = original_name
return new_name
def transform(self, expr: Sql) -> Sql:
"""maps the table names according to mapdict
Args:
expr (Sql):
Returns:
Sql: modified ast expression
"""
if isinstance(expr, Table):
expr.name = self.rename_table_column(expr.name)
return expr
if isinstance(expr, Column):
split_name = split_quote(expr.name)
if len(split_name) == 1:
new_name = self.rename_table_column(split_name[-1])
else:
tablename = self.rename_table_column(".".join(split_name[:-1]))
colname = self.rename_table_column(split_name[-1])
new_name = f"{tablename}.{colname}"
expr.name = new_name
return expr
return self.walk_tree(expr)
def rename(self) -> None:
self.run()