Commit 4f208e13 authored by Mark Florisson's avatar Mark Florisson

branch merge

parents 841e7703 d591804e
...@@ -94,7 +94,7 @@ class DistutilsInfo(object): ...@@ -94,7 +94,7 @@ class DistutilsInfo(object):
value = [tuple(macro.split('=')) for macro in value] value = [tuple(macro.split('=')) for macro in value]
self.values[key] = value self.values[key] = value
elif exn is not None: elif exn is not None:
for key in self.distutils_settings: for key in distutils_settings:
if key in ('name', 'sources'): if key in ('name', 'sources'):
pass pass
value = getattr(exn, key, None) value = getattr(exn, key, None)
...@@ -154,19 +154,33 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -154,19 +154,33 @@ def strip_string_literals(code, prefix='__Pyx_L'):
in_quote = False in_quote = False
raw = False raw = False
while True: while True:
hash_mark = code.find('#', q)
single_q = code.find("'", q) single_q = code.find("'", q)
double_q = code.find('"', q) double_q = code.find('"', q)
q = min(single_q, double_q) q = min(single_q, double_q)
if q == -1: q = max(single_q, double_q) if q == -1: q = max(single_q, double_q)
if q == -1:
if in_quote: # Process comment.
counter += 1 if -1 < hash_mark and (hash_mark < q or q == -1):
label = "'%s%s" % (prefix, counter) end = code.find('\n', hash_mark)
literals[label] = code[start:] if end == -1:
new_code.append(label) end = None
else: new_code.append(code[start:hash_mark+1])
new_code.append(code[start:]) counter += 1
label = "%s%s" % (prefix, counter)
literals[label] = code[hash_mark+1:end]
new_code.append(label)
if end is None:
break
q = end
start = q
# We're done.
elif q == -1:
new_code.append(code[start:])
break break
# Try to close the quote.
elif in_quote: elif in_quote:
if code[q-1] == '\\': if code[q-1] == '\\':
k = 2 k = 2
...@@ -179,12 +193,14 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -179,12 +193,14 @@ def strip_string_literals(code, prefix='__Pyx_L'):
counter += 1 counter += 1
label = "%s%s" % (prefix, counter) label = "%s%s" % (prefix, counter)
literals[label] = code[start+len(in_quote):q] literals[label] = code[start+len(in_quote):q]
new_code.append("'%s'" % label) new_code.append("%s%s%s" % (in_quote, label, in_quote))
q += len(in_quote) q += len(in_quote)
start = q
in_quote = False in_quote = False
start = q
else: else:
q += 1 q += 1
# Open the quote.
else: else:
raw = False raw = False
if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]): if len(code) >= q+3 and (code[q+1] == code[q] == code[q+2]):
...@@ -202,13 +218,13 @@ def strip_string_literals(code, prefix='__Pyx_L'): ...@@ -202,13 +218,13 @@ def strip_string_literals(code, prefix='__Pyx_L'):
return "".join(new_code), literals return "".join(new_code), literals
def parse_dependencies(source_filename): def parse_dependencies(source_filename):
# Actual parsing is way to slow, so we use regular expressions. # Actual parsing is way to slow, so we use regular expressions.
# The only catch is that we must strip comments and string # The only catch is that we must strip comments and string
# literals ahead of time. # literals ahead of time.
source = Utils.open_source_file(source_filename, "rU").read() source = Utils.open_source_file(source_filename, "rU").read()
distutils_info = DistutilsInfo(source) distutils_info = DistutilsInfo(source)
source = re.sub('#.*', '', source)
source, literals = strip_string_literals(source) source, literals = strip_string_literals(source)
source = source.replace('\\\n', ' ') source = source.replace('\\\n', ' ')
if '\t' in source: if '\t' in source:
...@@ -389,8 +405,8 @@ def create_extension_list(patterns, ctx=None, aliases=None): ...@@ -389,8 +405,8 @@ def create_extension_list(patterns, ctx=None, aliases=None):
continue continue
template = pattern template = pattern
name = template.name name = template.name
base = DistutilsInfo(template) base = DistutilsInfo(exn=template)
exn_type = type(template) exn_type = template.__class__
else: else:
raise TypeError(pattern) raise TypeError(pattern)
for file in glob(filepattern): for file in glob(filepattern):
......
#no doctest
print "Warning: Using prototype cython.inline code..." print "Warning: Using prototype cython.inline code..."
import tempfile import tempfile
...@@ -8,14 +9,15 @@ try: ...@@ -8,14 +9,15 @@ try:
except ImportError: except ImportError:
import md5 as hashlib import md5 as hashlib
from distutils.dist import Distribution from distutils.core import Distribution, Extension
from Cython.Distutils.extension import Extension from distutils.command.build_ext import build_ext
from Cython.Distutils import build_ext
import Cython
from Cython.Compiler.Main import Context, CompilationOptions, default_options from Cython.Compiler.Main import Context, CompilationOptions, default_options
from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform from Cython.Compiler.ParseTreeTransforms import CythonTransform, SkipDeclarations, AnalyseDeclarationsTransform
from Cython.Compiler.TreeFragment import parse_from_strings from Cython.Compiler.TreeFragment import parse_from_strings
from Cython.Build.Dependencies import strip_string_literals, cythonize
_code_cache = {} _code_cache = {}
...@@ -81,6 +83,7 @@ def cython_inline(code, ...@@ -81,6 +83,7 @@ def cython_inline(code,
locals=None, locals=None,
globals=None, globals=None,
**kwds): **kwds):
code, literals = strip_string_literals(code)
code = strip_common_indent(code) code = strip_common_indent(code)
ctx = Context(include_dirs, default_options) ctx = Context(include_dirs, default_options)
if locals is None: if locals is None:
...@@ -103,42 +106,54 @@ def cython_inline(code, ...@@ -103,42 +106,54 @@ def cython_inline(code,
arg_names = kwds.keys() arg_names = kwds.keys()
arg_names.sort() arg_names.sort()
arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names]) arg_sigs = tuple([(get_type(kwds[arg], ctx), arg) for arg in arg_names])
key = code, arg_sigs key = code, arg_sigs, sys.version_info, sys.executable, Cython.__version__
module = _code_cache.get(key) module_name = "_cython_inline_" + hashlib.md5(str(key)).hexdigest()
if not module: # # TODO: Does this cover all the platforms?
# if (not os.path.exists(os.path.join(lib_dir, module_name + ".so")) and
# not os.path.exists(os.path.join(lib_dir, module_name + ".dll"))):
try:
if not os.path.exists(lib_dir):
os.makedirs(lib_dir)
if lib_dir not in sys.path:
sys.path.append(lib_dir)
__import__(module_name)
except ImportError:
c_include_dirs = []
cimports = [] cimports = []
qualified = re.compile(r'([.\w]+)[.]') qualified = re.compile(r'([.\w]+)[.]')
for type, _ in arg_sigs: for type, _ in arg_sigs:
m = qualified.match(type) m = qualified.match(type)
if m: if m:
cimports.append('\ncimport %s' % m.groups()[0]) cimports.append('\ncimport %s' % m.groups()[0])
# one special case
if m.groups()[0] == 'numpy':
import numpy
c_include_dirs.append(numpy.get_include())
module_body, func_body = extract_func_code(code) module_body, func_body = extract_func_code(code)
params = ', '.join(['%s %s' % a for a in arg_sigs]) params = ', '.join(['%s %s' % a for a in arg_sigs])
module_code = """ module_code = """
%(cimports)s
%(module_body)s %(module_body)s
%(cimports)s
def __invoke(%(params)s): def __invoke(%(params)s):
%(func_body)s %(func_body)s
""" % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body } """ % {'cimports': '\n'.join(cimports), 'module_body': module_body, 'params': params, 'func_body': func_body }
# print module_code for key, value in literals.items():
_, pyx_file = tempfile.mkstemp('.pyx') module_code = module_code.replace(key, value)
pyx_file = os.path.join(lib_dir, module_name + '.pyx')
open(pyx_file, 'w').write(module_code) open(pyx_file, 'w').write(module_code)
module = "_" + hashlib.md5(code + str(arg_sigs)).hexdigest()
extension = Extension( extension = Extension(
name = module, name = module_name,
sources = [pyx_file], sources = [pyx_file],
pyrex_include_dirs = include_dirs) include_dirs = c_include_dirs)
build_extension = build_ext(Distribution()) build_extension = build_ext(Distribution())
build_extension.finalize_options() build_extension.finalize_options()
build_extension.extensions = [extension] build_extension.extensions = cythonize([extension])
build_extension.build_temp = os.path.dirname(pyx_file) build_extension.build_temp = os.path.dirname(pyx_file)
if lib_dir not in sys.path:
sys.path.append(lib_dir)
build_extension.build_lib = lib_dir build_extension.build_lib = lib_dir
build_extension.run() build_extension.run()
_code_cache[key] = module _code_cache[key] = module_name
arg_list = [kwds[arg] for arg in arg_names] arg_list = [kwds[arg] for arg in arg_names]
return __import__(module).__invoke(*arg_list) return __import__(module_name).__invoke(*arg_list)
non_space = re.compile('[^ ]') non_space = re.compile('[^ ]')
def strip_common_indent(code): def strip_common_indent(code):
...@@ -165,7 +180,6 @@ module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimpor ...@@ -165,7 +180,6 @@ module_statement = re.compile(r'^((cdef +(extern|class))|cimport|(from .+ cimpor
def extract_func_code(code): def extract_func_code(code):
module = [] module = []
function = [] function = []
# TODO: string literals, backslash
current = function current = function
code = code.replace('\t', ' ') code = code.replace('\t', ' ')
lines = code.split('\n') lines = code.split('\n')
...@@ -177,3 +191,54 @@ def extract_func_code(code): ...@@ -177,3 +191,54 @@ def extract_func_code(code):
current = function current = function
current.append(line) current.append(line)
return '\n'.join(module), ' ' + '\n '.join(function) return '\n'.join(module), ' ' + '\n '.join(function)
try:
from inspect import getcallargs
except ImportError:
def getcallargs(func, *arg_values, **kwd_values):
all = {}
args, varargs, kwds, defaults = inspect.getargspec(func)
if varargs is not None:
all[varargs] = arg_values[len(args):]
for name, value in zip(args, arg_values):
all[name] = value
for name, value in kwd_values.items():
if name in args:
if name in all:
raise TypeError, "Duplicate argument %s" % name
all[name] = kwd_values.pop(name)
if kwds is not None:
all[kwds] = kwd_values
elif kwd_values:
raise TypeError, "Unexpected keyword arguments: %s" % kwd_values.keys()
if defaults is None:
defaults = ()
first_default = len(args) - len(defaults)
for ix, name in enumerate(args):
if name not in all:
if ix >= first_default:
all[name] = defaults[ix - first_default]
else:
raise TypeError, "Missing argument: %s" % name
return all
def get_body(source):
ix = source.index(':')
if source[:5] == 'lambda':
return "return %s" % source[ix+1:]
else:
return source[ix+1:]
# Lots to be done here... It would be especially cool if compiled functions
# could invoke each other quickly.
class RuntimeCompiledFunction(object):
def __init__(self, f):
self._f = f
self._body = get_body(inspect.getsource(f))
def __call__(self, *args, **kwds):
all = getcallargs(self._f, *args, **kwds)
return cython_inline(self._body, locals=self._f.func_globals, globals=self._f.func_globals, **all)
This diff is collapsed.
...@@ -759,16 +759,22 @@ class GlobalState(object): ...@@ -759,16 +759,22 @@ class GlobalState(object):
try: try:
return self.input_file_contents[source_desc] return self.input_file_contents[source_desc]
except KeyError: except KeyError:
pass
source_file = source_desc.get_lines(encoding='ASCII',
error_handling='ignore')
try:
F = [u' * ' + line.rstrip().replace( F = [u' * ' + line.rstrip().replace(
u'*/', u'*[inserted by cython to avoid comment closer]/' u'*/', u'*[inserted by cython to avoid comment closer]/'
).replace( ).replace(
u'/*', u'/[inserted by cython to avoid comment start]*' u'/*', u'/[inserted by cython to avoid comment start]*'
) )
for line in source_desc.get_lines(encoding='ASCII', for line in source_file]
error_handling='ignore')] finally:
if len(F) == 0: F.append(u'') if hasattr(source_file, 'close'):
self.input_file_contents[source_desc] = F source_file.close()
return F if not F: F.append(u'')
self.input_file_contents[source_desc] = F
return F
# #
# Utility code state # Utility code state
......
This diff is collapsed.
...@@ -102,6 +102,7 @@ class Context(object): ...@@ -102,6 +102,7 @@ class Context(object):
from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
from ParseTreeTransforms import ExpandInplaceOperators
from TypeInference import MarkAssignments, MarkOverflowingArithmetic from TypeInference import MarkAssignments, MarkOverflowingArithmetic
from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
from AnalysedTreeTransforms import AutoTestDictTransform from AnalysedTreeTransforms import AutoTestDictTransform
...@@ -147,6 +148,7 @@ class Context(object): ...@@ -147,6 +148,7 @@ class Context(object):
IntroduceBufferAuxiliaryVars(self), IntroduceBufferAuxiliaryVars(self),
_check_c_declarations, _check_c_declarations,
AnalyseExpressionsTransform(self), AnalyseExpressionsTransform(self),
ExpandInplaceOperators(self),
OptimizeBuiltinCalls(self), ## Necessary? OptimizeBuiltinCalls(self), ## Necessary?
IterationTransform(), IterationTransform(),
SwitchTransform(), SwitchTransform(),
......
This diff is collapsed.
This diff is collapsed.
from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor from Cython.Compiler.Visitor import VisitorTransform, TreeVisitor
from Cython.Compiler.Visitor import CythonTransform, EnvTransform from Cython.Compiler.Visitor import CythonTransform, EnvTransform, ScopeTrackingTransform
from Cython.Compiler.ModuleNode import ModuleNode from Cython.Compiler.ModuleNode import ModuleNode
from Cython.Compiler.Nodes import * from Cython.Compiler.Nodes import *
from Cython.Compiler.ExprNodes import * from Cython.Compiler.ExprNodes import *
...@@ -133,7 +133,7 @@ class PostParseError(CompileError): pass ...@@ -133,7 +133,7 @@ class PostParseError(CompileError): pass
ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions' ERR_CDEF_INCLASS = 'Cannot assign default value to fields in cdef classes, structs or unions'
ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)' ERR_BUF_DEFAULTS = 'Invalid buffer defaults specification (see docs)'
ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared' ERR_INVALID_SPECIALATTR_TYPE = 'Special attributes must not have a type declared'
class PostParse(CythonTransform): class PostParse(ScopeTrackingTransform):
""" """
Basic interpretation of the parse tree, as well as validity Basic interpretation of the parse tree, as well as validity
checking that can be done on a very basic level on the parse checking that can be done on a very basic level on the parse
...@@ -168,9 +168,6 @@ class PostParse(CythonTransform): ...@@ -168,9 +168,6 @@ class PostParse(CythonTransform):
if a more pure Abstract Syntax Tree is wanted. if a more pure Abstract Syntax Tree is wanted.
""" """
# Track our context.
scope_type = None # can be either of 'module', 'function', 'class'
def __init__(self, context): def __init__(self, context):
super(PostParse, self).__init__(context) super(PostParse, self).__init__(context)
self.specialattribute_handlers = { self.specialattribute_handlers = {
...@@ -178,28 +175,8 @@ class PostParse(CythonTransform): ...@@ -178,28 +175,8 @@ class PostParse(CythonTransform):
} }
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.scope_type = 'module'
self.scope_node = node
self.lambda_counter = 1 self.lambda_counter = 1
self.visitchildren(node) return super(PostParse, self).visit_ModuleNode(node)
return node
def visit_scope(self, node, scope_type):
prev = self.scope_type, self.scope_node
self.scope_type = scope_type
self.scope_node = node
self.visitchildren(node)
self.scope_type, self.scope_node = prev
return node
def visit_ClassDefNode(self, node):
return self.visit_scope(node, 'class')
def visit_FuncDefNode(self, node):
return self.visit_scope(node, 'function')
def visit_CStructOrUnionDefNode(self, node):
return self.visit_scope(node, 'struct')
def visit_LambdaNode(self, node): def visit_LambdaNode(self, node):
# unpack a lambda expression into the corresponding DefNode # unpack a lambda expression into the corresponding DefNode
...@@ -242,7 +219,7 @@ class PostParse(CythonTransform): ...@@ -242,7 +219,7 @@ class PostParse(CythonTransform):
declbase = declbase.base declbase = declbase.base
if isinstance(declbase, CNameDeclaratorNode): if isinstance(declbase, CNameDeclaratorNode):
if declbase.default is not None: if declbase.default is not None:
if self.scope_type in ('class', 'struct'): if self.scope_type in ('cclass', 'pyclass', 'struct'):
if isinstance(self.scope_node, CClassDefNode): if isinstance(self.scope_node, CClassDefNode):
handler = self.specialattribute_handlers.get(decl.name) handler = self.specialattribute_handlers.get(decl.name)
if handler: if handler:
...@@ -1197,7 +1174,60 @@ class AnalyseExpressionsTransform(CythonTransform): ...@@ -1197,7 +1174,60 @@ class AnalyseExpressionsTransform(CythonTransform):
node.analyse_scoped_expressions(node.expr_scope) node.analyse_scoped_expressions(node.expr_scope)
self.visitchildren(node) self.visitchildren(node)
return node return node
class ExpandInplaceOperators(EnvTransform):
def visit_InPlaceAssignmentNode(self, node):
lhs = node.lhs
rhs = node.rhs
if lhs.type.is_cpp_class:
# No getting around this exact operator here.
return node
if isinstance(lhs, IndexNode) and lhs.is_buffer_access:
# There is code to handle this case.
return node
def side_effect_free_reference(node, setting=False):
if isinstance(node, NameNode):
return node, []
elif node.type.is_pyobject and not setting:
node = LetRefNode(node)
return node, [node]
elif isinstance(node, IndexNode):
if node.is_buffer_access:
raise ValueError, "Buffer access"
base, temps = side_effect_free_reference(node.base)
index = LetRefNode(node.index)
return IndexNode(node.pos, base=base, index=index), temps + [index]
elif isinstance(node, AttributeNode):
obj, temps = side_effect_free_reference(node.obj)
return AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
else:
node = LetRefNode(node)
return node, [node]
try:
lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
except ValueError:
return node
dup = lhs.__class__(**lhs.__dict__)
binop = binop_node(node.pos,
operator = node.operator,
operand1 = dup,
operand2 = rhs,
inplace=True)
node = SingleAssignmentNode(node.pos, lhs=lhs, rhs=binop)
# Use LetRefNode to avoid side effects.
let_ref_nodes.reverse()
for t in let_ref_nodes:
node = LetNode(t, node)
node.analyse_expressions(self.current_env())
return node
def visit_ExprNode(self, node):
# In-place assignments can't happen within an expression.
return node
class AlignFunctionDefinitions(CythonTransform): class AlignFunctionDefinitions(CythonTransform):
""" """
This class takes the signatures from a .pxd file and applies them to This class takes the signatures from a .pxd file and applies them to
...@@ -1236,15 +1266,11 @@ class AlignFunctionDefinitions(CythonTransform): ...@@ -1236,15 +1266,11 @@ class AlignFunctionDefinitions(CythonTransform):
def visit_DefNode(self, node): def visit_DefNode(self, node):
pxd_def = self.scope.lookup(node.name) pxd_def = self.scope.lookup(node.name)
if pxd_def: if pxd_def:
if self.scope.is_c_class_scope and len(pxd_def.type.args) > 0: if not pxd_def.is_cfunction:
# The self parameter type needs adjusting.
pxd_def.type.args[0].type = self.scope.parent_type
if pxd_def.is_cfunction:
node = node.as_cfunction(pxd_def)
else:
error(node.pos, "'%s' redeclared" % node.name) error(node.pos, "'%s' redeclared" % node.name)
error(pxd_def.pos, "previous declaration here") error(pxd_def.pos, "previous declaration here")
return None return None
node = node.as_cfunction(pxd_def)
elif self.scope.is_module_scope and self.directives['auto_cpdef']: elif self.scope.is_module_scope and self.directives['auto_cpdef']:
node = node.as_cfunction(scope=self.scope) node = node.as_cfunction(scope=self.scope)
# Enable this when internal def functions are allowed. # Enable this when internal def functions are allowed.
......
...@@ -35,6 +35,8 @@ cpdef p_yield_statement(PyrexScanner s) ...@@ -35,6 +35,8 @@ cpdef p_yield_statement(PyrexScanner s)
cpdef p_power(PyrexScanner s) cpdef p_power(PyrexScanner s)
cpdef p_new_expr(PyrexScanner s) cpdef p_new_expr(PyrexScanner s)
cpdef p_trailer(PyrexScanner s, node1) cpdef p_trailer(PyrexScanner s, node1)
cpdef p_call_parse_args(PyrexScanner s, bint allow_genexp = *)
cpdef p_call_build_packed_args(pos, positional_args, keyword_args, star_arg)
cpdef p_call(PyrexScanner s, function) cpdef p_call(PyrexScanner s, function)
cpdef p_index(PyrexScanner s, base) cpdef p_index(PyrexScanner s, base)
cpdef p_subscript_list(PyrexScanner s) cpdef p_subscript_list(PyrexScanner s)
......
...@@ -381,7 +381,7 @@ def p_trailer(s, node1): ...@@ -381,7 +381,7 @@ def p_trailer(s, node1):
# arglist: argument (',' argument)* [','] # arglist: argument (',' argument)* [',']
# argument: [test '='] test # Really [keyword '='] test # argument: [test '='] test # Really [keyword '='] test
def p_call(s, function): def p_call_parse_args(s, allow_genexp = True):
# s.sy == '(' # s.sy == '('
pos = s.position() pos = s.position()
s.next() s.next()
...@@ -428,29 +428,43 @@ def p_call(s, function): ...@@ -428,29 +428,43 @@ def p_call(s, function):
if s.sy == ',': if s.sy == ',':
s.next() s.next()
s.expect(')') s.expect(')')
return positional_args, keyword_args, star_arg, starstar_arg
def p_call_build_packed_args(pos, positional_args, keyword_args, star_arg):
arg_tuple = None
keyword_dict = None
if positional_args or not star_arg:
arg_tuple = ExprNodes.TupleNode(pos,
args = positional_args)
if star_arg:
star_arg_tuple = ExprNodes.AsTupleNode(pos, arg = star_arg)
if arg_tuple:
arg_tuple = ExprNodes.binop_node(pos,
operator = '+', operand1 = arg_tuple,
operand2 = star_arg_tuple)
else:
arg_tuple = star_arg_tuple
if keyword_args:
keyword_args = [ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args]
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = keyword_args)
return arg_tuple, keyword_dict
def p_call(s, function):
# s.sy == '('
pos = s.position()
positional_args, keyword_args, star_arg, starstar_arg = \
p_call_parse_args(s)
if not (keyword_args or star_arg or starstar_arg): if not (keyword_args or star_arg or starstar_arg):
return ExprNodes.SimpleCallNode(pos, return ExprNodes.SimpleCallNode(pos,
function = function, function = function,
args = positional_args) args = positional_args)
else: else:
arg_tuple = None arg_tuple, keyword_dict = p_call_build_packed_args(
keyword_dict = None pos, positional_args, keyword_args, star_arg)
if positional_args or not star_arg:
arg_tuple = ExprNodes.TupleNode(pos,
args = positional_args)
if star_arg:
star_arg_tuple = ExprNodes.AsTupleNode(pos, arg = star_arg)
if arg_tuple:
arg_tuple = ExprNodes.binop_node(pos,
operator = '+', operand1 = arg_tuple,
operand2 = star_arg_tuple)
else:
arg_tuple = star_arg_tuple
if keyword_args:
keyword_args = [ExprNodes.DictItemNode(pos=key.pos, key=key, value=value)
for key, value in keyword_args]
keyword_dict = ExprNodes.DictNode(pos,
key_value_pairs = keyword_args)
return ExprNodes.GeneralCallNode(pos, return ExprNodes.GeneralCallNode(pos,
function = function, function = function,
positional_args = arg_tuple, positional_args = arg_tuple,
...@@ -2607,16 +2621,23 @@ def p_class_statement(s, decorators): ...@@ -2607,16 +2621,23 @@ def p_class_statement(s, decorators):
s.next() s.next()
class_name = EncodedString( p_ident(s) ) class_name = EncodedString( p_ident(s) )
class_name.encoding = s.source_encoding class_name.encoding = s.source_encoding
arg_tuple = None
keyword_dict = None
starstar_arg = None
if s.sy == '(': if s.sy == '(':
s.next() positional_args, keyword_args, star_arg, starstar_arg = \
base_list = p_simple_expr_list(s) p_call_parse_args(s, allow_genexp = False)
s.expect(')') arg_tuple, keyword_dict = p_call_build_packed_args(
else: pos, positional_args, keyword_args, star_arg)
base_list = [] if arg_tuple is None:
# XXX: empty arg_tuple
arg_tuple = ExprNodes.TupleNode(pos, args = [])
doc, body = p_suite(s, Ctx(level = 'class'), with_doc = 1) doc, body = p_suite(s, Ctx(level = 'class'), with_doc = 1)
return Nodes.PyClassDefNode(pos, return Nodes.PyClassDefNode(pos,
name = class_name, name = class_name,
bases = ExprNodes.TupleNode(pos, args = base_list), bases = arg_tuple,
keyword_args = keyword_dict,
starstar_arg = starstar_arg,
doc = doc, body = body, decorators = decorators) doc = doc, body = body, decorators = decorators)
def p_c_class_definition(s, pos, ctx): def p_c_class_definition(s, pos, ctx):
......
...@@ -379,14 +379,8 @@ class BuiltinObjectType(PyObjectType): ...@@ -379,14 +379,8 @@ class BuiltinObjectType(PyObjectType):
base_type = None base_type = None
module_name = '__builtin__' module_name = '__builtin__'
alternative_name = None # used for str/bytes duality
def __init__(self, name, cname): def __init__(self, name, cname):
self.name = name self.name = name
if name == 'str':
self.alternative_name = 'bytes'
elif name == 'bytes':
self.alternative_name = 'str'
self.cname = cname self.cname = cname
self.typeptr_cname = "&" + cname self.typeptr_cname = "&" + cname
...@@ -403,9 +397,7 @@ class BuiltinObjectType(PyObjectType): ...@@ -403,9 +397,7 @@ class BuiltinObjectType(PyObjectType):
def assignable_from(self, src_type): def assignable_from(self, src_type):
if isinstance(src_type, BuiltinObjectType): if isinstance(src_type, BuiltinObjectType):
return src_type.name == self.name or ( return src_type.name == self.name
src_type.name == self.alternative_name and
src_type.name is not None)
elif src_type.is_extension_type: elif src_type.is_extension_type:
return (src_type.module_name == '__builtin__' and return (src_type.module_name == '__builtin__' and
src_type.name == self.name) src_type.name == self.name)
......
...@@ -358,8 +358,11 @@ class PyrexScanner(Scanner): ...@@ -358,8 +358,11 @@ class PyrexScanner(Scanner):
self.error("Unrecognized character") self.error("Unrecognized character")
if sy == IDENT: if sy == IDENT:
if systring in self.keywords: if systring in self.keywords:
if systring == 'print' and \ if systring == 'print' and print_function in self.context.future_directives:
print_function in self.context.future_directives: self.keywords.remove('print')
systring = EncodedString(systring)
elif systring == 'exec' and self.context.language_level >= 3:
self.keywords.remove('exec')
systring = EncodedString(systring) systring = EncodedString(systring)
else: else:
sy = systring sy = systring
...@@ -416,7 +419,11 @@ class PyrexScanner(Scanner): ...@@ -416,7 +419,11 @@ class PyrexScanner(Scanner):
if message: if message:
self.error(message) self.error(message)
else: else:
self.error("Expected '%s'" % what) if self.sy == IDENT:
found = self.systring
else:
found = self.sy
self.error("Expected '%s', found '%s'" % (what, found))
def expect_indent(self): def expect_indent(self):
self.expect('INDENT', self.expect('INDENT',
......
...@@ -70,7 +70,7 @@ class Entry(object): ...@@ -70,7 +70,7 @@ class Entry(object):
# or class attribute during # or class attribute during
# class construction # class construction
# is_member boolean Is an assigned class member # is_member boolean Is an assigned class member
# is_real_dict boolean Is a real dict, PyClass attributes dict # is_pyclass_attr boolean Is a name in a Python class namespace
# is_variable boolean Is a variable # is_variable boolean Is a variable
# is_cfunction boolean Is a C function # is_cfunction boolean Is a C function
# is_cmethod boolean Is a C method of an extension type # is_cmethod boolean Is a C method of an extension type
...@@ -132,7 +132,7 @@ class Entry(object): ...@@ -132,7 +132,7 @@ class Entry(object):
is_cglobal = 0 is_cglobal = 0
is_pyglobal = 0 is_pyglobal = 0
is_member = 0 is_member = 0
is_real_dict = 0 is_pyclass_attr = 0
is_variable = 0 is_variable = 0
is_cfunction = 0 is_cfunction = 0
is_cmethod = 0 is_cmethod = 0
...@@ -539,7 +539,7 @@ class Scope(object): ...@@ -539,7 +539,7 @@ class Scope(object):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname = None, visibility = 'private', defining = 0, cname = None, visibility = 'private', defining = 0,
api = 0, in_pxd = 0, modifiers = ()): api = 0, in_pxd = 0, modifiers = (), utility_code = None):
# Add an entry for a C function. # Add an entry for a C function.
if not cname: if not cname:
if api or visibility != 'private': if api or visibility != 'private':
...@@ -552,7 +552,18 @@ class Scope(object): ...@@ -552,7 +552,18 @@ class Scope(object):
warning(pos, "Function '%s' previously declared as '%s'" % (name, entry.visibility), 1) warning(pos, "Function '%s' previously declared as '%s'" % (name, entry.visibility), 1)
if not entry.type.same_as(type): if not entry.type.same_as(type):
if visibility == 'extern' and entry.visibility == 'extern': if visibility == 'extern' and entry.visibility == 'extern':
can_override = False
if self.is_cpp(): if self.is_cpp():
can_override = True
elif cname:
# if all alternatives have different cnames,
# it's safe to allow signature overrides
for alt_entry in entry.all_alternatives():
if not alt_entry.cname or cname == alt_entry.cname:
break # cname not unique!
else:
can_override = True
if can_override:
temp = self.add_cfunction(name, type, pos, cname, visibility, modifiers) temp = self.add_cfunction(name, type, pos, cname, visibility, modifiers)
temp.overloaded_alternatives = entry.all_alternatives() temp.overloaded_alternatives = entry.all_alternatives()
entry = temp entry = temp
...@@ -574,6 +585,7 @@ class Scope(object): ...@@ -574,6 +585,7 @@ class Scope(object):
entry.is_implemented = True entry.is_implemented = True
if modifiers: if modifiers:
entry.func_modifiers = modifiers entry.func_modifiers = modifiers
entry.utility_code = utility_code
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
...@@ -711,8 +723,8 @@ class BuiltinScope(Scope): ...@@ -711,8 +723,8 @@ class BuiltinScope(Scope):
# If python_equiv == "*", the Python equivalent has the same name # If python_equiv == "*", the Python equivalent has the same name
# as the entry, otherwise it has the name specified by python_equiv. # as the entry, otherwise it has the name specified by python_equiv.
name = EncodedString(name) name = EncodedString(name)
entry = self.declare_cfunction(name, type, None, cname, visibility='extern') entry = self.declare_cfunction(name, type, None, cname, visibility='extern',
entry.utility_code = utility_code utility_code = utility_code)
if python_equiv: if python_equiv:
if python_equiv == "*": if python_equiv == "*":
python_equiv = name python_equiv = name
...@@ -1352,7 +1364,7 @@ class StructOrUnionScope(Scope): ...@@ -1352,7 +1364,7 @@ class StructOrUnionScope(Scope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname = None, visibility = 'private', defining = 0, cname = None, visibility = 'private', defining = 0,
api = 0, in_pxd = 0, modifiers = ()): api = 0, in_pxd = 0, modifiers = ()): # currently no utility code ...
return self.declare_var(name, type, pos, cname, visibility) return self.declare_var(name, type, pos, cname, visibility)
class ClassScope(Scope): class ClassScope(Scope):
...@@ -1407,7 +1419,7 @@ class PyClassScope(ClassScope): ...@@ -1407,7 +1419,7 @@ class PyClassScope(ClassScope):
entry = Scope.declare_var(self, name, type, pos, entry = Scope.declare_var(self, name, type, pos,
cname, visibility, is_cdef) cname, visibility, is_cdef)
entry.is_pyglobal = 1 entry.is_pyglobal = 1
entry.is_real_dict = 1 entry.is_pyclass_attr = 1
return entry return entry
def add_default_value(self, type): def add_default_value(self, type):
...@@ -1528,7 +1540,8 @@ class CClassScope(ClassScope): ...@@ -1528,7 +1540,8 @@ class CClassScope(ClassScope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname = None, visibility = 'private', cname = None, visibility = 'private',
defining = 0, api = 0, in_pxd = 0, modifiers = ()): defining = 0, api = 0, in_pxd = 0, modifiers = (),
utility_code = None):
if get_special_method_signature(name): if get_special_method_signature(name):
error(pos, "Special methods must be declared with 'def', not 'cdef'") error(pos, "Special methods must be declared with 'def', not 'cdef'")
args = type.args args = type.args
...@@ -1562,6 +1575,7 @@ class CClassScope(ClassScope): ...@@ -1562,6 +1575,7 @@ class CClassScope(ClassScope):
visibility, modifiers) visibility, modifiers)
if defining: if defining:
entry.func_cname = self.mangle(Naming.func_prefix, name) entry.func_cname = self.mangle(Naming.func_prefix, name)
entry.utility_code = utility_code
return entry return entry
def add_cfunction(self, name, type, pos, cname, visibility, modifiers): def add_cfunction(self, name, type, pos, cname, visibility, modifiers):
...@@ -1572,7 +1586,20 @@ class CClassScope(ClassScope): ...@@ -1572,7 +1586,20 @@ class CClassScope(ClassScope):
entry.is_cmethod = 1 entry.is_cmethod = 1
entry.prev_entry = prev_entry entry.prev_entry = prev_entry
return entry return entry
def declare_builtin_cfunction(self, name, type, cname, utility_code = None):
# overridden methods of builtin types still have their Python
# equivalent that must be accessible to support bound methods
name = EncodedString(name)
entry = self.declare_cfunction(name, type, None, cname, visibility='extern',
utility_code = utility_code)
var_entry = Entry(name, name, py_object_type)
var_entry.is_variable = 1
var_entry.is_builtin = 1
var_entry.utility_code = utility_code
entry.as_variable = var_entry
return entry
def declare_property(self, name, doc, pos): def declare_property(self, name, doc, pos):
entry = self.lookup_here(name) entry = self.lookup_here(name)
if entry is None: if entry is None:
...@@ -1660,7 +1687,7 @@ class CppClassScope(Scope): ...@@ -1660,7 +1687,7 @@ class CppClassScope(Scope):
def declare_cfunction(self, name, type, pos, def declare_cfunction(self, name, type, pos,
cname = None, visibility = 'extern', defining = 0, cname = None, visibility = 'extern', defining = 0,
api = 0, in_pxd = 0, modifiers = ()): api = 0, in_pxd = 0, modifiers = (), utility_code = None):
if name == self.name.split('::')[-1] and cname is None: if name == self.name.split('::')[-1] and cname is None:
self.check_base_default_constructor(pos) self.check_base_default_constructor(pos)
name = '<init>' name = '<init>'
...@@ -1669,6 +1696,8 @@ class CppClassScope(Scope): ...@@ -1669,6 +1696,8 @@ class CppClassScope(Scope):
entry = self.declare_var(name, type, pos, cname, visibility) entry = self.declare_var(name, type, pos, cname, visibility)
if prev_entry: if prev_entry:
entry.overloaded_alternatives = prev_entry.all_alternatives() entry.overloaded_alternatives = prev_entry.all_alternatives()
entry.utility_code = utility_code
return entry
def declare_inherited_cpp_attributes(self, base_scope): def declare_inherited_cpp_attributes(self, base_scope):
# Declare entries for all the C++ attributes of an # Declare entries for all the C++ attributes of an
...@@ -1689,7 +1718,8 @@ class CppClassScope(Scope): ...@@ -1689,7 +1718,8 @@ class CppClassScope(Scope):
for base_entry in base_scope.cfunc_entries: for base_entry in base_scope.cfunc_entries:
entry = self.declare_cfunction(base_entry.name, base_entry.type, entry = self.declare_cfunction(base_entry.name, base_entry.type,
base_entry.pos, base_entry.cname, base_entry.pos, base_entry.cname,
base_entry.visibility, base_entry.func_modifiers) base_entry.visibility, base_entry.func_modifiers,
utility_code = base_entry.utility_code)
entry.is_inherited = 1 entry.is_inherited = 1
def specialize(self, values): def specialize(self, values):
...@@ -1710,7 +1740,8 @@ class CppClassScope(Scope): ...@@ -1710,7 +1740,8 @@ class CppClassScope(Scope):
scope.declare_cfunction(e.name, scope.declare_cfunction(e.name,
e.type.specialize(values), e.type.specialize(values),
e.pos, e.pos,
e.cname) e.cname,
utility_code = e.utility_code)
return scope return scope
def add_include_file(self, filename): def add_include_file(self, filename):
......
...@@ -64,6 +64,7 @@ class Signature(object): ...@@ -64,6 +64,7 @@ class Signature(object):
error_value_map = { error_value_map = {
'O': "NULL", 'O': "NULL",
'T': "NULL",
'i': "-1", 'i': "-1",
'b': "-1", 'b': "-1",
'l': "-1", 'l': "-1",
...@@ -91,6 +92,10 @@ class Signature(object): ...@@ -91,6 +92,10 @@ class Signature(object):
# argument is 'self' for methods or 'class' for classmethods # argument is 'self' for methods or 'class' for classmethods
return self.fixed_arg_format[i] == 'T' return self.fixed_arg_format[i] == 'T'
def returns_self_type(self):
# return type is same as 'self' argument type
return self.ret_format == 'T'
def fixed_arg_type(self, i): def fixed_arg_type(self, i):
return self.format_map[self.fixed_arg_format[i]] return self.format_map[self.fixed_arg_format[i]]
...@@ -100,13 +105,20 @@ class Signature(object): ...@@ -100,13 +105,20 @@ class Signature(object):
def exception_value(self): def exception_value(self):
return self.error_value_map.get(self.ret_format) return self.error_value_map.get(self.ret_format)
def function_type(self): def function_type(self, self_arg_override=None):
# Construct a C function type descriptor for this signature # Construct a C function type descriptor for this signature
args = [] args = []
for i in xrange(self.num_fixed_args()): for i in xrange(self.num_fixed_args()):
arg_type = self.fixed_arg_type(i) if self_arg_override is not None and self.is_self_arg(i):
args.append(PyrexTypes.CFuncTypeArg("", arg_type, None)) assert isinstance(self_arg_override, PyrexTypes.CFuncTypeArg)
ret_type = self.return_type() args.append(self_arg_override)
else:
arg_type = self.fixed_arg_type(i)
args.append(PyrexTypes.CFuncTypeArg("", arg_type, None))
if self_arg_override is not None and self.returns_self_type():
ret_type = self_arg_override.type
else:
ret_type = self.return_type()
exc_value = self.exception_value() exc_value = self.exception_value()
return PyrexTypes.CFuncType(ret_type, args, exception_value = exc_value) return PyrexTypes.CFuncType(ret_type, args, exception_value = exc_value)
......
...@@ -8,6 +8,7 @@ import Nodes ...@@ -8,6 +8,7 @@ import Nodes
import ExprNodes import ExprNodes
from Nodes import Node from Nodes import Node
from ExprNodes import AtomicExprNode from ExprNodes import AtomicExprNode
from PyrexTypes import c_ptr_type
class TempHandle(object): class TempHandle(object):
# THIS IS DEPRECATED, USE LetRefNode instead # THIS IS DEPRECATED, USE LetRefNode instead
...@@ -196,6 +197,8 @@ class LetNodeMixin: ...@@ -196,6 +197,8 @@ class LetNodeMixin:
def setup_temp_expr(self, code): def setup_temp_expr(self, code):
self.temp_expression.generate_evaluation_code(code) self.temp_expression.generate_evaluation_code(code)
self.temp_type = self.temp_expression.type self.temp_type = self.temp_expression.type
if self.temp_type.is_array:
self.temp_type = c_ptr_type(self.temp_type.base_type)
self._result_in_temp = self.temp_expression.result_in_temp() self._result_in_temp = self.temp_expression.result_in_temp()
if self._result_in_temp: if self._result_in_temp:
self.temp = self.temp_expression.result() self.temp = self.temp_expression.result()
......
...@@ -3,14 +3,15 @@ cimport cython ...@@ -3,14 +3,15 @@ cimport cython
cdef class BasicVisitor: cdef class BasicVisitor:
cdef dict dispatch_table cdef dict dispatch_table
cpdef visit(self, obj) cpdef visit(self, obj)
cpdef find_handler(self, obj) cdef _visit(self, obj)
cdef find_handler(self, obj)
cdef class TreeVisitor(BasicVisitor): cdef class TreeVisitor(BasicVisitor):
cdef public list access_path cdef public list access_path
cpdef visitchild(self, child, parent, attrname, idx) cdef _visitchild(self, child, parent, attrname, idx)
@cython.locals(idx=int) @cython.locals(idx=int)
cpdef dict _visitchildren(self, parent, attrs) cdef dict _visitchildren(self, parent, attrs)
# cpdef visitchildren(self, parent, attrs=*) cpdef visitchildren(self, parent, attrs=*)
cdef class VisitorTransform(TreeVisitor): cdef class VisitorTransform(TreeVisitor):
cpdef visitchildren(self, parent, attrs=*) cpdef visitchildren(self, parent, attrs=*)
...@@ -19,3 +20,15 @@ cdef class VisitorTransform(TreeVisitor): ...@@ -19,3 +20,15 @@ cdef class VisitorTransform(TreeVisitor):
cdef class CythonTransform(VisitorTransform): cdef class CythonTransform(VisitorTransform):
cdef public context cdef public context
cdef public current_directives cdef public current_directives
cdef class ScopeTrackingTransform(CythonTransform):
cdef public scope_type
cdef public scope_node
cdef visit_scope(self, node, scope_type)
cdef class EnvTransform(CythonTransform):
cdef public list env_stack
cdef class RecursiveNodeReplacer(VisitorTransform):
cdef public orig_node
cdef public new_node
...@@ -20,6 +20,9 @@ class BasicVisitor(object): ...@@ -20,6 +20,9 @@ class BasicVisitor(object):
self.dispatch_table = {} self.dispatch_table = {}
def visit(self, obj): def visit(self, obj):
return self._visit(obj)
def _visit(self, obj):
try: try:
handler_method = self.dispatch_table[type(obj)] handler_method = self.dispatch_table[type(obj)]
except KeyError: except KeyError:
...@@ -173,10 +176,10 @@ class TreeVisitor(BasicVisitor): ...@@ -173,10 +176,10 @@ class TreeVisitor(BasicVisitor):
last_node.pos, self.__class__.__name__, last_node.pos, self.__class__.__name__,
u'\n'.join(trace), e, stacktrace) u'\n'.join(trace), e, stacktrace)
def visitchild(self, child, parent, attrname, idx): def _visitchild(self, child, parent, attrname, idx):
self.access_path.append((parent, attrname, idx)) self.access_path.append((parent, attrname, idx))
try: try:
result = self.visit(child) result = self._visit(child)
except Errors.CompileError: except Errors.CompileError:
raise raise
except Exception, e: except Exception, e:
...@@ -206,9 +209,9 @@ class TreeVisitor(BasicVisitor): ...@@ -206,9 +209,9 @@ class TreeVisitor(BasicVisitor):
child = getattr(parent, attr) child = getattr(parent, attr)
if child is not None: if child is not None:
if type(child) is list: if type(child) is list:
childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)] childretval = [self._visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
else: else:
childretval = self.visitchild(child, parent, attr, None) childretval = self._visitchild(child, parent, attr, None)
assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent) assert not isinstance(childretval, list), 'Cannot insert list here: %s in %r' % (attr, parent)
result[attr] = childretval result[attr] = childretval
return result return result
...@@ -256,7 +259,7 @@ class VisitorTransform(TreeVisitor): ...@@ -256,7 +259,7 @@ class VisitorTransform(TreeVisitor):
return node return node
def __call__(self, root): def __call__(self, root):
return self.visit(root) return self._visit(root)
class CythonTransform(VisitorTransform): class CythonTransform(VisitorTransform):
""" """
...@@ -288,8 +291,8 @@ class CythonTransform(VisitorTransform): ...@@ -288,8 +291,8 @@ class CythonTransform(VisitorTransform):
class ScopeTrackingTransform(CythonTransform): class ScopeTrackingTransform(CythonTransform):
# Keeps track of type of scopes # Keeps track of type of scopes
scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass' #scope_type: can be either of 'module', 'function', 'cclass', 'pyclass', 'struct'
scope_node = None #scope_node: the node that owns the current scope
def visit_ModuleNode(self, node): def visit_ModuleNode(self, node):
self.scope_type = 'module' self.scope_type = 'module'
...@@ -388,7 +391,7 @@ class PrintTree(TreeVisitor): ...@@ -388,7 +391,7 @@ class PrintTree(TreeVisitor):
def __call__(self, tree, phase=None): def __call__(self, tree, phase=None):
print("Parse tree dump at phase '%s'" % phase) print("Parse tree dump at phase '%s'" % phase)
self.visit(tree) self._visit(tree)
return tree return tree
# Don't do anything about process_list, the defaults gives # Don't do anything about process_list, the defaults gives
......
...@@ -5,7 +5,7 @@ cdef class Scanner: ...@@ -5,7 +5,7 @@ cdef class Scanner:
cdef public lexicon cdef public lexicon
cdef public stream cdef public stream
cdef public name cdef public name
cdef public buffer cdef public unicode buffer
cdef public Py_ssize_t buf_start_pos cdef public Py_ssize_t buf_start_pos
cdef public Py_ssize_t next_pos cdef public Py_ssize_t next_pos
cdef public Py_ssize_t cur_pos cdef public Py_ssize_t cur_pos
...@@ -26,16 +26,15 @@ cdef class Scanner: ...@@ -26,16 +26,15 @@ cdef class Scanner:
@cython.locals(input_state=long) @cython.locals(input_state=long)
cpdef next_char(self) cpdef next_char(self)
@cython.locals(queue=list)
cpdef tuple read(self) cpdef tuple read(self)
cpdef tuple scan_a_token(self) cdef tuple scan_a_token(self)
cpdef tuple position(self) cpdef tuple position(self)
@cython.locals(cur_pos=long, cur_line=long, cur_line_start=long, @cython.locals(cur_pos=long, cur_line=long, cur_line_start=long,
input_state=long, next_pos=long, input_state=long, next_pos=long, state=dict,
buf_start_pos=long, buf_len=long, buf_index=long, buf_start_pos=long, buf_len=long, buf_index=long,
trace=bint, discard=long) trace=bint, discard=long, data=unicode, buffer=unicode)
cpdef run_machine_inlined(self) cdef run_machine_inlined(self)
cpdef begin(self, state) cpdef begin(self, state)
cpdef produce(self, value, text = *) cpdef produce(self, value, text = *)
...@@ -163,7 +163,8 @@ class Scanner(object): ...@@ -163,7 +163,8 @@ class Scanner(object):
buffer = self.buffer buffer = self.buffer
buf_start_pos = self.buf_start_pos buf_start_pos = self.buf_start_pos
buf_len = len(buffer) buf_len = len(buffer)
backup_state = None b_action, b_cur_pos, b_cur_line, b_cur_line_start, b_cur_char, b_input_state, b_next_pos = \
None, 0, 0, 0, u'', 0, 0
trace = self.trace trace = self.trace
while 1: while 1:
if trace: #TRACE# if trace: #TRACE#
...@@ -173,8 +174,8 @@ class Scanner(object): ...@@ -173,8 +174,8 @@ class Scanner(object):
#action = state.action #@slow #action = state.action #@slow
action = state['action'] #@fast action = state['action'] #@fast
if action is not None: if action is not None:
backup_state = ( b_action, b_cur_pos, b_cur_line, b_cur_line_start, b_cur_char, b_input_state, b_next_pos = \
action, cur_pos, cur_line, cur_line_start, cur_char, input_state, next_pos) action, cur_pos, cur_line, cur_line_start, cur_char, input_state, next_pos
# End inlined self.save_for_backup() # End inlined self.save_for_backup()
c = cur_char c = cur_char
#new_state = state.new_state(c) #@slow #new_state = state.new_state(c) #@slow
...@@ -234,9 +235,11 @@ class Scanner(object): ...@@ -234,9 +235,11 @@ class Scanner(object):
if trace: #TRACE# if trace: #TRACE#
print("blocked") #TRACE# print("blocked") #TRACE#
# Begin inlined: action = self.back_up() # Begin inlined: action = self.back_up()
if backup_state is not None: if b_action is not None:
(action, cur_pos, cur_line, cur_line_start, (action, cur_pos, cur_line, cur_line_start,
cur_char, input_state, next_pos) = backup_state cur_char, input_state, next_pos) = \
(b_action, b_cur_pos, b_cur_line, b_cur_line_start,
b_cur_char, b_input_state, b_next_pos)
else: else:
action = None action = None
break # while 1 break # while 1
......
...@@ -18,6 +18,10 @@ def inline(f, *args, **kwds): ...@@ -18,6 +18,10 @@ def inline(f, *args, **kwds):
assert len(args) == len(kwds) == 0 assert len(args) == len(kwds) == 0
return f return f
def compile(f):
from Cython.Build.Inline import RuntimeCompiledFunction
return RuntimeCompiledFunction(f)
# Special functions # Special functions
def cdiv(a, b): def cdiv(a, b):
......
...@@ -173,15 +173,23 @@ def unpack_source_tree(tree_file, dir=None): ...@@ -173,15 +173,23 @@ def unpack_source_tree(tree_file, dir=None):
dir = tempfile.mkdtemp() dir = tempfile.mkdtemp()
header = [] header = []
cur_file = None cur_file = None
for line in open(tree_file).readlines(): f = open(tree_file)
lines = f.readlines()
f.close()
f = None
for line in lines:
if line[:5] == '#####': if line[:5] == '#####':
filename = line.strip().strip('#').strip().replace('/', os.path.sep) filename = line.strip().strip('#').strip().replace('/', os.path.sep)
path = os.path.join(dir, filename) path = os.path.join(dir, filename)
if not os.path.exists(os.path.dirname(path)): if not os.path.exists(os.path.dirname(path)):
os.makedirs(os.path.dirname(path)) os.makedirs(os.path.dirname(path))
if cur_file is not None:
cur_file.close()
cur_file = open(path, 'w') cur_file = open(path, 'w')
elif cur_file is not None: elif cur_file is not None:
cur_file.write(line) cur_file.write(line)
else: else:
header.append(line) header.append(line)
if cur_file is not None:
cur_file.close()
return dir, ''.join(header) return dir, ''.join(header)
...@@ -316,18 +316,22 @@ class CythonCompileTestCase(unittest.TestCase): ...@@ -316,18 +316,22 @@ class CythonCompileTestCase(unittest.TestCase):
source_file = os.path.join(test_directory, module) + '.pyx' source_file = os.path.join(test_directory, module) + '.pyx'
source_and_output = codecs.open( source_and_output = codecs.open(
self.find_module_source_file(source_file), 'rU', 'ISO-8859-1') self.find_module_source_file(source_file), 'rU', 'ISO-8859-1')
out = codecs.open(os.path.join(workdir, module + '.pyx'), try:
'w', 'ISO-8859-1') out = codecs.open(os.path.join(workdir, module + '.pyx'),
for line in source_and_output: 'w', 'ISO-8859-1')
last_line = line for line in source_and_output:
if line.startswith("_ERRORS"): last_line = line
out.close() if line.startswith("_ERRORS"):
out = ErrorWriter() out.close()
else: out = ErrorWriter()
out.write(line) else:
out.write(line)
finally:
source_and_output.close()
try: try:
geterrors = out.geterrors geterrors = out.geterrors
except AttributeError: except AttributeError:
out.close()
return [] return []
else: else:
return geterrors() return geterrors()
...@@ -660,7 +664,10 @@ def collect_doctests(path, module_prefix, suite, selectors): ...@@ -660,7 +664,10 @@ def collect_doctests(path, module_prefix, suite, selectors):
for f in filenames: for f in filenames:
if file_matches(f): if file_matches(f):
if not f.endswith('.py'): continue if not f.endswith('.py'): continue
filepath = os.path.join(dirpath, f)[:-len(".py")] filepath = os.path.join(dirpath, f)
if os.path.getsize(filepath) == 0: continue
if 'no doctest' in open(filepath).next(): continue
filepath = filepath[:-len(".py")]
modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.') modulename = module_prefix + filepath[len(path)+1:].replace(os.path.sep, '.')
if not [ 1 for match in selectors if match(modulename) ]: if not [ 1 for match in selectors if match(modulename) ]:
continue continue
......
...@@ -101,8 +101,12 @@ def compile_cython_modules(profile=False): ...@@ -101,8 +101,12 @@ def compile_cython_modules(profile=False):
pyx_source_file = source_file + ".py" pyx_source_file = source_file + ".py"
else: else:
pyx_source_file = source_file + ".pyx" pyx_source_file = source_file + ".pyx"
dep_files = []
if os.path.exists(source_file + '.pxd'):
dep_files.append(source_file + '.pxd')
extensions.append( extensions.append(
Extension(module, sources = [pyx_source_file]) Extension(module, sources = [pyx_source_file],
depends = dep_files)
) )
class build_ext(build_ext_orig): class build_ext(build_ext_orig):
...@@ -154,9 +158,18 @@ def compile_cython_modules(profile=False): ...@@ -154,9 +158,18 @@ def compile_cython_modules(profile=False):
else: else:
pyx_source_file = source_file + ".pyx" pyx_source_file = source_file + ".pyx"
c_source_file = source_file + ".c" c_source_file = source_file + ".c"
if not os.path.exists(c_source_file) or \ source_is_newer = False
Utils.file_newer_than(pyx_source_file, if not os.path.exists(c_source_file):
Utils.modification_time(c_source_file)): source_is_newer = True
else:
c_last_modified = Utils.modification_time(c_source_file)
if Utils.file_newer_than(pyx_source_file, c_last_modified):
source_is_newer = True
else:
pxd_source_file = source_file + ".pxd"
if os.path.exists(pxd_source_file) and Utils.file_newer_than(pxd_source_file, c_last_modified):
source_is_newer = True
if source_is_newer:
print("Compiling module %s ..." % module) print("Compiling module %s ..." % module)
result = compile(pyx_source_file) result = compile(pyx_source_file)
c_source_file = result.c_file c_source_file = result.c_file
...@@ -241,6 +254,7 @@ setup( ...@@ -241,6 +254,7 @@ setup(
scripts = scripts, scripts = scripts,
packages=[ packages=[
'Cython', 'Cython',
'Cython.Build',
'Cython.Compiler', 'Cython.Compiler',
'Cython.Runtime', 'Cython.Runtime',
'Cython.Distutils', 'Cython.Distutils',
......
...@@ -16,8 +16,12 @@ with_statement_module_level_T536 ...@@ -16,8 +16,12 @@ with_statement_module_level_T536
function_as_method_T494 function_as_method_T494
closure_inside_cdef_T554 closure_inside_cdef_T554
ipow_crash_T562 ipow_crash_T562
pure_mode_cmethod_inheritance_T583
# CPython regression tests that don't current work: # CPython regression tests that don't current work:
pyregr.test_threadsignals pyregr.test_threadsignals
pyregr.test_module pyregr.test_module
# CPython regression tests that don't make sense
pyregr.test_gdb
...@@ -5,7 +5,7 @@ PYTHON -c "import a" ...@@ -5,7 +5,7 @@ PYTHON -c "import a"
# TODO: Better interface... # TODO: Better interface...
from Cython.Compiler.Dependencies import cythonize from Cython.Build.Dependencies import cythonize
from distutils.core import setup from distutils.core import setup
......
...@@ -5,7 +5,7 @@ PYTHON -c "import a" ...@@ -5,7 +5,7 @@ PYTHON -c "import a"
# TODO: Better interface... # TODO: Better interface...
from Cython.Compiler.Dependencies import cythonize from Cython.Build.Dependencies import cythonize
from distutils.core import setup from distutils.core import setup
......
...@@ -15,6 +15,7 @@ cdef int f() except -1: ...@@ -15,6 +15,7 @@ cdef int f() except -1:
i = len(x) i = len(x)
x = open(y, z) x = open(y, z)
x = pow(y, z, w) x = pow(y, z, w)
x = pow(y, z)
x = reload(y) x = reload(y)
x = repr(y) x = repr(y)
setattr(x, y, z) setattr(x, y, z)
......
...@@ -6,5 +6,5 @@ cdef nogil class test: pass ...@@ -6,5 +6,5 @@ cdef nogil class test: pass
_ERRORS = u""" _ERRORS = u"""
2: 5: Expected an identifier, found 'pass' 2: 5: Expected an identifier, found 'pass'
3: 9: Empty declarator 3: 9: Empty declarator
4:11: Expected ':' 4:11: Expected ':', found 'class'
""" """
...@@ -5,5 +5,5 @@ cpdef nogil class test: pass ...@@ -5,5 +5,5 @@ cpdef nogil class test: pass
_ERRORS = u""" _ERRORS = u"""
2: 6: cdef blocks cannot be declared cpdef 2: 6: cdef blocks cannot be declared cpdef
3: 6: cdef blocks cannot be declared cpdef 3: 6: cdef blocks cannot be declared cpdef
3:12: Expected ':' 3:12: Expected ':', found 'class'
""" """
...@@ -2,5 +2,5 @@ cdef packed foo: ...@@ -2,5 +2,5 @@ cdef packed foo:
pass pass
_ERRORS = u""" _ERRORS = u"""
1:12: Expected 'struct' 1:12: Expected 'struct', found 'foo'
""" """
cimport cython
_set = set # CPython may not define it (in Py2.3), but Cython does :)
def test_set_clear_bound():
"""
>>> type(test_set_clear_bound()) is _set
True
>>> list(test_set_clear_bound())
[]
"""
cdef set s1 = set([1])
clear = s1.clear
clear()
return s1
text = u'ab jd sdflk as sa sadas asdas fsdf '
pipe_sep = u'|'
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NameNode")
def test_unicode_join_bound(unicode sep, l):
"""
>>> l = text.split()
>>> len(l)
8
>>> print( pipe_sep.join(l) )
ab|jd|sdflk|as|sa|sadas|asdas|fsdf
>>> print( test_unicode_join_bound(pipe_sep, l) )
ab|jd|sdflk|as|sa|sadas|asdas|fsdf
"""
join = sep.join
return join(l)
import sys
IS_PY3 = sys.version_info[0] >= 3
__doc__ = """
>>> it = iter([1,2,3])
>>> if not IS_PY3:
... next = type(it).next
>>> next(it)
1
>>> next(it)
2
>>> next(it)
3
>>> next(it)
Traceback (most recent call last):
StopIteration
>>> next(it)
Traceback (most recent call last):
StopIteration
>>> if IS_PY3: next(it, 123)
... else: print(123)
123
"""
if IS_PY3:
__doc__ += """
>>> next(123)
Traceback (most recent call last):
TypeError: int object is not an iterator
"""
def test_next_not_iterable(it):
"""
>>> test_next_not_iterable(123)
Traceback (most recent call last):
TypeError: int object is not an iterator
"""
return next(it)
def test_single_next(it):
"""
>>> it = iter([1,2,3])
>>> test_single_next(it)
1
>>> test_single_next(it)
2
>>> test_single_next(it)
3
>>> test_single_next(it)
Traceback (most recent call last):
StopIteration
>>> test_single_next(it)
Traceback (most recent call last):
StopIteration
"""
return next(it)
def test_default_next(it, default):
"""
>>> it = iter([1,2,3])
>>> test_default_next(it, 99)
1
>>> test_default_next(it, 99)
2
>>> test_default_next(it, 99)
3
>>> test_default_next(it, 99)
99
>>> test_default_next(it, 99)
99
"""
return next(it, default)
def test_next_override(it):
"""
>>> it = iter([1,2,3])
>>> test_next_override(it)
1
>>> test_next_override(it)
1
>>> test_next_override(it)
1
>>> test_next_override(it)
1
"""
def next(it):
return 1
return next(it)
def pow3(a,b,c):
"""
>>> pow3(2,3,5)
3
>>> pow3(3,3,5)
2
"""
return pow(a,b,c)
def pow3_const():
"""
>>> pow3_const()
3
"""
return pow(2,3,5)
def pow2(a,b):
"""
>>> pow2(2,3)
8
>>> pow2(3,3)
27
"""
return pow(a,b)
def pow2_const():
"""
>>> pow2_const()
8
"""
return pow(2,3)
def pow_args(*args):
"""
>>> pow_args(2,3)
8
>>> pow_args(2,3,5)
3
"""
return pow(*args)
...@@ -37,8 +37,8 @@ def long_int_mix(): ...@@ -37,8 +37,8 @@ def long_int_mix():
""" """
>>> long_int_mix() == 1 + (2 * 3) // 2 >>> long_int_mix() == 1 + (2 * 3) // 2
True True
>>> if IS_PY3: type(long_int_mix()) is int >>> if IS_PY3: type(long_int_mix()) is int or type(long_int_mix())
... else: type(long_int_mix()) is long ... else: type(long_int_mix()) is long or type(long_int_mix())
True True
""" """
return 1L + (2 * 3L) // 2 return 1L + (2 * 3L) // 2
......
# cython: language_level=3 # cython: language_level=3
cimport cython
try: try:
sorted sorted
except NameError: except NameError:
...@@ -15,6 +17,25 @@ def print_function(*args): ...@@ -15,6 +17,25 @@ def print_function(*args):
""" """
print(*args) # this isn't valid Py2 syntax print(*args) # this isn't valid Py2 syntax
def exec3_function(cmd):
"""
>>> exec3_function('a = 1+1')['a']
2
"""
g = {}
l = {}
exec(cmd, g, l)
return l
def exec2_function(cmd):
"""
>>> exec2_function('a = 1+1')['a']
2
"""
g = {}
exec(cmd, g)
return g
ustring = "abcdefg" ustring = "abcdefg"
def unicode_literals(): def unicode_literals():
...@@ -36,6 +57,13 @@ def list_comp(): ...@@ -36,6 +57,13 @@ def list_comp():
assert x == 'abc' # don't leak in Py3 code assert x == 'abc' # don't leak in Py3 code
return result return result
def list_comp_unknown_type(l):
"""
>>> list_comp_unknown_type(range(5))
[0, 4, 8]
"""
return [x*2 for x in l if x % 2 == 0]
def set_comp(): def set_comp():
""" """
>>> sorted(set_comp()) >>> sorted(set_comp())
...@@ -55,3 +83,26 @@ def dict_comp(): ...@@ -55,3 +83,26 @@ def dict_comp():
result = {x:x*2 for x in range(5) if x % 2 == 0} result = {x:x*2 for x in range(5) if x % 2 == 0}
assert x == 'abc' # don't leak assert x == 'abc' # don't leak
return result return result
# in Python 3, d.keys/values/items() are the iteration methods
@cython.test_assert_path_exists(
"//WhileStatNode",
"//WhileStatNode/SimpleCallNode",
"//WhileStatNode/SimpleCallNode/NameNode")
@cython.test_fail_if_path_exists(
"//ForInStatNode")
def dict_iter(dict d):
"""
>>> d = {'a' : 1, 'b' : 2, 'c' : 3}
>>> keys, values, items = dict_iter(d)
>>> sorted(keys)
['a', 'b', 'c']
>>> sorted(values)
[1, 2, 3]
>>> sorted(items)
[('a', 1), ('b', 2), ('c', 3)]
"""
keys = [ key for key in d.keys() ]
values = [ value for value in d.values() ]
items = [ item for item in d.items() ]
return keys, values, items
__doc__ = u""" cimport cython
>>> str(f(5, 7))
'29509034655744'
"""
def f(a,b): def f(a,b):
"""
>>> str(f(5, 7))
'29509034655744'
"""
a += b a += b
a *= b a *= b
a **= b a **= b
...@@ -117,3 +117,130 @@ def test_side_effects(): ...@@ -117,3 +117,130 @@ def test_side_effects():
b[side_effect(3)] += 10 b[side_effect(3)] += 10
b[c_side_effect(4)] += 100 b[c_side_effect(4)] += 100
return a, [b[i] for i from 0 <= i < 5] return a, [b[i] for i from 0 <= i < 5]
@cython.cdivision(True)
def test_inplace_cdivision(int a, int b):
"""
>>> test_inplace_cdivision(13, 10)
3
>>> test_inplace_cdivision(13, -10)
3
>>> test_inplace_cdivision(-13, 10)
-3
>>> test_inplace_cdivision(-13, -10)
-3
"""
a %= b
return a
@cython.cdivision(False)
def test_inplace_pydivision(int a, int b):
"""
>>> test_inplace_pydivision(13, 10)
3
>>> test_inplace_pydivision(13, -10)
-7
>>> test_inplace_pydivision(-13, 10)
7
>>> test_inplace_pydivision(-13, -10)
-3
"""
a %= b
return a
def test_complex_inplace(double complex x, double complex y):
"""
>>> test_complex_inplace(1, 1)
(2+0j)
>>> test_complex_inplace(2, 3)
(15+0j)
>>> test_complex_inplace(2+3j, 4+5j)
(-16+62j)
"""
x += y
x *= y
return x
# The following is more subtle than one might expect.
cdef struct Inner:
int x
cdef struct Aa:
int value
Inner inner
cdef struct NestedA:
Aa a
cdef struct ArrayOfA:
Aa[10] a
def nested_struct_assignment():
"""
>>> nested_struct_assignment()
"""
cdef NestedA nested
nested.a.value = 2
nested.a.value += 3
assert nested.a.value == 5
nested.a.inner.x = 5
nested.a.inner.x += 10
assert nested.a.inner.x == 15
def nested_array_assignment():
"""
>>> nested_array_assignment()
c side effect 0
c side effect 1
"""
cdef ArrayOfA array
array.a[0].value = 2
array.a[c_side_effect(0)].value += 3
assert array.a[0].value == 5
array.a[1].inner.x = 5
array.a[c_side_effect(1)].inner.x += 10
assert array.a[1].inner.x == 15
cdef class VerboseDict(object):
cdef name
cdef dict dict
def __init__(self, name, **kwds):
self.name = name
self.dict = kwds
def __getitem__(self, key):
print self.name, "__getitem__", key
return self.dict[key]
def __setitem__(self, key, value):
print self.name, "__setitem__", key, value
self.dict[key] = value
def __repr__(self):
return repr(self.name)
def deref_and_increment(o, key):
"""
>>> deref_and_increment({'a': 1}, 'a')
side effect a
>>> v = VerboseDict('v', a=10)
>>> deref_and_increment(v, 'a')
side effect a
v __getitem__ a
v __setitem__ a 11
"""
o[side_effect(key)] += 1
def double_deref_and_increment(o, key1, key2):
"""
>>> v = VerboseDict('v', a=10)
>>> w = VerboseDict('w', vkey=v)
>>> double_deref_and_increment(w, 'vkey', 'a')
side effect vkey
w __getitem__ vkey
side effect a
v __getitem__ a
v __setitem__ a 11
"""
o[side_effect(key1)][side_effect(key2)] += 1
...@@ -3,11 +3,18 @@ __doc__ = u""" ...@@ -3,11 +3,18 @@ __doc__ = u"""
(1, 1L, -1L, 18446744073709551615L) (1, 1L, -1L, 18446744073709551615L)
>>> py_longs() >>> py_longs()
(1, 1L, 100000000000000000000000000000000L, -100000000000000000000000000000000L) (1, 1L, 100000000000000000000000000000000L, -100000000000000000000000000000000L)
>>> py_huge_calculated_long()
1606938044258990275541962092341162602522202993782792835301376L
>>> py_huge_computation_small_result_neg()
(-2535301200456458802993406410752L, -2535301200456458802993406410752L)
""" """
import sys cimport cython
from cython cimport typeof from cython cimport typeof
import sys
if sys.version_info[0] >= 3: if sys.version_info[0] >= 3:
__doc__ = __doc__.replace(u'L', u'') __doc__ = __doc__.replace(u'L', u'')
...@@ -27,6 +34,25 @@ def c_longs(): ...@@ -27,6 +34,25 @@ def c_longs():
def py_longs(): def py_longs():
return 1, 1L, 100000000000000000000000000000000, -100000000000000000000000000000000 return 1, 1L, 100000000000000000000000000000000, -100000000000000000000000000000000
@cython.test_fail_if_path_exists("//NumBinopNode", "//IntBinopNode")
@cython.test_assert_path_exists("//ReturnStatNode/IntNode")
def py_huge_calculated_long():
return 1 << 200
@cython.test_fail_if_path_exists("//NumBinopNode", "//IntBinopNode")
@cython.test_assert_path_exists("//ReturnStatNode/IntNode")
def py_huge_computation_small_result():
"""
>>> py_huge_computation_small_result()
2
"""
return (1 << 200) >> 199
@cython.test_fail_if_path_exists("//NumBinopNode", "//IntBinopNode")
#@cython.test_assert_path_exists("//ReturnStatNode/IntNode")
def py_huge_computation_small_result_neg():
return -(2 ** 101), (-2) ** 101
def large_literal(): def large_literal():
""" """
>>> type(large_literal()) is int >>> type(large_literal()) is int
...@@ -59,50 +85,67 @@ def c_long_types(): ...@@ -59,50 +85,67 @@ def c_long_types():
def c_oct(): def c_oct():
""" """
>>> c_oct() >>> c_oct()
(1, 17, 63) (1, -17, 63)
""" """
cdef int a = 0o01 cdef int a = 0o01
cdef int b = 0o21 cdef int b = -0o21
cdef int c = 0o77 cdef int c = 0o77
return a,b,c return a,b,c
def c_oct_py2_legacy():
"""
>>> c_oct_py2_legacy()
(1, -17, 63)
"""
cdef int a = 001
cdef int b = -021
cdef int c = 077
return a,b,c
def py_oct(): def py_oct():
""" """
>>> py_oct() >>> py_oct()
(1, 17, 63) (1, -17, 63)
"""
return 0o01, -0o21, 0o77
def py_oct_py2_legacy():
"""
>>> py_oct_py2_legacy()
(1, -17, 63)
""" """
return 0o01, 0o21, 0o77 return 001, -021, 077
def c_hex(): def c_hex():
""" """
>>> c_hex() >>> c_hex()
(1, 33, 255) (1, -33, 255)
""" """
cdef int a = 0x01 cdef int a = 0x01
cdef int b = 0x21 cdef int b = -0x21
cdef int c = 0xFF cdef int c = 0xFF
return a,b,c return a,b,c
def py_hex(): def py_hex():
""" """
>>> py_hex() >>> py_hex()
(1, 33, 255) (1, -33, 255)
""" """
return 0x01, 0x21, 0xFF return 0x01, -0x21, 0xFF
def c_bin(): def c_bin():
""" """
>>> c_bin() >>> c_bin()
(1, 2, 15) (1, -2, 15)
""" """
cdef int a = 0b01 cdef int a = 0b01
cdef int b = 0b10 cdef int b = -0b10
cdef int c = 0b1111 cdef int c = 0b1111
return a,b,c return a,b,c
def py_bin(): def py_bin():
""" """
>>> py_bin() >>> py_bin()
(1, 2, 15) (1, -2, 15)
""" """
return 0b01, 0b10, 0b1111 return 0b01, -0b10, 0b1111
cimport cython
def f(obj1, obj2, obj3, obj4, obj5): def f(obj1, obj2, obj3, obj4, obj5):
""" """
>>> f(1, 2, 3, 4, 5) >>> f(1, 2, 3, 4, 5)
...@@ -54,6 +57,7 @@ def test_list_sort_reversed(): ...@@ -54,6 +57,7 @@ def test_list_sort_reversed():
l1.sort(reversed=True) l1.sort(reversed=True)
return l1 return l1
@cython.test_assert_path_exists("//SimpleCallNode//NoneCheckNode")
def test_list_reverse(): def test_list_reverse():
""" """
>>> test_list_reverse() >>> test_list_reverse()
...@@ -64,6 +68,17 @@ def test_list_reverse(): ...@@ -64,6 +68,17 @@ def test_list_reverse():
l1.reverse() l1.reverse()
return l1 return l1
@cython.test_assert_path_exists("//SimpleCallNode//NoneCheckNode")
def test_list_append():
"""
>>> test_list_append()
[1, 2, 3, 4]
"""
cdef list l1 = [1,2]
l1.append(3)
l1.append(4)
return l1
def test_list_pop(): def test_list_pop():
""" """
>>> test_list_pop() >>> test_list_pop()
......
"""
>>> obj = Foo()
>>> obj.metaclass_was_here
True
"""
class Base(type): class Base(type):
def __new__(cls, name, bases, attrs): def __new__(cls, name, bases, attrs):
attrs['metaclass_was_here'] = True attrs['metaclass_was_here'] = True
return type.__new__(cls, name, bases, attrs) return type.__new__(cls, name, bases, attrs)
class Foo(object): class Foo(object):
"""
>>> obj = Foo()
>>> obj.metaclass_was_here
True
"""
__metaclass__ = Base __metaclass__ = Base
class Py3Base(type):
def __new__(cls, name, bases, attrs, foo=None):
attrs['foo'] = foo
return type.__new__(cls, name, bases, attrs)
def __init__(self, cls, attrs, obj, foo=None):
pass
@staticmethod
def __prepare__(name, bases, **kwargs):
return {'bar': 666, 'dirty': True}
class Py3Foo(object, metaclass=Py3Base, foo=123):
"""
>>> obj = Py3Foo()
>>> obj.foo
123
>>> obj.bar
666
>>> obj.dirty
False
"""
dirty = False
cdef class Base:
cpdef str noargs(self)
cpdef str int_arg(self, int i)
cpdef str _class(tp)
cdef class Derived(Base):
cpdef str noargs(self)
cpdef str int_arg(self, int i)
cpdef str _class(tp)
cdef class DerivedDerived(Derived):
cpdef str noargs(self)
cpdef str int_arg(self, int i)
cpdef str _class(tp)
cdef class Derived2(Base):
cpdef str noargs(self)
cpdef str int_arg(self, int i)
cpdef str _class(tp)
class Base(object):
'''
>>> base = Base()
>>> print(base.noargs())
Base
>>> print(base.int_arg(1))
Base
>>> print(base._class())
Base
'''
def noargs(self):
return "Base"
def int_arg(self, i):
return "Base"
@classmethod
def _class(tp):
return "Base"
class Derived(Base):
'''
>>> derived = Derived()
>>> print(derived.noargs())
Derived
>>> print(derived.int_arg(1))
Derived
>>> print(derived._class())
Derived
'''
def noargs(self):
return "Derived"
def int_arg(self, i):
return "Derived"
@classmethod
def _class(tp):
return "Derived"
class DerivedDerived(Derived):
'''
>>> derived = DerivedDerived()
>>> print(derived.noargs())
DerivedDerived
>>> print(derived.int_arg(1))
DerivedDerived
>>> print(derived._class())
DerivedDerived
'''
def noargs(self):
return "DerivedDerived"
def int_arg(self, i):
return "DerivedDerived"
@classmethod
def _class(tp):
return "DerivedDerived"
class Derived2(Base):
'''
>>> derived = Derived2()
>>> print(derived.noargs())
Derived2
>>> print(derived.int_arg(1))
Derived2
>>> print(derived._class())
Derived2
'''
def noargs(self):
return "Derived2"
def int_arg(self, i):
return "Derived2"
@classmethod
def _class(tp):
return "Derived2"
...@@ -56,6 +56,15 @@ def test_set_clear(): ...@@ -56,6 +56,15 @@ def test_set_clear():
s1.clear() s1.clear()
return s1 return s1
def test_set_clear_None():
"""
>>> test_set_clear_None()
Traceback (most recent call last):
AttributeError: 'NoneType' object has no attribute 'clear'
"""
cdef set s1 = None
s1.clear()
def test_set_list_comp(): def test_set_list_comp():
""" """
>>> type(test_set_list_comp()) is _set >>> type(test_set_list_comp()) is _set
......
...@@ -350,8 +350,7 @@ cdef object some_float_value(): ...@@ -350,8 +350,7 @@ cdef object some_float_value():
@cython.test_fail_if_path_exists('//NameNode[@type.is_pyobject = True]') @cython.test_fail_if_path_exists('//NameNode[@type.is_pyobject = True]')
@cython.test_assert_path_exists('//InPlaceAssignmentNode/NameNode', @cython.test_assert_path_exists('//NameNode[@type.is_pyobject]',
'//NameNode[@type.is_pyobject]',
'//NameNode[@type.is_pyobject = False]') '//NameNode[@type.is_pyobject = False]')
@infer_types(None) @infer_types(None)
def double_loop(): def double_loop():
......
...@@ -243,7 +243,6 @@ def index_add(unicode ustring, Py_ssize_t i, Py_ssize_t j): ...@@ -243,7 +243,6 @@ def index_add(unicode ustring, Py_ssize_t i, Py_ssize_t j):
@cython.test_assert_path_exists("//CoerceToPyTypeNode", @cython.test_assert_path_exists("//CoerceToPyTypeNode",
"//IndexNode", "//IndexNode",
"//InPlaceAssignmentNode",
"//CoerceToPyTypeNode//IndexNode") "//CoerceToPyTypeNode//IndexNode")
@cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode") @cython.test_fail_if_path_exists("//IndexNode//CoerceToPyTypeNode")
def index_concat_loop(unicode ustring): def index_concat_loop(unicode ustring):
......
...@@ -180,9 +180,12 @@ pipe_sep = u'|' ...@@ -180,9 +180,12 @@ pipe_sep = u'|'
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
"//CoerceToPyTypeNode", "//CoerceFromPyTypeNode", "//CoerceToPyTypeNode", "//CoerceFromPyTypeNode",
"//CastNode", "//TypecastNode") "//CastNode", "//TypecastNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]")
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//PythonCapiCallNode") "//SimpleCallNode",
"//SimpleCallNode//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def join(unicode sep, l): def join(unicode sep, l):
""" """
>>> l = text.split() >>> l = text.split()
...@@ -197,9 +200,11 @@ def join(unicode sep, l): ...@@ -197,9 +200,11 @@ def join(unicode sep, l):
@cython.test_fail_if_path_exists( @cython.test_fail_if_path_exists(
"//CoerceToPyTypeNode", "//CoerceFromPyTypeNode", "//CoerceToPyTypeNode", "//CoerceFromPyTypeNode",
"//CastNode", "//TypecastNode", "//NoneCheckNode") "//CastNode", "//TypecastNode", "//NoneCheckNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = true]")
@cython.test_assert_path_exists( @cython.test_assert_path_exists(
"//PythonCapiCallNode") "//SimpleCallNode",
"//SimpleCallNode//AttributeNode[@is_py_attr = false]")
def join_sep(l): def join_sep(l):
""" """
>>> l = text.split() >>> l = text.split()
...@@ -212,6 +217,22 @@ def join_sep(l): ...@@ -212,6 +217,22 @@ def join_sep(l):
""" """
return u'|'.join(l) return u'|'.join(l)
@cython.test_assert_path_exists(
"//SimpleCallNode",
"//SimpleCallNode//NameNode")
def join_unbound(unicode sep, l):
"""
>>> l = text.split()
>>> len(l)
8
>>> print( pipe_sep.join(l) )
ab|jd|sdflk|as|sa|sadas|asdas|fsdf
>>> print( join_unbound(pipe_sep, l) )
ab|jd|sdflk|as|sa|sadas|asdas|fsdf
"""
join = unicode.join
return join(sep, l)
# unicode.startswith(s, prefix, [start, [end]]) # unicode.startswith(s, prefix, [start, [end]])
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment