from __future__ import absolute_import
from .Errors import error, message
from . import ExprNodes
from . import Nodes
from . import Builtin
from . import PyrexTypes
from .. import Utils
from .PyrexTypes import py_object_type, unspecified_type
from .Visitor import CythonTransform, EnvTransform
try:
reduce
except NameError:
from functools import reduce
class TypedExprNode(ExprNodes.ExprNode):
# Used for declaring assignments of a specified type without a known entry.
subexprs = []
def __init__(self, type, pos=None):
super(TypedExprNode, self).__init__(pos, type=type)
object_expr = TypedExprNode(py_object_type)
class MarkParallelAssignments(EnvTransform):
# Collects assignments inside parallel blocks prange, with parallel.
# Perhaps it's better to move it to ControlFlowAnalysis.
# tells us whether we're in a normal loop
in_loop = False
parallel_errors = False
def __init__(self, context):
# Track the parallel block scopes (with parallel, for i in prange())
self.parallel_block_stack = []
super(MarkParallelAssignments, self).__init__(context)
def mark_assignment(self, lhs, rhs, inplace_op=None):
if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
if lhs.entry is None:
# TODO: This shouldn't happen...
return
if self.parallel_block_stack:
parallel_node = self.parallel_block_stack[-1]
previous_assignment = parallel_node.assignments.get(lhs.entry)
# If there was a previous assignment to the variable, keep the
# previous assignment position
if previous_assignment:
pos, previous_inplace_op = previous_assignment
if (inplace_op and previous_inplace_op and
inplace_op != previous_inplace_op):
# x += y; x *= y
t = (inplace_op, previous_inplace_op)
error(lhs.pos,
"Reduction operator '%s' is inconsistent "
"with previous reduction operator '%s'" % t)
else:
pos = lhs.pos
parallel_node.assignments[lhs.entry] = (pos, inplace_op)
parallel_node.assigned_nodes.append(lhs)
elif isinstance(lhs, ExprNodes.SequenceNode):
for i, arg in enumerate(lhs.args):
if not rhs or arg.is_starred:
item_node = None
else:
item_node = rhs.inferable_item_node(i)
self.mark_assignment(arg, item_node)
else:
# Could use this info to infer cdef class attributes...
pass
def visit_WithTargetAssignmentStatNode(self, node):
self.mark_assignment(node.lhs, node.with_node.enter_call)
self.visitchildren(node)
return node
def visit_SingleAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.rhs)
self.visitchildren(node)
return node
def visit_CascadedAssignmentNode(self, node):
for lhs in node.lhs_list:
self.mark_assignment(lhs, node.rhs)
self.visitchildren(node)
return node
def visit_InPlaceAssignmentNode(self, node):
self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
self.visitchildren(node)
return node
def visit_ForInStatNode(self, node):
# TODO: Remove redundancy with range optimization...
is_special = False
sequence = node.iterator.sequence
target = node.target
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.current_env().lookup(function.name)
if not entry or entry.is_builtin:
if function.name == 'reversed' and len(sequence.args) == 1:
sequence = sequence.args[0]
elif function.name == 'enumerate' and len(sequence.args) == 1:
if target.is_sequence_constructor and len(target.args) == 2:
iterator = sequence.args[0]
if iterator.is_name:
iterator_type = iterator.infer_type(self.current_env())
if iterator_type.is_builtin_type:
# assume that builtin types have a length within Py_ssize_t
self.mark_assignment(
target.args[0],
ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
type=PyrexTypes.c_py_ssize_t_type))
target = target.args[1]
sequence = sequence.args[0]
if isinstance(sequence, ExprNodes.SimpleCallNode):
function = sequence.function
if sequence.self is None and function.is_name:
entry = self.current_env().lookup(function.name)
if not entry or entry.is_builtin:
if function.name in ('range', 'xrange'):
is_special = True
for arg in sequence.args[:2]:
self.mark_assignment(target, arg)
if len(sequence.args) > 2:
self.mark_assignment(
target,
ExprNodes.binop_node(node.pos,
'+',
sequence.args[0],
sequence.args[2]))
if not is_special:
# A for-loop basically translates to subsequent calls to
# __getitem__(), so using an IndexNode here allows us to
# naturally infer the base type of pointers, C arrays,
# Python strings, etc., while correctly falling back to an
# object type when the base type cannot be handled.
self.mark_assignment(target, ExprNodes.IndexNode(
node.pos,
base=sequence,
index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
type=PyrexTypes.c_py_ssize_t_type)))
self.visitchildren(node)
return node
def visit_ForFromStatNode(self, node):
self.mark_assignment(node.target, node.bound1)
if node.step is not None:
self.mark_assignment(node.target,
ExprNodes.binop_node(node.pos,
'+',
node.bound1,
node.step))
self.visitchildren(node)
return node
def visit_WhileStatNode(self, node):
self.visitchildren(node)
return node
def visit_ExceptClauseNode(self, node):
if node.target is not None:
self.mark_assignment(node.target, object_expr)
self.visitchildren(node)
return node
def visit_FromCImportStatNode(self, node):
pass # Can't be assigned to...
def visit_FromImportStatNode(self, node):
for name, target in node.items:
if name != "*":
self.mark_assignment(target, object_expr)
self.visitchildren(node)
return node
def visit_DefNode(self, node):
# use fake expressions with the right result type
if node.star_arg:
self.mark_assignment(
node.star_arg, TypedExprNode(Builtin.tuple_type, node.pos))
if node.starstar_arg:
self.mark_assignment(
node.starstar_arg, TypedExprNode(Builtin.dict_type, node.pos))
EnvTransform.visit_FuncDefNode(self, node)
return node
def visit_DelStatNode(self, node):
for arg in node.args:
self.mark_assignment(arg, arg)
self.visitchildren(node)
return node
def visit_ParallelStatNode(self, node):
if self.parallel_block_stack:
node.parent = self.parallel_block_stack[-1]
else:
node.parent = None
nested = False
if node.is_prange:
if not node.parent:
node.is_parallel = True
else:
node.is_parallel = (node.parent.is_prange or not
node.parent.is_parallel)
nested = node.parent.is_prange
else:
node.is_parallel = True
# Note: nested with parallel() blocks are handled by
# ParallelRangeTransform!
# nested = node.parent
nested = node.parent and node.parent.is_prange
self.parallel_block_stack.append(node)
nested = nested or len(self.parallel_block_stack) > 2
if not self.parallel_errors and nested and not node.is_prange:
error(node.pos, "Only prange() may be nested")
self.parallel_errors = True
if node.is_prange:
child_attrs = node.child_attrs
node.child_attrs = ['body', 'target', 'args']
self.visitchildren(node)
node.child_attrs = child_attrs
self.parallel_block_stack.pop()
if node.else_clause:
node.else_clause = self.visit(node.else_clause)
else:
self.visitchildren(node)
self.parallel_block_stack.pop()
self.parallel_errors = False
return node
def visit_YieldExprNode(self, node):
if self.parallel_block_stack:
error(node.pos, "'%s' not allowed in parallel sections" % node.expr_keyword)
return node
def visit_ReturnStatNode(self, node):
node.in_parallel = bool(self.parallel_block_stack)
return node
class MarkOverflowingArithmetic(CythonTransform):
# It may be possible to integrate this with the above for
# performance improvements (though likely not worth it).
might_overflow = False
def __call__(self, root):
self.env_stack = []
self.env = root.scope
return super(MarkOverflowingArithmetic, self).__call__(root)
def visit_safe_node(self, node):
self.might_overflow, saved = False, self.might_overflow
self.visitchildren(node)
self.might_overflow = saved
return node
def visit_neutral_node(self, node):
self.visitchildren(node)
return node
def visit_dangerous_node(self, node):
self.might_overflow, saved = True, self.might_overflow
self.visitchildren(node)
self.might_overflow = saved
return node
def visit_FuncDefNode(self, node):
self.env_stack.append(self.env)
self.env = node.local_scope
self.visit_safe_node(node)
self.env = self.env_stack.pop()
return node
def visit_NameNode(self, node):
if self.might_overflow:
entry = node.entry or self.env.lookup(node.name)
if entry:
entry.might_overflow = True
return node
def visit_BinopNode(self, node):
if node.operator in '&|^':
return self.visit_neutral_node(node)
else:
return self.visit_dangerous_node(node)
def visit_SimpleCallNode(self, node):
if node.function.is_name and node.function.name == 'abs':
# Overflows for minimum value of fixed size ints.
return self.visit_dangerous_node(node)
else:
return self.visit_neutral_node(node)
visit_UnopNode = visit_neutral_node
visit_UnaryMinusNode = visit_dangerous_node
visit_InPlaceAssignmentNode = visit_dangerous_node
visit_Node = visit_safe_node
def visit_assignment(self, lhs, rhs):
if (isinstance(rhs, ExprNodes.IntNode)
and isinstance(lhs, ExprNodes.NameNode)
and Utils.long_literal(rhs.value)):
entry = lhs.entry or self.env.lookup(lhs.name)
if entry:
entry.might_overflow = True
def visit_SingleAssignmentNode(self, node):
self.visit_assignment(node.lhs, node.rhs)
self.visitchildren(node)
return node
def visit_CascadedAssignmentNode(self, node):
for lhs in node.lhs_list:
self.visit_assignment(lhs, node.rhs)
self.visitchildren(node)
return node
class PyObjectTypeInferer(object):
"""
If it's not declared, it's a PyObject.
"""
def infer_types(self, scope):
"""
Given a dict of entries, map all unspecified types to a specified type.
"""
for name, entry in scope.entries.items():
if entry.type is unspecified_type:
entry.type = py_object_type
class SimpleAssignmentTypeInferer(object):
"""
Very basic type inference.
Note: in order to support cross-closure type inference, this must be
applies to nested scopes in top-down order.
"""
def set_entry_type(self, entry, entry_type):
entry.type = entry_type
for e in entry.all_entries():
e.type = entry_type
def infer_types(self, scope):
enabled = scope.directives['infer_types']
verbose = scope.directives['infer_types.verbose']
if enabled == True:
spanning_type = aggressive_spanning_type
elif enabled is None: # safe mode
spanning_type = safe_spanning_type
else:
for entry in scope.entries.values():
if entry.type is unspecified_type:
self.set_entry_type(entry, py_object_type)
return
# Set of assignments
assignments = set()
assmts_resolved = set()
dependencies = {}
assmt_to_names = {}
for name, entry in scope.entries.items():
for assmt in entry.cf_assignments:
names = assmt.type_dependencies()
assmt_to_names[assmt] = names
assmts = set()
for node in names:
assmts.update(node.cf_state)
dependencies[assmt] = assmts
if entry.type is unspecified_type:
assignments.update(entry.cf_assignments)
else:
assmts_resolved.update(entry.cf_assignments)
def infer_name_node_type(node):
types = [assmt.inferred_type for assmt in node.cf_state]
if not types:
node_type = py_object_type
else:
entry = node.entry
node_type = spanning_type(
types, entry.might_overflow, entry.pos, scope)
node.inferred_type = node_type
def infer_name_node_type_partial(node):
types = [assmt.inferred_type for assmt in node.cf_state
if assmt.inferred_type is not None]
if not types:
return
entry = node.entry
return spanning_type(types, entry.might_overflow, entry.pos, scope)
def inferred_types(entry):
has_none = False
has_pyobjects = False
types = []
for assmt in entry.cf_assignments:
if assmt.rhs.is_none:
has_none = True
else:
rhs_type = assmt.inferred_type
if rhs_type and rhs_type.is_pyobject:
has_pyobjects = True
types.append(rhs_type)
# Ignore None assignments as long as there are concrete Python type assignments.
# but include them if None is the only assigned Python object.
if has_none and not has_pyobjects:
types.append(py_object_type)
return types
def resolve_assignments(assignments):
resolved = set()
for assmt in assignments:
deps = dependencies[assmt]
# All assignments are resolved
if assmts_resolved.issuperset(deps):
for node in assmt_to_names[assmt]:
infer_name_node_type(node)
# Resolve assmt
inferred_type = assmt.infer_type()
assmts_resolved.add(assmt)
resolved.add(assmt)
assignments.difference_update(resolved)
return resolved
def partial_infer(assmt):
partial_types = []
for node in assmt_to_names[assmt]:
partial_type = infer_name_node_type_partial(node)
if partial_type is None:
return False
partial_types.append((node, partial_type))
for node, partial_type in partial_types:
node.inferred_type = partial_type
assmt.infer_type()
return True
partial_assmts = set()
def resolve_partial(assignments):
# try to handle circular references
partials = set()
for assmt in assignments:
if assmt in partial_assmts:
continue
if partial_infer(assmt):
partials.add(assmt)
assmts_resolved.add(assmt)
partial_assmts.update(partials)
return partials
# Infer assignments
while True:
if not resolve_assignments(assignments):
if not resolve_partial(assignments):
break
inferred = set()
# First pass
for entry in scope.entries.values():
if entry.type is not unspecified_type:
continue
entry_type = py_object_type
if assmts_resolved.issuperset(entry.cf_assignments):
types = inferred_types(entry)
if types and all(types):
entry_type = spanning_type(
types, entry.might_overflow, entry.pos, scope)
inferred.add(entry)
self.set_entry_type(entry, entry_type)
def reinfer():
dirty = False
for entry in inferred:
for assmt in entry.cf_assignments:
assmt.infer_type()
types = inferred_types(entry)
new_type = spanning_type(types, entry.might_overflow, entry.pos, scope)
if new_type != entry.type:
self.set_entry_type(entry, new_type)
dirty = True
return dirty
# types propagation
while reinfer():
pass
if verbose:
for entry in inferred:
message(entry.pos, "inferred '%s' to be of type '%s'" % (
entry.name, entry.type))
def find_spanning_type(type1, type2):
if type1 is type2:
result_type = type1
elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
# type inference can break the coercion back to a Python bool
# if it returns an arbitrary int type here
return py_object_type
else:
result_type = PyrexTypes.spanning_type(type1, type2)
if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
Builtin.float_type):
# Python's float type is just a C double, so it's safe to
# use the C type instead
return PyrexTypes.c_double_type
return result_type
def simply_type(result_type, pos):
if result_type.is_reference:
result_type = result_type.ref_base_type
if result_type.is_const:
result_type = result_type.const_base_type
if result_type.is_cpp_class:
result_type.check_nullary_constructor(pos)
if result_type.is_array:
result_type = PyrexTypes.c_ptr_type(result_type.base_type)
return result_type
def aggressive_spanning_type(types, might_overflow, pos, scope):
return simply_type(reduce(find_spanning_type, types), pos)
def safe_spanning_type(types, might_overflow, pos, scope):
result_type = simply_type(reduce(find_spanning_type, types), pos)
if result_type.is_pyobject:
# In theory, any specific Python type is always safe to
# infer. However, inferring str can cause some existing code
# to break, since we are also now much more strict about
# coercion from str to char *. See trac #553.
if result_type.name == 'str':
return py_object_type
else:
return result_type
elif result_type is PyrexTypes.c_double_type:
# Python's float type is just a C double, so it's safe to use
# the C type instead
return result_type
elif result_type is PyrexTypes.c_bint_type:
# find_spanning_type() only returns 'bint' for clean boolean
# operations without other int types, so this is safe, too
return result_type
elif result_type.is_pythran_expr:
return result_type
elif result_type.is_ptr:
# Any pointer except (signed|unsigned|) char* can't implicitly
# become a PyObject, and inferring char* is now accepted, too.
return result_type
elif result_type.is_cpp_class:
# These can't implicitly become Python objects either.
return result_type
elif result_type.is_struct:
# Though we have struct -> object for some structs, this is uncommonly
# used, won't arise in pure Python, and there shouldn't be side
# effects, so I'm declaring this safe.
return result_type
# TODO: double complex should be OK as well, but we need
# to make sure everything is supported.
elif (result_type.is_int or result_type.is_enum) and not might_overflow:
return result_type
elif (not result_type.can_coerce_to_pyobject(scope)
and not result_type.is_error):
return result_type
return py_object_type
def get_type_inferer():
return SimpleAssignmentTypeInferer()