Commit 92117aab authored by Kevin Modzelewski's avatar Kevin Modzelewski

List comprehension support

Should support:
- multiple comprehensions
- multiple if conditions
- nested control flow expressions
- OSR'ing from the list comprehension

Though it tends to hit the OSR bug in the previous commit.

Some extra changes that could have been split out:
- use pointers-to-const instead of references-to-const for attribute-name passing,
  to make it harder to bind to a temporary name that will go away.
- add a 'cls_only' flag to getattr / getattrType to not have to special-case clsattrs
  (or simply get it wrong, in the case of getattrType)
parent 33e65f3d
...@@ -170,6 +170,7 @@ class NameCollectorVisitor : public ASTVisitor { ...@@ -170,6 +170,7 @@ class NameCollectorVisitor : public ASTVisitor {
virtual bool visit_break(AST_Break *node) { return false; } virtual bool visit_break(AST_Break *node) { return false; }
virtual bool visit_call(AST_Call *node) { return false; } virtual bool visit_call(AST_Call *node) { return false; }
virtual bool visit_compare(AST_Compare *node) { return false; } virtual bool visit_compare(AST_Compare *node) { return false; }
virtual bool visit_comprehension(AST_comprehension *node) { return false; }
//virtual bool visit_classdef(AST_ClassDef *node) { return false; } //virtual bool visit_classdef(AST_ClassDef *node) { return false; }
virtual bool visit_continue(AST_Continue *node) { return false; } virtual bool visit_continue(AST_Continue *node) { return false; }
virtual bool visit_dict(AST_Dict *node) { return false; } virtual bool visit_dict(AST_Dict *node) { return false; }
...@@ -181,6 +182,7 @@ class NameCollectorVisitor : public ASTVisitor { ...@@ -181,6 +182,7 @@ class NameCollectorVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node) { return false; } virtual bool visit_index(AST_Index *node) { return false; }
//virtual bool visit_keyword(AST_keyword *node) { return false; } //virtual bool visit_keyword(AST_keyword *node) { return false; }
virtual bool visit_list(AST_List *node) { return false; } virtual bool visit_list(AST_List *node) { return false; }
virtual bool visit_listcomp(AST_ListComp *node) { return false; }
//virtual bool visit_module(AST_Module *node) { return false; } //virtual bool visit_module(AST_Module *node) { return false; }
//virtual bool visit_name(AST_Name *node) { return false; } //virtual bool visit_name(AST_Name *node) { return false; }
virtual bool visit_num(AST_Num *node) { return false; } virtual bool visit_num(AST_Num *node) { return false; }
......
...@@ -159,7 +159,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -159,7 +159,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
virtual void* visit_attribute(AST_Attribute *node) { virtual void* visit_attribute(AST_Attribute *node) {
CompilerType *t = getType(node->value); CompilerType *t = getType(node->value);
assert(node->ctx_type == AST_TYPE::Load); assert(node->ctx_type == AST_TYPE::Load);
CompilerType *rtn = t->getattrType(node->attr); CompilerType *rtn = t->getattrType(&node->attr, false);
//if (speculation != TypeAnalysis::NONE && (node->attr == "x" || node->attr == "y" || node->attr == "z")) { //if (speculation != TypeAnalysis::NONE && (node->attr == "x" || node->attr == "y" || node->attr == "z")) {
//rtn = processSpeculation(float_cls, node, rtn); //rtn = processSpeculation(float_cls, node, rtn);
...@@ -175,7 +175,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -175,7 +175,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
virtual void* visit_clsattribute(AST_ClsAttribute *node) { virtual void* visit_clsattribute(AST_ClsAttribute *node) {
CompilerType *t = getType(node->value); CompilerType *t = getType(node->value);
CompilerType *rtn = t->getattrType(node->attr); CompilerType *rtn = t->getattrType(&node->attr, true);
if (VERBOSITY() >= 2 && rtn == UNDEF) { if (VERBOSITY() >= 2 && rtn == UNDEF) {
printf("Think %s.%s is undefined, at %d:%d\n", t->debugName().c_str(), node->attr.c_str(), node->lineno, node->col_offset); printf("Think %s.%s is undefined, at %d:%d\n", t->debugName().c_str(), node->attr.c_str(), node->lineno, node->col_offset);
print_ast(node); print_ast(node);
...@@ -190,7 +190,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -190,7 +190,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
// TODO this isn't the exact behavior // TODO this isn't the exact behavior
std::string name = getOpName(node->op_type); std::string name = getOpName(node->op_type);
CompilerType *attr_type = left->getattrType(name); CompilerType *attr_type = left->getattrType(&name, true);
std::vector<CompilerType*> arg_types; std::vector<CompilerType*> arg_types;
arg_types.push_back(right); arg_types.push_back(right);
...@@ -253,13 +253,21 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -253,13 +253,21 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
return BOOL; return BOOL;
} }
std::string name = getOpName(node->ops[0]); std::string name = getOpName(node->ops[0]);
CompilerType *attr_type = left->getattrType(name); CompilerType *attr_type = left->getattrType(&name, true);
std::vector<CompilerType*> arg_types; std::vector<CompilerType*> arg_types;
arg_types.push_back(right); arg_types.push_back(right);
return attr_type->callType(arg_types); return attr_type->callType(arg_types);
} }
virtual void* visit_dict(AST_Dict *node) { virtual void* visit_dict(AST_Dict *node) {
// Get all the sub-types, even though they're not necessary to
// determine the expression type, so that things like speculations
// can be processed.
for (auto k : node->keys)
getType(k);
for (auto v : node->values)
getType(v);
return DICT; return DICT;
} }
...@@ -268,6 +276,13 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -268,6 +276,13 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
} }
virtual void* visit_list(AST_List *node) { virtual void* visit_list(AST_List *node) {
// Get all the sub-types, even though they're not necessary to
// determine the expression type, so that things like speculations
// can be processed.
for (auto elt : node->elts) {
getType(elt);
}
return LIST; return LIST;
} }
...@@ -312,7 +327,8 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -312,7 +327,8 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
virtual void* visit_subscript(AST_Subscript *node) { virtual void* visit_subscript(AST_Subscript *node) {
CompilerType *val = getType(node->value); CompilerType *val = getType(node->value);
CompilerType *slice = getType(node->slice); CompilerType *slice = getType(node->slice);
CompilerType *getitem_type = val->getattrType("__getitem__"); static std::string name("__getitem__");
CompilerType *getitem_type = val->getattrType(&name, true);
std::vector<CompilerType*> args; std::vector<CompilerType*> args;
args.push_back(slice); args.push_back(slice);
return getitem_type->callType(args); return getitem_type->callType(args);
...@@ -331,7 +347,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -331,7 +347,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
// TODO this isn't the exact behavior // TODO this isn't the exact behavior
std::string name = getOpName(node->op_type); std::string name = getOpName(node->op_type);
CompilerType *attr_type = operand->getattrType(name); CompilerType *attr_type = operand->getattrType(&name, true);
std::vector<CompilerType*> arg_types; std::vector<CompilerType*> arg_types;
return attr_type->callType(arg_types); return attr_type->callType(arg_types);
} }
...@@ -353,7 +369,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor { ...@@ -353,7 +369,7 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
// TODO this isn't the right behavior // TODO this isn't the right behavior
std::string name = getOpName(node->op_type); std::string name = getOpName(node->op_type);
name = "__i" + name.substr(2); name = "__i" + name.substr(2);
CompilerType *attr_type = t->getattrType(name); CompilerType *attr_type = t->getattrType(&name, true);
std::vector<CompilerType*> arg_types; std::vector<CompilerType*> arg_types;
arg_types.push_back(v); arg_types.push_back(v);
......
This diff is collapsed.
...@@ -36,7 +36,7 @@ class CompilerType { ...@@ -36,7 +36,7 @@ class CompilerType {
virtual ConcreteCompilerType* getConcreteType() = 0; virtual ConcreteCompilerType* getConcreteType() = 0;
virtual ConcreteCompilerType* getBoxType() = 0; virtual ConcreteCompilerType* getBoxType() = 0;
virtual bool canConvertTo(ConcreteCompilerType* other_type) = 0; virtual bool canConvertTo(ConcreteCompilerType* other_type) = 0;
virtual CompilerType* getattrType(const std::string &attr) = 0; virtual CompilerType* getattrType(const std::string *attr, bool cls_only) = 0;
virtual CompilerType* callType(std::vector<CompilerType*> &arg_types) = 0; virtual CompilerType* callType(std::vector<CompilerType*> &arg_types) = 0;
virtual BoxedClass* guaranteedClass() = 0; virtual BoxedClass* guaranteedClass() = 0;
}; };
...@@ -80,15 +80,15 @@ class _ValuedCompilerType : public CompilerType { ...@@ -80,15 +80,15 @@ class _ValuedCompilerType : public CompilerType {
printf("nonzero not defined for %s\n", debugName().c_str()); printf("nonzero not defined for %s\n", debugName().c_str());
abort(); abort();
} }
virtual CompilerVariable* getattr(IREmitter &emitter, VAR* value, const std::string &attr) { virtual CompilerVariable* getattr(IREmitter &emitter, VAR* value, const std::string *attr, bool cls_only) {
printf("getattr not defined for %s\n", debugName().c_str()); printf("getattr not defined for %s\n", debugName().c_str());
abort(); abort();
} }
virtual void setattr(IREmitter &emitter, VAR* value, const std::string &attr, CompilerVariable *v) { virtual void setattr(IREmitter &emitter, VAR* value, const std::string *attr, CompilerVariable *v) {
printf("setattr not defined for %s\n", debugName().c_str()); printf("setattr not defined for %s\n", debugName().c_str());
abort(); abort();
} }
virtual CompilerVariable* callattr(IREmitter &emitter, VAR* value, const std::string &attr, bool clsonly, const std::vector<CompilerVariable*>& args) { virtual CompilerVariable* callattr(IREmitter &emitter, VAR* value, const std::string *attr, bool clsonly, const std::vector<CompilerVariable*>& args) {
printf("callattr not defined for %s\n", debugName().c_str()); printf("callattr not defined for %s\n", debugName().c_str());
abort(); abort();
} }
...@@ -112,7 +112,7 @@ class _ValuedCompilerType : public CompilerType { ...@@ -112,7 +112,7 @@ class _ValuedCompilerType : public CompilerType {
printf("makeClassCheck not defined for %s\n", debugName().c_str()); printf("makeClassCheck not defined for %s\n", debugName().c_str());
abort(); abort();
} }
virtual CompilerType* getattrType(const std::string &attr) { virtual CompilerType* getattrType(const std::string *attr, bool cls_only) {
printf("getattrType not defined for %s\n", debugName().c_str()); printf("getattrType not defined for %s\n", debugName().c_str());
abort(); abort();
} }
...@@ -205,9 +205,9 @@ class CompilerVariable { ...@@ -205,9 +205,9 @@ class CompilerVariable {
virtual BoxedClass* guaranteedClass() = 0; virtual BoxedClass* guaranteedClass() = 0;
virtual ConcreteCompilerVariable* nonzero(IREmitter &emitter) = 0; virtual ConcreteCompilerVariable* nonzero(IREmitter &emitter) = 0;
virtual CompilerVariable* getattr(IREmitter &emitter, const std::string& attr) = 0; virtual CompilerVariable* getattr(IREmitter &emitter, const std::string *attr, bool cls_only) = 0;
virtual void setattr(IREmitter &emitter, const std::string& attr, CompilerVariable* v) = 0; virtual void setattr(IREmitter &emitter, const std::string *attr, CompilerVariable* v) = 0;
virtual CompilerVariable* callattr(IREmitter &emitter, const std::string &attr, bool clsonly, const std::vector<CompilerVariable*>& args) = 0; virtual CompilerVariable* callattr(IREmitter &emitter, const std::string *attr, bool clsonly, const std::vector<CompilerVariable*>& args) = 0;
virtual CompilerVariable* call(IREmitter &emitter, const std::vector<CompilerVariable*>& args) = 0; virtual CompilerVariable* call(IREmitter &emitter, const std::vector<CompilerVariable*>& args) = 0;
virtual void print(IREmitter &emitter) = 0; virtual void print(IREmitter &emitter) = 0;
virtual ConcreteCompilerVariable* len(IREmitter &emitter) = 0; virtual ConcreteCompilerVariable* len(IREmitter &emitter) = 0;
...@@ -268,13 +268,13 @@ class ValuedCompilerVariable : public CompilerVariable { ...@@ -268,13 +268,13 @@ class ValuedCompilerVariable : public CompilerVariable {
virtual ConcreteCompilerVariable* nonzero(IREmitter &emitter) { virtual ConcreteCompilerVariable* nonzero(IREmitter &emitter) {
return type->nonzero(emitter, this); return type->nonzero(emitter, this);
} }
virtual CompilerVariable* getattr(IREmitter &emitter, const std::string& attr) { virtual CompilerVariable* getattr(IREmitter &emitter, const std::string *attr, bool cls_only) {
return type->getattr(emitter, this, attr); return type->getattr(emitter, this, attr, cls_only);
} }
virtual void setattr(IREmitter &emitter, const std::string& attr, CompilerVariable *v) { virtual void setattr(IREmitter &emitter, const std::string *attr, CompilerVariable *v) {
type->setattr(emitter, this, attr, v); type->setattr(emitter, this, attr, v);
} }
virtual CompilerVariable* callattr(IREmitter &emitter, const std::string &attr, bool clsonly, const std::vector<CompilerVariable*>& args) { virtual CompilerVariable* callattr(IREmitter &emitter, const std::string *attr, bool clsonly, const std::vector<CompilerVariable*>& args) {
return type->callattr(emitter, this, attr, clsonly, args); return type->callattr(emitter, this, attr, clsonly, args);
} }
virtual CompilerVariable* call(IREmitter &emitter, const std::vector<CompilerVariable*>& args) { virtual CompilerVariable* call(IREmitter &emitter, const std::vector<CompilerVariable*>& args) {
......
...@@ -217,7 +217,7 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -217,7 +217,7 @@ class IRGeneratorImpl : public IRGenerator {
CompilerVariable *value = evalExpr(node->value); CompilerVariable *value = evalExpr(node->value);
CompilerVariable *rtn = value->getattr(emitter, node->attr); CompilerVariable *rtn = value->getattr(emitter, &node->attr, false);
value->decvref(emitter); value->decvref(emitter);
return rtn; return rtn;
} }
...@@ -226,29 +226,9 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -226,29 +226,9 @@ class IRGeneratorImpl : public IRGenerator {
assert(state != PARTIAL); assert(state != PARTIAL);
CompilerVariable *value = evalExpr(node->value); CompilerVariable *value = evalExpr(node->value);
CompilerVariable *rtn = value->getattr(emitter, &node->attr, true);
//ASSERT((node->attr == "__iter__" || node->attr == "__hasnext__" || node->attr == "next" || node->attr == "__enter__" || node->attr == "__exit__") && (value->getType() == UNDEF || value->getType() == value->getBoxType()) && "inefficient for anything else, should change", "%s", node->attr.c_str());
ConcreteCompilerVariable *converted = value->makeConverted(emitter, value->getBoxType());
value->decvref(emitter); value->decvref(emitter);
return rtn;
bool do_patchpoint = ENABLE_ICGETATTRS && emitter.getTarget() != IREmitter::INTERPRETER;
llvm::Value *rtn;
if (do_patchpoint) {
PatchpointSetupInfo *pp = patchpoints::createGetattrPatchpoint(emitter.currentFunction());
std::vector<llvm::Value*> llvm_args;
llvm_args.push_back(converted->getValue());
llvm_args.push_back(getStringConstantPtr(node->attr + '\0'));
llvm::Value* uncasted = emitter.createPatchpoint(pp, (void*)pyston::getclsattr, llvm_args);
rtn = emitter.getBuilder()->CreateIntToPtr(uncasted, g.llvm_value_type_ptr);
} else {
rtn = emitter.getBuilder()->CreateCall2(g.funcs.getclsattr,
converted->getValue(), getStringConstantPtr(node->attr + '\0'));
}
converted->decvref(emitter);
return new ConcreteCompilerVariable(UNKNOWN, rtn, true);
} }
enum BinExpType { enum BinExpType {
...@@ -538,7 +518,7 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -538,7 +518,7 @@ class IRGeneratorImpl : public IRGenerator {
CompilerVariable *rtn; CompilerVariable *rtn;
if (is_callattr) { if (is_callattr) {
rtn = func->callattr(emitter, *attr, callattr_clsonly, args); rtn = func->callattr(emitter, attr, callattr_clsonly, args);
} else { } else {
rtn = func->call(emitter, args); rtn = func->call(emitter, args);
} }
...@@ -560,7 +540,8 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -560,7 +540,8 @@ class IRGeneratorImpl : public IRGenerator {
llvm::Value* v = emitter.getBuilder()->CreateCall(g.funcs.createDict); llvm::Value* v = emitter.getBuilder()->CreateCall(g.funcs.createDict);
ConcreteCompilerVariable *rtn = new ConcreteCompilerVariable(DICT, v, true); ConcreteCompilerVariable *rtn = new ConcreteCompilerVariable(DICT, v, true);
if (node->keys.size()) { if (node->keys.size()) {
CompilerVariable *setitem = rtn->getattr(emitter, "__setitem__"); static const std::string setitem_str("__setitem__");
CompilerVariable *setitem = rtn->getattr(emitter, &setitem_str, true);
for (int i = 0; i < node->keys.size(); i++) { for (int i = 0; i < node->keys.size(); i++) {
CompilerVariable *key = evalExpr(node->keys[i]); CompilerVariable *key = evalExpr(node->keys[i]);
CompilerVariable *value = evalExpr(node->values[i]); CompilerVariable *value = evalExpr(node->values[i]);
...@@ -654,7 +635,7 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -654,7 +635,7 @@ class IRGeneratorImpl : public IRGenerator {
// Method 2 [testing-only]: (ab)uses existing getattr patchpoints and just calls module.getattr() // Method 2 [testing-only]: (ab)uses existing getattr patchpoints and just calls module.getattr()
// This option exists for performance testing because method 1 does not currently use patchpoints. // This option exists for performance testing because method 1 does not currently use patchpoints.
ConcreteCompilerVariable *mod = new ConcreteCompilerVariable(MODULE, embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_value_type_ptr), false); ConcreteCompilerVariable *mod = new ConcreteCompilerVariable(MODULE, embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_value_type_ptr), false);
CompilerVariable *attr = mod->getattr(emitter, node->id); CompilerVariable *attr = mod->getattr(emitter, &node->id, false);
mod->decvref(emitter); mod->decvref(emitter);
return attr; return attr;
} }
...@@ -822,9 +803,6 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -822,9 +803,6 @@ class IRGeneratorImpl : public IRGenerator {
case AST_TYPE::List: case AST_TYPE::List:
rtn = evalList(static_cast<AST_List*>(node)); rtn = evalList(static_cast<AST_List*>(node));
break; break;
//case AST_TYPE::ListComp:
//rtn = evalListComp(static_cast<AST_ListComp*>(node));
//break;
case AST_TYPE::Name: case AST_TYPE::Name:
rtn = evalName(static_cast<AST_Name*>(node)); rtn = evalName(static_cast<AST_Name*>(node));
break; break;
...@@ -1012,7 +990,7 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -1012,7 +990,7 @@ class IRGeneratorImpl : public IRGenerator {
if (irstate->getScopeInfo()->refersToGlobal(name)) { if (irstate->getScopeInfo()->refersToGlobal(name)) {
// TODO do something special here so that it knows to only emit a monomorphic inline cache? // TODO do something special here so that it knows to only emit a monomorphic inline cache?
ConcreteCompilerVariable* module = new ConcreteCompilerVariable(MODULE, embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_value_type_ptr), false); ConcreteCompilerVariable* module = new ConcreteCompilerVariable(MODULE, embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_value_type_ptr), false);
module->setattr(emitter, name, val); module->setattr(emitter, &name, val);
module->decvref(emitter); module->decvref(emitter);
} else { } else {
CompilerVariable* &prev = symbol_table[name]; CompilerVariable* &prev = symbol_table[name];
...@@ -1027,7 +1005,7 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -1027,7 +1005,7 @@ class IRGeneratorImpl : public IRGenerator {
void _doSetattr(AST_Attribute* target, CompilerVariable* val) { void _doSetattr(AST_Attribute* target, CompilerVariable* val) {
assert(state != PARTIAL); assert(state != PARTIAL);
CompilerVariable *t = evalExpr(target->value); CompilerVariable *t = evalExpr(target->value);
t->setattr(emitter, target->attr, val); t->setattr(emitter, &target->attr, val);
t->decvref(emitter); t->decvref(emitter);
} }
...@@ -1148,7 +1126,7 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -1148,7 +1126,7 @@ class IRGeneratorImpl : public IRGenerator {
AST_FunctionDef *fdef = static_cast<AST_FunctionDef*>(node->body[i]); AST_FunctionDef *fdef = static_cast<AST_FunctionDef*>(node->body[i]);
CLFunction *cl = this->_wrapFunction(fdef); CLFunction *cl = this->_wrapFunction(fdef);
CompilerVariable *func = makeFunction(emitter, cl); CompilerVariable *func = makeFunction(emitter, cl);
cls->setattr(emitter, fdef->name, func); cls->setattr(emitter, &fdef->name, func);
func->decvref(emitter); func->decvref(emitter);
} else { } else {
RELEASE_ASSERT(node->body[i]->type == AST_TYPE::Pass, "%d", type); RELEASE_ASSERT(node->body[i]->type == AST_TYPE::Pass, "%d", type);
...@@ -1369,7 +1347,7 @@ class IRGeneratorImpl : public IRGenerator { ...@@ -1369,7 +1347,7 @@ class IRGeneratorImpl : public IRGenerator {
for (SortedSymbolTable::iterator it = sorted_symbol_table.begin(), end = sorted_symbol_table.end(); it != end; ++it, ++i) { for (SortedSymbolTable::iterator it = sorted_symbol_table.begin(), end = sorted_symbol_table.end(); it != end; ++it, ++i) {
// I don't think this can fail, but if it can we should filter out dead symbols before // I don't think this can fail, but if it can we should filter out dead symbols before
// passing them on: // passing them on:
assert(irstate->getSourceInfo()->liveness->isLiveAtEnd(it->first, myblock)); ASSERT(irstate->getSourceInfo()->liveness->isLiveAtEnd(it->first, myblock), "%d %s", myblock->idx, it->first.c_str());
// This line can never get hit right now since we unnecessarily force every variable to be concrete // This line can never get hit right now since we unnecessarily force every variable to be concrete
// for a loop, since we generate all potential phis: // for a loop, since we generate all potential phis:
......
...@@ -127,12 +127,15 @@ static void readExprVector(std::vector<AST_expr*> &vec, BufferedReader *reader) ...@@ -127,12 +127,15 @@ static void readExprVector(std::vector<AST_expr*> &vec, BufferedReader *reader)
} }
} }
static void readMiscVector(std::vector<AST*> &vec, BufferedReader *reader) { template <class T>
static void readMiscVector(std::vector<T*> &vec, BufferedReader *reader) {
int num_elts = reader->readShort(); int num_elts = reader->readShort();
if (VERBOSITY("parsing") >= 2) if (VERBOSITY("parsing") >= 2)
printf("%d elts to read\n", num_elts); printf("%d elts to read\n", num_elts);
for (int i = 0; i < num_elts; i++) { for (int i = 0; i < num_elts; i++) {
vec.push_back(readASTMisc(reader)); AST* read = readASTMisc(reader);
assert(read->type == T::TYPE);
vec.push_back(static_cast<T*>(read));
} }
} }
...@@ -240,12 +243,7 @@ AST_Call* read_call(BufferedReader *reader) { ...@@ -240,12 +243,7 @@ AST_Call* read_call(BufferedReader *reader) {
rtn->col_offset = readColOffset(reader); rtn->col_offset = readColOffset(reader);
rtn->func = readASTExpr(reader); rtn->func = readASTExpr(reader);
std::vector<AST*> keyword_vec; readMiscVector(rtn->keywords, reader);
readMiscVector(keyword_vec, reader);
for (int i = 0; i < keyword_vec.size(); i++) {
assert(keyword_vec[i]->type == AST_TYPE::keyword);
rtn->keywords.push_back(static_cast<AST_keyword*>(keyword_vec[i]));
}
rtn->kwargs = readASTExpr(reader); rtn->kwargs = readASTExpr(reader);
rtn->lineno = reader->readULL(); rtn->lineno = reader->readULL();
...@@ -267,26 +265,18 @@ AST_expr* read_compare(BufferedReader *reader) { ...@@ -267,26 +265,18 @@ AST_expr* read_compare(BufferedReader *reader) {
rtn->ops.push_back((AST_TYPE::AST_TYPE)reader->readByte()); rtn->ops.push_back((AST_TYPE::AST_TYPE)reader->readByte());
} }
/*{ return rtn;
assert(rtn->ops.size() == 1); }
AST_Attribute *func = new AST_Attribute();
func->type = AST_TYPE::Attribute; AST_comprehension* read_comprehension(BufferedReader *reader) {
func->attr = getOpName(rtn->ops[0]); AST_comprehension *rtn = new AST_comprehension();
func->col_offset = rtn->col_offset;
func->ctx_type = AST_TYPE::Load;
func->lineno = rtn->lineno;
func->value = rtn->left;
AST_Call *call = new AST_Call(); readExprVector(rtn->ifs, reader);
call->type = AST_TYPE::Call; rtn->iter = readASTExpr(reader);
call->args.push_back(rtn->comparators[0]); rtn->target = readASTExpr(reader);
call->col_offset = rtn->col_offset;
call->func = func; rtn->col_offset = -1;
call->kwargs = NULL; rtn->lineno = -1;
call->lineno = rtn->lineno;
call->starargs = NULL;
return call;
}*/
return rtn; return rtn;
} }
...@@ -436,6 +426,16 @@ AST_List* read_list(BufferedReader *reader) { ...@@ -436,6 +426,16 @@ AST_List* read_list(BufferedReader *reader) {
return rtn; return rtn;
} }
AST_ListComp* read_listcomp(BufferedReader *reader) {
AST_ListComp *rtn = new AST_ListComp();
rtn->col_offset = readColOffset(reader);
rtn->elt = readASTExpr(reader);
readMiscVector(rtn->generators, reader);
rtn->lineno = reader->readULL();
return rtn;
}
AST_Module* read_module(BufferedReader *reader) { AST_Module* read_module(BufferedReader *reader) {
if (VERBOSITY("parsing") >= 2) if (VERBOSITY("parsing") >= 2)
printf("reading module\n"); printf("reading module\n");
...@@ -612,6 +612,8 @@ AST_expr* readASTExpr(BufferedReader *reader) { ...@@ -612,6 +612,8 @@ AST_expr* readASTExpr(BufferedReader *reader) {
return read_index(reader); return read_index(reader);
case AST_TYPE::List: case AST_TYPE::List:
return read_list(reader); return read_list(reader);
case AST_TYPE::ListComp:
return read_listcomp(reader);
case AST_TYPE::Name: case AST_TYPE::Name:
return read_name(reader); return read_name(reader);
case AST_TYPE::Num: case AST_TYPE::Num:
...@@ -698,6 +700,8 @@ AST* readASTMisc(BufferedReader *reader) { ...@@ -698,6 +700,8 @@ AST* readASTMisc(BufferedReader *reader) {
return read_alias(reader); return read_alias(reader);
case AST_TYPE::arguments: case AST_TYPE::arguments:
return read_arguments(reader); return read_arguments(reader);
case AST_TYPE::comprehension:
return read_comprehension(reader);
case AST_TYPE::keyword: case AST_TYPE::keyword:
return read_keyword(reader); return read_keyword(reader);
case AST_TYPE::Module: case AST_TYPE::Module:
......
...@@ -287,6 +287,17 @@ void* AST_Compare::accept_expr(ExprVisitor *v) { ...@@ -287,6 +287,17 @@ void* AST_Compare::accept_expr(ExprVisitor *v) {
return v->visit_compare(this); return v->visit_compare(this);
} }
void AST_comprehension::accept(ASTVisitor *v) {
bool skip = v->visit_comprehension(this);
if (skip) return;
target->accept(v);
iter->accept(v);
for (auto if_ : ifs) {
if_->accept(v);
}
}
void AST_ClassDef::accept(ASTVisitor *v) { void AST_ClassDef::accept(ASTVisitor *v) {
bool skip = v->visit_classdef(this); bool skip = v->visit_classdef(this);
if (skip) return; if (skip) return;
...@@ -436,6 +447,21 @@ void* AST_List::accept_expr(ExprVisitor *v) { ...@@ -436,6 +447,21 @@ void* AST_List::accept_expr(ExprVisitor *v) {
return v->visit_list(this); return v->visit_list(this);
} }
void AST_ListComp::accept(ASTVisitor *v) {
bool skip = v->visit_listcomp(this);
if (skip) return;
for (auto c : generators) {
c->accept(v);
}
elt->accept(v);
}
void* AST_ListComp::accept_expr(ExprVisitor *v) {
return v->visit_listcomp(this);
}
void AST_Module::accept(ASTVisitor *v) { void AST_Module::accept(ASTVisitor *v) {
bool skip = v->visit_module(this); bool skip = v->visit_module(this);
if (skip) return; if (skip) return;
...@@ -781,6 +807,20 @@ bool PrintVisitor::visit_compare(AST_Compare *node) { ...@@ -781,6 +807,20 @@ bool PrintVisitor::visit_compare(AST_Compare *node) {
return true; return true;
} }
bool PrintVisitor::visit_comprehension(AST_comprehension *node) {
printf("for ");
node->target->accept(this);
printf(" in ");
node->iter->accept(this);
for (AST_expr *i : node->ifs) {
printf(" if ");
i->accept(this);
}
return true;
}
bool PrintVisitor::visit_classdef(AST_ClassDef *node) { bool PrintVisitor::visit_classdef(AST_ClassDef *node) {
for (int i = 0, n = node->decorator_list.size(); i < n; i++) { for (int i = 0, n = node->decorator_list.size(); i < n; i++) {
printf("@"); printf("@");
...@@ -928,6 +968,17 @@ bool PrintVisitor::visit_list(AST_List *node) { ...@@ -928,6 +968,17 @@ bool PrintVisitor::visit_list(AST_List *node) {
return true; return true;
} }
bool PrintVisitor::visit_listcomp(AST_ListComp *node) {
printf("[");
node->elt->accept(this);
for (auto c : node->generators) {
printf(" ");
c->accept(this);
}
printf("]");
return true;
}
bool PrintVisitor::visit_keyword(AST_keyword *node) { bool PrintVisitor::visit_keyword(AST_keyword *node) {
printf("%s=", node->arg.c_str()); printf("%s=", node->arg.c_str());
node->value->accept(this); node->value->accept(this);
...@@ -1135,6 +1186,7 @@ class FlattenVisitor : public ASTVisitor { ...@@ -1135,6 +1186,7 @@ class FlattenVisitor : public ASTVisitor {
virtual bool visit_call(AST_Call *node) { output->push_back(node); return false; } virtual bool visit_call(AST_Call *node) { output->push_back(node); return false; }
virtual bool visit_classdef(AST_ClassDef *node) { output->push_back(node); return !expand_scopes; } virtual bool visit_classdef(AST_ClassDef *node) { output->push_back(node); return !expand_scopes; }
virtual bool visit_compare(AST_Compare *node) { output->push_back(node); return false; } virtual bool visit_compare(AST_Compare *node) { output->push_back(node); return false; }
virtual bool visit_comprehension(AST_comprehension *node) { output->push_back(node); return false; }
virtual bool visit_continue(AST_Continue *node) { output->push_back(node); return false; } virtual bool visit_continue(AST_Continue *node) { output->push_back(node); return false; }
virtual bool visit_dict(AST_Dict *node) { output->push_back(node); return false; } virtual bool visit_dict(AST_Dict *node) { output->push_back(node); return false; }
virtual bool visit_expr(AST_Expr *node) { output->push_back(node); return false; } virtual bool visit_expr(AST_Expr *node) { output->push_back(node); return false; }
...@@ -1145,6 +1197,7 @@ class FlattenVisitor : public ASTVisitor { ...@@ -1145,6 +1197,7 @@ class FlattenVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node) { output->push_back(node); return false; } virtual bool visit_index(AST_Index *node) { output->push_back(node); return false; }
virtual bool visit_keyword(AST_keyword *node) { output->push_back(node); return false; } virtual bool visit_keyword(AST_keyword *node) { output->push_back(node); return false; }
virtual bool visit_list(AST_List *node) { output->push_back(node); return false; } virtual bool visit_list(AST_List *node) { output->push_back(node); return false; }
virtual bool visit_listcomp(AST_ListComp *node) { output->push_back(node); return false; }
virtual bool visit_module(AST_Module *node) { output->push_back(node); return !expand_scopes; } virtual bool visit_module(AST_Module *node) { output->push_back(node); return !expand_scopes; }
virtual bool visit_name(AST_Name *node) { output->push_back(node); return false; } virtual bool visit_name(AST_Name *node) { output->push_back(node); return false; }
virtual bool visit_num(AST_Num *node) { output->push_back(node); return false; } virtual bool visit_num(AST_Num *node) { output->push_back(node); return false; }
...@@ -1174,4 +1227,12 @@ std::vector<AST*>* flatten(std::vector<AST_stmt*> &roots, bool expand_scopes) { ...@@ -1174,4 +1227,12 @@ std::vector<AST*>* flatten(std::vector<AST_stmt*> &roots, bool expand_scopes) {
return rtn; return rtn;
} }
std::vector<AST*>* flatten(AST_expr* root, bool expand_scopes) {
std::vector<AST*> *rtn = new std::vector<AST*>();
FlattenVisitor visitor(rtn, expand_scopes);
root->accept(&visitor);
return rtn;
}
} }
...@@ -249,6 +249,8 @@ class AST_Call : public AST_expr { ...@@ -249,6 +249,8 @@ class AST_Call : public AST_expr {
virtual void* accept_expr(ExprVisitor *v); virtual void* accept_expr(ExprVisitor *v);
AST_Call() : AST_expr(AST_TYPE::Call) {} AST_Call() : AST_expr(AST_TYPE::Call) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Call;
}; };
class AST_Compare : public AST_expr { class AST_Compare : public AST_expr {
...@@ -263,6 +265,19 @@ class AST_Compare : public AST_expr { ...@@ -263,6 +265,19 @@ class AST_Compare : public AST_expr {
AST_Compare() : AST_expr(AST_TYPE::Compare) {} AST_Compare() : AST_expr(AST_TYPE::Compare) {}
}; };
class AST_comprehension : public AST {
public:
AST_expr* target;
AST_expr* iter;
std::vector<AST_expr*> ifs;
virtual void accept(ASTVisitor *v);
AST_comprehension() : AST(AST_TYPE::comprehension) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::comprehension;
};
class AST_ClassDef : public AST_stmt { class AST_ClassDef : public AST_stmt {
public: public:
virtual void accept(ASTVisitor *v); virtual void accept(ASTVisitor *v);
...@@ -391,6 +406,8 @@ class AST_keyword : public AST { ...@@ -391,6 +406,8 @@ class AST_keyword : public AST {
virtual void accept(ASTVisitor *v); virtual void accept(ASTVisitor *v);
AST_keyword() : AST(AST_TYPE::keyword) {} AST_keyword() : AST(AST_TYPE::keyword) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::keyword;
}; };
class AST_List : public AST_expr { class AST_List : public AST_expr {
...@@ -406,6 +423,19 @@ class AST_List : public AST_expr { ...@@ -406,6 +423,19 @@ class AST_List : public AST_expr {
AST_List() : AST_expr(AST_TYPE::List) {} AST_List() : AST_expr(AST_TYPE::List) {}
}; };
class AST_ListComp : public AST_expr {
public:
const static AST_TYPE::AST_TYPE TYPE = AST_TYPE::ListComp;
std::vector<AST_comprehension*> generators;
AST_expr* elt;
virtual void accept(ASTVisitor *v);
virtual void* accept_expr(ExprVisitor *v);
AST_ListComp() : AST_expr(AST_TYPE::ListComp) {}
};
class AST_Module : public AST { class AST_Module : public AST {
public: public:
// no lineno, col_offset attributes // no lineno, col_offset attributes
...@@ -607,6 +637,7 @@ class ASTVisitor { ...@@ -607,6 +637,7 @@ class ASTVisitor {
virtual bool visit_call(AST_Call *node) { assert(0); abort(); } virtual bool visit_call(AST_Call *node) { assert(0); abort(); }
virtual bool visit_clsattribute(AST_ClsAttribute *node) { assert(0); abort(); } virtual bool visit_clsattribute(AST_ClsAttribute *node) { assert(0); abort(); }
virtual bool visit_compare(AST_Compare *node) { assert(0); abort(); } virtual bool visit_compare(AST_Compare *node) { assert(0); abort(); }
virtual bool visit_comprehension(AST_comprehension *node) { assert(0); abort(); }
virtual bool visit_classdef(AST_ClassDef *node) { assert(0); abort(); } virtual bool visit_classdef(AST_ClassDef *node) { assert(0); abort(); }
virtual bool visit_continue(AST_Continue *node) { assert(0); abort(); } virtual bool visit_continue(AST_Continue *node) { assert(0); abort(); }
virtual bool visit_dict(AST_Dict *node) { assert(0); abort(); } virtual bool visit_dict(AST_Dict *node) { assert(0); abort(); }
...@@ -620,6 +651,7 @@ class ASTVisitor { ...@@ -620,6 +651,7 @@ class ASTVisitor {
virtual bool visit_index(AST_Index *node) { assert(0); abort(); } virtual bool visit_index(AST_Index *node) { assert(0); abort(); }
virtual bool visit_keyword(AST_keyword *node) { assert(0); abort(); } virtual bool visit_keyword(AST_keyword *node) { assert(0); abort(); }
virtual bool visit_list(AST_List *node) { assert(0); abort(); } virtual bool visit_list(AST_List *node) { assert(0); abort(); }
virtual bool visit_listcomp(AST_ListComp *node) { assert(0); abort(); }
virtual bool visit_module(AST_Module *node) { assert(0); abort(); } virtual bool visit_module(AST_Module *node) { assert(0); abort(); }
virtual bool visit_name(AST_Name *node) { assert(0); abort(); } virtual bool visit_name(AST_Name *node) { assert(0); abort(); }
virtual bool visit_num(AST_Num *node) { assert(0); abort(); } virtual bool visit_num(AST_Num *node) { assert(0); abort(); }
...@@ -654,6 +686,7 @@ class NoopASTVisitor : public ASTVisitor { ...@@ -654,6 +686,7 @@ class NoopASTVisitor : public ASTVisitor {
virtual bool visit_call(AST_Call *node) { return false; } virtual bool visit_call(AST_Call *node) { return false; }
virtual bool visit_clsattribute(AST_ClsAttribute *node) { return false; } virtual bool visit_clsattribute(AST_ClsAttribute *node) { return false; }
virtual bool visit_compare(AST_Compare *node) { return false; } virtual bool visit_compare(AST_Compare *node) { return false; }
virtual bool visit_comprehension(AST_comprehension *node) { return false; }
virtual bool visit_classdef(AST_ClassDef *node) { return false; } virtual bool visit_classdef(AST_ClassDef *node) { return false; }
virtual bool visit_continue(AST_Continue *node) { return false; } virtual bool visit_continue(AST_Continue *node) { return false; }
virtual bool visit_dict(AST_Dict *node) { return false; } virtual bool visit_dict(AST_Dict *node) { return false; }
...@@ -667,6 +700,7 @@ class NoopASTVisitor : public ASTVisitor { ...@@ -667,6 +700,7 @@ class NoopASTVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node) { return false; } virtual bool visit_index(AST_Index *node) { return false; }
virtual bool visit_keyword(AST_keyword *node) { return false; } virtual bool visit_keyword(AST_keyword *node) { return false; }
virtual bool visit_list(AST_List *node) { return false; } virtual bool visit_list(AST_List *node) { return false; }
virtual bool visit_listcomp(AST_ListComp *node) { return false; }
virtual bool visit_module(AST_Module *node) { return false; } virtual bool visit_module(AST_Module *node) { return false; }
virtual bool visit_name(AST_Name *node) { return false; } virtual bool visit_name(AST_Name *node) { return false; }
virtual bool visit_num(AST_Num *node) { return false; } virtual bool visit_num(AST_Num *node) { return false; }
...@@ -700,6 +734,7 @@ class ExprVisitor { ...@@ -700,6 +734,7 @@ class ExprVisitor {
virtual void* visit_ifexp(AST_IfExp *node) { assert(0); abort(); } virtual void* visit_ifexp(AST_IfExp *node) { assert(0); abort(); }
virtual void* visit_index(AST_Index *node) { assert(0); abort(); } virtual void* visit_index(AST_Index *node) { assert(0); abort(); }
virtual void* visit_list(AST_List *node) { assert(0); abort(); } virtual void* visit_list(AST_List *node) { assert(0); abort(); }
virtual void* visit_listcomp(AST_ListComp *node) { assert(0); abort(); }
virtual void* visit_name(AST_Name *node) { assert(0); abort(); } virtual void* visit_name(AST_Name *node) { assert(0); abort(); }
virtual void* visit_num(AST_Num *node) { assert(0); abort(); } virtual void* visit_num(AST_Num *node) { assert(0); abort(); }
virtual void* visit_slice(AST_Slice *node) { assert(0); abort(); } virtual void* visit_slice(AST_Slice *node) { assert(0); abort(); }
...@@ -754,6 +789,7 @@ class PrintVisitor : public ASTVisitor { ...@@ -754,6 +789,7 @@ class PrintVisitor : public ASTVisitor {
virtual bool visit_break(AST_Break *node); virtual bool visit_break(AST_Break *node);
virtual bool visit_call(AST_Call *node); virtual bool visit_call(AST_Call *node);
virtual bool visit_compare(AST_Compare *node); virtual bool visit_compare(AST_Compare *node);
virtual bool visit_comprehension(AST_comprehension *node);
virtual bool visit_classdef(AST_ClassDef *node); virtual bool visit_classdef(AST_ClassDef *node);
virtual bool visit_clsattribute(AST_ClsAttribute *node); virtual bool visit_clsattribute(AST_ClsAttribute *node);
virtual bool visit_continue(AST_Continue *node); virtual bool visit_continue(AST_Continue *node);
...@@ -768,6 +804,7 @@ class PrintVisitor : public ASTVisitor { ...@@ -768,6 +804,7 @@ class PrintVisitor : public ASTVisitor {
virtual bool visit_index(AST_Index *node); virtual bool visit_index(AST_Index *node);
virtual bool visit_keyword(AST_keyword *node); virtual bool visit_keyword(AST_keyword *node);
virtual bool visit_list(AST_List *node); virtual bool visit_list(AST_List *node);
virtual bool visit_listcomp(AST_ListComp *node);
virtual bool visit_module(AST_Module *node); virtual bool visit_module(AST_Module *node);
virtual bool visit_name(AST_Name *node); virtual bool visit_name(AST_Name *node);
virtual bool visit_num(AST_Num *node); virtual bool visit_num(AST_Num *node);
...@@ -790,9 +827,10 @@ class PrintVisitor : public ASTVisitor { ...@@ -790,9 +827,10 @@ class PrintVisitor : public ASTVisitor {
// This is useful for analyses that care more about the constituent nodes than the // This is useful for analyses that care more about the constituent nodes than the
// exact tree structure; ex, finding all "global" directives. // exact tree structure; ex, finding all "global" directives.
std::vector<AST*>* flatten(std::vector<AST_stmt*> &roots, bool expand_scopes); std::vector<AST*>* flatten(std::vector<AST_stmt*> &roots, bool expand_scopes);
std::vector<AST*>* flatten(AST_expr *root, bool expand_scopes);
// Similar to the flatten() function, but filters for a specific type of ast nodes: // Similar to the flatten() function, but filters for a specific type of ast nodes:
template <class T> template <class T, class R>
std::vector<T*>* findNodes(std::vector<AST_stmt*> &roots, bool expand_scopes) { std::vector<T*>* findNodes(const R &roots, bool expand_scopes) {
std::vector<T*> *rtn = new std::vector<T*>(); std::vector<T*> *rtn = new std::vector<T*>();
std::vector<AST*> *flattened = flatten(roots, expand_scopes); std::vector<AST*> *flattened = flatten(roots, expand_scopes);
for (int i = 0; i < flattened->size(); i++) { for (int i = 0; i < flattened->size(); i++) {
......
...@@ -153,6 +153,17 @@ class CFGVisitor : public ASTVisitor { ...@@ -153,6 +153,17 @@ class CFGVisitor : public ASTVisitor {
return call; return call;
} }
AST_Call* makeCall(AST_expr* func, AST_expr* arg0) {
AST_Call *call = new AST_Call();
call->args.push_back(arg0);
call->starargs = NULL;
call->kwargs = NULL;
call->func = func;
call->col_offset = func->col_offset;
call->lineno = func->lineno;
return call;
}
AST_Name* makeName(const std::string &id, AST_TYPE::AST_TYPE ctx_type, int lineno=-1, int col_offset=-1) { AST_Name* makeName(const std::string &id, AST_TYPE::AST_TYPE ctx_type, int lineno=-1, int col_offset=-1) {
AST_Name *name = new AST_Name(); AST_Name *name = new AST_Name();
name->id = id; name->id = id;
...@@ -193,6 +204,12 @@ class CFGVisitor : public ASTVisitor { ...@@ -193,6 +204,12 @@ class CFGVisitor : public ASTVisitor {
return std::string(buf); return std::string(buf);
} }
std::string nodeName(AST_expr* node, const std::string &suffix, int idx) {
char buf[50];
snprintf(buf, 50, "!%p_%s_%d", node, suffix.c_str(), idx);
return std::string(buf);
}
AST_expr* remapAttribute(AST_Attribute* node) { AST_expr* remapAttribute(AST_Attribute* node) {
AST_Attribute *rtn = new AST_Attribute(); AST_Attribute *rtn = new AST_Attribute();
...@@ -386,6 +403,135 @@ class CFGVisitor : public ASTVisitor { ...@@ -386,6 +403,135 @@ class CFGVisitor : public ASTVisitor {
return rtn; return rtn;
} }
AST_expr* remapListComp(AST_ListComp* node) {
std::string rtn_name = nodeName(node);
push_back(makeAssign(rtn_name, new AST_List()));
std::vector<CFGBlock*> exit_blocks;
// Where the current level should jump to after finishing its iteration.
// For the outermost comprehension, this is NULL, and it doesn't jump anywhere;
// for the inner comprehensions, they should jump to the next-outer comprehension
// when they are done iterating.
CFGBlock *finished_block = NULL;
for (int i = 0, n = node->generators.size(); i < n; i++) {
AST_comprehension *c = node->generators[i];
bool is_innermost = (i == n-1);
AST_expr *remapped_iter = remapExpr(c->iter);
AST_expr *iter_attr = makeLoadAttribute(remapped_iter, "__iter__", true);
AST_expr *iter_call = makeCall(iter_attr);
std::string iter_name = nodeName(node, "iter", i);
AST_stmt *iter_assign = makeAssign(iter_name, iter_call);
push_back(iter_assign);
// TODO bad to save these like this?
AST_expr *hasnext_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "__hasnext__", true);
AST_expr *next_attr = makeLoadAttribute(makeName(iter_name, AST_TYPE::Load), "next", true);
AST_Jump *j;
CFGBlock *test_block = cfg->addBlock();
test_block->info = "listcomp_test";
//printf("Test block for comp %d is %d\n", i, test_block->idx);
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block);
push_back(j);
curblock = test_block;
AST_expr *test_call = makeCall(hasnext_attr);
CFGBlock* body_block = cfg->addBlock();
body_block->info = "listcomp_body";
CFGBlock* exit_block = cfg->addDeferredBlock();
exit_block->info = "listcomp_exit";
exit_blocks.push_back(exit_block);
//printf("Body block for comp %d is %d\n", i, body_block->idx);
AST_Branch *br = new AST_Branch();
br->col_offset = node->col_offset;
br->lineno = node->lineno;
br->test = test_call;
br->iftrue = body_block;
br->iffalse = exit_block;
curblock->connectTo(body_block);
curblock->connectTo(exit_block);
push_back(br);
curblock = body_block;
push_back(makeAssign(c->target, makeCall(next_attr)));
for (AST_expr *if_condition : c->ifs) {
AST_expr *remapped = remapExpr(if_condition);
AST_Branch *br = new AST_Branch();
br->test = remapped;
push_back(br);
// Put this below the entire body?
CFGBlock *body_tramp = cfg->addBlock();
body_tramp->info = "listcomp_if_trampoline";
//printf("body_tramp for %d is %d\n", i, body_tramp->idx);
CFGBlock *body_continue = cfg->addBlock();
body_continue->info = "listcomp_if_continue";
//printf("body_continue for %d is %d\n", i, body_continue->idx);
br->iffalse = body_tramp;
curblock->connectTo(body_tramp);
br->iftrue = body_continue;
curblock->connectTo(body_continue);
curblock = body_tramp;
j = new AST_Jump();
j->target = test_block;
push_back(j);
curblock->connectTo(test_block, true);
curblock = body_continue;
}
CFGBlock *body_end = curblock;
assert((finished_block != NULL) == (i != 0));
if (finished_block) {
curblock = exit_block;
j = new AST_Jump();
j->target = finished_block;
curblock->connectTo(finished_block, true);
push_back(j);
}
finished_block = test_block;
curblock = body_end;
if (is_innermost) {
AST_expr *elt = remapExpr(node->elt);
push_back(makeExpr(makeCall(makeLoadAttribute(makeName(rtn_name, AST_TYPE::Load), "append", true), elt)));
j = new AST_Jump();
j->target = test_block;
curblock->connectTo(test_block, true);
push_back(j);
assert(exit_blocks.size());
curblock = exit_blocks[0];
} else {
// continue onto the next comprehension and add to this body
}
}
// Wait until the end to place the end blocks, so that
// we get a nice nesting structure, that looks similar to what
// you'd get with a nested for loop:
for (int i = exit_blocks.size() - 1; i >= 0; i--) {
cfg->placeBlock(exit_blocks[i]);
printf("Exit block for comp %d is %d\n", i, exit_blocks[i]->idx);
}
return makeName(rtn_name, AST_TYPE::Load);
};
AST_expr* remapSlice(AST_Slice* node) { AST_expr* remapSlice(AST_Slice* node) {
AST_Slice *rtn = new AST_Slice(); AST_Slice *rtn = new AST_Slice();
rtn->lineno = node->lineno; rtn->lineno = node->lineno;
...@@ -464,6 +610,9 @@ class CFGVisitor : public ASTVisitor { ...@@ -464,6 +610,9 @@ class CFGVisitor : public ASTVisitor {
case AST_TYPE::List: case AST_TYPE::List:
rtn = remapList(static_cast<AST_List*>(node)); rtn = remapList(static_cast<AST_List*>(node));
break; break;
case AST_TYPE::ListComp:
rtn = remapListComp(static_cast<AST_ListComp*>(node));
break;
case AST_TYPE::Name: case AST_TYPE::Name:
return node; return node;
case AST_TYPE::Num: case AST_TYPE::Num:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment