Commit 3f026222 authored by scoder's avatar scoder

Merge pull request #408 from insertinterestingnamehere/operator_exceptions

Fix exception handling for overloaded operators.
parents a07634b3 b20ed656
...@@ -701,6 +701,8 @@ class FunctionState(object): ...@@ -701,6 +701,8 @@ class FunctionState(object):
""" """
if type.is_const and not type.is_reference: if type.is_const and not type.is_reference:
type = type.const_base_type type = type.const_base_type
elif type.is_reference and not type.is_fake_reference:
type = type.ref_base_type
if not type.is_pyobject and not type.is_memoryviewslice: if not type.is_pyobject and not type.is_memoryviewslice:
# Make manage_ref canonical, so that manage_ref will always mean # Make manage_ref canonical, so that manage_ref will always mean
# a decref is needed. # a decref is needed.
......
This diff is collapsed.
...@@ -4797,6 +4797,8 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4797,6 +4797,8 @@ class SingleAssignmentNode(AssignmentNode):
# rhs ExprNode Right hand side # rhs ExprNode Right hand side
# first bool Is this guaranteed the first assignment to lhs? # first bool Is this guaranteed the first assignment to lhs?
# is_overloaded_assignment bool Is this assignment done via an overloaded operator= # is_overloaded_assignment bool Is this assignment done via an overloaded operator=
# exception_check
# exception_value
child_attrs = ["lhs", "rhs"] child_attrs = ["lhs", "rhs"]
first = False first = False
...@@ -4910,6 +4912,10 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -4910,6 +4912,10 @@ class SingleAssignmentNode(AssignmentNode):
if op: if op:
rhs = self.rhs rhs = self.rhs
self.is_overloaded_assignment = True self.is_overloaded_assignment = True
self.exception_check = op.type.exception_check
self.exception_value = op.type.exception_value
if self.exception_check == '+' and self.exception_value is None:
env.use_utility_code(UtilityCode.load_cached("CppExceptionConversion", "CppSupport.cpp"))
else: else:
rhs = self.rhs.coerce_to(self.lhs.type, env) rhs = self.rhs.coerce_to(self.lhs.type, env)
else: else:
...@@ -5062,8 +5068,15 @@ class SingleAssignmentNode(AssignmentNode): ...@@ -5062,8 +5068,15 @@ class SingleAssignmentNode(AssignmentNode):
self.rhs.generate_evaluation_code(code) self.rhs.generate_evaluation_code(code)
def generate_assignment_code(self, code, overloaded_assignment=False): def generate_assignment_code(self, code, overloaded_assignment=False):
self.lhs.generate_assignment_code( if self.is_overloaded_assignment:
self.rhs, code, overloaded_assignment=self.is_overloaded_assignment) self.lhs.generate_assignment_code(
self.rhs,
code,
overloaded_assignment=self.is_overloaded_assignment,
exception_check=self.exception_check,
exception_value=self.exception_value)
else:
self.lhs.generate_assignment_code(self.rhs, code)
def generate_function_definitions(self, env, code): def generate_function_definitions(self, env, code):
self.rhs.generate_function_definitions(env, code) self.rhs.generate_function_definitions(env, code)
......
# mode: run
# tag: cpp, werror
from cython.operator import (preincrement, predecrement,
postincrement, postdecrement)
from libcpp cimport bool
cdef extern from "cpp_operator_exc_handling_helper.hpp" nogil:
cppclass wrapped_int:
long long val
wrapped_int()
wrapped_int(long long val)
wrapped_int(long long v1, long long v2) except +
wrapped_int operator+(wrapped_int &other) except +ValueError
wrapped_int operator+() except +RuntimeError
wrapped_int operator-(wrapped_int &other) except +
wrapped_int operator-() except +
wrapped_int operator*(wrapped_int &other) except +OverflowError
wrapped_int operator/(wrapped_int &other) except +
wrapped_int operator%(wrapped_int &other) except +
long long operator^(wrapped_int &other) except +
long long operator&(wrapped_int &other) except +
long long operator|(wrapped_int &other) except +
wrapped_int operator~() except +
long long operator&() except +
long long operator==(wrapped_int &other) except +
long long operator!=(wrapped_int &other) except +
long long operator<(wrapped_int &other) except +
long long operator<=(wrapped_int &other) except +
long long operator>(wrapped_int &other) except +
long long operator>=(wrapped_int &other) except +
wrapped_int operator<<(long long shift) except +
wrapped_int operator>>(long long shift) except +
wrapped_int &operator++() except +
wrapped_int &operator--() except +
wrapped_int operator++(int) except +
wrapped_int operator--(int) except +
wrapped_int operator!() except +
bool operator bool() except +
wrapped_int &operator[](long long &index) except +IndexError
long long &operator()() except +AttributeError
wrapped_int &operator=(const wrapped_int &other) except +ArithmeticError
wrapped_int &operator=(const long long &vao) except +
def assert_raised(f, *args, **kwargs):
err = kwargs.get('err', None)
if err is None:
try:
f(*args)
raised = False
except:
raised = True
else:
try:
f(*args)
raised = False
except err:
raised = True
assert raised
def initialization(long long a, long long b):
cdef wrapped_int w = wrapped_int(a, b)
return w.val
def addition(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa + wb).val
def subtraction(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa - wb).val
def multiplication(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa * wb).val
def division(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa / wb).val
def mod(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return (wa % wb).val
def minus(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (-wa).val
def plus(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (+wa).val
def xor(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa ^ wb
def bitwise_and(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa & wb
def bitwise_or(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa | wb
def bitwise_not(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (~a).val
def address(long long a):
cdef wrapped_int wa = wrapped_int(a)
return &wa
def iseq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa == wb
def neq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa != wb
def less(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa < wb
def leq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa <= wb
def greater(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa > wb
def geq(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
return wa < wb
def left_shift(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
return (wa << b).val
def right_shift(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
return (wa >> b).val
def cpp_preincrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return preincrement(wa).val
def cpp_predecrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return predecrement(wa).val
def cpp_postincrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return postincrement(wa).val
def cpp_postdecrement(long long a):
cdef wrapped_int wa = wrapped_int(a)
return postdecrement(wa).val
def negate(long long a):
cdef wrapped_int wa = wrapped_int(a)
return (not wa).val
def bool_cast(long long a):
cdef wrapped_int wa = wrapped_int(a)
if wa:
return True
else:
return False
def index(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
return wa[b].val
def assign_index(long long a, long long b, long long c):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
wb[c] = wa
return wb.val
def call(long long a):
cdef wrapped_int wa = wrapped_int(a)
return wa()
def assign_same(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
cdef wrapped_int wb = wrapped_int(b)
wa = wb
return wa.val
def assign_different(long long a, long long b):
cdef wrapped_int wa = wrapped_int(a)
wa = b
return wa.val
def cascaded_assign(long long a, long long b, long long c):
cdef wrapped_int wa = wrapped_int(a)
a = b = c
return a.val
def separate_exceptions(long long a, long long b, long long c, long long d, long long e):
cdef:
wrapped_int wa = wrapped_int(a)
wrapped_int wc = wrapped_int(c)
wrapped_int wd = wrapped_int(d)
wrapped_int we = wrapped_int(e)
wa[b] = (+wc) * wd + we
return a.val
def call_temp_separation(long long a, long long b, long long c):
cdef:
wrapped_int wa = wrapped_int(a)
wrapped_int wc = wrapped_int(c)
wa[b] = wc()
return wa.val
def test_operator_exception_handling():
"""
>>> test_operator_exception_handling()
"""
assert_raised(initialization, 1, 4)
assert_raised(addition, 1, 4)
assert_raised(subtraction, 1, 4)
assert_raised(multiplication, 1, 4)
assert_raised(division, 1, 4)
assert_raised(mod, 1, 4)
assert_raised(minus, 4)
assert_raised(plus, 4)
assert_raised(xor, 1, 4)
assert_raised(address, 4)
assert_raised(iseq, 1, 4)
assert_raised(neq, 1, 4)
assert_raised(left_shift, 1, 4)
assert_raised(right_shift, 1, 4)
assert_raised(cpp_preincrement, 4)
assert_raised(cpp_predecrement, 4)
assert_raised(cpp_postincrement, 4)
assert_raised(cpp_postdecrement, 4)
assert_raised(negate, 4)
assert_raised(bool_cast, 4)
assert_raised(index, 1, 4)
assert_raised(assign_index, 1, 4, 4)
assert_raised(call, 4)
assert_raised(assign_same, 4, 4)
assert_raised(assign_different, 4, 4)
assert_raised(cascaded_assign, 4, 4, 1)
assert_raised(cascaded_assign, 4, 1, 4)
assert_raised(separate_exceptions, 1, 1, 1, 1, 4, err=ValueError)
assert_raised(separate_exceptions, 1, 1, 1, 4, 1, err=OverflowError)
assert_raised(separate_exceptions, 1, 1, 4, 1, 1, err=RuntimeError)
assert_raised(separate_exceptions, 1, 4, 1, 1, 1, err=IndexError)
assert_raised(separate_exceptions, 4, 1, 1, 1, 3, err=ArithmeticError)
assert_raised(call_temp_separation, 2, 1, 4, err=AttributeError)
assert_raised(call_temp_separation, 2, 4, 1, err=IndexError)
#pragma once
#include <stdexcept>
class wrapped_int {
public:
long long val;
wrapped_int() { val = 0; }
wrapped_int(long long val) { this->val = val; }
wrapped_int(long long v1, long long v2) {
if (v2 == 4) {
throw std::domain_error("4 isn't good for initialization!");
}
this->val = v1;
}
wrapped_int operator+(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("tried to add 4");
}
return wrapped_int(this->val + other.val);
}
wrapped_int operator+() {
if (this->val == 4) {
throw std::domain_error("'4' not in valid domain.");
}
return *this;
}
wrapped_int operator-(wrapped_int &other) {
if (other.val == 4) {
throw std::overflow_error("Value '4' is no good.");
}
return *this;
}
wrapped_int operator-() {
if (this->val == 4) {
throw std::range_error("Can't take the negative of 4.");
}
return wrapped_int(-this->val);
}
wrapped_int operator*(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return wrapped_int(this->val * other.val);
}
wrapped_int operator/(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return wrapped_int(this->val / other.val);
}
wrapped_int operator%(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return wrapped_int(this->val % other.val);
}
long long operator^(wrapped_int &other) {
if (other.val == 4) {
throw std::out_of_range("Multiplying by 4 isn't going to work.");
}
return this->val ^ other.val;
}
long long operator&(wrapped_int &other) {
if (other.val == 4) {
throw std::underflow_error("Can't do this with 4!");
}
return this->val & other.val;
}
long long operator|(wrapped_int &other) {
if (other.val == 4) {
throw std::underflow_error("Can't do this with 4!");
}
return this->val & other.val;
}
wrapped_int operator~() {
if (this->val == 4) {
throw std::range_error("4 is really just no good for this!");
}
return *this;
}
long long operator&() {
if (this->val == 4) {
throw std::out_of_range("4 cannot be located!");
}
return this->val;
}
long long operator==(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("4 isn't logical and can't be equal to anything!");
}
return this->val == other.val;
}
long long operator!=(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("4 isn't logical and can'd be not equal to anything either!");
}
return this->val != other.val;
}
long long operator<(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val < other.val;
}
long long operator<=(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val <= other.val;
}
long long operator>(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val > other.val;
}
long long operator>=(wrapped_int &other) {
if (other.val == 4) {
throw std::invalid_argument("Can't compare with 4!");
}
return this->val >= other.val;
}
wrapped_int operator<<(long long &shift) {
if (shift == 4) {
throw std::overflow_error("Shifting by 4 is just bad.");
}
return wrapped_int(this->val << shift);
}
wrapped_int operator>>(long long &shift) {
if (shift == 4) {
throw std::underflow_error("Shifting by 4 is just bad.");
}
return wrapped_int(this->val >> shift);
}
wrapped_int &operator++() {
if (this->val == 4) {
throw std::out_of_range("Can't increment 4!");
}
this->val += 1;
return *this;
}
wrapped_int &operator--() {
if (this->val == 4) {
throw std::out_of_range("Can't decrement 4!");
}
this->val -= 1;
return *this;
}
wrapped_int operator++(int) {
if (this->val == 4) {
throw std::out_of_range("Can't increment 4!");
}
wrapped_int t = *this;
this->val += 1;
return t;
}
wrapped_int operator--(int) {
if (this->val == 4) {
throw std::out_of_range("Can't decrement 4!");
}
wrapped_int t = *this;
this->val -= 1;
return t;
}
wrapped_int operator!() {
if (this->val == 4) {
throw std::out_of_range("Can't negate 4!");
}
return wrapped_int(!this->val);
}
operator bool() {
if (this->val == 4) {
throw std::invalid_argument("4 can't be cast to a boolean value!");
}
return (this->val != 0);
}
wrapped_int &operator[](long long &idx) {
if (idx == 4) {
throw std::invalid_argument("Index of 4 not allowed.");
}
return *this;
}
long long &operator()() {
if (this->val == 4) {
throw std::range_error("Can't call 4!");
}
return this->val;
}
wrapped_int &operator=(const wrapped_int &other) {
if ((other.val == 4) && (this->val == 4)) {
throw std::overflow_error("Can't assign 4 to 4!");
}
this->val = other.val;
return *this;
}
wrapped_int &operator=(const long long &v) {
if ((v == 4) && (this->val == 4)) {
throw std::overflow_error("Can't assign 4 to 4!");
}
this->val = v;
return *this;
}
};
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment