Why Gemfury? Push, build, and install  RubyGems npm packages Python packages Maven artifacts PHP packages Go Modules Debian packages RPM packages NuGet packages

Repository URL to install this package:

Details    
mypy / semanal_enum.py
Size: Mime:
"""Semantic analysis of call-based Enum definitions.

This is conceptually part of mypy.semanal (semantic analyzer pass 2).
"""

from typing import List, Tuple, Optional, Union, cast

from mypy.nodes import (
    Expression, Context, TypeInfo, AssignmentStmt, NameExpr, CallExpr, RefExpr, StrExpr,
    UnicodeExpr, TupleExpr, ListExpr, DictExpr, Var, SymbolTableNode, GDEF, MDEF, ARG_POS,
    EnumCallExpr
)
from mypy.semanal_shared import SemanticAnalyzerInterface
from mypy.options import Options


class EnumCallAnalyzer:
    def __init__(self, options: Options, api: SemanticAnalyzerInterface) -> None:
        self.options = options
        self.api = api

    def process_enum_call(self, s: AssignmentStmt, is_func_scope: bool) -> None:
        """Check if s defines an Enum; if yes, store the definition in symbol table."""
        if len(s.lvalues) != 1 or not isinstance(s.lvalues[0], NameExpr):
            return
        lvalue = s.lvalues[0]
        name = lvalue.name
        enum_call = self.check_enum_call(s.rvalue, name, is_func_scope)
        if enum_call is None:
            return
        # Yes, it's a valid Enum definition. Add it to the symbol table.
        node = self.api.lookup(name, s)
        if node:
            node.kind = GDEF   # TODO locally defined Enum
            node.node = enum_call

    def check_enum_call(self,
                        node: Expression,
                        var_name: Optional[str],
                        is_func_scope: bool) -> Optional[TypeInfo]:
        """Check if a call defines an Enum.

        Example:

          A = enum.Enum('A', 'foo bar')

        is equivalent to:

          class A(enum.Enum):
              foo = 1
              bar = 2
        """
        if not isinstance(node, CallExpr):
            return None
        call = node
        callee = call.callee
        if not isinstance(callee, RefExpr):
            return None
        fullname = callee.fullname
        if fullname not in ('enum.Enum', 'enum.IntEnum', 'enum.Flag', 'enum.IntFlag'):
            return None
        items, values, ok = self.parse_enum_call_args(call, fullname.split('.')[-1])
        if not ok:
            # Error. Construct dummy return value.
            return self.build_enum_call_typeinfo('Enum', [], fullname)
        name = cast(Union[StrExpr, UnicodeExpr], call.args[0]).value
        if name != var_name or is_func_scope:
            # Give it a unique name derived from the line number.
            name += '@' + str(call.line)
        info = self.build_enum_call_typeinfo(name, items, fullname)
        # Store it as a global just in case it would remain anonymous.
        # (Or in the nearest class if there is one.)
        stnode = SymbolTableNode(GDEF, info)
        self.api.add_symbol_table_node(name, stnode)
        call.analyzed = EnumCallExpr(info, items, values)
        call.analyzed.set_line(call.line, call.column)
        return info

    def build_enum_call_typeinfo(self, name: str, items: List[str], fullname: str) -> TypeInfo:
        base = self.api.named_type_or_none(fullname)
        assert base is not None
        info = self.api.basic_new_typeinfo(name, base)
        info.metaclass_type = info.calculate_metaclass_type()
        info.is_enum = True
        for item in items:
            var = Var(item)
            var.info = info
            var.is_property = True
            var._fullname = '{}.{}'.format(self.api.qualified_name(name), item)
            info.names[item] = SymbolTableNode(MDEF, var)
        return info

    def parse_enum_call_args(self, call: CallExpr,
                             class_name: str) -> Tuple[List[str],
                                                       List[Optional[Expression]], bool]:
        args = call.args
        if len(args) < 2:
            return self.fail_enum_call_arg("Too few arguments for %s()" % class_name, call)
        if len(args) > 2:
            return self.fail_enum_call_arg("Too many arguments for %s()" % class_name, call)
        if call.arg_kinds != [ARG_POS, ARG_POS]:
            return self.fail_enum_call_arg("Unexpected arguments to %s()" % class_name, call)
        if not isinstance(args[0], (StrExpr, UnicodeExpr)):
            return self.fail_enum_call_arg(
                "%s() expects a string literal as the first argument" % class_name, call)
        items = []
        values = []  # type: List[Optional[Expression]]
        if isinstance(args[1], (StrExpr, UnicodeExpr)):
            fields = args[1].value
            for field in fields.replace(',', ' ').split():
                items.append(field)
        elif isinstance(args[1], (TupleExpr, ListExpr)):
            seq_items = args[1].items
            if all(isinstance(seq_item, (StrExpr, UnicodeExpr)) for seq_item in seq_items):
                items = [cast(Union[StrExpr, UnicodeExpr], seq_item).value
                         for seq_item in seq_items]
            elif all(isinstance(seq_item, (TupleExpr, ListExpr))
                     and len(seq_item.items) == 2
                     and isinstance(seq_item.items[0], (StrExpr, UnicodeExpr))
                     for seq_item in seq_items):
                for seq_item in seq_items:
                    assert isinstance(seq_item, (TupleExpr, ListExpr))
                    name, value = seq_item.items
                    assert isinstance(name, (StrExpr, UnicodeExpr))
                    items.append(name.value)
                    values.append(value)
            else:
                return self.fail_enum_call_arg(
                    "%s() with tuple or list expects strings or (name, value) pairs" %
                    class_name,
                    call)
        elif isinstance(args[1], DictExpr):
            for key, value in args[1].items:
                if not isinstance(key, (StrExpr, UnicodeExpr)):
                    return self.fail_enum_call_arg(
                        "%s() with dict literal requires string literals" % class_name, call)
                items.append(key.value)
                values.append(value)
        else:
            # TODO: Allow dict(x=1, y=2) as a substitute for {'x': 1, 'y': 2}?
            return self.fail_enum_call_arg(
                "%s() expects a string, tuple, list or dict literal as the second argument" %
                class_name,
                call)
        if len(items) == 0:
            return self.fail_enum_call_arg("%s() needs at least one item" % class_name, call)
        if not values:
            values = [None] * len(items)
        assert len(items) == len(values)
        return items, values, True

    def fail_enum_call_arg(self, message: str,
                           context: Context) -> Tuple[List[str],
                                                      List[Optional[Expression]], bool]:
        self.fail(message, context)
        return [], [], False

    # Helpers

    def fail(self, msg: str, ctx: Context) -> None:
        self.api.fail(msg, ctx)