Commit be95915c authored by Stefan Behnel's avatar Stefan Behnel

fix error handling in backported matrix multiplication and compare to actual behaviour in Py3.5

parent 1fed1015
...@@ -1353,7 +1353,8 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) { ...@@ -1353,7 +1353,8 @@ static CYTHON_INLINE PyObject* __Pyx_PyObject_CallNoArg(PyObject *func) {
#define __Pyx_PyNumber_MatrixMultiply(x,y) PyNumber_MatrixMultiply(x,y) #define __Pyx_PyNumber_MatrixMultiply(x,y) PyNumber_MatrixMultiply(x,y)
#define __Pyx_PyNumber_InPlaceMatrixMultiply(x,y) PyNumber_InPlaceMatrixMultiply(x,y) #define __Pyx_PyNumber_InPlaceMatrixMultiply(x,y) PyNumber_InPlaceMatrixMultiply(x,y)
#else #else
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y); #define __Pyx_PyNumber_MatrixMultiply(x,y) __Pyx__PyNumber_MatrixMultiply(x, y, "@")
static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name);
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y); static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y);
#endif #endif
...@@ -1392,7 +1393,7 @@ bad: ...@@ -1392,7 +1393,7 @@ bad:
return result; return result;
} }
static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) { static PyObject* __Pyx__PyNumber_MatrixMultiply(PyObject* x, PyObject* y, const char* op_name) {
PyObject *func; PyObject *func;
// FIXME: make subtype aware // FIXME: make subtype aware
// see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types // see note at https://docs.python.org/3/reference/datamodel.html#emulating-numeric-types
...@@ -1410,10 +1411,20 @@ static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) { ...@@ -1410,10 +1411,20 @@ static PyObject* __Pyx_PyNumber_MatrixMultiply(PyObject* x, PyObject* y) {
func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__")); func = __Pyx_PyObject_GetAttrStr(y, PYIDENT("__rmatmul__"));
if (func) { if (func) {
PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, x); PyObject *result = __Pyx_PyObject_CallMatrixMethod(func, x);
return result; if (result != Py_NotImplemented)
return result;
Py_DECREF(result);
} else {
if (!PyErr_ExceptionMatches(PyExc_AttributeError))
return NULL;
PyErr_Clear();
} }
Py_INCREF(Py_NotImplemented); PyErr_Format(PyExc_TypeError,
return Py_NotImplemented; "unsupported operand type(s) for %.2s: '%.100s' and '%.100s'",
op_name,
Py_TYPE(x)->tp_name,
Py_TYPE(y)->tp_name);
return NULL;
} }
static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) { static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) {
...@@ -1429,6 +1440,6 @@ static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y) ...@@ -1429,6 +1440,6 @@ static PyObject* __Pyx_PyNumber_InPlaceMatrixMultiply(PyObject* x, PyObject* y)
return NULL; return NULL;
PyErr_Clear(); PyErr_Clear();
} }
return __Pyx_PyNumber_MatrixMultiply(x, y); return __Pyx__PyNumber_MatrixMultiply(x, y, "@=");
} }
#endif #endif
...@@ -17,6 +17,28 @@ ExtMatMult(1) @ 22 ...@@ -17,6 +17,28 @@ ExtMatMult(1) @ 22
ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)')
>>> print(test_imatmul(a, b)) >>> print(test_imatmul(a, b))
ExtMatMult("ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') @ ExtMatMult(2)") ExtMatMult("ExtMatMult('ExtMatMult(1) @ ExtMatMult(2)') @ ExtMatMult(2)")
>>> x = y = 1
>>> x @ y
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @: 'int' and 'int'
>>> x @= y
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @=: 'int' and 'int'
>>> y = MatMult(22)
>>> x @= y
>>> print(x)
1 @ MatMult(22)
>>> x = MatMult(22)
>>> print(x @ 1)
MatMult(22) @ 1
>>> print(1 @ x)
1 @ MatMult(22)
>>> x @= 1
>>> print(x)
MatMult('MatMult(22) @ 1')
""" """
...@@ -71,6 +93,10 @@ def test_matmul(a, b): ...@@ -71,6 +93,10 @@ def test_matmul(a, b):
11 @ MatMult(2) 11 @ MatMult(2)
>>> print(test_matmul(MatMult('abc'), MatMult('def'))) >>> print(test_matmul(MatMult('abc'), MatMult('def')))
MatMult('abc') @ MatMult('def') MatMult('abc') @ MatMult('def')
>>> test_matmul(1, 2)
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @: 'int' and 'int'
""" """
return a @ b return a @ b
...@@ -81,6 +107,14 @@ def test_imatmul(a, b): ...@@ -81,6 +107,14 @@ def test_imatmul(a, b):
MatMult('MatMult(1) @ MatMult(2)') MatMult('MatMult(1) @ MatMult(2)')
>>> print(test_imatmul(MatMult('abc'), MatMult('def'))) >>> print(test_imatmul(MatMult('abc'), MatMult('def')))
MatMult("MatMult('abc') @ MatMult('def')") MatMult("MatMult('abc') @ MatMult('def')")
>>> print(test_imatmul(11, MatMult('def')))
11 @ MatMult('def')
>>> print(test_imatmul(MatMult('abc'), 11))
MatMult("MatMult('abc') @ 11")
>>> test_imatmul(1, 2)
Traceback (most recent call last):
TypeError: unsupported operand type(s) for @=: 'int' and 'int'
""" """
a @= b a @= b
return a return a
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment