Commit 2420c89f authored by Arnaud Fontaine's avatar Arnaud Fontaine

py3: _mysql.string_literal() returns bytes().

And _mysql/mysqldb API (_mysql.connection.query()) converts the query string to
bytes() (additionally, cursor.execute(QUERY, ARGS) calls query() after
converting everything to bytes() too).
parent 8024ba17
This diff is collapsed.
...@@ -85,7 +85,7 @@ class SQLDict(SQLBase): ...@@ -85,7 +85,7 @@ class SQLDict(SQLBase):
uid = line.uid uid = line.uid
original_uid = path_and_method_id_dict.get(key) original_uid = path_and_method_id_dict.get(key)
if original_uid is None: if original_uid is None:
sql_method_id = " AND method_id = %s AND group_method_id = %s" % ( sql_method_id = b" AND method_id = %s AND group_method_id = %s" % (
quote(method_id), quote(line.group_method_id)) quote(method_id), quote(line.group_method_id))
m = Message.load(line.message, uid=uid, line=line) m = Message.load(line.message, uid=uid, line=line)
merge_parent = m.activity_kw.get('merge_parent') merge_parent = m.activity_kw.get('merge_parent')
...@@ -102,11 +102,11 @@ class SQLDict(SQLBase): ...@@ -102,11 +102,11 @@ class SQLDict(SQLBase):
uid_list = [] uid_list = []
if path_list: if path_list:
# Select parent messages. # Select parent messages.
result = Results(db.query("SELECT * FROM message" result = Results(db.query(b"SELECT * FROM message"
" WHERE processing_node IN (0, %s) AND path IN (%s)%s" b" WHERE processing_node IN (0, %d) AND path IN (%s)%s"
" ORDER BY path LIMIT 1 FOR UPDATE" % ( b" ORDER BY path LIMIT 1 FOR UPDATE" % (
processing_node, processing_node,
','.join(map(quote, path_list)), b','.join(map(quote, path_list)),
sql_method_id, sql_method_id,
), 0)) ), 0))
if result: # found a parent if result: # found a parent
...@@ -119,11 +119,11 @@ class SQLDict(SQLBase): ...@@ -119,11 +119,11 @@ class SQLDict(SQLBase):
m = Message.load(line.message, uid=uid, line=line) m = Message.load(line.message, uid=uid, line=line)
# return unreserved similar children # return unreserved similar children
path = line.path path = line.path
result = db.query("SELECT uid FROM message" result = db.query(b"SELECT uid FROM message"
" WHERE processing_node = 0 AND (path = %s OR path LIKE %s)" b" WHERE processing_node = 0 AND (path = %s OR path LIKE %s)"
"%s FOR UPDATE" % ( b"%s FOR UPDATE" % (
quote(path), quote(path.replace('_', r'\_') + '/%'), quote(path), quote(path.replace('_', r'\_') + '/%'),
sql_method_id, sql_method_id.encode(),
), 0)[1] ), 0)[1]
reserve_uid_list = [x for x, in result] reserve_uid_list = [x for x, in result]
uid_list += reserve_uid_list uid_list += reserve_uid_list
...@@ -132,8 +132,8 @@ class SQLDict(SQLBase): ...@@ -132,8 +132,8 @@ class SQLDict(SQLBase):
reserve_uid_list.append(uid) reserve_uid_list.append(uid)
else: else:
# Select duplicates. # Select duplicates.
result = db.query("SELECT uid FROM message" result = db.query(b"SELECT uid FROM message"
" WHERE processing_node = 0 AND path = %s%s FOR UPDATE" % ( b" WHERE processing_node = 0 AND path = %s%s FOR UPDATE" % (
quote(path), sql_method_id, quote(path), sql_method_id,
), 0)[1] ), 0)[1]
reserve_uid_list = uid_list = [x for x, in result] reserve_uid_list = uid_list = [x for x, in result]
......
...@@ -27,6 +27,7 @@ from __future__ import absolute_import ...@@ -27,6 +27,7 @@ from __future__ import absolute_import
# #
############################################################################## ##############################################################################
from six import string_types as basestring from six import string_types as basestring
from Products.ERP5Type.Utils import ensure_list
import socket import socket
from six.moves import urllib from six.moves import urllib
...@@ -1356,7 +1357,7 @@ class ActivityTool (BaseTool): ...@@ -1356,7 +1357,7 @@ class ActivityTool (BaseTool):
# use a round-robin algorithm. # use a round-robin algorithm.
# XXX: We always finish by iterating over all queues, in case that # XXX: We always finish by iterating over all queues, in case that
# getPriority does not see messages dequeueMessage would process. # getPriority does not see messages dequeueMessage would process.
activity_list = activity_dict.values() activity_list = ensure_list(activity_dict.values())
def sort_key(activity): def sort_key(activity):
return activity.getPriority(self, processing_node, return activity.getPriority(self, processing_node,
node_family_id_set) node_family_id_set)
...@@ -1390,7 +1391,7 @@ class ActivityTool (BaseTool): ...@@ -1390,7 +1391,7 @@ class ActivityTool (BaseTool):
path = None if obj is None else '/'.join(obj.getPhysicalPath()) path = None if obj is None else '/'.join(obj.getPhysicalPath())
db = self.getSQLConnection() db = self.getSQLConnection()
quote = db.string_literal quote = db.string_literal
return bool(db.query("(%s)" % ") UNION ALL (".join( return bool(db.query(b"(%s)" % b") UNION ALL (".join(
activity.hasActivitySQL(quote, path=path, **kw) activity.hasActivitySQL(quote, path=path, **kw)
for activity in activity_dict.itervalues()))[1]) for activity in activity_dict.itervalues()))[1])
......
...@@ -111,6 +111,7 @@ from Shared.DC.ZRDB.TM import TM ...@@ -111,6 +111,7 @@ from Shared.DC.ZRDB.TM import TM
from DateTime import DateTime from DateTime import DateTime
from zLOG import LOG, ERROR, WARNING from zLOG import LOG, ERROR, WARNING
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from Products.ERP5Type.Utils import str2bytes
hosed_connection = ( hosed_connection = (
CR.SERVER_GONE_ERROR, CR.SERVER_GONE_ERROR,
...@@ -203,7 +204,7 @@ def ord_or_None(s): ...@@ -203,7 +204,7 @@ def ord_or_None(s):
return ord(s) return ord(s)
match_select = re.compile( match_select = re.compile(
r'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)', rb'(?:SET\s+STATEMENT\s+(.+?)\s+FOR\s+)?SELECT\s+(.+)',
re.IGNORECASE | re.DOTALL, re.IGNORECASE | re.DOTALL,
).match ).match
...@@ -417,12 +418,14 @@ class DB(TM): ...@@ -417,12 +418,14 @@ class DB(TM):
"""Execute 'query_string' and return at most 'max_rows'.""" """Execute 'query_string' and return at most 'max_rows'."""
self._use_TM and self._register() self._use_TM and self._register()
desc = None desc = None
if not isinstance(query_string, bytes):
query_string = str2bytes(query_string)
# XXX deal with a typical mistake that the user appends # XXX deal with a typical mistake that the user appends
# an unnecessary and rather harmful semicolon at the end. # an unnecessary and rather harmful semicolon at the end.
# Unfortunately, MySQLdb does not want to be graceful. # Unfortunately, MySQLdb does not want to be graceful.
if query_string[-1:] == ';': if query_string[-1:] == b';':
query_string = query_string[:-1] query_string = query_string[:-1]
for qs in query_string.split('\0'): for qs in query_string.split(b'\0'):
qs = qs.strip() qs = qs.strip()
if qs: if qs:
select_match = match_select(qs) select_match = match_select(qs)
...@@ -431,12 +434,12 @@ class DB(TM): ...@@ -431,12 +434,12 @@ class DB(TM):
if query_timeout is not None: if query_timeout is not None:
statement, select = select_match.groups() statement, select = select_match.groups()
if statement: if statement:
statement += ", max_statement_time=%f" % query_timeout statement += b", max_statement_time=%f" % query_timeout
else: else:
statement = "max_statement_time=%f" % query_timeout statement = b"max_statement_time=%f" % query_timeout
qs = "SET STATEMENT %s FOR SELECT %s" % (statement, select) qs = b"SET STATEMENT %s FOR SELECT %s" % (statement, select)
if max_rows: if max_rows:
qs = "%s LIMIT %d" % (qs, max_rows) qs = b"%s LIMIT %d" % (qs, max_rows)
c = self._query(qs) c = self._query(qs)
if c: if c:
if desc is not None is not c.describe(): if desc is not None is not c.describe():
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment