Commit 21a0f693 authored by Jérome Perrin's avatar Jérome Perrin

IdTool: expect bytes for group_id

group_id is used as key of OOBtree and as [documented], it's not
possible to mix keys that can not be compared, so we can not have a mix
of string and bytes.
parent 6a6e11e5
......@@ -3413,13 +3413,13 @@ class TestTransactions(AccountingTestCase):
# ...except uid generator
new_uid, = portal_ids.generateNewIdList(
id_generator='uid',
id_group='catalog_uid',
id_group=b'catalog_uid',
id_count=1,
)
portal_ids.clearGenerator(all=True)
portal_ids.generateNewIdList(
id_generator='uid',
id_group='catalog_uid',
id_group=b'catalog_uid',
id_count=1,
default=new_uid,
)
......
......@@ -132,19 +132,28 @@ class TestIdTool(ERP5TypeTestCase):
Check the method generateNewId
"""
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='a02'))
id_group=b'a02'))
# Different groups generate different ids
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='b02'))
id_group=b'b02'))
self.assertEqual(1, self.id_tool.generateNewId(id_generator=id_generator,
id_group='a02'))
id_group=b'a02'))
# With default value
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='c02', default=0))
id_group=b'c02', default=0))
self.assertEqual(20, self.id_tool.generateNewId(id_generator=id_generator,
id_group='d02', default=20))
id_group=b'd02', default=20))
self.assertEqual(21, self.id_tool.generateNewId(id_generator=id_generator,
id_group='d02', default=3))
id_group=b'd02', default=3))
# generateNewId expect types, but convert id_group when passed a wrong type
# TODO assert warning
# self.assertRaises(
# TypeError,
# self.id_tool.generateNewId,
# id_generator=id_generator,
# id_group='wrong type !')
def test_02a_generateNewIdWithZODBGenerator(self):
"""
......@@ -159,8 +168,8 @@ class TestIdTool(ERP5TypeTestCase):
# generate ids
self.checkGenerateNewId('test_application_zodb')
# check zodb dict
self.assertEqual(zodb_generator.last_id_dict['c02'], 0)
self.assertEqual(zodb_generator.last_id_dict['d02'], 21)
self.assertEqual(zodb_generator.last_id_dict[b'c02'], 0)
self.assertEqual(zodb_generator.last_id_dict[b'd02'], 21)
def checkGenerateNewIdWithSQL(self, store):
"""
......@@ -187,12 +196,12 @@ class TestIdTool(ERP5TypeTestCase):
# generate ids
self.checkGenerateNewId('test_application_sql')
# check last_id in sql
self.assertEqual(last_id_method(id_group='c02')[0]['LAST_INSERT_ID()'], 0)
self.assertEqual(last_id_method(id_group='d02')[0]['LAST_INSERT_ID()'], 21)
self.assertEqual(last_id_method(id_group=b'c02')[0]['LAST_INSERT_ID()'], 0)
self.assertEqual(last_id_method(id_group=b'd02')[0]['LAST_INSERT_ID()'], 21)
# check zodb dict
if store:
self.assertEqual(sql_generator.last_max_id_dict['c02'].value, 0)
self.assertEqual(sql_generator.last_max_id_dict['d02'].value, 21)
self.assertEqual(sql_generator.last_max_id_dict[b'c02'].value, 0)
self.assertEqual(sql_generator.last_max_id_dict[b'd02'].value, 21)
else:
self.assertEqual(len(sql_generator.last_max_id_dict), 0)
......@@ -215,24 +224,24 @@ class TestIdTool(ERP5TypeTestCase):
Check the generateNewIdList
"""
self.assertEqual([0], self.id_tool.generateNewIdList(\
id_generator=id_generator, id_group='a03'))
id_generator=id_generator, id_group=b'a03'))
# Different groups generate different ids
self.assertEqual([0, 1], self.id_tool.generateNewIdList(\
id_generator=id_generator,
id_group='b03', id_count=2))
id_group=b'b03', id_count=2))
self.assertEqual([1 ,2, 3], self.id_tool.generateNewIdList(\
id_generator=id_generator,
id_group='a03', id_count=3))
id_group=b'a03', id_count=3))
# With default value
self.assertEqual([0, 1, 2], self.id_tool.generateNewIdList(\
id_generator=id_generator,
id_group='c03', default=0, id_count=3))
id_group=b'c03', default=0, id_count=3))
self.assertEqual([20, 21, 22], self.id_tool.generateNewIdList(\
id_generator=id_generator,
id_group='d03', default=20, id_count=3))
id_group=b'd03', default=20, id_count=3))
self.assertEqual([23, 24], self.id_tool.generateNewIdList(\
id_generator=id_generator,
id_group='d03', default=3, id_count=2))
id_group=b'd03', default=3, id_count=2))
def test_03a_generateNewIdListWithZODBGenerator(self):
"""
......@@ -253,16 +262,16 @@ class TestIdTool(ERP5TypeTestCase):
"""
self.assertEqual([1, 2, 3], self.id_tool.generateNewIdList(
id_generator='test_application_zodb',
id_group='a04', default=1, id_count=3))
id_group=b'a04', default=1, id_count=3))
self.assertEqual(4, self.id_tool.generateNewId(
id_generator='test_application_zodb',
id_group='a04'))
id_group=b'a04'))
self.assertEqual(1, self.id_tool.generateNewId(
id_generator='test_application_sql',
id_group='a04', default=1))
id_group=b'a04', default=1))
self.assertEqual([2, 3, 4], self.id_tool.generateNewIdList(
id_generator='test_application_sql',
id_group='a04', id_count=3))
id_group=b'a04', id_count=3))
def test_05_RebuildTableForDefaultSQLNonContinuousIncreasingIdGenerator(self):
"""
......@@ -273,8 +282,8 @@ class TestIdTool(ERP5TypeTestCase):
generator = self.id_tool._getLatestGeneratorValue(
'mysql_non_continuous_increasing')
self.assertTrue(generator is not None)
generator.generateNewId(id_group='foo_bar', default=4)
self.assertEqual(generator.last_max_id_dict['foo_bar'].value, 4)
generator.generateNewId(id_group=b'foo_bar', default=4)
self.assertEqual(generator.last_max_id_dict[b'foo_bar'].value, 4)
portal.IdTool_zDropTable()
# make sure to use same connector as IdTool_zDropTable to avoid mariadb :
# "Waiting for table metadata lock"
......@@ -282,7 +291,7 @@ class TestIdTool(ERP5TypeTestCase):
query = 'select last_id from portal_ids where id_group="foo_bar"'
self.assertRaises(ProgrammingError, sql_connection.manage_test, query)
generator.rebuildSqlTable()
result = sql_connection.manage_test(query)
result = sql_connection.manage_test(query)
self.assertEqual(result[0].last_id, 4)
def checkExportImportDict(self, id_generator):
......@@ -291,12 +300,12 @@ class TestIdTool(ERP5TypeTestCase):
"""
generator = self.getLastGenerator(id_generator)
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='06'))
id_group=b'06'))
id_dict = generator.exportGeneratorIdDict()
self.assertEqual(0, id_dict['06'])
generator.importGeneratorIdDict(id_dict={'06':6})
self.assertEqual(0, id_dict[b'06'])
generator.importGeneratorIdDict(id_dict={b'06': 6})
self.assertEqual(7, self.id_tool.generateNewId(id_generator=id_generator,
id_group='06'))
id_group=b'06'))
def test_06_ExportImportDict(self):
"""
......@@ -311,7 +320,7 @@ class TestIdTool(ERP5TypeTestCase):
"""
generator = self.getLastGenerator(id_generator)
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='07'))
id_group=b'07'))
id_dict = generator.exportGeneratorIdDict()
id_dict_before = dict(id_dict)
generator.importGeneratorIdDict(id_dict=id_dict, clear=True)
......@@ -321,10 +330,10 @@ class TestIdTool(ERP5TypeTestCase):
# make sure generating a new id will increment
self.assertEqual(1, self.id_tool.generateNewId(id_generator=id_generator,
id_group='07'))
id_group=b'07'))
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='another_group'))
id_group=b'another_group'))
# reimport clearing, the group we just use should have been cleared out
generator.importGeneratorIdDict(id_dict=id_dict, clear=True)
id_dict = generator.exportGeneratorIdDict()
......@@ -346,28 +355,27 @@ class TestIdTool(ERP5TypeTestCase):
sql_generator = self.getLastGenerator(id_generator)
sql_generator.setStoredInZodb(True)
sql_generator.setStoreInterval(2)
#sql_generator.setStoreInterval(2)
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='07'))
self.assertEqual(sql_generator.last_max_id_dict['07'].value, 0)
id_group=b'07'))
self.assertEqual(sql_generator.last_max_id_dict[b'07'].value, 0)
self.assertEqual(1, self.id_tool.generateNewId(id_generator=id_generator,
id_group='07'))
id_group=b'07'))
# last_id isn't stored because 1 < last_id (0) + store_interval
self.assertEqual(sql_generator.last_max_id_dict['07'].value, 0)
self.assertEqual(sql_generator.last_max_id_dict[b'07'].value, 0)
self.assertEqual(2, self.id_tool.generateNewId(id_generator=id_generator,
id_group='07'))
self.assertEqual(sql_generator.last_max_id_dict['07'].value, 2)
id_group=b'07'))
self.assertEqual(sql_generator.last_max_id_dict[b'07'].value, 2)
self.getLastGenerator(id_generator).\
importGeneratorIdDict(id_dict = {'07':5})
importGeneratorIdDict(id_dict={b'07': 5})
self.assertEqual(6, self.id_tool.generateNewId(id_generator=id_generator,
id_group='07'))
id_group=b'07'))
# last_id stored because 6 < last_id (5) + store_interval
self.assertEqual(sql_generator.last_max_id_dict['07'].value, 5)
self.assertEqual(sql_generator.last_max_id_dict[b'07'].value, 5)
# the sql value is higher that zodb value so the export return the sql
# value
id_dict = self.getLastGenerator(id_generator).exportGeneratorIdDict()
self.assertEqual(id_dict['07'], 6)
self.assertEqual(id_dict[b'07'], 6)
def test_08_updateLastMaxIdDictFromTable(self):
"""
......@@ -378,27 +386,27 @@ class TestIdTool(ERP5TypeTestCase):
sql_generator = self.getLastGenerator(id_generator)
sql_generator.setStoredInZodb(False)
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='A-08'))
id_group=b'A-08'))
self.assertEqual(1, self.id_tool.generateNewId(id_generator=id_generator,
id_group='A-08'))
id_group=b'A-08'))
self.assertEqual(2, self.id_tool.generateNewId(id_generator=id_generator,
id_group='A-08'))
id_group=b'A-08'))
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group='B-08'))
id_group=b'B-08'))
self.assertEqual(1, self.id_tool.generateNewId(id_generator=id_generator,
id_group='B-08'))
id_group=b'B-08'))
A_LOT_OF_KEY = 2500
var_id = 'C-%04d'
for x in range(A_LOT_OF_KEY):
self.assertEqual(0, self.id_tool.generateNewId(id_generator=id_generator,
id_group=var_id % x))
id_group=(var_id % x).encode()))
# test before update
self.assertEqual(None, sql_generator.last_max_id_dict.get('A-08'))
self.assertEqual(None, sql_generator.last_max_id_dict.get('B-08'))
self.assertEqual(None, sql_generator.last_max_id_dict.get(b'A-08'))
self.assertEqual(None, sql_generator.last_max_id_dict.get(b'B-08'))
for x in range(A_LOT_OF_KEY):
self.assertEqual(None, sql_generator.last_max_id_dict.get(var_id % x))
self.assertEqual(None, sql_generator.last_max_id_dict.get((var_id % x).encode()))
createZODBPythonScript(
self.portal.portal_skins.custom,
'IdTool_updateLastMaxId',
......@@ -419,10 +427,10 @@ if new_last_id_group is not None:
self.tic()
# asserts
self.assertEqual(2, sql_generator.last_max_id_dict['A-08'].value)
self.assertEqual(1, sql_generator.last_max_id_dict['B-08'].value)
self.assertEqual(2, sql_generator.last_max_id_dict[b'A-08'].value)
self.assertEqual(1, sql_generator.last_max_id_dict[b'B-08'].value)
for x in range(A_LOT_OF_KEY):
self.assertEqual(0, sql_generator.last_max_id_dict[var_id % x].value)
self.assertEqual(0, sql_generator.last_max_id_dict[(var_id % x).encode()].value)
def test_decentralised_ZODB_id_generator(self):
"""
......@@ -435,8 +443,8 @@ if new_last_id_group is not None:
old_id_group = str((
'test_decentralised_ZODB_id_generator',
container.getPath(),
))
new_id_group = 'test_decentralised_ZODB_id_generator'
)).encode()
new_id_group = b'test_decentralised_ZODB_id_generator'
latest_id_old_generator, = portal_ids.generateNewIdList(
id_group=old_id_group,
id_generator='zodb_continuous_increasing',
......
......@@ -91,9 +91,10 @@ class IdGenerator(Base):
by BTrees.Length to manage conflict in the zodb, use also a persistant
mapping to be persistent
"""
# For compatibilty with sql data, must not use id_group as a list
if not isinstance(id_group, str):
raise TypeError('id_group is not a string')
# Type of id groups must be consistent, because we use them as BTree keys
# https://btrees.readthedocs.io/en/latest/overview.html#total-ordering-and-persistence
if not isinstance(id_group, bytes):
raise TypeError('id_group must be bytes')
return self._getLatestSpecialiseValue().generateNewIdList(id_group=id_group,
id_count=id_count,
default=default,
......
......@@ -71,8 +71,10 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator):
mapping to be persistent
"""
# Check the arguments
if id_group in (None, 'None'):
if id_group in (None, b'None'):
raise ValueError('%r is not a valid group Id.' % id_group)
if not isinstance(id_group, bytes):
raise TypeError('id_group must be bytes')
if default is None:
default = 0
......@@ -134,6 +136,7 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator):
# the last id stored in the sql table
for line in self._getValueListFromTable():
id_group = line['id_group']
assert isinstance(id_group, bytes)
last_id = line['last_id']
if id_group in self.last_max_id_dict and \
self.last_max_id_dict[id_group].value > last_id:
......@@ -197,6 +200,7 @@ class SQLNonContinuousIncreasingIdGenerator(IdGenerator):
getattr(portal_ids, 'dict_length_ids', None) is None):
dump_dict = portal_ids.dict_length_ids
for id_group, last_id in dump_dict.items():
assert isinstance(id_group, bytes)
last_insert_id = get_last_id_method(id_group=id_group)
last_id = int(last_id.value)
if len(last_insert_id) != 0:
......
......@@ -57,8 +57,10 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator):
Use int to store the last_id, use also a persistant mapping for to be
persistent.
"""
if id_group in (None, 'None'):
if id_group in (None, b'None'):
raise ValueError('%r is not a valid group Id.' % id_group)
if not isinstance(id_group, bytes):
raise TypeError('id_group must be bytes')
if default is None:
default = 0
last_id_dict = getattr(self, 'last_id_dict', None)
......@@ -107,8 +109,8 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator):
# Dump the dict_ids dictionary
if getattr(portal_ids, 'dict_ids', None) is not None:
for id_group, last_id in portal_ids.dict_ids.items():
if not isinstance(id_group, str):
id_group = repr(id_group)
if not isinstance(id_group, bytes):
id_group = repr(id_group).encode()
if id_group in self.last_id_dict and \
self.last_id_dict[id_group] > last_id:
continue
......@@ -148,7 +150,9 @@ class ZODBContinuousIncreasingIdGenerator(IdGenerator):
self.clearGenerator()
if not isinstance(id_dict, dict):
raise TypeError('the argument given is not a dictionary')
for value in id_dict.values():
for key, value in id_dict.items():
if not isinstance(key, bytes):
raise TypeError('key %r given in dictionary is not bytes' % (key, ))
if not isinstance(value, six.integer_types):
raise TypeError('the value given in dictionary is not a integer')
self.last_id_dict.update(id_dict)
......
......@@ -115,12 +115,13 @@ class IdTool(BaseTool):
"""
Generate the next id in the sequence of ids of a particular group
"""
if id_group in (None, 'None'):
if id_group in (None, b'None'):
raise ValueError('%r is not a valid id_group' % id_group)
# for compatibilty with sql data, must not use id_group as a list
if not isinstance(id_group, str):
id_group = repr(id_group)
warnings.warn('id_group must be a string, other types '
if not isinstance(id_group, bytes):
# TODO: check that this repr same as python2 !
id_group = repr(id_group).encode()
warnings.warn('id_group must be bytes, other types '
'are deprecated.', DeprecationWarning)
if id_generator is None:
id_generator = 'document'
......@@ -175,11 +176,11 @@ class IdTool(BaseTool):
"""
Generate a list of next ids in the sequence of ids of a particular group
"""
if id_group in (None, 'None'):
if id_group in (None, b'None'):
raise ValueError('%r is not a valid id_group' % id_group)
# for compatibilty with sql data, must not use id_group as a list
if not isinstance(id_group, str):
id_group = repr(id_group)
if not isinstance(id_group, bytes):
id_group = repr(id_group).encode()
warnings.warn('id_group must be a string, other types '
'are deprecated.', DeprecationWarning)
if id_generator is None:
......
......@@ -3591,7 +3591,7 @@ class Base(
sequence and for <self> instance, and is monotonously increasing by 1 for
each generated id.
group (string):
group (bytes):
Identifies the sequence to use.
count (int):
How many identifiers to generate.
......@@ -3623,8 +3623,8 @@ class Base(
It is expected that group creation is a rare event, very unlikely to
happen concurrently in multiple transactions on the same object.
"""
if not isinstance(group, basestring):
raise TypeError('group must be a string')
if not isinstance(group, bytes):
raise TypeError('group must be bytes')
if not isinstance(default, six.integer_types):
raise TypeError('default must be an integer')
if not isinstance(count, six.integer_types):
......@@ -3649,7 +3649,7 @@ class Base(
next_id = default
new_next_id = None if poison else next_id + count
id_generator_state[group].value = new_next_id
return range(next_id, new_next_id)
return ensure_list(range(next_id, new_next_id))
InitializeClass(Base)
......
......@@ -1148,7 +1148,7 @@ class Catalog(Folder,
uid_buffer.extend(
self.getPortalObject().portal_ids.generateNewIdList(
id_generator='uid',
id_group='catalog_uid',
id_group=b'catalog_uid',
id_count=UID_BUFFER_SIZE,
default=getattr(self, '_max_uid', lambda: 1)(),
),
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment