From 043fa54aaaa02ff5231c93b670443c122361d49b Mon Sep 17 00:00:00 2001
From: Stefan Behnel <stefan_ml@behnel.de>
Date: Sat, 20 Aug 2016 18:51:34 +0200
Subject: [PATCH] speed up f-string building with a specialised
 PyUnicode_Join() implementation

---
 CHANGES.rst                      |  2 +-
 Cython/Compiler/ExprNodes.py     | 36 ++++++++++++++-
 Cython/Utility/ModuleSetupCode.c |  2 +
 Cython/Utility/StringTools.c     | 75 ++++++++++++++++++++++++++++++++
 tests/run/fstring.pyx            | 70 +++++++++++++++++++++++++++--
 5 files changed, 178 insertions(+), 7 deletions(-)

diff --git a/CHANGES.rst b/CHANGES.rst
index c00564292..53b15a381 100644
--- a/CHANGES.rst
+++ b/CHANGES.rst
@@ -22,7 +22,7 @@ Features added
 * Buffer variables are no longer excluded from ``locals()``.
   Patch by da-woods.
 
-* Formatting C integers in f-strings is faster.
+* Building f-strings is faster, especially when formatting C integers.
 
 * for-loop iteration over "std::string".
 
diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index dceaed22e..e2b6dd773 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -2983,16 +2983,42 @@ class JoinedStrNode(ExprNode):
         code.mark_pos(self.pos)
         num_items = len(self.values)
         list_var = code.funcstate.allocate_temp(py_object_type, manage_ref=True)
+        ulength_var = code.funcstate.allocate_temp(PyrexTypes.c_py_ssize_t_type, manage_ref=False)
+        max_char_var = code.funcstate.allocate_temp(PyrexTypes.c_py_ucs4_type, manage_ref=False)
 
         code.putln('%s = PyTuple_New(%s); %s' % (
             list_var,
             num_items,
             code.error_goto_if_null(list_var, self.pos)))
         code.put_gotref(list_var)
+        code.putln("%s = 0;" % ulength_var)
+        code.putln("%s = 127;" % max_char_var)  # at least ASCII character range
 
         for i, node in enumerate(self.values):
             node.generate_evaluation_code(code)
             node.make_owned_reference(code)
+
+            ulength = "__Pyx_PyUnicode_GET_LENGTH(%s)" % node.py_result()
+            max_char_value = "__Pyx_PyUnicode_MAX_CHAR_VALUE(%s)" % node.py_result()
+            is_ascii = False
+            if isinstance(node, UnicodeNode):
+                try:
+                    node.value.encode('iso8859-1')
+                    max_char_value = '255'
+                    node.value.encode('us-ascii')
+                    is_ascii = True
+                except UnicodeEncodeError:
+                    pass
+                else:
+                    ulength = str(len(node.value))
+            elif isinstance(node, FormattedValueNode) and node.value.type.is_numeric:
+                is_ascii = True  # formatted C numbers are always ASCII
+
+            if not is_ascii:
+                code.putln("%s = (%s > %s) ? %s : %s;" % (
+                    max_char_var, max_char_value, max_char_var, max_char_value, max_char_var))
+            code.putln("%s += %s;" % (ulength_var, ulength))
+
             code.put_giveref(node.py_result())
             code.putln('PyTuple_SET_ITEM(%s, %s, %s);' % (list_var, i, node.py_result()))
             node.generate_post_assignment_code(code)
@@ -3000,14 +3026,20 @@ class JoinedStrNode(ExprNode):
 
         code.mark_pos(self.pos)
         self.allocate_temp_result(code)
-        code.putln('%s = PyUnicode_Join(%s, %s); %s' % (
+        code.globalstate.use_utility_code(UtilityCode.load_cached("JoinPyUnicode", "StringTools.c"))
+        code.putln('%s = __Pyx_PyUnicode_Join(%s, %d, %s, %s); %s' % (
             self.result(),
-            Naming.empty_unicode,
             list_var,
+            num_items,
+            ulength_var,
+            max_char_var,
             code.error_goto_if_null(self.py_result(), self.pos)))
         code.put_gotref(self.py_result())
+
         code.put_decref_clear(list_var, py_object_type)
         code.funcstate.release_temp(list_var)
+        code.funcstate.release_temp(ulength_var)
+        code.funcstate.release_temp(max_char_var)
 
 
 class FormattedValueNode(ExprNode):
diff --git a/Cython/Utility/ModuleSetupCode.c b/Cython/Utility/ModuleSetupCode.c
index 711ff175b..1dd52aa5e 100644
--- a/Cython/Utility/ModuleSetupCode.c
+++ b/Cython/Utility/ModuleSetupCode.c
@@ -172,6 +172,7 @@
                                               0 : _PyUnicode_Ready((PyObject *)(op)))
   #define __Pyx_PyUnicode_GET_LENGTH(u)   PyUnicode_GET_LENGTH(u)
   #define __Pyx_PyUnicode_READ_CHAR(u, i) PyUnicode_READ_CHAR(u, i)
+  #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u)   PyUnicode_MAX_CHAR_VALUE(u)
   #define __Pyx_PyUnicode_KIND(u)         PyUnicode_KIND(u)
   #define __Pyx_PyUnicode_DATA(u)         PyUnicode_DATA(u)
   #define __Pyx_PyUnicode_READ(k, d, i)   PyUnicode_READ(k, d, i)
@@ -185,6 +186,7 @@
   #define __Pyx_PyUnicode_READY(op)       (0)
   #define __Pyx_PyUnicode_GET_LENGTH(u)   PyUnicode_GET_SIZE(u)
   #define __Pyx_PyUnicode_READ_CHAR(u, i) ((Py_UCS4)(PyUnicode_AS_UNICODE(u)[i]))
+  #define __Pyx_PyUnicode_MAX_CHAR_VALUE(u)   ((sizeof(Py_UNICODE) == 2) ? 65535 : 1114111)
   #define __Pyx_PyUnicode_KIND(u)         (sizeof(Py_UNICODE))
   #define __Pyx_PyUnicode_DATA(u)         ((void*)PyUnicode_AS_UNICODE(u))
   /* (void)(k) => avoid unused variable warning due to macro: */
diff --git a/Cython/Utility/StringTools.c b/Cython/Utility/StringTools.c
index ea3ff7210..9e4a373a1 100644
--- a/Cython/Utility/StringTools.c
+++ b/Cython/Utility/StringTools.c
@@ -732,6 +732,81 @@ static CYTHON_INLINE PyObject* __Pyx_PyBytes_Join(PyObject* sep, PyObject* value
 #endif
 
 
+/////////////// JoinPyUnicode.proto ///////////////
+
+static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength,
+                                      Py_UCS4 max_char);
+
+/////////////// JoinPyUnicode ///////////////
+//@requires: IncludeStringH
+//@substitute: naming
+
+static PyObject* __Pyx_PyUnicode_Join(PyObject* value_tuple, Py_ssize_t value_count, Py_ssize_t result_ulength,
+                                      CYTHON_UNUSED Py_UCS4 max_char) {
+#if CYTHON_USE_UNICODE_INTERNALS && CYTHON_ASSUME_SAFE_MACROS && !CYTHON_AVOID_BORROWED_REFS
+    PyObject *result_uval;
+    int result_ukind;
+    Py_ssize_t i, char_pos;
+    void *result_udata;
+#if CYTHON_PEP393_ENABLED
+    // Py 3.3+  (post PEP-393)
+    result_uval = PyUnicode_New(result_ulength, max_char);
+    if (unlikely(!result_uval)) return NULL;
+    result_ukind = (max_char <= 255) ? PyUnicode_1BYTE_KIND : (max_char <= 65535) ? PyUnicode_2BYTE_KIND : PyUnicode_4BYTE_KIND;
+    result_udata = PyUnicode_DATA(result_uval);
+#else
+    // Py 2.x/3.2  (pre PEP-393)
+    result_uval = PyUnicode_FromUnicode(NULL, result_ulength);
+    if (unlikely(!result_uval)) return NULL;
+    result_ukind = sizeof(Py_UNICODE);
+    result_udata = PyUnicode_AS_UNICODE(result_uval);
+#endif
+
+    char_pos = 0;
+    for (i=0; i < value_count; i++) {
+        int ukind;
+        Py_ssize_t ulength;
+        void *udata;
+        PyObject *uval = PyTuple_GET_ITEM(value_tuple, i);
+        if (unlikely(__Pyx_PyUnicode_READY(uval)))
+            goto bad;
+        ulength = __Pyx_PyUnicode_GET_LENGTH(uval);
+        if (unlikely(!ulength))
+            continue;
+        if (unlikely(char_pos + ulength < 0))
+            goto overflow;
+        ukind = __Pyx_PyUnicode_KIND(uval);
+        udata = __Pyx_PyUnicode_DATA(uval);
+        if (!CYTHON_PEP393_ENABLED || ukind == result_ukind) {
+            memcpy(result_udata + char_pos * result_ukind, udata, ulength * result_ukind);
+        } else {
+            #if CYTHON_COMPILING_IN_CPYTHON && PY_VERSION_HEX >= 0x030300F0
+            _PyUnicode_FastCopyCharacters(result_uval, char_pos, uval, 0, ulength);
+            #else
+            Py_ssize_t j;
+            for (j=0; j < ulength; j++) {
+                Py_UCS4 uchar = __Pyx_PyUnicode_READ(ukind, udata, j);
+                __Pyx_PyUnicode_WRITE(result_ukind, result_udata, char_pos+j, uchar);
+            }
+            #endif
+        }
+        char_pos += ulength;
+    }
+    return result_uval;
+overflow:
+    PyErr_SetString(PyExc_OverflowError, "join() result is too long for a Python string");
+bad:
+    Py_DECREF(result_uval);
+    return NULL;
+#else
+    // non-CPython fallback
+    result_ulength++;
+    value_count++;
+    return PyUnicode_Join($empty_unicode, value_tuple);
+#endif
+}
+
+
 /////////////// BuildPyUnicode.proto ///////////////
 
 static PyObject* __Pyx_PyUnicode_BuildFromAscii(Py_ssize_t ulength, char* chars, int clength,
diff --git a/tests/run/fstring.pyx b/tests/run/fstring.pyx
index c1020b763..7c8e8a2e0 100644
--- a/tests/run/fstring.pyx
+++ b/tests/run/fstring.pyx
@@ -8,6 +8,15 @@
 import sys
 IS_PYPY = hasattr(sys, 'pypy_version_info')
 
+cdef extern from *:
+    int INT_MAX
+    long LONG_MAX
+    long LONG_MIN
+
+max_int = INT_MAX
+max_long = LONG_MAX
+min_long = LONG_MIN
+
 
 def format2(ab, cd):
     """
@@ -70,6 +79,29 @@ def format_c_numbers(signed char c, short s, int n, long l, float f, double d):
     return s1, s2, s3, s4
 
 
+def format_c_numbers_max(int n, long l):
+    """
+    >>> n, l = max_int, max_long
+    >>> s1, s2 = format_c_numbers_max(n, l)
+    >>> s1 == '{n}:{l}'.format(n=n, l=l) or s1
+    True
+    >>> s2 == '{n:012X}:{l:020X}'.format(n=n, l=l) or s2
+    True
+
+    >>> n, l = -max_int-1, -max_long-1
+    >>> s1, s2 = format_c_numbers_max(n, l)
+    >>> s1 == '{n}:{l}'.format(n=n, l=l) or s1
+    True
+    >>> s2 == '{n:012X}:{l:020X}'.format(n=n, l=l) or s2
+    True
+    """
+    s1 = f"{n}:{l}"
+    assert isinstance(s1, unicode), type(s1)
+    s2 = f"{n:012X}:{l:020X}"
+    assert isinstance(s2, unicode), type(s2)
+    return s1, s2
+
+
 def format_bool(bint x):
     """
     >>> a, b, c, d = format_bool(1)
@@ -125,15 +157,17 @@ def format_c_values(Py_UCS4 uchar, Py_UNICODE pyunicode):
     s2 = f"{pyunicode}"
     assert isinstance(s2, unicode), type(s2)
     l = [1, 2, 3]
-    s3 = f"{l.reverse()}"  # C int return value => None
+    s3 = f"{l.reverse()}"  # C int return value => 'None'
     assert isinstance(s3, unicode), type(s3)
     assert l == [3, 2, 1]
     return s, s1, s2, s3
 
 
+xyz_ustring = u'x脛y脰z'
+
 def format_strings(str s, unicode u):
-    """
-    >>> a, b, c, d = format_strings('abc', b'xyz'.decode('ascii'))
+    u"""
+    >>> a, b, c, d, e, f, g = format_strings('abc', b'xyz'.decode('ascii'))
     >>> print(a)
     abcxyz
     >>> print(b)
@@ -142,6 +176,28 @@ def format_strings(str s, unicode u):
     uxyzsabc
     >>> print(d)
     sabcuxyz
+    >>> print(e)
+    sabcu脛脛uxyz
+    >>> print(f)
+    sabcu\N{SNOWMAN}uxyz
+    >>> print(g)
+    sabcu\N{OLD PERSIAN SIGN A}uxyz\N{SNOWMAN}
+
+    >>> a, b, c, d, e, f, g = format_strings('abc', xyz_ustring)
+    >>> print(a)
+    abcx脛y脰z
+    >>> print(b)
+    x脛y脰zabc
+    >>> print(c)
+    ux脛y脰zsabc
+    >>> print(d)
+    sabcux脛y脰z
+    >>> print(e)
+    sabcu脛脛ux脛y脰z
+    >>> print(f)
+    sabcu\N{SNOWMAN}ux脛y脰z
+    >>> print(g)
+    sabcu\N{OLD PERSIAN SIGN A}ux脛y脰z\N{SNOWMAN}
     """
     a = f"{s}{u}"
     assert isinstance(a, unicode), type(a)
@@ -151,7 +207,13 @@ def format_strings(str s, unicode u):
     assert isinstance(c, unicode), type(c)
     d = f"s{s}u{u}"
     assert isinstance(d, unicode), type(d)
-    return a, b, c, d
+    e = f"s{s}u脛脛u{u}"
+    assert isinstance(e, unicode), type(e)
+    f = f"s{s}u\N{SNOWMAN}u{u}"
+    assert isinstance(f, unicode), type(f)
+    g = f"s{s}u\N{OLD PERSIAN SIGN A}u{u}\N{SNOWMAN}"
+    assert isinstance(g, unicode), type(g)
+    return a, b, c, d, e, f, g
 
 
 def format_str(str s1, str s2):
-- 
2.30.9