Commit bbbb58f0 authored by Kirill Smelkov's avatar Kirill Smelkov

golang_str: bstr/ustr support for + and *

Add support for +, *, += and *= operators to bstr and ustr.

For * rhs should be integer and the result, similarly to std strings, is
repetition of rhs times.

For + the other argument could be any supported string - bstr/ustr /
unicode/bytes/bytearray. And the result is always bstr or ustr:

    u()   +     *     ->  u()
    b()   +     *     ->  b()
    u''   +  u()/b()  ->  u()
    u''   +  u''      ->  u''
    b''   +  u()/b()  ->  b()
    b''   +      b''  ->  b''
    barr  +  u()/b()  ->  barr

in particular if lhs is bstr or ustr, the result will remain exactly of
original lhs type. This should be handy when one has e.g. bstr at hand
and wants to incrementally append something to it.

And if lhs is bytes/unicode, but we append bstr/ustr to it, we "upgrade"
the result to bstr/ustr correspondingly. Only if lhs is bytearray it
remains to stay that way because it is logical for appended object to
remain mutable if it was mutable in the beginning.

As before bytearray.__add__ and friends need to patched a bit for
bytearray not to reject ustr.
parent ebd18f3f
...@@ -265,8 +265,9 @@ object is either `bstr` or `ustr` correspondingly. ...@@ -265,8 +265,9 @@ object is either `bstr` or `ustr` correspondingly.
Usage example:: Usage example::
s = b('привет') # s is bstr corresponding to UTF-8 encoding of 'привет'. s = b('привет') # s is bstr corresponding to UTF-8 encoding of 'привет'.
s += ' мир' # s is b('привет мир')
for c in s: # c will iterate through for c in s: # c will iterate through
... # [u(_) for _ in ('п','р','и','в','е','т')] ... # [u(_) for _ in ('п','р','и','в','е','т',' ','м','и','р')]
def f(s): def f(s):
s = u(s) # make sure s is ustr, decoding as UTF-8(*) if it was bstr, bytes, bytearray or buffer. s = u(s) # make sure s is ustr, decoding as UTF-8(*) if it was bstr, bytes, bytearray or buffer.
......
...@@ -24,7 +24,7 @@ It is included from _golang.pyx . ...@@ -24,7 +24,7 @@ It is included from _golang.pyx .
from cpython cimport PyUnicode_AsUnicode, PyUnicode_GetSize, PyUnicode_FromUnicode from cpython cimport PyUnicode_AsUnicode, PyUnicode_GetSize, PyUnicode_FromUnicode
from cpython cimport PyUnicode_DecodeUTF8 from cpython cimport PyUnicode_DecodeUTF8
from cpython cimport PyTypeObject, Py_TYPE, richcmpfunc from cpython cimport PyTypeObject, Py_TYPE, richcmpfunc, binaryfunc
from cpython cimport Py_EQ, Py_NE, Py_LT, Py_GT, Py_LE, Py_GE from cpython cimport Py_EQ, Py_NE, Py_LT, Py_GT, Py_LE, Py_GE
from cpython.iterobject cimport PySeqIter_New from cpython.iterobject cimport PySeqIter_New
from cpython cimport PyObject_CheckBuffer from cpython cimport PyObject_CheckBuffer
...@@ -35,6 +35,12 @@ cdef extern from "Python.h": ...@@ -35,6 +35,12 @@ cdef extern from "Python.h":
ctypedef int (*initproc)(object, PyObject *, PyObject *) except -1 ctypedef int (*initproc)(object, PyObject *, PyObject *) except -1
ctypedef struct _XPyTypeObject "PyTypeObject": ctypedef struct _XPyTypeObject "PyTypeObject":
initproc tp_init initproc tp_init
PySequenceMethods *tp_as_sequence
ctypedef struct PySequenceMethods:
binaryfunc sq_concat
binaryfunc sq_inplace_concat
from libc.stdint cimport uint8_t from libc.stdint cimport uint8_t
...@@ -160,6 +166,16 @@ cdef _pyu_coerce(x): # -> ustr|unicode ...@@ -160,6 +166,16 @@ cdef _pyu_coerce(x): # -> ustr|unicode
else: else:
raise TypeError("u: coerce: invalid type %s" % type(x)) raise TypeError("u: coerce: invalid type %s" % type(x))
# _pybu_rcoerce coerces x from `x op b|u` to either bstr or ustr.
# NOTE bytearray is handled outside of this function.
cdef _pybu_rcoerce(x): # -> bstr|ustr
if isinstance(x, bytes):
return pyb(x)
elif isinstance(x, unicode):
return pyu(x)
else:
raise TypeError('b/u: coerce: invalid type %s' % type(x))
# __pystr converts obj to ~str of current python: # __pystr converts obj to ~str of current python:
# #
...@@ -307,6 +323,32 @@ class pybstr(bytes): ...@@ -307,6 +323,32 @@ class pybstr(bytes):
return pyu(self).__iter__() return pyu(self).__iter__()
# __add__, __radd__ (no need to override __iadd__)
def __add__(a, b):
# NOTE Cython < 3 does not automatically support __radd__ for cdef class
# https://cython.readthedocs.io/en/latest/src/userguide/migrating_to_cy30.html#arithmetic-special-methods
# but pybstr is currently _not_ cdef'ed class
# see also https://github.com/cython/cython/issues/4750
return pyb(bytes.__add__(a, _pyb_coerce(b)))
def __radd__(b, a):
# a.__add__(b) returned NotImplementedError, e.g. for unicode.__add__(bstr)
# u'' + b() -> u() ; same as u() + b() -> u()
# b'' + b() -> b() ; same as b() + b() -> b()
# barr + b() -> barr
if isinstance(a, bytearray):
# force `bytearray +=` to go via bytearray.sq_inplace_concat - see PyNumber_InPlaceAdd
return NotImplemented
a = _pybu_rcoerce(a)
return a.__add__(b)
# __mul__, __rmul__ (no need to override __imul__)
def __mul__(a, b):
return pyb(bytes.__mul__(a, b))
def __rmul__(b, a):
return b.__mul__(a)
# XXX cannot `cdef class` with __new__: https://github.com/cython/cython/issues/799 # XXX cannot `cdef class` with __new__: https://github.com/cython/cython/issues/799
class pyustr(unicode): class pyustr(unicode):
"""ustr is unicode-string. """ustr is unicode-string.
...@@ -411,6 +453,34 @@ class pyustr(unicode): ...@@ -411,6 +453,34 @@ class pyustr(unicode):
return PySeqIter_New(self) return PySeqIter_New(self)
# __add__, __radd__ (no need to override __iadd__)
def __add__(a, b):
# NOTE Cython < 3 does not automatically support __radd__ for cdef class
# https://cython.readthedocs.io/en/latest/src/userguide/migrating_to_cy30.html#arithmetic-special-methods
# but pyustr is currently _not_ cdef'ed class
# see also https://github.com/cython/cython/issues/4750
return pyu(unicode.__add__(a, _pyu_coerce(b)))
def __radd__(b, a):
# a.__add__(b) returned NotImplementedError, e.g. for unicode.__add__(bstr)
# u'' + u() -> u() ; same as u() + u() -> u()
# b'' + u() -> b() ; same as b() + u() -> b()
# barr + u() -> barr
if isinstance(a, bytearray):
# force `bytearray +=` to go via bytearray.sq_inplace_concat - see PyNumber_InPlaceAdd
# for pyustr this relies on patch to bytearray.sq_inplace_concat to accept ustr as bstr
return NotImplemented
a = _pybu_rcoerce(a)
return a.__add__(b)
# __mul__, __rmul__ (no need to override __imul__)
def __mul__(a, b):
return pyu(unicode.__mul__(a, b))
def __rmul__(b, a):
return b.__mul__(a)
# _pyustrIter wraps unicode iterator to return pyustr for each yielded character. # _pyustrIter wraps unicode iterator to return pyustr for each yielded character.
cdef class _pyustrIter: cdef class _pyustrIter:
cdef object uiter cdef object uiter
...@@ -570,7 +640,13 @@ if PY_MAJOR_VERSION < 3: ...@@ -570,7 +640,13 @@ if PY_MAJOR_VERSION < 3:
# - bytearray.__init__ to accept ustr instead of raising 'TypeError: # - bytearray.__init__ to accept ustr instead of raising 'TypeError:
# string argument without an encoding' (pybug: bytearray() should respect # string argument without an encoding' (pybug: bytearray() should respect
# __bytes__ similarly to bytes) # __bytes__ similarly to bytes)
#
# - bytearray.{sq_concat,sq_inplace_concat} to accept ustr instead of raising
# TypeError. (pybug: bytearray + and += should respect __bytes__)
cdef initproc _bytearray_tp_init = (<_XPyTypeObject*>bytearray) .tp_init cdef initproc _bytearray_tp_init = (<_XPyTypeObject*>bytearray) .tp_init
cdef binaryfunc _bytearray_sq_concat = (<_XPyTypeObject*>bytearray) .tp_as_sequence.sq_concat
cdef binaryfunc _bytearray_sq_iconcat = (<_XPyTypeObject*>bytearray) .tp_as_sequence.sq_inplace_concat
cdef int _bytearray_tp_xinit(object self, PyObject* args, PyObject* kw) except -1: cdef int _bytearray_tp_xinit(object self, PyObject* args, PyObject* kw) except -1:
if args != NULL and (kw == NULL or (not <object>kw)): if args != NULL and (kw == NULL or (not <object>kw)):
...@@ -583,9 +659,22 @@ cdef int _bytearray_tp_xinit(object self, PyObject* args, PyObject* kw) except - ...@@ -583,9 +659,22 @@ cdef int _bytearray_tp_xinit(object self, PyObject* args, PyObject* kw) except -
return _bytearray_tp_init(self, args, kw) return _bytearray_tp_init(self, args, kw)
cdef object _bytearray_sq_xconcat(object a, object b):
if isinstance(b, pyustr):
b = pyb(b)
return _bytearray_sq_concat(a, b)
cdef object _bytearray_sq_xiconcat(object a, object b):
if isinstance(b, pyustr):
b = pyb(b)
return _bytearray_sq_iconcat(a, b)
def _bytearray_x__init__(self, *argv, **kw): def _bytearray_x__init__(self, *argv, **kw):
# NOTE don't return - just call: __init__ should return None # NOTE don't return - just call: __init__ should return None
_bytearray_tp_xinit(self, <PyObject*>argv, <PyObject*>kw) _bytearray_tp_xinit(self, <PyObject*>argv, <PyObject*>kw)
def _bytearray_x__add__ (a, b): return _bytearray_sq_xconcat(a, b)
def _bytearray_x__iadd__(a, b): return _bytearray_sq_xiconcat(a, b)
def _(): def _():
cdef PyTypeObject* t cdef PyTypeObject* t
...@@ -596,6 +685,13 @@ def _(): ...@@ -596,6 +685,13 @@ def _():
if t_.tp_init == _bytearray_tp_init: if t_.tp_init == _bytearray_tp_init:
t_.tp_init = _bytearray_tp_xinit t_.tp_init = _bytearray_tp_xinit
_patch_slot(t, '__init__', _bytearray_x__init__) _patch_slot(t, '__init__', _bytearray_x__init__)
t_sq = t_.tp_as_sequence
if t_sq.sq_concat == _bytearray_sq_concat:
t_sq.sq_concat = _bytearray_sq_xconcat
_patch_slot(t, '__add__', _bytearray_x__add__)
if t_sq.sq_inplace_concat == _bytearray_sq_iconcat:
t_sq.sq_inplace_concat = _bytearray_sq_xiconcat
_patch_slot(t, '__iadd__', _bytearray_x__iadd__)
_() _()
# _patch_slot installs func_or_descr into typ's __dict__ as name. # _patch_slot installs func_or_descr into typ's __dict__ as name.
......
...@@ -437,6 +437,32 @@ def test_strings_iter(): ...@@ -437,6 +437,32 @@ def test_strings_iter():
assert list(XIter()) == ['м','и','р','у',' ','м','и','р'] assert list(XIter()) == ['м','и','р','у',' ','м','и','р']
# verify string operations like `x * 3` for all cases from bytes, bytearray, unicode, bstr and ustr.
@mark.parametrize('tx', (bytes, unicode, bytearray, bstr, ustr))
def test_strings_ops1(tx):
x = xstr(u'мир', tx)
assert type(x) is tx
# *
_ = x * 3
assert type(_) is tx
assert xudata(_) == u'мирмирмир'
_ = 3 * x
assert type(_) is tx
assert xudata(_) == u'мирмирмир'
# *=
_ = x
_ *= 3
assert type(_) is tx
assert xudata(_) == u'мирмирмир'
assert _ is x if tx is bytearray else \
_ is not x
# verify string operations like `x + y` for all combinations of pairs from # verify string operations like `x + y` for all combinations of pairs from
# bytes, unicode, bstr, ustr and bytearray. Except if both x and y are std # bytes, unicode, bstr, ustr and bytearray. Except if both x and y are std
# python types, e.g. (bytes, unicode), because those combinations are handled # python types, e.g. (bytes, unicode), because those combinations are handled
...@@ -483,6 +509,41 @@ def test_strings_ops2(tx, ty): ...@@ -483,6 +509,41 @@ def test_strings_ops2(tx, ty):
assert not (x > y) assert not (x > y)
assert y > x assert y > x
# +
#
# type(x + y) is determined by type(x):
# u() + * -> u()
# b() + * -> b()
# u'' + u()/b() -> u()
# u'' + u'' -> u''
# b'' + u()/b() -> b()
# b'' + b'' -> b''
# barr + u()/b() -> barr
if tx in (bstr, ustr):
tadd = tx
elif tx in (unicode, bytes):
if ty in (unicode, bytes, bytearray):
tadd = tx # we are skipping e.g. bytes + unicode
else:
assert ty in (bstr, ustr)
tadd = tbu(tx)
else:
assert tx is bytearray
tadd = tx
_ = x + y
assert type(_) is tadd
assert _ is not x; assert _ is not y
assert _ == xstr(u'hello мир', tadd)
# += (same typing rules as for +)
_ = x
_ += y
assert type(_) is tadd
assert _ == xstr(u'hello мир', tadd)
assert _ is x if tx is bytearray else \
_ is not x
# verify string operations like `x + y` for x being bstr/ustr and y being a # verify string operations like `x + y` for x being bstr/ustr and y being a
# type unsupported for coercion. # type unsupported for coercion.
...@@ -492,6 +553,9 @@ def test_strings_ops2_bufreject(tx, ty): ...@@ -492,6 +553,9 @@ def test_strings_ops2_bufreject(tx, ty):
x = xstr(u'мир', tx) x = xstr(u'мир', tx)
y = ty(b'123') y = ty(b'123')
with raises(TypeError): x + y
with raises(TypeError): x * y
assert (x == y) is False # see test_strings_ops2_eq_any assert (x == y) is False # see test_strings_ops2_eq_any
assert (x != y) is True assert (x != y) is True
with raises(TypeError): x >= y with raises(TypeError): x >= y
...@@ -499,6 +563,10 @@ def test_strings_ops2_bufreject(tx, ty): ...@@ -499,6 +563,10 @@ def test_strings_ops2_bufreject(tx, ty):
with raises(TypeError): x > y with raises(TypeError): x > y
with raises(TypeError): x < y with raises(TypeError): x < y
# reverse operations, e.g. memoryview + bstr
with raises(TypeError): y + x
with raises(TypeError): y * x
# `y > x` does not raise when x is bstr (= provides buffer): # `y > x` does not raise when x is bstr (= provides buffer):
y == x # not raises TypeError - see test_strings_ops2_eq_any y == x # not raises TypeError - see test_strings_ops2_eq_any
y != x # y != x #
...@@ -658,6 +726,28 @@ def test_strings_patched_transparently(): ...@@ -658,6 +726,28 @@ def test_strings_patched_transparently():
assert _(3) == r"bytearray(b'\x00\x00\x00')" assert _(3) == r"bytearray(b'\x00\x00\x00')"
assert _((1,2,3)) == r"bytearray(b'\x01\x02\x03')" assert _((1,2,3)) == r"bytearray(b'\x01\x02\x03')"
# bytearray.{sq_concat,sq_inplace_concat} stay unaffected
a = bytearray()
def _(delta):
aa = a + delta
aa_ = a.__add__(delta)
assert aa is not a
assert aa_ is not a
aclone = bytearray(a)
a_ = a
a_ += delta
aclone_ = aclone
aclone_.__iadd__(delta)
assert a_ is a
assert a_ == aa
assert aclone_ is aclone
assert aclone_ == a_
return a_
assert _(b'') == b''
assert _(b'a') == b'a'
assert _(b'b') == b'ab'
assert _(b'cde') == b'abcde'
# ---- benchmarks ---- # ---- benchmarks ----
...@@ -708,3 +798,31 @@ def xstr(text, typ): ...@@ -708,3 +798,31 @@ def xstr(text, typ):
s = _() s = _()
assert type(s) is typ assert type(s) is typ
return s return s
# xudata returns data of x converted to unicode string.
# x can be bytes/unicode/bytearray / bstr/ustr.
def xudata(x):
def _():
if type(x) in (bytes, bytearray):
return x.decode('utf-8')
elif type(x) is unicode:
return x
elif type(x) is ustr:
return _udata(x)
elif type(x) is bstr:
return _bdata(x).decode('utf-8')
else:
raise TypeError(x)
xu = _()
assert type(xu) is unicode
return xu
# tbu maps specified type to b/u:
# b/bytes/bytearray -> b; u/unicode -> u.
def tbu(typ):
if typ in (bytes, bytearray, bstr):
return bstr
if typ in (unicode, ustr):
return ustr
raise AssertionError("invalid type %r" % typ)
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment