Ticket #17025: wherenode_refactor.diff

File wherenode_refactor.diff, 68.3 KB (added by Anssi Kääriäinen, 13 years ago)
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index be42d02..3070915 100644
    a b class QuerySet(object):  
    883883        Prepare the query for computing a result that contains aggregate annotations.
    884884        """
    885885        opts = self.model._meta
    886         if self.query.group_by is None:
     886        if not self.query.group_by:
    887887            field_names = [f.attname for f in opts.fields]
    888888            self.query.add_fields(field_names, False)
    889             self.query.set_group_by()
     889            self.query.group_by = True
    890890
    891891    def _prepare(self):
    892892        return self
    class ValuesQuerySet(QuerySet):  
    938938
    939939        if self._fields:
    940940            self.extra_names = []
     941            # We collect the aggregate names here from the backing query's
     942            # aggregates. These are those aggregates that will be in the
     943            # values list. We set the list here to [], so that in the end
     944            # of this method we can check "is None" and set the aggregate
     945            # mask of the backing query to empty indicating that we aren't
     946            # actually interested in any fields. Naturally if we collect
     947            # some aggregates, then the aggregate mask will be set to that.
    941948            self.aggregate_names = []
    942949            if not self.query.extra and not self.query.aggregates:
    943950                # Short cut - if there are no extra or aggregates, then
    class ValuesQuerySet(QuerySet):  
    946953            else:
    947954                self.query.default_cols = False
    948955                self.field_names = []
     956                # OK, we have a list of fields - now we split them into
     957                # fields which are aggregates, those which are from extra
     958                # and normal fields. Why?
    949959                for f in self._fields:
    950960                    # we inspect the full extra_select list since we might
    951961                    # be adding back an extra select item that we hadn't
    class ValuesQuerySet(QuerySet):  
    962972            self.field_names = [f.attname for f in self.model._meta.fields]
    963973            self.aggregate_names = None
    964974
     975        # Why can't we just keep the values we are interested in, pass that
     976        # into compiler, and let it do the final pruning?
    965977        self.query.select = []
    966978        if self.extra_names is not None:
    967979            self.query.set_extra_mask(self.extra_names)
    968980        self.query.add_fields(self.field_names, True)
     981        # Ok, if we are called without fields, this means we do keep the
     982        # aggregates.
    969983        if self.aggregate_names is not None:
    970984            self.query.set_aggregate_mask(self.aggregate_names)
    971985
    class ValuesQuerySet(QuerySet):  
    9971011        """
    9981012        Prepare the query for computing a result that contains aggregate annotations.
    9991013        """
    1000         self.query.set_group_by()
     1014        # This super call will add all the fields in the model into the query,
     1015        # or do nothing if group_by is set. We call it, but it will not do
     1016        # anything.
     1017        self.query.group_by = True
     1018        super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
    10011019
     1020        # Set the new additional aggregates into the aggregate mask.
    10021021        if self.aggregate_names is not None:
    10031022            self.aggregate_names.extend(aggregates)
    10041023            self.query.set_aggregate_mask(self.aggregate_names)
    10051024
    1006         super(ValuesQuerySet, self)._setup_aggregate_query(aggregates)
     1025
    10071026
    10081027    def _as_sql(self, connection):
    10091028        """
  • django/db/models/query_utils.py

    diff --git a/django/db/models/query_utils.py b/django/db/models/query_utils.py
    index a56ab5c..10e532d 100644
    a b class Q(tree.Node):  
    4343    def __init__(self, *args, **kwargs):
    4444        super(Q, self).__init__(children=list(args) + kwargs.items())
    4545
     46    def _new_instance(cls, children=None, connector=None, negated=False):
     47        obj = tree.Node(children, connector, negated)
     48        obj.__class__ = cls
     49        return obj
     50    _new_instance = classmethod(_new_instance)
     51
    4652    def _combine(self, other, conn):
    4753        if not isinstance(other, Q):
    4854            raise TypeError(other)
    4955        obj = type(self)()
    50         obj.add(self, conn)
    51         obj.add(other, conn)
     56        obj.connector = conn
     57        if len(self) == 1 and not self.negated:
     58            obj.add(self.children[0], conn)
     59        else:
     60            obj.add(self, conn)
     61        if len(other) == 1 and not other.negated:
     62            obj.add(other.children[0], conn)
     63        else:
     64            obj.add(other, conn)
    5265        return obj
    5366
    5467    def __or__(self, other):
    class Q(tree.Node):  
    5871        return self._combine(other, self.AND)
    5972
    6073    def __invert__(self):
    61         obj = type(self)()
    62         obj.add(self, self.AND)
     74        obj = self.clone()
    6375        obj.negate()
    6476        return obj
    6577
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 207bc0c..aef8483 100644
    a b  
     1import copy
    12"""
    23Classes to represent the default SQL aggregate functions
    34"""
    class Aggregate(object):  
    6970
    7071        self.field = tmp
    7172
     73    def clone(self):
     74        clone = copy.copy(self)
     75        clone.col = self.col[:]
     76        return clone
     77
    7278    def relabel_aliases(self, change_map):
    7379        if isinstance(self.col, (list, tuple)):
    7480            self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 6bf7de2..174cfd1 100644
    a b class SQLCompiler(object):  
    4848        self.quote_cache[name] = r
    4949        return r
    5050
     51    def where_to_sql(self):
     52        """
     53        This method is responsible for:
     54           - Removing always True / always False parts of the tree
     55           - Splitting the tree into having and where
     56           - Getting the group by columns from the having part of the query
     57           - And finally turning the remaining trees into SQL and params
     58
     59        Returns 3-tuple of the form:
     60           ((where, w_params), (having, h_params), having_group_by)
     61
     62        Where the having_group_by is a set of SQL snippets to add into the
     63        group by, for example ["T1".some_field]
     64        """
     65        # Prune the tree. If we are left with a tree that matches nothing
     66        # this EmptyResultSet will be risen.
     67        where = self.query.where.clone_internal()
     68        where.final_prune(self.quote_name_unless_alias, self.connection)
     69        if where.match_nothing:
     70            raise EmptyResultSet
     71        if self.query.aggregates:
     72            having = self.query.where_class()
     73            where.split_aggregates(having)
     74            where.prune_tree(); having.prune_tree()
     75            group_by = set(); having.get_group_by(group_by)
     76            return (where.as_sql(), having.as_sql(), group_by)
     77        else:
     78            return (where and where.as_sql() or ('', []),  ('', []), set())
     79
    5180    def as_sql(self, with_limits=True, with_col_aliases=False):
    5281        """
    5382        Creates the SQL for this query. Returns the SQL string and list of
    class SQLCompiler(object):  
    6897        from_, f_params = self.get_from_clause()
    6998
    7099        qn = self.quote_name_unless_alias
    71 
    72         where, w_params = self.query.where.as_sql(qn=qn, connection=self.connection)
    73         having, h_params = self.query.having.as_sql(qn=qn, connection=self.connection)
     100        where_tpl, having_tpl, having_group_by = self.where_to_sql()
     101        having, h_params = having_tpl
     102        where, w_params = where_tpl
     103       
    74104        params = []
    75105        for val in self.query.extra_select.itervalues():
    76106            params.extend(val[1])
    class SQLCompiler(object):  
    88118            result.append('WHERE %s' % where)
    89119            params.extend(w_params)
    90120
    91         grouping, gb_params = self.get_grouping()
    92         if grouping:
     121        grouping, gb_params = self.get_grouping(having_group_by)
     122        if self.query.group_by:
    93123            if ordering:
    94124                # If the backend can't group by PK (i.e., any database
    95125                # other than MySQL), then any fields mentioned in the
    class SQLCompiler(object):  
    101131                            gb_params.extend(col_params)
    102132            else:
    103133                ordering = self.connection.ops.force_no_ordering()
    104             result.append('GROUP BY %s' % ', '.join(grouping))
    105             params.extend(gb_params)
     134            if grouping:
     135                result.append('GROUP BY %s' % ', '.join(grouping))
     136                params.extend(gb_params)
    106137
    107138        if having:
    108139            result.append('HAVING %s' % having)
    109140            params.extend(h_params)
    110141
    111         if ordering:
     142        # This is a hack: we rely on the ordering for GROUP BY. Subqueries do
     143        # not use ordering, so instead of clearing the ordering, subqueries
     144        # flag the query as not using the ordering there is defined. This is
     145        # sure to bite us, and should be fixed. The real fix might be that
     146        # relying on doing .order_by() to get the wanted GROUP BY might just
     147        # need to be deprecated. Or maybe we should have a variable
     148        # ordering_group_by, making it explicit that we collect the order_by
     149        # GROUP BY clauses in different scope than the actual order by. But
     150        # that just sounds hacky. Or maybe just resurrect the query.group_by
     151        # set.
     152        if ordering and self.query.use_ordering:
    112153            result.append('ORDER BY %s' % ', '.join(ordering))
    113154
    114155        if with_limits:
    class SQLCompiler(object):  
    142183        """
    143184        obj = self.query.clone()
    144185        if obj.low_mark == 0 and obj.high_mark is None:
    145             # If there is no slicing in use, then we can safely drop all ordering
    146             obj.clear_ordering(True)
     186            # If there is no slicing in use, then we can safely drop all
     187            # ordering.
     188            # TODO: We rely on ordering to determine the GROUP BY clause.
     189            # So we keep the ordering, but tell the compiler not to append
     190            # it to the query, just to group by it. Refactor.
     191            obj.use_ordering = False
     192            obj.order_by = [f for f in self.query.order_by if f not in self.query.aggregates]
     193            # We essentially defined a group_by variable above. It seems clear
     194            # that we need a custom group_by variable, which we can then use
     195            # properly. This was in the original code.
    147196        obj.bump_prefix()
    148197        return obj.get_compiler(connection=self.connection).as_sql()
    149198
    class SQLCompiler(object):  
    474523                first = False
    475524        return result, []
    476525
    477     def get_grouping(self):
     526    def get_grouping(self, where_group_by):
    478527        """
    479528        Returns a tuple representing the SQL elements in the "group by" clause.
    480529        """
     530        if not self.query.group_by:
     531             return [], []
    481532        qn = self.quote_name_unless_alias
    482533        result, params = [], []
    483         if self.query.group_by is not None:
    484             if (len(self.query.model._meta.fields) == len(self.query.select) and
    485                 self.connection.features.allows_group_by_pk):
    486                 self.query.group_by = [
     534        group_by = where_group_by
     535        if (len(self.query.model._meta.fields) == len(self.query.select) and
     536            self.connection.features.allows_group_by_pk):
     537                group_by = set([
    487538                    (self.query.model._meta.db_table, self.query.model._meta.pk.column)
    488                 ]
    489 
    490             group_by = self.query.group_by or []
    491 
    492             extra_selects = []
    493             for extra_select, extra_params in self.query.extra_select.itervalues():
    494                 extra_selects.append(extra_select)
    495                 params.extend(extra_params)
    496             cols = (group_by + self.query.select +
    497                 self.query.related_select_cols + extra_selects)
    498             seen = set()
    499             for col in cols:
    500                 if col in seen:
    501                     continue
    502                 seen.add(col)
    503                 if isinstance(col, (list, tuple)):
    504                     result.append('%s.%s' % (qn(col[0]), qn(col[1])))
    505                 elif hasattr(col, 'as_sql'):
    506                     result.append(col.as_sql(qn, self.connection))
    507                 else:
    508                     result.append('(%s)' % str(col))
     539                ])
     540
     541        extra_selects = []
     542        for extra_select, extra_params in self.query.extra_select.itervalues():
     543            extra_selects.append(extra_select)
     544            params.extend(extra_params)
     545       
     546        cols = group_by.union(self.query.select +
     547            self.query.related_select_cols + extra_selects)
     548        for col in cols:
     549            if isinstance(col, (list, tuple)):
     550                result.append('%s.%s' % (qn(col[0]), qn(col[1])))
     551            elif hasattr(col, 'as_sql'):
     552                result.append(col.as_sql(qn, self.connection))
     553            else:
     554                result.append('(%s)' % str(col))
    509555        return result, params
    510556
    511557    def fill_related_selections(self, opts=None, root_alias=None, cur_depth=1,
    class SQLDeleteCompiler(SQLCompiler):  
    864910                "Can only delete from one table at a time."
    865911        qn = self.quote_name_unless_alias
    866912        result = ['DELETE FROM %s' % qn(self.query.tables[0])]
    867         where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
     913        where_tpl, _, _ = self.where_to_sql()
     914        where, params = where_tpl
    868915        result.append('WHERE %s' % where)
    869916        return ' '.join(result), tuple(params)
    870917
    class SQLUpdateCompiler(SQLCompiler):  
    909956        if not values:
    910957            return '', ()
    911958        result.append(', '.join(values))
    912         where, params = self.query.where.as_sql(qn=qn, connection=self.connection)
     959        where_tpl, _, _ = self.where_to_sql()
     960        where, params = where_tpl
    913961        if where:
    914962            result.append('WHERE %s' % where)
    915963        return ' '.join(result), tuple(update_params + params)
  • django/db/models/sql/datastructures.py

    diff --git a/django/db/models/sql/datastructures.py b/django/db/models/sql/datastructures.py
    index 92d64e1..b8e06da 100644
    a b the SQL domain.  
    66class EmptyResultSet(Exception):
    77    pass
    88
    9 class FullResultSet(Exception):
    10     pass
    11 
    129class MultiJoin(Exception):
    1310    """
    1411    Used by join construction code to indicate the point at which a
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index 61fd2be..453df65 100644
    a b from django.db.models.sql import aggregates as base_aggregates_module  
    2020from django.db.models.sql.constants import *
    2121from django.db.models.sql.datastructures import EmptyResultSet, Empty, MultiJoin
    2222from django.db.models.sql.expressions import SQLEvaluator
    23 from django.db.models.sql.where import (WhereNode, Constraint, EverythingNode,
    24     ExtraWhere, AND, OR)
     23from django.db.models.sql.where import (WhereNode, Constraint, ExtraWhere,
     24    AND, OR)
    2525from django.core.exceptions import FieldError
    2626
    2727__all__ = ['Query', 'RawQuery']
    class RawQuery(object):  
    4747        return RawQuery(self.sql, using, params=self.params)
    4848
    4949    def convert_values(self, value, field, connection):
    50         """Convert the database-returned value into a type that is consistent
     50        """
     51        Convert the database-returned value into a type that is consistent
    5152        across database backends.
    5253
    5354        By default, this defers to the underlying backend operations, but
    class RawQuery(object):  
    8182        self.cursor = connections[self.using].cursor()
    8283        self.cursor.execute(self.sql, self.params)
    8384
    84 
    8585class Query(object):
    8686    """
    8787    A single SQL query.
    class Query(object):  
    121121        self.tables = []    # Aliases in the order they are created.
    122122        self.where = where()
    123123        self.where_class = where
    124         self.group_by = None
    125         self.having = where()
     124        self.use_ordering = True
    126125        self.order_by = []
    127126        self.low_mark, self.high_mark = 0, None  # Used for offset/limit
    128127        self.distinct = False
    class Query(object):  
    131130        self.select_related = False
    132131        self.related_select_cols = []
    133132
    134         # SQL aggregate-related attributes
     133        # Here is some random rambling about aggregates. First, the current
     134        # implementation is pretty darned hard to understand. There is little
     135        # to no documentation, and there is usage of these variables all over
     136        # the place.
     137        #
     138        # So, first, here is a list of what we will need for successful
     139        # aggregate queries. First we will naturally need the actual
     140        # aggregates, these are stored in self.aggregates, and this is
     141        # pretty much clear thing.
     142        #
     143        # Next, we will need the fields to group by with. This we shouldn't
     144        # keep record of, as the set of fields to group by is the wanted
     145        # select fields, having fields, order fields, and extra select fields.
     146        # We can and do compute these when the query gets executed. They will
     147        # come as a side product when preparing other parts of the query for
     148        # execution.
     149
     150        # Still, we need the having clause. This is under control now, as
     151        # we have gotten rid of the query.having. We split the query.where
     152        # into having and where based on the actual need.
     153
     154        # The current implementation of blindly adding the fields to the query
     155        # is a bit dangerous - it leads to potential multijoins which will
     156        # result in duplicate rows for aggregation. This is a hard problem to
     157        # solve correctly.
     158
     159        # So, what do these variables below represent? aggregates is clear,
     160        # it represents the aggregates in the query. Next comes group_by, this
     161        # is a variable that represents if we should do a GROUP BY at all.
     162
     163        # Then we have self.aggregate_select_mask. This is the fields actually
     164        # present in the query. Why self.aggregates do not get changed when we
     165        # change the aggregates actually in the query is unknown. The actual
     166        # fields in the query can be accessed through self.aggregate_select,
     167        # which is a property showing only the fields in the aggregate select
     168        # mask.
     169
     170        # To make things more complicated, db/query.py keeps its own variable
     171        # aggregate_names. It seems this is collected from the aggregate_select
     172        # property, and then used to add fields to the query. This is just
     173        # speculation, I do not understand completely what it does.
     174
    135175        self.aggregates = SortedDict() # Maps alias -> SQL aggregate function
     176        self.group_by = False
    136177        self.aggregate_select_mask = None
    137178        self._aggregate_select_cache = None
    138179
    class Query(object):  
    254295        obj.dupe_avoidance = self.dupe_avoidance.copy()
    255296        obj.select = self.select[:]
    256297        obj.tables = self.tables[:]
    257         obj.where = copy.deepcopy(self.where, memo=memo)
     298        # We do not need to clone the leaf nodes - they are immutable until
     299        # the query is executed, or relabel_alias is called. In either case
     300        # we will take care of the copying where needed. This can be a major
     301        # speed optimization when the where three has a lot of leaf nodes.
     302        obj.where = self.where.clone_internal()
    258303        obj.where_class = self.where_class
    259         if self.group_by is None:
    260             obj.group_by = None
    261         else:
    262             obj.group_by = self.group_by[:]
    263         obj.having = copy.deepcopy(self.having, memo=memo)
    264304        obj.order_by = self.order_by[:]
     305        obj.use_ordering = self.use_ordering
    265306        obj.low_mark, obj.high_mark = self.low_mark, self.high_mark
    266307        obj.distinct = self.distinct
    267308        obj.select_for_update = self.select_for_update
    268309        obj.select_for_update_nowait = self.select_for_update_nowait
    269310        obj.select_related = self.select_related
    270311        obj.related_select_cols = []
    271         obj.aggregates = copy.deepcopy(self.aggregates, memo=memo)
     312        if self.aggregates:
     313            obj.aggregates = copy.deepcopy(self.aggregates, memo=memo)
     314        else:
     315            obj.aggregates = SortedDict()
    272316        if self.aggregate_select_mask is None:
    273317            obj.aggregate_select_mask = None
    274318        else:
    class Query(object):  
    279323        # It will get re-populated in the cloned queryset the next time it's
    280324        # used.
    281325        obj._aggregate_select_cache = None
     326        obj.group_by = self.group_by
    282327        obj.max_depth = self.max_depth
    283328        obj.extra = self.extra.copy()
    284329        if self.extra_select_mask is None:
    class Query(object):  
    291336            obj._extra_select_cache = self._extra_select_cache.copy()
    292337        obj.extra_tables = self.extra_tables
    293338        obj.extra_order_by = self.extra_order_by
    294         obj.deferred_loading = copy.deepcopy(self.deferred_loading, memo=memo)
     339        obj.deferred_loading = self.deferred_loading[0].copy(), self.deferred_loading[1]
    295340        if self.filter_is_sticky and self.used_aliases:
    296341            obj.used_aliases = self.used_aliases.copy()
    297342        else:
    class Query(object):  
    343388        # If there is a group by clause, aggregating does not add useful
    344389        # information but retrieves only the first row. Aggregate
    345390        # over the subquery instead.
    346         if self.group_by is not None:
     391        if self.group_by:
    347392            from django.db.models.sql.subqueries import AggregateQuery
    348393            query = AggregateQuery(self.model)
    349394
    class Query(object):  
    406451                obj.add_subquery(subquery, using=using)
    407452            except EmptyResultSet:
    408453                # add_subquery evaluates the query, if it's an EmptyResultSet
    409                 # then there are can be no results, and therefore there the
    410                 # count is obviously 0
     454                # then there can be no results. Therefore the count is 0.
    411455                return 0
    412456
    413457        obj.add_count_column()
    class Query(object):  
    499543                if self.alias_refcount.get(alias) or rhs.alias_refcount.get(alias):
    500544                    self.promote_alias(alias, True)
    501545
    502         # Now relabel a copy of the rhs where-clause and add it to the current
    503         # one.
    504         if rhs.where:
    505             w = copy.deepcopy(rhs.where)
    506             w.relabel_aliases(change_map)
    507             if not self.where:
    508                 # Since 'self' matches everything, add an explicit "include
    509                 # everything" where-constraint so that connections between the
    510                 # where clauses won't exclude valid results.
    511                 self.where.add(EverythingNode(), AND)
    512         elif self.where:
    513             # rhs has an empty where clause.
    514             w = self.where_class()
    515             w.add(EverythingNode(), AND)
     546        if connector == OR and (not self.where or not rhs.where):
     547            # One of the two sides matches everything and the connector is OR.
     548            # This means the new where condition must match everything.
     549            self.where = self.where_class()
    516550        else:
    517             w = self.where_class()
    518         self.where.add(w, connector)
     551            rhs_where = rhs.where.clone()
     552            rhs_where.relabel_aliases(change_map)
     553            self.where = self.where_class([self.where, rhs_where], connector)
     554            # the root node's connector must always be AND
     555            if self.where.connector == OR:
     556                self.where = self.where_class([self.where])
     557        self.where.prune_tree()
    519558
    520559        # Selection columns and extra extensions are those provided by 'rhs'.
    521560        self.select = []
    class Query(object):  
    735774        assert set(change_map.keys()).intersection(set(change_map.values())) == set()
    736775
    737776        # 1. Update references in "select" (normal columns plus aliases),
    738         # "group by", "where" and "having".
     777        # "group by" and  "where"
    739778        self.where.relabel_aliases(change_map)
    740         self.having.relabel_aliases(change_map)
    741         for columns in [self.select, self.group_by or []]:
     779        for columns in [self.select]:
    742780            for pos, col in enumerate(columns):
    743781                if isinstance(col, (list, tuple)):
    744782                    old_alias = col[0]
    class Query(object):  
    803841        The 'exceptions' parameter is a container that holds alias names which
    804842        should not be changed.
    805843        """
     844        # We must make sure the leaf nodes of the where tree will be cloned,
     845        # as they will be relabeled.
     846        self.where = self.where.clone()
     847
    806848        current = ord(self.alias_prefix)
    807849        assert current < ord('Z')
    808850        prefix = chr(current + 1)
    class Query(object):  
    952994                self.unref_alias(alias)
    953995        self.included_inherited_models = {}
    954996
    955     def need_force_having(self, q_object):
    956         """
    957         Returns whether or not all elements of this q_object need to be put
    958         together in the HAVING clause.
    959         """
    960         for child in q_object.children:
    961             if isinstance(child, Node):
    962                 if self.need_force_having(child):
    963                     return True
    964             else:
    965                 if child[0].split(LOOKUP_SEP)[0] in self.aggregates:
    966                     return True
    967         return False
    968 
    969997    def add_aggregate(self, aggregate, model, alias, is_summary):
    970998        """
    971999        Adds a single aggregate expression to the Query
    class Query(object):  
    9821010                    aggregate.name, field_name, field_name))
    9831011        elif ((len(field_list) > 1) or
    9841012            (field_list[0] not in [i.name for i in opts.fields]) or
    985             self.group_by is None or
     1013            not self.group_by or
    9861014            not is_summary):
    9871015            # If:
    9881016            #   - the field descriptor has more than one part (foo__bar), or
    class Query(object):  
    10141042        # Add the aggregate to the query
    10151043        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
    10161044
     1045    def add_where_leaf(self, data, negated=False):
     1046        leaf_class = self.where.leaf_class()
     1047        self.where.add(leaf_class(data, negated), AND)
     1048
    10171049    def add_filter(self, filter_expr, connector=AND, negate=False, trim=False,
    1018             can_reuse=None, process_extras=True, force_having=False):
     1050            can_reuse=None, process_extras=True):
    10191051        """
    10201052        Add a single filter to the query. The 'filter_expr' is a pair:
    10211053        (filter_string, value). E.g. ('name__contains', 'fred')
    class Query(object):  
    10531085
    10541086        # By default, this is a WHERE clause. If an aggregate is referenced
    10551087        # in the value, the filter will be promoted to a HAVING
    1056         having_clause = False
    10571088
    10581089        # Interpret '__exact=None' as the sql 'is NULL'; otherwise, reject all
    10591090        # uses of None as a query value.
    class Query(object):  
    10671098        elif hasattr(value, 'evaluate'):
    10681099            # If value is a query expression, evaluate it
    10691100            value = SQLEvaluator(value, self)
    1070             having_clause = value.contains_aggregate
    1071 
    10721101        for alias, aggregate in self.aggregates.items():
    10731102            if alias in (parts[0], LOOKUP_SEP.join(parts)):
    1074                 entry = self.where_class()
    1075                 entry.add((aggregate, lookup_type, value), AND)
    1076                 if negate:
    1077                     entry.negate()
    1078                 self.having.add(entry, connector)
     1103                self.add_where_leaf((aggregate, lookup_type, value))
    10791104                return
    10801105
    10811106        opts = self.get_meta()
    class Query(object):  
    11421167            self.promote_alias_chain(join_it, join_promote)
    11431168            self.promote_alias_chain(table_it, table_promote or join_promote)
    11441169
    1145         if having_clause or force_having:
    1146             if (alias, col) not in self.group_by:
    1147                 self.group_by.append((alias, col))
    1148             self.having.add((Constraint(alias, col, field), lookup_type, value),
    1149                 connector)
    1150         else:
    1151             self.where.add((Constraint(alias, col, field), lookup_type, value),
    1152                 connector)
     1170        self.add_where_leaf((Constraint(alias, col, field), lookup_type, value))
    11531171
    11541172        if negate:
    11551173            self.promote_alias_chain(join_list)
    class Query(object):  
    11581176                    for alias in join_list:
    11591177                        if self.alias_map[alias][JOIN_TYPE] == self.LOUTER:
    11601178                            j_col = self.alias_map[alias][RHS_JOIN_COL]
    1161                             entry = self.where_class()
    1162                             entry.add(
     1179                            self.add_where_leaf(
    11631180                                (Constraint(alias, j_col, None), 'isnull', True),
    1164                                 AND
     1181                                negated=True
    11651182                            )
    1166                             entry.negate()
    1167                             self.where.add(entry, AND)
    11681183                            break
    11691184                if not (lookup_type == 'in'
    11701185                            and not hasattr(value, 'as_sql')
    class Query(object):  
    11741189                    # exclude the "foo__in=[]" case from this handling, because
    11751190                    # it's short-circuited in the Where class.
    11761191                    # We also need to handle the case where a subquery is provided
    1177                     self.where.add((Constraint(alias, col, None), 'isnull', False), AND)
     1192                    self.add_where_leaf((Constraint(alias, col, None), 'isnull', False))
    11781193
    11791194        if can_reuse is not None:
    11801195            can_reuse.update(join_list)
    class Query(object):  
    11831198                self.add_filter(filter, negate=negate, can_reuse=can_reuse,
    11841199                        process_extras=False)
    11851200
    1186     def add_q(self, q_object, used_aliases=None, force_having=False):
     1201    def add_q(self, q_object):
    11871202        """
    11881203        Adds a Q-object to the current filter.
    11891204
    11901205        Can also be used to add anything that has an 'add_to_query()' method.
     1206
     1207        In case add_to_query path is not executed, this method's main purpose
     1208        is to walk the q_object's internal nodes and manage the state of the
     1209        self.where. Leaf nodes will be handled by add_filter.
     1210
     1211        The self.where tree is managed by pushing new nodes to the tree. This
     1212        way self.where is always at the right node when add_filter adds items
     1213        to it.
     1214
     1215        We need to start a new subtree when:
     1216           - The connector of the q_object is different than the connector of
     1217             the where tree.
     1218           - The q_object is negated.
     1219
     1220        After call of this function with q_object=~Q(pk=1)&~Q(Q(pk=3)|Q(pk=2))
     1221        we should have the following tree:
     1222                      AND
     1223                     /   \
     1224                    NOT  NOT
     1225                     |     \
     1226                    pk=1   OR
     1227                          /  \
     1228                        pk=3 pk=2
     1229
     1230        This method will call recursively itself for those childrens of the
     1231        q_object which are Q-objs, and call add_filter for the leaf nodes.
     1232
     1233        We will add all filters to self.where. When the query is executed, the
     1234        tree is splitted into where and having clauses.
    11911235        """
    1192         if used_aliases is None:
    1193             used_aliases = self.used_aliases
     1236
     1237        # Complex custom objects are responsible for adding themselves.
    11941238        if hasattr(q_object, 'add_to_query'):
    1195             # Complex custom objects are responsible for adding themselves.
    1196             q_object.add_to_query(self, used_aliases)
    1197         else:
    1198             if self.where and q_object.connector != AND and len(q_object) > 1:
    1199                 self.where.start_subtree(AND)
    1200                 subtree = True
     1239            q_object.add_to_query(self, self.used_aliases)
     1240            return
     1241
     1242        # Start subtree if needed. At the end we check if anything got added
     1243        # into the subtrees. If not, prune em.
     1244        connector = q_object.connector
     1245        subtree_parent = None
     1246        if self.where.connector <> connector or q_object.negated:
     1247            subtree = self.where_class(connector=connector)
     1248            subtree_parent = self.where
     1249            self.where.add(subtree, self.where.connector)
     1250            self.where = subtree
     1251        if q_object.negated:
     1252            self.where.negate()
     1253
     1254        # Aliases that were newly added or not used at all need to
     1255        # be promoted to outer joins if they are nullable relations.
     1256        # (they shouldn't turn the whole conditional into the empty
     1257        # set just because they don't match anything). Take the
     1258        # before snapshot of the aliases.
     1259        if connector == OR:
     1260            refcounts_before = self.alias_refcount.copy()
     1261
     1262        for child in q_object.children:
     1263            if isinstance(child, Node):
     1264                self.add_q(child)
    12011265            else:
    1202                 subtree = False
    1203             connector = AND
    1204             if q_object.connector == OR and not force_having:
    1205                 force_having = self.need_force_having(q_object)
    1206             for child in q_object.children:
    1207                 if connector == OR:
    1208                     refcounts_before = self.alias_refcount.copy()
    1209                 if force_having:
    1210                     self.having.start_subtree(connector)
    1211                 else:
    1212                     self.where.start_subtree(connector)
    1213                 if isinstance(child, Node):
    1214                     self.add_q(child, used_aliases, force_having=force_having)
    1215                 else:
    1216                     self.add_filter(child, connector, q_object.negated,
    1217                             can_reuse=used_aliases, force_having=force_having)
    1218                 if force_having:
    1219                     self.having.end_subtree()
    1220                 else:
    1221                     self.where.end_subtree()
    1222 
    1223                 if connector == OR:
    1224                     # Aliases that were newly added or not used at all need to
    1225                     # be promoted to outer joins if they are nullable relations.
    1226                     # (they shouldn't turn the whole conditional into the empty
    1227                     # set just because they don't match anything).
    1228                     self.promote_unused_aliases(refcounts_before, used_aliases)
    1229                 connector = q_object.connector
    1230             if q_object.negated:
    1231                 self.where.negate()
    1232             if subtree:
    1233                 self.where.end_subtree()
    1234         if self.filter_is_sticky:
    1235             self.used_aliases = used_aliases
     1266                self.add_filter(child, connector, q_object.negated,
     1267                        can_reuse=self.used_aliases)
     1268
     1269        if connector == OR:
     1270            self.promote_unused_aliases(refcounts_before, self.used_aliases)
     1271        if subtree_parent:
     1272            self.where = subtree_parent
     1273        self.where.prune_tree()
    12361274
    12371275    def setup_joins(self, names, opts, alias, dupe_multis, allow_many=True,
    12381276            allow_explicit_fk=False, can_reuse=None, negate=False,
    class Query(object):  
    12541292        column (used for any 'where' constraint), the final 'opts' value and the
    12551293        list of tables joined.
    12561294        """
     1295       
    12571296        joins = [alias]
    12581297        last = [0]
    12591298        dupe_set = set()
    class Query(object):  
    15331572        # database from tripping over IN (...,NULL,...) selects and returning
    15341573        # nothing
    15351574        alias, col = query.select[0]
    1536         query.where.add((Constraint(alias, col, None), 'isnull', False), AND)
     1575        query.add_where_leaf((Constraint(alias, col, None), 'isnull', False))
    15371576
    15381577        self.add_filter(('%s__in' % prefix, query), negate=True, trim=True,
    15391578                can_reuse=can_reuse)
    class Query(object):  
    16591698        if force_empty:
    16601699            self.default_ordering = False
    16611700
    1662     def set_group_by(self):
    1663         """
    1664         Expands the GROUP BY clause required by the query.
    1665 
    1666         This will usually be the set of all non-aggregate fields in the
    1667         return data. If the database backend supports grouping by the
    1668         primary key, and the query would be equivalent, the optimization
    1669         will be made automatically.
    1670         """
    1671         self.group_by = []
    1672 
    1673         for sel in self.select:
    1674             self.group_by.append(sel)
    1675 
    16761701    def add_count_column(self):
    16771702        """
    16781703        Converts the query to do count(...) or count(distinct(pk)) in order to
    class Query(object):  
    17051730        # Clear out the select cache to reflect the new unmasked aggregates.
    17061731        self.aggregates = {None: count}
    17071732        self.set_aggregate_mask(None)
    1708         self.group_by = None
     1733        self.group_by = False
    17091734
    17101735    def add_select_related(self, fields):
    17111736        """
    class Query(object):  
    17481773            # This is order preserving, since self.extra_select is a SortedDict.
    17491774            self.extra.update(select_pairs)
    17501775        if where or params:
    1751             self.where.add(ExtraWhere(where, params), AND)
     1776            self.add_where_leaf(ExtraWhere(where, params))
    17521777        if tables:
    17531778            self.extra_tables += tuple(tables)
    17541779        if order_by:
    class Query(object):  
    18241849        target[model] = set([f.name for f in fields])
    18251850
    18261851    def set_aggregate_mask(self, names):
     1852        if 'n_authors' in self.aggregate_select and names is None:
     1853            import ipdb; ipdb.set_trace()
    18271854        "Set the mask of aggregates that will actually be returned by the SELECT"
    18281855        if names is None:
    18291856            self.aggregate_select_mask = None
  • django/db/models/sql/subqueries.py

    diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
    index 1b03647..3a7774c 100644
    a b class DeleteQuery(Query):  
    3737            field = self.model._meta.pk
    3838        for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
    3939            where = self.where_class()
    40             where.add((Constraint(None, field.column, field), 'in',
    41                     pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), AND)
     40            leaf = where.leaf_class()
     41            where.add(leaf((Constraint(None, field.column, field), 'in',
     42                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE])), AND)
    4243            self.do_query(self.model._meta.db_table, where, using=using)
    4344
    4445class UpdateQuery(Query):
    class UpdateQuery(Query):  
    7374        self.add_update_values(values)
    7475        for offset in range(0, len(pk_list), GET_ITERATOR_CHUNK_SIZE):
    7576            self.where = self.where_class()
    76             self.where.add((Constraint(None, pk_field.column, pk_field), 'in',
    77                     pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
    78                     AND)
     77            self.add_where_leaf((Constraint(None, pk_field.column, pk_field), 'in',
     78                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]))
    7979            self.get_compiler(using).execute_sql(None)
    8080
    8181    def add_update_values(self, values):
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 3e9dbf0..ea26f62 100644
    a b from itertools import repeat  
    66
    77from django.utils import tree
    88from django.db.models.fields import Field
    9 from datastructures import EmptyResultSet, FullResultSet
     9from django.db.models.sql.aggregates import Aggregate
    1010
    1111# Connection types
    1212AND = 'AND'
    1313OR = 'OR'
    1414
    15 class EmptyShortCircuit(Exception):
     15class WhereLeaf(object):
    1616    """
    17     Internal exception used to indicate that a "matches nothing" node should be
    18     added to the where-clause.
    19     """
    20     pass
     17    Represents a leaf node in a where tree. Contains single constraint,
     18    and knows how to turn it into sql and params.
    2119
    22 class WhereNode(tree.Node):
     20    This implements many of the WhereNode's methods. Here the methods
     21    will do the terminal work, while WhereNode's methods will be mostly
     22    recursive in nature.
    2323    """
    24     Used to represent the SQL where-clause.
    25 
    26     The class is tied to the Query class that created it (in order to create
    27     the correct SQL).
    2824
    29     The children in this tree are usually either Q-like objects or lists of
    30     [table_alias, field_name, db_type, lookup_type, value_annotation,
    31     params]. However, a child could also be any class with as_sql() and
    32     relabel_aliases() methods.
    33     """
    34     default = AND
     25    # Fast and pretty way to test if the node is a leaf node.
     26    is_leaf = True
    3527
    36     def add(self, data, connector):
    37         """
    38         Add a node to the where-tree. If the data is a list or tuple, it is
    39         expected to be of the form (obj, lookup_type, value), where obj is
    40         a Constraint object, and is then slightly munged before being stored
    41         (to avoid storing any reference to field objects). Otherwise, the 'data'
    42         is stored unchanged and can be any class with an 'as_sql()' method.
    43         """
     28    def __init__(self, data, negated=False):
     29        self.sql = ''
     30        self.negated = negated
     31        self.params = []
     32        self.match_all = False
     33        self.match_nothing = False
    4434        if not isinstance(data, (list, tuple)):
    45             super(WhereNode, self).add(data, connector)
    46             return
    47 
    48         obj, lookup_type, value = data
    49         if hasattr(value, '__iter__') and hasattr(value, 'next'):
    50             # Consume any generators immediately, so that we can determine
    51             # emptiness and transform any non-empty values correctly.
    52             value = list(value)
    53 
    54         # The "annotation" parameter is used to pass auxilliary information
    55         # about the value(s) to the query construction. Specifically, datetime
    56         # and empty values need special handling. Other types could be used
    57         # here in the future (using Python types is suggested for consistency).
    58         if isinstance(value, datetime.datetime):
    59             annotation = datetime.datetime
    60         elif hasattr(value, 'value_annotation'):
    61             annotation = value.value_annotation
     35            self.data = data
    6236        else:
    63             annotation = bool(value)
    64 
    65         if hasattr(obj, "prepare"):
    66             value = obj.prepare(lookup_type, value)
    67             super(WhereNode, self).add((obj, lookup_type, annotation, value),
    68                 connector)
    69             return
    70 
    71         super(WhereNode, self).add((obj, lookup_type, annotation, value),
    72                 connector)
     37            # Preprocess the data
     38            obj, lookup_type, value = data
     39
     40            if hasattr(value, '__iter__') and hasattr(value, 'next'):
     41                # Consume any generators immediately, so that we can determine
     42                # emptiness and transform any non-empty values correctly.
     43                value = list(value)
     44
     45            # The "annotation" parameter is used to pass auxilliary information
     46            # about the value(s) to the query construction. Specifically, datetime
     47            # and empty values need special handling. Other types could be used
     48            # here in the future (using Python types is suggested for consistency).
     49            if isinstance(value, datetime.datetime):
     50                annotation = datetime.datetime
     51            elif hasattr(value, 'value_annotation'):
     52                annotation = value.value_annotation
     53            else:
     54                annotation = bool(value)
    7355
    74     def as_sql(self, qn, connection):
    75         """
    76         Returns the SQL version of the where clause and the value to be
    77         substituted in. Returns None, None if this node is empty.
     56            if hasattr(obj, "prepare"):
     57                value = obj.prepare(lookup_type, value)
     58            self.data = (obj, lookup_type, annotation, value)
    7859
    79         If 'node' is provided, that is the root of the SQL generation
    80         (generally not needed except by the internal implementation for
    81         recursion).
    82         """
    83         if not self.children:
    84             return None, []
    85         result = []
    86         result_params = []
    87         empty = True
    88         for child in self.children:
    89             try:
    90                 if hasattr(child, 'as_sql'):
    91                     sql, params = child.as_sql(qn=qn, connection=connection)
    92                 else:
    93                     # A leaf node in the tree.
    94                     sql, params = self.make_atom(child, qn, connection)
    95 
    96             except EmptyResultSet:
    97                 if self.connector == AND and not self.negated:
    98                     # We can bail out early in this particular case (only).
    99                     raise
    100                 elif self.negated:
    101                     empty = False
    102                 continue
    103             except FullResultSet:
    104                 if self.connector == OR:
    105                     if self.negated:
    106                         empty = True
    107                         break
    108                     # We match everything. No need for any constraints.
    109                     return '', []
    110                 if self.negated:
    111                     empty = True
    112                 continue
    113 
    114             empty = False
    115             if sql:
    116                 result.append(sql)
    117                 result_params.extend(params)
    118         if empty:
    119             raise EmptyResultSet
     60    def create_sql(self, qn, connection):
     61        if hasattr(self.data, 'as_sql'):
     62            self.sql, self.params = self.data.as_sql(qn, connection)
     63        else:
     64            self.sql, self.params = self.make_atom(qn, connection)
     65        if self.negated and self.sql:
     66            self.sql = 'NOT ' + self.sql
    12067
    121         conn = ' %s ' % self.connector
    122         sql_string = conn.join(result)
    123         if sql_string:
    124             if self.negated:
    125                 sql_string = 'NOT (%s)' % sql_string
    126             elif len(self.children) != 1:
    127                 sql_string = '(%s)' % sql_string
    128         return sql_string, result_params
     68    def as_sql(self):
     69        return self.sql, self.params
    12970
    130     def make_atom(self, child, qn, connection):
     71    def make_atom(self, qn, connection):
    13172        """
    13273        Turn a tuple (table_alias, column_name, db_type, lookup_type,
    13374        value_annot, params) into valid SQL.
    class WhereNode(tree.Node):  
    13576        Returns the string for the SQL fragment and the parameters to use for
    13677        it.
    13778        """
    138         lvalue, lookup_type, value_annot, params_or_value = child
     79        lvalue, lookup_type, value_annot, params_or_value = self.data
    13980        if hasattr(lvalue, 'process'):
     81            from django.db.models.base import ObjectDoesNotExist
    14082            try:
    14183                lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
    142             except EmptyShortCircuit:
    143                 raise EmptyResultSet
     84            except ObjectDoesNotExist:
     85                self.set_sql_matches_nothing()
     86                return '', []
    14487        else:
    14588            params = Field().get_db_prep_lookup(lookup_type, params_or_value,
    14689                connection=connection, prepared=True)
    class WhereNode(tree.Node):  
    175118
    176119        if lookup_type == 'in':
    177120            if not value_annot:
    178                 raise EmptyResultSet
     121                self.set_sql_matches_nothing()
     122                return '', []
    179123            if extra:
    180124                return ('%s IN %s' % (field_sql, extra), params)
    181125            max_in_list_size = connection.ops.max_in_list_size()
    class WhereNode(tree.Node):  
    210154            return connection.ops.regex_lookup(lookup_type) % (field_sql, cast_sql), params
    211155
    212156        raise TypeError('Invalid lookup_type: %r' % lookup_type)
     157         
     158       
     159    def set_sql_matches_nothing(self):
     160        if self.negated:
     161            self.match_everything = True
     162        else:
     163            self.match_nothing = True
    213164
     165    def subtree_contains_aggregate(self):
     166        """
     167        The leaf node contains aggregate if it has an aggregate in it, or it
     168        contains a subquery which contains an aggregate as a value.
     169        """
     170        return (isinstance(self.data[0], Aggregate) or
     171                   (len(self.data) == 4 and
     172                    hasattr(self.data[3], 'contains_aggregate') and
     173                    self.data[3].contains_aggregate))
     174   
    214175    def sql_for_columns(self, data, qn, connection):
    215176        """
    216177        Returns the SQL fragment used for the left-hand side of a column
    class WhereNode(tree.Node):  
    224185            lhs = qn(name)
    225186        return connection.ops.field_cast_sql(db_type) % lhs
    226187
    227     def relabel_aliases(self, change_map, node=None):
     188    def relabel_aliases(self, change_map):
     189        if hasattr(self.data, 'relabel_aliases'):
     190            self.data.relabel_aliases(change_map)
     191        elif isinstance(self.data[0], (list, tuple)):
     192            elt = list(self.data[0])
     193            if elt[0] in change_map:
     194                elt[0] = change_map[elt[0]]
     195                self.data = (tuple(elt),) + self.data[1:]
     196        else:
     197            self.data[0].relabel_aliases(change_map)
     198
     199            # Check if the query value also requires relabelling
     200            if hasattr(self.data[3], 'relabel_aliases'):
     201                self.data[3].relabel_aliases(change_map)
     202
     203    def get_group_by(self, group_by):
     204        if isinstance(self.data, tuple) and not isinstance(self.data[0], Aggregate):
     205            group_by.add((self.data[0].alias, self.data[0].col))
     206   
     207    def clone(self):
    228208        """
    229         Relabels the alias values of any children. 'change_map' is a dictionary
    230         mapping old (current) alias values to the new values.
     209        TODO: It is unfortunate that the data can be all sorts of things. It
     210        would be a good idea to make the Constraint a bit larger class, so
     211        that it could hold also the lookup type and value. Then we would
     212        always have something implementing similar interface in Data.
    231213        """
    232         if not node:
    233             node = self
    234         for pos, child in enumerate(node.children):
    235             if hasattr(child, 'relabel_aliases'):
    236                 child.relabel_aliases(change_map)
    237             elif isinstance(child, tree.Node):
    238                 self.relabel_aliases(change_map, child)
    239             elif isinstance(child, (list, tuple)):
    240                 if isinstance(child[0], (list, tuple)):
    241                     elt = list(child[0])
    242                     if elt[0] in change_map:
    243                         elt[0] = change_map[elt[0]]
    244                         node.children[pos] = (tuple(elt),) + child[1:]
    245                 else:
    246                     child[0].relabel_aliases(change_map)
     214        clone = self.__class__(None, self.negated)
     215        if hasattr(self.data, 'clone'):
     216            clone.data = self.data.clone()
     217       
     218        else:
     219            if hasattr(self.data[3], 'clone'):
     220                new_data3 = self.data[3].clone()
     221            else:
     222                new_data3 = self.data[3]
     223            clone.data = (self.data[0].clone(), self.data[1], self.data[2], new_data3)
     224        return clone
     225
     226    def negate(self):
     227        self.negated = not self.negated
    247228
    248                 # Check if the query value also requires relabelling
    249                 if hasattr(child[3], 'relabel_aliases'):
    250                     child[3].relabel_aliases(change_map)
     229    def __str__(self):
     230        return "%s%s, %s, %s" % (self.negated and 'NOT: ' or '',
     231                                 self.data[0], self.data[1], self.data[3])
    251232
    252 class EverythingNode(object):
     233class WhereNode(tree.Node):
    253234    """
    254     A node that matches everything.
     235    Used to represent the SQL where-clause.
     236
     237    The class is tied to the Query class that created it (in order to create
     238    the correct SQL).
     239
     240    The children in this tree are usually either Q-like objects or lists of
     241    [table_alias, field_name, db_type, lookup_type, value_annotation,
     242    params]. However, a child could also be any class with as_sql() and
     243    relabel_aliases() methods.
    255244    """
    256245
    257     def as_sql(self, qn=None, connection=None):
    258         raise FullResultSet
     246    default = AND
     247    is_leaf = False
    259248
    260     def relabel_aliases(self, change_map, node=None):
    261         return
     249    def leaf_class(cls):
     250        # Subclass hook
     251        return WhereLeaf
     252    leaf_class = classmethod(leaf_class)
    262253
    263 class NothingNode(object):
    264     """
    265     A node that matches nothing.
    266     """
    267     def as_sql(self, qn=None, connection=None):
    268         raise EmptyResultSet
     254    def clone_internal(self):
     255        clone = self._new_instance()
     256        clone.negated = self.negated; clone.connector = self.connector
     257        clone.children = [c.is_leaf and c or c.clone() for c in self.children]
     258        return clone
     259                 
     260
     261    def final_prune(self, qn, connection):
     262        """
     263        This will do the final pruning of the tree, that is, removing parts
     264        of the tree that must match everything / nothing.
     265
     266        Due to the fact that the only way to get to know that is calling
     267        as_sql(), we will at the same time turn the leaf nodes into sql.
     268        """
     269        # There variables make sense only in the context of the final prune.
     270        # There is no need to clone them, and there is no need to have them
     271        # elsewhere. So, define them here instead of __init__.
     272        self.match_all = False
     273        self.match_nothing = False
     274        for child in self.children[:]:
     275            if child.is_leaf:
     276                child.create_sql(qn, connection)
     277            else:
     278                child.final_prune(qn, connection)
     279            if child.match_all:
     280                 if self.connector == OR:
     281                     self.match_all = True
     282                     break
     283                 self.children.remove(child)
     284            if child.match_nothing:
     285                 if self.connector == AND:
     286                     self.match_nothing = True
     287                     break
     288                 self.children.remove(child)
     289        else:
     290            # We got through the loop without a break. Check if there are any
     291            # children left. If not, this node must be a match_all node.
     292            if not self.children:
     293                self.match_all = True
     294        if self.negated:
     295            # If the node is negated, then turn the tables around.
     296            self.match_all, self.match_nothing = self.match_nothing, self.match_all
     297   
     298    def split_aggregates(self, having, parent=None):
     299        """
     300        Remove those parts of self that must go into the having clause. Part
     301        must go into having if:
     302          - It is connected to parent with OR and the subtree contains
     303            aggregate
     304          - The node is a leaf node and it contains aggregate
     305        """
     306        from django.conf import settings
     307        if self.connector == OR:
     308            if self.subtree_contains_aggregate():
     309                having.add(self, AND)
     310                # Note that OR cannot be the highest node in the tree, a where
     311                # tree must always contain AND as root, and as such parent
     312                # can't be None here.
     313                parent.children.remove(self)
     314        else:
     315            if self.negated:
     316                # TODO: I believe this might be broken. If in fact it isn't,
     317                # we need a comment why it isn't so.
     318                neg_node = having._new_instance(negated=True)
     319                having.add(neg_node, AND)
     320                having = neg_node
     321            for child in self.children[:]:
     322                if child.is_leaf:
     323                    if child.subtree_contains_aggregate():
     324                        having.add(child, AND)
     325                        self.children.remove(child)
     326                else:
     327                    child.split_aggregates(having, self)
     328
     329    def subtree_contains_aggregate(self):
     330        """
     331        Returns whether or not all elements of this q_object need to be put
     332        together in the HAVING clause.
     333        """
     334        for child in self.children:
     335            if child.subtree_contains_aggregate():
     336                 return True
     337        return False
     338
     339    def as_sql(self):
     340        """
     341        Turns this tree into SQL and params. It is assumed that leaf nodes are already
     342        TODO: rename, and have as_sql implement the normal as_sql(qn, connection)
     343        interface.
     344        """
     345        if not self:
     346            return '', []
     347        sql_snippets, params = [], []
     348        for child in self.children:
     349            child_sql, child_params = child.as_sql()
     350            sql_snippets.append(child_sql); params.extend(child_params)
     351
     352        conn = ' %s ' % self.connector
     353        sql_string = conn.join(sql_snippets)
     354        if self.negated and sql_string:
     355            sql_string = 'NOT (%s)' % sql_string
     356        elif len(self.children) != 1:
     357            sql_string = '(%s)' % sql_string
     358        return sql_string, params
     359
     360    def get_group_by(self, group_by):
     361        for child in self.children:
     362            child.get_group_by(group_by)
    269363
    270364    def relabel_aliases(self, change_map, node=None):
    271         return
     365        """
     366        Relabels the alias values of any children. 'change_map' is a dictionary
     367        mapping old (current) alias values to the new values.
     368        """
     369        for child in self.children:
     370            child.relabel_aliases(change_map)
    272371
    273372class ExtraWhere(object):
    274373    def __init__(self, sqls, params):
    275374        self.sqls = sqls
    276375        self.params = params
    277376
     377    def relabel_aliases(self, change_map):
     378        return
     379
    278380    def as_sql(self, qn=None, connection=None):
    279381        return " AND ".join(self.sqls), tuple(self.params or ())
    280382
     383    def clone(self):
     384        return self
     385
    281386class Constraint(object):
    282387    """
    283388    An object that can be passed to WhereNode.add() and knows how to
    284389    pre-process itself prior to including in the WhereNode.
    285390    """
     391   
    286392    def __init__(self, alias, col, field):
    287393        self.alias, self.col, self.field = alias, col, field
    288394
    class Constraint(object):  
    318424    def process(self, lookup_type, value, connection):
    319425        """
    320426        Returns a tuple of data suitable for inclusion in a WhereNode
    321         instance.
     427        instance. Can raise ObjectDoesNotExist
    322428        """
    323         # Because of circular imports, we need to import this here.
    324         from django.db.models.base import ObjectDoesNotExist
    325         try:
    326             if self.field:
    327                 params = self.field.get_db_prep_lookup(lookup_type, value,
    328                     connection=connection, prepared=True)
    329                 db_type = self.field.db_type(connection=connection)
    330             else:
    331                 # This branch is used at times when we add a comparison to NULL
    332                 # (we don't really want to waste time looking up the associated
    333                 # field object at the calling location).
    334                 params = Field().get_db_prep_lookup(lookup_type, value,
    335                     connection=connection, prepared=True)
    336                 db_type = None
    337         except ObjectDoesNotExist:
    338             raise EmptyShortCircuit
    339 
     429        if self.field:
     430            params = self.field.get_db_prep_lookup(lookup_type, value,
     431                connection=connection, prepared=True)
     432            db_type = self.field.db_type(connection=connection)
     433        else:
     434            # This branch is used at times when we add a comparison to NULL
     435            # (we don't really want to waste time looking up the associated
     436            # field object at the calling location).
     437            params = Field().get_db_prep_lookup(lookup_type, value,
     438                connection=connection, prepared=True)
     439            db_type = None
    340440        return (self.alias, self.col, db_type), params
    341441
    342442    def relabel_aliases(self, change_map):
    343443        if self.alias in change_map:
    344444            self.alias = change_map[self.alias]
     445
     446    def clone(self):
     447        return Constraint(self.alias, self.col, self.field)
     448
     449    def __str__(self):
     450        return "%s.%s" % (self.alias, self.col)
  • django/utils/tree.py

    diff --git a/django/utils/tree.py b/django/utils/tree.py
    index 36b5977..f733d1b 100644
    a b class Node(object):  
    1919        """
    2020        Constructs a new Node. If no connector is given, the default will be
    2121        used.
    22 
    23         Warning: You probably don't want to pass in the 'negated' parameter. It
    24         is NOT the same as constructing a node and calling negate() on the
    25         result.
    2622        """
    2723        self.children = children and children[:] or []
    2824        self.connector = connector or self.default
    29         self.subtree_parents = []
     25        self.parent = None
    3026        self.negated = negated
    3127
    3228    # We need this because of django.db.models.query_utils.Q. Q. __init__() is
    3329    # problematic, but it is a natural Node subclass in all other respects.
     30    # The __init__ of Q has different signature, and thus _new_instance of Q
     31    # does call Q's version of __init__.
    3432    def _new_instance(cls, children=None, connector=None, negated=False):
     33        return cls(children, connector, negated)
     34    _new_instance = classmethod(_new_instance)
     35
     36    def clone(self):
    3537        """
    36         This is called to create a new instance of this class when we need new
    37         Nodes (or subclasses) in the internal code in this class. Normally, it
    38         just shadows __init__(). However, subclasses with an __init__ signature
    39         that is not an extension of Node.__init__ might need to implement this
    40         method to allow a Node to create a new instance of them (if they have
    41         any extra setting up to do).
     38        Clones the internal nodes of the tree. If also_leafs is False, does
     39        not copy leaf nodes. This is a useful optimization for WhereNode
     40        because WhereLeaf nodes do not need copying except when relabel_aliases
     41        is called.
    4242        """
    43         obj = Node(children, connector, negated)
    44         obj.__class__ = cls
     43        obj = self._new_instance()
     44        obj.children = [
     45            isinstance(c, tuple) and c or c.clone() for c in self.children
     46         ]
     47        obj.connector = self.connector
     48        obj.negated = self.negated
    4549        return obj
    46     _new_instance = classmethod(_new_instance)
     50
     51    def __repr__(self):
     52        return self.as_subtree
    4753
    4854    def __str__(self):
    4955        if self.negated:
    class Node(object):  
    5258        return '(%s: %s)' % (self.connector, ', '.join([str(c) for c in
    5359                self.children]))
    5460
    55     def __deepcopy__(self, memodict):
    56         """
    57         Utility method used by copy.deepcopy().
    58         """
    59         obj = Node(connector=self.connector, negated=self.negated)
    60         obj.__class__ = self.__class__
    61         obj.children = copy.deepcopy(self.children, memodict)
    62         obj.subtree_parents = copy.deepcopy(self.subtree_parents, memodict)
    63         return obj
     61    def _as_subtree(self, indent=0):
     62        buf = []
     63        if self.negated:
     64            buf.append(" " * indent + "NOT")
     65        buf.append((" " * indent) + self.connector + ":")
     66        indent += 2
     67        for child in self.children:
     68            if isinstance(child, Node):
     69                buf.append(child._as_subtree(indent=indent))
     70            else:
     71                buf.append((" " * indent) + str(child))
     72        return "\n".join(buf)
     73    as_subtree = property(_as_subtree)
    6474
    6575    def __len__(self):
    6676        """
    class Node(object):  
    8292
    8393    def add(self, node, conn_type):
    8494        """
    85         Adds a new node to the tree. If the conn_type is the same as the root's
    86         current connector type, the node is added to the first level.
     95        Adds a new node to the tree. If the conn_type is the same as the
     96        root's current connector type, the node is added to the first level.
    8797        Otherwise, the whole tree is pushed down one level and a new root
    88         connector is created, connecting the existing tree and the new node.
     98        connector is created, connecting the existing tree and the added node.
    8999        """
    90100        if node in self.children and conn_type == self.connector:
    91101            return
    92         if len(self.children) < 2:
    93             self.connector = conn_type
    94102        if self.connector == conn_type:
    95             if isinstance(node, Node) and (node.connector == conn_type or
    96                     len(node) == 1):
    97                 self.children.extend(node.children)
    98             else:
    99                 self.children.append(node)
     103            self.children.append(node)
    100104        else:
    101             obj = self._new_instance(self.children, self.connector,
    102                     self.negated)
    103             self.connector = conn_type
    104             self.children = [obj, node]
     105            obj = self._new_instance([node], conn_type)
     106            self.children.append(obj)
    105107
    106108    def negate(self):
    107109        """
    108         Negate the sense of the root connector. This reorganises the children
    109         so that the current node has a single child: a negated node containing
    110         all the previous children. This slightly odd construction makes adding
    111         new children behave more intuitively.
    112 
    113         Interpreting the meaning of this negate is up to client code. This
    114         method is useful for implementing "not" arrangements.
    115         """
    116         self.children = [self._new_instance(self.children, self.connector,
    117                 not self.negated)]
    118         self.connector = self.default
    119 
    120     def start_subtree(self, conn_type):
    121         """
    122         Sets up internal state so that new nodes are added to a subtree of the
    123         current node. The conn_type specifies how the sub-tree is joined to the
    124         existing children.
    125         """
    126         if len(self.children) == 1:
    127             self.connector = conn_type
    128         elif self.connector != conn_type:
    129             self.children = [self._new_instance(self.children, self.connector,
    130                     self.negated)]
    131             self.connector = conn_type
    132             self.negated = False
    133 
    134         self.subtree_parents.append(self.__class__(self.children,
    135                 self.connector, self.negated))
    136         self.connector = self.default
    137         self.negated = False
    138         self.children = []
    139 
    140     def end_subtree(self):
    141         """
    142         Closes off the most recently unmatched start_subtree() call.
    143 
    144         This puts the current state into a node of the parent tree and returns
    145         the current instances state to be the parent.
    146         """
    147         obj = self.subtree_parents.pop()
    148         node = self.__class__(self.children, self.connector)
    149         self.connector = obj.connector
    150         self.negated = obj.negated
    151         self.children = obj.children
    152         self.children.append(node)
    153 
     110        Negate the sense of this node.
     111        """
     112        self.negated = not self.negated
     113
     114    def prune_tree(self):
     115        """
     116        Removes empty children nodes, and non-necessary intermediatry
     117        nodes from this node.
     118        """
     119        for child in self.children[:]:
     120            if not child:
     121                self.children.remove(child)
     122            elif not child.is_leaf:
     123                child.prune_tree()
     124                if len(child) == 1:
     125                    # There is no need for this node.we can prune internal
     126                    # nodes with just on child
     127                    grandchild = child.children[0]
     128                    if child.negated:
     129                        grandchild.negate()
     130                    self.children.remove(child)
     131                    self.children.append(grandchild)
     132                elif not child:
     133                    self.children.remove(child)
  • tests/regressiontests/aggregation_regress/tests.py

    diff --git a/tests/regressiontests/aggregation_regress/tests.py b/tests/regressiontests/aggregation_regress/tests.py
    index acdc59a..badc1cb 100644
    a b class AggregationTests(TestCase):  
    465465        # Regression for #15709 - Ensure each group_by field only exists once
    466466        # per query
    467467        qs = Book.objects.values('publisher').annotate(max_pages=Max('pages')).order_by()
    468         grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping()
     468        grouping, gb_params = qs.query.get_compiler(qs.db).get_grouping(set())
    469469        self.assertEqual(len(grouping), 1)
    470470
    471471    def test_duplicate_alias(self):
  • tests/regressiontests/queries/tests.py

    diff --git a/tests/regressiontests/queries/tests.py b/tests/regressiontests/queries/tests.py
    index d8fd5bc..4f505a3 100644
    a b class Queries1Tests(BaseQuerysetTest):  
    820820        q = Note.objects.filter(Q(extrainfo__author=self.a1)|Q(extrainfo=xx)).query
    821821        self.assertEqual(
    822822            len([x[2] for x in q.alias_map.values() if x[2] == q.LOUTER and q.alias_refcount[x[1]]]),
    823             1
     823            2
    824824        )
    825825
    826826
Back to Top