Commit 102366d1 authored by da-woods's avatar da-woods Committed by GitHub

Implement cdef dataclasses (GH-3400)

New decorator/function "@cython.dataclasses.dataclass" and "cython.dataclasses.field()" to mark dataclasses and their fields.

Tries to match the interface provided by a regular dataclass as much as possible.
This means taking the types from the dataclasses module if available (so they match exactly) or a fallback Python version that just implements the core parts (executed with "PyRun_String()" in the C source).

Use of placeholders in generated "__init__" code means the code in the C file isn't hugely readable. Probably not a huge issue, but don't really see a way round that.

As part of this I've also also implemented a Cython version of "typing.ClassVar". Although really designed for use with dataclasses it behaves sensibly when used in types in a normal cdef class. This is worth documenting more thoroughly.

Closes https://github.com/cython/cython/issues/2903
parent 07f45205
......@@ -4,7 +4,8 @@
from __future__ import absolute_import
from .Symtab import BuiltinScope, StructOrUnionScope
from .StringEncoding import EncodedString
from .Symtab import BuiltinScope, StructOrUnionScope, ModuleScope
from .Code import UtilityCode
from .TypeSlots import Signature
from . import PyrexTypes
......@@ -451,3 +452,56 @@ def init_builtins():
init_builtins()
##############################
# Support for a few standard library modules that Cython understands (currently typing and dataclasses)
##############################
_known_module_scopes = {}
def get_known_standard_library_module_scope(module_name):
mod = _known_module_scopes.get(module_name)
if mod:
return mod
if module_name == "typing":
mod = ModuleScope(module_name, None, None)
for name, tp in [
('Dict', dict_type),
('List', list_type),
('Tuple', tuple_type),
('Set', set_type),
('FrozenSet', frozenset_type),
]:
name = EncodedString(name)
if name == "Tuple":
indexed_type = PyrexTypes.PythonTupleTypeConstructor(EncodedString("typing."+name), tp)
else:
indexed_type = PyrexTypes.PythonTypeConstructor(EncodedString("typing."+name), tp)
entry = mod.declare_type(name, indexed_type, pos = None)
for name in ['ClassVar', 'Optional']:
indexed_type = PyrexTypes.SpecialPythonTypeConstructor(EncodedString("typing."+name))
entry = mod.declare_type(name, indexed_type, pos = None)
_known_module_scopes[module_name] = mod
elif module_name == "dataclasses":
mod = ModuleScope(module_name, None, None)
indexed_type = PyrexTypes.SpecialPythonTypeConstructor(EncodedString("dataclasses.InitVar"))
entry = mod.declare_type(EncodedString("InitVar"), indexed_type, pos = None)
_known_module_scopes[module_name] = mod
return mod
def get_known_standard_library_entry(qualified_name):
name_parts = qualified_name.split(".")
module_name = EncodedString(name_parts[0])
rest = name_parts[1:]
if len(rest) > 1: # for now, we don't know how to deal with any nested modules
return None
mod = get_known_standard_library_module_scope(module_name)
# eventually handle more sophisticated multiple lookups if needed
if mod and rest:
return mod.lookup_here(rest[0])
return None
......@@ -6,6 +6,7 @@ from .UtilityCode import CythonUtilityCode
from .Errors import error
from .Scanning import StringSourceDescriptor
from . import MemoryView
from .StringEncoding import EncodedString
class CythonScope(ModuleScope):
......@@ -135,9 +136,16 @@ class CythonScope(ModuleScope):
for ext_type in ext_types:
ext_type.is_cython_builtin_type = 1
# self.entries["array"] = view_utility_scope.entries.pop("array")
# dataclasses scope
dc_str = EncodedString(u'dataclasses')
dataclassesscope = ModuleScope(dc_str, self, context=None)
self.declare_module(dc_str, dataclassesscope, pos=None).as_module = dataclassesscope
dataclassesscope.is_cython_builtin = True
dataclassesscope.pxd_file_loaded = True
# doesn't actually have any contents
def create_cython_scope(context):
# One could in fact probably make it a singleton,
......
# functions to transform a c class into a dataclass
from collections import OrderedDict
from textwrap import dedent
import operator
from . import ExprNodes
from . import Nodes
from . import PyrexTypes
from . import UtilNodes
from . import Builtin
from . import Naming
from .Errors import error, warning
from .Code import UtilityCode, TempitaUtilityCode
from .Visitor import VisitorTransform
from .StringEncoding import BytesLiteral, EncodedString
from .TreeFragment import TreeFragment
from .ParseTreeTransforms import (NormalizeTree, SkipDeclarations, AnalyseDeclarationsTransform,
MarkClosureVisitor)
from .Options import copy_inherited_directives
_dataclass_loader_utilitycode = None
def make_dataclasses_module_callnode(pos):
global _dataclass_loader_utilitycode
if not _dataclass_loader_utilitycode:
python_utility_code = UtilityCode.load_cached("Dataclasses_fallback", "Dataclasses.py")
python_utility_code = EncodedString(python_utility_code.impl)
_dataclass_loader_utilitycode = TempitaUtilityCode.load(
"SpecificModuleLoader", "Dataclasses.c",
context={'cname': "dataclasses", 'py_code': python_utility_code.as_c_string_literal()})
return ExprNodes.PythonCapiCallNode(
pos, "__Pyx_Load_dataclasses_Module",
PyrexTypes.CFuncType(PyrexTypes.py_object_type, []),
utility_code=_dataclass_loader_utilitycode,
args=[],
)
_INTERNAL_DEFAULTSHOLDER_NAME = EncodedString('__pyx_dataclass_defaults')
class RemoveAssignmentsToNames(VisitorTransform, SkipDeclarations):
"""
Cython (and Python) normally treats
class A:
x = 1
as generating a class attribute. However for dataclasses the `= 1` should be interpreted as
a default value to initialize an instance attribute with.
This transform therefore removes the `x=1` assignment so that the class attribute isn't
generated, while recording what it has removed so that it can be used in the initialization.
"""
def __init__(self, names):
super(RemoveAssignmentsToNames, self).__init__()
self.names = names
self.removed_assignments = {}
def visit_CClassNode(self, node):
self.visitchildren(node)
return node
def visit_PyClassNode(self, node):
return node # go no further
def visit_FuncDefNode(self, node):
return node # go no further
def visit_SingleAssignmentNode(self, node):
if node.lhs.is_name and node.lhs.name in self.names:
if node.lhs.name in self.removed_assignments:
warning(node.pos, ("Multiple assignments for '%s' in dataclass; "
"using most recent") % node.lhs.name, 1)
self.removed_assignments[node.lhs.name] = node.rhs
return []
return node
# I believe cascaded assignment is always a syntax error with annotations
# so there's no need to define visit_CascadedAssignmentNode
def visit_Node(self, node):
self.visitchildren(node)
return node
class _MISSING_TYPE(object):
pass
MISSING = _MISSING_TYPE()
class Field(object):
"""
Field is based on the dataclasses.field class from the standard library module.
It is used internally during the generation of Cython dataclasses to keep track
of the settings for individual attributes.
Attributes of this class are stored as nodes so they can be used in code construction
more readily (i.e. we store BoolNode rather than bool)
"""
default = MISSING
default_factory = MISSING
private = False
literal_keys = ("repr", "hash", "init", "compare", "metadata")
# default values are defined by the CPython dataclasses.field
def __init__(self, pos, default=MISSING, default_factory=MISSING,
repr=None, hash=None, init=None,
compare=None, metadata=None,
is_initvar=False, is_classvar=False,
**additional_kwds):
if default is not MISSING:
self.default = default
if default_factory is not MISSING:
self.default_factory = default_factory
self.repr = repr or ExprNodes.BoolNode(pos, value=True)
self.hash = hash or ExprNodes.NoneNode(pos)
self.init = init or ExprNodes.BoolNode(pos, value=True)
self.compare = compare or ExprNodes.BoolNode(pos, value=True)
self.metadata = metadata or ExprNodes.NoneNode(pos)
self.is_initvar = is_initvar
self.is_classvar = is_classvar
for k, v in additional_kwds.items():
# There should not be any additional keywords!
error(v.pos, "cython.dataclasses.field() got an unexpected keyword argument '%s'" % k)
for field_name in self.literal_keys:
field_value = getattr(self, field_name)
if not field_value.is_literal:
error(field_value.pos,
"cython.dataclasses.field parameter '%s' must be a literal value" % field_name)
def iterate_record_node_arguments(self):
for key in (self.literal_keys + ('default', 'default_factory')):
value = getattr(self, key)
if value is not MISSING:
yield key, value
def process_class_get_fields(node):
var_entries = node.scope.var_entries
# order of definition is used in the dataclass
var_entries = sorted(var_entries, key=operator.attrgetter('pos'))
var_names = [entry.name for entry in var_entries]
# don't treat `x = 1` as an assignment of a class attribute within the dataclass
transform = RemoveAssignmentsToNames(var_names)
transform(node)
default_value_assignments = transform.removed_assignments
if node.base_type and node.base_type.dataclass_fields:
fields = node.base_type.dataclass_fields.copy()
else:
fields = OrderedDict()
for entry in var_entries:
name = entry.name
is_initvar = (entry.type.python_type_constructor_name == "dataclasses.InitVar")
# TODO - classvars aren't included in "var_entries" so are missed here
# and thus this code is never triggered
is_classvar = (entry.type.python_type_constructor_name == "typing.ClassVar")
if is_initvar or is_classvar:
entry.type = entry.type.resolve() # no longer need the special type
if name in default_value_assignments:
assignment = default_value_assignments[name]
if (isinstance(assignment, ExprNodes.CallNode)
and assignment.function.as_cython_attribute() == "dataclasses.field"):
# I believe most of this is well-enforced when it's treated as a directive
# but it doesn't hurt to make sure
if (not isinstance(assignment, ExprNodes.GeneralCallNode)
or not isinstance(assignment.positional_args, ExprNodes.TupleNode)
or assignment.positional_args.args
or not isinstance(assignment.keyword_args, ExprNodes.DictNode)):
error(assignment.pos, "Call to 'cython.dataclasses.field' must only consist "
"of compile-time keyword arguments")
continue
keyword_args = assignment.keyword_args.as_python_dict()
if 'default' in keyword_args and 'default_factory' in keyword_args:
error(assignment.pos, "cannot specify both default and default_factory")
continue
field = Field(node.pos, **keyword_args)
else:
if isinstance(assignment, ExprNodes.CallNode):
func = assignment.function
if ((func.is_name and func.name == "field")
or (func.is_attribute and func.attribute == "field")):
warning(assignment.pos, "Do you mean cython.dataclasses.field instead?", 1)
if assignment.type in [Builtin.list_type, Builtin.dict_type, Builtin.set_type]:
# The standard library module generates a TypeError at runtime
# in this situation.
# Error message is copied from CPython
error(assignment.pos, "mutable default <class '{0}'> for field {1} is not allowed: "
"use default_factory".format(assignment.type.name, name))
field = Field(node.pos, default=assignment)
else:
field = Field(node.pos)
field.is_initvar = is_initvar
field.is_classvar = is_classvar
if entry.visibility == "private":
field.private = True
fields[name] = field
node.entry.type.dataclass_fields = fields
return fields
def handle_cclass_dataclass(node, dataclass_args, analyse_decs_transform):
# default argument values from https://docs.python.org/3/library/dataclasses.html
kwargs = dict(init=True, repr=True, eq=True,
order=False, unsafe_hash=False, frozen=False)
if dataclass_args is not None:
if dataclass_args[0]:
error(node.pos, "cython.dataclasses.dataclass takes no positional arguments")
for k, v in dataclass_args[1].items():
if k not in kwargs:
error(node.pos,
"cython.dataclasses.dataclass() got an unexpected keyword argument '%s'" % k)
if not isinstance(v, ExprNodes.BoolNode):
error(node.pos,
"Arguments passed to cython.dataclasses.dataclass must be True or False")
kwargs[k] = v
fields = process_class_get_fields(node)
dataclass_module = make_dataclasses_module_callnode(node.pos)
# create __dataclass_params__ attribute. I try to use the exact
# `_DataclassParams` class defined in the standard library module if at all possible
# for maximum duck-typing compatibility.
dataclass_params_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
attribute=EncodedString("_DataclassParams"))
dataclass_params_keywords = ExprNodes.DictNode.from_pairs(
node.pos,
[ (ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
ExprNodes.BoolNode(node.pos, value=v))
for k, v in kwargs.items() ])
dataclass_params = ExprNodes.GeneralCallNode(node.pos,
function = dataclass_params_func,
positional_args = ExprNodes.TupleNode(node.pos, args=[]),
keyword_args = dataclass_params_keywords)
dataclass_params_assignment = Nodes.SingleAssignmentNode(
node.pos,
lhs = ExprNodes.NameNode(node.pos, name=EncodedString("__dataclass_params__")),
rhs = dataclass_params)
dataclass_fields_stats = _set_up_dataclass_fields(node, fields, dataclass_module)
stats = Nodes.StatListNode(node.pos,
stats=[dataclass_params_assignment] + dataclass_fields_stats)
code_lines = []
placeholders = {}
extra_stats = []
for cl, ph, es in [ generate_init_code(kwargs['init'], node, fields),
generate_repr_code(kwargs['repr'], node, fields),
generate_eq_code(kwargs['eq'], node, fields),
generate_order_code(kwargs['order'], node, fields),
generate_hash_code(kwargs['unsafe_hash'], kwargs['eq'], kwargs['frozen'], node, fields) ]:
code_lines.append(cl)
placeholders.update(ph)
extra_stats.extend(extra_stats)
code_lines = "\n".join(code_lines)
code_tree = TreeFragment(code_lines, level='c_class', pipeline=[NormalizeTree(node.scope)]
).substitute(placeholders)
stats.stats += (code_tree.stats + extra_stats)
# turn off annotation typing, so all arguments to __init__ are accepted as
# generic objects and thus can accept _HAS_DEFAULT_FACTORY.
# Type conversion comes later
comp_directives = Nodes.CompilerDirectivesNode(node.pos,
directives=copy_inherited_directives(node.scope.directives, annotation_typing=False),
body=stats)
comp_directives.analyse_declarations(node.scope)
# probably already in this scope, but it doesn't hurt to make sure
analyse_decs_transform.enter_scope(node, node.scope)
analyse_decs_transform.visit(comp_directives)
analyse_decs_transform.exit_scope()
node.body.stats.append(comp_directives)
def generate_init_code(init, node, fields):
"""
All of these "generate_*_code" functions return a tuple of:
- code string
- placeholder dict (often empty)
- stat list (often empty)
which can then be combined later and processed once.
Notes on CPython generated "__init__":
* Implemented in `_init_fn`.
* The use of the `dataclasses._HAS_DEFAULT_FACTORY` sentinel value as
the default argument for fields that need constructing with a factory
function is copied from the CPython implementation. (`None` isn't
suitable because it could also be a value for the user to pass.)
There's no real reason why it needs importing from the dataclasses module
though - it could equally be a value generated by Cython when the module loads.
* seen_default and the associated error message are copied directly from Python
* Call to user-defined __post_init__ function (if it exists) is copied from
CPython.
"""
if not init or node.scope.lookup_here("__init__"):
return "", {}, []
# selfname behaviour copied from the cpython module
selfname = "__dataclass_self__" if "self" in fields else "self"
args = [selfname]
placeholders = {}
placeholder_count = [0]
# create a temp to get _HAS_DEFAULT_FACTORY
dataclass_module = make_dataclasses_module_callnode(node.pos)
has_default_factory = ExprNodes.AttributeNode(
node.pos,
obj=dataclass_module,
attribute=EncodedString("_HAS_DEFAULT_FACTORY")
)
def get_placeholder_name():
while True:
name = "INIT_PLACEHOLDER_%d" % placeholder_count[0]
if (name not in placeholders
and name not in fields):
# make sure name isn't already used and doesn't
# conflict with a variable name (which is unlikely but possible)
break
placeholder_count[0] += 1
return name
default_factory_placeholder = get_placeholder_name()
placeholders[default_factory_placeholder] = has_default_factory
function_body_code_lines = []
seen_default = False
for name, field in fields.items():
if not field.init.value:
continue
entry = node.scope.lookup(name)
if entry.annotation:
annotation = u": %s" % entry.annotation.string.value
else:
annotation = u""
assignment = u''
if field.default is not MISSING or field.default_factory is not MISSING:
seen_default = True
if field.default_factory is not MISSING:
ph_name = default_factory_placeholder
else:
ph_name = get_placeholder_name()
placeholders[ph_name] = field.default # should be a node
assignment = u" = %s" % ph_name
elif seen_default:
error(entry.pos, ("non-default argument '%s' follows default argument "
"in dataclass __init__") % name)
return "", {}, []
args.append(u"%s%s%s" % (name, annotation, assignment))
if field.is_initvar:
continue
elif field.default_factory is MISSING:
if field.init.value:
function_body_code_lines.append(u" %s.%s = %s" % (selfname, name, name))
else:
ph_name = get_placeholder_name()
placeholders[ph_name] = field.default_factory
if field.init.value:
# close to:
# def __init__(self, name=_PLACEHOLDER_VALUE):
# self.name = name_default_factory() if name is _PLACEHOLDER_VALUE else name
function_body_code_lines.append(u" %s.%s = %s() if %s is %s else %s" % (
selfname, name, ph_name, name, default_factory_placeholder, name))
else:
# still need to use the default factory to initialize
function_body_code_lines.append(u" %s.%s = %s()"
% (selfname, name, ph_name))
args = u", ".join(args)
func_def = u"def __init__(%s):" % args
code_lines = [func_def] + (function_body_code_lines or ["pass"])
if node.scope.lookup("__post_init__"):
post_init_vars = ", ".join(name for name, field in fields.items()
if field.is_initvar)
code_lines.append(" %s.__post_init__(%s)" % (selfname, post_init_vars))
return u"\n".join(code_lines), placeholders, []
def generate_repr_code(repr, node, fields):
"""
The CPython implementation is just:
['return self.__class__.__qualname__ + f"(' +
', '.join([f"{f.name}={{self.{f.name}!r}}"
for f in fields]) +
')"'],
The only notable difference here is self.__class__.__qualname__ -> type(self).__name__
which is because Cython currently supports Python 2.
"""
if not repr or node.scope.lookup("__repr__"):
return "", {}, []
code_lines = ["def __repr__(self):"]
strs = [u"%s={self.%s!r}" % (name, name)
for name, field in fields.items()
if field.repr.value and not field.is_initvar]
format_string = u", ".join(strs)
code_lines.append(u' name = getattr(type(self), "__qualname__", type(self).__name__)')
code_lines.append(u" return f'{name}(%s)'" % format_string)
code_lines = u"\n".join(code_lines)
return code_lines, {}, []
def generate_cmp_code(op, funcname, node, fields):
if node.scope.lookup_here(funcname):
return "", {}, []
names = [name for name, field in fields.items() if (field.compare.value and not field.is_initvar)]
if not names:
return "", {}, [] # no comparable types
code_lines = [
"def %s(self, other):" % funcname,
" cdef %s other_cast" % node.class_name,
" if isinstance(other, %s):" % node.class_name,
" other_cast = <%s>other" % node.class_name,
" else:",
" return NotImplemented"
]
# The Python implementation of dataclasses.py does a tuple comparison
# (roughly):
# return self._attributes_to_tuple() {op} other._attributes_to_tuple()
#
# For the Cython implementation a tuple comparison isn't an option because
# not all attributes can be converted to Python objects and stored in a tuple
#
# TODO - better diagnostics of whether the types support comparison before
# generating the code. Plus, do we want to convert C structs to dicts and
# compare them that way (I think not, but it might be in demand)?
checks = []
for name in names:
checks.append("(self.%s %s other_cast.%s)" % (
name, op, name))
if checks:
code_lines.append(" return " + " and ".join(checks))
else:
if "=" in op:
code_lines.append(" return True") # "() == ()" is True
else:
code_lines.append(" return False")
code_lines = u"\n".join(code_lines)
return code_lines, {}, []
def generate_eq_code(eq, node, fields):
if not eq:
return code_lines, {}, []
return generate_cmp_code("==", "__eq__", node, fields)
def generate_order_code(order, node, fields):
if not order:
return "", {}, []
code_lines = []
placeholders = {}
stats = []
for op, name in [("<", "__lt__"),
("<=", "__le__"),
(">", "__gt__"),
(">=", "__ge__")]:
res = generate_cmp_code(op, name, node, fields)
code_lines.append(res[0])
placeholders.update(res[1])
stats.extend(res[2])
return "\n".join(code_lines), placeholders, stats
def generate_hash_code(unsafe_hash, eq, frozen, node, fields):
"""
Copied from CPython implementation - the intention is to follow this as far as
is possible:
# +------------------- unsafe_hash= parameter
# | +----------- eq= parameter
# | | +--- frozen= parameter
# | | |
# v v v | | |
# | no | yes | <--- class has explicitly defined __hash__
# +=======+=======+=======+========+========+
# | False | False | False | | | No __eq__, use the base class __hash__
# +-------+-------+-------+--------+--------+
# | False | False | True | | | No __eq__, use the base class __hash__
# +-------+-------+-------+--------+--------+
# | False | True | False | None | | <-- the default, not hashable
# +-------+-------+-------+--------+--------+
# | False | True | True | add | | Frozen, so hashable, allows override
# +-------+-------+-------+--------+--------+
# | True | False | False | add | raise | Has no __eq__, but hashable
# +-------+-------+-------+--------+--------+
# | True | False | True | add | raise | Has no __eq__, but hashable
# +-------+-------+-------+--------+--------+
# | True | True | False | add | raise | Not frozen, but hashable
# +-------+-------+-------+--------+--------+
# | True | True | True | add | raise | Frozen, so hashable
# +=======+=======+=======+========+========+
# For boxes that are blank, __hash__ is untouched and therefore
# inherited from the base class. If the base is object, then
# id-based hashing is used.
The Python implementation creates a tuple of all the fields, then hashes them.
This implementation creates a tuple of all the hashes of all the fields and hashes that.
The reason for this slight difference is to avoid to-Python conversions for anything
that Cython knows how to hash directly (It doesn't look like this currently applies to
anything though...).
"""
hash_entry = node.scope.lookup_here("__hash__")
if hash_entry:
# TODO ideally assignment of __hash__ to None shouldn't trigger this
# but difficult to get the right information here
if unsafe_hash:
# error message taken from CPython dataclasses module
error(node.pos, "Cannot overwrite attribute __hash__ in class %s" % node.class_name)
return "", {}, []
if not unsafe_hash:
if not eq:
return
if not frozen:
return "", {}, [Nodes.SingleAssignmentNode(
node.pos,
lhs=ExprNodes.NameNode(node.pos, name=EncodedString("__hash__")),
rhs=ExprNodes.NoneNode(node.pos),
)]
names = [
name for name, field in fields.items()
if (not field.is_initvar and
(field.compare.value if field.hash.value is None else field.hash.value))
]
if not names:
return "", {}, [] # nothing to hash
# make a tuple of the hashes
tpl = u", ".join(u"hash(self.%s)" % name for name in names )
# if we're here we want to generate a hash
code_lines = dedent(u"""\
def __hash__(self):
return hash((%s))
""") % tpl
return code_lines, {}, []
def get_field_type(pos, entry):
"""
sets the .type attribute for a field
Returns the annotation if possible (since this is what the dataclasses
module does). If not (for example, attributes defined with cdef) then
it creates a string fallback.
"""
if entry.annotation:
# Right now it doesn't look like cdef classes generate an
# __annotations__ dict, therefore it's safe to just return
# entry.annotation
# (TODO: remove .string if we ditch PEP563)
return entry.annotation.string
# If they do in future then we may need to look up into that
# to duplicating the node. The code below should do this:
#class_name_node = ExprNodes.NameNode(pos, name=entry.scope.name)
#annotations = ExprNodes.AttributeNode(
# pos, obj=class_name_node,
# attribute=EncodedString("__annotations__")
#)
#return ExprNodes.IndexNode(
# pos, base=annotations,
# index=ExprNodes.StringNode(pos, value=entry.name)
#)
else:
# it's slightly unclear what the best option is here - we could
# try to return PyType_Type. This case should only happen with
# attributes defined with cdef so Cython is free to make it's own
# decision
s = entry.type.declaration_code("", for_display=1)
return ExprNodes.StringNode(pos, value=s)
class FieldRecordNode(ExprNodes.ExprNode):
"""
__dataclass_fields__ contains a bunch of field objects recording how each field
of the dataclass was initialized (mainly corresponding to the arguments passed to
the "field" function). This node is used for the attributes of these field objects.
If possible, coerces `arg` to a Python object.
Otherwise, generates a sensible backup string.
"""
subexprs = ['arg']
def __init__(self, pos, arg):
super(FieldRecordNode, self).__init__(pos, arg=arg)
def analyse_types(self, env):
self.arg.analyse_types(env)
self.type = self.arg.type
return self
def coerce_to_pyobject(self, env):
if self.arg.type.can_coerce_to_pyobject(env):
return self.arg.coerce_to_pyobject(env)
else:
# A string representation of the code that gave the field seems like a reasonable
# fallback. This'll mostly happen for "default" and "default_factory" where the
# type may be a C-type that can't be converted to Python.
return self._make_string()
def _make_string(self):
from .AutoDocTransforms import AnnotationWriter
writer = AnnotationWriter(description="Dataclass field")
string = writer.write(self.arg)
return ExprNodes.StringNode(self.pos, value=EncodedString(string))
def generate_evaluation_code(self, code):
return self.arg.generate_evaluation_code(code)
def _set_up_dataclass_fields(node, fields, dataclass_module):
# For defaults and default_factories containing things like lambda,
# they're already declared in the class scope, and it creates a big
# problem if multiple copies are floating around in both the __init__
# function, and in the __dataclass_fields__ structure.
# Therefore, create module-level constants holding these values and
# pass those around instead
#
# If possible we use the `Field` class defined in the standard library
# module so that the information stored here is as close to a regular
# dataclass as is possible.
variables_assignment_stats = []
for name, field in fields.items():
if field.private:
continue # doesn't appear in the public interface
for attrname in [ "default", "default_factory" ]:
field_default = getattr(field, attrname)
if field_default is MISSING or field_default.is_literal or field_default.is_name:
# some simple cases where we don't need to set up
# the variable as a module-level constant
continue
global_scope = node.scope.global_scope()
module_field_name = global_scope.mangle(
global_scope.mangle(Naming.dataclass_field_default_cname, node.class_name),
name)
# create an entry in the global scope for this variable to live
field_node = ExprNodes.NameNode(field_default.pos, name=EncodedString(module_field_name))
field_node.entry = global_scope.declare_var(field_node.name, type=field_default.type or PyrexTypes.unspecified_type,
pos=field_default.pos, cname=field_node.name, is_cdef=1)
# replace the field so that future users just receive the namenode
setattr(field, attrname, field_node)
variables_assignment_stats.append(
Nodes.SingleAssignmentNode(field_default.pos, lhs=field_node, rhs=field_default))
placeholders = {}
field_func = ExprNodes.AttributeNode(node.pos, obj=dataclass_module,
attribute=EncodedString("field"))
dc_fields = ExprNodes.DictNode(node.pos, key_value_pairs=[])
dc_fields_namevalue_assignments = []
for name, field in fields.items():
if field.private:
continue # doesn't appear in the public interface
type_placeholder_name = "PLACEHOLDER_%s" % name
placeholders[type_placeholder_name] = get_field_type(
node.pos, node.scope.entries[name]
)
# defining these make the fields introspect more like a Python dataclass
field_type_placeholder_name = "PLACEHOLDER_FIELD_TYPE_%s" % name
if field.is_initvar:
placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
node.pos, obj=dataclass_module,
attribute=EncodedString("_FIELD_INITVAR")
)
elif field.is_classvar:
# TODO - currently this isn't triggered
placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
node.pos, obj=dataclass_module,
attribute=EncodedString("_FIELD_CLASSVAR")
)
else:
placeholders[field_type_placeholder_name] = ExprNodes.AttributeNode(
node.pos, obj=dataclass_module,
attribute=EncodedString("_FIELD")
)
dc_field_keywords = ExprNodes.DictNode.from_pairs(
node.pos,
[(ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(k)),
FieldRecordNode(node.pos, arg=v))
for k, v in field.iterate_record_node_arguments()]
)
dc_field_call = ExprNodes.GeneralCallNode(
node.pos, function = field_func,
positional_args = ExprNodes.TupleNode(node.pos, args=[]),
keyword_args = dc_field_keywords)
dc_fields.key_value_pairs.append(
ExprNodes.DictItemNode(
node.pos,
key=ExprNodes.IdentifierStringNode(node.pos, value=EncodedString(name)),
value=dc_field_call))
dc_fields_namevalue_assignments.append(
dedent(u"""\
__dataclass_fields__[{0!r}].name = {0!r}
__dataclass_fields__[{0!r}].type = {1}
__dataclass_fields__[{0!r}]._field_type = {2}
""").format(name, type_placeholder_name, field_type_placeholder_name))
dataclass_fields_assignment = \
Nodes.SingleAssignmentNode(node.pos,
lhs = ExprNodes.NameNode(node.pos,
name=EncodedString("__dataclass_fields__")),
rhs = dc_fields)
dc_fields_namevalue_assignments = u"\n".join(dc_fields_namevalue_assignments)
dc_fields_namevalue_assignments = TreeFragment(dc_fields_namevalue_assignments,
level="c_class",
pipeline=[NormalizeTree(None)])
dc_fields_namevalue_assignments = dc_fields_namevalue_assignments.substitute(placeholders)
return (variables_assignment_stats
+ [dataclass_fields_assignment]
+ dc_fields_namevalue_assignments.stats)
......@@ -1183,6 +1183,15 @@ class ExprNode(Node):
kwargs[attr_name] = value
return cls(node.pos, **kwargs)
def get_known_standard_library_import(self):
"""
Gets the module.path that this node was imported from.
Many nodes do not have one, or it is ambiguous, in which case
this function returns a false value.
"""
return None
class AtomicExprNode(ExprNode):
# Abstract base class for expression nodes which have
......@@ -2038,13 +2047,25 @@ class NameNode(AtomicExprNode):
"'%s' cannot be specialized since its type is not a fused argument to this function" %
self.name)
atype = error_type
visibility = 'private'
if 'dataclasses.dataclass' in env.directives:
# handle "frozen" directive - full inspection of the dataclass directives happens
# in Dataclass.py
frozen_directive = None
dataclass_directive = env.directives['dataclasses.dataclass']
if dataclass_directive:
dataclass_directive_kwds = dataclass_directive[1]
frozen_directive = dataclass_directive_kwds.get('frozen', None)
is_frozen = frozen_directive and frozen_directive.is_literal and frozen_directive.value
if atype.is_pyobject or atype.can_coerce_to_pyobject(env):
visibility = 'readonly' if is_frozen else 'public'
# If the object can't be coerced that's fine - we just don't create a property
if as_target and env.is_c_class_scope and not (atype.is_pyobject or atype.is_error):
# TODO: this will need revising slightly if either cdef dataclasses or
# annotated cdef attributes are implemented
# TODO: this will need revising slightly if annotated cdef attributes are implemented
atype = py_object_type
warning(annotation.pos, "Annotation ignored since class-level attributes must be Python objects. "
"Were you trying to set up an instance attribute?", 2)
entry = self.entry = env.declare_var(name, atype, self.pos, is_cdef=not as_target)
entry = self.entry = env.declare_var(name, atype, self.pos, is_cdef=not as_target, visibility=visibility)
# Even if the entry already exists, make sure we're supplying an annotation if we can.
if annotation and not entry.annotation:
entry.annotation = annotation
......@@ -2057,6 +2078,10 @@ class NameNode(AtomicExprNode):
entry = env.lookup(self.name)
if entry and entry.as_module:
return entry.as_module
if entry and entry.known_standard_library_import:
scope = Builtin.get_known_standard_library_module_scope(entry.known_standard_library_import)
if scope and scope.is_module_scope:
return scope
return None
def analyse_as_type(self, env):
......@@ -2071,6 +2096,10 @@ class NameNode(AtomicExprNode):
entry = env.lookup(self.name)
if entry and entry.is_type:
return entry.type
elif entry and entry.known_standard_library_import:
entry = Builtin.get_known_standard_library_entry(entry.known_standard_library_import)
if entry and entry.is_type:
return entry.type
else:
return None
......@@ -2098,9 +2127,14 @@ class NameNode(AtomicExprNode):
self.entry = env.lookup_assignment_expression_target(self.name)
else:
self.entry = env.lookup_here(self.name)
if self.entry:
self.entry.known_standard_library_import = "" # already exists somewhere and so is now ambiguous
if not self.entry and self.annotation is not None:
# name : type = ...
self.declare_from_annotation(env, as_target=True)
is_dataclass = 'dataclasses.dataclass' in env.directives
# In a dataclass, an assignment should not prevent a name from becoming an instance attribute.
# Hence, "as_target = not is_dataclass".
self.declare_from_annotation(env, as_target=not is_dataclass)
if not self.entry:
if env.directives['warn.undeclared']:
warning(self.pos, "implicit declaration of '%s'" % self.name, 1)
......@@ -2609,6 +2643,11 @@ class NameNode(AtomicExprNode):
style, text = 'c_call', 'c function (%s)'
code.annotate(pos, AnnotationItem(style, text % self.type, size=len(self.name)))
def get_known_standard_library_import(self):
if self.entry:
return self.entry.known_standard_library_import
return None
class BackquoteNode(ExprNode):
# `expr`
#
......@@ -2718,6 +2757,9 @@ class ImportNode(ExprNode):
code.error_goto_if_null(self.result(), self.pos)))
self.generate_gotref(code)
def get_known_standard_library_import(self):
return self.module_name.value
class IteratorNode(ExprNode):
# Used as part of for statement implementation.
......@@ -3630,9 +3672,9 @@ class IndexNode(_IndexingBaseNode):
def analyse_as_type(self, env):
base_type = self.base.analyse_as_type(env)
if base_type and not base_type.is_pyobject:
if base_type.is_cpp_class:
if isinstance(self.index, TupleNode):
if base_type and (not base_type.is_pyobject or base_type.python_type_constructor_name):
if base_type.is_cpp_class or base_type.python_type_constructor_name:
if self.index.is_sequence_constructor:
template_values = self.index.args
else:
template_values = [self.index]
......@@ -7489,6 +7531,12 @@ class AttributeNode(ExprNode):
style, text = 'c_attr', 'c attribute (%s)'
code.annotate(self.pos, AnnotationItem(style, text % self.type, size=len(self.attribute)))
def get_known_standard_library_import(self):
module_name = self.obj.get_known_standard_library_import()
if module_name:
return StringEncoding.EncodedString("%s.%s" % (module_name, self.attribute))
return None
#-------------------------------------------------------------------
#
......@@ -9029,6 +9077,11 @@ class DictNode(ExprNode):
for item in self.key_value_pairs:
item.annotate(code)
def as_python_dict(self):
# returns a dict with constant keys and Node values
# (only works on DictNodes where the keys are ConstNodes or PyConstNode)
return dict([(key.value, value) for key, value in self.key_value_pairs])
class DictItemNode(ExprNode):
# Represents a single item in a DictNode
......@@ -9859,6 +9912,9 @@ class LambdaNode(InnerFunctionNode):
name = StringEncoding.EncodedString('<lambda>')
def analyse_declarations(self, env):
if hasattr(self, "lambda_name"):
# this if-statement makes it safe to run twice
return
self.lambda_name = self.def_node.lambda_name = env.next_id('lambda')
self.def_node.no_assignment_synthesis = True
self.def_node.pymethdef_required = True
......@@ -9888,6 +9944,9 @@ class GeneratorExpressionNode(LambdaNode):
binding = False
def analyse_declarations(self, env):
if hasattr(self, "genexpr_name"):
# this if-statement makes it safe to run twice
return
self.genexpr_name = env.next_id('genexpr')
super(GeneratorExpressionNode, self).analyse_declarations(env)
# No pymethdef required
......
......@@ -135,6 +135,10 @@ type_dict_guard_temp = pyrex_prefix + "typedict_guard"
cython_runtime_cname = pyrex_prefix + "cython_runtime"
cyfunction_type_cname = pyrex_prefix + "CyFunctionType"
fusedfunction_type_cname = pyrex_prefix + "FusedFunctionType"
# the name "dflt" was picked by analogy with the CPython dataclass module which stores
# the default values in variables named f"_dflt_{field.name}" in a hidden scope that's
# passed to the __init__ function. (The name is unimportant to the exact workings though)
dataclass_field_default_cname = pyrex_prefix + "dataclass_dflt"
global_code_object_cache_find = pyrex_prefix + 'find_code_object'
global_code_object_cache_insert = pyrex_prefix + 'insert_code_object'
......
......@@ -585,7 +585,9 @@ class CArrayDeclaratorNode(CDeclaratorNode):
child_attrs = ["base", "dimension"]
def analyse(self, base_type, env, nonempty=0, visibility=None, in_pxd=False):
if (base_type.is_cpp_class and base_type.is_template_type()) or base_type.is_cfunction:
if ((base_type.is_cpp_class and base_type.is_template_type()) or
base_type.is_cfunction or
base_type.python_type_constructor_name):
from .ExprNodes import TupleNode
if isinstance(self.dimension, TupleNode):
args = self.dimension.args
......@@ -597,7 +599,7 @@ class CArrayDeclaratorNode(CDeclaratorNode):
error(args[ix].pos, "Template parameter not a type")
base_type = error_type
else:
base_type = base_type.specialize_here(self.pos, values)
base_type = base_type.specialize_here(self.pos, env, values)
return self.base.analyse(base_type, env, nonempty=nonempty, visibility=visibility, in_pxd=in_pxd)
if self.dimension:
self.dimension = self.dimension.analyse_const_expression(env)
......@@ -963,6 +965,11 @@ class CArgDeclNode(Node):
base_type, arg_type = annotation.analyse_type_annotation(env, assigned_value=self.default)
if base_type is not None:
self.base_type = base_type
if arg_type and arg_type.python_type_constructor_name == "typing.Optional":
self.or_none = True
arg_type = arg_type.resolve()
if arg_type and arg_type.is_pyobject and not self.or_none:
self.not_none = True
return arg_type
def calculate_default_value_code(self, code):
......@@ -1064,7 +1071,13 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
else:
scope = None
break
if scope is None and len(self.module_path) == 1:
# (may be possible to handle longer module paths?)
# TODO: probably not the best place to declare it?
from .Builtin import get_known_standard_library_module_scope
found_entry = env.lookup(self.module_path[0])
if found_entry and found_entry.known_standard_library_import:
scope = get_known_standard_library_module_scope(found_entry.known_standard_library_import)
if scope is None:
# Maybe it's a cimport.
scope = env.find_imported_module(self.module_path, self.pos)
......@@ -1189,20 +1202,23 @@ class TemplatedTypeNode(CBaseTypeNode):
base_type = self.base_type_node.analyse(env)
if base_type.is_error: return base_type
if base_type.is_cpp_class and base_type.is_template_type():
if ((base_type.is_cpp_class and base_type.is_template_type()) or
base_type.python_type_constructor_name):
# Templated class
if self.keyword_args and self.keyword_args.key_value_pairs:
error(self.pos, "c++ templates cannot take keyword arguments")
tp = "c++ templates" if base_type.is_cpp_class else "indexed types"
error(self.pos, "%s cannot take keyword arguments" % tp)
self.type = PyrexTypes.error_type
else:
template_types = []
for template_node in self.positional_args:
type = template_node.analyse_as_type(env)
if type is None:
if type is None and base_type.is_cpp_class:
error(template_node.pos, "unknown type in template argument")
type = error_type
# for indexed_pytype we can be a bit more flexible and pass None
template_types.append(type)
self.type = base_type.specialize_here(self.pos, template_types)
self.type = base_type.specialize_here(self.pos, env, template_types)
elif base_type.is_pyobject:
# Buffer
......@@ -5066,6 +5082,7 @@ class CClassDefNode(ClassDefNode):
check_size = None
decorators = None
shadow = False
is_dataclass = False
@property
def punycode_class_name(self):
......@@ -5115,6 +5132,8 @@ class CClassDefNode(ClassDefNode):
if env.in_cinclude and not self.objstruct_name:
error(self.pos, "Object struct name specification required for C class defined in 'extern from' block")
if "dataclasses.dataclass" in env.directives:
self.is_dataclass = True
if self.decorators:
error(self.pos, "Decorators not allowed on cdef classes (used on type '%s')" % self.class_name)
self.base_type = None
......@@ -5846,6 +5865,13 @@ class SingleAssignmentNode(AssignmentNode):
self.lhs.analyse_assignment_expression_target_declaration(env)
else:
self.lhs.analyse_target_declaration(env)
# if an entry doesn't exist that just implies that lhs isn't made up purely
# of AttributeNodes and NameNodes - it isn't useful as a known path to
# a standard library module
if (self.lhs.is_attribute or self.lhs.is_name) and self.lhs.entry and not self.lhs.entry.known_standard_library_import:
stdlib_import_name = self.rhs.get_known_standard_library_import()
if stdlib_import_name:
self.lhs.entry.known_standard_library_import = stdlib_import_name
def analyse_types(self, env, use_temp=0):
from . import ExprNodes
......@@ -8548,7 +8574,8 @@ class CImportStatNode(StatNode):
env.declare_module(top_name, top_module_scope, self.pos)
else:
name = self.as_name or self.module_name
env.declare_module(name, module_scope, self.pos)
entry = env.declare_module(name, module_scope, self.pos)
entry.known_standard_library_import = self.module_name
if self.module_name in utility_code_for_cimports:
env.use_utility_code(utility_code_for_cimports[self.module_name]())
......@@ -8662,6 +8689,14 @@ class FromImportStatNode(StatNode):
self.import_star = 1
else:
target.analyse_target_declaration(env)
if target.entry:
if target.get_known_standard_library_import() is None:
target.entry.known_standard_library_import = EncodedString(
"%s.%s" % (self.module.module_name.value, name))
else:
# it isn't unambiguous
target.entry.known_standard_library_import = ""
def analyse_expressions(self, env):
from . import ExprNodes
......
......@@ -195,16 +195,6 @@ class IterationTransform(Visitor.EnvTransform):
annotation = iterable.entry.annotation.expr
if annotation.is_subscript:
annotation = annotation.base # container base type
# FIXME: generalise annotation evaluation => maybe provide a "qualified name" also for imported names?
if annotation.is_name:
if annotation.entry and annotation.entry.qualified_name == 'typing.Dict':
annotation_type = Builtin.dict_type
elif annotation.name == 'Dict':
annotation_type = Builtin.dict_type
if annotation.entry and annotation.entry.qualified_name in ('typing.Set', 'typing.FrozenSet'):
annotation_type = Builtin.set_type
elif annotation.name in ('Set', 'FrozenSet'):
annotation_type = Builtin.set_type
if Builtin.dict_type in (iterable.type, annotation_type):
# like iterating over dict.keys()
......
......@@ -302,6 +302,12 @@ def normalise_encoding_name(option_name, encoding):
return name
return encoding
# use as a sential value to defer analysis of the arguments
# instead of analysing them in InterpretCompilerDirectives. The dataclass directives are quite
# complicated and it's easier to deal with them at the point the dataclass is created
class DEFER_ANALYSIS_OF_ARGUMENTS:
pass
DEFER_ANALYSIS_OF_ARGUMENTS = DEFER_ANALYSIS_OF_ARGUMENTS()
# Override types possibilities above, if needed
directive_types = {
......@@ -328,6 +334,8 @@ directive_types = {
'c_string_encoding': normalise_encoding_name,
'trashcan': bool,
'total_ordering': bool,
'dataclasses.dataclass': DEFER_ANALYSIS_OF_ARGUMENTS,
'dataclasses.field': DEFER_ANALYSIS_OF_ARGUMENTS,
}
for key, val in _directive_defaults.items():
......@@ -372,6 +380,7 @@ directive_scopes = { # defaults to available everywhere
'iterable_coroutine': ('module', 'function'),
'trashcan' : ('cclass',),
'total_ordering': ('cclass', ),
'dataclasses.dataclass' : ('class', 'cclass',),
'cpp_locals': ('module', 'function', 'cclass'), # I don't think they make sense in a with_statement
}
......
......@@ -954,7 +954,6 @@ class InterpretCompilerDirectives(CythonTransform):
for pos, name, as_name, kind in node.imported_names:
full_name = submodule + name
qualified_name = u"cython." + full_name
if self.is_parallel_directive(qualified_name, node.pos):
# from cython cimport parallel, or
# from cython.parallel cimport parallel, prange, ...
......@@ -964,6 +963,10 @@ class InterpretCompilerDirectives(CythonTransform):
if kind is not None:
self.context.nonfatal_error(PostParseError(pos,
"Compiler directive imports must be plain imports"))
elif full_name in ['dataclasses', 'typing']:
self.directive_names[as_name or name] = full_name
# unlike many directives, still treat it as a regular module
newimp.append((pos, name, as_name, kind))
else:
newimp.append((pos, name, as_name, kind))
......@@ -1105,7 +1108,7 @@ class InterpretCompilerDirectives(CythonTransform):
if directivetype is bool:
arg = ExprNodes.BoolNode(node.pos, value=True)
return [self.try_to_parse_directive(optname, [arg], None, node.pos)]
elif directivetype is None:
elif directivetype is None or directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS:
return [(optname, None)]
else:
raise PostParseError(
......@@ -1160,7 +1163,7 @@ class InterpretCompilerDirectives(CythonTransform):
if len(args) != 0:
raise PostParseError(pos,
'The %s directive takes no prepositional arguments' % optname)
return optname, dict([(key.value, value) for key, value in kwds.key_value_pairs])
return optname, kwds.as_python_dict()
elif directivetype is list:
if kwds and len(kwds.key_value_pairs) != 0:
raise PostParseError(pos,
......@@ -1172,6 +1175,9 @@ class InterpretCompilerDirectives(CythonTransform):
raise PostParseError(pos,
'The %s directive takes one compile-time string argument' % optname)
return (optname, directivetype(optname, str(args[0].value)))
elif directivetype is Options.DEFER_ANALYSIS_OF_ARGUMENTS:
# signal to pass things on without processing
return (optname, (args, kwds.as_python_dict()))
else:
assert False
......@@ -1239,7 +1245,8 @@ class InterpretCompilerDirectives(CythonTransform):
name, value = directive
if self.directives.get(name, object()) != value:
directives.append(directive)
if directive[0] == 'staticmethod':
if (directive[0] == 'staticmethod' or
(directive[0] == 'dataclasses.dataclass' and scope_name == 'class')):
both.append(dec)
# Adapt scope type based on decorators that change it.
if directive[0] == 'cclass' and scope_name == 'class':
......@@ -1248,6 +1255,12 @@ class InterpretCompilerDirectives(CythonTransform):
realdecs.append(dec)
if realdecs and (scope_name == 'cclass' or
isinstance(node, (Nodes.CClassDefNode, Nodes.CVarDefNode))):
for realdec in realdecs:
realdec = realdec.decorator
if ((realdec.is_name and realdec.name == "dataclass") or
(realdec.is_attribute and realdec.attribute == "dataclass")):
error(realdec.pos,
"Use '@cython.dataclasses.dataclass' on cdef classes to create a dataclass")
# Note - arbitrary C function decorators are caught later in DecoratorTransform
raise PostParseError(realdecs[0].pos, "Cdef functions/classes cannot take arbitrary decorators.")
node.decorators = realdecs[::-1] + both[::-1]
......@@ -1906,6 +1919,9 @@ if VALUE is not None:
def visit_CClassDefNode(self, node):
node = self.visit_ClassDefNode(node)
if node.scope and 'dataclasses.dataclass' in node.scope.directives:
from .Dataclass import handle_cclass_dataclass
handle_cclass_dataclass(node, node.scope.directives['dataclasses.dataclass'], self)
if node.scope and node.scope.implemented and node.body:
stats = []
for entry in node.scope.var_entries:
......
......@@ -194,6 +194,7 @@ class PyrexType(BaseType):
# is_string boolean Is a C char * type
# is_pyunicode_ptr boolean Is a C PyUNICODE * type
# is_cpp_string boolean Is a C++ std::string type
# python_type_constructor_name string or None non-None if it is a Python type constructor that can be indexed/"templated"
# is_unicode_char boolean Is either Py_UCS4 or Py_UNICODE
# is_returncode boolean Is used only to signal exceptions
# is_error boolean Is the dummy error type
......@@ -257,6 +258,7 @@ class PyrexType(BaseType):
is_struct_or_union = 0
is_cpp_class = 0
is_optional_cpp_class = 0
python_type_constructor_name = None
is_cpp_string = 0
is_struct = 0
is_enum = 0
......@@ -1507,12 +1509,14 @@ class PyExtensionType(PyObjectType):
# early_init boolean Whether to initialize early (as opposed to during module execution).
# defered_declarations [thunk] Used to declare class hierarchies in order
# check_size 'warn', 'error', 'ignore' What to do if tp_basicsize does not match
# dataclass_fields OrderedDict nor None Used for inheriting from dataclasses
is_extension_type = 1
has_attributes = 1
early_init = 1
objtypedef_cname = None
dataclass_fields = None
def __init__(self, name, typedef_flag, base_type, is_external=0, check_size=None):
self.name = name
......@@ -3872,7 +3876,7 @@ class CppClassType(CType):
T.get_fused_types(result, seen)
return result
def specialize_here(self, pos, template_values=None):
def specialize_here(self, pos, env, template_values=None):
if not self.is_template_type():
error(pos, "'%s' type is not a template" % self)
return error_type
......@@ -4400,6 +4404,102 @@ class ErrorType(PyrexType):
return "dummy"
class PythonTypeConstructor(PyObjectType):
"""Used to help Cython interpret indexed types from the typing module (or similar)
"""
def __init__(self, name, base_type=None):
self.python_type_constructor_name = name
self.base_type = base_type
def specialize_here(self, pos, env, template_values=None):
if self.base_type:
# for a lot of the typing classes it doesn't really matter what the template is
# (i.e. typing.Dict[int] is really just a dict)
return self.base_type
return self
def __repr__(self):
if self.base_type:
return "%s[%r]" % (self.name, self.base_type)
else:
return self.name
def is_template_type(self):
return True
class PythonTupleTypeConstructor(PythonTypeConstructor):
def specialize_here(self, pos, env, template_values=None):
if (template_values and None not in template_values and
not any(v.is_pyobject for v in template_values)):
entry = env.declare_tuple_type(pos, template_values)
if entry:
return entry.type
return super(PythonTupleTypeConstructor, self).specialize_here(pos, env, template_values)
class SpecialPythonTypeConstructor(PythonTypeConstructor):
"""
For things like ClassVar, Optional, etc, which have extra features on top of being
a "templated" type.
"""
def __init__(self, name, template_type=None):
super(SpecialPythonTypeConstructor, self).__init__(name, None)
if (name == "typing.ClassVar" and template_type
and not template_type.is_pyobject):
# because classvars end up essentially used as globals they have
# to be PyObjects. Try to find the nearest suitable type (although
# practically I doubt this matters).
py_type_name = template_type.py_type_name()
if py_type_name:
from .Builtin import builtin_scope
template_type = (builtin_scope.lookup_type(py_type_name)
or py_object_type)
else:
template_type = py_object_types
self.template_type = template_type
def __repr__(self):
if self.template_type:
return "%s[%r]" % (self.name, self.template_type)
else:
return self.name
def is_template_type(self):
return self.template_type is None
def resolve(self):
if self.template_type:
return self.template_type.resolve()
else:
return self
def specialize_here(self, pos, env, template_values=None):
if len(template_values) != 1:
error(pos, "'%s' takes exactly one template argument." % self.name)
# return a copy of the template type with python_type_constructor_name as an attribute
# so it can be identified, and a resolve function that gets back to
# the original type (since types are usually tested with "is")
new_type = template_values[0]
if self.python_type_constructor_name == "typing.ClassVar":
# classvar must remain a py_object_type
new_type = py_object_type
if (self.python_type_constructor_name == "typing.Optional" and
not new_type.is_pyobject):
# optional must be a py_object, but can be a specialized py_object
new_type = py_object_type
return SpecialPythonTypeConstructor(
self.python_type_constructor_name,
template_type = template_values[0])
def __getattr__(self, name):
if self.template_type:
return getattr(self.template_type, name)
return super(SpecialPythonTypeConstructor, self).__getattr__(name)
rank_to_type_name = (
"char", # 0
"short", # 1
......
......@@ -159,6 +159,9 @@ class Entry(object):
# is a specialization
# is_cgetter boolean Is a c-level getter function
# is_cpp_optional boolean Entry should be declared as std::optional (cpp_locals directive)
# known_standard_library_import Either None (default), an empty string (definitely can't be determined)
# or a string of "modulename.something.attribute"
# Used for identifying imports from typing/dataclasses etc
# TODO: utility_code and utility_code_definition serves the same purpose...
......@@ -166,6 +169,7 @@ class Entry(object):
borrowed = 0
init = ""
annotation = None
pep563_annotation = None
visibility = 'private'
is_builtin = 0
is_cglobal = 0
......@@ -231,6 +235,7 @@ class Entry(object):
outer_entry = None
is_cgetter = False
is_cpp_optional = False
known_standard_library_import = None
def __init__(self, name, cname, type, pos = None, init = None):
self.name = name
......@@ -998,13 +1003,27 @@ class Scope(object):
entry = self.declare_var(name, py_object_type, None)
return entry
def lookup_type(self, name):
entry = self.lookup(name)
def _type_or_specialized_type_from_entry(self, entry):
if entry and entry.is_type:
if entry.type.is_fused and self.fused_to_specific:
return entry.type.specialize(self.fused_to_specific)
return entry.type
return None
def lookup_type(self, name):
entry = self.lookup(name)
# The logic here is:
# 1. if entry is a type then return it (and maybe specialize it)
# 2. if the entry comes from a known standard library import then follow that
# 3. repeat step 1 with the (possibly) updated entry
tp = self._type_or_specialized_type_from_entry(entry)
if tp:
return tp
# allow us to find types from the "typing" module and similar
if entry and entry.known_standard_library_import:
from .Builtin import get_known_standard_library_entry
entry = get_known_standard_library_entry(entry.known_standard_library_import)
return self._type_or_specialized_type_from_entry(entry)
def lookup_operator(self, operator, operands):
if operands[0].type.is_cpp_class:
......@@ -2284,6 +2303,15 @@ class CClassScope(ClassScope):
cname = None, visibility = 'private',
api = 0, in_pxd = 0, is_cdef = 0):
name = self.mangle_class_private_name(name)
if type.python_type_constructor_name == "typing.ClassVar":
is_cdef = 0
type = type.resolve()
if (type.python_type_constructor_name == "dataclasses.InitVar" and
'dataclasses.dataclass' not in self.directives):
error(pos, "Use of cython.dataclasses.InitVar does not make sense outside a dataclass")
if is_cdef:
# Add an entry for an attribute.
if self.defined:
......@@ -2530,6 +2558,7 @@ class CClassScope(ClassScope):
base_entry.name, adapt(base_entry.cname),
base_entry.type, None, 'private')
entry.is_variable = 1
entry.annotation = base_entry.annotation
self.inherited_var_entries.append(entry)
# If the class defined in a pxd, specific entries have not been added.
......
......@@ -525,6 +525,30 @@ class CythonDotParallel(object):
# def threadsavailable(self):
# return 1
class CythonDotImportedFromElsewhere(object):
"""
cython.dataclasses just shadows the standard library modules of the same name
"""
def __init__(self, module):
self.__path__ = []
self.__file__ = None
self.__name__ = module
self.__package__ = module
def __getattr__(self, attr):
# we typically only expect this to be called once
from importlib import import_module
import sys
try:
mod = import_module(self.__name__)
except ImportError:
# but if they don't exist (Python is not sufficiently up-to-date) then
# you can't use them
raise AttributeError("%s: the standard library module %s is not available" %
(attr, self.__name__))
sys.modules['cython.%s' % self.__name__] = mod
return getattr(mod, attr)
class CythonCImports(object):
"""
......@@ -547,4 +571,7 @@ sys.modules['cython.parallel'] = CythonDotParallel()
sys.modules['cython.cimports'] = CythonCImports('cython.cimports')
sys.modules['cython.cimports.libc'] = CythonCImports('cython.cimports.libc')
sys.modules['cython.cimports.libc.math'] = math
# In pure Python mode @cython.dataclasses.dataclass and dataclass field should just
# shadow the standard library ones (if they are available)
dataclasses = sys.modules['cython.dataclasses'] = CythonDotImportedFromElsewhere('dataclasses')
del math, sys
/////////////// FetchSharedCythonModule.proto ///////
static PyObject *__Pyx_FetchSharedCythonABIModule(void);
/////////////// FetchSharedCythonModule ////////////
static PyObject *__Pyx_FetchSharedCythonABIModule(void) {
PyObject *abi_module = PyImport_AddModule((char*) __PYX_ABI_MODULE_NAME);
if (unlikely(!abi_module)) return NULL;
Py_INCREF(abi_module);
return abi_module;
}
/////////////// FetchCommonType.proto ///////////////
#if !CYTHON_USE_TYPE_SPECS
......@@ -8,15 +21,9 @@ static PyTypeObject* __Pyx_FetchCommonTypeFromSpec(PyObject *module, PyType_Spec
/////////////// FetchCommonType ///////////////
//@requires:ExtensionTypes.c::FixUpExtensionType
//@requires: FetchSharedCythonModule
//@requires:StringTools.c::IncludeStringH
static PyObject *__Pyx_FetchSharedCythonABIModule(void) {
PyObject *abi_module = PyImport_AddModule((char*) __PYX_ABI_MODULE_NAME);
if (!abi_module) return NULL;
Py_INCREF(abi_module);
return abi_module;
}
static int __Pyx_VerifyCachedType(PyObject *cached_type,
const char *name,
Py_ssize_t basicsize,
......
///////////////////// ModuleLoader.proto //////////////////////////
static PyObject* __Pyx_LoadInternalModule(const char* name, const char* fallback_code); /* proto */
//////////////////// ModuleLoader ///////////////////////
//@requires: CommonStructures.c::FetchSharedCythonModule
static PyObject* __Pyx_LoadInternalModule(const char* name, const char* fallback_code) {
// We want to be able to use the contents of the standard library dataclasses module where available.
// If those objects aren't available (due to Python version) then a simple fallback is substituted
// instead, which largely just fails with a not-implemented error.
//
// The fallbacks are placed in the "shared abi module" as a convenient internal place to
// store them
PyObject *shared_abi_module = 0, *module = 0;
shared_abi_module = __Pyx_FetchSharedCythonABIModule();
if (!shared_abi_module) return NULL;
if (PyObject_HasAttrString(shared_abi_module, name)) {
PyObject* result = PyObject_GetAttrString(shared_abi_module, name);
Py_DECREF(shared_abi_module);
return result;
}
// the best and simplest case is simply to defer to the standard library (if available)
module = PyImport_ImportModule(name);
if (!module) {
PyObject *localDict, *runValue, *builtins, *modulename;
if (!PyErr_ExceptionMatches(PyExc_ImportError)) goto bad;
PyErr_Clear(); // this is reasonably likely (especially on older versions of Python)
#if PY_MAJOR_VERSION < 3
modulename = PyBytes_FromFormat("_cython_" CYTHON_ABI ".%s", name);
#else
modulename = PyUnicode_FromFormat("_cython_" CYTHON_ABI ".%s", name);
#endif
if (!modulename) goto bad;
#if PY_MAJOR_VERSION >= 3 && CYTHON_COMPILING_IN_CPYTHON
module = PyImport_AddModuleObject(modulename); // borrowed
#else
module = PyImport_AddModule(PyBytes_AsString(modulename)); // borrowed
#endif
Py_DECREF(modulename);
if (!module) goto bad;
Py_INCREF(module);
if (PyObject_SetAttrString(shared_abi_module, name, module) < 0) goto bad;
localDict = PyModule_GetDict(module); // borrowed
if (!localDict) goto bad;
builtins = PyEval_GetBuiltins(); // borrowed
if (!builtins) goto bad;
if (PyDict_SetItemString(localDict, "__builtins__", builtins) <0) goto bad;
runValue = PyRun_String(fallback_code, Py_file_input, localDict, localDict);
if (!runValue) goto bad;
Py_DECREF(runValue);
}
goto shared_cleanup;
bad:
Py_CLEAR(module);
shared_cleanup:
Py_XDECREF(shared_abi_module);
return module;
}
///////////////////// SpecificModuleLoader.proto //////////////////////
//@substitute: tempita
static PyObject* __Pyx_Load_{{cname}}_Module(void); /* proto */
//////////////////// SpecificModuleLoader ///////////////////////
//@requires: ModuleLoader
static PyObject* __Pyx_Load_{{cname}}_Module(void) {
return __Pyx_LoadInternalModule("{{cname}}", {{py_code}});
}
################### Dataclasses_fallback ###############################
# This is the fallback dataclass code if the stdlib module isn't available.
# It defines enough of the support types to be used with cdef classes
# and to fail if used on regular types.
# (Intended to be included as py code - not compiled)
from collections import namedtuple
try:
from types import MappingProxyType
except ImportError:
# mutable fallback if unavailable
MappingProxyType = lambda x: x
class _MISSING_TYPE(object):
pass
MISSING = _MISSING_TYPE()
_DataclassParams = namedtuple('_DataclassParams',
["init", "repr", "eq", "order", "unsafe_hash", "frozen"])
class Field(object):
__slots__ = ('name',
'type',
'default',
'default_factory',
'repr',
'hash',
'init',
'compare',
'metadata',
'_field_type', # Private: not to be used by user code.
)
def __init__(self, default, default_factory, init, repr, hash, compare,
metadata):
self.name = None
self.type = None
self.default = default
self.default_factory = default_factory
self.init = init
self.repr = repr
self.hash = hash
self.compare = compare
# Be aware that if MappingProxyType is unavailable (i.e. py2?) then we
# don't enforce non-mutability that the real module does
self.metadata = (MappingProxyType({})
if metadata is None else
MappingProxyType(metadata))
self._field_type = None
def __repr__(self):
return ('Field('
'name={0!r},'
'type={1!r},'
'default={2!r},'
'default_factory={3!r},'
'init={4!r},'
'repr={5!r},'
'hash={6!r},'
'compare={7!r},'
'metadata={8!r},'
')'.format(self.name, self.type, self.default,
self.default_factory, self.init,
self.repr, self.hash, self.compare,
self.metadata))
# A sentinel object for default values to signal that a default
# factory will be used. This is given a nice repr() which will appear
# in the function signature of dataclasses' constructors.
class _HAS_DEFAULT_FACTORY_CLASS:
def __repr__(self):
return '<factory>'
_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
def dataclass(*args, **kwds):
raise NotImplementedError("Standard library 'dataclasses' module"
"is unavailable, likely due to the version of Python you're using.")
# Markers for the various kinds of fields and pseudo-fields.
class _FIELD_BASE:
def __init__(self, name):
self.name = name
def __repr__(self):
return self.name
_FIELD = _FIELD_BASE('_FIELD')
_FIELD_CLASSVAR = _FIELD_BASE('_FIELD_CLASSVAR')
_FIELD_INITVAR = _FIELD_BASE('_FIELD_INITVAR')
def field(*ignore, **kwds):
default = kwds.pop("default", MISSING)
default_factory = kwds.pop("default_factory", MISSING)
init = kwds.pop("init", True)
repr = kwds.pop("repr", True)
hash = kwds.pop("hash", None)
compare = kwds.pop("compare", True)
metadata = kwds.pop("metadata", None)
if kwds:
raise ValueError("field received unexpected keyword arguments: %s"
% list(kwds.keys()))
if default is not MISSING and default_factory is not MISSING:
raise ValueError('cannot specify both default and default_factory')
if ignore:
raise ValueError("'field' does not take any positional arguments")
return Field(default, default_factory, init, repr, hash, compare, metadata)
cimport cython
try:
import typing
import dataclasses
except ImportError:
pass # The modules don't actually have to exists for Cython to use them as annotations
@cython.dataclasses.dataclass
cdef class MyDataclass:
# fields can be declared using annotations
a: cython.int = 0
b: double = cython.dataclasses.field(default_factory = lambda: 10, repr=False)
# fields can also be declared using `cdef`:
cdef str c
c = "hello" # assignment of default value on a separate line
# typing.InitVar and typing.ClassVar also work
d: dataclasses.InitVar[double] = 5
e: typing.ClassVar[list] = []
......@@ -335,15 +335,27 @@ declare types of variables in a Python 3.6 compatible way as follows:
There is currently no way to express the visibility of object attributes.
Cython does not support the full range of annotations described by PEP-484.
For example it does not currently understand features from the ``typing`` module
such as ``Optional[]`` or typed containers such as ``List[str]``. This is partly
because some of these type hints are not relevant for the compilation to
``typing`` Module
^^^^^^^^^^^^^^^^^
Support for the full range of annotations described by PEP-484 is not yet
complete. Cython 3 currently understands the following features from the
``typing`` module:
* ``Optional[tp]``, which is interpreted as ``tp or None``;
* typed containers such as ``List[str]``, which is interpreted as ``list``. The
hint that the elements are of type ``str`` is currently ignored;
* ``Tuple[...]``, which is converted into a Cython C-tuple where possible
and a regular Python ``tuple`` otherwise.
* ``ClassVar[...]``, which is understood in the context of
``cdef class`` or ``@cython.cclass``.
Some of the unsupported features are likely to remain
unsupported since these type hints are not relevant for the compilation to
efficient C code. In other cases, however, where the generated C code could
benefit from these type hints but does not currently, help is welcome to
improve the type analysis in Cython.
Tips and Tricks
---------------
......
......@@ -1051,5 +1051,29 @@ generated containing declarations for its object struct and type object. By
including the ``.h`` file in external C code that you write, that code can
access the attributes of the extension type.
Dataclass extension types
=========================
Cython supports extension types that behave like the dataclasses defined in
the Python 3.7+ standard library. The main benefit of using a dataclass is
that it can auto-generate simple ``__init__``, ``__repr__`` and comparison
functions. The Cython implementation behaves as much like the Python
standard library implementation as possible and therefore the documentation
here only briefly outlines the differences - if you plan on using them
then please read `the documentation for the standard library module
<https://docs.python.org/3/library/dataclasses.html>`_.
Dataclasses can be declared using the ``@cython.dataclasses.dataclass``
decorator on a Cython extension type. ``@cython.dataclasses.dataclass``
can only be applied to extension types (types marked ``cdef`` or created with the
``cython.cclass`` decorator) and not to regular classes. If
you need to define special properties on a field then use ``cython.dataclasses.field``
.. literalinclude:: ../../examples/userguide/extension_types/dataclass.pyx
You may use C-level types such as structs, pointers, or C++ classes.
However, you may find these types are not compatible with the auto-generated
special methods - for example if they cannot be converted from a Python
type they cannot be passed to a constructor, and so you must use a
``default_factory`` to initialize them. Like with the Python implementation, you can also control
which special functions an attribute is used in using ``field()``.
# mode: error
cimport cython
@cython.dataclasses.dataclass(1, shouldnt_be_here=True, init=5, unsafe_hash=True)
cdef class C:
a: list = [] # mutable
b: int = cython.dataclasses.field(default=5, default_factory=int)
c: int
def __hash__(self):
pass
_ERRORS = """
6:5: Arguments passed to cython.dataclasses.dataclass must be True or False
6:5: Cannot overwrite attribute __hash__ in class C
6:5: cython.dataclasses.dataclass() got an unexpected keyword argument 'shouldnt_be_here'
6:5: cython.dataclasses.dataclass takes no positional arguments
7:14: mutable default <class 'list'> for field a is not allowed: use default_factory
8:37: cannot specify both default and default_factory
9:4: non-default argument 'c' follows default argument in dataclass __init__
"""
# mode: error
# tag: dataclass
import dataclasses
@dataclasses.dataclass
cdef class C:
pass
_ERRORS = """
6:0: Cdef functions/classes cannot take arbitrary decorators.
6:0: Use '@cython.dataclasses.dataclass' on cdef classes to create a dataclass
"""
# mode: compile
# tag: dataclass, warnings
cimport cython
from dataclass import field
@cython.dataclasses.dataclass
cdef class E:
a: int = field()
_WARNINGS="""
9:18: Do you mean cython.dataclasses.field instead?
"""
# mode: error
cimport cython
@cython.dataclasses.dataclass
cdef class C:
a: int = cython.dataclasses.field(unexpected=True)
_ERRORS = """
7:49: cython.dataclasses.field() got an unexpected keyword argument 'unexpected'
"""
# mode: run
# tag: dataclass
from cython cimport dataclasses
from cython.dataclasses cimport dataclass, field
try:
import typing
from typing import ClassVar
from dataclasses import InitVar
import dataclasses as py_dataclasses
except ImportError:
pass
import cython
from libc.stdlib cimport malloc, free
include "../testsupport/cythonarrayutil.pxi"
cdef class NotADataclass:
cdef cython.int a
b: float
def __repr__(self):
return "NADC"
def __str__(self):
return "string of NotADataclass" # should be called - repr is called!
def __eq__(self, other):
return type(self) == type(other)
def __hash__(self):
return 1
@dataclass(unsafe_hash=True)
cdef class BasicDataclass:
"""
>>> sorted(list(BasicDataclass.__dataclass_fields__.keys()))
['a', 'b', 'c', 'd']
# Check the field type attribute - this is currently a string since
# it's taken from the annotation, but if we drop PEP563 in future
# then it may change
>>> BasicDataclass.__dataclass_fields__["a"].type
'float'
>>> BasicDataclass.__dataclass_fields__["b"].type
'NotADataclass'
>>> BasicDataclass.__dataclass_fields__["c"].type
'object'
>>> BasicDataclass.__dataclass_fields__["d"].type
'list'
>>> inst1 = BasicDataclass() # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: __init__() takes at least 1 ...
>>> inst1 = BasicDataclass(2.0)
# The error at-least demonstrates that the hash function has been created
>>> hash(inst1) # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: ...unhashable...
>>> inst2 = BasicDataclass(2.0)
>>> inst1 == inst2
True
>>> inst2 = BasicDataclass(2.0, NotADataclass(), [])
>>> inst1 == inst2
False
>>> inst2 = BasicDataclass(2.0, NotADataclass(), [], [1,2,3])
>>> inst2
BasicDataclass(a=2.0, b=NADC, c=[], d=[1, 2, 3])
>>> inst2.c = "Some string"
>>> inst2
BasicDataclass(a=2.0, b=NADC, c='Some string', d=[1, 2, 3])
"""
a: float
b: NotADataclass = field(default_factory=NotADataclass)
c: object = field(default=0)
d: list = dataclasses.field(default_factory=list)
@dataclasses.dataclass
cdef class InheritsFromDataclass(BasicDataclass):
"""
>>> sorted(list(InheritsFromDataclass.__dataclass_fields__.keys()))
['a', 'b', 'c', 'd', 'e']
>>> InheritsFromDataclass(a=1.0, e=5)
In __post_init__
InheritsFromDataclass(a=1.0, b=NADC, c=0, d=[], e=5)
"""
e: cython.int = 0
def __post_init__(self):
print "In __post_init__"
@cython.dataclasses.dataclass
cdef class InheritsFromNotADataclass(NotADataclass):
"""
>>> sorted(list(InheritsFromNotADataclass.__dataclass_fields__.keys()))
['c']
>>> InheritsFromNotADataclass()
InheritsFromNotADataclass(c=1)
>>> InheritsFromNotADataclass(5)
InheritsFromNotADataclass(c=5)
"""
c: cython.int = 1
cdef struct S:
int a
ctypedef S* S_ptr
cdef S_ptr malloc_a_struct():
return <S_ptr>malloc(sizeof(S))
@dataclass
cdef class ContainsNonPyFields:
"""
>>> ContainsNonPyFields() # doctest: +ELLIPSIS
Traceback (most recent call last):
TypeError: __init__() takes ... 1 positional ...
>>> ContainsNonPyFields(mystruct={'a': 1 }) # doctest: +ELLIPSIS
ContainsNonPyFields(mystruct={'a': 1}, memview=<MemoryView of 'array' at ...>)
>>> ContainsNonPyFields(mystruct={'a': 1 }, memview=create_array((2,2), "c")) # doctest: +ELLIPSIS
ContainsNonPyFields(mystruct={'a': 1}, memview=<MemoryView of 'array' at ...>)
>>> ContainsNonPyFields(mystruct={'a': 1 }, mystruct_ptr=0)
Traceback (most recent call last):
TypeError: __init__() got an unexpected keyword argument 'mystruct_ptr'
"""
mystruct: S = cython.dataclasses.field(compare=False)
mystruct_ptr: S_ptr = field(init=False, repr=False, default_factory=malloc_a_struct)
memview: int[:, ::1] = field(default=create_array((3,1), "c"), # mutable so not great but OK for a test
compare=False)
def __dealloc__(self):
free(self.mystruct_ptr)
@dataclass
cdef class InitClassVars:
"""
Private (i.e. defined with "cdef") members deliberately don't appear
TODO - ideally c1 and c2 should also be listed here
>>> sorted(list(InitClassVars.__dataclass_fields__.keys()))
['a', 'b1', 'b2']
>>> InitClassVars.c1
2.0
>>> InitClassVars.e1
[]
>>> inst1 = InitClassVars()
In __post_init__
>>> inst1 # init vars don't appear in string
InitClassVars(a=0)
>>> inst2 = InitClassVars(b1=5, d2=100)
In __post_init__
>>> inst1 == inst2 # comparison ignores the initvar
True
"""
a: cython.int = 0
b1: InitVar[double] = 1.0
b2: py_dataclasses.InitVar[double] = 1.0
c1: ClassVar[float] = 2.0
c2: typing.ClassVar[float] = 2.0
cdef InitVar[cython.int] d1
cdef py_dataclasses.InitVar[cython.int] d2
d1 = 5
d2 = 5
cdef ClassVar[list] e1
cdef typing.ClassVar[list] e2
e1 = []
e2 = []
def __post_init__(self, b1, b2, d1, d2):
# Check that the initvars haven't been assigned yet
assert self.b1==0, self.b1
assert self.b2==0, self.b2
assert self.d1==0, self.d1
assert self.d2==0, self.d2
self.b1 = b1
self.b2 = b2
self.d1 = d1
self.d2 = d2
print "In __post_init__"
@dataclass
cdef class TestVisibility:
"""
>>> inst = TestVisibility()
>>> "a" in TestVisibility.__dataclass_fields__
False
>>> hasattr(inst, "a")
False
>>> "b" in TestVisibility.__dataclass_fields__
True
>>> hasattr(inst, "b")
True
>>> "c" in TestVisibility.__dataclass_fields__
True
>>> TestVisibility.__dataclass_fields__["c"].type
'double'
>>> hasattr(inst, "c")
True
"""
cdef double a
a = 1.0
b: double = 2.0
cdef public double c
c = 3.0
@dataclass(frozen=True)
cdef class TestFrozen:
"""
>>> inst = TestFrozen(a=5)
>>> inst.a
5.0
>>> inst.a = 2. # doctest: +ELLIPSIS
Traceback (most recent call last):
AttributeError: attribute 'a' of '...TestFrozen' objects is not writable
"""
a: double = 2.0
import sys
if sys.version_info >= (3, 7):
__doc__ = """
>>> from dataclasses import Field, is_dataclass, fields
# It uses the types from the standard library where available
>>> all(isinstance(v, Field) for v in BasicDataclass.__dataclass_fields__.values())
True
# check out Cython dataclasses are close enough to convince it
>>> is_dataclass(BasicDataclass)
True
>>> is_dataclass(BasicDataclass(1.5))
True
>>> is_dataclass(InheritsFromDataclass)
True
>>> is_dataclass(NotADataclass)
False
>>> is_dataclass(InheritsFromNotADataclass)
True
>>> [ f.name for f in fields(BasicDataclass)]
['a', 'b', 'c', 'd']
>>> [ f.name for f in fields(InitClassVars)]
['a']
"""
cimport cython
try:
import typing
from typing import Optional
except ImportError:
pass # Cython can still identify the use of "typing" even if the module doesn't exist
### extension types
......@@ -79,6 +84,39 @@ def ext_not_none(MyExtType x not None):
"""
return attr(x)
def ext_annotations(x: MyExtType):
"""
Behaves the same as "MyExtType x not None"
>>> ext_annotations(MyExtType())
123
>>> ext_annotations(None)
Traceback (most recent call last):
TypeError: Argument 'x' has incorrect type (expected ext_type_none_arg.MyExtType, got NoneType)
"""
return attr(x)
@cython.allow_none_for_extension_args(False)
def ext_annotations_check_on(x: MyExtType):
"""
>>> ext_annotations_check_on(MyExtType())
123
>>> ext_annotations_check_on(None)
Traceback (most recent call last):
TypeError: Argument 'x' has incorrect type (expected ext_type_none_arg.MyExtType, got NoneType)
"""
return attr(x)
def ext_optional(x: typing.Optional[MyExtType], y: Optional[MyExtType]):
"""
Behaves the same as "or None"
>>> ext_optional(MyExtType(), MyExtType())
246
>>> ext_optional(MyExtType(), None)
444
>>> ext_optional(None, MyExtType())
444
"""
return attr(x) + attr(y)
### builtin types (using list)
......
......@@ -5,10 +5,13 @@
import cython
from typing import Dict, List, TypeVar, Optional, Generic, Tuple
try:
import typing
from typing import Set as _SET_
from typing import ClassVar
except ImportError:
ClassVar = Optional # fake it in Py3.5
pass # this should allow Cython to interpret the directives even when the module doesn't exist
var = 1 # type: annotation
......@@ -51,6 +54,8 @@ class BasicStarship(object):
'Picard'
>>> bs.stats
{}
>>> BasicStarship.stats
{}
"""
captain: str = 'Picard' # instance variable with default
damage: cython.int # instance variable without default
......@@ -117,11 +122,7 @@ def iter_declared_dict(d):
>>> iter_declared_dict(d)
7.0
>>> class D(object):
... def __getitem__(self, x): return 2
... def __iter__(self): return iter([1, 2, 3])
>>> iter_declared_dict(D())
6.0
# specialized "compiled" test in module-level __doc__
"""
typed_dict : Dict[float, float] = d
s = 0.0
......@@ -140,11 +141,7 @@ def iter_declared_dict_arg(d : Dict[float, float]):
>>> iter_declared_dict_arg(d)
7.0
>>> class D(object):
... def __getitem__(self, x): return 2
... def __iter__(self): return iter([1, 2, 3])
>>> iter_declared_dict_arg(D())
6.0
# module level "compiled" test in __doc__ below
"""
s = 0.0
for key in d:
......@@ -161,13 +158,58 @@ def literal_list_ptr():
return a[3]
def test_subscripted_types():
"""
>>> test_subscripted_types()
dict object
list object
set object
"""
a: typing.Dict[int, float] = {}
b: List[int] = []
c: _SET_[object] = set()
print(cython.typeof(a) + (" object" if not cython.compiled else ""))
print(cython.typeof(b) + (" object" if not cython.compiled else ""))
print(cython.typeof(c) + (" object" if not cython.compiled else ""))
# because tuple is specifically special cased to go to ctuple where possible
def test_tuple(a: typing.Tuple[int, float], b: typing.Tuple[int, ...],
c: Tuple[int, object] # cannot be a ctuple
):
"""
>>> test_tuple((1, 1.0), (1, 1.0), (1, 1.0))
int
int
tuple object
tuple object
"""
x: typing.Tuple[int, float] = (a[0], a[1])
y: Tuple[int, ...] = (1,2.)
z = a[0] # should infer to int
print(cython.typeof(z))
print(cython.typeof(x[0]))
print(cython.typeof(y) + (" object" if not cython.compiled else ""))
print(cython.typeof(c) + (" object" if not cython.compiled else ""))
if cython.compiled:
__doc__ = """
# passing non-dicts to variables declared as dict now fails
>>> class D(object):
... def __getitem__(self, x): return 2
... def __iter__(self): return iter([1, 2, 3])
>>> iter_declared_dict(D()) # doctest:+IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError: Expected dict, got D
>>> iter_declared_dict_arg(D()) # doctest:+IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
TypeError: Expected dict, got D
"""
_WARNINGS = """
37:19: Unknown type declaration in annotation, ignoring
38:12: Unknown type declaration in annotation, ignoring
39:18: Unknown type declaration in annotation, ignoring
73:11: Annotation ignored since class-level attributes must be Python objects. Were you trying to set up an instance attribute?
73:19: Unknown type declaration in annotation, ignoring
# FIXME: these are sort-of evaluated now, so the warning is misleading
126:21: Unknown type declaration in annotation, ignoring
137:35: Unknown type declaration in annotation, ignoring
"""
# mode: run
import cython
try:
import typing
from typing import List, Tuple
from typing import Set as _SET_
except:
pass # this should allow Cython to interpret the directives even when the module doesn't exist
def test_subscripted_types():
"""
>>> test_subscripted_types()
dict object
list object
set object
"""
cdef typing.Dict[int, float] a = {}
cdef List[int] b = []
cdef _SET_[object] c = set()
print(cython.typeof(a))
print(cython.typeof(b))
print(cython.typeof(c))
cdef class TestClassVar:
"""
>>> TestClassVar.cls
5
>>> TestClassVar.regular # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
AttributeError:
"""
cdef int regular
cdef typing.ClassVar[int] cls
cls = 5
# because tuple is specifically special cased to go to ctuple where possible
def test_tuple(typing.Tuple[int, float] a, typing.Tuple[int, ...] b,
Tuple[int, object] c # cannot be a ctuple
):
"""
>>> test_tuple((1, 1.0), (1, 1.0), (1, 1.0))
int
int
tuple object
tuple object
"""
cdef typing.Tuple[int, float] x = (a[0], a[1])
cdef Tuple[int, ...] y = (1,2.)
z = a[0] # should infer to int
print(cython.typeof(z))
print(cython.typeof(x[0]))
print(cython.typeof(y))
print(cython.typeof(c))
# mode: run
# tag: dataclass, pure3.7
from __future__ import print_function
import cython
@cython.dataclasses.dataclass(order=True, unsafe_hash=True)
@cython.cclass
class MyDataclass:
"""
>>> sorted(list(MyDataclass.__dataclass_fields__.keys()))
['a', 'self']
>>> inst1 = MyDataclass(2.0, ['a', 'b'])
>>> print(inst1)
MyDataclass(a=2.0, self=['a', 'b'])
>>> inst2 = MyDataclass()
>>> print(inst2)
MyDataclass(a=1, self=[])
>>> inst1 == inst2
False
>>> inst1 > inst2
True
>>> inst2 == MyDataclass()
True
>>> hash(inst1) != id(inst1)
True
"""
a: int = 1
self: list = cython.dataclasses.field(default_factory=list, hash=False) # test that arguments of init don't conflict
# mode: run
# tag: pure3.6
from __future__ import print_function
import cython
try:
import typing
from typing import List
from typing import Set as _SET_
except ImportError:
pass # this should allow Cython to interpret the directives even when the module doesn't exist
def test_subscripted_types():
"""
>>> test_subscripted_types()
dict object
list object
set object
"""
a: typing.Dict[int, float] = {}
b: List[int] = []
c: _SET_[object] = set()
print(cython.typeof(a) + (" object" if not cython.compiled else ""))
print(cython.typeof(b) + (" object" if not cython.compiled else ""))
print(cython.typeof(c) + (" object" if not cython.compiled else ""))
@cython.cclass
class TestClassVar:
"""
>>> TestClassVar.cls
5
>>> TestClassVar.regular # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
AttributeError:
"""
regular: int
cls: typing.ClassVar[int] = 5 # this is a little redundant really because the assignment ensures it
# mode: run
import cython
try:
import typing
from typing import List
from typing import Set as _SET_
except ImportError:
pass # this should allow Cython to interpret the directives even when the module doesn't exist
def test_subscripted_types():
"""
>>> test_subscripted_types()
dict object
list object
set object
"""
cdef typing.Dict[int, float] a = {}
cdef List[int] b = []
cdef _SET_[object] c = set()
print(cython.typeof(a))
print(cython.typeof(b))
print(cython.typeof(c))
cdef class TestClassVar:
"""
>>> TestClassVar.cls
5
>>> TestClassVar.regular # doctest: +IGNORE_EXCEPTION_DETAIL
Traceback (most recent call last):
...
AttributeError:
"""
cdef int regular
cdef typing.ClassVar[int] cls
cls = 5
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment