Commit e6a81240 authored by Robert Bradshaw's avatar Robert Bradshaw

Python-style binary operation methods.

parent 9c78524a
......@@ -56,6 +56,12 @@ Bugs fixed
* The signature of the NumPy C-API function ``PyArray_SearchSorted()`` was fixed.
Patch by Brock Mendel. (Github issue #3606)
* Added support for Python binary operator semantics.
One can now define, e.g. both ``__add__`` and ``__radd__`` for cdef classes
as for standard Python classes rather than a single ``__add__`` method where
self can be either the first or second argument. (Github issue #2056)
This behavior is guarded by the c_api_binop_methods directive.
0.29.17 (2020-04-26)
====================
......
......@@ -29,7 +29,7 @@ from . import Pythran
from .Errors import error, warning
from .PyrexTypes import py_object_type
from ..Utils import open_new_file, replace_suffix, decode_filename, build_hex_version
from .Code import UtilityCode, IncludeCode
from .Code import UtilityCode, IncludeCode, TempitaUtilityCode
from .StringEncoding import EncodedString
from .Pythran import has_np_pythran
......@@ -1255,6 +1255,9 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
self.generate_dict_getter_function(scope, code)
if scope.defines_any_special(TypeSlots.richcmp_special_methods):
self.generate_richcmp_function(scope, code)
for slot in TypeSlots.PyNumberMethods:
if slot.is_binop and scope.defines_any_special(slot.user_methods):
self.generate_binop_function(scope, slot, code)
self.generate_property_accessors(scope, code)
self.generate_method_table(scope, code)
self.generate_getset_table(scope, code)
......@@ -1892,6 +1895,44 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
code.putln("}") # switch
code.putln("}")
def generate_binop_function(self, scope, slot, code):
func_name = scope.mangle_internal(slot.slot_name)
code.putln()
preprocessor_guard = slot.preprocessor_guard_code()
if preprocessor_guard:
code.putln(preprocessor_guard)
if scope.directives['c_api_binop_methods']:
code.putln('#define %s %s' % (func_name, slot.left_slot.slot_code(scope)))
else:
def has_slot_method(method_name):
entry = scope.lookup(method_name)
return bool(entry and entry.is_special and entry.func_cname)
def call_slot_method(method_name, reverse):
entry = scope.lookup(method_name)
if reverse:
operands = "right, left"
else:
operands = "left, right"
if entry and entry.is_special and entry.func_cname:
return "%s(%s)" % (entry.func_cname, operands)
else:
py_ident = code.intern_identifier(EncodedString(method_name))
return "%s_maybe_call_super(%s, %s)" % (func_name, operands, py_ident)
code.putln(
TempitaUtilityCode.load_cached(
"BinopSlot", "ExtensionTypes.c",
context={
"func_name": func_name,
"slot_name": slot.slot_name,
"overloads_left": int(has_slot_method(slot.left_slot.method_name)),
"call_left": call_slot_method(slot.left_slot.method_name, reverse=False),
"call_right": call_slot_method(slot.right_slot.method_name, reverse=True),
"type_cname": '((PyTypeObject*) %s)' % scope.namespace_cname,
}).impl.strip())
code.putln()
if preprocessor_guard:
code.putln("#endif")
def generate_getattro_function(self, scope, code):
# First try to get the attribute using __getattribute__, if defined, or
# PyObject_GenericGetAttr.
......
......@@ -178,6 +178,7 @@ _directive_defaults = {
'auto_pickle': None,
'cdivision': False, # was True before 0.12
'cdivision_warnings': False,
'c_api_binop_methods': True, # Change for 3.0
'overflowcheck': False,
'overflowcheck.fold': True,
'always_allow_keywords': False,
......
......@@ -180,13 +180,14 @@ class SlotDescriptor(object):
# ifdef Full #ifdef string that slot is wrapped in. Using this causes py3, py2 and flags to be ignored.)
def __init__(self, slot_name, dynamic=False, inherited=False,
py3=True, py2=True, ifdef=None):
py3=True, py2=True, ifdef=None, is_binop=False):
self.slot_name = slot_name
self.is_initialised_dynamically = dynamic
self.is_inherited = inherited
self.ifdef = ifdef
self.py3 = py3
self.py2 = py2
self.is_binop = is_binop
def preprocessor_guard_code(self):
ifdef = self.ifdef
......@@ -405,6 +406,17 @@ class SyntheticSlot(InternalMethodSlot):
return self.default_value
class BinopSlot(SyntheticSlot):
def __init__(self, signature, slot_name, left_method, **kargs):
assert left_method.startswith('__')
right_method = '__r' + left_method[2:]
SyntheticSlot.__init__(
self, slot_name, [left_method, right_method], "0", is_binop=True, **kargs)
# MethodSlot causes special method registration.
self.left_slot = MethodSlot(signature, "", left_method)
self.right_slot = MethodSlot(signature, "", right_method)
class RichcmpSlot(MethodSlot):
def slot_code(self, scope):
entry = scope.lookup_here(self.method_name)
......@@ -728,23 +740,23 @@ property_accessor_signatures = {
PyNumberMethods_Py3_GUARD = "PY_MAJOR_VERSION < 3 || (CYTHON_COMPILING_IN_PYPY && PY_VERSION_HEX < 0x03050000)"
PyNumberMethods = (
MethodSlot(binaryfunc, "nb_add", "__add__"),
MethodSlot(binaryfunc, "nb_subtract", "__sub__"),
MethodSlot(binaryfunc, "nb_multiply", "__mul__"),
MethodSlot(binaryfunc, "nb_divide", "__div__", ifdef = PyNumberMethods_Py3_GUARD),
MethodSlot(binaryfunc, "nb_remainder", "__mod__"),
MethodSlot(binaryfunc, "nb_divmod", "__divmod__"),
MethodSlot(ternaryfunc, "nb_power", "__pow__"),
BinopSlot(binaryfunc, "nb_add", "__add__"),
BinopSlot(binaryfunc, "nb_subtract", "__sub__"),
BinopSlot(binaryfunc, "nb_multiply", "__mul__"),
BinopSlot(binaryfunc, "nb_divide", "__div__", ifdef = PyNumberMethods_Py3_GUARD),
BinopSlot(binaryfunc, "nb_remainder", "__mod__"),
BinopSlot(binaryfunc, "nb_divmod", "__divmod__"),
BinopSlot(ternaryfunc, "nb_power", "__pow__"),
MethodSlot(unaryfunc, "nb_negative", "__neg__"),
MethodSlot(unaryfunc, "nb_positive", "__pos__"),
MethodSlot(unaryfunc, "nb_absolute", "__abs__"),
MethodSlot(inquiry, "nb_nonzero", "__nonzero__", py3 = ("nb_bool", "__bool__")),
MethodSlot(unaryfunc, "nb_invert", "__invert__"),
MethodSlot(binaryfunc, "nb_lshift", "__lshift__"),
MethodSlot(binaryfunc, "nb_rshift", "__rshift__"),
MethodSlot(binaryfunc, "nb_and", "__and__"),
MethodSlot(binaryfunc, "nb_xor", "__xor__"),
MethodSlot(binaryfunc, "nb_or", "__or__"),
BinopSlot(binaryfunc, "nb_lshift", "__lshift__"),
BinopSlot(binaryfunc, "nb_rshift", "__rshift__"),
BinopSlot(binaryfunc, "nb_and", "__and__"),
BinopSlot(binaryfunc, "nb_xor", "__xor__"),
BinopSlot(binaryfunc, "nb_or", "__or__"),
EmptySlot("nb_coerce", ifdef = PyNumberMethods_Py3_GUARD),
MethodSlot(unaryfunc, "nb_int", "__int__", fallback="__long__"),
MethodSlot(unaryfunc, "nb_long", "__long__", fallback="__int__", py3 = "<RESERVED>"),
......@@ -767,8 +779,8 @@ PyNumberMethods = (
# Added in release 2.2
# The following require the Py_TPFLAGS_HAVE_CLASS flag
MethodSlot(binaryfunc, "nb_floor_divide", "__floordiv__"),
MethodSlot(binaryfunc, "nb_true_divide", "__truediv__"),
BinopSlot(binaryfunc, "nb_floor_divide", "__floordiv__"),
BinopSlot(binaryfunc, "nb_true_divide", "__truediv__"),
MethodSlot(ibinaryfunc, "nb_inplace_floor_divide", "__ifloordiv__"),
MethodSlot(ibinaryfunc, "nb_inplace_true_divide", "__itruediv__"),
......@@ -776,7 +788,7 @@ PyNumberMethods = (
MethodSlot(unaryfunc, "nb_index", "__index__"),
# Added in release 3.5
MethodSlot(binaryfunc, "nb_matrix_multiply", "__matmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
BinopSlot(binaryfunc, "nb_matrix_multiply", "__matmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
MethodSlot(ibinaryfunc, "nb_inplace_matrix_multiply", "__imatmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
)
......
......@@ -278,3 +278,59 @@ __PYX_GOOD:
Py_XDECREF(setstate_cython);
return ret;
}
/////////////// BinopSlot ///////////////
static CYTHON_INLINE PyObject *{{func_name}}_maybe_call_super(PyObject *self, PyObject *other, PyObject* name) {
PyObject *res;
PyObject *method;
if (!Py_TYPE(self)->tp_base) {
return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
}
// TODO: Use _PyType_LookupId or similar.
method = PyObject_GetAttr((PyObject*) Py_TYPE(self)->tp_base, name);
if (!method) {
PyErr_Clear();
return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
}
res = __Pyx_PyObject_Call2Args(method, self, other);
Py_DECREF(method);
if (!res) {
return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
}
return res;
}
static PyObject *{{func_name}}(PyObject *left, PyObject *right) {
PyObject *res;
int maybe_self_is_left, maybe_self_is_right = 0;
maybe_self_is_left = Py_TYPE(left) == Py_TYPE(right)
|| (Py_TYPE(left)->tp_as_number && Py_TYPE(left)->tp_as_number->{{slot_name}} == &{{func_name}})
|| PyType_IsSubtype(Py_TYPE(left), {{type_cname}});
// Optimize for the common case where the left operation is defined (and successful).
if (!{{overloads_left}}) {
maybe_self_is_right = Py_TYPE(left) == Py_TYPE(right)
|| (Py_TYPE(right)->tp_as_number && Py_TYPE(right)->tp_as_number->{{slot_name}} == &{{func_name}})
|| PyType_IsSubtype(Py_TYPE(right), {{type_cname}});
}
if (maybe_self_is_left) {
if (maybe_self_is_right && !{{overloads_left}}) {
res = {{call_right}};
if (res != Py_NotImplemented) return res;
Py_DECREF(res);
maybe_self_is_right = 0; // Don't bother calling it again.
}
res = {{call_left}};
if (res != Py_NotImplemented) return res;
Py_DECREF(res);
}
if ({{overloads_left}}) {
maybe_self_is_right = Py_TYPE(left) == Py_TYPE(right)
|| (Py_TYPE(right)->tp_as_number && Py_TYPE(right)->tp_as_number->{{slot_name}} == &{{func_name}})
|| PyType_IsSubtype(Py_TYPE(right), {{type_cname}});
}
if (maybe_self_is_right) {
return {{call_right}};
}
return Py_INCREF(Py_NotImplemented), Py_NotImplemented;
}
cimport cython
@cython.c_api_binop_methods(False)
@cython.cclass
class Base(object):
"""
>>> Base() + 2
'Base.__add__(Base(), 2)'
>>> 2 + Base()
'Base.__radd__(Base(), 2)'
"""
def __add__(self, other):
return "Base.__add__(%s, %s)" % (self, other)
def __radd__(self, other):
return "Base.__radd__(%s, %s)" % (self, other)
def __repr__(self):
return "%s()" % (self.__class__.__name__)
@cython.c_api_binop_methods(False)
@cython.cclass
class OverloadLeft(Base):
"""
>>> OverloadLeft() + 2
'OverloadLeft.__add__(OverloadLeft(), 2)'
>>> 2 + OverloadLeft()
'Base.__radd__(OverloadLeft(), 2)'
>>> OverloadLeft() + Base()
'OverloadLeft.__add__(OverloadLeft(), Base())'
>>> Base() + OverloadLeft()
'Base.__add__(Base(), OverloadLeft())'
"""
def __add__(self, other):
return "OverloadLeft.__add__(%s, %s)" % (self, other)
@cython.c_api_binop_methods(False)
@cython.cclass
class OverloadRight(Base):
"""
>>> OverloadRight() + 2
'Base.__add__(OverloadRight(), 2)'
>>> 2 + OverloadRight()
'OverloadRight.__radd__(OverloadRight(), 2)'
>>> OverloadRight() + Base()
'Base.__add__(OverloadRight(), Base())'
>>> Base() + OverloadRight()
'OverloadRight.__radd__(OverloadRight(), Base())'
"""
def __radd__(self, other):
return "OverloadRight.__radd__(%s, %s)" % (self, other)
@cython.c_api_binop_methods(True)
@cython.cclass
class OverloadCApi(Base):
"""
>>> OverloadCApi() + 2
'OverloadCApi.__add__(OverloadCApi(), 2)'
>>> 2 + OverloadCApi()
'OverloadCApi.__add__(2, OverloadCApi())'
>>> OverloadCApi() + Base()
'OverloadCApi.__add__(OverloadCApi(), Base())'
>>> Base() + OverloadCApi()
'OverloadCApi.__add__(Base(), OverloadCApi())'
"""
def __add__(self, other):
return "OverloadCApi.__add__(%s, %s)" % (self, other)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment