Commit f3e03b35 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Convert "a in (b, c)" to "a == b or a == c"

Do this by adding "contains" to our codegen type system, and
implement a special contains on the unboxedtuple type.

This makes this operation quite a lot faster, but it looks like
largely because we don't implement a couple optimizations that
we should:
- we create a new tuple object every time we hit that line
- our generic contains code goes through compare(), which returns
  a box (since "<" and friends can return non-bools), but contains
  will always return a bool, so we have a bunch of extra boxing/unboxing

We probably should separate out the contains logic from the rest of the
comparisons, since it works quite differently and doesn't
gain anything by being there.
parent b0db3e65
...@@ -455,6 +455,10 @@ public: ...@@ -455,6 +455,10 @@ public:
return new ConcreteCompilerVariable(UNKNOWN, rtn, true); return new ConcreteCompilerVariable(UNKNOWN, rtn, true);
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
return lhs->binexp(emitter, info, var, AST_TYPE::In, Compare);
}
Box* deserializeFromFrame(const FrameVals& vals) override { Box* deserializeFromFrame(const FrameVals& vals) override {
assert(vals.size() == 1); assert(vals.size() == 1);
return reinterpret_cast<Box*>(vals[0]); return reinterpret_cast<Box*>(vals[0]);
...@@ -1081,6 +1085,10 @@ public: ...@@ -1081,6 +1085,10 @@ public:
} }
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
return makeBool(false);
}
ConcreteCompilerType* getBoxType() override { return BOXED_INT; } ConcreteCompilerType* getBoxType() override { return BOXED_INT; }
Box* deserializeFromFrame(const FrameVals& vals) override { Box* deserializeFromFrame(const FrameVals& vals) override {
...@@ -1319,6 +1327,10 @@ public: ...@@ -1319,6 +1327,10 @@ public:
return rtn; return rtn;
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
return makeBool(false);
}
ConcreteCompilerType* getBoxType() override { return BOXED_FLOAT; } ConcreteCompilerType* getBoxType() override { return BOXED_FLOAT; }
Box* deserializeFromFrame(const FrameVals& vals) override { Box* deserializeFromFrame(const FrameVals& vals) override {
...@@ -1688,6 +1700,10 @@ public: ...@@ -1688,6 +1700,10 @@ public:
return rtn; return rtn;
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
return UNKNOWN->contains(emitter, info, var, lhs);
}
CompilerVariable* getitem(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* slice) override { CompilerVariable* getitem(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* slice) override {
static BoxedString* attr = static_cast<BoxedString*>(PyString_InternFromString("__getitem__")); static BoxedString* attr = static_cast<BoxedString*>(PyString_InternFromString("__getitem__"));
bool no_attribute = false; bool no_attribute = false;
...@@ -1914,6 +1930,13 @@ public: ...@@ -1914,6 +1930,13 @@ public:
return rtn; return rtn;
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
ConcreteCompilerVariable* converted = var->makeConverted(emitter, STR);
CompilerVariable* rtn = converted->contains(emitter, info, lhs);
converted->decvref(emitter);
return rtn;
}
ConcreteCompilerVariable* nonzero(IREmitter& emitter, const OpInfo& info, VAR* var) override { ConcreteCompilerVariable* nonzero(IREmitter& emitter, const OpInfo& info, VAR* var) override {
return makeBool(var->getValue()->size() != 0); return makeBool(var->getValue()->size() != 0);
} }
...@@ -2033,6 +2056,10 @@ public: ...@@ -2033,6 +2056,10 @@ public:
return rtn; return rtn;
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
return makeBool(false);
}
ConcreteCompilerType* getBoxType() override { return BOXED_BOOL; } ConcreteCompilerType* getBoxType() override { return BOXED_BOOL; }
Box* deserializeFromFrame(const FrameVals& vals) override { Box* deserializeFromFrame(const FrameVals& vals) override {
...@@ -2196,6 +2223,41 @@ public: ...@@ -2196,6 +2223,41 @@ public:
return rtn; return rtn;
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
llvm::SmallVector<std::pair<llvm::BasicBlock*, llvm::Value*>, 4> phi_incoming;
llvm::BasicBlock* end = emitter.createBasicBlock();
for (CompilerVariable* e : *var->getValue()) {
CompilerVariable* eq = lhs->binexp(emitter, info, e, AST_TYPE::Eq, Compare);
ConcreteCompilerVariable* eq_nonzero = eq->nonzero(emitter, info);
assert(eq_nonzero->getType() == BOOL);
llvm::Value* raw = i1FromBool(emitter, eq_nonzero);
phi_incoming.push_back(std::make_pair(emitter.currentBasicBlock(), getConstantInt(1, g.i1)));
llvm::BasicBlock* new_bb = emitter.createBasicBlock();
new_bb->moveAfter(emitter.currentBasicBlock());
emitter.getBuilder()->CreateCondBr(raw, end, new_bb);
emitter.setCurrentBasicBlock(new_bb);
}
// TODO This last block is unnecessary:
phi_incoming.push_back(std::make_pair(emitter.currentBasicBlock(), getConstantInt(0, g.i1)));
emitter.getBuilder()->CreateBr(end);
end->moveAfter(emitter.currentBasicBlock());
emitter.setCurrentBasicBlock(end);
auto phi = emitter.getBuilder()->CreatePHI(g.i1, phi_incoming.size());
for (auto p : phi_incoming) {
phi->addIncoming(p.second, p.first);
}
return boolFromI1(emitter, phi);
}
CompilerVariable* callattr(IREmitter& emitter, const OpInfo& info, VAR* var, BoxedString* attr, CallattrFlags flags, CompilerVariable* callattr(IREmitter& emitter, const OpInfo& info, VAR* var, BoxedString* attr, CallattrFlags flags,
const std::vector<CompilerVariable*>& args, const std::vector<CompilerVariable*>& args,
const std::vector<BoxedString*>* keyword_names) override { const std::vector<BoxedString*>* keyword_names) override {
...@@ -2320,6 +2382,10 @@ public: ...@@ -2320,6 +2382,10 @@ public:
return undefVariable(); return undefVariable();
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs) override {
return undefVariable();
}
CompilerVariable* getitem(IREmitter& emitter, const OpInfo& info, ConcreteCompilerVariable* var, CompilerVariable* getitem(IREmitter& emitter, const OpInfo& info, ConcreteCompilerVariable* var,
CompilerVariable* slice) override { CompilerVariable* slice) override {
return undefVariable(); return undefVariable();
......
...@@ -152,6 +152,7 @@ public: ...@@ -152,6 +152,7 @@ public:
printf("binexp not defined for %s\n", debugName().c_str()); printf("binexp not defined for %s\n", debugName().c_str());
abort(); abort();
} }
virtual CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, VAR* var, CompilerVariable* lhs);
virtual llvm::Value* makeClassCheck(IREmitter& emitter, VAR* var, BoxedClass* c) { virtual llvm::Value* makeClassCheck(IREmitter& emitter, VAR* var, BoxedClass* c) {
printf("makeClassCheck not defined for %s\n", debugName().c_str()); printf("makeClassCheck not defined for %s\n", debugName().c_str());
abort(); abort();
...@@ -277,6 +278,7 @@ public: ...@@ -277,6 +278,7 @@ public:
virtual CompilerVariable* getPystonIter(IREmitter& emitter, const OpInfo& info) = 0; virtual CompilerVariable* getPystonIter(IREmitter& emitter, const OpInfo& info) = 0;
virtual CompilerVariable* binexp(IREmitter& emitter, const OpInfo& info, CompilerVariable* rhs, virtual CompilerVariable* binexp(IREmitter& emitter, const OpInfo& info, CompilerVariable* rhs,
AST_TYPE::AST_TYPE op_type, BinExpType exp_type) = 0; AST_TYPE::AST_TYPE op_type, BinExpType exp_type) = 0;
virtual CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, CompilerVariable* lhs) = 0;
virtual void serializeToFrame(std::vector<llvm::Value*>& stackmap_args) = 0; virtual void serializeToFrame(std::vector<llvm::Value*>& stackmap_args) = 0;
...@@ -369,6 +371,9 @@ public: ...@@ -369,6 +371,9 @@ public:
BinExpType exp_type) override { BinExpType exp_type) override {
return type->binexp(emitter, info, this, rhs, op_type, exp_type); return type->binexp(emitter, info, this, rhs, op_type, exp_type);
} }
CompilerVariable* contains(IREmitter& emitter, const OpInfo& info, CompilerVariable* lhs) override {
return type->contains(emitter, info, this, lhs);
}
llvm::Value* makeClassCheck(IREmitter& emitter, BoxedClass* cls) override { llvm::Value* makeClassCheck(IREmitter& emitter, BoxedClass* cls) override {
return type->makeClassCheck(emitter, this, cls); return type->makeClassCheck(emitter, this, cls);
...@@ -421,6 +426,15 @@ CompilerVariable* _ValuedCompilerType<V>::getPystonIter(IREmitter& emitter, cons ...@@ -421,6 +426,15 @@ CompilerVariable* _ValuedCompilerType<V>::getPystonIter(IREmitter& emitter, cons
return r; return r;
} }
template <typename V>
CompilerVariable* _ValuedCompilerType<V>::contains(IREmitter& emitter, const OpInfo& info, VAR* var,
CompilerVariable* rhs) {
ConcreteCompilerVariable* converted = makeConverted(emitter, var, getBoxType());
auto r = UNKNOWN->contains(emitter, info, converted, rhs);
converted->decvref(emitter);
return r;
}
template <typename V> template <typename V>
std::vector<CompilerVariable*> _ValuedCompilerType<V>::unpack(IREmitter& emitter, const OpInfo& info, VAR* var, std::vector<CompilerVariable*> _ValuedCompilerType<V>::unpack(IREmitter& emitter, const OpInfo& info, VAR* var,
int num_into) { int num_into) {
......
...@@ -766,6 +766,19 @@ private: ...@@ -766,6 +766,19 @@ private:
assert(left); assert(left);
assert(right); assert(right);
if (type == AST_TYPE::In || type == AST_TYPE::NotIn) {
CompilerVariable* r = right->contains(emitter, getOpInfoForNode(node, unw_info), left);
assert(r->getType() == BOOL);
if (type == AST_TYPE::NotIn) {
ConcreteCompilerVariable* converted = r->makeConverted(emitter, BOOL);
// TODO: would be faster to just do unboxBoolNegated
llvm::Value* raw = i1FromBool(emitter, converted);
raw = emitter.getBuilder()->CreateXor(raw, getConstantInt(1, g.i1));
r = boolFromI1(emitter, raw);
}
return r;
}
return left->binexp(emitter, getOpInfoForNode(node, unw_info), right, type, exp_type); return left->binexp(emitter, getOpInfoForNode(node, unw_info), right, type, exp_type);
} }
......
...@@ -219,11 +219,11 @@ public: ...@@ -219,11 +219,11 @@ public:
if (!lookup_success) { if (!lookup_success) {
llvm::Constant* int_val llvm::Constant* int_val
= llvm::ConstantInt::get(g.i64, reinterpret_cast<uintptr_t>(addr), false); = llvm::ConstantInt::get(g.i64, reinterpret_cast<uintptr_t>(addr), false);
llvm::Constant* ptr_val = llvm::ConstantExpr::getIntToPtr(int_val, g.i8_ptr); llvm::Constant* ptr_val = llvm::ConstantExpr::getIntToPtr(int_val, g.i8);
ii->setArgOperand(i, ptr_val); ii->setArgOperand(i, ptr_val);
continue; continue;
} else { } else {
ii->setArgOperand(i, module->getOrInsertGlobal(name, g.i8_ptr)); ii->setArgOperand(i, module->getOrInsertGlobal(name, g.i8));
} }
} }
} }
......
...@@ -223,10 +223,8 @@ Box* tupleNonzero(BoxedTuple* self) { ...@@ -223,10 +223,8 @@ Box* tupleNonzero(BoxedTuple* self) {
Box* tupleContains(BoxedTuple* self, Box* elt) { Box* tupleContains(BoxedTuple* self, Box* elt) {
int size = self->size(); int size = self->size();
for (int i = 0; i < size; i++) { for (Box* e : *self) {
Box* e = self->elts[i]; int r = PyObject_RichCompareBool(elt, e, Py_EQ);
int r = PyObject_RichCompareBool(e, elt, Py_EQ);
if (r == -1) if (r == -1)
throwCAPIException(); throwCAPIException();
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment