Commit 5816b3df authored by Marius Wachtler's avatar Marius Wachtler

binop: use PyNumber_* for user defined classes

this speeds up a simple numpy benchmark by about 10x.
parent c429f19d
...@@ -1846,77 +1846,73 @@ extern "C" int PyNumber_Check(PyObject* obj) noexcept { ...@@ -1846,77 +1846,73 @@ extern "C" int PyNumber_Check(PyObject* obj) noexcept {
return obj->cls->tp_as_number && (obj->cls->tp_as_number->nb_int || obj->cls->tp_as_number->nb_float); return obj->cls->tp_as_number && (obj->cls->tp_as_number->nb_int || obj->cls->tp_as_number->nb_float);
} }
extern "C" PyObject* PyNumber_Add(PyObject* lhs, PyObject* rhs) noexcept { extern "C" PyObject* PyNumber_Add(PyObject* v, PyObject* w) noexcept {
try { PyObject* result = binary_op1(v, w, NB_SLOT(nb_add));
return binop(lhs, rhs, AST_TYPE::Add); if (result == Py_NotImplemented) {
} catch (ExcInfo e) { PySequenceMethods* m = v->cls->tp_as_sequence;
setCAPIException(e); Py_DECREF(result);
return nullptr; if (m && m->sq_concat) {
return (*m->sq_concat)(v, w);
}
result = binop_type_error(v, w, "+");
} }
return result;
} }
extern "C" PyObject* PyNumber_Subtract(PyObject* lhs, PyObject* rhs) noexcept { #define BINARY_FUNC(func, op, op_name) \
try { extern "C" PyObject* func(PyObject* v, PyObject* w) noexcept { return binary_op(v, w, NB_SLOT(op), op_name); }
return binop(lhs, rhs, AST_TYPE::Sub);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_Multiply(PyObject* lhs, PyObject* rhs) noexcept { BINARY_FUNC(PyNumber_Or, nb_or, "|")
try { BINARY_FUNC(PyNumber_Xor, nb_xor, "^")
return binop(lhs, rhs, AST_TYPE::Mult); BINARY_FUNC(PyNumber_And, nb_and, "&")
} catch (ExcInfo e) { BINARY_FUNC(PyNumber_Lshift, nb_lshift, "<<")
setCAPIException(e); BINARY_FUNC(PyNumber_Rshift, nb_rshift, ">>")
return nullptr; BINARY_FUNC(PyNumber_Subtract, nb_subtract, "-")
} BINARY_FUNC(PyNumber_Divide, nb_divide, "/")
} BINARY_FUNC(PyNumber_Divmod, nb_divmod, "divmod()")
extern "C" PyObject* PyNumber_Divide(PyObject* lhs, PyObject* rhs) noexcept { static PyObject* sequence_repeat(ssizeargfunc repeatfunc, PyObject* seq, PyObject* n) noexcept {
try { Py_ssize_t count;
return binop(lhs, rhs, AST_TYPE::Div); if (PyIndex_Check(n)) {
} catch (ExcInfo e) { count = PyNumber_AsSsize_t(n, PyExc_OverflowError);
setCAPIException(e); if (count == -1 && PyErr_Occurred())
return nullptr; return NULL;
} else {
return type_error("can't multiply sequence by "
"non-int of type '%.200s'",
n);
} }
return (*repeatfunc)(seq, count);
} }
extern "C" PyObject* PyNumber_FloorDivide(PyObject* lhs, PyObject* rhs) noexcept { extern "C" PyObject* PyNumber_Multiply(PyObject* v, PyObject* w) noexcept {
try { PyObject* result = binary_op1(v, w, NB_SLOT(nb_multiply));
return binop(lhs, rhs, AST_TYPE::FloorDiv); if (result == Py_NotImplemented) {
} catch (ExcInfo e) { PySequenceMethods* mv = v->cls->tp_as_sequence;
setCAPIException(e); PySequenceMethods* mw = w->cls->tp_as_sequence;
return nullptr; Py_DECREF(result);
if (mv && mv->sq_repeat) {
return sequence_repeat(mv->sq_repeat, v, w);
} else if (mw && mw->sq_repeat) {
return sequence_repeat(mw->sq_repeat, w, v);
}
result = binop_type_error(v, w, "*");
} }
return result;
} }
extern "C" PyObject* PyNumber_TrueDivide(PyObject* lhs, PyObject* rhs) noexcept { extern "C" PyObject* PyNumber_FloorDivide(PyObject* v, PyObject* w) noexcept {
try { /* XXX tp_flags test */
return binop(lhs, rhs, AST_TYPE::TrueDiv); return binary_op(v, w, NB_SLOT(nb_floor_divide), "//");
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
} }
extern "C" PyObject* PyNumber_Remainder(PyObject* lhs, PyObject* rhs) noexcept { extern "C" PyObject* PyNumber_TrueDivide(PyObject* v, PyObject* w) noexcept {
try { /* XXX tp_flags test */
return binop(lhs, rhs, AST_TYPE::Mod); return binary_op(v, w, NB_SLOT(nb_true_divide), "/");
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
} }
extern "C" PyObject* PyNumber_Divmod(PyObject* lhs, PyObject* rhs) noexcept { extern "C" PyObject* PyNumber_Remainder(PyObject* v, PyObject* w) noexcept {
try { return binary_op(v, w, NB_SLOT(nb_remainder), "%");
return binop(lhs, rhs, AST_TYPE::DivMod);
} catch (ExcInfo e) {
e.clear();
fatalOrError(PyExc_NotImplementedError, "unimplemented");
return nullptr;
}
} }
extern "C" PyObject* PyNumber_Power(PyObject* v, PyObject* w, PyObject* z) noexcept { extern "C" PyObject* PyNumber_Power(PyObject* v, PyObject* w, PyObject* z) noexcept {
...@@ -1960,57 +1956,15 @@ extern "C" PyObject* PyNumber_Absolute(PyObject* o) noexcept { ...@@ -1960,57 +1956,15 @@ extern "C" PyObject* PyNumber_Absolute(PyObject* o) noexcept {
} }
extern "C" PyObject* PyNumber_Invert(PyObject* o) noexcept { extern "C" PyObject* PyNumber_Invert(PyObject* o) noexcept {
try { PyNumberMethods* m;
return unaryop(o, AST_TYPE::Invert);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_Lshift(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::LShift);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_Rshift(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::RShift);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_And(PyObject* lhs, PyObject* rhs) noexcept {
try {
return binop(lhs, rhs, AST_TYPE::BitAnd);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
}
extern "C" PyObject* PyNumber_Xor(PyObject* lhs, PyObject* rhs) noexcept { if (o == NULL)
try { return null_error();
return binop(lhs, rhs, AST_TYPE::BitXor); m = o->cls->tp_as_number;
} catch (ExcInfo e) { if (m && m->nb_invert)
setCAPIException(e); return (*m->nb_invert)(o);
return nullptr;
}
}
extern "C" PyObject* PyNumber_Or(PyObject* lhs, PyObject* rhs) noexcept { return type_error("bad operand type for unary ~: '%.200s'", o);
try {
return binop(lhs, rhs, AST_TYPE::BitOr);
} catch (ExcInfo e) {
setCAPIException(e);
return nullptr;
}
} }
extern "C" PyObject* PyNumber_InPlaceAdd(PyObject* v, PyObject* w) noexcept { extern "C" PyObject* PyNumber_InPlaceAdd(PyObject* v, PyObject* w) noexcept {
...@@ -2036,20 +1990,6 @@ extern "C" PyObject* PyNumber_InPlaceSubtract(PyObject* v, PyObject* w) noexcept ...@@ -2036,20 +1990,6 @@ extern "C" PyObject* PyNumber_InPlaceSubtract(PyObject* v, PyObject* w) noexcept
return binary_iop(v, w, NB_SLOT(nb_inplace_subtract), NB_SLOT(nb_subtract), "-="); return binary_iop(v, w, NB_SLOT(nb_inplace_subtract), NB_SLOT(nb_subtract), "-=");
} }
static PyObject* sequence_repeat(ssizeargfunc repeatfunc, PyObject* seq, PyObject* n) {
Py_ssize_t count;
if (PyIndex_Check(n)) {
count = PyNumber_AsSsize_t(n, PyExc_OverflowError);
if (count == -1 && PyErr_Occurred())
return NULL;
} else {
return type_error("can't multiply sequence by "
"non-int of type '%.200s'",
n);
}
return (*repeatfunc)(seq, count);
}
extern "C" PyObject* PyNumber_InPlaceMultiply(PyObject* v, PyObject* w) noexcept { extern "C" PyObject* PyNumber_InPlaceMultiply(PyObject* v, PyObject* w) noexcept {
PyObject* result = binary_iop1(v, w, NB_SLOT(nb_inplace_multiply), NB_SLOT(nb_multiply)); PyObject* result = binary_iop1(v, w, NB_SLOT(nb_inplace_multiply), NB_SLOT(nb_multiply));
if (result == Py_NotImplemented) { if (result == Py_NotImplemented) {
......
...@@ -5476,6 +5476,77 @@ Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteAr ...@@ -5476,6 +5476,77 @@ Box* binopInternal(Box* lhs, Box* rhs, int op_type, bool inplace, BinopRewriteAr
rewrite_args = NULL; rewrite_args = NULL;
} }
// Currently can't patchpoint user-defined binops since we can't assume that just because
// resolving it one way right now (ex, using the value from lhs.__add__) means that later
// we'll resolve it the same way, even for the same argument types.
// TODO implement full resolving semantics inside the rewrite?
bool can_patchpoint = !lhs->cls->is_user_defined && !rhs->cls->is_user_defined;
if (!can_patchpoint) {
PyObject* (*func)(PyObject*, PyObject*) = NULL;
switch (op_type) {
case AST_TYPE::Add:
func = inplace ? PyNumber_InPlaceAdd : PyNumber_Add;
break;
case AST_TYPE::BitOr:
func = inplace ? PyNumber_InPlaceOr : PyNumber_Or;
break;
case AST_TYPE::BitXor:
func = inplace ? PyNumber_InPlaceXor : PyNumber_Xor;
break;
case AST_TYPE::BitAnd:
func = inplace ? PyNumber_InPlaceAnd : PyNumber_And;
break;
case AST_TYPE::LShift:
func = inplace ? PyNumber_InPlaceLshift : PyNumber_Lshift;
break;
case AST_TYPE::RShift:
func = inplace ? PyNumber_InPlaceRshift : PyNumber_Rshift;
break;
case AST_TYPE::Sub:
func = inplace ? PyNumber_InPlaceSubtract : PyNumber_Subtract;
break;
case AST_TYPE::Div:
func = inplace ? PyNumber_InPlaceDivide : PyNumber_Divide;
break;
case AST_TYPE::Mod:
func = inplace ? PyNumber_InPlaceRemainder : PyNumber_Remainder;
break;
case AST_TYPE::Mult:
func = inplace ? PyNumber_InPlaceMultiply : PyNumber_Multiply;
break;
case AST_TYPE::FloorDiv:
func = inplace ? PyNumber_InPlaceFloorDivide : PyNumber_FloorDivide;
break;
case AST_TYPE::TrueDiv:
func = inplace ? PyNumber_InPlaceTrueDivide : PyNumber_TrueDivide;
break;
case AST_TYPE::DivMod:
func = inplace ? NULL : PyNumber_Divmod;
break;
};
if (func) {
if (rewrite_args) {
rewrite_args->lhs->addAttrGuard(offsetof(Box, cls), (intptr_t)lhs->cls);
rewrite_args->rhs->addAttrGuard(offsetof(Box, cls), (intptr_t)rhs->cls);
RewriterVar* r_ret = rewrite_args->rewriter->call(true, (void*)func, rewrite_args->lhs,
rewrite_args->rhs)->setType(RefType::OWNED);
rewrite_args->rewriter->checkAndThrowCAPIException(r_ret);
rewrite_args->out_rtn = r_ret;
rewrite_args->out_success = true;
}
Box* rtn = func(lhs, rhs);
if (!rtn)
throwCAPIException();
return rtn;
}
}
if (!can_patchpoint)
rewrite_args = NULL;
RewriterVar* r_lhs = NULL; RewriterVar* r_lhs = NULL;
RewriterVar* r_rhs = NULL; RewriterVar* r_rhs = NULL;
if (rewrite_args) { if (rewrite_args) {
...@@ -5546,7 +5617,6 @@ template Box* binopInternal<NOT_REWRITABLE>(Box*, Box*, int, bool, BinopRewriteA ...@@ -5546,7 +5617,6 @@ template Box* binopInternal<NOT_REWRITABLE>(Box*, Box*, int, bool, BinopRewriteA
extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) { extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) {
STAT_TIMER(t0, "us_timer_slowpath_binop", 10); STAT_TIMER(t0, "us_timer_slowpath_binop", 10);
bool can_patchpoint = !lhs->cls->is_user_defined && !rhs->cls->is_user_defined;
#if 0 #if 0
static uint64_t* st_id = Stats::getStatCounter("us_timer_slowpath_binop_patchable"); static uint64_t* st_id = Stats::getStatCounter("us_timer_slowpath_binop_patchable");
static uint64_t* st_id_nopatch = Stats::getStatCounter("us_timer_slowpath_binop_nopatch"); static uint64_t* st_id_nopatch = Stats::getStatCounter("us_timer_slowpath_binop_nopatch");
...@@ -5561,14 +5631,8 @@ extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) { ...@@ -5561,14 +5631,8 @@ extern "C" Box* binop(Box* lhs, Box* rhs, int op_type) {
// int id = Stats::getStatId("slowpath_binop_" + *getTypeName(lhs) + op_name + *getTypeName(rhs)); // int id = Stats::getStatId("slowpath_binop_" + *getTypeName(lhs) + op_name + *getTypeName(rhs));
// Stats::log(id); // Stats::log(id);
std::unique_ptr<Rewriter> rewriter((Rewriter*)NULL); std::unique_ptr<Rewriter> rewriter(
// Currently can't patchpoint user-defined binops since we can't assume that just because Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
// resolving it one way right now (ex, using the value from lhs.__add__) means that later
// we'll resolve it the same way, even for the same argument types.
// TODO implement full resolving semantics inside the rewrite?
if (can_patchpoint)
rewriter.reset(
Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
Box* rtn; Box* rtn;
if (rewriter.get()) { if (rewriter.get()) {
...@@ -5605,15 +5669,8 @@ extern "C" Box* augbinop(Box* lhs, Box* rhs, int op_type) { ...@@ -5605,15 +5669,8 @@ extern "C" Box* augbinop(Box* lhs, Box* rhs, int op_type) {
// int id = Stats::getStatId("slowpath_augbinop_" + *getTypeName(lhs) + op_name + *getTypeName(rhs)); // int id = Stats::getStatId("slowpath_augbinop_" + *getTypeName(lhs) + op_name + *getTypeName(rhs));
// Stats::log(id); // Stats::log(id);
std::unique_ptr<Rewriter> rewriter((Rewriter*)NULL); std::unique_ptr<Rewriter> rewriter(
// Currently can't patchpoint user-defined binops since we can't assume that just because Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
// resolving it one way right now (ex, using the value from lhs.__add__) means that later
// we'll resolve it the same way, even for the same argument types.
// TODO implement full resolving semantics inside the rewrite?
bool can_patchpoint = !lhs->cls->is_user_defined && !rhs->cls->is_user_defined;
if (can_patchpoint)
rewriter.reset(
Rewriter::createRewriter(__builtin_extract_return_addr(__builtin_return_address(0)), 3, "binop"));
Box* rtn; Box* rtn;
if (rewriter.get()) { if (rewriter.get()) {
......
...@@ -33,7 +33,7 @@ def install_and_test_lxml(): ...@@ -33,7 +33,7 @@ def install_and_test_lxml():
subprocess.check_call([PYTHON_EXE, "setup.py", "build_ext", "-i", "--with-cython"], cwd=LXML_DIR) subprocess.check_call([PYTHON_EXE, "setup.py", "build_ext", "-i", "--with-cython"], cwd=LXML_DIR)
expected = [{'ran': 1381, 'failures': 3, 'errors': 1}] expected = [{'ran': 1381, 'failures': 3}]
run_test([PYTHON_EXE, "test.py"], cwd=LXML_DIR, expected=expected) run_test([PYTHON_EXE, "test.py"], cwd=LXML_DIR, expected=expected)
create_virtenv(ENV_NAME, None, force_create = True) create_virtenv(ENV_NAME, None, force_create = True)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment