import Nodes import ExprNodes import PyrexTypes import Visitor import Builtin import UtilNodes import TypeSlots import Symtab import Options from Code import UtilityCode from StringEncoding import EncodedString, BytesLiteral from Errors import error from ParseTreeTransforms import SkipDeclarations import codecs try: reduce except NameError: from functools import reduce def unwrap_node(node): while isinstance(node, UtilNodes.ResultRefNode): node = node.expression return node def is_common_value(a, b): a = unwrap_node(a) b = unwrap_node(b) if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): return a.name == b.name if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): return not a.is_py_attr and is_common_value(a.obj, b.obj) and a.attribute == b.attribute return False class IterationTransform(Visitor.VisitorTransform): """Transform some common for-in loop patterns into efficient C loops: - for-in-dict loop becomes a while loop calling PyDict_Next() - for-in-enumerate is replaced by an external counter variable - for-in-range loop becomes a plain C for loop """ PyDict_Next_func_type = PyrexTypes.CFuncType( PyrexTypes.c_bint_type, [ PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None), PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None), PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None) ]) PyDict_Next_name = EncodedString("PyDict_Next") PyDict_Next_entry = Symtab.Entry( PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type) visit_Node = Visitor.VisitorTransform.recurse_to_children def visit_ModuleNode(self, node): self.current_scope = node.scope self.visitchildren(node) return node def visit_DefNode(self, node): oldscope = self.current_scope self.current_scope = node.entry.scope self.visitchildren(node) self.current_scope = oldscope return node def visit_ForInStatNode(self, node): self.visitchildren(node) return self._optimise_for_loop(node) def _optimise_for_loop(self, node): iterator = node.iterator.sequence if iterator.type is Builtin.dict_type: # like iterating over dict.keys() return self._transform_dict_iteration( node, dict_obj=iterator, keys=True, values=False) if not isinstance(iterator, ExprNodes.SimpleCallNode): return node function = iterator.function # dict iteration? if isinstance(function, ExprNodes.AttributeNode) and \ function.obj.type == Builtin.dict_type: dict_obj = function.obj method = function.attribute keys = values = False if method == 'iterkeys': keys = True elif method == 'itervalues': values = True elif method == 'iteritems': keys = values = True else: return node return self._transform_dict_iteration( node, dict_obj, keys, values) # enumerate() ? if iterator.self is None and \ isinstance(function, ExprNodes.NameNode) and \ function.entry.is_builtin and \ function.name == 'enumerate': return self._transform_enumerate_iteration(node, iterator) # range() iteration? if Options.convert_range and node.target.type.is_int: if iterator.self is None and \ isinstance(function, ExprNodes.NameNode) and \ function.entry.is_builtin and \ function.name in ('range', 'xrange'): return self._transform_range_iteration(node, iterator) return node def _transform_enumerate_iteration(self, node, enumerate_function): args = enumerate_function.arg_tuple.args if len(args) == 0: error(enumerate_function.pos, "enumerate() requires an iterable argument") return node elif len(args) > 1: error(enumerate_function.pos, "enumerate() takes at most 1 argument") return node if not node.target.is_sequence_constructor: # leave this untouched for now return node targets = node.target.args if len(targets) != 2: # leave this untouched for now return node if not isinstance(targets[0], ExprNodes.NameNode): # leave this untouched for now return node enumerate_target, iterable_target = targets counter_type = enumerate_target.type if not counter_type.is_pyobject and not counter_type.is_int: # nothing we can do here, I guess return node temp = UtilNodes.LetRefNode(ExprNodes.IntNode(enumerate_function.pos, value='0', type=counter_type, constant_result=0)) inc_expression = ExprNodes.AddNode( enumerate_function.pos, operand1 = temp, operand2 = ExprNodes.IntNode(node.pos, value='1', type=counter_type, constant_result=1), operator = '+', type = counter_type, is_temp = counter_type.is_pyobject ) loop_body = [ Nodes.SingleAssignmentNode( pos = enumerate_target.pos, lhs = enumerate_target, rhs = temp), Nodes.SingleAssignmentNode( pos = enumerate_target.pos, lhs = temp, rhs = inc_expression) ] if isinstance(node.body, Nodes.StatListNode): node.body.stats = loop_body + node.body.stats else: loop_body.append(node.body) node.body = Nodes.StatListNode( node.body.pos, stats = loop_body) node.target = iterable_target node.iterator.sequence = enumerate_function.arg_tuple.args[0] # recurse into loop to check for further optimisations return UtilNodes.LetNode(temp, self._optimise_for_loop(node)) def _transform_range_iteration(self, node, range_function): args = range_function.arg_tuple.args if len(args) < 3: step_pos = range_function.pos step_value = 1 step = ExprNodes.IntNode(step_pos, value='1', constant_result=1) else: step = args[2] step_pos = step.pos if not isinstance(step.constant_result, (int, long)): # cannot determine step direction return node step_value = step.constant_result if step_value == 0: # will lead to an error elsewhere return node if not isinstance(step, ExprNodes.IntNode): step = ExprNodes.IntNode(step_pos, value=str(step_value), constant_result=step_value) if step_value < 0: step.value = -step_value relation1 = '>=' relation2 = '>' else: relation1 = '<=' relation2 = '<' if len(args) == 1: bound1 = ExprNodes.IntNode(range_function.pos, value='0', constant_result=0) bound2 = args[0].coerce_to_integer(self.current_scope) else: bound1 = args[0].coerce_to_integer(self.current_scope) bound2 = args[1].coerce_to_integer(self.current_scope) step = step.coerce_to_integer(self.current_scope) for_node = Nodes.ForFromStatNode( node.pos, target=node.target, bound1=bound1, relation1=relation1, relation2=relation2, bound2=bound2, step=step, body=node.body, else_clause=node.else_clause, from_range=True) return for_node def _transform_dict_iteration(self, node, dict_obj, keys, values): py_object_ptr = PyrexTypes.c_void_ptr_type temps = [] temp = UtilNodes.TempHandle(PyrexTypes.py_object_type) temps.append(temp) dict_temp = temp.ref(dict_obj.pos) temp = UtilNodes.TempHandle(PyrexTypes.c_py_ssize_t_type) temps.append(temp) pos_temp = temp.ref(node.pos) pos_temp_addr = ExprNodes.AmpersandNode( node.pos, operand=pos_temp, type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type)) if keys: temp = UtilNodes.TempHandle(py_object_ptr) temps.append(temp) key_temp = temp.ref(node.target.pos) key_temp_addr = ExprNodes.AmpersandNode( node.target.pos, operand=key_temp, type=PyrexTypes.c_ptr_type(py_object_ptr)) else: key_temp_addr = key_temp = ExprNodes.NullNode( pos=node.target.pos) if values: temp = UtilNodes.TempHandle(py_object_ptr) temps.append(temp) value_temp = temp.ref(node.target.pos) value_temp_addr = ExprNodes.AmpersandNode( node.target.pos, operand=value_temp, type=PyrexTypes.c_ptr_type(py_object_ptr)) else: value_temp_addr = value_temp = ExprNodes.NullNode( pos=node.target.pos) key_target = value_target = node.target tuple_target = None if keys and values: if node.target.is_sequence_constructor: if len(node.target.args) == 2: key_target, value_target = node.target.args else: # unusual case that may or may not lead to an error return node else: tuple_target = node.target def coerce_object_to(obj_node, dest_type): class FakeEnv(object): nogil = False if dest_type.is_pyobject: if dest_type.is_extension_type or dest_type.is_builtin_type: obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv()) result = ExprNodes.TypecastNode( obj_node.pos, operand = obj_node, type = dest_type) return (result, None) else: temp = UtilNodes.TempHandle(dest_type) temps.append(temp) temp_result = temp.ref(obj_node.pos) class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode): def result(self): return temp_result.result() def generate_execution_code(self, code): self.generate_result_code(code) return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv())) if isinstance(node.body, Nodes.StatListNode): body = node.body else: body = Nodes.StatListNode(pos = node.body.pos, stats = [node.body]) if tuple_target: tuple_result = ExprNodes.TupleNode( pos = tuple_target.pos, args = [key_temp, value_temp], is_temp = 1, type = Builtin.tuple_type, ) body.stats.insert( 0, Nodes.SingleAssignmentNode( pos = tuple_target.pos, lhs = tuple_target, rhs = tuple_result)) else: # execute all coercions before the assignments coercion_stats = [] assign_stats = [] if keys: temp_result, coercion = coerce_object_to( key_temp, key_target.type) if coercion: coercion_stats.append(coercion) assign_stats.append( Nodes.SingleAssignmentNode( pos = key_temp.pos, lhs = key_target, rhs = temp_result)) if values: temp_result, coercion = coerce_object_to( value_temp, value_target.type) if coercion: coercion_stats.append(coercion) assign_stats.append( Nodes.SingleAssignmentNode( pos = value_temp.pos, lhs = value_target, rhs = temp_result)) body.stats[0:0] = coercion_stats + assign_stats result_code = [ Nodes.SingleAssignmentNode( pos = dict_obj.pos, lhs = dict_temp, rhs = dict_obj), Nodes.SingleAssignmentNode( pos = node.pos, lhs = pos_temp, rhs = ExprNodes.IntNode(node.pos, value='0', constant_result=0)), Nodes.WhileStatNode( pos = node.pos, condition = ExprNodes.SimpleCallNode( pos = dict_obj.pos, type = PyrexTypes.c_bint_type, function = ExprNodes.NameNode( pos = dict_obj.pos, name = self.PyDict_Next_name, type = self.PyDict_Next_func_type, entry = self.PyDict_Next_entry), args = [dict_temp, pos_temp_addr, key_temp_addr, value_temp_addr] ), body = body, else_clause = node.else_clause ) ] return UtilNodes.TempsBlockNode( node.pos, temps=temps, body=Nodes.StatListNode( node.pos, stats = result_code )) class SwitchTransform(Visitor.VisitorTransform): """ This transformation tries to turn long if statements into C switch statements. The requirement is that every clause be an (or of) var == value, where the var is common among all clauses and both var and value are ints. """ def extract_conditions(self, cond): while True: if isinstance(cond, ExprNodes.CoerceToTempNode): cond = cond.arg elif isinstance(cond, UtilNodes.EvalWithTempExprNode): # this is what we get from the FlattenInListTransform cond = cond.subexpression elif isinstance(cond, ExprNodes.TypecastNode): cond = cond.operand else: break if (isinstance(cond, ExprNodes.PrimaryCmpNode) and cond.cascade is None and cond.operator == '==' and not cond.is_python_comparison()): if is_common_value(cond.operand1, cond.operand1): if isinstance(cond.operand2, ExprNodes.ConstNode): return cond.operand1, [cond.operand2] elif hasattr(cond.operand2, 'entry') and cond.operand2.entry and cond.operand2.entry.is_const: return cond.operand1, [cond.operand2] if is_common_value(cond.operand2, cond.operand2): if isinstance(cond.operand1, ExprNodes.ConstNode): return cond.operand2, [cond.operand1] elif hasattr(cond.operand1, 'entry') and cond.operand1.entry and cond.operand1.entry.is_const: return cond.operand2, [cond.operand1] elif (isinstance(cond, ExprNodes.BoolBinopNode) and cond.operator == 'or'): t1, c1 = self.extract_conditions(cond.operand1) t2, c2 = self.extract_conditions(cond.operand2) if is_common_value(t1, t2): return t1, c1+c2 return None, None def visit_IfStatNode(self, node): self.visitchildren(node) common_var = None case_count = 0 cases = [] for if_clause in node.if_clauses: var, conditions = self.extract_conditions(if_clause.condition) if var is None: return node elif common_var is not None and not is_common_value(var, common_var): return node elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]): return node else: common_var = var case_count += len(conditions) cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, conditions = conditions, body = if_clause.body)) if case_count < 2: return node common_var = unwrap_node(common_var) return Nodes.SwitchStatNode(pos = node.pos, test = common_var, cases = cases, else_clause = node.else_clause) visit_Node = Visitor.VisitorTransform.recurse_to_children class FlattenInListTransform(Visitor.VisitorTransform, SkipDeclarations): """ This transformation flattens "x in [val1, ..., valn]" into a sequential list of comparisons. """ def visit_PrimaryCmpNode(self, node): self.visitchildren(node) if node.cascade is not None: return node elif node.operator == 'in': conjunction = 'or' eq_or_neq = '==' elif node.operator == 'not_in': conjunction = 'and' eq_or_neq = '!=' else: return node if not isinstance(node.operand2, (ExprNodes.TupleNode, ExprNodes.ListNode)): return node args = node.operand2.args if len(args) == 0: return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in') lhs = UtilNodes.ResultRefNode(node.operand1) conds = [] for arg in args: cond = ExprNodes.PrimaryCmpNode( pos = node.pos, operand1 = lhs, operator = eq_or_neq, operand2 = arg, cascade = None) conds.append(ExprNodes.TypecastNode( pos = node.pos, operand = cond, type = PyrexTypes.c_bint_type)) def concat(left, right): return ExprNodes.BoolBinopNode( pos = node.pos, operator = conjunction, operand1 = left, operand2 = right) condition = reduce(concat, conds) return UtilNodes.EvalWithTempExprNode(lhs, condition) visit_Node = Visitor.VisitorTransform.recurse_to_children class OptimizeBuiltinCalls(Visitor.VisitorTransform): """Optimize some common methods calls and instantiation patterns for builtin types. """ # only intercept on call nodes visit_Node = Visitor.VisitorTransform.recurse_to_children def visit_GeneralCallNode(self, node): self.visitchildren(node) function = node.function if not function.type.is_pyobject: return node arg_tuple = node.positional_args if not isinstance(arg_tuple, ExprNodes.TupleNode): return node return self._dispatch_to_handler( node, function, arg_tuple, node.keyword_args) def visit_SimpleCallNode(self, node): self.visitchildren(node) function = node.function if not function.type.is_pyobject: return node arg_tuple = node.arg_tuple if not isinstance(arg_tuple, ExprNodes.TupleNode): return node return self._dispatch_to_handler( node, node.function, arg_tuple) def visit_PyTypeTestNode(self, node): """Flatten redundant type checks after tree changes. """ old_arg = node.arg self.visitchildren(node) if old_arg is node.arg or node.arg.type != node.type: return node return node.arg def _find_handler(self, match_name, has_kwargs): call_type = has_kwargs and 'general' or 'simple' handler = getattr(self, '_handle_%s_%s' % (call_type, match_name), None) if handler is None: handler = getattr(self, '_handle_any_%s' % match_name, None) return handler def _dispatch_to_handler(self, node, function, arg_tuple, kwargs=None): if function.is_name: match_name = "_function_%s" % function.name function_handler = self._find_handler( "function_%s" % function.name, kwargs) if function_handler is None: return node if kwargs: return function_handler(node, arg_tuple, kwargs) else: return function_handler(node, arg_tuple) elif isinstance(function, ExprNodes.AttributeNode): arg_list = arg_tuple.args self_arg = function.obj obj_type = self_arg.type is_unbound_method = False if obj_type.is_builtin_type: if obj_type is Builtin.type_type and arg_list and \ arg_list[0].type.is_pyobject: # calling an unbound method like 'list.append(L,x)' # (ignoring 'type.mro()' here ...) type_name = function.obj.name self_arg = None is_unbound_method = True else: type_name = obj_type.name else: type_name = "object" # safety measure method_handler = self._find_handler( "method_%s_%s" % (type_name, function.attribute), kwargs) if method_handler is None: return node if self_arg is not None: arg_list = [self_arg] + list(arg_list) if kwargs: return method_handler(node, arg_list, kwargs, is_unbound_method) else: return method_handler(node, arg_list, is_unbound_method) else: return node ### builtin types def _handle_general_function_dict(self, node, pos_args, kwargs): """Replace dict(a=b,c=d,...) by the underlying keyword dict construction which is done anyway. """ if len(pos_args.args) > 0: return node if not isinstance(kwargs, ExprNodes.DictNode): return node if node.starstar_arg: # we could optimize this by updating the kw dict instead return node return kwargs PyDict_Copy_func_type = PyrexTypes.CFuncType( Builtin.dict_type, [ PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None) ]) def _handle_simple_function_dict(self, node, pos_args): """Replace dict(some_dict) by PyDict_Copy(some_dict) and dict([ (a,b) for ... ]) by a literal { a:b for ... }. """ if len(pos_args.args) != 1: return node arg = pos_args.args[0] if arg.type is Builtin.dict_type: arg = ExprNodes.NoneCheckNode( arg, "PyExc_TypeError", "'NoneType' is not iterable") return ExprNodes.PythonCapiCallNode( node.pos, "PyDict_Copy", self.PyDict_Copy_func_type, args = [dict_arg], is_temp = node.is_temp ) elif isinstance(arg, ExprNodes.ComprehensionNode) and \ arg.type is Builtin.list_type: append_node = arg.append if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \ len(append_node.expr.args) == 2: key_node, value_node = append_node.expr.args target_node = ExprNodes.DictNode( pos=arg.target.pos, key_value_pairs=[], is_temp=1) new_append_node = ExprNodes.DictComprehensionAppendNode( append_node.pos, target=target_node, key_expr=key_node, value_expr=value_node, is_temp=1) arg.target = target_node arg.type = target_node.type replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node) return replace_in(arg) return node def _handle_simple_function_set(self, node, pos_args): """Replace set([a,b,...]) by a literal set {a,b,...} and set([ x for ... ]) by a literal { x for ... }. """ arg_count = len(pos_args.args) if arg_count == 0: return ExprNodes.SetNode(node.pos, args=[], type=Builtin.set_type, is_temp=1) if arg_count > 1: return node iterable = pos_args.args[0] if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)): return ExprNodes.SetNode(node.pos, args=iterable.args, type=Builtin.set_type, is_temp=1) elif isinstance(iterable, ExprNodes.ComprehensionNode) and \ iterable.type is Builtin.list_type: iterable.target = ExprNodes.SetNode( node.pos, args=[], type=Builtin.set_type, is_temp=1) iterable.type = Builtin.set_type iterable.pos = node.pos return iterable else: return node PyList_AsTuple_func_type = PyrexTypes.CFuncType( Builtin.tuple_type, [ PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None) ]) def _handle_simple_function_tuple(self, node, pos_args): """Replace tuple([...]) by a call to PyList_AsTuple. """ if len(pos_args.args) != 1: return node list_arg = pos_args.args[0] if list_arg.type is not Builtin.list_type: return node if not isinstance(list_arg, (ExprNodes.ComprehensionNode, ExprNodes.ListNode)): pos_args.args[0] = ExprNodes.NoneCheckNode( list_arg, "PyExc_TypeError", "'NoneType' object is not iterable") return ExprNodes.PythonCapiCallNode( node.pos, "PyList_AsTuple", self.PyList_AsTuple_func_type, args = pos_args.args, is_temp = node.is_temp ) ### builtin functions PyObject_GetAttr2_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None), ]) PyObject_GetAttr3_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), ]) def _handle_simple_function_getattr(self, node, pos_args): args = pos_args.args if len(args) == 2: node = ExprNodes.PythonCapiCallNode( node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type, args = args, is_temp = node.is_temp ) elif len(args) == 3: node = ExprNodes.PythonCapiCallNode( node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type, utility_code = Builtin.getattr3_utility_code, args = args, is_temp = node.is_temp ) else: error(node.pos, "getattr() called with wrong number of args, " "expected 2 or 3, found %d" % len(args)) return node Pyx_Type_func_type = PyrexTypes.CFuncType( Builtin.type_type, [ PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) ]) def _handle_simple_function_type(self, node, pos_args): args = pos_args.args if len(args) != 1: return node node = ExprNodes.PythonCapiCallNode( node.pos, "__Pyx_Type", self.Pyx_Type_func_type, args = args, is_temp = node.is_temp, utility_code = pytype_utility_code, ) return node ### methods of builtin types PyObject_Append_func_type = PyrexTypes.CFuncType( PyrexTypes.py_object_type, [ PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), ]) def _handle_simple_method_object_append(self, node, args, is_unbound_method): # X.append() is almost always referring to a list if len(args) != 2: return node return ExprNodes.PythonCapiCallNode( node.pos, "__Pyx_PyObject_Append", self.PyObject_Append_func_type, args = args, is_temp = node.is_temp, utility_code = append_utility_code ) PyList_Append_func_type = PyrexTypes.CFuncType( PyrexTypes.c_int_type, [ PyrexTypes.CFuncTypeArg("list", PyrexTypes.py_object_type, None), PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None), ], exception_value = "-1") def _handle_simple_method_list_append(self, node, args, is_unbound_method): if len(args) != 2: error(node.pos, "list.append(x) called with wrong number of args, found %d" % len(args)) return node return self._substitute_method_call( node, "PyList_Append", self.PyList_Append_func_type, 'append', is_unbound_method, args) single_param_func_type = PyrexTypes.CFuncType( PyrexTypes.c_int_type, [ PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), ], exception_value = "-1") def _handle_simple_method_list_sort(self, node, args, is_unbound_method): if len(args) != 1: return node return self._substitute_method_call( node, "PyList_Sort", self.single_param_func_type, 'sort', is_unbound_method, args) def _handle_simple_method_list_reverse(self, node, args, is_unbound_method): if len(args) != 1: error(node.pos, "list.reverse(x) called with wrong number of args, found %d" % len(args)) return node return self._substitute_method_call( node, "PyList_Reverse", self.single_param_func_type, 'reverse', is_unbound_method, args) PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( Builtin.bytes_type, [ PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), PyrexTypes.CFuncTypeArg("encoding", PyrexTypes.c_char_ptr_type, None), PyrexTypes.CFuncTypeArg("errors", PyrexTypes.c_char_ptr_type, None), ], exception_value = "NULL") PyUnicode_AsXyzString_func_type = PyrexTypes.CFuncType( Builtin.bytes_type, [ PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), ], exception_value = "NULL") _special_encodings = ['UTF8', 'UTF16', 'Latin1', 'ASCII', 'unicode_escape', 'raw_unicode_escape'] _special_encoders = [ (name, codecs.getencoder(name)) for name in _special_encodings ] def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method): if len(args) < 1 or len(args) > 3: error(node.pos, "unicode.encode(...) called with wrong number of args, found %d" % len(args)) return node null_node = ExprNodes.NullNode(node.pos) string_node = args[0] if len(args) == 1: return self._substitute_method_call( node, "PyUnicode_AsEncodedString", self.PyUnicode_AsEncodedString_func_type, 'encode', is_unbound_method, [string_node, null_node, null_node]) encoding_node = args[1] if isinstance(encoding_node, ExprNodes.CoerceToPyTypeNode): encoding_node = encoding_node.arg if not isinstance(encoding_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode)): return node encoding = encoding_node.value encoding_node = ExprNodes.StringNode(encoding_node.pos, value=encoding, type=PyrexTypes.c_char_ptr_type) if len(args) == 3: error_handling_node = args[2] if isinstance(error_handling_node, ExprNodes.CoerceToPyTypeNode): error_handling_node = error_handling_node.arg if not isinstance(error_handling_node, (ExprNodes.UnicodeNode, ExprNodes.StringNode)): return node error_handling = error_handling_node.value if error_handling == 'strict': error_handling_node = null_node else: error_handling_node = ExprNodes.StringNode( error_handling_node.pos, value=error_handling, type=PyrexTypes.c_char_ptr_type) else: error_handling = 'strict' error_handling_node = null_node if isinstance(string_node, ExprNodes.UnicodeNode): # constant, so try to do the encoding at compile time try: value = string_node.value.encode(encoding, error_handling) except: # well, looks like we can't pass else: value = BytesLiteral(value) value.encoding = encoding return ExprNodes.StringNode( string_node.pos, value=value, type=Builtin.bytes_type) if error_handling == 'strict': # try to find a specific encoder function try: requested_encoder = codecs.getencoder(encoding) except: pass else: encode_function = None for name, encoder in self._special_encoders: if encoder == requested_encoder: if '_' in name: name = ''.join([ s.capitalize() for s in name.split('_')]) encode_function = "PyUnicode_As%sString" % name break if encode_function is not None: return self._substitute_method_call( node, encode_function, self.PyUnicode_AsXyzString_func_type, 'encode', is_unbound_method, [string_node]) return self._substitute_method_call( node, "PyUnicode_AsEncodedString", self.PyUnicode_AsEncodedString_func_type, 'encode', is_unbound_method, [string_node, encoding_node, error_handling_node]) def _substitute_method_call(self, node, name, func_type, attr_name, is_unbound_method, args=()): args = list(args) if args: self_arg = args[0] if is_unbound_method: self_arg = ExprNodes.NoneCheckNode( self_arg, "PyExc_TypeError", "descriptor '%s' requires a '%s' object but received a 'NoneType'" % ( attr_name, node.function.obj.name)) else: self_arg = ExprNodes.NoneCheckNode( self_arg, "PyExc_AttributeError", "'NoneType' object has no attribute '%s'" % attr_name) args[0] = self_arg # FIXME: args[0] may need a runtime None check (ticket #166) return ExprNodes.PythonCapiCallNode( node.pos, name, func_type, args = args, is_temp = node.is_temp ) append_utility_code = UtilityCode( proto = """ static INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) { if (likely(PyList_CheckExact(L))) { if (PyList_Append(L, x) < 0) return NULL; Py_INCREF(Py_None); return Py_None; /* this is just to have an accurate signature */ } else { PyObject *r, *m; m = __Pyx_GetAttrString(L, "append"); if (!m) return NULL; r = PyObject_CallFunctionObjArgs(m, x, NULL); Py_DECREF(m); return r; } } """, impl = "" ) pytype_utility_code = UtilityCode( proto = """ static INLINE PyObject* __Pyx_Type(PyObject* o) { PyObject* type = (PyObject*) Py_TYPE(o); Py_INCREF(type); return type; } """ ) class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): """Calculate the result of constant expressions to store it in ``expr_node.constant_result``, and replace trivial cases by their constant result. """ def _calculate_const(self, node): if node.constant_result is not ExprNodes.constant_value_not_set: return # make sure we always set the value not_a_constant = ExprNodes.not_a_constant node.constant_result = not_a_constant # check if all children are constant children = self.visitchildren(node) for child_result in children.itervalues(): if type(child_result) is list: for child in child_result: if child.constant_result is not_a_constant: return elif child_result.constant_result is not_a_constant: return # now try to calculate the real constant value try: node.calculate_constant_result() # if node.constant_result is not ExprNodes.not_a_constant: # print node.__class__.__name__, node.constant_result except (ValueError, TypeError, KeyError, IndexError, AttributeError): # ignore all 'normal' errors here => no constant result pass except Exception: # this looks like a real error import traceback, sys traceback.print_exc(file=sys.stdout) NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode, ExprNodes.LongNode, ExprNodes.FloatNode) def _widest_node_class(self, *nodes): try: return self.NODE_TYPE_ORDER[ max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))] except ValueError: return None def visit_ExprNode(self, node): self._calculate_const(node) return node def visit_BinopNode(self, node): self._calculate_const(node) if node.constant_result is ExprNodes.not_a_constant: return node if isinstance(node.constant_result, float): # We calculate float constants to make them available to # the compiler, but we do not aggregate them into a # constant node to prevent any loss of precision. return node if not isinstance(node.operand1, ExprNodes.ConstNode) or \ not isinstance(node.operand2, ExprNodes.ConstNode): # We calculate other constants to make them available to # the compiler, but we only aggregate constant nodes # recursively, so non-const nodes are straight out. return node # now inject a new constant node with the calculated value try: type1, type2 = node.operand1.type, node.operand2.type if type1 is None or type2 is None: return node except AttributeError: return node if type1 is type2: new_node = node.operand1 else: widest_type = PyrexTypes.widest_numeric_type(type1, type2) if type(node.operand1) is type(node.operand2): new_node = node.operand1 new_node.type = widest_type elif type1 is widest_type: new_node = node.operand1 elif type2 is widest_type: new_node = node.operand2 else: target_class = self._widest_node_class( node.operand1, node.operand2) if target_class is None: return node new_node = target_class(pos=node.pos, type = widest_type) new_node.constant_result = node.constant_result new_node.value = str(node.constant_result) #new_node = new_node.coerce_to(node.type, self.current_scope) return new_node # in the future, other nodes can have their own handler method here # that can replace them with a constant result node visit_Node = Visitor.VisitorTransform.recurse_to_children class FinalOptimizePhase(Visitor.CythonTransform): """ This visitor handles several commuting optimizations, and is run just before the C code generation phase. The optimizations currently implemented in this class are: - Eliminate None assignment and refcounting for first assignment. - isinstance -> typecheck for cdef types """ def visit_SingleAssignmentNode(self, node): """Avoid redundant initialisation of local variables before their first assignment. """ self.visitchildren(node) if node.first: lhs = node.lhs lhs.lhs_of_first_assignment = True if isinstance(lhs, ExprNodes.NameNode) and lhs.entry.type.is_pyobject: # Have variable initialized to 0 rather than None lhs.entry.init_to_none = False lhs.entry.init = 0 return node def visit_SimpleCallNode(self, node): """Replace generic calls to isinstance(x, type) by a more efficient type check. """ self.visitchildren(node) if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode): if node.function.name == 'isinstance': type_arg = node.args[1] if type_arg.type.is_builtin_type and type_arg.type.name == 'type': from CythonScope import utility_scope node.function.entry = utility_scope.lookup('PyObject_TypeCheck') node.function.type = node.function.entry.type PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type) node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr) return node