from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import *
from Cython.Compiler.TreeFragment import TreeFragment
from Cython.Utils import EncodedString
from Cython.Compiler.Errors import CompileError
from sets import Set as set

class BufferTransform(CythonTransform):
    """
    Run after type analysis. Takes care of the buffer functionality.
    """
    scope = None

    def __call__(self, node):
        cymod = self.context.modules[u'__cython__']
        self.buffer_type = cymod.entries[u'Py_buffer'].type
        return super(BufferTransform, self).__call__(node)

    def handle_scope(self, node, scope):
        # For all buffers, insert extra variables in the scope.
        # The variables are also accessible from the buffer_info
        # on the buffer entry
        bufvars = [(name, entry) for name, entry
                   in scope.entries.iteritems()
                   if entry.type.buffer_options is not None]
                   
        for name, entry in bufvars:
            # Variable has buffer opts, declare auxiliary vars
            bufopts = entry.type.buffer_options

            bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
                                        self.buffer_type, node.pos)

            temp_var =  scope.declare_var(temp_name_handle(u"%s_tmp" % name),
                                        entry.type, node.pos)
            
            
            stridevars = []
            shapevars = []
            for idx in range(bufopts.ndim):
                # stride
                varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
                var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
                stridevars.append(var)
                # shape
                varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
                var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
                shapevars.append(var)
            entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, 
                                                shapevars)
            entry.buffer_aux.temp_var = temp_var
        self.scope = scope

            
    def visit_ModuleNode(self, node):
        self.handle_scope(node, node.scope)
        self.visitchildren(node)
        return node

    def visit_FuncDefNode(self, node):
        self.handle_scope(node, node.local_scope)
        self.visitchildren(node)
        return node

    acquire_buffer_fragment = TreeFragment(u"""
        TMP = LHS
        if TMP is not None:
            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
        TMP = RHS
        __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
        ASSIGN_AUX
        LHS = TMP
    """)

    fetch_strides = TreeFragment(u"""
        TARGET = BUFINFO.strides[IDX]
    """)

    fetch_shape = TreeFragment(u"""
        TARGET = BUFINFO.shape[IDX]
    """)

    def visit_SingleAssignmentNode(self, node):
        # On assignments, two buffer-related things can happen:
        # a) A buffer variable is assigned to (reacquisition)
        # b) Buffer access assignment: arr[...] = ...
        # Since we don't allow nested buffers, these don't overlap.
        
        self.visitchildren(node)
        # Only acquire buffers on vars (not attributes) for now.
        if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
            # Is buffer variable
            return self.reacquire_buffer(node)
        elif (isinstance(node.lhs, IndexNode) and
              isinstance(node.lhs.base, NameNode) and
              node.lhs.base.entry.buffer_aux is not None):
            return self.assign_into_buffer(node)
        
    def reacquire_buffer(self, node):
        bufaux = node.lhs.entry.buffer_aux
        auxass = []
        for idx, entry in enumerate(bufaux.stridevars):
            entry.used = True
            ass = self.fetch_strides.substitute({
                u"TARGET": NameNode(node.pos, name=entry.name),
                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
                u"IDX": IntNode(node.pos, value=EncodedString(idx))
            })
            auxass.append(ass)

        for idx, entry in enumerate(bufaux.shapevars):
            entry.used = True
            ass = self.fetch_shape.substitute({
                u"TARGET": NameNode(node.pos, name=entry.name),
                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
                u"IDX": IntNode(node.pos, value=EncodedString(idx))
            })
            auxass.append(ass)

        bufaux.buffer_info_var.used = True
        acq = self.acquire_buffer_fragment.substitute({
            u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
            u"LHS" : node.lhs,
            u"RHS": node.rhs,
            u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
            u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
        }, pos=node.pos)
        # Note: The below should probably be refactored into something
        # like fragment.substitute(..., context=self.context), with
        # TreeFragment getting context.pipeline_until_now() and
        # applying it on the fragment.
        acq.analyse_declarations(self.scope)
        acq.analyse_expressions(self.scope)
        stats = acq.stats
        return stats

    def assign_into_buffer(self, node):
        result = SingleAssignmentNode(node.pos,
                                      rhs=self.visit(node.rhs),
                                      lhs=self.buffer_index(node.lhs))
        result.analyse_expressions(self.scope)
        return result
        

    def buffer_index(self, node):
        bufaux = node.base.entry.buffer_aux
        assert bufaux is not None
        # indices * strides...
        to_sum = [ IntBinopNode(node.pos, operator='*',
                                operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
                                operand2=NameNode(node.pos, name=stride.name))
            for index, stride in zip(node.indices, bufaux.stridevars)]

        # then sum them 
        expr = to_sum[0]
        for next in to_sum[1:]:
            expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)

        tmp= self.buffer_access.substitute({
            'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
            'OFFSET': expr
            }, pos=node.pos)

        return tmp.stats[0].expr

    buffer_access = TreeFragment(u"""
        (<unsigned char*>(BUF.buf + OFFSET))[0]
    """)
    def visit_IndexNode(self, node):
        # Only occurs when the IndexNode is an rvalue
        if node.is_buffer_access:
            assert node.index is None
            assert node.indices is not None
            result = self.buffer_index(node)
            result.analyse_expressions(self.scope)
            return result
        else:
            return node

    def visit_CallNode(self, node):
###        print node.dump()
        return node
    
#    def visit_FuncDefNode(self, node):
#        print node.dump()