Commit f2baabfd authored by Jérome Perrin's avatar Jérome Perrin

IdTool: expect bytes for group_id

group_id is used as key of OOBtree and as [documented], it's not
possible to mix keys that can not be compared, so we can not have a mix
of string and bytes.

🤔 a BTree can not contain str and bytes as keys (the same way that on
python2 it can not contain unicode and str), basically we just need a
consistent type and something compatible with how we transform the data
from python2 databases.
parent 04b796c3
...@@ -29,11 +29,14 @@ ...@@ -29,11 +29,14 @@
############################################################################## ##############################################################################
import unittest import unittest
import warnings
from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase
from Products.ERP5Type.tests.utils import createZODBPythonScript from Products.ERP5Type.tests.utils import createZODBPythonScript
from MySQLdb import ProgrammingError from MySQLdb import ProgrammingError
from six.moves import range from six.moves import range
import six
class TestIdTool(ERP5TypeTestCase): class TestIdTool(ERP5TypeTestCase):
...@@ -146,6 +149,36 @@ class TestIdTool(ERP5TypeTestCase): ...@@ -146,6 +149,36 @@ class TestIdTool(ERP5TypeTestCase):
self.assertEqual(21, self.id_tool.generateNewId(id_generator=id_generator, self.assertEqual(21, self.id_tool.generateNewId(id_generator=id_generator,
id_group='d02', default=3)) id_group='d02', default=3))
# generateNewId expect str, but convert id_group when passed a wrong type
with warnings.catch_warnings(record=True) as recorded:
self.assertEqual(
self.id_tool.generateNewId(
id_generator=id_generator,
id_group=('d', 1),
), 0)
self.assertEqual(
[(type(w.message), str(w.message)) for w in recorded],
[(DeprecationWarning, 'id_group must be a string, other types are deprecated.')],
)
# on python3, it understands bytes and converts to string
self.assertEqual(
self.id_tool.generateNewId(
id_generator=id_generator,
id_group='bytes',
), 0)
with warnings.catch_warnings(record=True) as recorded:
self.assertEqual(
self.id_tool.generateNewId(
id_generator=id_generator,
id_group=b'bytes',
), 1)
if six.PY3:
self.assertEqual(
[(type(w.message), str(w.message)) for w in recorded],
[(BytesWarning, 'id_group must be a string, not bytes.')],
)
def test_02a_generateNewIdWithZODBGenerator(self): def test_02a_generateNewIdWithZODBGenerator(self):
""" """
Check the generateNewId with a zodb id generator Check the generateNewId with a zodb id generator
...@@ -234,6 +267,38 @@ class TestIdTool(ERP5TypeTestCase): ...@@ -234,6 +267,38 @@ class TestIdTool(ERP5TypeTestCase):
id_generator=id_generator, id_generator=id_generator,
id_group='d03', default=3, id_count=2)) id_group='d03', default=3, id_count=2))
# generateNewIdList expect str, but convert id_group when passed a wrong type
with warnings.catch_warnings(record=True) as recorded:
self.assertEqual(
self.id_tool.generateNewIdList(
id_generator=id_generator,
id_group=('d', 1),
id_count=1,), [0])
self.assertEqual(
[(type(w.message), str(w.message)) for w in recorded],
[(DeprecationWarning, 'id_group must be a string, other types are deprecated.')],
)
# on python3, it understands bytes and converts to string
self.assertEqual(
self.id_tool.generateNewIdList(
id_generator=id_generator,
id_group='bytes',
id_count=1,
), [0])
with warnings.catch_warnings(record=True) as recorded:
self.assertEqual(
self.id_tool.generateNewIdList(
id_generator=id_generator,
id_group=b'bytes',
id_count=1,
), [1])
if six.PY3:
self.assertEqual(
[(type(w.message), str(w.message)) for w in recorded],
[(BytesWarning, 'id_group must be a string, not bytes.')],
)
def test_03a_generateNewIdListWithZODBGenerator(self): def test_03a_generateNewIdListWithZODBGenerator(self):
""" """
Check the generateNewIdList with zodb generator Check the generateNewIdList with zodb generator
...@@ -282,7 +347,7 @@ class TestIdTool(ERP5TypeTestCase): ...@@ -282,7 +347,7 @@ class TestIdTool(ERP5TypeTestCase):
query = 'select last_id from portal_ids where id_group="foo_bar"' query = 'select last_id from portal_ids where id_group="foo_bar"'
self.assertRaises(ProgrammingError, sql_connection.manage_test, query) self.assertRaises(ProgrammingError, sql_connection.manage_test, query)
generator.rebuildSqlTable() generator.rebuildSqlTable()
result = sql_connection.manage_test(query) result = sql_connection.manage_test(query)
self.assertEqual(result[0].last_id, 4) self.assertEqual(result[0].last_id, 4)
def checkExportImportDict(self, id_generator): def checkExportImportDict(self, id_generator):
......
...@@ -60,7 +60,9 @@ class TestIdToolUpgrade(ERP5TypeTestCase): ...@@ -60,7 +60,9 @@ class TestIdToolUpgrade(ERP5TypeTestCase):
self.tic() self.tic()
def beforeTearDown(self): def beforeTearDown(self):
self.portal.portal_caches.clearAllCache()
self.id_tool.clearGenerator(all=True) self.id_tool.clearGenerator(all=True)
self.tic()
def createGenerators(self): def createGenerators(self):
""" """
...@@ -296,7 +298,26 @@ class TestIdToolUpgrade(ERP5TypeTestCase): ...@@ -296,7 +298,26 @@ class TestIdToolUpgrade(ERP5TypeTestCase):
id_generator.clearGenerator() # clear stored data id_generator.clearGenerator() # clear stored data
self._checkDataStructureMigration(id_generator) self._checkDataStructureMigration(id_generator)
def test_suite(): def test_portal_ids_table_id_group_column_binary(self):
suite = unittest.TestSuite() """portal_ids.id_group is now created as VARCHAR,
suite.addTest(unittest.makeSuite(TestIdToolUpgrade)) but it use to be binary. There is no data migration, the
return suite SQL method has been adjusted to cast during select.
This checks that id generator works well when the column
is VARBINARY, like it's the case for old instances.
"""
self.assertEqual(
self.sql_generator.generateNewId(id_group=self.id()),
0)
exported = self.sql_generator.exportGeneratorIdDict()
self.tic()
self.portal.portal_ids.IdTool_zCommit()
self.portal.erp5_sql_connection.manage_test(
'ALTER TABLE portal_ids MODIFY COLUMN id_group VARBINARY(255)'
)
self.tic()
self.sql_generator.importGeneratorIdDict(exported, clear=True)
self.tic()
self.assertEqual(
self.sql_generator.generateNewId(id_group=self.id()),
1)
...@@ -73,6 +73,8 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator): ...@@ -73,6 +73,8 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator):
# Check the arguments # Check the arguments
if id_group in (None, 'None'): if id_group in (None, 'None'):
raise ValueError('%r is not a valid group Id.' % id_group) raise ValueError('%r is not a valid group Id.' % id_group)
if not isinstance(id_group, str):
raise TypeError('id_group must be str')
if default is None: if default is None:
default = 0 default = 0
...@@ -134,6 +136,7 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator): ...@@ -134,6 +136,7 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator):
# the last id stored in the sql table # the last id stored in the sql table
for line in self._getValueListFromTable(): for line in self._getValueListFromTable():
id_group = line['id_group'] id_group = line['id_group']
assert isinstance(id_group, str)
last_id = line['last_id'] last_id = line['last_id']
if id_group in self.last_max_id_dict and \ if id_group in self.last_max_id_dict and \
self.last_max_id_dict[id_group].value > last_id: self.last_max_id_dict[id_group].value > last_id:
...@@ -197,6 +200,7 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator): ...@@ -197,6 +200,7 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator):
getattr(portal_ids, 'dict_length_ids', None) is None): getattr(portal_ids, 'dict_length_ids', None) is None):
dump_dict = portal_ids.dict_length_ids dump_dict = portal_ids.dict_length_ids
for id_group, last_id in dump_dict.items(): for id_group, last_id in dump_dict.items():
assert isinstance(id_group, str)
last_insert_id = get_last_id_method(id_group=id_group) last_insert_id = get_last_id_method(id_group=id_group)
last_id = int(last_id.value) last_id = int(last_id.value)
if len(last_insert_id) != 0: if len(last_insert_id) != 0:
......
...@@ -59,6 +59,8 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator): ...@@ -59,6 +59,8 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator):
""" """
if id_group in (None, 'None'): if id_group in (None, 'None'):
raise ValueError('%r is not a valid group Id.' % id_group) raise ValueError('%r is not a valid group Id.' % id_group)
if not isinstance(id_group, str):
raise TypeError('id_group must be str')
if default is None: if default is None:
default = 0 default = 0
last_id_dict = getattr(self, 'last_id_dict', None) last_id_dict = getattr(self, 'last_id_dict', None)
...@@ -109,6 +111,8 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator): ...@@ -109,6 +111,8 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator):
for id_group, last_id in portal_ids.dict_ids.items(): for id_group, last_id in portal_ids.dict_ids.items():
if not isinstance(id_group, str): if not isinstance(id_group, str):
id_group = repr(id_group) id_group = repr(id_group)
if isinstance(id_group, bytes):
raise NotImplementedErro('TODO' + repr(id_group))
if id_group in self.last_id_dict and \ if id_group in self.last_id_dict and \
self.last_id_dict[id_group] > last_id: self.last_id_dict[id_group] > last_id:
continue continue
...@@ -148,7 +152,9 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator): ...@@ -148,7 +152,9 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator):
self.clearGenerator() self.clearGenerator()
if not isinstance(id_dict, dict): if not isinstance(id_dict, dict):
raise TypeError('the argument given is not a dictionary') raise TypeError('the argument given is not a dictionary')
for value in id_dict.values(): for key, value in id_dict.items():
if not isinstance(key, str):
raise TypeError('key %r given in dictionary is not str' % (key, ))
if not isinstance(value, six.integer_types): if not isinstance(value, six.integer_types):
raise TypeError('the value given in dictionary is not a integer') raise TypeError('the value given in dictionary is not a integer')
self.last_id_dict.update(id_dict) self.last_id_dict.update(id_dict)
......
...@@ -118,6 +118,9 @@ class IdTool(BaseTool): ...@@ -118,6 +118,9 @@ class IdTool(BaseTool):
if id_group in (None, 'None'): if id_group in (None, 'None'):
raise ValueError('%r is not a valid id_group' % id_group) raise ValueError('%r is not a valid id_group' % id_group)
# for compatibilty with sql data, must not use id_group as a list # for compatibilty with sql data, must not use id_group as a list
if six.PY3 and isinstance(id_group, bytes):
warnings.warn('id_group must be a string, not bytes.', BytesWarning)
id_group = id_group.decode('utf-8')
if not isinstance(id_group, str): if not isinstance(id_group, str):
id_group = repr(id_group) id_group = repr(id_group)
warnings.warn('id_group must be a string, other types ' warnings.warn('id_group must be a string, other types '
...@@ -177,6 +180,9 @@ class IdTool(BaseTool): ...@@ -177,6 +180,9 @@ class IdTool(BaseTool):
""" """
if id_group in (None, 'None'): if id_group in (None, 'None'):
raise ValueError('%r is not a valid id_group' % id_group) raise ValueError('%r is not a valid id_group' % id_group)
if six.PY3 and isinstance(id_group, bytes):
warnings.warn('id_group must be a string, not bytes.', BytesWarning)
id_group = id_group.decode('utf-8')
# for compatibilty with sql data, must not use id_group as a list # for compatibilty with sql data, must not use id_group as a list
if not isinstance(id_group, str): if not isinstance(id_group, str):
id_group = repr(id_group) id_group = repr(id_group)
......
CREATE TABLE `portal_ids` ( CREATE TABLE `portal_ids` (
`id_group` VARBINARY(255), `id_group` VARCHAR(255),
`last_id` BIGINT UNSIGNED, `last_id` BIGINT UNSIGNED,
PRIMARY KEY (`id_group`) PRIMARY KEY (`id_group`)
) ENGINE=InnoDB; ) ENGINE=InnoDB;
\ No newline at end of file
CREATE TABLE `portal_ids` ( CREATE TABLE `portal_ids` (
`id_group` VARBINARY(255), `id_group` VARCHAR(255),
`last_id` BIGINT UNSIGNED, `last_id` BIGINT UNSIGNED,
PRIMARY KEY (`id_group`) PRIMARY KEY (`id_group`)
) ENGINE=InnoDB ) ENGINE=InnoDB
......
select id_group, last_id from portal_ids select id_group, cast(id_group as CHAR) id_group from portal_ids
\ No newline at end of file \ No newline at end of file
select id_group, last_id from portal_ids select cast(id_group as CHAR) id_group, last_id from portal_ids
<dtml-if id_group>where id_group > "<dtml-var id_group>"</dtml-if> <dtml-if id_group>where id_group > <dtml-sqlvar id_group type="string"></dtml-if>
order by id_group order by id_group
\ No newline at end of file
...@@ -3655,7 +3655,7 @@ class Base( ...@@ -3655,7 +3655,7 @@ class Base(
next_id = default next_id = default
new_next_id = None if poison else next_id + count new_next_id = None if poison else next_id + count
id_generator_state[group].value = new_next_id id_generator_state[group].value = new_next_id
return range(next_id, new_next_id) return ensure_list(range(next_id, new_next_id))
InitializeClass(Base) InitializeClass(Base)
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment