from __future__ import absolute_import

from .Errors import error, message
from . import ExprNodes
from . import Nodes
from . import Builtin
from . import PyrexTypes
from .. import Utils
from .PyrexTypes import py_object_type, unspecified_type
from .Visitor import CythonTransform, EnvTransform

try:
    reduce
except NameError:
    from functools import reduce


class TypedExprNode(ExprNodes.ExprNode):
    # Used for declaring assignments of a specified type without a known entry.
    subexprs = []

    def __init__(self, type, pos=None):
        super(TypedExprNode, self).__init__(pos, type=type)

object_expr = TypedExprNode(py_object_type)


class MarkParallelAssignments(EnvTransform):
    # Collects assignments inside parallel blocks prange, with parallel.
    # Perhaps it's better to move it to ControlFlowAnalysis.

    # tells us whether we're in a normal loop
    in_loop = False

    parallel_errors = False

    def __init__(self, context):
        # Track the parallel block scopes (with parallel, for i in prange())
        self.parallel_block_stack = []
        super(MarkParallelAssignments, self).__init__(context)

    def mark_assignment(self, lhs, rhs, inplace_op=None):
        if isinstance(lhs, (ExprNodes.NameNode, Nodes.PyArgDeclNode)):
            if lhs.entry is None:
                # TODO: This shouldn't happen...
                return

            if self.parallel_block_stack:
                parallel_node = self.parallel_block_stack[-1]
                previous_assignment = parallel_node.assignments.get(lhs.entry)

                # If there was a previous assignment to the variable, keep the
                # previous assignment position
                if previous_assignment:
                    pos, previous_inplace_op = previous_assignment

                    if (inplace_op and previous_inplace_op and
                            inplace_op != previous_inplace_op):
                        # x += y; x *= y
                        t = (inplace_op, previous_inplace_op)
                        error(lhs.pos,
                              "Reduction operator '%s' is inconsistent "
                              "with previous reduction operator '%s'" % t)
                else:
                    pos = lhs.pos

                parallel_node.assignments[lhs.entry] = (pos, inplace_op)
                parallel_node.assigned_nodes.append(lhs)

        elif isinstance(lhs, ExprNodes.SequenceNode):
            for i, arg in enumerate(lhs.args):
                if not rhs or arg.is_starred:
                    item_node = None
                else:
                    item_node = rhs.inferable_item_node(i)
                self.mark_assignment(arg, item_node)
        else:
            # Could use this info to infer cdef class attributes...
            pass

    def visit_WithTargetAssignmentStatNode(self, node):
        self.mark_assignment(node.lhs, node.with_node.enter_call)
        self.visitchildren(node)
        return node

    def visit_SingleAssignmentNode(self, node):
        self.mark_assignment(node.lhs, node.rhs)
        self.visitchildren(node)
        return node

    def visit_CascadedAssignmentNode(self, node):
        for lhs in node.lhs_list:
            self.mark_assignment(lhs, node.rhs)
        self.visitchildren(node)
        return node

    def visit_InPlaceAssignmentNode(self, node):
        self.mark_assignment(node.lhs, node.create_binop_node(), node.operator)
        self.visitchildren(node)
        return node

    def visit_ForInStatNode(self, node):
        # TODO: Remove redundancy with range optimization...
        is_special = False
        sequence = node.iterator.sequence
        target = node.target
        if isinstance(sequence, ExprNodes.SimpleCallNode):
            function = sequence.function
            if sequence.self is None and function.is_name:
                entry = self.current_env().lookup(function.name)
                if not entry or entry.is_builtin:
                    if function.name == 'reversed' and len(sequence.args) == 1:
                        sequence = sequence.args[0]
                    elif function.name == 'enumerate' and len(sequence.args) == 1:
                        if target.is_sequence_constructor and len(target.args) == 2:
                            iterator = sequence.args[0]
                            if iterator.is_name:
                                iterator_type = iterator.infer_type(self.current_env())
                                if iterator_type.is_builtin_type:
                                    # assume that builtin types have a length within Py_ssize_t
                                    self.mark_assignment(
                                        target.args[0],
                                        ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
                                                          type=PyrexTypes.c_py_ssize_t_type))
                                    target = target.args[1]
                                    sequence = sequence.args[0]
        if isinstance(sequence, ExprNodes.SimpleCallNode):
            function = sequence.function
            if sequence.self is None and function.is_name:
                entry = self.current_env().lookup(function.name)
                if not entry or entry.is_builtin:
                    if function.name in ('range', 'xrange'):
                        is_special = True
                        for arg in sequence.args[:2]:
                            self.mark_assignment(target, arg)
                        if len(sequence.args) > 2:
                            self.mark_assignment(
                                target,
                                ExprNodes.binop_node(node.pos,
                                                     '+',
                                                     sequence.args[0],
                                                     sequence.args[2]))

        if not is_special:
            # A for-loop basically translates to subsequent calls to
            # __getitem__(), so using an IndexNode here allows us to
            # naturally infer the base type of pointers, C arrays,
            # Python strings, etc., while correctly falling back to an
            # object type when the base type cannot be handled.
            self.mark_assignment(target, ExprNodes.IndexNode(
                node.pos,
                base=sequence,
                index=ExprNodes.IntNode(target.pos, value='PY_SSIZE_T_MAX',
                                        type=PyrexTypes.c_py_ssize_t_type)))

        self.visitchildren(node)
        return node

    def visit_ForFromStatNode(self, node):
        self.mark_assignment(node.target, node.bound1)
        if node.step is not None:
            self.mark_assignment(node.target,
                    ExprNodes.binop_node(node.pos,
                                         '+',
                                         node.bound1,
                                         node.step))
        self.visitchildren(node)
        return node

    def visit_WhileStatNode(self, node):
        self.visitchildren(node)
        return node

    def visit_ExceptClauseNode(self, node):
        if node.target is not None:
            self.mark_assignment(node.target, object_expr)
        self.visitchildren(node)
        return node

    def visit_FromCImportStatNode(self, node):
        pass # Can't be assigned to...

    def visit_FromImportStatNode(self, node):
        for name, target in node.items:
            if name != "*":
                self.mark_assignment(target, object_expr)
        self.visitchildren(node)
        return node

    def visit_DefNode(self, node):
        # use fake expressions with the right result type
        if node.star_arg:
            self.mark_assignment(
                node.star_arg, TypedExprNode(Builtin.tuple_type, node.pos))
        if node.starstar_arg:
            self.mark_assignment(
                node.starstar_arg, TypedExprNode(Builtin.dict_type, node.pos))
        EnvTransform.visit_FuncDefNode(self, node)
        return node

    def visit_DelStatNode(self, node):
        for arg in node.args:
            self.mark_assignment(arg, arg)
        self.visitchildren(node)
        return node

    def visit_ParallelStatNode(self, node):
        if self.parallel_block_stack:
            node.parent = self.parallel_block_stack[-1]
        else:
            node.parent = None

        nested = False
        if node.is_prange:
            if not node.parent:
                node.is_parallel = True
            else:
                node.is_parallel = (node.parent.is_prange or not
                                    node.parent.is_parallel)
                nested = node.parent.is_prange
        else:
            node.is_parallel = True
            # Note: nested with parallel() blocks are handled by
            # ParallelRangeTransform!
            # nested = node.parent
            nested = node.parent and node.parent.is_prange

        self.parallel_block_stack.append(node)

        nested = nested or len(self.parallel_block_stack) > 2
        if not self.parallel_errors and nested and not node.is_prange:
            error(node.pos, "Only prange() may be nested")
            self.parallel_errors = True

        if node.is_prange:
            child_attrs = node.child_attrs
            node.child_attrs = ['body', 'target', 'args']
            self.visitchildren(node)
            node.child_attrs = child_attrs

            self.parallel_block_stack.pop()
            if node.else_clause:
                node.else_clause = self.visit(node.else_clause)
        else:
            self.visitchildren(node)
            self.parallel_block_stack.pop()

        self.parallel_errors = False
        return node

    def visit_YieldExprNode(self, node):
        if self.parallel_block_stack:
            error(node.pos, "'%s' not allowed in parallel sections" % node.expr_keyword)
        return node

    def visit_ReturnStatNode(self, node):
        node.in_parallel = bool(self.parallel_block_stack)
        return node


class MarkOverflowingArithmetic(CythonTransform):

    # It may be possible to integrate this with the above for
    # performance improvements (though likely not worth it).

    might_overflow = False

    def __call__(self, root):
        self.env_stack = []
        self.env = root.scope
        return super(MarkOverflowingArithmetic, self).__call__(root)

    def visit_safe_node(self, node):
        self.might_overflow, saved = False, self.might_overflow
        self.visitchildren(node)
        self.might_overflow = saved
        return node

    def visit_neutral_node(self, node):
        self.visitchildren(node)
        return node

    def visit_dangerous_node(self, node):
        self.might_overflow, saved = True, self.might_overflow
        self.visitchildren(node)
        self.might_overflow = saved
        return node

    def visit_FuncDefNode(self, node):
        self.env_stack.append(self.env)
        self.env = node.local_scope
        self.visit_safe_node(node)
        self.env = self.env_stack.pop()
        return node

    def visit_NameNode(self, node):
        if self.might_overflow:
            entry = node.entry or self.env.lookup(node.name)
            if entry:
                entry.might_overflow = True
        return node

    def visit_BinopNode(self, node):
        if node.operator in '&|^':
            return self.visit_neutral_node(node)
        else:
            return self.visit_dangerous_node(node)

    def visit_SimpleCallNode(self, node):
        if node.function.is_name and node.function.name == 'abs':
          # Overflows for minimum value of fixed size ints.
          return self.visit_dangerous_node(node)
        else:
          return self.visit_neutral_node(node)

    visit_UnopNode = visit_neutral_node

    visit_UnaryMinusNode = visit_dangerous_node

    visit_InPlaceAssignmentNode = visit_dangerous_node

    visit_Node = visit_safe_node

    def visit_assignment(self, lhs, rhs):
        if (isinstance(rhs, ExprNodes.IntNode)
                and isinstance(lhs, ExprNodes.NameNode)
                and Utils.long_literal(rhs.value)):
            entry = lhs.entry or self.env.lookup(lhs.name)
            if entry:
                entry.might_overflow = True

    def visit_SingleAssignmentNode(self, node):
        self.visit_assignment(node.lhs, node.rhs)
        self.visitchildren(node)
        return node

    def visit_CascadedAssignmentNode(self, node):
        for lhs in node.lhs_list:
            self.visit_assignment(lhs, node.rhs)
        self.visitchildren(node)
        return node

class PyObjectTypeInferer(object):
    """
    If it's not declared, it's a PyObject.
    """
    def infer_types(self, scope):
        """
        Given a dict of entries, map all unspecified types to a specified type.
        """
        for name, entry in scope.entries.items():
            if entry.type is unspecified_type:
                entry.type = py_object_type

class SimpleAssignmentTypeInferer(object):
    """
    Very basic type inference.

    Note: in order to support cross-closure type inference, this must be
    applies to nested scopes in top-down order.
    """
    def set_entry_type(self, entry, entry_type):
        entry.type = entry_type
        for e in entry.all_entries():
            e.type = entry_type

    def infer_types(self, scope):
        enabled = scope.directives['infer_types']
        verbose = scope.directives['infer_types.verbose']

        if enabled == True:
            spanning_type = aggressive_spanning_type
        elif enabled is None: # safe mode
            spanning_type = safe_spanning_type
        else:
            for entry in scope.entries.values():
                if entry.type is unspecified_type:
                    self.set_entry_type(entry, py_object_type)
            return

        # Set of assignments
        assignments = set()
        assmts_resolved = set()
        dependencies = {}
        assmt_to_names = {}

        for name, entry in scope.entries.items():
            for assmt in entry.cf_assignments:
                names = assmt.type_dependencies()
                assmt_to_names[assmt] = names
                assmts = set()
                for node in names:
                    assmts.update(node.cf_state)
                dependencies[assmt] = assmts
            if entry.type is unspecified_type:
                assignments.update(entry.cf_assignments)
            else:
                assmts_resolved.update(entry.cf_assignments)

        def infer_name_node_type(node):
            types = [assmt.inferred_type for assmt in node.cf_state]
            if not types:
                node_type = py_object_type
            else:
                entry = node.entry
                node_type = spanning_type(
                    types, entry.might_overflow, entry.pos, scope)
            node.inferred_type = node_type

        def infer_name_node_type_partial(node):
            types = [assmt.inferred_type for assmt in node.cf_state
                     if assmt.inferred_type is not None]
            if not types:
                return
            entry = node.entry
            return spanning_type(types, entry.might_overflow, entry.pos, scope)

        def inferred_types(entry):
            has_none = False
            has_pyobjects = False
            types = []
            for assmt in entry.cf_assignments:
                if assmt.rhs.is_none:
                    has_none = True
                else:
                    rhs_type = assmt.inferred_type
                    if rhs_type and rhs_type.is_pyobject:
                        has_pyobjects = True
                    types.append(rhs_type)
            # Ignore None assignments as long as there are concrete Python type assignments.
            # but include them if None is the only assigned Python object.
            if has_none and not has_pyobjects:
                types.append(py_object_type)
            return types

        def resolve_assignments(assignments):
            resolved = set()
            for assmt in assignments:
                deps = dependencies[assmt]
                # All assignments are resolved
                if assmts_resolved.issuperset(deps):
                    for node in assmt_to_names[assmt]:
                        infer_name_node_type(node)
                    # Resolve assmt
                    inferred_type = assmt.infer_type()
                    assmts_resolved.add(assmt)
                    resolved.add(assmt)
            assignments.difference_update(resolved)
            return resolved

        def partial_infer(assmt):
            partial_types = []
            for node in assmt_to_names[assmt]:
                partial_type = infer_name_node_type_partial(node)
                if partial_type is None:
                    return False
                partial_types.append((node, partial_type))
            for node, partial_type in partial_types:
                node.inferred_type = partial_type
            assmt.infer_type()
            return True

        partial_assmts = set()
        def resolve_partial(assignments):
            # try to handle circular references
            partials = set()
            for assmt in assignments:
                if assmt in partial_assmts:
                    continue
                if partial_infer(assmt):
                    partials.add(assmt)
                    assmts_resolved.add(assmt)
            partial_assmts.update(partials)
            return partials

        # Infer assignments
        while True:
            if not resolve_assignments(assignments):
                if not resolve_partial(assignments):
                    break
        inferred = set()
        # First pass
        for entry in scope.entries.values():
            if entry.type is not unspecified_type:
                continue
            entry_type = py_object_type
            if assmts_resolved.issuperset(entry.cf_assignments):
                types = inferred_types(entry)
                if types and all(types):
                    entry_type = spanning_type(
                        types, entry.might_overflow, entry.pos, scope)
                    inferred.add(entry)
            self.set_entry_type(entry, entry_type)

        def reinfer():
            dirty = False
            for entry in inferred:
                for assmt in entry.cf_assignments:
                    assmt.infer_type()
                types = inferred_types(entry)
                new_type = spanning_type(types, entry.might_overflow, entry.pos, scope)
                if new_type != entry.type:
                    self.set_entry_type(entry, new_type)
                    dirty = True
            return dirty

        # types propagation
        while reinfer():
            pass

        if verbose:
            for entry in inferred:
                message(entry.pos, "inferred '%s' to be of type '%s'" % (
                    entry.name, entry.type))


