Commit cf7b26ef authored by Kevin Modzelewski's avatar Kevin Modzelewski

Add the __getitem__ fallback for 'in' comparisons

Also add instance.__contains__
parent 2928b4a5
......@@ -932,6 +932,87 @@ extern "C" PyObject* PySequence_List(PyObject* v) noexcept {
return result;
}
/* Iterate over seq. Result depends on the operation:
PY_ITERSEARCH_COUNT: -1 if error, else # of times obj appears in seq.
PY_ITERSEARCH_INDEX: 0-based index of first occurrence of obj in seq;
set ValueError and return -1 if none found; also return -1 on error.
Py_ITERSEARCH_CONTAINS: return 1 if obj in seq, else 0; -1 on error.
*/
extern "C" Py_ssize_t _PySequence_IterSearch(PyObject* seq, PyObject* obj, int operation) noexcept {
Py_ssize_t n;
int wrapped; /* for PY_ITERSEARCH_INDEX, true iff n wrapped around */
PyObject* it; /* iter(seq) */
if (seq == NULL || obj == NULL) {
null_error();
return -1;
}
it = PyObject_GetIter(seq);
if (it == NULL) {
type_error("argument of type '%.200s' is not iterable", seq);
return -1;
}
n = wrapped = 0;
for (;;) {
int cmp;
PyObject* item = PyIter_Next(it);
if (item == NULL) {
if (PyErr_Occurred())
goto Fail;
break;
}
cmp = PyObject_RichCompareBool(obj, item, Py_EQ);
Py_DECREF(item);
if (cmp < 0)
goto Fail;
if (cmp > 0) {
switch (operation) {
case PY_ITERSEARCH_COUNT:
if (n == PY_SSIZE_T_MAX) {
PyErr_SetString(PyExc_OverflowError, "count exceeds C integer size");
goto Fail;
}
++n;
break;
case PY_ITERSEARCH_INDEX:
if (wrapped) {
PyErr_SetString(PyExc_OverflowError, "index exceeds C integer size");
goto Fail;
}
goto Done;
case PY_ITERSEARCH_CONTAINS:
n = 1;
goto Done;
default:
assert(!"unknown operation");
}
}
if (operation == PY_ITERSEARCH_INDEX) {
if (n == PY_SSIZE_T_MAX)
wrapped = 1;
++n;
}
}
if (operation != PY_ITERSEARCH_INDEX)
goto Done;
PyErr_SetString(PyExc_ValueError, "sequence.index(x): x not in sequence");
/* fall into failure code */
Fail:
n = -1;
/* fall through */
Done:
Py_DECREF(it);
return n;
}
extern "C" PyObject* PyObject_CallFunction(PyObject* callable, const char* format, ...) noexcept {
va_list va;
PyObject* args;
......
......@@ -490,7 +490,8 @@ extern "C" PyObject* PyIter_Next(PyObject* iter) noexcept {
return callattr(iter, &next_str, CallattrFlags({.cls_only = true, .null_on_nonexistent = false }),
ArgPassSpec(0), NULL, NULL, NULL, NULL, NULL);
} catch (ExcInfo e) {
setCAPIException(e);
if (!e.matches(StopIteration))
setCAPIException(e);
return NULL;
}
}
......
......@@ -16,6 +16,7 @@
#include <sstream>
#include "capi/types.h"
#include "core/types.h"
#include "gc/collector.h"
#include "runtime/objmodel.h"
......@@ -277,6 +278,24 @@ Box* instanceDelitem(Box* _inst, Box* key) {
return runtimeCall(delitem_func, ArgPassSpec(1), key, NULL, NULL, NULL, NULL);
}
Box* instanceContains(Box* _inst, Box* key) {
RELEASE_ASSERT(_inst->cls == instance_cls, "");
BoxedInstance* inst = static_cast<BoxedInstance*>(_inst);
Box* contains_func = _instanceGetattribute(inst, boxStrConstant("__contains__"), false);
if (!contains_func) {
int result = _PySequence_IterSearch(inst, key, PY_ITERSEARCH_CONTAINS);
if (result < 0)
throwCAPIException();
assert(result == 0 || result == 1);
return boxBool(result);
}
Box* r = runtimeCall(contains_func, ArgPassSpec(1), key, NULL, NULL, NULL, NULL);
return boxBool(nonzero(r));
}
void setupClassobj() {
classobj_cls = BoxedHeapClass::create(type_cls, object_cls, &BoxedClassobj::gcHandler,
offsetof(BoxedClassobj, attrs), 0, sizeof(BoxedClassobj), false, "classobj");
......@@ -304,6 +323,7 @@ void setupClassobj() {
instance_cls->giveAttr("__getitem__", new BoxedFunction(boxRTFunction((void*)instanceGetitem, UNKNOWN, 2)));
instance_cls->giveAttr("__setitem__", new BoxedFunction(boxRTFunction((void*)instanceSetitem, UNKNOWN, 3)));
instance_cls->giveAttr("__delitem__", new BoxedFunction(boxRTFunction((void*)instanceDelitem, UNKNOWN, 2)));
instance_cls->giveAttr("__contains__", new BoxedFunction(boxRTFunction((void*)instanceContains, UNKNOWN, 2)));
instance_cls->freeze();
}
......
......@@ -75,6 +75,12 @@ Box* seqiterNext(Box* s) {
RELEASE_ASSERT(s->cls == seqiter_cls || s->cls == seqreviter_cls, "");
BoxedSeqIter* self = static_cast<BoxedSeqIter*>(s);
if (!self->next) {
Box* hasnext = seqiterHasnext(s);
if (hasnext == False)
raiseExcHelper(StopIteration, "");
}
RELEASE_ASSERT(self->next, "");
Box* r = self->next;
self->next = NULL;
......
......@@ -3206,17 +3206,11 @@ Box* compareInternal(Box* lhs, Box* rhs, int op_type, CompareRewriteArgs* rewrit
Box* contained = callattrInternal1(rhs, &contains_str, CLASS_ONLY, NULL, ArgPassSpec(1), lhs);
if (contained == NULL) {
Box* iter = callattrInternal0(rhs, &iter_str, CLASS_ONLY, NULL, ArgPassSpec(0));
if (iter)
ASSERT(isUserDefined(rhs->cls), "%s should probably have a __contains__", getTypeName(rhs));
RELEASE_ASSERT(iter == NULL, "need to try iterating");
Box* getitem = typeLookup(rhs->cls, getitem_str, NULL);
if (getitem)
ASSERT(isUserDefined(rhs->cls), "%s should probably have a __contains__", getTypeName(rhs));
RELEASE_ASSERT(getitem == NULL, "need to try old iteration protocol");
raiseExcHelper(TypeError, "argument of type '%s' is not iterable", getTypeName(rhs));
int result = _PySequence_IterSearch(rhs, lhs, PY_ITERSEARCH_CONTAINS);
if (result < 0)
throwCAPIException();
assert(result == 0 || result == 1);
return boxBool(result);
}
bool b = nonzero(contained);
......
......@@ -147,8 +147,6 @@ static const char* objectNewParameterTypeErrorMsg() {
}
}
bool exceptionMatches(const ExcInfo& e, BoxedClass* cls);
// This function will ascii-encode any unicode objects it gets passed, or return the argument
// unmodified if it wasn't a unicode object.
// This is intended for functions that deal with attribute or variable names, which we internally
......
......@@ -21,3 +21,42 @@ for i in xrange(1, 4):
print i in (1, 2, 5)
class D(object):
def __getitem__(self, i):
print i
if i < 10:
return i ** 2
raise IndexError()
d = D()
print 5 in d
print 15 in d
print 25 in d
class D():
def __getitem__(self, i):
print i
if i < 10:
return i ** 2
raise IndexError()
d = D()
print 5 in d
print 15 in d
print 25 in d
class F():
def __init__(self):
self.n = 0
def __iter__(self):
return self
def next(self):
if self.n >= 10:
raise StopIteration()
self.n += 1
return self.n ** 2
f = F()
print 5 in f
print 15 in f
print 25 in f
# expected: fail
# - exceptions
class D(object):
def __getitem__(self, idx):
print "getitem", idx
if idx >= 20:
raise IndexError()
return idx
print 10 in D()
print 1000 in D()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment