Commit ad24a17c authored by scoder's avatar scoder Committed by GitHub

Allow None to coerce to C types separately from other object values. (GH-4740)

This is used by some optimisations for builtins that call C-API functions directly but need to convert None arguments to NULL or special integer values in order to mimic the original Python interface.

Also add and backport the CPython macros for None checks (and True/False, while we're at it):
https://docs.python.org/3/c-api/structures.html#c.Py_Is

Closes https://github.com/cython/cython/issues/4737
See https://github.com/cython/cython/issues/4706
parent 69cb05b3
......@@ -13614,6 +13614,9 @@ class CoerceFromPyTypeNode(CoercionNode):
# This node is used to convert a Python object
# to a C data type.
# Allow 'None' to map to a difference C value independent of the coercion, e.g. to 'NULL' or '0'.
special_none_cvalue = None
def __init__(self, result_type, arg, env):
CoercionNode.__init__(self, arg)
self.type = result_type
......@@ -13643,7 +13646,10 @@ class CoerceFromPyTypeNode(CoercionNode):
NoneCheckNode.generate_if_needed(self.arg, code, "expected bytes, NoneType found")
code.putln(self.type.from_py_call_code(
self.arg.py_result(), self.result(), self.pos, code, from_py_function=from_py_function))
self.arg.py_result(), self.result(), self.pos, code,
from_py_function=from_py_function,
special_none_cvalue=self.special_none_cvalue,
))
if self.type.is_pyobject:
self.generate_gotref(code)
......
......@@ -3620,6 +3620,8 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
return node
if len(args) < 2:
args.append(ExprNodes.NullNode(node.pos))
else:
self._inject_null_for_none(args, 1)
self._inject_int_default_argument(
node, args, 2, PyrexTypes.c_py_ssize_t_type, "-1")
......@@ -4135,13 +4137,35 @@ class OptimizeBuiltinCalls(Visitor.NodeRefCleanupMixin,
format_args=[attr_name])
return self_arg
obj_to_obj_func_type = PyrexTypes.CFuncType(
PyrexTypes.py_object_type, [
PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None)
])
def _inject_null_for_none(self, args, index):
if len(args) <= index:
return
arg = args[index]
args[index] = ExprNodes.NullNode(arg.pos) if arg.is_none else ExprNodes.PythonCapiCallNode(
arg.pos, "__Pyx_NoneAsNull",
self.obj_to_obj_func_type,
args=[arg.coerce_to_simple(self.current_env())],
is_temp=0,
)
def _inject_int_default_argument(self, node, args, arg_index, type, default_value):
# Python usually allows passing None for range bounds,
# so we treat that as requesting the default.
assert len(args) >= arg_index
if len(args) == arg_index:
if len(args) == arg_index or args[arg_index].is_none:
args.append(ExprNodes.IntNode(node.pos, value=str(default_value),
type=type, constant_result=default_value))
else:
args[arg_index] = args[arg_index].coerce_to(type, self.current_env())
arg = args[arg_index].coerce_to(type, self.current_env())
if isinstance(arg, ExprNodes.CoerceFromPyTypeNode):
# Add a runtime check for None and map it to the default value.
arg.special_none_cvalue = str(default_value)
args[arg_index] = arg
def _inject_bint_default_argument(self, node, args, arg_index, default_value):
assert len(args) >= arg_index
......
......@@ -341,7 +341,8 @@ class PyrexType(BaseType):
return 0
def _assign_from_py_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None, extra_args=None):
from_py_function=None, error_condition=None, extra_args=None,
special_none_cvalue=None):
args = ', ' + ', '.join('%s' % arg for arg in extra_args) if extra_args else ''
convert_call = "%s(%s%s)" % (
from_py_function or self.from_py_function,
......@@ -350,6 +351,10 @@ class PyrexType(BaseType):
)
if self.is_enum:
convert_call = typecast(self, c_long_type, convert_call)
if special_none_cvalue:
# NOTE: requires 'source_code' to be simple!
convert_call = "(__Pyx_Py_IsNone(%s) ? (%s) : (%s))" % (
source_code, special_none_cvalue, convert_call)
return '%s = %s; %s' % (
result_code,
convert_call,
......@@ -555,11 +560,13 @@ class CTypedefType(BaseType):
source_code, result_code, result_type, to_py_function)
def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None):
from_py_function=None, error_condition=None,
special_none_cvalue=None):
return self.typedef_base_type.from_py_call_code(
source_code, result_code, error_pos, code,
from_py_function or self.from_py_function,
error_condition or self.error_condition(result_code)
error_condition or self.error_condition(result_code),
special_none_cvalue=special_none_cvalue,
)
def overflow_check_binop(self, binop, env, const_rhs=False):
......@@ -978,13 +985,16 @@ class MemoryViewSliceType(PyrexType):
return True
def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None):
from_py_function=None, error_condition=None,
special_none_cvalue=None):
# NOTE: auto-detection of readonly buffers is disabled:
# writable = self.writable_needed or not self.dtype.is_const
writable = not self.dtype.is_const
return self._assign_from_py_code(
source_code, result_code, error_pos, code, from_py_function, error_condition,
extra_args=['PyBUF_WRITABLE' if writable else '0'])
extra_args=['PyBUF_WRITABLE' if writable else '0'],
special_none_cvalue=special_none_cvalue,
)
def create_to_py_utility_code(self, env):
self._dtype_to_py_func, self._dtype_from_py_func = self.dtype_object_conversion_funcs(env)
......@@ -1674,9 +1684,11 @@ class CType(PyrexType):
source_code or 'NULL')
def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None):
from_py_function=None, error_condition=None,
special_none_cvalue=None):
return self._assign_from_py_code(
source_code, result_code, error_pos, code, from_py_function, error_condition)
source_code, result_code, error_pos, code, from_py_function, error_condition,
special_none_cvalue=special_none_cvalue)
......@@ -2675,8 +2687,10 @@ class CArrayType(CPointerBaseType):
return True
def from_py_call_code(self, source_code, result_code, error_pos, code,
from_py_function=None, error_condition=None):
from_py_function=None, error_condition=None,
special_none_cvalue=None):
assert not error_condition, '%s: %s' % (error_pos, error_condition)
assert not special_none_cvalue, '%s: %s' % (error_pos, special_none_cvalue) # not currently supported
call_code = "%s(%s, %s, %s)" % (
from_py_function or self.from_py_function,
source_code, result_code, self.size)
......
......@@ -628,6 +628,28 @@ class __Pyx_FakeReference {
#define __Pyx_IS_TYPE(ob, type) (((const PyObject*)ob)->ob_type == (type))
#endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_Is)
#define __Pyx_Py_Is(x, y) Py_Is(x, y)
#else
#define __Pyx_Py_Is(x, y) ((x) == (y))
#endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsNone)
#define __Pyx_Py_IsNone(ob) Py_IsNone(ob)
#else
#define __Pyx_Py_IsNone(ob) __Pyx_Py_Is((ob), Py_None)
#endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsTrue)
#define __Pyx_Py_IsTrue(ob) Py_IsTrue(ob)
#else
#define __Pyx_Py_IsTrue(ob) __Pyx_Py_Is((ob), Py_True)
#endif
#if PY_VERSION_HEX >= 0x030A00B1 || defined(Py_IsFalse)
#define __Pyx_Py_IsFalse(ob) Py_IsFalse(ob)
#else
#define __Pyx_Py_IsFalse(ob) __Pyx_Py_Is((ob), Py_False)
#endif
#define __Pyx_NoneAsNull(obj) (__Pyx_Py_IsNone(obj) ? NULL : (obj))
#ifndef Py_TPFLAGS_CHECKTYPES
#define Py_TPFLAGS_CHECKTYPES 0
#endif
......
......@@ -59,6 +59,24 @@ def split_sep(unicode s, sep):
ab jd
sdflk as sa
sadas asdas fsdf\x20
>>> print_all( text.split(None) )
ab
jd
sdflk
as
sa
sadas
asdas
fsdf
>>> print_all( split_sep(text, None) )
ab
jd
sdflk
as
sa
sadas
asdas
fsdf
"""
return s.split(sep)
......@@ -76,6 +94,14 @@ def split_sep_max(unicode s, sep, max):
>>> print_all( split_sep_max(text, sep, 1) )
ab jd
sdflk as sa sadas asdas fsdf\x20
>>> print_all( text.split(None, 2) )
ab
jd
sdflk as sa sadas asdas fsdf\x20
>>> print_all( split_sep_max(text, None, 2) )
ab
jd
sdflk as sa sadas asdas fsdf\x20
"""
return s.split(sep, max)
......@@ -92,6 +118,12 @@ def split_sep_max_int(unicode s, sep):
>>> print_all( split_sep_max_int(text, sep) )
ab jd
sdflk as sa sadas asdas fsdf\x20
>>> print_all( text.split(None, 1) )
ab
jd sdflk as sa sadas asdas fsdf\x20
>>> print_all( split_sep_max_int(text, None) )
ab
jd sdflk as sa sadas asdas fsdf\x20
"""
return s.split(sep, 1)
......@@ -337,6 +369,11 @@ def startswith_start_end(unicode s, sub, start, end):
False
>>> startswith_start_end(text, 'b X', 1, 5)
'NO MATCH'
>>> text.startswith('ab ', None, None)
True
>>> startswith_start_end(text, 'ab ', None, None)
'MATCH'
"""
if s.startswith(sub, start, end):
return 'MATCH'
......@@ -407,6 +444,11 @@ def endswith_start_end(unicode s, sub, start, end):
True
>>> endswith_start_end(text, ('fsdf ', 'fsdf X'), 10, len(text)-1)
'NO MATCH'
>>> text.endswith('fsdf ', None, None)
True
>>> endswith_start_end(text, 'fsdf ', None, None)
'MATCH'
"""
if s.endswith(sub, start, end):
return 'MATCH'
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment