Commit 9f426f26 authored by Jérome Perrin's avatar Jérome Perrin

Encoding errors with catalog searches

This fixes some regressions introduced by nexedi/erp5!1545 when searching the catalog for non ascii text

See merge request nexedi/erp5!1752
parents cd24fb39 7d5ea201
############################################################################## ##############################################################################
# # coding: utf-8
# Copyright (c) 2005 Nexedi SARL and Contributors. All Rights Reserved. # Copyright (c) 2005 Nexedi SARL and Contributors. All Rights Reserved.
# Kevin Deldycke <kevin_AT_nexedi_DOT_com> # Kevin Deldycke <kevin_AT_nexedi_DOT_com>
# #
...@@ -29,6 +29,7 @@ ...@@ -29,6 +29,7 @@
from collections import defaultdict from collections import defaultdict
import os import os
import six
from DateTime import DateTime from DateTime import DateTime
from Products.ERP5Type.Utils import convertToUpperCase from Products.ERP5Type.Utils import convertToUpperCase
from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase from Products.ERP5Type.tests.ERP5TypeTestCase import ERP5TypeTestCase
...@@ -1103,6 +1104,74 @@ class TestERP5Base(ERP5TypeTestCase): ...@@ -1103,6 +1104,74 @@ class TestERP5Base(ERP5TypeTestCase):
translated_portal_type='Personne')]) translated_portal_type='Personne')])
self.abort() self.abort()
def test_standard_translated_related_keys_non_ascii(self):
# make sure we can search by "translated_validation_state_title" and
# "translated_portal_type" with non ascii translations
message_catalog = self.portal.Localizer.erp5_ui
lang = 'fr'
if lang not in [x['id'] for x in
self.portal.Localizer.get_languages_map()]:
self.portal.Localizer.manage_addLanguage(lang)
message_catalog.gettext('Draft', add=1)
message_catalog.gettext('Person', add=1)
message_catalog.message_edit('Draft', lang, u'Broüillon', '')
message_catalog.message_edit('Person', lang, u'Pérsonne', '')
self.portal.ERP5Site_updateTranslationTable()
person_1 = self.portal.person_module.newContent(portal_type='Person', first_name='名前')
person_1.validate()
person_2 = self.portal.person_module.newContent(portal_type='Person')
organisation = self.portal.organisation_module.newContent(
portal_type='Organisation')
self.tic()
# patch the method, we'll abort later
self.portal.Localizer.get_selected_language = lambda: lang
self.assertEqual({person_1, person_2}, {x.getObject()
for x in self.portal.portal_catalog(translated_portal_type='Pérsonne')})
self.assertEqual({person_2, organisation}, {x.getObject()
for x in self.portal.portal_catalog(
translated_validation_state_title='Broüillon',
portal_type=('Person', 'Organisation'))})
self.assertEqual([person_2],
[x.getObject() for x in
self.portal.portal_catalog(translated_validation_state_title='Broüillon',
translated_portal_type='Pérsonne')])
self.assertEqual([person_1],
[x.getObject() for x in
self.portal.portal_catalog(title='名前',
translated_portal_type='Pérsonne')])
if six.PY2:
# listbox (for example) searches catalog with unicode
self.assertEqual({person_1, person_2}, {x.getObject()
for x in self.portal.portal_catalog(translated_portal_type=u'Pérsonne')})
self.assertEqual({person_2, organisation}, {x.getObject()
for x in self.portal.portal_catalog(
translated_validation_state_title=u'Broüillon',
portal_type=('Person', 'Organisation'))})
self.assertEqual([person_2],
[x.getObject() for x in
self.portal.portal_catalog(translated_validation_state_title=u'Broüillon',
translated_portal_type=u'Pérsonne')])
self.assertEqual([person_1],
[x.getObject() for x in
self.portal.portal_catalog(title=u'名前',
translated_portal_type='Pérsonne')])
self.assertEqual([person_1],
[x.getObject() for x in
self.portal.portal_catalog(title='名前',
translated_portal_type=u'Pérsonne')])
self.assertEqual([person_1],
[x.getObject() for x in
self.portal.portal_catalog(title=u'名前',
translated_portal_type=u'Pérsonne')])
self.abort()
def test_Base_createCloneDocument(self): def test_Base_createCloneDocument(self):
module = self.portal.person_module module = self.portal.person_module
module.manage_permission('Add portal content', ['Member'], 0) module.manage_permission('Add portal content', ['Member'], 0)
......
...@@ -633,6 +633,10 @@ class TestERP5Catalog(ERP5TypeTestCase, LogInterceptor): ...@@ -633,6 +633,10 @@ class TestERP5Catalog(ERP5TypeTestCase, LogInterceptor):
folder_object_list = [x.getObject().getId() for x in folder_object_list = [x.getObject().getId() for x in
person_module.searchFolder(title=title)] person_module.searchFolder(title=title)]
self.assertEqual(['5'],folder_object_list) self.assertEqual(['5'],folder_object_list)
if six.PY2:
folder_object_list = [x.getObject().getId() for x in
person_module.searchFolder(title=unicode(title, 'utf-8'))]
self.assertEqual(['5'],folder_object_list)
def test_Collation(self): def test_Collation(self):
person_module = self.getPersonModule() person_module = self.getPersonModule()
...@@ -654,6 +658,10 @@ class TestERP5Catalog(ERP5TypeTestCase, LogInterceptor): ...@@ -654,6 +658,10 @@ class TestERP5Catalog(ERP5TypeTestCase, LogInterceptor):
person_module.searchFolder(title='sebastien')] person_module.searchFolder(title='sebastien')]
self.assertEqual(['5'],folder_object_list) self.assertEqual(['5'],folder_object_list)
if six.PY2:
folder_object_list = [x.getObject().getId() for x in
person_module.searchFolder(title=u'Sebastien')]
self.assertEqual(['5'],folder_object_list)
def test_20_SearchFolderWithDynamicRelatedKey(self): def test_20_SearchFolderWithDynamicRelatedKey(self):
# Create some objects # Create some objects
......
...@@ -30,7 +30,6 @@ from Acquisition import aq_base, aq_parent ...@@ -30,7 +30,6 @@ from Acquisition import aq_base, aq_parent
from zLOG import LOG, INFO, ERROR from zLOG import LOG, INFO, ERROR
from io import BytesIO from io import BytesIO
from Products.ERP5Type import Permissions from Products.ERP5Type import Permissions
from Products.ERP5Type.Utils import str2bytes
security = ClassSecurityInfo() security = ClassSecurityInfo()
DA.security = security DA.security = security
...@@ -207,7 +206,7 @@ def DA__call__(self, REQUEST=None, __ick__=None, src__=0, test__=0, **kw): ...@@ -207,7 +206,7 @@ def DA__call__(self, REQUEST=None, __ick__=None, src__=0, test__=0, **kw):
security=getSecurityManager() security=getSecurityManager()
security.addContext(self) security.addContext(self)
try: try:
query = str2bytes(self.template(p, **argdata)) query = self.template(p, **argdata)
except TypeError as msg: except TypeError as msg:
msg = str(msg) msg = str(msg)
if 'client' in msg: if 'client' in msg:
...@@ -223,8 +222,6 @@ def DA__call__(self, REQUEST=None, __ick__=None, src__=0, test__=0, **kw): ...@@ -223,8 +222,6 @@ def DA__call__(self, REQUEST=None, __ick__=None, src__=0, test__=0, **kw):
result=self._cached_result(DB__, query, self.max_rows_, c) result=self._cached_result(DB__, query, self.max_rows_, c)
else: else:
try: try:
# if 'portal_ids' in query:
# LOG("DA query", INFO, "query = %s" %(query,))
result=DB__.query(query, self.max_rows_) result=DB__.query(query, self.max_rows_)
except: except:
LOG("DA call raise", ERROR, "DB = %s, c = %s, query = %s" %(DB__, c, query), error=True) LOG("DA call raise", ERROR, "DB = %s, c = %s, query = %s" %(DB__, c, query), error=True)
......
...@@ -89,6 +89,7 @@ $Id: DA.py,v 1.4 2001/08/09 20:16:36 adustman Exp $''' % database_type ...@@ -89,6 +89,7 @@ $Id: DA.py,v 1.4 2001/08/09 20:16:36 adustman Exp $''' % database_type
__version__='$Revision: 1.4 $'[11:-2] __version__='$Revision: 1.4 $'[11:-2]
import os import os
import six
from collections import defaultdict from collections import defaultdict
from weakref import WeakKeyDictionary from weakref import WeakKeyDictionary
import transaction import transaction
...@@ -172,6 +173,8 @@ class Connection(DABase.Connection): ...@@ -172,6 +173,8 @@ class Connection(DABase.Connection):
# any reason, that would generate an infinite loop. # any reason, that would generate an infinite loop.
self.connect(self.connection_string) self.connect(self.connection_string)
connection = self._v_database_connection connection = self._v_database_connection
if not isinstance(v, six.binary_type):
v = v.encode('utf-8')
return connection.string_literal(v) return connection.string_literal(v)
......
...@@ -111,7 +111,6 @@ from Shared.DC.ZRDB.TM import TM ...@@ -111,7 +111,6 @@ from Shared.DC.ZRDB.TM import TM
from DateTime import DateTime from DateTime import DateTime
from zLOG import LOG, ERROR, WARNING from zLOG import LOG, ERROR, WARNING
from ZODB.POSException import ConflictError from ZODB.POSException import ConflictError
from Products.ERP5Type.Utils import str2bytes
hosed_connection = ( hosed_connection = (
CR.SERVER_GONE_ERROR, CR.SERVER_GONE_ERROR,
...@@ -425,8 +424,8 @@ class DB(TM): ...@@ -425,8 +424,8 @@ class DB(TM):
"""Execute 'query_string' and return at most 'max_rows'.""" """Execute 'query_string' and return at most 'max_rows'."""
self._use_TM and self._register() self._use_TM and self._register()
desc = None desc = None
if not isinstance(query_string, bytes): if isinstance(query_string, six.text_type):
query_string = str2bytes(query_string) query_string = query_string.encode('utf-8')
# XXX deal with a typical mistake that the user appends # XXX deal with a typical mistake that the user appends
# an unnecessary and rather harmful semicolon at the end. # an unnecessary and rather harmful semicolon at the end.
# Unfortunately, MySQLdb does not want to be graceful. # Unfortunately, MySQLdb does not want to be graceful.
...@@ -466,6 +465,8 @@ class DB(TM): ...@@ -466,6 +465,8 @@ class DB(TM):
return items, result return items, result
def string_literal(self, s): def string_literal(self, s):
# This method accepts bytes or str with only ASCII characters
# and return bytes.
return self.db.string_literal(s) return self.db.string_literal(s)
def _begin(self, *ignored): def _begin(self, *ignored):
...@@ -649,6 +650,8 @@ class DeferredDB(DB): ...@@ -649,6 +650,8 @@ class DeferredDB(DB):
def query(self, query_string, max_rows=1000): def query(self, query_string, max_rows=1000):
self._register() self._register()
if isinstance(query_string, six.text_type):
query_string = query_string.encode('utf-8')
for qs in query_string.split(b'\0'): for qs in query_string.split(b'\0'):
qs = qs.strip() qs = qs.strip()
if qs: if qs:
......
Markdown is supported
0%
or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment