Buffer.py 29 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13
from __future__ import absolute_import

from .Visitor import CythonTransform
from .ModuleNode import ModuleNode
from .Errors import CompileError
from .UtilityCode import CythonUtilityCode
from .Code import UtilityCode, TempitaUtilityCode

from . import Options
from . import Interpreter
from . import PyrexTypes
from . import Naming
from . import Symtab
14 15 16


def dedent(text, reindent=0):
17 18
    from textwrap import dedent
    text = dedent(text)
19 20 21 22
    if reindent > 0:
        indent = " " * reindent
        text = '\n'.join([indent + x for x in text.split('\n')])
    return text
23 24 25 26 27 28 29 30

class IntroduceBufferAuxiliaryVars(CythonTransform):

    #
    # Entry point
    #

    buffers_exists = False
31
    using_memoryview = False
32 33 34

    def __call__(self, node):
        assert isinstance(node, ModuleNode)
35
        self.max_ndim = 0
36
        result = super(IntroduceBufferAuxiliaryVars, self).__call__(node)
37
        if self.buffers_exists:
38
            use_bufstruct_declare_code(node.scope)
39
            use_py2_buffer_functions(node.scope)
40 41
            node.scope.use_utility_code(empty_bufstruct_utility)

42 43 44 45 46 47 48 49 50 51 52 53 54 55
        return result


    #
    # Basic operations for transforms
    #
    def handle_scope(self, node, scope):
        # For all buffers, insert extra variables in the scope.
        # The variables are also accessible from the buffer_info
        # on the buffer entry
        bufvars = [entry for name, entry
                   in scope.entries.iteritems()
                   if entry.type.is_buffer]
        if len(bufvars) > 0:
56
            bufvars.sort(key=lambda entry: entry.name)
57 58
            self.buffers_exists = True

59
        memviewslicevars = [entry for name, entry
60
                in scope.entries.iteritems()
61 62
                if entry.type.is_memoryviewslice]
        if len(memviewslicevars) > 0:
63 64 65 66 67 68 69 70
            self.buffers_exists = True


        for (name, entry) in scope.entries.iteritems():
            if name == 'memoryview' and isinstance(entry.utility_code_definition, CythonUtilityCode):
                self.using_memoryview = True
                break

71 72

        if isinstance(node, ModuleNode) and len(bufvars) > 0:
73
            # for now...note that pos is wrong
74 75
            raise CompileError(node.pos, "Buffer vars not allowed in module scope")
        for entry in bufvars:
76 77
            if entry.type.dtype.is_ptr:
                raise CompileError(node.pos, "Buffers with pointer types not yet supported.")
78

79 80
            name = entry.name
            buftype = entry.type
81 82 83
            if buftype.ndim > Options.buffer_max_dims:
                raise CompileError(node.pos,
                        "Buffer ndims exceeds Options.buffer_max_dims = %d" % Options.buffer_max_dims)
84 85
            if buftype.ndim > self.max_ndim:
                self.max_ndim = buftype.ndim
86 87

            # Declare auxiliary vars
88 89
            def decvar(type, prefix):
                cname = scope.mangle(prefix, name)
90
                aux_var = scope.declare_var(name=None, cname=cname,
91
                                            type=type, pos=node.pos)
92
                if entry.is_arg:
93
                    aux_var.used = True # otherwise, NameNode will mark whether it is used
94

95
                return aux_var
96

97
            auxvars = ((PyrexTypes.c_pyx_buffer_nd_type, Naming.pybuffernd_prefix),
98
                       (PyrexTypes.c_pyx_buffer_type, Naming.pybufferstruct_prefix))
99
            pybuffernd, rcbuffer = [decvar(type, prefix) for (type, prefix) in auxvars]
100

101
            entry.buffer_aux = Symtab.BufferAux(pybuffernd, rcbuffer)
102

103 104 105 106 107 108 109 110 111 112 113 114 115
        scope.buffer_entries = bufvars
        self.scope = scope

    def visit_ModuleNode(self, node):
        self.handle_scope(node, node.scope)
        self.visitchildren(node)
        return node

    def visit_FuncDefNode(self, node):
        self.handle_scope(node, node.local_scope)
        self.visitchildren(node)
        return node

116 117 118
#
# Analysis
#
119 120
buffer_options = ("dtype", "ndim", "mode", "negative_indices", "cast") # ordered!
buffer_defaults = {"ndim": 1, "mode": "full", "negative_indices": True, "cast": False}
121
buffer_positional_options_count = 1 # anything beyond this needs keyword argument
122 123 124 125 126

ERR_BUF_OPTION_UNKNOWN = '"%s" is not a buffer option'
ERR_BUF_TOO_MANY = 'Too many buffer options'
ERR_BUF_DUP = '"%s" buffer option already supplied'
ERR_BUF_MISSING = '"%s" missing'
127
ERR_BUF_MODE = 'Only allowed buffer modes are: "c", "fortran", "full", "strided" (as a compile-time string)'
128
ERR_BUF_NDIM = 'ndim must be a non-negative integer'
129
ERR_BUF_DTYPE = 'dtype must be "object", numeric type or a struct'
130
ERR_BUF_BOOL = '"%s" must be a boolean'
131 132 133 134 135 136 137 138 139 140 141 142 143 144

def analyse_buffer_options(globalpos, env, posargs, dictargs, defaults=None, need_complete=True):
    """
    Must be called during type analysis, as analyse is called
    on the dtype argument.

    posargs and dictargs should consist of a list and a dict
    of tuples (value, pos). Defaults should be a dict of values.

    Returns a dict containing all the options a buffer can have and
    its value (with the positions stripped).
    """
    if defaults is None:
        defaults = buffer_defaults
145

146
    posargs, dictargs = Interpreter.interpret_compiletime_options(posargs, dictargs, type_env=env, type_args = (0,'dtype'))
147

148
    if len(posargs) > buffer_positional_options_count:
149 150 151
        raise CompileError(posargs[-1][1], ERR_BUF_TOO_MANY)

    options = {}
Stefan Behnel's avatar
Stefan Behnel committed
152
    for name, (value, pos) in dictargs.iteritems():
153 154
        if not name in buffer_options:
            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
Stefan Behnel's avatar
Stefan Behnel committed
155 156
        options[name] = value

157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172
    for name, (value, pos) in zip(buffer_options, posargs):
        if not name in buffer_options:
            raise CompileError(pos, ERR_BUF_OPTION_UNKNOWN % name)
        if name in options:
            raise CompileError(pos, ERR_BUF_DUP % name)
        options[name] = value

    # Check that they are all there and copy defaults
    for name in buffer_options:
        if not name in options:
            try:
                options[name] = defaults[name]
            except KeyError:
                if need_complete:
                    raise CompileError(globalpos, ERR_BUF_MISSING % name)

173 174 175 176 177 178
    dtype = options.get("dtype")
    if dtype and dtype.is_extension_type:
        raise CompileError(globalpos, ERR_BUF_DTYPE)

    ndim = options.get("ndim")
    if ndim and (not isinstance(ndim, int) or ndim < 0):
179 180
        raise CompileError(globalpos, ERR_BUF_NDIM)

181
    mode = options.get("mode")
182
    if mode and not (mode in ('full', 'strided', 'c', 'fortran')):
183 184
        raise CompileError(globalpos, ERR_BUF_MODE)

185 186 187 188 189 190 191
    def assert_bool(name):
        x = options.get(name)
        if not isinstance(x, bool):
            raise CompileError(globalpos, ERR_BUF_BOOL % name)

    assert_bool('negative_indices')
    assert_bool('cast')
192

193
    return options
194

195 196 197 198

#
# Code generation
#
199

200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219
class BufferEntry(object):
    def __init__(self, entry):
        self.entry = entry
        self.type = entry.type
        self.cname = entry.buffer_aux.buflocal_nd_var.cname
        self.buf_ptr = "%s.rcbuffer->pybuffer.buf" % self.cname
        self.buf_ptr_type = self.entry.type.buffer_ptr_type

    def get_buf_suboffsetvars(self):
        return self._for_all_ndim("%s.diminfo[%d].suboffsets")

    def get_buf_stridevars(self):
        return self._for_all_ndim("%s.diminfo[%d].strides")

    def get_buf_shapevars(self):
        return self._for_all_ndim("%s.diminfo[%d].shape")

    def _for_all_ndim(self, s):
        return [s % (self.cname, i) for i in range(self.type.ndim)]

