Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
sarus_sql / sarus_sql / ast_transform / table_renamer.py
Size: Mime:
import typing as t

from snsql._ast.ast import Column, Query, Sql, Table, UnifiedQuery

from .ast_walker import ASTWalker
from .utils import is_quoted, split_quote


class TableRenamer(ASTWalker):
    """TableRenamer rewrite a `Query` object by replacing all the tables
    specified as keys in `map_dict` by their corresponding values.

    Args:
        query (Query): a parsed query
        map_dict (Dict[Tuple[str], Tuple[str]]): the dictionnary containing
            the mapping between the original and the new tablenames
    """

    def __init__(
        self,
        query: t.Union[Query, UnifiedQuery],
        map_dict: t.Dict[t.Tuple[str, ...], t.Tuple[str, ...]],
        **kwargs: t.Any,
    ):
        super().__init__(query, **kwargs)
        self.map_dict = map_dict

    def rename_table_column(self, original_name: str) -> str:
        """If `original_name` is in `self.map_dict`,
        returns the corresponding new name.
        Otherwise, returns `original_name`"""

        escaped_name = split_quote(original_name, "")
        escaped_name_len = len(escaped_name)
        matching_names = []
        for old, new in self.map_dict.items():
            if len(old) >= escaped_name_len and (
                old[len(old) - escaped_name_len :] == escaped_name
                or old[:1] == escaped_name
            ):
                matching_names.append(new)

        if len(matching_names) == 1:
            new_name = ".".join(
                [
                    '"' + part + '"' if not is_quoted(part) else part
                    for part in matching_names[0]
                ]
            )
        elif len(matching_names) > 1:
            raise ValueError(
                f"Qualified table name {original_name} is ambiguous"
            )
        else:
            new_name = original_name
        return new_name

    def transform(self, expr: Sql) -> Sql:
        """maps the table names according to mapdict

        Args:
            expr (Sql):

        Returns:
            Sql: modified ast expression
        """
        if isinstance(expr, Table):
            expr.name = self.rename_table_column(expr.name)
            return expr

        if isinstance(expr, Column):
            split_name = split_quote(expr.name)
            if len(split_name) == 1:
                new_name = self.rename_table_column(split_name[-1])
            else:
                tablename = self.rename_table_column(".".join(split_name[:-1]))
                colname = self.rename_table_column(split_name[-1])
                new_name = f"{tablename}.{colname}"
            expr.name = new_name
            return expr

        return self.walk_tree(expr)

    def rename(self) -> None:
        self.run()