Ticket #14030: 14030-2.patch
File 14030-2.patch, 20.1 KB (added by , 13 years ago) |
---|
-
django/db/models/aggregates.py
From cf61c2dd2f0afec3414de3dff8497fc461491b3c Mon Sep 17 00:00:00 2001 From: Nate Bragg <jonathan.bragg@alum.rpi.edu> Date: Thu, 19 Jan 2012 21:01:32 -0500 Subject: [PATCH] An attempt at rebasing out the changes required for supporting F expressions in aggregation from the more complex patch supporting conditional aggregation for #11305. Additional changes needed to make F expressions usable without being passed in inside an aggregation function. Also added some doc, and some tests. --- django/db/models/aggregates.py | 2 + django/db/models/sql/aggregates.py | 19 ++++++- django/db/models/sql/compiler.py | 39 ++++++++----- django/db/models/sql/expressions.py | 3 + django/db/models/sql/query.py | 96 +++++++++++++++++++-------------- django/db/models/sql/where.py | 10 ++++ django/test/testcases.py | 12 ++++ docs/ref/models/querysets.txt | 23 ++++++++ tests/modeltests/aggregation/tests.py | 31 +++++++++++ 9 files changed, 176 insertions(+), 59 deletions(-) diff --git a/django/db/models/aggregates.py b/django/db/models/aggregates.py index a2349cf..61848fe 100644
a b class Aggregate(object): 20 20 self.extra = extra 21 21 22 22 def _default_alias(self): 23 if hasattr(self.lookup, 'evaluate'): 24 raise ValueError('When aggregating over an expression, you need to give an alias.') 23 25 return '%s__%s' % (self.lookup, self.name.lower()) 24 26 default_alias = property(_default_alias) 25 27 -
django/db/models/sql/aggregates.py
diff --git a/django/db/models/sql/aggregates.py b/django/db/models/sql/aggregates.py index 207bc0c..27175fe 100644
a b 1 1 """ 2 2 Classes to represent the default SQL aggregate functions 3 3 """ 4 from django.db.models.sql.expressions import SQLEvaluator 4 5 5 6 class AggregateField(object): 6 7 """An internal field mockup used to identify aggregates in the … … class Aggregate(object): 66 67 tmp = computed_aggregate_field 67 68 else: 68 69 tmp = tmp.source 69 70 71 # We don't know the real source of this aggregate, and the 72 # aggregate doesn't define ordinal or computed either. So 73 # we default to computed for these cases. 74 if tmp is None: 75 tmp = computed_aggregate_field 70 76 self.field = tmp 71 77 72 78 def relabel_aliases(self, change_map): 73 79 if isinstance(self.col, (list, tuple)): 74 80 self.col = (change_map.get(self.col[0], self.col[0]), self.col[1]) 81 else: 82 self.col.relabel_aliases(change_map) 75 83 76 84 def as_sql(self, qn, connection): 77 85 "Return the aggregate, rendered as SQL." 78 86 87 col_params = [] 79 88 if hasattr(self.col, 'as_sql'): 80 field_name = self.col.as_sql(qn, connection) 89 if isinstance(self.col, SQLEvaluator): 90 field_name, col_params = self.col.as_sql(qn, connection) 91 else: 92 field_name = self.col.as_sql(qn, connection) 93 81 94 elif isinstance(self.col, (list, tuple)): 82 95 field_name = '.'.join([qn(c) for c in self.col]) 83 96 else: … … class Aggregate(object): 89 102 } 90 103 params.update(self.extra) 91 104 92 return self.sql_template % params105 return (self.sql_template % params, col_params) 93 106 94 107 95 108 class Avg(Aggregate): -
django/db/models/sql/compiler.py
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 72948f9..bf3cb25 100644
a b class SQLCompiler(object): 68 68 # as the pre_sql_setup will modify query state in a way that forbids 69 69 # another run of it. 70 70 self.refcounts_before = self.query.alias_refcount.copy() 71 out_cols = self.get_columns(with_col_aliases)71 out_cols, c_params = self.get_columns(with_col_aliases) 72 72 ordering, ordering_group_by = self.get_ordering() 73 73 74 74 distinct_fields = self.get_distinct() … … class SQLCompiler(object): 84 84 params = [] 85 85 for val in self.query.extra_select.itervalues(): 86 86 params.extend(val[1]) 87 # Extra-select comes before aggregation in the select list 88 params.extend(c_params) 87 89 88 90 result = ['SELECT'] 89 91 … … class SQLCompiler(object): 178 180 qn = self.quote_name_unless_alias 179 181 qn2 = self.connection.ops.quote_name 180 182 result = ['(%s) AS %s' % (col[0], qn2(alias)) for alias, col in self.query.extra_select.iteritems()] 183 query_params = [] 181 184 aliases = set(self.query.extra_select.keys()) 182 185 if with_aliases: 183 186 col_aliases = aliases.copy() … … class SQLCompiler(object): 220 223 aliases.update(new_aliases) 221 224 222 225 max_name_length = self.connection.ops.max_name_length() 223 result.extend([ 224 '%s%s' % ( 225 aggregate.as_sql(qn, self.connection), 226 alias is not None 227 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 228 or '' 226 for alias, aggregate in self.query.aggregate_select.items(): 227 sql, params = aggregate.as_sql(qn, self.connection) 228 result.append( 229 '%s%s' % ( 230 sql, 231 alias is not None 232 and ' AS %s' % qn(truncate_name(alias, max_name_length)) 233 or '' 234 ) 229 235 ) 230 for alias, aggregate in self.query.aggregate_select.items() 231 ]) 236 query_params.extend(params) 232 237 233 238 for table, col in self.query.related_select_cols: 234 239 r = '%s.%s' % (qn(table), qn(col)) … … class SQLCompiler(object): 243 248 col_aliases.add(col) 244 249 245 250 self._select_aliases = aliases 246 return result 251 return result, query_params 247 252 248 253 def get_default_columns(self, with_aliases=False, col_aliases=None, 249 254 start_alias=None, opts=None, as_pairs=False, local_only=False): … … class SQLAggregateCompiler(SQLCompiler): 1046 1051 """ 1047 1052 if qn is None: 1048 1053 qn = self.quote_name_unless_alias 1054 buf = [] 1055 a_params = [] 1056 for aggregate in self.query.aggregate_select.values(): 1057 sql, query_params = aggregate.as_sql(qn, self.connection) 1058 buf.append(sql) 1059 a_params.extend(query_params) 1060 aggregate_sql = ', '.join(buf) 1049 1061 1050 1062 sql = ('SELECT %s FROM (%s) subquery' % ( 1051 ', '.join([ 1052 aggregate.as_sql(qn, self.connection) 1053 for aggregate in self.query.aggregate_select.values() 1054 ]), 1063 aggregate_sql, 1055 1064 self.query.subquery) 1056 1065 ) 1057 params = self.query.sub_params1066 params = tuple(a_params) + (self.query.sub_params) 1058 1067 return (sql, params) 1059 1068 1060 1069 class SQLDateCompiler(SQLCompiler): -
django/db/models/sql/expressions.py
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py index 1bbf742..f9c23a9 100644
a b class SQLEvaluator(object): 65 65 for child in node.children: 66 66 if hasattr(child, 'evaluate'): 67 67 sql, params = child.evaluate(self, qn, connection) 68 if isinstance(sql, tuple): 69 expression_params.extend(sql[1]) 70 sql = sql[0] 68 71 else: 69 72 sql, params = '%s', (child,) 70 73 -
django/db/models/sql/query.py
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index ed2bc06..5624446 100644
a b from django.utils.encoding import force_unicode 14 14 from django.utils.tree import Node 15 15 from django.db import connections, DEFAULT_DB_ALIAS 16 16 from django.db.models import signals 17 from django.db.models.aggregates import Aggregate 17 18 from django.db.models.expressions import ExpressionNode 18 19 from django.db.models.fields import FieldDoesNotExist 19 20 from django.db.models.query_utils import InvalidQuery … … class Query(object): 322 323 This is required because of the predisposition of certain backends 323 324 to return Decimal and long types when they are not needed. 324 325 """ 326 is_ordinal = getattr(aggregate,"is_ordinal",False) 327 is_computed = getattr(aggregate,"is_computed",True) 325 328 if value is None: 326 if aggregate.is_ordinal:329 if is_ordinal: 327 330 return 0 328 331 # Return None as-is 329 332 return value 330 elif aggregate.is_ordinal:333 elif is_ordinal: 331 334 # Any ordinal aggregate (e.g., count) returns an int 332 335 return int(value) 333 elif aggregate.is_computed:336 elif is_computed: 334 337 # Any computed aggregate (e.g., avg) returns a float 335 338 return float(value) 336 339 else: … … class Query(object): 987 990 Adds a single aggregate expression to the Query 988 991 """ 989 992 opts = model._meta 990 field_list = aggregate.lookup.split(LOOKUP_SEP) 991 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 992 # Aggregate is over an annotation 993 field_name = field_list[0] 994 col = field_name 995 source = self.aggregates[field_name] 996 if not is_summary: 997 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 998 aggregate.name, field_name, field_name)) 999 elif ((len(field_list) > 1) or 1000 (field_list[0] not in [i.name for i in opts.fields]) or 1001 self.group_by is None or 1002 not is_summary): 1003 # If: 1004 # - the field descriptor has more than one part (foo__bar), or 1005 # - the field descriptor is referencing an m2m/m2o field, or 1006 # - this is a reference to a model field (possibly inherited), or 1007 # - this is an annotation over a model field 1008 # then we need to explore the joins that are required. 1009 1010 field, source, opts, join_list, last, _ = self.setup_joins( 1011 field_list, opts, self.get_initial_alias(), False) 1012 1013 # Process the join chain to see if it can be trimmed 1014 col, _, join_list = self.trim_joins(source, join_list, last, False) 1015 1016 # If the aggregate references a model or field that requires a join, 1017 # those joins must be LEFT OUTER - empty join rows must be returned 1018 # in order for zeros to be returned for those aggregates. 1019 for column_alias in join_list: 1020 self.promote_alias(column_alias, unconditional=True) 1021 1022 col = (join_list[-1], col) 993 if hasattr(aggregate, 'evaluate'): 994 self.aggregates[alias] = SQLEvaluator(aggregate, self) 995 return 996 if hasattr(aggregate.lookup, 'evaluate'): 997 # If lookup is a query expression, evaluate it 998 col = SQLEvaluator(aggregate.lookup, self) 999 # TODO: find out the real source of this field. If any field has 1000 # is_computed, then source can be set to is_computed. 1001 source = None 1023 1002 else: 1024 # The simplest cases. No joins required - 1025 # just reference the provided column alias. 1026 field_name = field_list[0] 1027 source = opts.get_field(field_name) 1028 col = field_name 1003 field_list = aggregate.lookup.split(LOOKUP_SEP) 1004 join_list = [] 1005 if len(field_list) == 1 and aggregate.lookup in self.aggregates: 1006 # Aggregate is over an annotation 1007 field_name = field_list[0] 1008 col = field_name 1009 source = self.aggregates[field_name] 1010 if not is_summary: 1011 raise FieldError("Cannot compute %s('%s'): '%s' is an aggregate" % ( 1012 aggregate.name, field_name, field_name)) 1013 elif ((len(field_list) > 1) or 1014 (field_list[0] not in [i.name for i in opts.fields]) or 1015 self.group_by is None or 1016 not is_summary): 1017 # If: 1018 # - the field descriptor has more than one part (foo__bar), or 1019 # - the field descriptor is referencing an m2m/m2o field, or 1020 # - this is a reference to a model field (possibly inherited), or 1021 # - this is an annotation over a model field 1022 # then we need to explore the joins that are required. 1023 1024 field, source, opts, join_list, last, _ = self.setup_joins( 1025 field_list, opts, self.get_initial_alias(), False) 1026 1027 # Process the join chain to see if it can be trimmed 1028 col, _, join_list = self.trim_joins(source, join_list, last, False) 1029 1030 # If the aggregate references a model or field that requires a join, 1031 # those joins must be LEFT OUTER - empty join rows must be returned 1032 # in order for zeros to be returned for those aggregates. 1033 for column_alias in join_list: 1034 self.promote_alias(column_alias, unconditional=True) 1035 1036 col = (join_list[-1], col) 1037 else: 1038 # The simplest cases. No joins required - 1039 # just reference the provided column alias. 1040 field_name = field_list[0] 1041 source = opts.get_field(field_name) 1042 col = field_name 1029 1043 1030 1044 # Add the aggregate to the query 1031 1045 aggregate.add_to_query(self, alias, col=col, source=source, is_summary=is_summary) -
django/db/models/sql/where.py
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 1455ba6..8b530bd 100644
a b class WhereNode(tree.Node): 139 139 it. 140 140 """ 141 141 lvalue, lookup_type, value_annot, params_or_value = child 142 additional_params = [] 142 143 if hasattr(lvalue, 'process'): 143 144 try: 144 145 lvalue, params = lvalue.process(lookup_type, params_or_value, connection) … … class WhereNode(tree.Node): 153 154 else: 154 155 # A smart object with an as_sql() method. 155 156 field_sql = lvalue.as_sql(qn, connection) 157 if isinstance(field_sql, tuple): 158 # It returned also params 159 additional_params.extend(field_sql[1]) 160 field_sql = field_sql[0] 156 161 157 162 if value_annot is datetime.datetime: 158 163 cast_sql = connection.ops.datetime_cast_sql() … … class WhereNode(tree.Node): 161 166 162 167 if hasattr(params, 'as_sql'): 163 168 extra, params = params.as_sql(qn, connection) 169 if isinstance(extra, tuple): 170 params = params + tuple(extra[1]) 171 extra = extra[0] 164 172 cast_sql = '' 165 173 else: 166 174 extra = '' … … class WhereNode(tree.Node): 170 178 lookup_type = 'isnull' 171 179 value_annot = True 172 180 181 additional_params.extend(params) 182 params = additional_params 173 183 if lookup_type in connection.operators: 174 184 format = "%s %%s %%s" % (connection.ops.lookup_cast(lookup_type),) 175 185 return (format % (field_sql, -
django/test/testcases.py
diff --git a/django/test/testcases.py b/django/test/testcases.py index af455a2..17326ee 100644
a b class TransactionTestCase(SimpleTestCase): 774 774 return self.assertEqual(set(map(transform, qs)), set(values)) 775 775 return self.assertEqual(map(transform, qs), values) 776 776 777 def assertQuerysetAlmostEqual(self, qs, values, transform=repr, ordered=True, places=7): 778 # This could have been done with iterating zip(map(transform, qs), values), 779 # checking each with assertAlmostEqual, which rounds the difference of each 780 # pair, but this way you get much nicer error messages, and you can have an 781 # unordered comparison, at the cost of a half a digit of accuracy. 782 round_to = lambda v: round(v,places) 783 tqs = map(round_to, map(transform, qs) ) 784 tvs = map(round_to, values) 785 if not ordered: 786 return self.assertEqual(set(tqs), set(tvs)) 787 return self.assertEqual(tqs, tvs) 788 777 789 def assertNumQueries(self, num, func=None, *args, **kwargs): 778 790 using = kwargs.pop("using", DEFAULT_DB_ALIAS) 779 791 conn = connections[using] -
docs/ref/models/querysets.txt
diff --git a/docs/ref/models/querysets.txt b/docs/ref/models/querysets.txt index 103cae1..a0aab2b 100644
a b control the name of the annotation:: 246 246 >>> q[0].number_of_entries 247 247 42 248 248 249 In addition to aggregation functions, `:ref:`F() objects <query-expressions>` 250 can be used to perform a specific mathematical operation:: 251 252 # The 1.0 is to force float conversion 253 >>> q = Entry.objects.annotate(cpb_ratio=F('n_comments')*1.0/F('n_pingbacks')) 254 # The ratio of comments to pingbacks for the first blog entry 255 >>> q[0].cpb_ratio 256 0.0625 257 249 258 For an in-depth discussion of aggregation, see :doc:`the topic guide on 250 259 Aggregation </topics/db/aggregation>`. 251 260 … … control the name of the aggregation value that is returned:: 1483 1492 >>> q = Blog.objects.aggregate(number_of_entries=Count('entry')) 1484 1493 {'number_of_entries': 16} 1485 1494 1495 Inside aggregation functions, `:ref:`F() objects <query-expressions>` 1496 can be used to perform a specific mathematical operation:: 1497 1498 # The 1.0 is to force float conversion 1499 >>> q = Entry.objects.aggregate(avg_cpb_ratio=Avg(F('n_comments')*1.0/F('n_pingbacks'))) 1500 {'avg_cpb_ratio': 0.125} 1501 1486 1502 For an in-depth discussion of aggregation, see :doc:`the topic guide on 1487 1503 Aggregation </topics/db/aggregation>`. 1488 1504 … … Django provides the following aggregation functions in the 2117 2133 aggregate functions, see 2118 2134 :doc:`the topic guide on aggregation </topics/db/aggregation>`. 2119 2135 2136 Note that in addition to taking a named field, aggregation 2137 functions can take `:ref:`F() objects <query-expressions>`. 2138 2139 .. admonition:: Default aliases 2140 2141 When using ``F()`` objects, note that there is no default alias. 2142 2120 2143 Avg 2121 2144 ~~~ 2122 2145 -
tests/modeltests/aggregation/tests.py
diff --git a/tests/modeltests/aggregation/tests.py b/tests/modeltests/aggregation/tests.py index a35dbb3..a5d3a4e 100644
a b import datetime 4 4 from decimal import Decimal 5 5 6 6 from django.db.models import Avg, Sum, Count, Max, Min 7 from django.db.models import F 7 8 from django.test import TestCase, Approximate 8 9 9 10 from .models import Author, Publisher, Book, Store … … class BaseAggregateTestCase(TestCase): 63 64 self.assertEqual(len(vals), 1) 64 65 self.assertAlmostEqual(vals["amazon_mean"], 4.08, places=2) 65 66 67 def test_aggregate_f_expression(self): 68 vals = Book.objects.all().aggregate(price_per_page=Avg(F('price')*1.0/F('pages'))) 69 self.assertEqual(len(vals), 1) 70 self.assertAlmostEqual(vals["price_per_page"], 0.0745110754864109, places=2) 71 72 def test_annotate_f_expression(self): 73 self.assertQuerysetAlmostEqual( 74 Book.objects.all().annotate(price_per_page=F('price')*1.0/F('pages')), [ 75 0.0671140939597315, 76 0.0437310606060606, 77 0.0989666666666667, 78 0.0848285714285714, 79 0.0731448763250883, 80 0.0792811839323467, 81 ], 82 lambda b: b.price_per_page, 83 places=4 84 ) 85 86 self.assertQuerysetAlmostEqual( 87 Publisher.objects.all().annotate(price_per_page=Avg(F('book__price')*1.0/F('book__pages'))), [ 88 0.0830403803131991, 89 0.0437310606060606, 90 0.0789867238768299, 91 0.0792811839323467, 92 ], 93 lambda p: p.price_per_page, 94 places=4 95 ) 96 66 97 def test_annotate_basic(self): 67 98 self.assertQuerysetEqual( 68 99 Book.objects.annotate().order_by('pk'), [