220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258
    def generate_buffer_lookup_code(self, code, index_cnames):
        # Create buffer lookup and return it
        # This is done via utility macros/inline functions, which vary
        # according to the access mode used.
        params = []
        nd = self.type.ndim
        mode = self.type.mode
        if mode == 'full':
            for i, s, o in zip(index_cnames,
                               self.get_buf_stridevars(),
                               self.get_buf_suboffsetvars()):
                params.append(i)
                params.append(s)
                params.append(o)
            funcname = "__Pyx_BufPtrFull%dd" % nd
            funcgen = buf_lookup_full_code
        else:
            if mode == 'strided':
                funcname = "__Pyx_BufPtrStrided%dd" % nd
                funcgen = buf_lookup_strided_code
            elif mode == 'c':
                funcname = "__Pyx_BufPtrCContig%dd" % nd
                funcgen = buf_lookup_c_code
            elif mode == 'fortran':
                funcname = "__Pyx_BufPtrFortranContig%dd" % nd
                funcgen = buf_lookup_fortran_code
            else:
                assert False
            for i, s in zip(index_cnames, self.get_buf_stridevars()):
                params.append(i)
                params.append(s)

        # Make sure the utility code is available
        if funcname not in code.globalstate.utility_codes:
            code.globalstate.utility_codes.add(funcname)
            protocode = code.globalstate['utility_code_proto']
            defcode = code.globalstate['utility_code_def']
            funcgen(protocode, defcode, name=funcname, nd=nd)

259
        buf_ptr_type_code = self.buf_ptr_type.empty_declaration_code()
260 261 262 263
        ptrcode = "%s(%s, %s, %s)" % (funcname, buf_ptr_type_code, self.buf_ptr,
                                      ", ".join(params))
        return ptrcode

264

265
def get_flags(buffer_aux, buffer_type):
266
    flags = 'PyBUF_FORMAT'
267 268
    mode = buffer_type.mode
    if mode == 'full':
269
        flags += '| PyBUF_INDIRECT'
270
    elif mode == 'strided':
271
        flags += '| PyBUF_STRIDES'
272 273 274 275
    elif mode == 'c':
        flags += '| PyBUF_C_CONTIGUOUS'
    elif mode == 'fortran':
        flags += '| PyBUF_F_CONTIGUOUS'
276 277
    else:
        assert False
278 279
    if buffer_aux.writable_needed: flags += "| PyBUF_WRITABLE"
    return flags
280

281 282
def used_buffer_aux_vars(entry):
    buffer_aux = entry.buffer_aux
283 284
    buffer_aux.buflocal_nd_var.used = True
    buffer_aux.rcbuf_var.used = True
285

286
def put_unpack_buffer_aux_into_scope(buf_entry, code):
287 288
    # Generate code to copy the needed struct info into local
    # variables.
289 290
    buffer_aux, mode = buf_entry.buffer_aux, buf_entry.type.mode
    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
291

292
    fldnames = ['strides', 'shape']
293
    if mode == 'full':
294 295 296 297 298 299 300 301 302
        fldnames.append('suboffsets')

    ln = []
    for i in range(buf_entry.type.ndim):
        for fldname in fldnames:
            ln.append("%s.diminfo[%d].%s = %s.rcbuffer->pybuffer.%s[%d];" % \
                    (pybuffernd_struct, i, fldname,
                     pybuffernd_struct, fldname, i))
    code.putln(' '.join(ln))
303

304 305 306 307 308 309 310 311 312 313 314 315
def put_init_vars(entry, code):
    bufaux = entry.buffer_aux
    pybuffernd_struct = bufaux.buflocal_nd_var.cname
    pybuffer_struct = bufaux.rcbuf_var.cname
    # init pybuffer_struct
    code.putln("%s.pybuffer.buf = NULL;" % pybuffer_struct)
    code.putln("%s.refcount = 0;" % pybuffer_struct)
    # init the buffer object
    # code.put_init_var_to_py_none(entry)
    # init the pybuffernd_struct
    code.putln("%s.data = NULL;" % pybuffernd_struct)
    code.putln("%s.rcbuffer = &%s;" % (pybuffernd_struct, pybuffer_struct))
316 317

def put_acquire_arg_buffer(entry, code, pos):
318
    code.globalstate.use_utility_code(acquire_utility_code)
319
    buffer_aux = entry.buffer_aux
320
    getbuffer = get_getbuffer_call(code, entry.cname, buffer_aux, entry.type)
321

322
    # Acquire any new buffer
323
    code.putln("{")
324
    code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % entry.type.dtype.struct_nesting_depth())
325 326
    code.putln(code.error_goto_if("%s == -1" % getbuffer, pos))
    code.putln("}")
327
    # An exception raised in arg parsing cannot be catched, so no
328
    # need to care about the buffer then.
329
    put_unpack_buffer_aux_into_scope(entry, code)
330

331 332
def put_release_buffer_code(code, entry):
    code.globalstate.use_utility_code(acquire_utility_code)
333
    code.putln("__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);" % entry.buffer_aux.buflocal_nd_var.cname)
334

335 336 337 338
def get_getbuffer_call(code, obj_cname, buffer_aux, buffer_type):
    ndim = buffer_type.ndim
    cast = int(buffer_type.cast)
    flags = get_flags(buffer_aux, buffer_type)
339
    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
340 341

    dtype_typeinfo = get_type_information_cname(code, buffer_type.dtype)
342

343
    return ("__Pyx_GetBufferAndValidate(&%(pybuffernd_struct)s.rcbuffer->pybuffer, "
344
            "(PyObject*)%(obj_cname)s, &%(dtype_typeinfo)s, %(flags)s, %(ndim)d, "
345
            "%(cast)d, __pyx_stack)" % locals())
346

347
def put_assign_to_buffer(lhs_cname, rhs_cname, buf_entry,
348
                         is_initialized, pos, code):
349 350 351 352 353 354 355
    """
    Generate code for reassigning a buffer variables. This only deals with getting
    the buffer auxiliary structure and variables set up correctly, the assignment
    itself and refcounting is the responsibility of the caller.

    However, the assignment operation may throw an exception so that the reassignment
    never happens.
356

357 358 359 360 361
    Depending on the circumstances there are two possible outcomes:
    - Old buffer released, new acquired, rhs assigned to lhs
    - Old buffer released, new acquired which fails, reaqcuire old lhs buffer
      (which may or may not succeed).
    """
362

363
    buffer_aux, buffer_type = buf_entry.buffer_aux, buf_entry.type
364
    code.globalstate.use_utility_code(acquire_utility_code)
365
    pybuffernd_struct = buffer_aux.buflocal_nd_var.cname
366
    flags = get_flags(buffer_aux, buffer_type)
367

368
    code.putln("{")  # Set up necesarry stack for getbuffer
369
    code.putln("__Pyx_BufFmt_StackElem __pyx_stack[%d];" % buffer_type.dtype.struct_nesting_depth())
370

371
    getbuffer = get_getbuffer_call(code, "%s", buffer_aux, buffer_type) # fill in object below
372

373 374
    if is_initialized:
        # Release any existing buffer
375
        code.putln('__Pyx_SafeReleaseBuffer(&%s.rcbuffer->pybuffer);' % pybuffernd_struct)
376
        # Acquire
377
        retcode_cname = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
378
        code.putln("%s = %s;" % (retcode_cname, getbuffer % rhs_cname))
Stefan Behnel's avatar
Stefan Behnel committed
379
        code.putln('if (%s) {' % (code.unlikely("%s < 0" % retcode_cname)))
380 381 382 383
        # If acquisition failed, attempt to reacquire the old buffer
        # before raising the exception. A failure of reacquisition
        # will cause the reacquisition exception to be reported, one
        # can consider working around this later.
384
        type, value, tb = [code.funcstate.allocate_temp(PyrexTypes.py_object_type, manage_ref=False)
385 386
                           for i in range(3)]
        code.putln('PyErr_Fetch(&%s, &%s, &%s);' % (type, value, tb))
Stefan Behnel's avatar
Stefan Behnel committed
387
        code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % lhs_cname)))
388
        code.putln('Py_XDECREF(%s); Py_XDECREF(%s); Py_XDECREF(%s);' % (type, value, tb)) # Do not refnanny these!
389
        code.globalstate.use_utility_code(raise_buffer_fallback_code)
390
        code.putln('__Pyx_RaiseBufferFallbackError();')
391
        code.putln('} else {')
392 393
        code.putln('PyErr_Restore(%s, %s, %s);' % (type, value, tb))
        for t in (type, value, tb):
394
            code.funcstate.release_temp(t)
Stefan Behnel's avatar
Stefan Behnel committed
395 396
        code.putln('}')
        code.putln('}')
397
        # Unpack indices
398
        put_unpack_buffer_aux_into_scope(buf_entry, code)
399
        code.putln(code.error_goto_if_neg(retcode_cname, pos))
400
        code.funcstate.release_temp(retcode_cname)
401
    else:
402 403 404 405
        # Our entry had no previous value, so set to None when acquisition fails.
        # In this case, auxiliary vars should be set up right in initialization to a zero-buffer,
        # so it suffices to set the buf field to NULL.
        code.putln('if (%s) {' % code.unlikely("%s == -1" % (getbuffer % rhs_cname)))
406
        code.putln('%s = %s; __Pyx_INCREF(Py_None); %s.rcbuffer->pybuffer.buf = NULL;' %
407 408
                   (lhs_cname,
                    PyrexTypes.typecast(buffer_type, PyrexTypes.py_object_type, "Py_None"),
409
                    pybuffernd_struct))
410 411 412
        code.putln(code.error_goto(pos))
        code.put('} else {')
        # Unpack indices
413
        put_unpack_buffer_aux_into_scope(buf_entry, code)
414
        code.putln('}')
415

416
    code.putln("}") # Release stack
417

418

419
def put_buffer_lookup_code(entry, index_signeds, index_cnames, directives,
420
                           pos, code, negative_indices, in_nogil_context):
421 422 423 424 425
    """
    Generates code to process indices and calculate an offset into
    a buffer. Returns a C string which gives a pointer which can be
    read from or written to at will (it is an expression so caller should
    store it in a temporary if it is used more than once).
426 427 428 429 430

    As the bounds checking can have any number of combinations of unsigned
    arguments, smart optimizations etc. we insert it directly in the function
    body. The lookup however is delegated to a inline function that is instantiated
    once per ndim (lookup with suboffsets tend to get quite complicated).
431

432
    entry is a BufferEntry
433
    """
434
    negative_indices = directives['wraparound'] and negative_indices
435

436
    if directives['boundscheck']:
437 438
        # Check bounds and fix negative indices.
        # We allocate a temporary which is initialized to -1, meaning OK (!).
439 440 441 442 443
        # If an error occurs, the temp is set to the index dimension the
        # error is occurring at.
        failed_dim_temp = code.funcstate.allocate_temp(PyrexTypes.c_int_type, manage_ref=False)
        code.putln("%s = -1;" % failed_dim_temp)
        for dim, (signed, cname, shape) in enumerate(zip(index_signeds, index_cnames, entry.get_buf_shapevars())):
444 445 446
            if signed != 0:
                # not unsigned, deal with negative index
                code.putln("if (%s < 0) {" % cname)
447
                if negative_indices:
448
                    code.putln("%s += %s;" % (cname, shape))
449
                    code.putln("if (%s) %s = %d;" % (
450 451
                        code.unlikely("%s < 0" % cname),
                        failed_dim_temp, dim))
452
                else:
453
                    code.putln("%s = %d;" % (failed_dim_temp, dim))
454
                code.put("} else ")
455
            # check bounds in positive direction
456
            if signed != 0:
457 458 459
                cast = ""
            else:
                cast = "(size_t)"
460
            code.putln("if (%s) %s = %d;" % (
461
                code.unlikely("%s >= %s%s" % (cname, cast, shape)),
462
                failed_dim_temp, dim))
