Commit a8163914 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Add rewriting for 'a in b' expressions

parent a3477848
...@@ -1007,7 +1007,7 @@ std::string getCurrentPythonLine() { ...@@ -1007,7 +1007,7 @@ std::string getCurrentPythonLine() {
return "unknown:-1"; return "unknown:-1";
} }
void logByCurrentPythonLine(std::string& stat_name) { void logByCurrentPythonLine(const std::string& stat_name) {
std::string stat = stat_name + "<" + getCurrentPythonLine() + ">"; std::string stat = stat_name + "<" + getCurrentPythonLine() + ">";
Stats::log(Stats::getStatCounter(stat)); Stats::log(Stats::getStatCounter(stat));
} }
......
...@@ -46,7 +46,7 @@ ExecutionPoint getExecutionPoint(); ...@@ -46,7 +46,7 @@ ExecutionPoint getExecutionPoint();
std::string getCurrentPythonLine(); std::string getCurrentPythonLine();
// doesn't really belong in unwinding.h, since it's stats related, but it needs to unwind to get the current line... // doesn't really belong in unwinding.h, since it's stats related, but it needs to unwind to get the current line...
void logByCurrentPythonLine(std::string& stat_name); void logByCurrentPythonLine(const std::string& stat_name);
// Adds stack locals and closure locals into the locals dict, and returns it. // Adds stack locals and closure locals into the locals dict, and returns it.
Box* fastLocalsToBoxedLocals(); Box* fastLocalsToBoxedLocals();
......
...@@ -3727,6 +3727,19 @@ static bool convert3wayCompareResultToBool(Box* v, int op_type) { ...@@ -3727,6 +3727,19 @@ static bool convert3wayCompareResultToBool(Box* v, int op_type) {
}; };
} }
Box* nonzeroAndBox(Box* b, bool negate) {
if (likely(b->cls == bool_cls)) {
if (negate)
return boxBool(b != True);
return b;
}
bool t = b->nonzeroIC();
if (negate)
t = !t;
return boxBool(t);
}
Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrite_args) { Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrite_args) {
if (op_type == AST_TYPE::Is || op_type == AST_TYPE::IsNot) { if (op_type == AST_TYPE::Is || op_type == AST_TYPE::IsNot) {
bool neg = (op_type == AST_TYPE::IsNot); bool neg = (op_type == AST_TYPE::IsNot);
...@@ -3742,11 +3755,26 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrit ...@@ -3742,11 +3755,26 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrit
} }
if (op_type == AST_TYPE::In || op_type == AST_TYPE::NotIn) { if (op_type == AST_TYPE::In || op_type == AST_TYPE::NotIn) {
// TODO do rewrite
static BoxedString* contains_str = static_cast<BoxedString*>(PyString_InternFromString("__contains__")); static BoxedString* contains_str = static_cast<BoxedString*>(PyString_InternFromString("__contains__"));
Box* contained = callattrInternal1(rhs, contains_str, CLASS_ONLY, NULL, ArgPassSpec(1), lhs);
Box* contained;
RewriterVar* r_contained;
if (rewrite_args) {
CallRewriteArgs crewrite_args(rewrite_args->rewriter, rewrite_args->rhs, rewrite_args->destination);
crewrite_args.arg1 = rewrite_args->lhs;
contained = callattrInternal1(rhs, contains_str, CLASS_ONLY, &crewrite_args, ArgPassSpec(1), lhs);
if (!crewrite_args.out_success)
rewrite_args = NULL;
else if (contained)
r_contained = crewrite_args.out_rtn;
} else {
contained = callattrInternal1(rhs, contains_str, CLASS_ONLY, NULL, ArgPassSpec(1), lhs);
}
if (contained == NULL) { if (contained == NULL) {
rewrite_args = NULL;
int result = _PySequence_IterSearch(rhs, lhs, PY_ITERSEARCH_CONTAINS); int result = _PySequence_IterSearch(rhs, lhs, PY_ITERSEARCH_CONTAINS);
if (result < 0) if (result < 0)
throwCAPIException(); throwCAPIException();
...@@ -3754,6 +3782,14 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrit ...@@ -3754,6 +3782,14 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrit
return boxBool(result); return boxBool(result);
} }
if (rewrite_args) {
auto r_negate = rewrite_args->rewriter->loadConst((int)(op_type == AST_TYPE::NotIn));
RewriterVar* r_contained_box
= rewrite_args->rewriter->call(true, (void*)nonzeroAndBox, r_contained, r_negate);
rewrite_args->out_rtn = r_contained_box;
rewrite_args->out_success = true;
}
bool b; bool b;
if (contained->cls == bool_cls) if (contained->cls == bool_cls)
b = contained == True; b = contained == True;
......
...@@ -410,7 +410,7 @@ Box* setPop(BoxedSet* self) { ...@@ -410,7 +410,7 @@ Box* setPop(BoxedSet* self) {
Box* setContains(BoxedSet* self, Box* v) { Box* setContains(BoxedSet* self, Box* v) {
RELEASE_ASSERT(PyAnySet_Check(self), ""); RELEASE_ASSERT(PyAnySet_Check(self), "");
return boxBool(self->s.count(v) != 0); return boxBool(self->s.find(v) != self->s.end());
} }
Box* setEq(BoxedSet* self, BoxedSet* rhs) { Box* setEq(BoxedSet* self, BoxedSet* rhs) {
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment