Ticket #14030: 14030-2.patch

File 14030-2.patch, 20.1 KB (added by Nate Bragg, 13 years ago)

Removed "aggregate that in fact is not an aggregate" hack

  • django/db/models/aggregates.py

    From cf61c2dd2f0afec3414de3dff8497fc461491b3c Mon Sep 17 00:00:00 2001
    From: Nate Bragg <jonathan.bragg@alum.rpi.edu>
    Date: Thu, 19 Jan 2012 21:01:32 -0500
    Subject: [PATCH] An attempt at rebasing out the changes required for
     supporting F expressions in aggregation from the more
     complex patch supporting conditional aggregation for
     #11305.
    
    Additional changes needed to make F expressions usable without
    being passed in inside an aggregation function.
    
    Also added some doc, and some tests.
    ---
     django/db/models/aggregates.py        |    2 +
     django/db/models/sql/aggregates.py    |   19 ++++++-
     django/db/models/sql/compiler.py      |   39 ++++++++-----
     django/db/models/sql/expressions.py   |    3 +
     django/db/models/sql/query.py         |   96 +++++++++++++++++++--------------
     django/db/models/sql/where.py         |   10 ++++
     django/test/testcases.py              |   12 ++++
     docs/ref/models/querysets.txt         |   23 ++++++++
     tests/modeltests/aggregation/tests.py |   31 +++++++++++
     9 files changed, 176 insertions(+), 59 deletions(-)
    
    diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py
    index a2349cf..61848fe 100644
    a b class Aggregate(object):  
    2020        self.extra = extra
    2121
    2222    def _default_alias(self):
     23        if hasattr(self.lookup, 'evaluate'):
     24             raise ValueError('When aggregating over an expression, you need to give an alias.')
    2325        return '%s__%s' % (self.lookup, self.name.lower())
    2426    default_alias = property(_default_alias)
    2527
  • django/db/models/sql/aggregates.py

    diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py
    index 207bc0c..27175fe 100644
    a b  
    11"""
    22Classes to represent the default SQL aggregate functions
    33"""
     4from django.db.models.sql.expressions import SQLEvaluator
    45
    56class AggregateField(object):
    67    """An internal field mockup used to identify aggregates in the
    class Aggregate(object):  
    6667                tmp = computed_aggregate_field
    6768            else:
    6869                tmp = tmp.source
    69 
     70       
     71        # We don't know the real source of this aggregate, and the
     72        # aggregate doesn't define ordinal or computed either. So
     73        # we default to computed for these cases.
     74        if tmp is None:
     75            tmp = computed_aggregate_field
    7076        self.field = tmp
    7177
    7278    def relabel_aliases(self, change_map):
    7379        if isinstance(self.col, (list, tuple)):
    7480            self.col = (change_map.get(self.col[0], self.col[0]), self.col[1])
     81        else:
     82            self.col.relabel_aliases(change_map)
    7583
    7684    def as_sql(self, qn, connection):
    7785        "Return the aggregate, rendered as SQL."
    7886
     87        col_params = []
    7988        if hasattr(self.col, 'as_sql'):
    80             field_name = self.col.as_sql(qn, connection)
     89            if isinstance(self.col, SQLEvaluator):
     90                field_name, col_params = self.col.as_sql(qn, connection)
     91            else:
     92                field_name = self.col.as_sql(qn, connection)
     93           
    8194        elif isinstance(self.col, (list, tuple)):
    8295            field_name = '.'.join([qn(c) for c in self.col])
    8396        else:
    class Aggregate(object):  
    89102        }
    90103        params.update(self.extra)
    91104
    92         return self.sql_template % params
     105        return (self.sql_template % params, col_params)
    93106
    94107
    95108class Avg(Aggregate):
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 72948f9..bf3cb25 100644
    a b class SQLCompiler(object):  
    6868        # as the pre_sql_setup will modify query state in a way that forbids
    6969        # another run of it.
    7070        self.refcounts_before = self.query.alias_refcount.copy()
    71         out_cols = self.get_columns(with_col_aliases)
     71        out_cols, c_params = self.get_columns(with_col_aliases)
    7272        ordering, ordering_group_by = self.get_ordering()
    7373
    7474        distinct_fields = self.get_distinct()
    class SQLCompiler(object):  
    8484        params = []
    8585        for val in self.query.extra_select.itervalues():
    8686            params.extend(val[1])
     87        # Extra-select comes before aggregation in the select list
     88        params.extend(c_params)
    8789
    8890        result = ['SELECT']
    8991
    class SQLCompiler(object):  
    178180        qn = self.quote_name_unless_alias
    179181        qn2 = self.connection.ops.quote_name
    180182        result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()]
     183        query_params = []
    181184        aliases = set(self.query.extra_select.keys())
    182185        if with_aliases:
    183186            col_aliases = aliases.copy()
    class SQLCompiler(object):  
    220223            aliases.update(new_aliases)
    221224
    222225        max_name_length = self.connection.ops.max_name_length()
    223         result.extend([
    224             '%s%s' % (
    225                 aggregate.as_sql(qn, self.connection),
    226                 alias is not None
    227                     and ' AS %s' % qn(truncate_name(alias, max_name_length))
    228                     or ''
     226        for alias, aggregate in self.query.aggregate_select.items():
     227            sql, params = aggregate.as_sql(qn, self.connection)
     228            result.append(
     229                '%s%s' % (
     230                    sql,
     231                    alias is not None
     232                       and ' AS %s' % qn(truncate_name(alias, max_name_length))
     233                       or ''
     234                )
    229235            )
    230             for alias, aggregate in self.query.aggregate_select.items()
    231         ])
     236            query_params.extend(params)
    232237
    233238        for table, col in self.query.related_select_cols:
    234239            r = '%s.%s' % (qn(table), qn(col))
    class SQLCompiler(object):  
    243248                col_aliases.add(col)
    244249
    245250        self._select_aliases = aliases
    246         return result
     251        return result, query_params
    247252
    248253    def get_default_columns(self, with_aliases=False, col_aliases=None,
    249254            start_alias=None, opts=None, as_pairs=False, local_only=False):
    class SQLAggregateCompiler(SQLCompiler):  
    10461051        """
    10471052        if qn is None:
    10481053            qn = self.quote_name_unless_alias
     1054        buf = []
     1055        a_params = []
     1056        for aggregate in self.query.aggregate_select.values():
     1057            sql, query_params = aggregate.as_sql(qn, self.connection)
     1058            buf.append(sql)
     1059            a_params.extend(query_params)
     1060        aggregate_sql = ', '.join(buf)
    10491061
    10501062        sql = ('SELECT %s FROM (%s) subquery' % (
    1051             ', '.join([
    1052                 aggregate.as_sql(qn, self.connection)
    1053                 for aggregate in self.query.aggregate_select.values()
    1054             ]),
     1063            aggregate_sql, 
    10551064            self.query.subquery)
    10561065        )
    1057         params = self.query.sub_params
     1066        params = tuple(a_params) + (self.query.sub_params)
    10581067        return (sql, params)
    10591068
    10601069class SQLDateCompiler(SQLCompiler):
  • django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    index 1bbf742..f9c23a9 100644
    a b class SQLEvaluator(object):  
    6565        for child in node.children:
    6666            if hasattr(child, 'evaluate'):
    6767                sql, params = child.evaluate(self, qn, connection)
     68                if isinstance(sql, tuple):
     69                    expression_params.extend(sql[1])
     70                    sql = sql[0]
    6871            else:
    6972                sql, params = '%s', (child,)
    7073
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index ed2bc06..5624446 100644
    a b from django.utils.encoding import force_unicode  
    1414from django.utils.tree import Node
    1515from django.db import connections, DEFAULT_DB_ALIAS
    1616from django.db.models import signals
     17from django.db.models.aggregates import Aggregate
    1718from django.db.models.expressions import ExpressionNode
    1819from django.db.models.fields import FieldDoesNotExist
    1920from django.db.models.query_utils import InvalidQuery
    class Query(object):  
    322323        This is required because of the predisposition of certain backends
    323324        to return Decimal and long types when they are not needed.
    324325        """
     326        is_ordinal = getattr(aggregate,"is_ordinal",False)
     327        is_computed = getattr(aggregate,"is_computed",True)
    325328        if value is None:
    326             if aggregate.is_ordinal:
     329            if is_ordinal:
    327330                return 0
    328331            # Return None as-is
    329332            return value
    330         elif aggregate.is_ordinal:
     333        elif is_ordinal:
    331334            # Any ordinal aggregate (e.g., count) returns an int
    332335            return int(value)
    333         elif aggregate.is_computed:
     336        elif is_computed:
    334337            # Any computed aggregate (e.g., avg) returns a float
    335338            return float(value)
    336339        else:
    class Query(object):  
    987990        Adds a single aggregate expression to the Query
    988991        """
    989992        opts = model._meta
    990         field_list = aggregate.lookup.split(LOOKUP_SEP)
    991         if len(field_list) == 1 and aggregate.lookup in self.aggregates:
    992             # Aggregate is over an annotation
    993             field_name = field_list[0]
    994             col = field_name
    995             source = self.aggregates[field_name]
    996             if not is_summary:
    997                 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
    998                     aggregate.name, field_name, field_name))
    999         elif ((len(field_list) > 1) or
    1000             (field_list[0] not in [i.name for i in opts.fields]) or
    1001             self.group_by is None or
    1002             not is_summary):
    1003             # If:
    1004             #   - the field descriptor has more than one part (foo__bar), or
    1005             #   - the field descriptor is referencing an m2m/m2o field, or
    1006             #   - this is a reference to a model field (possibly inherited), or
    1007             #   - this is an annotation over a model field
    1008             # then we need to explore the joins that are required.
    1009 
    1010             field, source, opts, join_list, last, _ = self.setup_joins(
    1011                 field_list, opts, self.get_initial_alias(), False)
    1012 
    1013             # Process the join chain to see if it can be trimmed
    1014             col, _, join_list = self.trim_joins(source, join_list, last, False)
    1015 
    1016             # If the aggregate references a model or field that requires a join,
    1017             # those joins must be LEFT OUTER - empty join rows must be returned
    1018             # in order for zeros to be returned for those aggregates.
    1019             for column_alias in join_list:
    1020                 self.promote_alias(column_alias, unconditional=True)
    1021 
    1022             col = (join_list[-1], col)
     993        if hasattr(aggregate, 'evaluate'):
     994            self.aggregates[alias] = SQLEvaluator(aggregate, self)
     995            return
     996        if hasattr(aggregate.lookup, 'evaluate'):
     997            # If lookup is a query expression, evaluate it
     998            col = SQLEvaluator(aggregate.lookup, self)
     999            # TODO: find out the real source of this field. If any field has
     1000            # is_computed, then source can be set to is_computed.
     1001            source = None
    10231002        else:
    1024             # The simplest cases. No joins required -
    1025             # just reference the provided column alias.
    1026             field_name = field_list[0]
    1027             source = opts.get_field(field_name)
    1028             col = field_name
     1003            field_list = aggregate.lookup.split(LOOKUP_SEP)
     1004            join_list = []
     1005            if len(field_list) == 1 and aggregate.lookup in self.aggregates:
     1006                # Aggregate is over an annotation
     1007                field_name = field_list[0]
     1008                col = field_name
     1009                source = self.aggregates[field_name]
     1010                if not is_summary:
     1011                    raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % (
     1012                        aggregate.name, field_name, field_name))
     1013            elif ((len(field_list) > 1) or
     1014                (field_list[0] not in [i.name for i in opts.fields]) or
     1015                self.group_by is None or
     1016                not is_summary):
     1017                # If:
     1018                #   - the field descriptor has more than one part (foo__bar), or
     1019                #   - the field descriptor is referencing an m2m/m2o field, or
     1020                #   - this is a reference to a model field (possibly inherited), or
     1021                #   - this is an annotation over a model field
     1022                # then we need to explore the joins that are required.
     1023
     1024                field, source, opts, join_list, last, _ = self.setup_joins(
     1025                    field_list, opts, self.get_initial_alias(), False)
     1026
     1027                # Process the join chain to see if it can be trimmed
     1028                col, _, join_list = self.trim_joins(source, join_list, last, False)
     1029
     1030                # If the aggregate references a model or field that requires a join,
     1031                # those joins must be LEFT OUTER - empty join rows must be returned
     1032                # in order for zeros to be returned for those aggregates.
     1033                for column_alias in join_list:
     1034                    self.promote_alias(column_alias, unconditional=True)
     1035
     1036                col = (join_list[-1], col)
     1037            else:
     1038                # The simplest cases. No joins required -
     1039                # just reference the provided column alias.
     1040                field_name = field_list[0]
     1041                source = opts.get_field(field_name)
     1042                col = field_name
    10291043
    10301044        # Add the aggregate to the query
    10311045        aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary)
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 1455ba6..8b530bd 100644
    a b class WhereNode(tree.Node):  
    139139        it.
    140140        """
    141141        lvalue, lookup_type, value_annot, params_or_value = child
     142        additional_params = []
    142143        if hasattr(lvalue, 'process'):
    143144            try:
    144145                lvalue, params = lvalue.process(lookup_type, params_or_value, connection)
    class WhereNode(tree.Node):  
    153154        else:
    154155            # A smart object with an as_sql() method.
    155156            field_sql = lvalue.as_sql(qn, connection)
     157            if isinstance(field_sql, tuple):
     158                # It returned also params
     159                additional_params.extend(field_sql[1])
     160                field_sql = field_sql[0]
    156161
    157162        if value_annot is datetime.datetime:
    158163            cast_sql = connection.ops.datetime_cast_sql()
    class WhereNode(tree.Node):  
    161166
    162167        if hasattr(params, 'as_sql'):
    163168            extra, params = params.as_sql(qn, connection)
     169            if isinstance(extra, tuple):
     170                params = params + tuple(extra[1])
     171                extra = extra[0]
    164172            cast_sql = ''
    165173        else:
    166174            extra = ''
    class WhereNode(tree.Node):  
    170178            lookup_type = 'isnull'
    171179            value_annot = True
    172180
     181        additional_params.extend(params)
     182        params = additional_params
    173183        if lookup_type in connection.operators:
    174184            format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),)
    175185            return (format % (field_sql,
  • django/test/testcases.py

    diff --git a/django/test/testcases.py b/django/test/testcases.py
    index af455a2..17326ee 100644
    a b class TransactionTestCase(SimpleTestCase):  
    774774            return self.assertEqual(set(map(transform, qs)), set(values))
    775775        return self.assertEqual(map(transform, qs), values)
    776776
     777    def assertQuerysetAlmostEqual(self, qs, values, transform=repr, ordered=True, places=7):
     778        # This could have been done with iterating zip(map(transform, qs), values),
     779        # checking each with assertAlmostEqual, which rounds the difference of each
     780        # pair, but this way you get much nicer error messages, and you can have an
     781        # unordered comparison, at the cost of a half a digit of accuracy.
     782        round_to = lambda v: round(v,places)
     783        tqs = map(round_to, map(transform, qs) )
     784        tvs = map(round_to, values)
     785        if not ordered:
     786            return self.assertEqual(set(tqs), set(tvs))
     787        return self.assertEqual(tqs, tvs)
     788
    777789    def assertNumQueries(self, num, func=None, *args, **kwargs):
    778790        using = kwargs.pop("using", DEFAULT_DB_ALIAS)
    779791        conn = connections[using]
  • docs/ref/models/querysets.txt

    diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt
    index 103cae1..a0aab2b 100644
    a b control the name of the annotation::  
    246246    >>> q[0].number_of_entries
    247247    42
    248248
     249In addition to aggregation functions, `:ref:`F() objects <query-expressions>`
     250can be used to perform a specific mathematical operation::
     251
     252    # The 1.0 is to force float conversion
     253    >>> q = Entry.objects.annotate(cpb_ratio=F('n_comments')*1.0/F('n_pingbacks'))
     254    # The ratio of comments to pingbacks for the first blog entry
     255    >>> q[0].cpb_ratio
     256    0.0625
     257
    249258For an in-depth discussion of aggregation, see :doc:`the topic guide on
    250259Aggregation </topics/db/aggregation>`.
    251260
    control the name of the aggregation value that is returned::  
    14831492    >>> q = Blog.objects.aggregate(number_of_entries=Count('entry'))
    14841493    {'number_of_entries': 16}
    14851494
     1495Inside aggregation functions, `:ref:`F() objects <query-expressions>`
     1496can be used to perform a specific mathematical operation::
     1497
     1498    # The 1.0 is to force float conversion
     1499    >>> q = Entry.objects.aggregate(avg_cpb_ratio=Avg(F('n_comments')*1.0/F('n_pingbacks')))
     1500    {'avg_cpb_ratio': 0.125}
     1501
    14861502For an in-depth discussion of aggregation, see :doc:`the topic guide on
    14871503Aggregation </topics/db/aggregation>`.
    14881504
    Django provides the following aggregation functions in the  
    21172133aggregate functions, see
    21182134:doc:`the topic guide on aggregation </topics/db/aggregation>`.
    21192135
     2136Note that in addition to taking a named field, aggregation
     2137functions can take `:ref:`F() objects <query-expressions>`.
     2138
     2139.. admonition:: Default aliases
     2140
     2141    When using ``F()`` objects, note that there is no default alias.
     2142
    21202143Avg
    21212144~~~
    21222145
  • tests/modeltests/aggregation/tests.py

    diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py
    index a35dbb3..a5d3a4e 100644
    a b import datetime  
    44from decimal import Decimal
    55
    66from django.db.models import Avg, Sum, Count, Max, Min
     7from django.db.models import F
    78from django.test import TestCase, Approximate
    89
    910from .models import Author, Publisher, Book, Store
    class BaseAggregateTestCase(TestCase):  
    6364        self.assertEqual(len(vals), 1)
    6465        self.assertAlmostEqual(vals["amazon_mean"], 4.08, places=2)
    6566
     67    def test_aggregate_f_expression(self):
     68        vals = Book.objects.all().aggregate(price_per_page=Avg(F('price')*1.0/F('pages')))
     69        self.assertEqual(len(vals), 1)
     70        self.assertAlmostEqual(vals["price_per_page"], 0.0745110754864109, places=2)
     71
     72    def test_annotate_f_expression(self):
     73        self.assertQuerysetAlmostEqual(
     74            Book.objects.all().annotate(price_per_page=F('price')*1.0/F('pages')), [
     75                0.0671140939597315,
     76                0.0437310606060606,
     77                0.0989666666666667,
     78                0.0848285714285714,
     79                0.0731448763250883,
     80                0.0792811839323467,
     81            ],
     82            lambda b: b.price_per_page,
     83            places=4
     84        )
     85
     86        self.assertQuerysetAlmostEqual(
     87            Publisher.objects.all().annotate(price_per_page=Avg(F('book__price')*1.0/F('book__pages'))), [
     88                0.0830403803131991,
     89                0.0437310606060606,
     90                0.0789867238768299,
     91                0.0792811839323467,
     92            ],
     93            lambda p: p.price_per_page,
     94            places=4
     95        )
     96
    6697    def test_annotate_basic(self):
    6798        self.assertQuerysetEqual(
    6899            Book.objects.annotate().order_by('pk'), [
Back to Top