463 464 465 466 467 468 469 470

        if in_nogil_context:
            code.globalstate.use_utility_code(raise_indexerror_nogil)
            func = '__Pyx_RaiseBufferIndexErrorNogil'
        else:
            code.globalstate.use_utility_code(raise_indexerror_code)
            func = '__Pyx_RaiseBufferIndexError'

471 472
        code.putln("if (%s) {" % code.unlikely("%s != -1" % failed_dim_temp))
        code.putln('%s(%s);' % (func, failed_dim_temp))
473
        code.putln(code.error_goto(pos))
Stefan Behnel's avatar
Stefan Behnel committed
474
        code.putln('}')
475
        code.funcstate.release_temp(failed_dim_temp)
476
    elif negative_indices:
477
        # Only fix negative indices.
478
        for signed, cname, shape in zip(index_signeds, index_cnames, entry.get_buf_shapevars()):
479
            if signed != 0:
480
                code.putln("if (%s < 0) %s += %s;" % (cname, cname, shape))
481

482
    return entry.generate_buffer_lookup_code(code, index_cnames)
483

484

485 486 487
def use_bufstruct_declare_code(env):
    env.use_utility_code(buffer_struct_declare_code)

488 489

def get_empty_bufstruct_code(max_ndim):
490
    code = dedent("""
491 492
        static Py_ssize_t __Pyx_zeros[] = {%s};
        static Py_ssize_t __Pyx_minusones[] = {%s};
493
    """) % (", ".join(["0"] * max_ndim), ", ".join(["-1"] * max_ndim))
494
    return UtilityCode(proto=code)
495

496
empty_bufstruct_utility = get_empty_bufstruct_code(Options.buffer_max_dims)
497

498
def buf_lookup_full_code(proto, defin, name, nd):
499
    """
500
    Generates a buffer lookup function for the right number
501 502
    of dimensions. The function gives back a void* at the right location.
    """
503
    # _i_ndex, _s_tride, sub_o_ffset
504 505 506 507
    macroargs = ", ".join(["i%d, s%d, o%d" % (i, i, i) for i in range(nd)])
    proto.putln("#define %s(type, buf, %s) (type)(%s_imp(buf, %s))" % (name, macroargs, name, macroargs))

    funcargs = ", ".join(["Py_ssize_t i%d, Py_ssize_t s%d, Py_ssize_t o%d" % (i, i, i) for i in range(nd)])
508
    proto.putln("static CYTHON_INLINE void* %s_imp(void* buf, %s);" % (name, funcargs))
509
    defin.putln(dedent("""
510
        static CYTHON_INLINE void* %s_imp(void* buf, %s) {
511
          char* ptr = (char*)buf;
512
        """) % (name, funcargs) + "".join([dedent("""\
513
          ptr += s%d * i%d;
514
          if (o%d >= 0) ptr = *((char**)ptr) + o%d;
515
        """) % (i, i, i, i) for i in range(nd)]
516
        ) + "\nreturn ptr;\n}")
517

518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550
def buf_lookup_strided_code(proto, defin, name, nd):
    """
    Generates a buffer lookup function for the right number
    of dimensions. The function gives back a void* at the right location.
    """
    # _i_ndex, _s_tride
    args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
    offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd)])
    proto.putln("#define %s(type, buf, %s) (type)((char*)buf + %s)" % (name, args, offset))

def buf_lookup_c_code(proto, defin, name, nd):
    """
    Similar to strided lookup, but can assume that the last dimension
    doesn't need a multiplication as long as.
    Still we keep the same signature for now.
    """
    if nd == 1:
        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
    else:
        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(nd - 1)])
        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, nd - 1))

def buf_lookup_fortran_code(proto, defin, name, nd):
    """
    Like C lookup, but the first index is optimized instead.
    """
    if nd == 1:
        proto.putln("#define %s(type, buf, i0, s0) ((type)buf + i0)" % name)
    else:
        args = ", ".join(["i%d, s%d" % (i, i) for i in range(nd)])
        offset = " + ".join(["i%d * s%d" % (i, i) for i in range(1, nd)])
        proto.putln("#define %s(type, buf, %s) ((type)((char*)buf + %s) + i%d)" % (name, args, offset, 0))
551

552 553

def use_py2_buffer_functions(env):
554 555 556
    env.use_utility_code(GetAndReleaseBufferUtilityCode())

class GetAndReleaseBufferUtilityCode(object):
557 558 559
    # Emulation of PyObject_GetBuffer and PyBuffer_Release for Python 2.
    # For >= 2.6 we do double mode -- use the new buffer interface on objects
    # which has the right tp_flags set, but emulation otherwise.
560

561
    requires = None
562
    is_cython_utility = False
563

564 565 566 567 568 569 570 571 572 573 574 575 576
    def __init__(self):
        pass

    def __eq__(self, other):
        return isinstance(other, GetAndReleaseBufferUtilityCode)

    def __hash__(self):
        return 24342342

    def get_tree(self): pass

    def put_code(self, output):
        code = output['utility_code_def']
577
        proto_code = output['utility_code_proto']
578 579 580 581 582 583 584 585 586 587 588 589 590
        env = output.module_node.scope
        cython_scope = env.context.cython_scope
        
        # Search all types for __getbuffer__ overloads
        types = []
        visited_scopes = set()
        def find_buffer_types(scope):
            if scope in visited_scopes:
                return
            visited_scopes.add(scope)
            for m in scope.cimported_modules:
                find_buffer_types(m)
            for e in scope.type_entries:
591 592
                if isinstance(e.utility_code_definition, CythonUtilityCode):
                    continue
593 594 595 596 597 598 599 600 601 602 603 604 605
                t = e.type
                if t.is_extension_type:
                    if scope is cython_scope and not e.used:
                        continue
                    release = get = None
                    for x in t.scope.pyfunc_entries:
                        if x.name == u"__getbuffer__": get = x.func_cname
                        elif x.name == u"__releasebuffer__": release = x.func_cname
                    if get:
                        types.append((t.typeptr_cname, get, release))

        find_buffer_types(env)

606
        util_code = TempitaUtilityCode.load(
607 608 609
            "GetAndReleaseBuffer", from_file="Buffer.c",
            context=dict(types=types))

610 611
        proto = util_code.format_code(util_code.proto)
        impl = util_code.format_code(
612
            util_code.inject_string_constants(util_code.impl, output)[1])
613

614 615
        proto_code.putln(proto)
        code.putln(impl)
616

617

618 619 620 621 622 623 624 625 626 627 628 629
def mangle_dtype_name(dtype):
    # Use prefixes to seperate user defined types from builtins
    # (consider "typedef float unsigned_int")
    if dtype.is_pyobject:
        return "object"
    elif dtype.is_ptr:
        return "ptr"
    else:
        if dtype.is_typedef or dtype.is_struct_or_union:
            prefix = "nn_"
        else:
            prefix = ""
630
        type_decl = dtype.empty_declaration_code()
631 632
        type_decl = type_decl.replace(" ", "_")
        return prefix + type_decl.replace("[", "_").replace("]", "_")
633 634

def get_type_information_cname(code, dtype, maxdepth=None):
635 636 637 638 639 640 641 642 643
    """
    Output the run-time type information (__Pyx_TypeInfo) for given dtype,
    and return the name of the type info struct.

    Structs with two floats of the same size are encoded as complex numbers.
    One can seperate between complex numbers declared as struct or with native
    encoding by inspecting to see if the fields field of the type is
    filled in.
    """
644 645 646
    namesuffix = mangle_dtype_name(dtype)
    name = "__Pyx_TypeInfo_%s" % namesuffix
    structinfo_name = "__Pyx_StructFields_%s" % namesuffix
647

648
    if dtype.is_error: return "<error>"
649

650 651 652
    # It's critical that walking the type info doesn't use more stack
    # depth than dtype.struct_nesting_depth() returns, so use an assertion for this
    if maxdepth is None: maxdepth = dtype.struct_nesting_depth()
653 654 655
    if maxdepth <= 0:
        assert False

656 657 658
    if name not in code.globalstate.utility_codes:
        code.globalstate.utility_codes.add(name)
        typecode = code.globalstate['typeinfo']
659

660 661 662 663 664 665
        arraysizes = []
        if dtype.is_array:
            while dtype.is_array:
                arraysizes.append(dtype.size)
                dtype = dtype.base_type

666
        complex_possible = dtype.is_struct_or_union and dtype.can_be_complex()
667

668
        declcode = dtype.empty_declaration_code()
669 670 671 672 673 674 675
        if dtype.is_simple_buffer_dtype():
            structinfo_name = "NULL"
        elif dtype.is_struct:
            fields = dtype.scope.var_entries
            # Must pre-call all used types in order not to recurse utility code
            # writing.
            assert len(fields) > 0
676
            types = [get_type_information_cname(code, f.type, maxdepth - 1)
677 678 679 680
                     for f in fields]
            typecode.putln("static __Pyx_StructField %s[] = {" % structinfo_name, safe=True)
            for f, typeinfo in zip(fields, types):
                typecode.putln('  {&%s, "%s", offsetof(%s, %s)},' %
681
                           (typeinfo, f.name, dtype.empty_declaration_code(), f.cname), safe=True)
682 683 684 685
            typecode.putln('  {NULL, NULL, 0}', safe=True)
            typecode.putln("};", safe=True)
        else:
            assert False
686

687
        rep = str(dtype)
688 689

        flags = "0"
690
        is_unsigned = "0"
691 692 693 694
        if dtype is PyrexTypes.c_char_type:
            is_unsigned = "IS_UNSIGNED(%s)" % declcode
            typegroup = "'H'"
        elif dtype.is_int:
695 696
            is_unsigned = "IS_UNSIGNED(%s)" % declcode
            typegroup = "%s ? 'U' : 'I'" % is_unsigned
697
        elif complex_possible or dtype.is_complex:
698
            typegroup = "'C'"
699
        elif dtype.is_float:
700
            typegroup = "'R'"
701
        elif dtype.is_struct:
702
            typegroup = "'S'"
703 704
            if dtype.packed:
                flags = "__PYX_BUF_FLAGS_PACKED_STRUCT"
705
        elif dtype.is_pyobject:
706
            typegroup = "'O'"
707
        else:
708
            assert False, dtype
709

710
        typeinfo = ('static __Pyx_TypeInfo %s = '
711
                        '{ "%s", %s, sizeof(%s), { %s }, %s, %s, %s, %s };')
712
        tup = (name, rep, structinfo_name, declcode,
Mark Florisson's avatar
Mark Florisson committed
713
               ', '.join([str(x) for x in arraysizes]) or '0', len(arraysizes),
714 715 716
               typegroup, is_unsigned, flags)
        typecode.putln(typeinfo % tup, safe=True)

717
    return name
718

719 720 721 722 723
def load_buffer_utility(util_code_name, context=None, **kwargs):
    if context is None:
        return UtilityCode.load(util_code_name, "Buffer.c", **kwargs)
    else:
        return TempitaUtilityCode.load(util_code_name, "Buffer.c", context=context, **kwargs)
724

Mark Florisson's avatar
Mark Florisson committed
725
context = dict(max_dims=str(Options.buffer_max_dims))
726 727
buffer_struct_declare_code = load_buffer_utility("BufferStructDeclare",
                                                 context=context)
728

729

730 731
# Utility function to set the right exception
# The caller should immediately goto_error
732
raise_indexerror_code = load_buffer_utility("BufferIndexError")
733
raise_indexerror_nogil = load_buffer_utility("BufferIndexErrorNogil")
734

735
raise_buffer_fallback_code = load_buffer_utility("BufferFallbackError")
736 737
buffer_structs_code = load_buffer_utility(
        "BufferFormatStructs", proto_block='utility_code_proto_before_types')
738
acquire_utility_code = load_buffer_utility("BufferFormatCheck",
739 740 741 742
                                           context=context,
                                           requires=[buffer_structs_code])

# See utility code BufferFormatFromTypeInfo
743 744 745
_typeinfo_to_format_code = load_buffer_utility("TypeInfoToFormat", context={},
                                               requires=[buffer_structs_code])
typeinfo_compare_code = load_buffer_utility("TypeInfoCompare", context={},
746
                                            requires=[buffer_structs_code])