def find_spanning_type(type1, type2):
    if type1 is type2:
        result_type = type1
    elif type1 is PyrexTypes.c_bint_type or type2 is PyrexTypes.c_bint_type:
        # type inference can break the coercion back to a Python bool
        # if it returns an arbitrary int type here
        return py_object_type
    else:
        result_type = PyrexTypes.spanning_type(type1, type2)
    if result_type in (PyrexTypes.c_double_type, PyrexTypes.c_float_type,
                       Builtin.float_type):
        # Python's float type is just a C double, so it's safe to
        # use the C type instead
        return PyrexTypes.c_double_type
    return result_type

def simply_type(result_type, pos):
    if result_type.is_reference:
        result_type = result_type.ref_base_type
    if result_type.is_const:
        result_type = result_type.const_base_type
    if result_type.is_cpp_class:
        result_type.check_nullary_constructor(pos)
    if result_type.is_array:
        result_type = PyrexTypes.c_ptr_type(result_type.base_type)
    return result_type

def aggressive_spanning_type(types, might_overflow, pos, scope):
    return simply_type(reduce(find_spanning_type, types), pos)

def safe_spanning_type(types, might_overflow, pos, scope):
    result_type = simply_type(reduce(find_spanning_type, types), pos)
    if result_type.is_pyobject:
        # In theory, any specific Python type is always safe to
        # infer. However, inferring str can cause some existing code
        # to break, since we are also now much more strict about
        # coercion from str to char *. See trac #553.
        if result_type.name == 'str':
            return py_object_type
        else:
            return result_type
    elif result_type is PyrexTypes.c_double_type:
        # Python's float type is just a C double, so it's safe to use
        # the C type instead
        return result_type
    elif result_type is PyrexTypes.c_bint_type:
        # find_spanning_type() only returns 'bint' for clean boolean
        # operations without other int types, so this is safe, too
        return result_type
    elif result_type.is_pythran_expr:
        return result_type
    elif result_type.is_ptr:
        # Any pointer except (signed|unsigned|) char* can't implicitly
        # become a PyObject, and inferring char* is now accepted, too.
        return result_type
    elif result_type.is_cpp_class:
        # These can't implicitly become Python objects either.
        return result_type
    elif result_type.is_struct:
        # Though we have struct -> object for some structs, this is uncommonly
        # used, won't arise in pure Python, and there shouldn't be side
        # effects, so I'm declaring this safe.
        return result_type
    # TODO: double complex should be OK as well, but we need
    # to make sure everything is supported.
    elif (result_type.is_int or result_type.is_enum) and not might_overflow:
        return result_type
    elif (not result_type.can_coerce_to_pyobject(scope)
            and not result_type.is_error):
        return result_type
    return py_object_type


def get_type_inferer():
    return SimpleAssignmentTypeInferer()