Ticket #14030: 14030.patch
File 14030.patch, 19.8 KB (added by , 13 years ago) |
---|
-
django/db/models/aggregates.py
From c7a74c08def758c62997ba037eccfb8f73ba3efc Mon Sep 17 00:00:00 2001 From: Nate Bragg <jonathan.bragg@alum.rpi.edu> Date: Thu, 19 Jan 2012 21:01:32 -0500 Subject: [PATCH] An attempt at rebasing out the changes required for supporting F expressions in aggregation from the more complex patch supporting conditional aggregation for #11305. Additional changes needed to make F expressions usable without being passed in inside an aggregation function. Also added some doc, and some tests. --- django/db/models/aggregates.py | 2 + django/db/models/sql/aggregates.py | 20 ++++++- django/db/models/sql/compiler.py | 39 +++++++++----- django/db/models/sql/expressions.py | 3 + django/db/models/sql/query.py | 93 +++++++++++++++++++------------- django/db/models/sql/where.py | 10 ++++ django/test/testcases.py | 12 ++++ docs/ref/models/querysets.txt | 23 ++++++++ tests/modeltests/aggregation/tests.py | 31 +++++++++++ 9 files changed, 177 insertions(+), 56 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index a2349cf..61848fe 100644
a b class Aggregate(object): 20 20 self.extra = extra 21 21 22 22 def _default_alias(self): 23 if hasattr(self.lookup, 'evaluate'): 24 raise ValueError('When aggregating over an expression, you need to give an alias.') 23 25 return '%s__%s' % (self.lookup, self.name.lower()) 24 26 default_alias = property(_default_alias) 25 27 -
django/db/models/sql/aggregates.py
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 207bc0c..7e131b9 100644
a b 1 1 """ 2 2 Classes to represent the default SQL aggregate functions 3 3 """ 4 from django.db.models.sql.expressions import SQLEvaluator 4 5 5 6 class AggregateField(object): 6 7 """An internal field mockup used to identify aggregates in the … … class Aggregate(object): 22 23 is_ordinal = False 23 24 is_computed = False 24 25 sql_template = '%(function)s(%(field)s)' 26 sql_function = '' 25 27 26 28 def __init__(self, col, source=None, is_summary=False, **extra): 27 29 """Instantiate an SQL aggregate … … class Aggregate(object): 66 68 tmp = computed_aggregate_field 67 69 else: 68 70 tmp = tmp.source 69 71 72 # We don't know the real source of this aggregate, and the 73 # aggregate doesn't define ordinal or computed either. So 74 # we default to computed for these cases. 75 if tmp is None: 76 tmp = computed_aggregate_field 70 77 self.field = tmp 71 78 72 79 def relabel_aliases(self, change_map): 73 80 if isinstance(self.col, (list, tuple)): 74 81 self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) 82 else: 83 self.col.relabel_aliases(change_map) 75 84 76 85 def as_sql(self, qn, connection): 77 86 "Return the aggregate, rendered as SQL." 78 87 88 col_params = [] 79 89 if hasattr(self.col, 'as_sql'): 80 field_name = self.col.as_sql(qn, connection) 90 if isinstance(self.col, SQLEvaluator): 91 field_name, col_params = self.col.as_sql(qn, connection) 92 else: 93 field_name = self.col.as_sql(qn, connection) 94 81 95 elif isinstance(self.col, (list, tuple)): 82 96 field_name = '.'.join([qn(c) for c in self.col]) 83 97 else: … … class Aggregate(object): 89 103 } 90 104 params.update(self.extra) 91 105 92 return self.sql_template % params106 return (self.sql_template % params, col_params) 93 107 94 108 95 109 class Avg(Aggregate): -
django/db/models/sql/compiler.py
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 72948f9..bf3cb25 100644
a b class SQLCompiler(object): 68 68 # as the pre_sql_setup will modify query state in a way that forbids 69 69 # another run of it. 70 70 self.refcounts_before = self.query.alias_refcount.copy() 71 out_cols = self.get_columns(with_col_aliases)71 out_cols, c_params = self.get_columns(with_col_aliases) 72 72 ordering, ordering_group_by = self.get_ordering() 73 73 74 74 distinct_fields = self.get_distinct() … … class SQLCompiler(object): 84 84 params = [] 85 85 for val in self.query.extra_select.itervalues(): 86 86 params.extend(val[1]) 87 # Extra-select comes before aggregation in the select list 88 params.extend(c_params) 87 89 88 90 result = ['SELECT'] 89 91 … … class SQLCompiler(object): 178 180 qn = self.quote_name_unless_alias 179 181 qn2 = self.connection.ops.quote_name 180 182 result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] 183 query_params = [] 181 184 aliases = set(self.query.extra_select.keys()) 182 185 if with_aliases: 183 186 col_aliases = aliases.copy() … … class SQLCompiler(object): 220 223 aliases.update(new_aliases) 221 224 222 225 max_name_length = self.connection.ops.max_name_length() 223 result.extend([ 224 '%s%s' % ( 225 aggregate.as_sql(qn, self.connection), 226 alias is not None 227 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 228 or '' 226 for alias, aggregate in self.query.aggregate_select.items(): 227 sql, params = aggregate.as_sql(qn, self.connection) 228 result.append( 229 '%s%s' % ( 230 sql, 231 alias is not None 232 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 233 or '' 234 ) 229 235 ) 230 for alias, aggregate in self.query.aggregate_select.items() 231 ]) 236 query_params.extend(params) 232 237 233 238 for table, col in self.query.related_select_cols: 234 239 r = '%s.%s' % (qn(table), qn(col)) … … class SQLCompiler(object): 243 248 col_aliases.add(col) 244 249 245 250 self._select_aliases = aliases 246 return result 251 return result, query_params 247 252 248 253 def get_default_columns(self, with_aliases=False, col_aliases=None, 249 254 start_alias=None, opts=None, as_pairs=False, local_only=False): … … class SQLAggregateCompiler(SQLCompiler): 1046 1051 """ 1047 1052 if qn is None: 1048 1053 qn = self.quote_name_unless_alias 1054 buf = [] 1055 a_params = [] 1056 for aggregate in self.query.aggregate_select.values(): 1057 sql, query_params = aggregate.as_sql(qn, self.connection) 1058 buf.append(sql) 1059 a_params.extend(query_params) 1060 aggregate_sql = ', '.join(buf) 1049 1061 1050 1062 sql = ('SELECT %s FROM (%s) subquery' % ( 1051 ', '.join([ 1052 aggregate.as_sql(qn, self.connection) 1053 for aggregate in self.query.aggregate_select.values() 1054 ]), 1063 aggregate_sql, 1055 1064 self.query.subquery) 1056 1065 ) 1057 params = self.query.sub_params1066 params = tuple(a_params) + (self.query.sub_params) 1058 1067 return (sql, params) 1059 1068 1060 1069 class SQLDateCompiler(SQLCompiler): -
django/db/models/sql/expressions.py
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 1bbf742..f9c23a9 100644
a b class SQLEvaluator(object): 65 65 for child in node.children: 66 66 if hasattr(child, 'evaluate'): 67 67 sql, params = child.evaluate(self, qn, connection) 68 if isinstance(sql, tuple): 69 expression_params.extend(sql[1]) 70 sql = sql[0] 68 71 else: 69 72 sql, params = '%s', (child,) 70 73 -
django/db/models/sql/query.py
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ed2bc06..2c0e973 100644
a b from django.utils.encoding import force_unicode 14 14 from django.utils.tree import Node 15 15 from django.db import connections, DEFAULT_DB_ALIAS 16 16 from django.db.models import signals 17 from django.db.models.aggregates import Aggregate 17 18 from django.db.models.expressions import ExpressionNode 18 19 from django.db.models.fields import FieldDoesNotExist 19 20 from django.db.models.query_utils import InvalidQuery … … class Query(object): 987 988 Adds a single aggregate expression to the Query 988 989 """ 989 990 opts = model._meta 990 field_list = aggregate.lookup.split(LOOKUP_SEP) 991 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 992 # Aggregate is over an annotation 993 field_name = field_list[0] 994 col = field_name 995 source = self.aggregates[field_name] 996 if not is_summary: 997 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 998 aggregate.name, field_name, field_name)) 999 elif ((len(field_list) > 1) or 1000 (field_list[0] not in [i.name for i in opts.fields]) or 1001 self.group_by is None or 1002 not is_summary): 1003 # If: 1004 # - the field descriptor has more than one part (foo__bar), or 1005 # - the field descriptor is referencing an m2m/m2o field, or 1006 # - this is a reference to a model field (possibly inherited), or 1007 # - this is an annotation over a model field 1008 # then we need to explore the joins that are required. 1009 1010 field, source, opts, join_list, last, _ = self.setup_joins( 1011 field_list, opts, self.get_initial_alias(), False) 1012 1013 # Process the join chain to see if it can be trimmed 1014 col, _, join_list = self.trim_joins(source, join_list, last, False) 1015 1016 # If the aggregate references a model or field that requires a join, 1017 # those joins must be LEFT OUTER - empty join rows must be returned 1018 # in order for zeros to be returned for those aggregates. 1019 for column_alias in join_list: 1020 self.promote_alias(column_alias, unconditional=True) 1021 1022 col = (join_list[-1], col) 991 if hasattr(aggregate, 'evaluate'): 992 # If aggregate is a query expression, make it an aggregate 993 # This is a 'cheat' to make an empty aggregate - i.e., 994 # one that has no attached function. This is because 995 # no computation needs to be done outside that which the 996 # F expression represents 997 aggregate = Aggregate(aggregate) 998 aggregate.name = 'Aggregate' 999 if hasattr(aggregate.lookup, 'evaluate'): 1000 # If lookup is a query expression, evaluate it 1001 col = SQLEvaluator(aggregate.lookup, self) 1002 # TODO: find out the real source of this field. If any field has 1003 # is_computed, then source can be set to is_computed. 1004 source = None 1023 1005 else: 1024 # The simplest cases. No joins required - 1025 # just reference the provided column alias. 1026 field_name = field_list[0] 1027 source = opts.get_field(field_name) 1028 col = field_name 1006 field_list = aggregate.lookup.split(LOOKUP_SEP) 1007 join_list = [] 1008 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 1009 # Aggregate is over an annotation 1010 field_name = field_list[0] 1011 col = field_name 1012 source = self.aggregates[field_name] 1013 if not is_summary: 1014 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 1015 aggregate.name, field_name, field_name)) 1016 elif ((len(field_list) > 1) or 1017 (field_list[0] not in [i.name for i in opts.fields]) or 1018 self.group_by is None or 1019 not is_summary): 1020 # If: 1021 # - the field descriptor has more than one part (foo__bar), or 1022 # - the field descriptor is referencing an m2m/m2o field, or 1023 # - this is a reference to a model field (possibly inherited), or 1024 # - this is an annotation over a model field 1025 # then we need to explore the joins that are required. 1026 1027 field, source, opts, join_list, last, _ = self.setup_joins( 1028 field_list, opts, self.get_initial_alias(), False) 1029 1030 # Process the join chain to see if it can be trimmed 1031 col, _, join_list = self.trim_joins(source, join_list, last, False) 1032 1033 # If the aggregate references a model or field that requires a join, 1034 # those joins must be LEFT OUTER - empty join rows must be returned 1035 # in order for zeros to be returned for those aggregates. 1036 for column_alias in join_list: 1037 self.promote_alias(column_alias, unconditional=True) 1038 1039 col = (join_list[-1], col) 1040 else: 1041 # The simplest cases. No joins required - 1042 # just reference the provided column alias. 1043 field_name = field_list[0] 1044 source = opts.get_field(field_name) 1045 col = field_name 1029 1046 1030 1047 # Add the aggregate to the query 1031 1048 aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) -
django/db/models/sql/where.py
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 1455ba6..8b530bd 100644
a b class WhereNode(tree.Node): 139 139 it. 140 140 """ 141 141 lvalue, lookup_type, value_annot, params_or_value = child 142 additional_params = [] 142 143 if hasattr(lvalue, 'process'): 143 144 try: 144 145 lvalue, params = lvalue.process(lookup_type, params_or_value, connection) … … class WhereNode(tree.Node): 153 154 else: 154 155 # A smart object with an as_sql() method. 155 156 field_sql = lvalue.as_sql(qn, connection) 157 if isinstance(field_sql, tuple): 158 # It returned also params 159 additional_params.extend(field_sql[1]) 160 field_sql = field_sql[0] 156 161 157 162 if value_annot is datetime.datetime: 158 163 cast_sql = connection.ops.datetime_cast_sql() … … class WhereNode(tree.Node): 161 166 162 167 if hasattr(params, 'as_sql'): 163 168 extra, params = params.as_sql(qn, connection) 169 if isinstance(extra, tuple): 170 params = params + tuple(extra[1]) 171 extra = extra[0] 164 172 cast_sql = '' 165 173 else: 166 174 extra = '' … … class WhereNode(tree.Node): 170 178 lookup_type = 'isnull' 171 179 value_annot = True 172 180 181 additional_params.extend(params) 182 params = additional_params 173 183 if lookup_type in connection.operators: 174 184 format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 175 185 return (format % (field_sql, -
django/test/testcases.py
diff --git a/django/test/testcases.py b/django/test/testcases.py index 53ea02a..ba4f496 100644
a b class TransactionTestCase(SimpleTestCase): 646 646 return self.assertEqual(set(map(transform, qs)), set(values)) 647 647 return self.assertEqual(map(transform, qs), values) 648 648 649 def assertQuerysetAlmostEqual(self, qs, values, transform=repr, ordered=True, places=7): 650 # This could have been done with iterating zip(map(transform, qs), values), 651 # checking each with assertAlmostEqual, which rounds the difference of each 652 # pair, but this way you get much nicer error messages, and you can have an 653 # unordered comparison, at the cost of a half a digit of accuracy. 654 round_to = lambda v: round(v,places) 655 tqs = map(round_to, map(transform, qs) ) 656 tvs = map(round_to, values) 657 if not ordered: 658 return self.assertEqual(set(tqs), set(tvs)) 659 return self.assertEqual(tqs, tvs) 660 649 661 def assertNumQueries(self, num, func=None, *args, **kwargs): 650 662 using = kwargs.pop("using", DEFAULT_DB_ALIAS) 651 663 conn = connections[using] -
docs/ref/models/querysets.txt
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 7633555..d175f44 100644
a b control the name of the annotation:: 245 245 >>> q[0].number_of_entries 246 246 42 247 247 248 In addition to aggregation functions, `:ref:`F() objects <query-expressions>` 249 can be used to perform a specific mathematical operation:: 250 251 # The 1.0 is to force float conversion 252 >>> q = Entry.objects.annotate(cpb_ratio=F('n_comments')*1.0/F('n_pingbacks')) 253 # The ratio of comments to pingbacks for the first blog entry 254 >>> q[0].cpb_ratio 255 0.0625 256 248 257 For an in-depth discussion of aggregation, see :doc:`the topic guide on 249 258 Aggregation </topics/db/aggregation>`. 250 259 … … control the name of the aggregation value that is returned:: 1482 1491 >>> q = Blog.objects.aggregate(number_of_entries=Count('entry')) 1483 1492 {'number_of_entries': 16} 1484 1493 1494 Inside aggregation functions, `:ref:`F() objects <query-expressions>` 1495 can be used to perform a specific mathematical operation:: 1496 1497 # The 1.0 is to force float conversion 1498 >>> q = Entry.objects.aggregate(avg_cpb_ratio=Avg(F('n_comments')*1.0/F('n_pingbacks'))) 1499 {'avg_cpb_ratio': 0.125} 1500 1485 1501 For an in-depth discussion of aggregation, see :doc:`the topic guide on 1486 1502 Aggregation </topics/db/aggregation>`. 1487 1503 … … Django provides the following aggregation functions in the 2116 2132 aggregate functions, see 2117 2133 :doc:`the topic guide on aggregation </topics/db/aggregation>`. 2118 2134 2135 Note that in addition to taking a named field, aggregation 2136 functions can take `:ref:`F() objects <query-expressions>`. 2137 2138 .. admonition:: Default aliases 2139 2140 When using ``F()`` objects, note that there is no default alias. 2141 2119 2142 Avg 2120 2143 ~~~ 2121 2144 -
tests/modeltests/aggregation/tests.py
diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py index a35dbb3..a5d3a4e 100644
a b import datetime 4 4 from decimal import Decimal 5 5 6 6 from django.db.models import Avg, Sum, Count, Max, Min 7 from django.db.models import F 7 8 from django.test import TestCase, Approximate 8 9 9 10 from .models import Author, Publisher, Book, Store … … class BaseAggregateTestCase(TestCase): 63 64 self.assertEqual(len(vals), 1) 64 65 self.assertAlmostEqual(vals["amazon_mean"], 4.08, places=2) 65 66 67 def test_aggregate_f_expression(self): 68 vals = Book.objects.all().aggregate(price_per_page=Avg(F('price')*1.0/F('pages'))) 69 self.assertEqual(len(vals), 1) 70 self.assertAlmostEqual(vals["price_per_page"], 0.0745110754864109, places=2) 71 72 def test_annotate_f_expression(self): 73 self.assertQuerysetAlmostEqual( 74 Book.objects.all().annotate(price_per_page=F('price')*1.0/F('pages')), [ 75 0.0671140939597315, 76 0.0437310606060606, 77 0.0989666666666667, 78 0.0848285714285714, 79 0.0731448763250883, 80 0.0792811839323467, 81 ], 82 lambda b: b.price_per_page, 83 places=4 84 ) 85 86 self.assertQuerysetAlmostEqual( 87 Publisher.objects.all().annotate(price_per_page=Avg(F('book__price')*1.0/F('book__pages'))), [ 88 0.0830403803131991, 89 0.0437310606060606, 90 0.0789867238768299, 91 0.0792811839323467, 92 ], 93 lambda p: p.price_per_page, 94 places=4 95 ) 96 66 97 def test_annotate_basic(self): 67 98 self.assertQuerysetEqual( 68 99 Book.objects.annotate().order_by('pk'), [