Commit eaae9bab authored by Vitja Makarov's avatar Vitja Makarov

Merge pull request #88 from vitek/_cyfunction_defaults

 cyfunction defaults
parents a102bc69 19bb02e7
...@@ -6087,7 +6087,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6087,7 +6087,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
# module_name EncodedString Name of defining module # module_name EncodedString Name of defining module
# code_object CodeObjectNode the PyCodeObject creator node # code_object CodeObjectNode the PyCodeObject creator node
subexprs = ['code_object'] subexprs = ['code_object', 'defaults_tuple']
self_object = None self_object = None
code_object = None code_object = None
...@@ -6096,6 +6096,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6096,6 +6096,7 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
defaults = None defaults = None
defaults_struct = None defaults_struct = None
defaults_pyobjects = 0 defaults_pyobjects = 0
defaults_tuple = None
type = py_object_type type = py_object_type
is_temp = 1 is_temp = 1
...@@ -6122,13 +6123,18 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6122,13 +6123,18 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
""" """
nonliteral_objects = [] nonliteral_objects = []
nonliteral_other = [] nonliteral_other = []
default_args = []
for arg in self.def_node.args: for arg in self.def_node.args:
if arg.default and not arg.default.is_literal: if arg.default:
if not arg.default.is_literal:
arg.is_dynamic = True arg.is_dynamic = True
if arg.type.is_pyobject: if arg.type.is_pyobject:
nonliteral_objects.append(arg) nonliteral_objects.append(arg)
else: else:
nonliteral_other.append(arg) nonliteral_other.append(arg)
else:
arg.default = DefaultLiteralArgNode(arg.pos, arg.default)
default_args.append(arg)
if nonliteral_objects or nonliteral_objects: if nonliteral_objects or nonliteral_objects:
module_scope = env.global_scope() module_scope = env.global_scope()
cname = module_scope.next_id(Naming.defaults_struct_prefix) cname = module_scope.next_id(Naming.defaults_struct_prefix)
...@@ -6153,6 +6159,28 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6153,6 +6159,28 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
Naming.dynamic_args_cname, entry.cname) Naming.dynamic_args_cname, entry.cname)
self.def_node.defaults_struct = self.defaults_struct.name self.def_node.defaults_struct = self.defaults_struct.name
if default_args:
if self.defaults_struct is None:
self.defaults_tuple = TupleNode(self.pos, args=[
arg.default for arg in default_args])
self.defaults_tuple.analyse_types(env)
else:
defaults_getter = Nodes.DefNode(
self.pos, args=[], star_arg=None, starstar_arg=None,
body=Nodes.ReturnStatNode(
self.pos, return_type=py_object_type,
value=DefaultsTupleNode(
self.pos, default_args,
self.defaults_struct)),
decorators=None, name="__defaults__")
defaults_getter.analyse_declarations(env)
defaults_getter.analyse_expressions(env)
defaults_getter.body.analyse_expressions(
defaults_getter.local_scope)
defaults_getter.py_wrapper_required = False
defaults_getter.pymethdef_required = False
self.def_node.defaults_getter = defaults_getter
def may_be_none(self): def may_be_none(self):
return False return False
...@@ -6245,6 +6273,13 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin): ...@@ -6245,6 +6273,13 @@ class PyCFunctionNode(ExprNode, ModuleNameMixin):
arg.generate_assignment_code(code, target='%s->%s' % ( arg.generate_assignment_code(code, target='%s->%s' % (
defaults, entry.cname)) defaults, entry.cname))
if self.defaults_tuple:
code.putln('__Pyx_CyFunction_SetDefaultsTuple(%s, %s);' % (
self.result(), self.defaults_tuple.py_result()))
if def_node.defaults_getter:
code.putln('__Pyx_CyFunction_SetDefaultsGetter(%s, %s);' % (
self.result(), def_node.defaults_getter.entry.pyfunc_cname))
if self.specialized_cpdefs: if self.specialized_cpdefs:
self.generate_fused_cpdef(code, code_object_result, flags) self.generate_fused_cpdef(code, code_object_result, flags)
...@@ -6375,6 +6410,73 @@ class CodeObjectNode(ExprNode): ...@@ -6375,6 +6410,73 @@ class CodeObjectNode(ExprNode):
)) ))
class DefaultLiteralArgNode(ExprNode):
# CyFunction's literal argument default value
#
# Evaluate literal only once.
subexprs = []
is_literal = True
is_temp = False
def __init__(self, pos, arg):
super(DefaultLiteralArgNode, self).__init__(pos)
self.arg = arg
self.type = self.arg.type
self.evaluated = False
def analyse_types(self, env):
pass
def generate_result_code(self, code):
pass
def generate_evaluation_code(self, code):
if not self.evaluated:
self.arg.generate_evaluation_code(code)
self.evaluated = True
def result(self):
return self.type.cast_code(self.arg.result())
class DefaultNonLiteralArgNode(ExprNode):
# CyFunction's non-literal argument default value
subexprs = []
def __init__(self, pos, arg, defaults_struct):
super(DefaultNonLiteralArgNode, self).__init__(pos)
self.arg = arg
self.defaults_struct = defaults_struct
def analyse_types(self, env):
self.type = self.arg.type
self.is_temp = False
def generate_result_code(self, code):
pass
def result(self):
return '__Pyx_CyFunction_Defaults(%s, %s)->%s' % (
self.defaults_struct.name, Naming.self_cname,
self.defaults_struct.lookup(self.arg.name).cname)
class DefaultsTupleNode(TupleNode):
# CyFunction's __defaults__ tuple
def __init__(self, pos, defaults, defaults_struct):
args = []
for arg in defaults:
if not arg.default.is_literal:
arg = DefaultNonLiteralArgNode(pos, arg, defaults_struct)
else:
arg = arg.default
args.append(arg)
super(DefaultsTupleNode, self).__init__(pos, args=args)
class LambdaNode(InnerFunctionNode): class LambdaNode(InnerFunctionNode):
# Lambda expression node (only used as a function reference) # Lambda expression node (only used as a function reference)
# #
......
...@@ -2704,6 +2704,8 @@ class DefNode(FuncDefNode): ...@@ -2704,6 +2704,8 @@ class DefNode(FuncDefNode):
py_wrapper_required = True py_wrapper_required = True
func_cname = None func_cname = None
defaults_getter = None
def __init__(self, pos, **kwds): def __init__(self, pos, **kwds):
FuncDefNode.__init__(self, pos, **kwds) FuncDefNode.__init__(self, pos, **kwds)
k = rk = r = 0 k = rk = r = 0
...@@ -3068,6 +3070,9 @@ class DefNode(FuncDefNode): ...@@ -3068,6 +3070,9 @@ class DefNode(FuncDefNode):
return 1 return 1
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
if self.defaults_getter:
self.defaults_getter.generate_function_definitions(env, code)
# Before closure cnames are mangled # Before closure cnames are mangled
if self.py_wrapper_required: if self.py_wrapper_required:
# func_cname might be modified by @cname # func_cname might be modified by @cname
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#define __Pyx_CyFunction_Defaults(type, f) \ #define __Pyx_CyFunction_Defaults(type, f) \
((type *)(((__pyx_CyFunctionObject *) (f))->defaults)) ((type *)(((__pyx_CyFunctionObject *) (f))->defaults))
#define __Pyx_CyFunction_SetDefaultsGetter(f, g) \
((__pyx_CyFunctionObject *) (f))->defaults_getter = (g)
typedef struct { typedef struct {
...@@ -31,6 +33,10 @@ typedef struct { ...@@ -31,6 +33,10 @@ typedef struct {
/* Dynamic default args*/ /* Dynamic default args*/
void *defaults; void *defaults;
int defaults_pyobjects; int defaults_pyobjects;
/* Defaults info */
PyObject *defaults_tuple; /* Const defaults tuple */
PyObject *(*defaults_getter)(PyObject *);
} __pyx_CyFunctionObject; } __pyx_CyFunctionObject;
static PyTypeObject *__pyx_CyFunctionType = 0; static PyTypeObject *__pyx_CyFunctionType = 0;
...@@ -46,6 +52,8 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *, ...@@ -46,6 +52,8 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *,
static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m, static CYTHON_INLINE void *__Pyx_CyFunction_InitDefaults(PyObject *m,
size_t size, size_t size,
int pyobjects); int pyobjects);
static CYTHON_INLINE void __Pyx_CyFunction_SetDefaultsTuple(PyObject *m,
PyObject *tuple);
static int __Pyx_CyFunction_init(void); static int __Pyx_CyFunction_init(void);
...@@ -197,6 +205,29 @@ __Pyx_CyFunction_get_code(__pyx_CyFunctionObject *op) ...@@ -197,6 +205,29 @@ __Pyx_CyFunction_get_code(__pyx_CyFunctionObject *op)
return result; return result;
} }
static PyObject *
__Pyx_CyFunction_get_defaults(__pyx_CyFunctionObject *op)
{
if (op->defaults_tuple) {
Py_INCREF(op->defaults_tuple);
return op->defaults_tuple;
}
if (op->defaults_getter) {
PyObject *res = op->defaults_getter((PyObject *) op);
/* Cache result */
if (res) {
Py_INCREF(res);
op->defaults_tuple = res;
}
return res;
}
Py_INCREF(Py_None);
return Py_None;
}
static PyGetSetDef __pyx_CyFunction_getsets[] = { static PyGetSetDef __pyx_CyFunction_getsets[] = {
{(char *) "func_doc", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, {(char *) "func_doc", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0},
{(char *) "__doc__", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0}, {(char *) "__doc__", (getter)__Pyx_CyFunction_get_doc, (setter)__Pyx_CyFunction_set_doc, 0, 0},
...@@ -211,6 +242,8 @@ static PyGetSetDef __pyx_CyFunction_getsets[] = { ...@@ -211,6 +242,8 @@ static PyGetSetDef __pyx_CyFunction_getsets[] = {
{(char *) "__closure__", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0}, {(char *) "__closure__", (getter)__Pyx_CyFunction_get_closure, 0, 0, 0},
{(char *) "func_code", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, {(char *) "func_code", (getter)__Pyx_CyFunction_get_code, 0, 0, 0},
{(char *) "__code__", (getter)__Pyx_CyFunction_get_code, 0, 0, 0}, {(char *) "__code__", (getter)__Pyx_CyFunction_get_code, 0, 0, 0},
{(char *) "func_defaults", (getter)__Pyx_CyFunction_get_defaults, 0, 0, 0},
{(char *) "__defaults__", (getter)__Pyx_CyFunction_get_defaults, 0, 0, 0},
{0, 0, 0, 0, 0} {0, 0, 0, 0, 0}
}; };
...@@ -261,6 +294,8 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *type, PyMethodDef *ml, int f ...@@ -261,6 +294,8 @@ static PyObject *__Pyx_CyFunction_New(PyTypeObject *type, PyMethodDef *ml, int f
/* Dynamic Default args */ /* Dynamic Default args */
op->defaults_pyobjects = 0; op->defaults_pyobjects = 0;
op->defaults = NULL; op->defaults = NULL;
op->defaults_tuple = NULL;
op->defaults_getter = NULL;
PyObject_GC_Track(op); PyObject_GC_Track(op);
return (PyObject *) op; return (PyObject *) op;
} }
...@@ -275,6 +310,7 @@ __Pyx_CyFunction_clear(__pyx_CyFunctionObject *m) ...@@ -275,6 +310,7 @@ __Pyx_CyFunction_clear(__pyx_CyFunctionObject *m)
Py_CLEAR(m->func_doc); Py_CLEAR(m->func_doc);
Py_CLEAR(m->func_code); Py_CLEAR(m->func_code);
Py_CLEAR(m->func_classobj); Py_CLEAR(m->func_classobj);
Py_CLEAR(m->defaults_tuple);
if (m->defaults) { if (m->defaults) {
PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m);
...@@ -308,6 +344,7 @@ static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit, ...@@ -308,6 +344,7 @@ static int __Pyx_CyFunction_traverse(__pyx_CyFunctionObject *m, visitproc visit,
Py_VISIT(m->func_doc); Py_VISIT(m->func_doc);
Py_VISIT(m->func_code); Py_VISIT(m->func_code);
Py_VISIT(m->func_classobj); Py_VISIT(m->func_classobj);
Py_VISIT(m->defaults_tuple);
if (m->defaults) { if (m->defaults) {
PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m); PyObject **pydefaults = __Pyx_CyFunction_Defaults(PyObject *, m);
...@@ -431,6 +468,13 @@ void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects) ...@@ -431,6 +468,13 @@ void *__Pyx_CyFunction_InitDefaults(PyObject *func, size_t size, int pyobjects)
m->defaults_pyobjects = pyobjects; m->defaults_pyobjects = pyobjects;
return m->defaults; return m->defaults;
} }
static void __Pyx_CyFunction_SetDefaultsTuple(PyObject *func, PyObject *tuple)
{
__pyx_CyFunctionObject *m = (__pyx_CyFunctionObject *) func;
m->defaults_tuple = tuple;
Py_INCREF(tuple);
}
//////////////////// CyFunctionClassCell.proto //////////////////// //////////////////// CyFunctionClassCell.proto ////////////////////
static CYTHON_INLINE void __Pyx_CyFunction_InitClassCell(PyObject *cyfunctions, static CYTHON_INLINE void __Pyx_CyFunction_InitClassCell(PyObject *cyfunctions,
PyObject *classobj); PyObject *classobj);
......
# cython: binding=True
# mode: run
# tag: cyfunction
import sys
def get_defaults(func):
if sys.version_info >= (2, 5, 0):
return func.__defaults__
return func.func_defaults
def test_defaults_none():
"""
>>> get_defaults(test_defaults_none)
"""
def test_defaults_literal(a=1, b=(1,2,3)):
"""
>>> get_defaults(test_defaults_literal) is get_defaults(test_defaults_literal)
True
>>> get_defaults(test_defaults_literal)
(1, (1, 2, 3))
>>> a, b = get_defaults(test_defaults_literal)
>>> c, d = test_defaults_literal()
>>> a is c
True
>>> b is d
True
"""
return a, b
def test_defaults_nonliteral():
"""
>>> f0, f1 = test_defaults_nonliteral()
>>> get_defaults(f0) is get_defaults(f0) # cached
True
>>> get_defaults(f0)
(0, {}, (1, 2, 3))
>>> a, b = get_defaults(f0)[1:]
>>> c, d = f0(0)
>>> a is c
True
>>> b is d
True
>>> get_defaults(f1) is get_defaults(f1) # cached
True
>>> get_defaults(f1)
(0, [], (1, 2, 3))
>>> a, b = get_defaults(f1)[1:]
>>> c, d = f1(0)
>>> a is c
True
>>> b is d
True
"""
ret = []
for i in {}, []:
def foo(a, b=0, c=i, d=(1,2,3)):
return c, d
ret.append(foo)
return ret
_counter = 0
def counter():
global _counter
_counter += 1
return _counter
def test_defaults_nonliteral_func_call(f):
"""
>>> f = test_defaults_nonliteral_func_call(counter)
>>> f()
1
>>> get_defaults(f)
(1,)
>>> f = test_defaults_nonliteral_func_call(lambda: list())
>>> f()
[]
>>> get_defaults(f)
([],)
>>> get_defaults(f)[0] is f()
True
"""
def func(a=f()):
return a
return func
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment