Commit c0ceabe1 authored by msg555's avatar msg555 Committed by GitHub

Update GetItem to support __class_getitem__ for type objects (GH-3518)

Closes #2753.
parent b692adc1
...@@ -39,6 +39,9 @@ else: ...@@ -39,6 +39,9 @@ else:
_py_int_types = (int, long) _py_int_types = (int, long)
IMPLICIT_CLASSMETHODS = {"__init_subclass__", "__class_getitem__"}
def relative_position(pos): def relative_position(pos):
return (pos[0].get_filenametable_entry(), pos[1]) return (pos[0].get_filenametable_entry(), pos[1])
...@@ -2438,7 +2441,7 @@ class CFuncDefNode(FuncDefNode): ...@@ -2438,7 +2441,7 @@ class CFuncDefNode(FuncDefNode):
py_func_body = self.call_self_node(is_module_scope=env.is_module_scope) py_func_body = self.call_self_node(is_module_scope=env.is_module_scope)
if self.is_static_method: if self.is_static_method:
from .ExprNodes import NameNode from .ExprNodes import NameNode
decorators = [DecoratorNode(self.pos, decorator=NameNode(self.pos, name='staticmethod'))] decorators = [DecoratorNode(self.pos, decorator=NameNode(self.pos, name=EncodedString('staticmethod')))]
decorators[0].decorator.analyse_types(env) decorators[0].decorator.analyse_types(env)
else: else:
decorators = [] decorators = []
...@@ -2883,7 +2886,12 @@ class DefNode(FuncDefNode): ...@@ -2883,7 +2886,12 @@ class DefNode(FuncDefNode):
self.is_staticmethod = False self.is_staticmethod = False
if self.name == '__new__' and env.is_py_class_scope: if self.name == '__new__' and env.is_py_class_scope:
self.is_staticmethod = 1 self.is_staticmethod = True
if not self.is_classmethod and self.name in IMPLICIT_CLASSMETHODS and env.is_py_class_scope:
from .ExprNodes import NameNode
self.decorators = self.decorators or []
self.decorators.insert(0, DecoratorNode(self.pos, decorator=NameNode(self.pos, name=EncodedString('classmethod'))))
self.is_classmethod = True
self.analyse_argument_types(env) self.analyse_argument_types(env)
if self.name == '<lambda>': if self.name == '<lambda>':
......
...@@ -273,24 +273,21 @@ static CYTHON_INLINE int __Pyx_IterFinish(void) { ...@@ -273,24 +273,21 @@ static CYTHON_INLINE int __Pyx_IterFinish(void) {
/////////////// ObjectGetItem.proto /////////////// /////////////// ObjectGetItem.proto ///////////////
#if CYTHON_USE_TYPE_SLOTS #if CYTHON_USE_TYPE_SLOTS
static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject* key);/*proto*/ static CYTHON_INLINE PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key);/*proto*/
#else #else
#define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key) #define __Pyx_PyObject_GetItem(obj, key) PyObject_GetItem(obj, key)
#endif #endif
/////////////// ObjectGetItem /////////////// /////////////// ObjectGetItem ///////////////
// //@requires: GetItemInt - added in IndexNode as it uses templating. // //@requires: GetItemInt - added in IndexNode as it uses templating.
//@requires: PyObjectGetAttrStrNoError
//@requires: PyObjectCallOneArg
#if CYTHON_USE_TYPE_SLOTS #if CYTHON_USE_TYPE_SLOTS
static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject* index) { static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject *index) {
// Get element from sequence object `obj` at index `index`.
PyObject *runerr; PyObject *runerr;
Py_ssize_t key_value; Py_ssize_t key_value;
PySequenceMethods *m = Py_TYPE(obj)->tp_as_sequence;
if (unlikely(!(m && m->sq_item))) {
PyErr_Format(PyExc_TypeError, "'%.200s' object is not subscriptable", Py_TYPE(obj)->tp_name);
return NULL;
}
key_value = __Pyx_PyIndex_AsSsize_t(index); key_value = __Pyx_PyIndex_AsSsize_t(index);
if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) { if (likely(key_value != -1 || !(runerr = PyErr_Occurred()))) {
return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1); return __Pyx_GetItemInt_Fast(obj, key_value, 0, 1, 1);
...@@ -304,12 +301,34 @@ static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject* index) { ...@@ -304,12 +301,34 @@ static PyObject *__Pyx_PyObject_GetIndex(PyObject *obj, PyObject* index) {
return NULL; return NULL;
} }
static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject* key) { static PyObject *__Pyx_PyObject_GetItem_Slow(PyObject *obj, PyObject *key) {
PyMappingMethods *m = Py_TYPE(obj)->tp_as_mapping; // Handles less common slow-path checks for GetItem
if (likely(m && m->mp_subscript)) { if (likely(PyType_Check(obj))) {
return m->mp_subscript(obj, key); PyObject *meth = __Pyx_PyObject_GetAttrStrNoError(obj, PYIDENT("__class_getitem__"));
if (meth) {
PyObject *result = __Pyx_PyObject_CallOneArg(meth, key);
Py_DECREF(meth);
return result;
}
}
PyErr_Format(PyExc_TypeError, "'%.200s' object is not subscriptable", Py_TYPE(obj)->tp_name);
return NULL;
}
static PyObject *__Pyx_PyObject_GetItem(PyObject *obj, PyObject *key) {
PyTypeObject *tp = Py_TYPE(obj);
PyMappingMethods *mm = tp->tp_as_mapping;
if (likely(mm && mm->mp_subscript)) {
return mm->mp_subscript(obj, key);
}
PySequenceMethods *sm = tp->tp_as_sequence;
if (likely(sm && sm->sq_item)) {
return __Pyx_PyObject_GetIndex(obj, key);
} }
return __Pyx_PyObject_GetIndex(obj, key);
return __Pyx_PyObject_GetItem_Slow(obj, key);
} }
#endif #endif
......
# mode: run
# tag: pure3.7
# cython: language_level=3
# COPIED FROM CPython 3.7
import unittest
import sys
class TestClassGetitem(unittest.TestCase):
# BEGIN - Additional tests from cython
def test_no_class_getitem(self):
class C: ...
with self.assertRaises(TypeError):
C[int]
# END - Additional tests from cython
def test_class_getitem(self):
getitem_args = []
class C:
def __class_getitem__(*args, **kwargs):
getitem_args.extend([args, kwargs])
return None
C[int, str]
self.assertEqual(getitem_args[0], (C, (int, str)))
self.assertEqual(getitem_args[1], {})
def test_class_getitem_format(self):
class C:
def __class_getitem__(cls, item):
return f'C[{item.__name__}]'
self.assertEqual(C[int], 'C[int]')
self.assertEqual(C[C], 'C[C]')
def test_class_getitem_inheritance(self):
class C:
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
class D(C): ...
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
def test_class_getitem_inheritance_2(self):
class C:
def __class_getitem__(cls, item):
return 'Should not see this'
class D(C):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
def test_class_getitem_classmethod(self):
class C:
@classmethod
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
class D(C): ...
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
@unittest.skipIf(sys.version_info < (3, 6), "__init_subclass__() requires Py3.6+ (PEP 487)")
def test_class_getitem_patched(self):
class C:
def __init_subclass__(cls):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
cls.__class_getitem__ = classmethod(__class_getitem__)
class D(C): ...
self.assertEqual(D[int], 'D[int]')
self.assertEqual(D[D], 'D[D]')
def test_class_getitem_with_builtins(self):
class A(dict):
called_with = None
def __class_getitem__(cls, item):
cls.called_with = item
class B(A):
pass
self.assertIs(B.called_with, None)
B[int]
self.assertIs(B.called_with, int)
def test_class_getitem_errors(self):
class C_too_few:
def __class_getitem__(cls):
return None
with self.assertRaises(TypeError):
C_too_few[int]
class C_too_many:
def __class_getitem__(cls, one, two):
return None
with self.assertRaises(TypeError):
C_too_many[int]
def test_class_getitem_errors_2(self):
class C:
def __class_getitem__(cls, item):
return None
with self.assertRaises(TypeError):
C()[int]
class E: ...
e = E()
e.__class_getitem__ = lambda cls, item: 'This will not work'
with self.assertRaises(TypeError):
e[int]
class C_not_callable:
__class_getitem__ = "Surprise!"
with self.assertRaises(TypeError):
C_not_callable[int]
def test_class_getitem_metaclass(self):
class Meta(type):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
self.assertEqual(Meta[int], 'Meta[int]')
def test_class_getitem_with_metaclass(self):
class Meta(type): pass
class C(metaclass=Meta):
def __class_getitem__(cls, item):
return f'{cls.__name__}[{item.__name__}]'
self.assertEqual(C[int], 'C[int]')
def test_class_getitem_metaclass_first(self):
class Meta(type):
def __getitem__(cls, item):
return 'from metaclass'
class C(metaclass=Meta):
def __class_getitem__(cls, item):
return 'from __class_getitem__'
self.assertEqual(C[int], 'from metaclass')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment