Commit 7f2d28c7 authored by Kevin Modzelewski's avatar Kevin Modzelewski

Implement basic assert support

parent 5e0bfd08
......@@ -174,6 +174,7 @@ class DefinednessVisitor : public ASTVisitor {
DefinednessVisitor(Map &state) : state(state) {
}
virtual bool visit_assert(AST_Assert* node) { return true; }
virtual bool visit_branch(AST_Branch* node) { return true; }
virtual bool visit_expr(AST_Expr* node) { return true; }
virtual bool visit_global(AST_Global* node) { return true; }
......
......@@ -395,6 +395,12 @@ class BasicBlockTypePropagator : public ExprVisitor, public StmtVisitor {
virtual void visit_assert(AST_Assert* node) {
getType(node->test);
if (node->msg)
getType(node->msg);
}
virtual void visit_assign(AST_Assign* node) {
CompilerType* t = getType(node->value);
for (int i = 0; i < node->targets.size(); i++) {
......
......@@ -669,7 +669,7 @@ class IntType : public ConcreteCompilerType {
} _INT;
ConcreteCompilerType *INT = &_INT;
CompilerVariable* makeInt(int64_t n) {
ConcreteCompilerVariable* makeInt(int64_t n) {
return new ConcreteCompilerVariable(INT, llvm::ConstantInt::get(g.i64, n, true), true);
}
......@@ -761,7 +761,7 @@ class FloatType : public ConcreteCompilerType {
} _FLOAT;
ConcreteCompilerType *FLOAT = &_FLOAT;
CompilerVariable* makeFloat(double d) {
ConcreteCompilerVariable* makeFloat(double d) {
return new ConcreteCompilerVariable(FLOAT, llvm::ConstantFP::get(g.double_, d), true);
}
......@@ -1032,18 +1032,23 @@ class StrConstantType : public ValuedCompilerType<std::string*> {
std::string debugName() {
return "str_constant";
}
virtual ConcreteCompilerType* getConcreteType() {
return STR;
}
virtual ConcreteCompilerType* getBoxType() {
return STR;
}
virtual void drop(IREmitter &emitter, VAR *var) {
// pass
}
virtual void grab(IREmitter &emitter, VAR *var) {
// pass
}
virtual void print(IREmitter &emitter, ValuedCompilerVariable<std::string*> *value) {
llvm::Constant* ptr = getStringConstantPtr(*(value->getValue()) + '\0');
llvm::Constant* fmt = getStringConstantPtr("%s\0");
......@@ -1081,6 +1086,10 @@ class StrConstantType : public ValuedCompilerType<std::string*> {
return rtn;
}
ConcreteCompilerVariable *nonzero(IREmitter &emitter, const OpInfo& info, VAR *var) override {
return makeBool(var->getValue()->size() != 0);
}
virtual CompilerVariable* dup(VAR *var, DupCache &cache) {
CompilerVariable* &rtn = cache[var];
......@@ -1152,7 +1161,7 @@ class BoolType : public ConcreteCompilerType {
}
};
ConcreteCompilerType *BOOL = new BoolType();
CompilerVariable* makeBool(bool b) {
ConcreteCompilerVariable* makeBool(bool b) {
return new ConcreteCompilerVariable(BOOL, llvm::ConstantInt::get(g.i1, b, false), true);
}
......
......@@ -313,9 +313,9 @@ class ValuedCompilerVariable : public CompilerVariable {
//assert(value->getType() == type->llvmType());
//}
CompilerVariable* makeInt(int64_t);
CompilerVariable* makeFloat(double);
CompilerVariable* makeBool(bool);
ConcreteCompilerVariable* makeInt(int64_t);
ConcreteCompilerVariable* makeFloat(double);
ConcreteCompilerVariable* makeBool(bool);
CompilerVariable* makeStr(std::string*);
CompilerVariable* makeFunction(IREmitter &emitter, CLFunction*);
CompilerVariable* undefVariable();
......
......@@ -1122,6 +1122,29 @@ class IRGeneratorImpl : public IRGenerator {
}
}
void doAssert(AST_Assert *node) {
AST_expr* test = node->test;
assert(test->type == AST_TYPE::Num);
AST_Num* num = ast_cast<AST_Num>(test);
assert(num->num_type == AST_Num::INT);
assert(num->n_int == 0);
std::vector<llvm::Value*> llvm_args;
llvm_args.push_back(embedConstantPtr(irstate->getSourceInfo()->parent_module, g.llvm_module_type_ptr));
ConcreteCompilerVariable *converted_msg = NULL;
if (node->msg) {
CompilerVariable *msg = evalExpr(node->msg);
converted_msg = msg->makeConverted(emitter, msg->getBoxType());
msg->decvref(emitter);
llvm_args.push_back(converted_msg->getValue());
} else {
llvm_args.push_back(embedConstantPtr(NULL, g.llvm_value_type_ptr));
}
llvm::CallInst *call = emitter.getBuilder()->CreateCall(g.funcs.assertFail, llvm_args);
call->setDoesNotReturn();
}
void doAssign(AST_Assign *node) {
CompilerVariable *val = evalExpr(node->value);
if (state == PARTIAL)
......@@ -1496,6 +1519,9 @@ class IRGeneratorImpl : public IRGenerator {
void doStmt(AST *node) {
switch (node->type) {
case AST_TYPE::Assert:
doAssert(ast_cast<AST_Assert>(node));
break;
case AST_TYPE::Assign:
doAssign(ast_cast<AST_Assign>(node));
break;
......
......@@ -175,6 +175,16 @@ AST_arguments* read_arguments(BufferedReader *reader) {
return rtn;
}
AST_Assert* read_assert(BufferedReader *reader) {
AST_Assert *rtn = new AST_Assert();
rtn->col_offset = readColOffset(reader);
rtn->lineno = reader->readULL();
rtn->msg = readASTExpr(reader);
rtn->test = readASTExpr(reader);
return rtn;
}
AST_Assign* read_assign(BufferedReader *reader) {
AST_Assign *rtn = new AST_Assign();
......@@ -654,6 +664,8 @@ AST_stmt* readASTStmt(BufferedReader *reader) {
assert(checkbyte == 0xae);
switch (type) {
case AST_TYPE::Assert:
return read_assert(reader);
case AST_TYPE::Assign:
return read_assign(reader);
case AST_TYPE::AugAssign:
......
......@@ -165,6 +165,7 @@ void initGlobalFuncs(GlobalState &g) {
GET(raiseAttributeErrorStr);
GET(raiseNotIterableError);
GET(assertNameDefined);
GET(assertFail);
GET(printFloat);
GET(listAppendInternal);
......
......@@ -22,7 +22,7 @@ struct GlobalFuncs {
llvm::Value *boxInt, *unboxInt, *boxFloat, *unboxFloat, *boxStringPtr, *boxCLFunction, *unboxCLFunction, *boxInstanceMethod, *boxBool, *unboxBool, *createTuple, *createDict, *createList, *createSlice, *createClass;
llvm::Value *getattr, *setattr, *print, *nonzero, *binop, *compare, *augbinop, *unboxedLen, *getitem, *getclsattr, *getGlobal, *setitem, *unaryop, *import;
llvm::Value *checkUnpackingLength, *raiseAttributeError, *raiseAttributeErrorStr, *raiseNotIterableError, *assertNameDefined;
llvm::Value *checkUnpackingLength, *raiseAttributeError, *raiseAttributeErrorStr, *raiseNotIterableError, *assertNameDefined, *assertFail;
llvm::Value *printFloat, *listAppendInternal;
llvm::Value *dump;
llvm::Value *runtimeCall0, *runtimeCall1, *runtimeCall2, *runtimeCall3, *runtimeCall;
......
......@@ -194,6 +194,18 @@ void AST_arguments::accept(ASTVisitor *v) {
if (kwarg) kwarg->accept(v);
}
void AST_Assert::accept(ASTVisitor *v) {
bool skip = v->visit_assert(this);
if (skip) return;
test->accept(v);
if (msg) msg->accept(v);
}
void AST_Assert::accept_stmt(StmtVisitor *v) {
v->visit_assert(this);
}
void AST_Assign::accept(ASTVisitor *v) {
bool skip = v->visit_assign(this);
if (skip) return;
......@@ -698,6 +710,16 @@ bool PrintVisitor::visit_arguments(AST_arguments *node) {
return true;
}
bool PrintVisitor::visit_assert(AST_Assert *node) {
printf("assert ");
node->test->accept(this);
if (node->msg) {
printf(", ");
node->msg->accept(this);
}
return true;
}
bool PrintVisitor::visit_assign(AST_Assign *node) {
for (int i = 0; i < node->targets.size(); i++) {
node->targets[i]->accept(this);
......@@ -1225,6 +1247,7 @@ class FlattenVisitor : public ASTVisitor {
virtual bool visit_alias(AST_alias *node) { output->push_back(node); return false; }
virtual bool visit_arguments(AST_arguments *node) { output->push_back(node); return false; }
virtual bool visit_assert(AST_Assert *node) { output->push_back(node); return false; }
virtual bool visit_assign(AST_Assign *node) { output->push_back(node); return false; }
virtual bool visit_augassign(AST_AugAssign *node) { output->push_back(node); return false; }
virtual bool visit_augbinop(AST_AugBinOp *node) { output->push_back(node); return false; }
......@@ -1243,6 +1266,7 @@ class FlattenVisitor : public ASTVisitor {
virtual bool visit_functiondef(AST_FunctionDef *node) { output->push_back(node); return !expand_scopes; }
virtual bool visit_global(AST_Global *node) { output->push_back(node); return false; }
virtual bool visit_if(AST_If *node) { output->push_back(node); return false; }
virtual bool visit_ifexp(AST_IfExp *node) { output->push_back(node); return false; }
virtual bool visit_import(AST_Import *node) { output->push_back(node); return false; }
virtual bool visit_importfrom(AST_ImportFrom *node) { output->push_back(node); return false; }
virtual bool visit_index(AST_Index *node) { output->push_back(node); return false; }
......
......@@ -181,6 +181,18 @@ class AST_arguments : public AST {
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::arguments;
};
class AST_Assert : public AST_stmt {
public:
AST_expr *msg, *test;
virtual void accept(ASTVisitor *v);
virtual void accept_stmt(StmtVisitor *v);
AST_Assert() : AST_stmt(AST_TYPE::Assert) {}
static const AST_TYPE::AST_TYPE TYPE = AST_TYPE::Assert;
};
class AST_Assign : public AST_stmt {
public:
std::vector<AST_expr*> targets;
......@@ -732,6 +744,7 @@ class ASTVisitor {
virtual bool visit_alias(AST_alias *node) { assert(0); abort(); }
virtual bool visit_arguments(AST_arguments *node) { assert(0); abort(); }
virtual bool visit_assert(AST_Assert *node) { assert(0); abort(); }
virtual bool visit_assign(AST_Assign *node) { assert(0); abort(); }
virtual bool visit_augassign(AST_AugAssign *node) { assert(0); abort(); }
virtual bool visit_augbinop(AST_AugBinOp *node) { assert(0); abort(); }
......@@ -783,6 +796,7 @@ class NoopASTVisitor : public ASTVisitor {
virtual bool visit_alias(AST_alias *node) { return false; }
virtual bool visit_arguments(AST_arguments *node) { return false; }
virtual bool visit_assert(AST_Assert *node) { return false; }
virtual bool visit_assign(AST_Assign *node) { return false; }
virtual bool visit_augassign(AST_AugAssign *node) { return false; }
virtual bool visit_augbinop(AST_AugBinOp *node) { return false; }
......@@ -858,6 +872,7 @@ class StmtVisitor {
public:
virtual ~StmtVisitor() {}
virtual void visit_assert(AST_Assert *node) { assert(0); abort(); }
virtual void visit_assign(AST_Assign *node) { assert(0); abort(); }
virtual void visit_augassign(AST_AugAssign *node) { assert(0); abort(); }
virtual void visit_break(AST_Break *node) { assert(0); abort(); }
......@@ -891,6 +906,7 @@ class PrintVisitor : public ASTVisitor {
virtual bool visit_alias(AST_alias *node);
virtual bool visit_arguments(AST_arguments *node);
virtual bool visit_assert(AST_Assert *node);
virtual bool visit_assign(AST_Assign *node);
virtual bool visit_augassign(AST_AugAssign *node);
virtual bool visit_augbinop(AST_AugBinOp *node);
......
......@@ -677,6 +677,55 @@ class CFGVisitor : public ASTVisitor {
virtual bool visit_importfrom(AST_ImportFrom* node) { push_back(node); return true; }
virtual bool visit_pass(AST_Pass* node) { return true; }
bool visit_assert(AST_Assert* node) override {
AST_Branch* br = new AST_Branch();
br->test = remapExpr(node->test);
push_back(br);
CFGBlock *iffalse = cfg->addBlock();
iffalse->info = "assert_fail";
curblock->connectTo(iffalse);
CFGBlock *iftrue = cfg->addBlock();
iftrue->info = "assert_pass";
curblock->connectTo(iftrue);
br->iftrue = iftrue;
br->iffalse = iffalse;
curblock = iffalse;
// The rest of this is pretty hacky:
// Emit a "assert(0, msg()); while (1) {}" section that basically captures
// what the assert will do but in a very hacky way.
AST_Assert* remapped = new AST_Assert();
if (node->msg)
remapped->msg = remapExpr(node->msg);
else
remapped->msg = NULL;
AST_Num* fake_test = new AST_Num();
fake_test->num_type = AST_Num::INT;
fake_test->n_int = 0;
remapped->test = fake_test;
remapped->lineno = node->lineno;
remapped->col_offset = node->col_offset;
push_back(remapped);
CFGBlock* unreachable = cfg->addBlock();
unreachable->info = "unreachable";
curblock->connectTo(unreachable);
AST_Jump* j = new AST_Jump();
j->target = unreachable;
push_back(j);
curblock = unreachable;
push_back(j);
curblock->connectTo(unreachable, true);
curblock = iftrue;
return true;
}
virtual bool visit_assign(AST_Assign* node) {
AST_Assign* remapped = new AST_Assign();
remapped->lineno = node->lineno;
......
......@@ -76,6 +76,7 @@ void force() {
FORCE(raiseAttributeErrorStr);
FORCE(raiseNotIterableError);
FORCE(assertNameDefined);
FORCE(assertFail);
FORCE(printFloat);
FORCE(listAppendInternal);
......
......@@ -226,6 +226,17 @@ extern "C" void my_assert(bool b) {
assert(b);
}
extern "C" void assertFail(BoxedModule *inModule, Box *msg) {
if (msg) {
BoxedString *tostr = str(msg);
fprintf(stderr, "AssertionError: %s\n", tostr->s.c_str());
raiseExc();
} else {
fprintf(stderr, "AssertionError\n");
raiseExc();
}
}
extern "C" void assertNameDefined(bool b, const char* name) {
if (!b) {
fprintf(stderr, "UnboundLocalError: local variable '%s' referenced before assignment\n", name);
......@@ -1840,7 +1851,7 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs *rewrit
if (op_type == AST_TYPE::LtE)
return boxBool(cmp1 <= cmp2);
}
RELEASE_ASSERT(0, "");
RELEASE_ASSERT(0, "%d", op_type);
}
extern "C" Box* compare(Box* lhs, Box* rhs, int op_type) {
......
......@@ -63,6 +63,7 @@ extern "C" Box* unaryop(Box* operand, int op_type);
extern "C" Box* import(const std::string *name);
extern "C" void checkUnpackingLength(i64 expected, i64 given);
extern "C" void assertNameDefined(bool b, const char* name);
extern "C" void assertFail(BoxedModule *inModule, Box *msg);
struct CompareRewriteArgs;
Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs *rewrite_args);
......
def msg():
print "msg()"
return "failure message"
assert 1
assert True if "a" else 1, [msg() for i in xrange(5)]
assert 1, msg()
assert 0, msg()
......@@ -10,3 +10,8 @@ print " test ".split()
print " test ".split(' ')
print " test ".split(None)
print "1<>2<>3".split('<>')
print map(bool, ["hello", "", "world"])
if "":
print "bad"
# expected: fail
# - string iteration
# This should probably be moved into the str_functions test, once it's no longer failing
for c in "hello world":
print repr(c)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment