Commit c5a3212a authored by Marius Wachtler's avatar Marius Wachtler

Support extended slice (multiple slices inside a tuple) syntax

parent e16f9f54
...@@ -473,6 +473,7 @@ public: ...@@ -473,6 +473,7 @@ public:
bool visit_dict(AST_Dict* node) override { return false; } bool visit_dict(AST_Dict* node) override { return false; }
bool visit_excepthandler(AST_ExceptHandler* node) override { return false; } bool visit_excepthandler(AST_ExceptHandler* node) override { return false; }
bool visit_expr(AST_Expr* node) override { return false; } bool visit_expr(AST_Expr* node) override { return false; }
bool visit_extslice(AST_ExtSlice* node) override { return false; }
bool visit_for(AST_For* node) override { return false; } bool visit_for(AST_For* node) override { return false; }
// bool visit_functiondef(AST_FunctionDef *node) override { return false; } // bool visit_functiondef(AST_FunctionDef *node) override { return false; }
// bool visit_global(AST_Global *node) override { return false; } // bool visit_global(AST_Global *node) override { return false; }
......
...@@ -453,6 +453,14 @@ private: ...@@ -453,6 +453,14 @@ private:
void* visit_slice(AST_Slice* node) override { return SLICE; } void* visit_slice(AST_Slice* node) override { return SLICE; }
void* visit_extslice(AST_ExtSlice* node) override {
std::vector<CompilerType*> elt_types;
for (auto* e : node->dims) {
elt_types.push_back(getType(e));
}
return makeTupleType(elt_types);
}
void* visit_str(AST_Str* node) override { void* visit_str(AST_Str* node) override {
if (node->str_type == AST_Str::STR) if (node->str_type == AST_Str::STR)
return STR; return STR;
......
...@@ -104,6 +104,7 @@ private: ...@@ -104,6 +104,7 @@ private:
Value visit_dict(AST_Dict* node); Value visit_dict(AST_Dict* node);
Value visit_expr(AST_expr* node); Value visit_expr(AST_expr* node);
Value visit_expr(AST_Expr* node); Value visit_expr(AST_Expr* node);
Value visit_extslice(AST_ExtSlice* node);
Value visit_index(AST_Index* node); Value visit_index(AST_Index* node);
Value visit_lambda(AST_Lambda* node); Value visit_lambda(AST_Lambda* node);
Value visit_list(AST_List* node); Value visit_list(AST_List* node);
...@@ -430,6 +431,14 @@ Value ASTInterpreter::visit_slice(AST_Slice* node) { ...@@ -430,6 +431,14 @@ Value ASTInterpreter::visit_slice(AST_Slice* node) {
return createSlice(lower.o, upper.o, step.o); return createSlice(lower.o, upper.o, step.o);
} }
Value ASTInterpreter::visit_extslice(AST_ExtSlice* node) {
int num_slices = node->dims.size();
BoxedTuple* rtn = BoxedTuple::create(num_slices);
for (int i = 0; i < num_slices; ++i)
rtn->elts[i] = visit_expr(node->dims[i]).o;
return rtn;
}
Value ASTInterpreter::visit_branch(AST_Branch* node) { Value ASTInterpreter::visit_branch(AST_Branch* node) {
Value v = visit_expr(node->test); Value v = visit_expr(node->test);
ASSERT(v.o == True || v.o == False, "Should have called NONZERO before this branch"); ASSERT(v.o == True || v.o == False, "Should have called NONZERO before this branch");
...@@ -942,6 +951,8 @@ Value ASTInterpreter::visit_expr(AST_expr* node) { ...@@ -942,6 +951,8 @@ Value ASTInterpreter::visit_expr(AST_expr* node) {
return visit_compare((AST_Compare*)node); return visit_compare((AST_Compare*)node);
case AST_TYPE::Dict: case AST_TYPE::Dict:
return visit_dict((AST_Dict*)node); return visit_dict((AST_Dict*)node);
case AST_TYPE::ExtSlice:
return visit_extslice((AST_ExtSlice*)node);
case AST_TYPE::Index: case AST_TYPE::Index:
return visit_index((AST_Index*)node); return visit_index((AST_Index*)node);
case AST_TYPE::Lambda: case AST_TYPE::Lambda:
......
...@@ -1092,6 +1092,20 @@ private: ...@@ -1092,6 +1092,20 @@ private:
return new ConcreteCompilerVariable(SLICE, rtn, true); return new ConcreteCompilerVariable(SLICE, rtn, true);
} }
CompilerVariable* evalExtSlice(AST_ExtSlice* node, UnwindInfo unw_info) {
std::vector<CompilerVariable*> elts;
for (auto* e : node->dims) {
elts.push_back(evalExpr(e, unw_info));
}
// TODO makeTuple should probably just transfer the vref, but I want to keep things consistent
CompilerVariable* rtn = makeTuple(elts);
for (auto* e : elts) {
e->decvref(emitter);
}
return rtn;
}
CompilerVariable* evalStr(AST_Str* node, UnwindInfo unw_info) { CompilerVariable* evalStr(AST_Str* node, UnwindInfo unw_info) {
if (node->str_type == AST_Str::STR) { if (node->str_type == AST_Str::STR) {
llvm::Value* rtn = embedConstantPtr( llvm::Value* rtn = embedConstantPtr(
...@@ -1350,6 +1364,9 @@ private: ...@@ -1350,6 +1364,9 @@ private:
case AST_TYPE::Dict: case AST_TYPE::Dict:
rtn = evalDict(ast_cast<AST_Dict>(node), unw_info); rtn = evalDict(ast_cast<AST_Dict>(node), unw_info);
break; break;
case AST_TYPE::ExtSlice:
rtn = evalExtSlice(ast_cast<AST_ExtSlice>(node), unw_info);
break;
case AST_TYPE::Index: case AST_TYPE::Index:
rtn = evalIndex(ast_cast<AST_Index>(node), unw_info); rtn = evalIndex(ast_cast<AST_Index>(node), unw_info);
break; break;
......
...@@ -403,6 +403,15 @@ AST_Expr* read_expr(BufferedReader* reader) { ...@@ -403,6 +403,15 @@ AST_Expr* read_expr(BufferedReader* reader) {
return rtn; return rtn;
} }
AST_ExtSlice* read_extslice(BufferedReader* reader) {
AST_ExtSlice* rtn = new AST_ExtSlice();
rtn->col_offset = -1;
rtn->lineno = -1;
readExprVector(rtn->dims, reader);
return rtn;
}
AST_For* read_for(BufferedReader* reader) { AST_For* read_for(BufferedReader* reader) {
AST_For* rtn = new AST_For(); AST_For* rtn = new AST_For();
...@@ -798,6 +807,8 @@ AST_expr* readASTExpr(BufferedReader* reader) { ...@@ -798,6 +807,8 @@ AST_expr* readASTExpr(BufferedReader* reader) {
return read_dict(reader); return read_dict(reader);
case AST_TYPE::DictComp: case AST_TYPE::DictComp:
return read_dictcomp(reader); return read_dictcomp(reader);
case AST_TYPE::ExtSlice:
return read_extslice(reader);
case AST_TYPE::GeneratorExp: case AST_TYPE::GeneratorExp:
return read_generatorexp(reader); return read_generatorexp(reader);
case AST_TYPE::IfExp: case AST_TYPE::IfExp:
......
...@@ -285,6 +285,10 @@ private: ...@@ -285,6 +285,10 @@ private:
writeExpr(node->value); writeExpr(node->value);
return true; return true;
} }
virtual bool visit_extslice(AST_ExtSlice* node) {
writeExprVector(node->dims);
return true;
}
virtual bool visit_for(AST_For* node) { virtual bool visit_for(AST_For* node) {
writeStmtVector(node->body); writeStmtVector(node->body);
writeColOffset(node->col_offset); writeColOffset(node->col_offset);
......
...@@ -1997,6 +1997,10 @@ public: ...@@ -1997,6 +1997,10 @@ public:
output->push_back(node); output->push_back(node);
return false; return false;
} }
virtual bool visit_extslice(AST_ExtSlice* node) {
output->push_back(node);
return false;
}
virtual bool visit_for(AST_For* node) { virtual bool visit_for(AST_For* node) {
output->push_back(node); output->push_back(node);
return !expand_scopes; return !expand_scopes;
......
...@@ -863,6 +863,16 @@ private: ...@@ -863,6 +863,16 @@ private:
return rtn; return rtn;
} }
AST_expr* remapExtSlice(AST_ExtSlice* node) {
AST_ExtSlice* rtn = new AST_ExtSlice();
rtn->lineno = node->lineno;
rtn->col_offset = node->col_offset;
for (auto* e : node->dims)
rtn->dims.push_back(remapExpr(e));
return rtn;
}
// This is a helper function used for generators expressions and comprehensions. // This is a helper function used for generators expressions and comprehensions.
// //
// Generates a FunctionDef which produces scope for `node'. The function produced is empty, so you'd better fill it. // Generates a FunctionDef which produces scope for `node'. The function produced is empty, so you'd better fill it.
...@@ -1174,6 +1184,9 @@ private: ...@@ -1174,6 +1184,9 @@ private:
case AST_TYPE::DictComp: case AST_TYPE::DictComp:
rtn = remapScopedComprehension<AST_Dict>(ast_cast<AST_DictComp>(node)); rtn = remapScopedComprehension<AST_Dict>(ast_cast<AST_DictComp>(node));
break; break;
case AST_TYPE::ExtSlice:
rtn = remapExtSlice(ast_cast<AST_ExtSlice>(node));
break;
case AST_TYPE::GeneratorExp: case AST_TYPE::GeneratorExp:
rtn = remapGeneratorExp(ast_cast<AST_GeneratorExp>(node)); rtn = remapGeneratorExp(ast_cast<AST_GeneratorExp>(node));
break; break;
......
...@@ -13,3 +13,7 @@ print sl ...@@ -13,3 +13,7 @@ print sl
sl = slice(1, 2, "hello") sl = slice(1, 2, "hello")
print sl print sl
C()[:,:]
C()[1:2,3:4]
C()[1:2:3,3:4:5]
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment