Commit 92117aab authored by Kevin Modzelewski's avatar Kevin Modzelewski

List comprehension support

Should support:
- multiple comprehensions
- multiple if conditions
- nested control flow expressions
- OSR'ing from the list comprehension

Though it tends to hit the OSR bug in the previous commit.

Some extra changes that could have been split out:
- use pointers-to-const instead of references-to-const for attribute-name passing,
  to make it harder to bind to a temporary name that will go away.
- add a 'cls_only' flag to getattr / getattrType to not have to special-case clsattrs
  (or simply get it wrong, in the case of getattrType)
parent 33e65f3d
......@@ -170,6 +170,7 @@ class NameCollectorVisitor : public ASTVisitor {
virtual bool visit_break(AST_Break *node) { return false; }
virtual bool visit_call(AST_Call *node) { return false; }
virtual bool visit_compare(AST_Compare *node) { return false; }
virtual bool visit_comprehension(AST_comprehension *node) { return false; }
//virtual bool visit_classdef(AST_ClassDef *node) { return false; }
virtual bool visit_continue(AST_Continue *node) { return false; }
virtual bool visit_dict(AST_Dict *node) { return false; }
......@@ -181,6 +182,7 @@ class NameCollectorVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node) { return false; }
//virtual bool visit_keyword(AST_keyword *node) { return false; }
virtual bool visit_list(AST_List *node) { return false; }
virtual bool visit_listcomp(AST_ListComp *node) { return false; }
//virtual bool visit_module(AST_Module *node) { return false; }
//virtual bool visit_name(AST_Name *node) { return false; }
virtual bool visit_num(AST_Num *node) { return false; }
......
......@@ -159,7 +159,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
virtual void* visit_attribute(AST_Attribute *node) {
CompilerType *t = getType(node->value);
assert(node->ctx_type == AST_TYPE::Load);
CompilerType *rtn = t->getattrType(node->attr);
CompilerType *rtn = t->getattrType(&node->attr, false);
//if (speculation != TypeAnalysis::NONE && (node->attr == "x" || node->attr == "y" || node->attr == "z")) {
//rtn = processSpeculation(float_cls, node, rtn);
......@@ -175,7 +175,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
virtual void* visit_clsattribute(AST_ClsAttribute *node) {
CompilerType *t = getType(node->value);
CompilerType *rtn = t->getattrType(node->attr);
CompilerType *rtn = t->getattrType(&node->attr, true);
if (VERBOSITY() >= 2 && rtn == UNDEF) {
printf("Think %s.%s is undefined, at %d:%d\n", t->debugName().c_str(), node->attr.c_str(), node->lineno, node->col_offset);
print_ast(node);
......@@ -190,7 +190,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
// TODO this isn't the exact behavior
std::string name = getOpName(node->op_type);
CompilerType *attr_type = left->getattrType(name);
CompilerType *attr_type = left->getattrType(&name, true);
std::vector<CompilerType*> arg_types;
arg_types.push_back(right);
......@@ -253,13 +253,21 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
return BOOL;
}
std::string name = getOpName(node->ops[0]);
CompilerType *attr_type = left->getattrType(name);
CompilerType *attr_type = left->getattrType(&name, true);
std::vector<CompilerType*> arg_types;
arg_types.push_back(right);
return attr_type->callType(arg_types);
}
virtual void* visit_dict(AST_Dict *node) {
// Get all the sub-types, even though they're not necessary to
// determine the expression type, so that things like speculations
// can be processed.
for (auto k : node->keys)
getType(k);
for (auto v : node->values)
getType(v);
return DICT;
}
......@@ -268,6 +276,13 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
}
virtual void* visit_list(AST_List *node) {
// Get all the sub-types, even though they're not necessary to
// determine the expression type, so that things like speculations
// can be processed.
for (auto elt : node->elts) {
getType(elt);
}
return LIST;
}
......@@ -312,7 +327,8 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
virtual void* visit_subscript(AST_Subscript *node) {
CompilerType *val = getType(node->value);
CompilerType *slice = getType(node->slice);
CompilerType *getitem_type = val->getattrType("__getitem__");
static std::string name("__getitem__");
CompilerType *getitem_type = val->getattrType(&name, true);
std::vector<CompilerType*> args;
args.push_back(slice);
return getitem_type->callType(args);
......@@ -331,7 +347,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
// TODO this isn't the exact behavior
std::string name = getOpName(node->op_type);
CompilerType *attr_type = operand->getattrType(name);
CompilerType *attr_type = operand->getattrType(&name, true);
std::vector<CompilerType*> arg_types;
return attr_type->callType(arg_types);
}
......@@ -353,7 +369,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
// TODO this isn't the right behavior
std::string name = getOpName(node->op_type);
name = "__i" + name.substr(2);
CompilerType *attr_type = t->getattrType(name);
CompilerType *attr_type = t->getattrType(&name, true);
std::vector<CompilerType*> arg_types;
arg_types.push_back(v);
......
This diff is collapsed.
......@@ -36,7 +36,7 @@ class CompilerType {
virtual ConcreteCompilerType* getConcreteType() = 0;
virtual ConcreteCompilerType* getBoxType() = 0;
virtual bool canConvertTo(ConcreteCompilerType* other_type) = 0;
virtual CompilerType* getattrType(const std::string &attr) = 0;
virtual CompilerType* getattrType(const std::string *attr, bool cls_only) = 0;
virtual CompilerType* callType(std::vector<CompilerType*> &arg_types) = 0;
virtual BoxedClass* guaranteedClass() = 0;
};
......@@ -80,15 +80,15 @@ class _ValuedCompilerType : public CompilerType {
printf("nonzero not defined for %s\n", debugName().c_str());
abort();
}
virtual CompilerVariable* getattr(IREmitter &emitter, VAR* value, const std::string &attr) {
virtual CompilerVariable* getattr(IREmitter &emitter, VAR* value, const std::string *attr, bool cls_only) {
printf("getattr not defined for %s\n", debugName().c_str());
abort();
}
virtual void setattr(IREmitter &emitter, VAR* value, const std::string &attr, CompilerVariable *v) {
virtual void setattr(IREmitter &emitter, VAR* value, const std::string *attr, CompilerVariable *v) {
printf("setattr not defined for %s\n", debugName().c_str());
abort();
}
virtual CompilerVariable* callattr(IREmitter &emitter, VAR* value, const std::string &attr, bool clsonly, const std::vector<CompilerVariable*>& args) {
virtual CompilerVariable* callattr(IREmitter &emitter, VAR* value, const std::string *attr, bool clsonly, const std::vector<CompilerVariable*>& args) {
printf("callattr not defined for %s\n", debugName().c_str());
abort();
}
......@@ -112,7 +112,7 @@ class _ValuedCompilerType : public CompilerType {
printf("makeClassCheck not defined for %s\n", debugName().c_str());
abort();
}
virtual CompilerType* getattrType(const std::string &attr) {
virtual CompilerType* getattrType(const std::string *attr, bool cls_only) {
printf("getattrType not defined for %s\n", debugName().c_str());
abort();
}
......@@ -205,9 +205,9 @@ class CompilerVariable {
virtual BoxedClass* guaranteedClass() = 0;
virtual ConcreteCompilerVariable* nonzero(IREmitter &emitter) = 0;
virtual CompilerVariable* getattr(IREmitter &emitter, const std::string& attr) = 0;
virtual void setattr(IREmitter &emitter, const std::string& attr, CompilerVariable* v) = 0;
virtual CompilerVariable* callattr(IREmitter &emitter, const std::string &attr, bool clsonly, const std::vector<CompilerVariable*>& args) = 0;
virtual CompilerVariable* getattr(IREmitter &emitter, const std::string *attr, bool cls_only) = 0;
virtual void setattr(IREmitter &emitter, const std::string *attr, CompilerVariable* v) = 0;
virtual CompilerVariable* callattr(IREmitter &emitter, const std::string *attr, bool clsonly, const std::vector<CompilerVariable*>& args) = 0;
virtual CompilerVariable* call(IREmitter &emitter, const std::vector<CompilerVariable*>& args) = 0;
virtual void print(IREmitter &emitter) = 0;
virtual ConcreteCompilerVariable* len(IREmitter &emitter) = 0;
......@@ -268,13 +268,13 @@ class ValuedCompilerVariable : public CompilerVariable {
virtual ConcreteCompilerVariable* nonzero(IREmitter &emitter) {
return type->nonzero(emitter, this);
}
virtual CompilerVariable* getattr(IREmitter &emitter, const std::string& attr) {
return type->getattr(emitter, this, attr);
virtual CompilerVariable* getattr(IREmitter &emitter, const std::string *attr, bool cls_only) {
return type->getattr(emitter, this, attr, cls_only);
}
virtual void setattr(IREmitter &emitter, const std::string& attr, CompilerVariable *v) {
virtual void setattr(IREmitter &emitter, const std::string *attr, CompilerVariable *v) {
type->setattr(emitter, this, attr, v);
}
virtual CompilerVariable* callattr(IREmitter &emitter, const std::string &attr, bool clsonly, const std::vector<CompilerVariable*>& args) {
virtual CompilerVariable* callattr(IREmitter &emitter, const std::string *attr, bool clsonly, const std::vector<CompilerVariable*>& args) {
return type->callattr(emitter, this, attr, clsonly, args);
}
virtual CompilerVariable* call(IREmitter &emitter, const std::vector<CompilerVariable*>& args) {
......
......@@ -217,7 +217,7 @@ class IRGeneratorImpl : public IRGenerator {
CompilerVariable *value = evalExpr(node->value);
CompilerVariable *rtn = value->getattr(emitter, node->attr);
CompilerVariable *rtn = value->getattr(emitter, &node->attr, false);
value->decvref(emitter);
return rtn;
}
......@@ -226,29 +226,9 @@ class IRGeneratorImpl : public IRGenerator {
assert(state != PARTIAL);
CompilerVariable *value = evalExpr(node->value);
//ASSERT((node->attr == "__iter__" || node->attr == "__hasnext__" || node->attr == "next" || node->attr == "__enter__" || node->attr == "__exit__") && (value->getType() == UNDEF || value->getType() == value->getBoxType()) && "inefficient for anything else, should change", "%s", node->attr.c_str());
ConcreteCompilerVariable *converted = value->makeConverted(emitter, value->getBoxType());
CompilerVariable *rtn = value->getattr(emitter, &node->attr, true);
value->decvref(emitter);
bool do_patchpoint = ENABLE_ICGETATTRS && emitter.getTarget() != IREmitter::INTERPRETER;
llvm::Value *rtn;
if (do_patchpoint) {
PatchpointSetupInfo *pp = patchpoints::createGetattrPatchpoint(emitter.currentFunction());
std::vector<llvm::Value*> llvm_args;
llvm_args.push_back(converted->getValue());
llvm_args.push_back(getStringConstantPtr(node->attr + '\0'));
llvm::Value* uncasted = emitter.createPatchpoint(pp, (void*)pyston::getclsattr, llvm_args);
rtn = emitter.getBuilder()->CreateIntToPtr(uncasted, g.llvm_value_type_ptr);
} else {
rtn = emitter.getBuilder()->CreateCall2(g.funcs.getclsattr,
converted->getValue(), getStringConstantPtr(node->attr + '\0'));
}
converted->decvref(emitter);
return new ConcreteCompilerVariable(UNKNOWN, rtn, true);
return rtn;
}
enum BinExpType {
......@@ -538,7 +518,7 @@ class IRGeneratorImpl : public IRGenerator {
CompilerVariable *rtn;
if (is_callattr) {
rtn = func->callattr(emitter, *attr, callattr_clsonly, args);
rtn = func->callattr(emitter, attr, callattr_clsonly, args);
} else {
rtn = func->call(emitter, args);
}
......@@ -560,7 +540,8 @@ class IRGeneratorImpl : public IRGenerator {
llvm::Value* v = emitter.getBuilder()->CreateCall(g.funcs.createDict);
ConcreteCompilerVariable *rtn = new ConcreteCompilerVariable(DICT, v, true);
if (node->keys.size()) {
CompilerVariable *setitem = rtn->getattr(emitter, "__setitem__");
static const std::string setitem_str("__setitem__");
CompilerVariable *setitem = rtn->getattr(emitter, &setitem_str, true);
for (int i = 0; i < node->keys.size(); i++) {
CompilerVariable *key = evalExpr(node->keys[i]);
CompilerVariable *value = evalExpr(node->values[i]);
......@@ -654,7 +635,7 @@ class IRGeneratorImpl : public IRGenerator {
// Method 2 [testing-only]: (ab)uses existing getattr patchpoints and just calls module.getattr()
// This option exists for performance testing because method 1 does not currently use patchpoints.
ConcreteCompilerVariable *mod = new ConcreteCompilerVariable(MODULE, embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_value_type_ptr), false);
CompilerVariable *attr = mod->getattr(emitter, node->id);
CompilerVariable *attr = mod->getattr(emitter, &node->id, false);
mod->decvref(emitter);
return attr;
}
......@@ -822,9 +803,6 @@ class IRGeneratorImpl : public IRGenerator {
case AST_TYPE::List:
rtn = evalList(static_cast<AST_List*>(node));
break;
//case AST_TYPE::ListComp:
//rtn = evalListComp(static_cast<AST_ListComp*>(node));
//break;
case AST_TYPE::Name:
rtn = evalName(static_cast<AST_Name*>(node));
break;
......@@ -1012,7 +990,7 @@ class IRGeneratorImpl : public IRGenerator {
if (irstate->getScopeInfo()->refersToGlobal(name)) {
// TODO do something special here so that it knows to only emit a monomorphic inline cache?
ConcreteCompilerVariable* module = new ConcreteCompilerVariable(MODULE, embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_value_type_ptr), false);
module->setattr(emitter, name, val);
module->setattr(emitter, &name, val);
module->decvref(emitter);
} else {
CompilerVariable* &prev = symbol_table[name];
......@@ -1027,7 +1005,7 @@ class IRGeneratorImpl : public IRGenerator {
void _doSetattr(AST_Attribute* target, CompilerVariable* val) {
assert(state != PARTIAL);
CompilerVariable *t = evalExpr(target->value);
t->setattr(emitter, target->attr, val);
t->setattr(emitter, &target->attr, val);
t->decvref(emitter);
}
......@@ -1148,7 +1126,7 @@ class IRGeneratorImpl : public IRGenerator {
AST_FunctionDef *fdef = static_cast<AST_FunctionDef*>(node->body[i]);
CLFunction *cl = this->_wrapFunction(fdef);
CompilerVariable *func = makeFunction(emitter, cl);
cls->setattr(emitter, fdef->name, func);
cls->setattr(emitter, &fdef->name, func);
func->decvref(emitter);
} else {
RELEASE_ASSERT(node->body[i]->type == AST_TYPE::Pass, "%d", type);
......@@ -1369,7 +1347,7 @@ class IRGeneratorImpl : public IRGenerator {
for (SortedSymbolTable::iterator it = sorted_symbol_table.begin(), end = sorted_symbol_table.end(); it != end; ++it, ++i) {
// I don't think this can fail, but if it can we should filter out dead symbols before
// passing them on:
assert(irstate->getSourceInfo()->liveness->isLiveAtEnd(it->first, myblock));
ASSERT(irstate->getSourceInfo()->liveness->isLiveAtEnd(it->first, myblock), "%d %s", myblock->idx, it->first.c_str());
// This line can never get hit right now since we unnecessarily force every variable to be concrete
// for a loop, since we generate all potential phis:
......
......@@ -127,12 +127,15 @@ static void readExprVector(std::vector<AST_expr*> &vec, BufferedReader *reader)
}
}
static void readMiscVector(std::vector<AST*> &vec, BufferedReader *reader) {
template <class T>
static void readMiscVector(std::vector<T*> &vec, BufferedReader *reader) {
int num_elts = reader->readShort();
if (VERBOSITY("parsing") >= 2)
printf("%d elts to read\n", num_elts);
for (int i = 0; i < num_elts; i++) {
vec.push_back(readASTMisc(reader));
AST* read = readASTMisc(reader);
assert(read->type == T::TYPE);
vec.push_back(static_cast<T*>(read));
}
}
......@@ -240,12 +243,7 @@ AST_Call* read_call(BufferedReader *reader) {
rtn->col_offset = readColOffset(reader);
rtn->func = readASTExpr(reader);
std::vector<AST*> keyword_vec;
readMiscVector(keyword_vec, reader);
for (int i = 0; i < keyword_vec.size(); i++) {
assert(keyword_vec[i]->type == AST_TYPE::keyword);
rtn->keywords.push_back(static_cast<AST_keyword*>(keyword_vec[i]));
}
readMiscVector(rtn->keywords, reader);
rtn->kwargs = readASTExpr(reader);
rtn->lineno = reader->readULL();
......@@ -267,26 +265,18 @@ AST_expr* read_compare(BufferedReader *reader) {
rtn->ops.push_back((AST_TYPE::AST_TYPE)reader->readByte());
}
/*{
assert(rtn->ops.size() == 1);
AST_Attribute *func = new AST_Attribute();
func->type = AST_TYPE::Attribute;
func->attr = getOpName(rtn->ops[0]);
func->col_offset = rtn->col_offset;
func->ctx_type = AST_TYPE::Load;
func->lineno = rtn->lineno;
func->value = rtn->left;
return rtn;
}
AST_comprehension* read_comprehension(BufferedReader *reader) {
AST_comprehension *rtn = new AST_comprehension();
AST_Call *call = new AST_Call();
call->type = AST_TYPE::Call;
call->args.push_back(rtn->comparators[0]);
call->col_offset = rtn->col_offset;
call->func = func;
call->kwargs = NULL;
call->lineno = rtn->lineno;
call->starargs = NULL;
return call;
}*/
readExprVector(rtn->ifs, reader);
rtn->iter = readASTExpr(reader);
rtn->target = readASTExpr(reader);
rtn->col_offset = -1;
rtn->lineno = -1;
return rtn;
}
......@@ -436,6 +426,16 @@ AST_List* read_list(BufferedReader *reader) {
return rtn;
}
AST_ListComp* read_listcomp(BufferedReader *reader) {
AST_ListComp *rtn = new AST_ListComp();
rtn->col_offset = readColOffset(reader);
rtn->elt = readASTExpr(reader);
readMiscVector(rtn->generators, reader);
rtn->lineno = reader->readULL();
return rtn;
}
AST_Module* read_module(BufferedReader *reader) {
if (VERBOSITY("parsing") >= 2)
printf("reading module\n");
......@@ -612,6 +612,8 @@ AST_expr* readASTExpr(BufferedReader *reader) {
return read_index(reader);
case AST_TYPE::List:
return read_list(reader);
case AST_TYPE::ListComp:
return read_listcomp(reader);
case AST_TYPE::Name:
return read_name(reader);
case AST_TYPE::Num:
......@@ -698,6 +700,8 @@ AST* readASTMisc(BufferedReader *reader) {
return read_alias(reader);
case AST_TYPE::arguments:
return read_arguments(reader);
case AST_TYPE::comprehension:
return read_comprehension(reader);
case AST_TYPE::keyword:
return read_keyword(reader);
case AST_TYPE::Module:
......
......@@ -287,6 +287,17 @@ void* AST_Compare::accept_expr(ExprVisitor *v) {
return v->visit_compare(this);
}
void AST_comprehension::accept(ASTVisitor *v) {
bool skip = v->visit_comprehension(this);
if (skip) return;
target->accept(v);
iter->accept(v);
for (auto if_ : ifs) {
if_->accept(v);
}
}
void AST_ClassDef::accept(ASTVisitor *v) {
bool skip = v->visit_classdef(this);
if (skip) return;
......@@ -436,6 +447,21 @@ void* AST_List::accept_expr(ExprVisitor *v) {
return v->visit_list(this);
}
void AST_ListComp::accept(ASTVisitor *v) {
bool skip = v->visit_listcomp(this);
if (skip) return;
for (auto c : generators) {
c->accept(v);
}
elt->accept(v);
}
void* AST_ListComp::accept_expr(ExprVisitor *v) {
return v->visit_listcomp(this);
}
void AST_Module::accept(ASTVisitor *v) {
bool skip = v->visit_module(this);
if (skip) return;
......@@ -781,6 +807,20 @@ bool PrintVisitor::visit_compare(AST_Compare *node) {
return true;
}
bool PrintVisitor::visit_comprehension(AST_comprehension *node) {
printf("for ");
node->target->accept(this);
printf(" in ");
node->iter->accept(this);
for (AST_expr *i : node->ifs) {
printf(" if ");
i->accept(this);
}
return true;
}
bool PrintVisitor::visit_classdef(AST_ClassDef *node) {
for (int i = 0, n = node->decorator_list.size(); i < n; i++) {
printf("@");
......@@ -928,6 +968,17 @@ bool PrintVisitor::visit_list(AST_List *node) {
return true;
}
bool PrintVisitor::visit_listcomp(AST_ListComp *node) {
printf("[");
node->elt->accept(this);
for (auto c : node->generators) {
printf(" ");
c->accept(this);
}
printf("]");
return true;
}
bool PrintVisitor::visit_keyword(AST_keyword *node) {
printf("%s=", node->arg.c_str());
node->value->accept(this);
......@@ -1135,6 +1186,7 @@ class FlattenVisitor : public ASTVisitor {
virtual bool visit_call(AST_Call *node) { output->push_back(node); return false; }
virtual bool visit_classdef(AST_ClassDef *node) { output->push_back(node); return !expand_scopes; }
virtual bool visit_compare(AST_Compare *node) { output->push_back(node); return false; }
virtual bool visit_comprehension(AST_comprehension *node) { output->push_back(node); return false; }
virtual bool visit_continue(AST_Continue *node) { output->push_back(node); return false; }
virtual bool visit_dict(AST_Dict *node) { output->push_back(node); return false; }
virtual bool visit_expr(AST_Expr *node) { output->push_back(node); return false; }
......@@ -1145,6 +1197,7 @@ class FlattenVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node) { output->push_back(node); return false; }
virtual bool visit_keyword(AST_keyword *node) { output->push_back(node); return false; }
virtual bool visit_list(AST_List *node) { output->push_back(node); return false; }
virtual bool visit_listcomp(AST_ListComp *node) { output->push_back(node); return false; }
virtual bool visit_module(AST_Module *node) { output->push_back(node); return !expand_scopes; }
virtual bool visit_name(AST_Name *node) { output->push_back(node); return false; }
virtual bool visit_num(AST_Num *node) { output->push_back(node); return false; }
......@@ -1174,4 +1227,12 @@ std::vector<AST*>* flatten(std::vector<AST_stmt*> &roots, bool expand_scopes) {
return rtn;
}
std::vector<AST*>* flatten(AST_expr* root, bool expand_scopes) {
std::vector<AST*> *rtn = new std::vector<AST*>();
FlattenVisitor visitor(rtn, expand_scopes);
root->accept(&visitor);
return rtn;
}
}
......@@ -249,6 +249,8 @@ class AST_Call : public AST_expr {
virtual void* accept_expr(ExprVisitor *v);
AST_Call() : AST_expr(AST_TYPE::Call) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Call;
};
class AST_Compare : public AST_expr {
......@@ -263,6 +265,19 @@ class AST_Compare : public AST_expr {
AST_Compare() : AST_expr(AST_TYPE::Compare) {}
};
class AST_comprehension : public AST {
public:
AST_expr* target;
AST_expr* iter;
std::vector<AST_expr*> ifs;
virtual void accept(ASTVisitor *v);
AST_comprehension() : AST(AST_TYPE::comprehension) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::comprehension;
};
class AST_ClassDef : public AST_stmt {
public:
virtual void accept(ASTVisitor *v);
......@@ -391,6 +406,8 @@ class AST_keyword : public AST {
virtual void accept(ASTVisitor *v);
AST_keyword() : AST(AST_TYPE::keyword) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword;
};
class AST_List : public AST_expr {
......@@ -406,6 +423,19 @@ class AST_List : public AST_expr {
AST_List() : AST_expr(AST_TYPE::List) {}
};
class AST_ListComp : public AST_expr {
public:
const static AST_TYPE::AST_TYPE TYPE = AST_TYPE::ListComp;
std::vector<AST_comprehension*> generators;
AST_expr* elt;
virtual void accept(ASTVisitor *v);
virtual void* accept_expr(ExprVisitor *v);
AST_ListComp() : AST_expr(AST_TYPE::ListComp) {}
};
class AST_Module : public AST {
public:
// no lineno, col_offset attributes
......@@ -607,6 +637,7 @@ class ASTVisitor {
virtual bool visit_call(AST_Call *node) { assert(0); abort(); }
virtual bool visit_clsattribute(AST_ClsAttribute *node) { assert(0); abort(); }
virtual bool visit_compare(AST_Compare *node) { assert(0); abort(); }
virtual bool visit_comprehension(AST_comprehension *node) { assert(0); abort(); }
virtual bool visit_classdef(AST_ClassDef *node) { assert(0); abort(); }
virtual bool visit_continue(AST_Continue *node) { assert(0); abort(); }
virtual bool visit_dict(AST_Dict *node) { assert(0); abort(); }
......@@ -620,6 +651,7 @@ class ASTVisitor {
virtual bool visit_index(AST_Index *node) { assert(0); abort(); }
virtual bool visit_keyword(AST_keyword *node) { assert(0); abort(); }
virtual bool visit_list(AST_List *node) { assert(0); abort(); }
virtual bool visit_listcomp(AST_ListComp *node) { assert(0); abort(); }
virtual bool visit_module(AST_Module *node) { assert(0); abort(); }
virtual bool visit_name(AST_Name *node) { assert(0); abort(); }
virtual bool visit_num(AST_Num *node) { assert(0); abort(); }
......@@ -654,6 +686,7 @@ class NoopASTVisitor : public ASTVisitor {
virtual bool visit_call(AST_Call *node) { return false; }
virtual bool visit_clsattribute(AST_ClsAttribute *node) { return false; }
virtual bool visit_compare(AST_Compare *node) { return false; }
virtual bool visit_comprehension(AST_comprehension *node) { return false; }
virtual bool visit_classdef(AST_ClassDef *node) { return false; }
virtual bool visit_continue(AST_Continue *node) { return false; }
virtual bool visit_dict(AST_Dict *node) { return false; }
......@@ -667,6 +700,7 @@ class NoopASTVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node) { return false; }
virtual bool visit_keyword(AST_keyword *node) { return false; }
virtual bool visit_list(AST_List *node) { return false; }
virtual bool visit_listcomp(AST_ListComp *node) { return false; }
virtual bool visit_module(AST_Module *node) { return false; }
virtual bool visit_name(AST_Name *node) { return false; }
virtual bool visit_num(AST_Num *node) { return false; }
......@@ -700,6 +734,7 @@ class ExprVisitor {
virtual void* visit_ifexp(AST_IfExp *node) { assert(0); abort(); }
virtual void* visit_index(AST_Index *node) { assert(0); abort(); }
virtual void* visit_list(AST_List *node) { assert(0); abort(); }
virtual void* visit_listcomp(AST_ListComp *node) { assert(0); abort(); }
virtual void* visit_name(AST_Name *node) { assert(0); abort(); }
virtual void* visit_num(AST_Num *node) { assert(0); abort(); }
virtual void* visit_slice(AST_Slice *node) { assert(0); abort(); }
......@@ -754,6 +789,7 @@ class PrintVisitor : public ASTVisitor {
virtual bool visit_break(AST_Break *node);
virtual bool visit_call(AST_Call *node);
virtual bool visit_compare(AST_Compare *node);
virtual bool visit_comprehension(AST_comprehension *node);
virtual bool visit_classdef(AST_ClassDef *node);
virtual bool visit_clsattribute(AST_ClsAttribute *node);
virtual bool visit_continue(AST_Continue *node);
......@@ -768,6 +804,7 @@ class PrintVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node);
virtual bool visit_keyword(AST_keyword *node);
virtual bool visit_list(AST_List *node);
virtual bool visit_listcomp(AST_ListComp *node);
virtual bool visit_module(AST_Module *node);
virtual bool visit_name(AST_Name *node);
virtual bool visit_num(AST_Num *node);
......@@ -790,9 +827,10 @@ class PrintVisitor : public ASTVisitor {
// This is useful for analyses that care more about the constituent nodes than the
// exact tree structure; ex, finding all "global" directives.
std::vector<AST*>* flatten(std::vector<AST_stmt*> &roots, bool expand_scopes);
std::vector<AST*>* flatten(AST_expr *root, bool expand_scopes);
// Similar to the flatten() function, but filters for a specific type of ast nodes:
template <class T>
std::vector<T*>* findNodes(std::vector<AST_stmt*> &roots, bool expand_scopes) {
template <class T, class R>
std::vector<T*>* findNodes(const R &roots, bool expand_scopes) {
std::vector<T*> *rtn = new std::vector<T*>();
std::vector<AST*> *flattened = flatten(roots, expand_scopes);
for (int i = 0; i < flattened->size(); i++) {
......
......@@ -153,6 +153,17 @@ class CFGVisitor : public ASTVisitor {
return call;
}
AST_Call* makeCall(AST_expr* func, AST_expr* arg0) {
AST_Call *call = new AST_Call();
call->args.push_back(arg0);
call->starargs = NULL;
call->kwargs = NULL;
call->func = func;
call->col_offset = func->col_offset;
call->lineno = func->lineno;
return call;
}
AST_Name* makeName(const std::string &id, AST_TYPE::AST_TYPE ctx_type, int lineno=-1, int col_offset=-1) {
AST_Name *name = new AST_Name();
name->id = id;
......@@ -193,6 +204,12 @@ class CFGVisitor : public ASTVisitor {
return std::string(buf);
}
std::string nodeName(AST_expr* node, const std::string &suffix, int idx) {
char buf[50];
snprintf(buf, 50, "!%p_%s_%d", node, suffix.c_str(), idx);
return std::string(buf);
}
AST_expr* remapAttribute(AST_Attribute* node) {
AST_Attribute *rtn = new AST_Attribute();
......@@ -386,6 +403,135 @@ class CFGVisitor : public ASTVisitor {
return rtn;
}
AST_expr* remapListComp(AST_ListComp* node) {
std::string rtn_name = nodeName(node);
push_back(makeAssign(rtn_name, new AST_List()));
std::vector<CFGBlock*> exit_blocks;
// Where the current level should jump to after finishing its iteration.
// For the outermost comprehension, this is NULL, and it doesn't jump anywhere;
// for the inner comprehensions, they should jump to the next-outer comprehension
// when they are done iterating.
CFGBlock *finished_block = NULL;
for (int i = 0, n = node->generators.size(); i < n; i++) {
AST_comprehension *c = node->generators[i];
bool is_innermost = (i == n-1);
AST_expr *remapped_iter = remapExpr(c->iter);
AST_expr *iter_attr = makeLoadAttribute(remapped_iter, "__iter__", true);
AST_expr *iter_call = makeCall(iter_attr);
std::string iter_name = nodeName(node, "iter", i);
AST_stmt *iter_assign = makeAssign(iter_name, iter_call);
push_back(iter_assign);
// TODO bad to save these like this?
AST_expr *hasnext_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "__hasnext__", true);
AST_expr *next_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "next", true);
AST_Jump *j;
CFGBlock *test_block = cfg->addBlock();
test_block->info = "listcomp_test";
//printf("Test block for comp %d is %d\n", i, test_block->idx);
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block);
push_back(j);
curblock = test_block;
AST_expr *test_call = makeCall(hasnext_attr);
CFGBlock* body_block = cfg->addBlock();
body_block->info = "listcomp_body";
CFGBlock* exit_block = cfg->addDeferredBlock();
exit_block->info = "listcomp_exit";
exit_blocks.push_back(exit_block);
//printf("Body block for comp %d is %d\n", i, body_block->idx);
AST_Branch *br = new AST_Branch();
br->col_offset = node->col_offset;
br->lineno = node->lineno;
br->test = test_call;
br->iftrue = body_block;
br->iffalse = exit_block;
curblock->connectTo(body_block);
curblock->connectTo(exit_block);
push_back(br);
curblock = body_block;
push_back(makeAssign(c->target, makeCall(next_attr)));
for (AST_expr *if_condition : c->ifs) {
AST_expr *remapped = remapExpr(if_condition);
AST_Branch *br = new AST_Branch();
br->test = remapped;
push_back(br);
// Put this below the entire body?
CFGBlock *body_tramp = cfg->addBlock();
body_tramp->info = "listcomp_if_trampoline";
//printf("body_tramp for %d is %d\n", i, body_tramp->idx);
CFGBlock *body_continue = cfg->addBlock();
body_continue->info = "listcomp_if_continue";
//printf("body_continue for %d is %d\n", i, body_continue->idx);
br->iffalse = body_tramp;
curblock->connectTo(body_tramp);
br->iftrue = body_continue;
curblock->connectTo(body_continue);
curblock = body_tramp;
j = new AST_Jump();
j->target = test_block;
push_back(j);
curblock->connectTo(test_block, true);
curblock = body_continue;
}
CFGBlock *body_end = curblock;
assert((finished_block != NULL) == (i != 0));
if (finished_block) {
curblock = exit_block;
j = new AST_Jump();
j->target = finished_block;
curblock->connectTo(finished_block, true);
push_back(j);
}
finished_block = test_block;
curblock = body_end;
if (is_innermost) {
AST_expr *elt = remapExpr(node->elt);
push_back(makeExpr(makeCall(makeLoadAttribute(makeName(rtn_name, AST_TYPE::Load), "append", true), elt)));
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block, true);
push_back(j);
assert(exit_blocks.size());
curblock = exit_blocks[0];
} else {
// continue onto the next comprehension and add to this body
}
}
// Wait until the end to place the end blocks, so that
// we get a nice nesting structure, that looks similar to what
// you'd get with a nested for loop:
for (int i = exit_blocks.size() - 1; i >= 0; i--) {
cfg->placeBlock(exit_blocks[i]);
printf("Exit block for comp %d is %d\n", i, exit_blocks[i]->idx);
}
return makeName(rtn_name, AST_TYPE::Load);
};
AST_expr* remapSlice(AST_Slice* node) {
AST_Slice *rtn = new AST_Slice();
rtn->lineno = node->lineno;
......@@ -464,6 +610,9 @@ class CFGVisitor : public ASTVisitor {
case AST_TYPE::List:
rtn = remapList(static_cast<AST_List*>(node));
break;
case AST_TYPE::ListComp:
rtn = remapListComp(static_cast<AST_ListComp*>(node));
break;
case AST_TYPE::Name:
return node;
case AST_TYPE::Num:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment