############################################################################## # # Copyright (c) 2008-2009 Nexedi SARL and Contributors. All Rights Reserved. # Jean-Paul Smets-Solanes <jp@nexedi.com> # Vincent Pelletier <vincent@nexedi.com> # # WARNING: This program as such is intended to be used by professional # programmers who take the whole responsability of assessing all potential # consequences resulting from its eventual inadequacies and bugs # End users who are looking for a ready-to-use solution with commercial # garantees and support are strongly adviced to contract a Free Software # Service Company # # This program is Free Software; you can redistribute it and/or # modify it under the terms of the GNU General Public License # as published by the Free Software Foundation; either version 2 # of the License, or (at your option) any later version. # # This program is distributed in the hope that it will be useful, # but WITHOUT ANY WARRANTY; without even the implied warranty of # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the # GNU General Public License for more details. # # You should have received a copy of the GNU General Public License # along with this program; if not, write to the Free Software # Foundation, Inc., 59 Temple Place - Suite 330, Boston, MA 02111-1307, USA. # ############################################################################## from zLOG import LOG from Interface.ISQLExpression import ISQLExpression from Interface.Verify import verifyClass from types import NoneType from SQLCatalog import profiler_decorator SQL_LIST_SEPARATOR = ', ' SQL_TABLE_FORMAT = '%s' # XXX: should be changed to '`%s`', but this breaks some ZSQLMethods. SQL_SELECT_ALIAS_FORMAT = '%s AS `%s`' """ TODO: - change table_alias_dict in internals to represent computed tables: ie: '(SELECT * FROM `bar` WHERE `baz` = "hoge") AS `foo`' '`foo` LEFT JOIN `bar` WHERE (`baz` = "hoge")' """ # Set to true to keep a reference to the query which created us. # Set to false to avoid keeping a reference to an object. DEBUG = True def defaultDict(value): if value is None: return {} assert isinstance(value, dict) return value.copy() class SQLExpression(object): __implements__ = ISQLExpression @profiler_decorator def __init__(self, query, table_alias_dict=None, order_by_list=(), order_by_dict=None, group_by_list=(), where_expression=None, where_expression_operator=None, sql_expression_list=(), select_dict=None, limit=None, from_expression=None, can_merge_select_dict=False): if DEBUG: self.query = query self.table_alias_dict = defaultDict(table_alias_dict) self.order_by_list = list(order_by_list) self.group_by_list = list(group_by_list) self.order_by_dict = defaultDict(order_by_dict) self.can_merge_select_dict = can_merge_select_dict # Only one of (where_expression, where_expression_operator) must be given (never both) assert None in (where_expression, where_expression_operator) # Exactly one of (where_expression, where_expression_operator) must be given, except if sql_expression_list is given and contains exactly one entry assert where_expression is not None or where_expression_operator is not None or (sql_expression_list is not None and len(sql_expression_list) == 1) # where_expression must be a basestring instance if given assert isinstance(where_expression, (NoneType, basestring)) # where_expression_operator must be 'and', 'or' or 'not' (if given) assert where_expression_operator in (None, 'and', 'or', 'not'), where_expression_operator self.where_expression = where_expression self.where_expression_operator = where_expression_operator # Exactly one of (where_expression, sql_expression_list) must be given (XXX: duplicate of previous conditions ?) assert where_expression is not None or sql_expression_list is not None if isinstance(sql_expression_list, (list, tuple)): sql_expression_list = [x for x in sql_expression_list if x is not None] self.sql_expression_list = list(sql_expression_list) self.select_dict = defaultDict(select_dict) if limit is None: self.limit = () elif isinstance(limit, (list, tuple)): if len(limit) < 3: self.limit = limit else: raise ValueError, 'Unrecognized "limit" value: %r' % (limit, ) else: self.limit = (limit, ) if from_expression is not None: LOG('SQLExpression', 0, 'Providing a from_expression is deprecated.') self.from_expression = from_expression @profiler_decorator def getTableAliasDict(self): """ Returns a dictionary: key: table alias (string) value: table name (string) If there are nested SQLExpressions, it aggregates their mappings and checks that they don't alias different table with the same name. If they do, it raises a ValueError. """ result = self.table_alias_dict.copy() for sql_expression in self.sql_expression_list: for alias, table_name in sql_expression.getTableAliasDict().iteritems(): existing_value = result.get(alias) if existing_value not in (None, table_name): message = '%r is a known alias for table %r, can\'t alias it now to table %r' % (alias, existing_value, table_name) if DEBUG: message = message + '. I was created by %r, and I am working on %r (%r) out of [%s]' % ( self.query, sql_expression, sql_expression.query, ', '.join('%r (%r)' % (x, x.query) for x in self.sql_expression_list)) raise ValueError, message result[alias] = table_name return result @profiler_decorator def getFromExpression(self): """ Returns a string. If there are nested SQLExpression, it checks that they either don't define any from_expression or the exact same from_expression. Otherwise, it raises a ValueError. """ result = self.from_expression for sql_expression in self.sql_expression_list: from_expression = sql_expression.getFromExpression() if None not in (result, from_expression): message = 'I don\'t know how to merge from_expressions' if DEBUG: message = message + '. I was created by %r, and I am working on %r (%r) out of [%s]' % ( self.query, sql_expression, sql_expression.query, ', '.join('%r (%r)' % (x, x.query) for x in self.sql_expression_list)) raise ValueError, message return result @profiler_decorator def getOrderByList(self): """ Returns a list of strings. If there are nested SQLExpression, it checks that they don't define sorts for columns which are already sorted. If they do, it raises a ValueError. """ result = self.order_by_list[:] known_column_set = set([x[0] for x in result]) for sql_expression in self.sql_expression_list: for order_by in sql_expression.getOrderByList(): if order_by[0] in known_column_set: raise ValueError, 'I don\'t know how to merge order_by yet' else: result.append(order_by) known_column_set.add(order_by[0]) return result @profiler_decorator def getOrderByDict(self): result_dict = self.order_by_dict.copy() for sql_expression in self.sql_expression_list: order_by_dict = sql_expression.getOrderByDict() for key, value in order_by_dict.iteritems(): if key in result_dict and value != result_dict[key]: message = 'I don\'t know how to merge order_by_dict with ' \ 'conflicting entries for key %r: %r vs. %r' % (key, result_dict[key], value) if DEBUG: message = message + '. I was created by %r, and I am working on %r (%r) out of [%s]' % ( self.query, sql_expression, sql_expression.query, ', '.join('%r (%r)' % (x, x.query) for x in self.sql_expression_list)) raise ValueError, message result_dict.update(order_by_dict) return result_dict @profiler_decorator def getOrderByExpression(self): """ Returns a string. Returns a rendered "order by" expression. See getOrderByList. """ order_by_dict = self.getOrderByDict() get = order_by_dict.get return SQL_LIST_SEPARATOR.join(get(x, str(x)) \ for x in self.getOrderByList()) @profiler_decorator def getWhereExpression(self): """ Returns a string. Returns a rendered "where" expression. """ if self.where_expression is not None: result = self.where_expression else: if self.where_expression_operator == 'not': assert len(self.sql_expression_list) == 1 result = '(NOT %s)' % (self.sql_expression_list[0].getWhereExpression()) elif len(self.sql_expression_list) == 1: result = self.sql_expression_list[0].getWhereExpression() elif len(self.sql_expression_list) == 0: result = '(1)' else: operator = '\n ' + self.where_expression_operator.upper() + ' ' result = '(%s)' % (operator.join(x.getWhereExpression() for x in self.sql_expression_list), ) return result @profiler_decorator def getLimit(self): """ Returns a list of 1 or 2 items (int or string). If there are nested SQLExpression, it checks that they either don't define any limit or the exact same limit. Otherwise it raises a ValueError. """ result = list(self.limit) for sql_expression in self.sql_expression_list: other_limit = sql_expression.getLimit() if other_limit not in ([], result): message = 'I don\'t know how to merge limits yet' if DEBUG: message = message + '. I was created by %r, and I am working on %r (%r) out of [%s]' % ( self.query, sql_expression, sql_expression.query, ', '.join('%r (%r)' % (x, x.query) for x in self.sql_expression_list)) raise ValueError, message return result @profiler_decorator def getLimitExpression(self): """ Returns a string. Returns a rendered "limit" expression. See getLimit. """ return SQL_LIST_SEPARATOR.join(str(x) for x in self.getLimit()) @profiler_decorator def getGroupByset(self): """ Returns a set of strings. If there are nested SQLExpression, it merges (union of sets) them with local value. """ result = set(self.group_by_list) for sql_expression in self.sql_expression_list: result.update(sql_expression.getGroupByset()) return result @profiler_decorator def getGroupByExpression(self): """ Returns a string. Returns a rendered "group by" expression. See getGroupBySet. """ return SQL_LIST_SEPARATOR.join(self.getGroupByset()) def canMergeSelectDict(self): return self.can_merge_select_dict @profiler_decorator def _getSelectDict(self): result = self.select_dict.copy() mergeable_set = set() if self.canMergeSelectDict(): mergeable_set.update(result) for sql_expression in self.sql_expression_list: can_merge_sql_expression = sql_expression.canMergeSelectDict() sql_expression_select_dict, sql_expression_mergeable_set = \ sql_expression._getSelectDict() mergeable_set.update(sql_expression_mergeable_set) for alias, column in sql_expression_select_dict.iteritems(): existing_value = result.get(alias) if existing_value not in (None, column): if can_merge_sql_expression and alias in mergeable_set: # Custom conflict resolution column = '%s + %s' % (existing_value, column) else: import pdb; pdb.set_trace() message = '%r is a known alias for column %r, can\'t alias it now to column %r' % (alias, existing_value, column) if DEBUG: message = message + '. I was created by %r, and I am working on %r (%r) out of [%s]' % ( self.query, sql_expression, sql_expression.query, ', '.join('%r (%r)' % (x, x.query) for x in self.sql_expression_list)) raise ValueError, message result[alias] = (column, can_merge_sql_expression) if can_merge_sql_expression: mergeable_set.add(alias) return result, mergeable_set @profiler_decorator def getSelectDict(self): """ Returns a dict: key: alias (string) value: column (string) or None If there are nested SQLExpression, it aggregates their mappings and checks that they don't alias different columns with the same name. If they do, it raises a ValueError. """ return self._getSelectDict()[0] @profiler_decorator def getSelectExpression(self): """ Returns a string. Returns a rendered "select" expression. See getSelectDict. """ return SQL_LIST_SEPARATOR.join( SQL_SELECT_ALIAS_FORMAT % (column, alias) for alias, column in self.getSelectDict().iteritems()) @profiler_decorator def asSQLExpressionDict(self): table_alias_dict = self.getTableAliasDict() from_table_list = [] append = from_table_list.append for alias, table in table_alias_dict.iteritems(): append((SQL_TABLE_FORMAT % (alias, ), SQL_TABLE_FORMAT % (table, ))) from_expression_dict = self.getFromExpression() if from_expression_dict is not None: from_expression = SQL_LIST_SEPARATOR.join( from_expression_dict.get(table, '`%s` AS `%s`' % (table, alias)) for alias, table in table_alias_dict.iteritems()) else: from_expression = None return { 'where_expression': self.getWhereExpression(), 'order_by_expression': self.getOrderByExpression(), 'from_table_list': from_table_list, 'from_expression': from_expression, 'limit_expression': self.getLimitExpression(), 'select_expression': self.getSelectExpression(), 'group_by_expression': self.getGroupByExpression() } verifyClass(ISQLExpression, SQLExpression)