Commit 267b8fb9 authored by scoder's avatar scoder

Merge pull request #288 from scoder/pep465

Implement PEP 465: dedicated infix operator for matrix multiplication
parents 1eb122af a3c4210f
......@@ -8843,6 +8843,16 @@ class TypeofNode(ExprNode):
#
#-------------------------------------------------------------------
try:
matmul_operator = operator.matmul
except AttributeError:
def matmul_operator(a, b):
try:
func = a.__matmul__
except AttributeError:
func = b.__rmatmul__
return func(a, b)
compile_time_binary_operators = {
'<': operator.lt,
'<=': operator.le,
......@@ -8864,6 +8874,7 @@ compile_time_binary_operators = {
'>>': operator.rshift,
'-': operator.sub,
'^': operator.xor,
'@': matmul_operator,
'in': lambda x, seq: x in seq,
'not_in': lambda x, seq: x not in seq,
}
......@@ -9180,10 +9191,11 @@ class NumBinopNode(BinopNode):
"+": "PyNumber_Add",
"-": "PyNumber_Subtract",
"*": "PyNumber_Multiply",
"@": "__Pyx_PyNumber_MatrixMultiply",
"/": "__Pyx_PyNumber_Divide",
"//": "PyNumber_FloorDivide",
"%": "PyNumber_Remainder",
"**": "PyNumber_Power"
"**": "PyNumber_Power",
}
overflow_op_names = {
......@@ -9282,6 +9294,17 @@ class MulNode(NumBinopNode):
return None
class MatMultNode(NumBinopNode):
# '@' operator.
def is_py_operation_types(self, type1, type2):
return True
def generate_evaluation_code(self, code):
code.globalstate.use_utility_code(UtilityCode.load_cached("MatrixMultiply", "ObjectHandling.c"))
super(MatMultNode, self).generate_evaluation_code(code)
class DivNode(NumBinopNode):
# '/' or '//' operator.
......@@ -10451,10 +10474,11 @@ binop_node_classes = {
"+": AddNode,
"-": SubNode,
"*": MulNode,
"@": MatMultNode,
"/": DivNode,
"//": DivNode,
"%": ModNode,
"**": PowNode
"**": PowNode,
}
def binop_node(pos, operator, operand1, operand2, inplace=False):
......
......@@ -10,6 +10,7 @@ char_prefixes = "cC"
any_string_prefix = raw_prefixes + string_prefixes + char_prefixes
IDENT = 'IDENT'
def make_lexicon():
from Cython.Plex import \
Str, Any, AnyBut, AnyChar, Rep, Rep1, Opt, Bol, Eol, Eof, \
......@@ -50,13 +51,12 @@ def make_lexicon():
Str('u') + four_hex | Str('x') + two_hex |
Str('U') + four_hex + four_hex | AnyChar)
deco = Str("@")
bra = Any("([{")
ket = Any(")]}")
punct = Any(":,;+-*/|&<>=.%`~^?!")
punct = Any(":,;+-*/|&<>=.%`~^?!@")
diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "//",
"+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=",
"<<=", ">>=", "**=", "//=", "->")
"<<=", ">>=", "**=", "//=", "->", "@=")
spaces = Rep1(Any(" \t\f"))
escaped_newline = Str("\\\n")
lineterm = Eol + Opt(Str("\n"))
......@@ -68,7 +68,6 @@ def make_lexicon():
(intliteral, 'INT'),
(fltconst, 'FLOAT'),
(imagconst, 'IMAG'),
(deco, 'DECORATOR'),
(punct | diphthong, TEXT),
(bra, Method('open_bracket_action')),
......
......@@ -267,10 +267,10 @@ def p_shift_expr(s):
def p_arith_expr(s):
return p_binop_expr(s, ('+', '-'), p_term)
#term: factor (('*'|'/'|'%') factor)*
#term: factor (('*'|'@'|'/'|'%'|'//') factor)*
def p_term(s):
return p_binop_expr(s, ('*', '/', '%', '//'), p_factor)
return p_binop_expr(s, ('*', '@', '/', '%', '//'), p_factor)
#factor: ('+'|'-'|'~'|'&'|typecast|sizeof) factor | power
......@@ -1129,7 +1129,7 @@ def p_expression_or_assignment(s):
expr = p_testlist_star_expr(s)
expr_list.append(expr)
if len(expr_list) == 1:
if re.match(r"([+*/\%^\&|-]|<<|>>|\*\*|//)=", s.sy):
if re.match(r"([+*/\%^\&|-]|<<|>>|\*\*|//|@)=", s.sy):
lhs = expr_list[0]
if isinstance(lhs, ExprNodes.SliceIndexNode):
# implementation requires IndexNode
......@@ -1837,7 +1837,7 @@ def p_statement(s, ctx, first_statement = 0):
return p_DEF_statement(s)
elif s.sy == 'IF':
return p_IF_statement(s, ctx)
elif s.sy == 'DECORATOR':
elif s.sy == '@':
if ctx.level not in ('module', 'class', 'c_class', 'function', 'property', 'module_pxd', 'c_class_pxd', 'other'):
s.error('decorator not allowed here')
s.level = ctx.level
......@@ -2884,7 +2884,7 @@ def p_ctypedef_statement(s, ctx):
def p_decorators(s):
decorators = []
while s.sy == 'DECORATOR':
while s.sy == '@':
pos = s.position()
s.next()
decstring = p_dotted_name(s, as_allowed=0)[2]
......
......@@ -703,7 +703,11 @@ PyNumberMethods = (
MethodSlot(ibinaryfunc, "nb_inplace_true_divide", "__itruediv__"),
# Added in release 2.5
MethodSlot(unaryfunc, "nb_index", "__index__", ifdef = "PY_VERSION_HEX >= 0x02050000")
MethodSlot(unaryfunc, "nb_index", "__index__"),
# Added in release 3.5
MethodSlot(binaryfunc, "nb_matrix_multiply", "__matmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
MethodSlot(ibinaryfunc, "nb_inplace_matrix_multiply", "__imatmul__", ifdef="PY_VERSION_HEX >= 0x03050000"),
)
PySequenceMethods = (
......
......@@ -1156,3 +1156,62 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_Call(PyObject *func, PyObject *arg
return result;
}
#endif
/////////////// MatrixMultiply.proto ///////////////
#if PY_VERSION_HEX >= 0x03050000
#define __Pyx_PyNumber_MatrixMultiply(x,y) PyNumber_MatrixMultiply(x,y)
#define __Pyx_PyNumber_InPlaceMatrixMultiply(x,y) PyNumber_InPlaceMatrixMultiply(x,y)
#else
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y);
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y);
#endif
/////////////// MatrixMultiply ///////////////
//@requires: PyObjectGetAttrStr
#if PY_VERSION_HEX < 0x03050000
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func;
// FIXME: make subtype aware
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__matmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, y, NULL);
Py_DECREF(func);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, x, NULL);
Py_DECREF(func);
return result;
}
Py_INCREF(Py_NotImplemented);
return Py_NotImplemented;
}
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
PyObject *func;
func = __Pyx_PyObject_GetAttrStr(x, PYIDENT("__imatmul__"));
if (func) {
PyObject *result = PyObject_CallFunctionObjArgs(func, y, NULL);
Py_DECREF(func);
if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
}
return __Pyx_PyNumber_MatrixMultiply(x, y);
}
#endif
import sys
if sys.version_info >= (3, 5):
__doc__ = """\
Note: support for providing Python special methods despite missing the C-level slot
is currently not supported.
>>> a, b = ExtMatMult(1), ExtMatMult(2)
>>> print(test_matmul(a, b))
ExtMatMult(1) @ ExtMatMult(2)
>>> print(test_matmul(a, 22))
ExtMatMult(1) @ 22
>>> print(test_matmul(11, b))
11 @ ExtMatMult(2)
>>> print(test_imatmul(a, b))
ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)')
>>> print(test_imatmul(a, b))
ExtMatMult("ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') @ ExtMatMult(2)")
"""
class MatMult(object):
def __init__(self, myself):
self.myself = myself
def __matmul__(self, other):
return '%r @ %r' % (self, other)
def __rmatmul__(self, other):
return '%r @ %r' % (other, self)
def __imatmul__(self, other):
self.myself = '%r @ %r' % (self, other)
return self
def __repr__(self):
return 'MatMult(%r)' % self.myself
cdef class ExtMatMult:
"""
Note: support for providing Python special methods despite missing the C-level slot
is currently not supported.
"""
cdef object myself
def __init__(self, myself):
self.myself = myself
def __matmul__(self, other):
return '%r @ %r' % (self, other)
def __rmatmul__(self, other):
return '%r @ %r' % (other, self)
def __imatmul__(self, other):
self.myself = '%r @ %r' % (self, other)
return self
def __repr__(self):
return 'ExtMatMult(%r)' % self.myself
def test_matmul(a, b):
"""
>>> print(test_matmul(MatMult(1), MatMult(2)))
MatMult(1) @ MatMult(2)
>>> print(test_matmul(MatMult(1), 22))
MatMult(1) @ 22
>>> print(test_matmul(11, MatMult(2)))
11 @ MatMult(2)
>>> print(test_matmul(MatMult('abc'), MatMult('def')))
MatMult('abc') @ MatMult('def')
"""
return a @ b
def test_imatmul(a, b):
"""
>>> print(test_imatmul(MatMult(1), MatMult(2)))
MatMult('MatMult(1) @ MatMult(2)')
>>> print(test_imatmul(MatMult('abc'), MatMult('def')))
MatMult("MatMult('abc') @ MatMult('def')")
"""
a @= b
return a
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment