From 44490b49ee45b50eea3156c76b8f2ef9594d44e1 Mon Sep 17 00:00:00 2001
From: Lisandro Dalcin <dalcinl@gmail.com>
Date: Wed, 19 Jul 2017 22:17:29 +0300
Subject: [PATCH] Update embedsignature directive

* emit function annotations
* implement ExpressionWriter visitor
---
 Cython/CodeWriter.py                 | 273 +++++++++++++++++++++++++++
 Cython/Compiler/AutoDocTransforms.py | 115 +++++------
 tests/run/embedsignatures.pyx        |   8 +-
 3 files changed, 326 insertions(+), 70 deletions(-)

diff --git a/Cython/CodeWriter.py b/Cython/CodeWriter.py
index cfdbf48a6..5b5566642 100644
--- a/Cython/CodeWriter.py
+++ b/Cython/CodeWriter.py
@@ -519,3 +519,276 @@ class PxdWriter(DeclarationWriter):
 
     def visit_StatNode(self, node):
         pass
+
+
+class ExpressionWriter(TreeVisitor):
+
+    def __init__(self, result=None):
+        super(ExpressionWriter, self).__init__()
+        if result is None:
+            result = u""
+        self.result = result
+
+    def write(self, tree):
+        self.visit(tree)
+        return self.result
+
+    def put(self, s):
+        self.result += s
+
+    def remove(self, s):
+        if self.result.endswith(s):
+            self.result = self.result[:-len(s)]
+
+    def comma_separated_list(self, items):
+        if len(items) > 0:
+            for item in items[:-1]:
+                self.visit(item)
+                self.put(u", ")
+            self.visit(items[-1])
+
+    def visit_Node(self, node):
+        raise AssertionError("Node not handled by serializer: %r" % node)
+
+    def visit_NameNode(self, node):
+        self.put(node.name)
+
+    def visit_NoneNode(self, node):
+        self.put(u"None")
+
+    def visit_BoolNode(self, node):
+        self.put(str(node.value))
+
+    def visit_ConstNode(self, node):
+        self.put(str(node.value))
+
+    def visit_ImagNode(self, node):
+        self.put(node.value)
+        self.put(u"j")
+
+    def visit_BytesNode(self, node):
+        repr_val = repr(node.value)
+        if repr_val[0] == 'b':
+            repr_val = repr_val[1:]
+        self.put(u"b%s" % repr_val)
+
+    def visit_StringNode(self, node):
+        repr_val = repr(node.value)
+        if repr_val[0] in 'ub':
+            repr_val = repr_val[1:]
+        self.put(u"%s" % repr_val)
+
+    def visit_UnicodeNode(self, node):
+        repr_val = repr(node.value)
+        if repr_val[0] == 'u':
+            repr_val = repr_val[1:]
+        self.put(u"u%s" % repr_val)
+
+    def emit_sequence(self, node, parens):
+        open_paren, close_paren = parens
+        items = node.subexpr_nodes()
+        self.put(open_paren)
+        self.comma_separated_list(items)
+        self.put(close_paren)
+
+    def visit_ListNode(self, node):
+        self.emit_sequence(node, u"[]")
+
+    def visit_TupleNode(self, node):
+        self.emit_sequence(node, u"()")
+
+    def visit_SetNode(self, node):
+        self.emit_sequence(node, u"{}")
+
+    def visit_DictNode(self, node):
+        self.emit_sequence(node, u"{}")
+
+    def visit_DictItemNode(self, node):
+        self.visit(node.key)
+        self.put(u": ")
+        self.visit(node.value)
+
+    unop_precedence = {
+        'not': 3, '!': 3,
+        '+': 11, '-': 11, '~': 11,
+    }
+    binop_precedence = {
+        'or': 1,
+        'and': 2,
+        # unary: 'not': 3, '!': 3,
+        'in': 4, 'not_in': 4, 'is': 4, 'is_not': 4, '<': 4, '<=': 4, '>': 4, '>=': 4, '!=': 4, '==': 4,
+        '|': 5,
+        '^': 6,
+        '&': 7,
+        '<<': 8, '>>': 8,
+        '+': 9, '-': 9,
+        '*': 10, '/': 10, '//': 10, '%': 10,
+        # unary: '+': 11, '-': 11, '~': 11
+        '**': 12,
+    }
+
+    def operator_enter(self, new_prec):
+        if not hasattr(self, 'precedence'):
+            self.precedence = [0]
+        old_prec = self.precedence[-1]
+        if old_prec > new_prec:
+            self.put(u"(")
+        self.precedence.append(new_prec)
+
+    def operator_exit(self):
+        old_prec, new_prec = self.precedence[-2:]
+        if old_prec > new_prec:
+            self.put(u")")
+        self.precedence.pop()
+
+    def visit_NotNode(self, node):
+        op = 'not'
+        prec = self.unop_precedence[op]
+        self.operator_enter(prec)
+        self.put(u"not ")
+        self.visit(node.operand)
+        self.operator_exit()
+
+    def visit_UnopNode(self, node):
+        op = node.operator
+        prec = self.unop_precedence[op]
+        self.operator_enter(prec)
+        self.put(u"%s" % node.operator)
+        self.visit(node.operand)
+        self.operator_exit()
+
+    def visit_BinopNode(self, node):
+        op = node.operator
+        prec = self.binop_precedence.get(op, 0)
+        self.operator_enter(prec)
+        self.visit(node.operand1)
+        self.put(u" %s " % op.replace('_', ' '))
+        self.visit(node.operand2)
+        self.operator_exit()
+
+    def visit_BoolBinopNode(self, node):
+        self.visit_BinopNode(node)
+
+    def visit_PrimaryCmpNode(self, node):
+        self.visit_BinopNode(node)
+
+    def visit_IndexNode(self, node):
+        self.visit(node.base)
+        self.put(u"[")
+        self.visit(node.index)
+        self.put(u"]")
+
+    def visit_SliceIndexNode(self, node):
+        self.visit(node.base)
+        self.put(u"[")
+        if node.start:
+            self.visit(node.start)
+        self.put(u":")
+        if node.stop:
+            self.visit(node.stop)
+        if node.slice:
+            self.put(u":")
+            self.visit(node.slice)
+        self.put(u"]")
+
+    def visit_SliceNode(self, node):
+        if not node.start.is_none:
+            self.visit(node.start)
+        self.put(u":")
+        if not node.stop.is_none:
+            self.visit(node.stop)
+        if not node.step.is_none:
+            self.put(u":")
+            self.visit(node.step)
+
+    def visit_CondExprNode(self, node):
+        self.visit(node.true_val)
+        self.put(u" if ")
+        self.visit(node.test)
+        self.put(u" else ")
+        self.visit(node.false_val)
+
+    def visit_AttributeNode(self, node):
+        self.visit(node.obj)
+        self.put(u".%s" % node.attribute)
+
+    def visit_SimpleCallNode(self, node):
+        self.visit(node.function)
+        self.put(u"(")
+        self.comma_separated_list(node.args)
+        self.put(")")
+
+    def emit_pos_args(self, node):
+        if node is None:
+            return
+        if isinstance(node, AddNode):
+            self.emit_pos_args(node.operand1)
+            self.emit_pos_args(node.operand2)
+        elif isinstance(node, TupleNode):
+            for expr in node.subexpr_nodes():
+                self.visit(expr)
+                self.put(u", ")
+        elif isinstance(node, AsTupleNode):
+            self.put("*")
+            self.visit(node.arg)
+            self.put(u", ")
+        else:
+            self.visit(node)
+            self.put(u", ")
+
+    def emit_kwd_args(self, node):
+        if node is None:
+            return
+        if isinstance(node, MergedDictNode):
+            for expr in node.subexpr_nodes():
+                self.emit_kwd_args(expr)
+        elif isinstance(node, DictNode):
+            for expr in node.subexpr_nodes():
+                self.put(u"%s=" % expr.key.value)
+                self.visit(expr.value)
+                self.put(u", ")
+        else:
+            self.put(u"**")
+            self.visit(node)
+            self.put(u", ")
+
+    def visit_GeneralCallNode(self, node):
+        self.visit(node.function)
+        self.put(u"(")
+        self.emit_pos_args(node.positional_args)
+        self.emit_kwd_args(node.keyword_args)
+        self.remove(u", ")
+        self.put(")")
+
+    def visit_ComprehensionNode(self, node):
+        tpmap = {'list': u"[]", 'dict': u"{}", 'set': u"{}"}
+        parens = tpmap[node.type.py_type_name()]
+        open_paren, close_paren = parens
+
+        body = node.loop.body
+        target = node.loop.target
+        sequence = node.loop.iterator.sequence
+        if isinstance(body, ComprehensionAppendNode):
+            condition = None
+        else:
+            condition = body.if_clauses[0].condition
+            body = body.if_clauses[0].body
+
+        self.put(open_paren)
+        self.visit(body)
+        self.put(u" for ")
+        self.visit(target)
+        self.put(u" in ")
+        self.visit(sequence)
+        if condition:
+            self.put(u" if ")
+            self.visit(condition)
+        self.put(close_paren)
+
+    def visit_ComprehensionAppendNode(self, node):
+        self.visit(node.expr)
+
+    def visit_DictComprehensionAppendNode(self, node):
+        self.visit(node.key_expr)
+        self.put(u": ")
+        self.visit(node.value_expr)
diff --git a/Cython/Compiler/AutoDocTransforms.py b/Cython/Compiler/AutoDocTransforms.py
index 8adaba099..d3c0a1d0d 100644
--- a/Cython/Compiler/AutoDocTransforms.py
+++ b/Cython/Compiler/AutoDocTransforms.py
@@ -1,89 +1,59 @@
-from __future__ import absolute_import
+from __future__ import absolute_import, print_function
 
 from .Visitor import CythonTransform
 from .StringEncoding import EncodedString
 from . import Options
 from . import PyrexTypes, ExprNodes
+from ..CodeWriter import ExpressionWriter
+
+
+class AnnotationWriter(ExpressionWriter):
+
+    def visit_Node(self, node):
+        self.put(u"<???>")
+
+    def visit_LambdaNode(self, node):
+        # XXX Should we do better?
+        self.put("<lambda>")
+
 
 class EmbedSignature(CythonTransform):
 
     def __init__(self, context):
         super(EmbedSignature, self).__init__(context)
-        self.denv = None # XXX
         self.class_name = None
         self.class_node = None
 
-    unop_precedence = 11
-    binop_precedence = {
-        'or': 1,
-        'and': 2,
-        'not': 3,
-        'in': 4, 'not in': 4, 'is': 4, 'is not': 4, '<': 4, '<=': 4, '>': 4, '>=': 4, '!=': 4, '==': 4,
-        '|': 5,
-        '^': 6,
-        '&': 7,
-        '<<': 8, '>>': 8,
-        '+': 9, '-': 9,
-        '*': 10, '/': 10, '//': 10, '%': 10,
-        # unary: '+': 11, '-': 11, '~': 11
-        '**': 12}
-
-    def _fmt_expr_node(self, node, precedence=0):
-        if isinstance(node, ExprNodes.BinopNode) and not node.inplace:
-            new_prec = self.binop_precedence.get(node.operator, 0)
-            result = '%s %s %s' % (self._fmt_expr_node(node.operand1, new_prec),
-                                   node.operator,
-                                   self._fmt_expr_node(node.operand2, new_prec))
-            if precedence > new_prec:
-                result = '(%s)' % result
-        elif isinstance(node, ExprNodes.UnopNode):
-            result = '%s%s' % (node.operator,
-                               self._fmt_expr_node(node.operand, self.unop_precedence))
-            if precedence > self.unop_precedence:
-                result = '(%s)' % result
-        elif isinstance(node, ExprNodes.AttributeNode):
-            result = '%s.%s' % (self._fmt_expr_node(node.obj), node.attribute)
-        else:
-            result = node.name
+    def _fmt_expr(self, node):
+        writer = AnnotationWriter()
+        result = writer.write(node)
+        # print(type(node).__name__, '-->', result)
         return result
 
-    def _fmt_arg_defv(self, arg):
-        default_val = arg.default
-        if not default_val:
-            return None
-        if isinstance(default_val, ExprNodes.NullNode):
-            return 'NULL'
-        try:
-            denv = self.denv  # XXX
-            ctval = default_val.compile_time_value(self.denv)
-            repr_val = repr(ctval)
-            if isinstance(default_val, ExprNodes.UnicodeNode):
-                if repr_val[:1] != 'u':
-                    return u'u%s' % repr_val
-            elif isinstance(default_val, ExprNodes.BytesNode):
-                if repr_val[:1] != 'b':
-                    return u'b%s' % repr_val
-            elif isinstance(default_val, ExprNodes.StringNode):
-                if repr_val[:1] in 'ub':
-                    return repr_val[1:]
-            return repr_val
-        except Exception:
-            try:
-                return self._fmt_expr_node(default_val)
-            except AttributeError:
-                return '<???>'
-
     def _fmt_arg(self, arg):
         if arg.type is PyrexTypes.py_object_type or arg.is_self_arg:
             doc = arg.name
         else:
             doc = arg.type.declaration_code(arg.name, for_display=1)
-        if arg.default:
-            arg_defv = self._fmt_arg_defv(arg)
-            if arg_defv:
-                doc = doc + ('=%s' % arg_defv)
+
+        if arg.annotation:
+            annotation = self._fmt_expr(arg.annotation)
+            doc = doc + (': %s' % annotation)
+            if arg.default:
+                default = self._fmt_expr(arg.default)
+                doc = doc + (' = %s' % default)
+        elif arg.default:
+            default = self._fmt_expr(arg.default)
+            doc = doc + ('=%s' % default)
         return doc
 
+    def _fmt_star_arg(self, arg):
+        arg_doc = arg.name
+        if arg.annotation:
+            annotation = self._fmt_expr(arg.annotation)
+            arg_doc = arg_doc + (': %s' % annotation)
+        return arg_doc
+
     def _fmt_arglist(self, args,
                      npargs=0, pargs=None,
                      nkargs=0, kargs=None,
@@ -94,11 +64,13 @@ class EmbedSignature(CythonTransform):
                 arg_doc = self._fmt_arg(arg)
                 arglist.append(arg_doc)
         if pargs:
-            arglist.insert(npargs, '*%s' % pargs.name)
+            arg_doc = self._fmt_star_arg(pargs)
+            arglist.insert(npargs, '*%s' % arg_doc)
         elif nkargs:
             arglist.insert(npargs, '*')
         if kargs:
-            arglist.append('**%s' % kargs.name)
+            arg_doc = self._fmt_star_arg(kargs)
+            arglist.append('**%s' % arg_doc)
         return arglist
 
     def _fmt_ret_type(self, ret):
@@ -110,6 +82,7 @@ class EmbedSignature(CythonTransform):
     def _fmt_signature(self, cls_name, func_name, args,
                        npargs=0, pargs=None,
                        nkargs=0, kargs=None,
+                       return_expr=None,
                        return_type=None, hide_self=False):
         arglist = self._fmt_arglist(args,
                                     npargs, pargs,
@@ -119,10 +92,13 @@ class EmbedSignature(CythonTransform):
         func_doc = '%s(%s)' % (func_name, arglist_doc)
         if cls_name:
             func_doc = '%s.%s' % (cls_name, func_doc)
-        if return_type:
+        ret_doc = None
+        if return_expr:
+            ret_doc = self._fmt_expr(return_expr)
+        elif return_type:
             ret_doc = self._fmt_ret_type(return_type)
-            if ret_doc:
-                func_doc = '%s -> %s' % (func_doc, ret_doc)
+        if ret_doc:
+            func_doc = '%s -> %s' % (func_doc, ret_doc)
         return func_doc
 
     def _embed_signature(self, signature, node_doc):
@@ -177,6 +153,7 @@ class EmbedSignature(CythonTransform):
             class_name, func_name, node.args,
             npargs, node.star_arg,
             nkargs, node.starstar_arg,
+            return_expr=node.return_type_annotation,
             return_type=None, hide_self=hide_self)
         if signature:
             if is_constructor:
diff --git a/tests/run/embedsignatures.pyx b/tests/run/embedsignatures.pyx
index b629e9c81..32581d150 100644
--- a/tests/run/embedsignatures.pyx
+++ b/tests/run/embedsignatures.pyx
@@ -80,6 +80,9 @@ __doc__ = ur"""
     >>> print (Ext.m.__doc__)
     Ext.m(self, a=u'spam')
 
+    >>> print (Ext.n.__doc__)
+    Ext.n(self, a: int, b: float = 1.0, *args: tuple, **kwargs: dict) -> (None, True)
+
     >>> print (Ext.get_int.__doc__)
     Ext.get_int(self) -> int
 
@@ -185,7 +188,7 @@ __doc__ = ur"""
     f_defexpr4(int x=(Ext.CONST1 + FLAG1) * Ext.CONST2)
 
     >>> print(funcdoc(f_defexpr5))
-    f_defexpr5(int x=4)
+    f_defexpr5(int x=2 + 2)
 
     >>> print(funcdoc(f_charptr_null))
     f_charptr_null(char *s=NULL) -> char *
@@ -259,6 +262,9 @@ cdef class Ext:
     def m(self, a=u'spam'):
         pass
 
+    def n(self, a: int, b: float = 1.0, *args: tuple, **kwargs: dict) -> (None, True):
+        pass
+
     cpdef int get_int(self):
         return 0
 
-- 
2.30.9