Ticket #7210: expr.diff
File expr.diff, 23.9 KB (added by , 16 years ago) |
---|
-
django/db/models/__init__.py
diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 86763d9..f5ab8eb 100644
a b from django.core import validators 4 4 from django.db import connection 5 5 from django.db.models.loading import get_apps, get_app, get_models, get_model, register_models 6 6 from django.db.models.query import Q 7 from django.db.models.sql.expressions import F 7 8 from django.db.models.manager import Manager 8 9 from django.db.models.base import Model, AdminOptions 9 10 from django.db.models.fields import * -
new file django/db/models/sql/expressions.py
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py new file mode 100644 index 0000000..0a933f5
- + 1 from copy import deepcopy 2 from datetime import datetime 3 4 from django.utils import tree 5 from django.core.exceptions import FieldError 6 from django.db import connection 7 from django.db.models.fields import Field, FieldDoesNotExist 8 from django.db.models.query_utils import QueryWrapper 9 10 class Expression(object): 11 """ 12 Base class for all sql expressions, expected by QuerySet.update. 13 """ 14 # Arithmetic connection types 15 ADD = '+' 16 SUB = '-' 17 MUL = '*' 18 DIV = '/' 19 MOD = '%%' 20 21 # Bitwise connection types 22 AND = '&' 23 OR = '|' 24 25 def _combine(self, other, conn, reversed, node=None): 26 if reversed: 27 obj = ExpressionNode([Literal(other)], conn) 28 obj.add(node or self, conn) 29 else: 30 obj = node or ExpressionNode([self], conn) 31 if isinstance(other, Expression): 32 obj.add(other, conn) 33 else: 34 obj.add(Literal(other), conn) 35 return obj 36 37 def __add__(self, other): 38 return self._combine(other, self.ADD, False) 39 40 def __sub__(self, other): 41 return self._combine(other, self.SUB, False) 42 43 def __mul__(self, other): 44 return self._combine(other, self.MUL, False) 45 46 def __div__(self, other): 47 return self._combine(other, self.DIV, False) 48 49 def __mod__(self, other): 50 return self._combine(other, self.MOD, False) 51 52 def __and__(self, other): 53 return self._combine(other, self.AND, False) 54 55 def __or__(self, other): 56 return self._combine(other, self.OR, False) 57 58 def __radd__(self, other): 59 return self._combine(other, self.ADD, True) 60 61 def __rsub__(self, other): 62 return self._combine(other, self.SUB, True) 63 64 def __rmul__(self, other): 65 return self._combine(other, self.MUL, True) 66 67 def __rdiv__(self, other): 68 return self._combine(other, self.DIV, True) 69 70 def __rmod__(self, other): 71 return self._combine(other, self.MOD, True) 72 73 def __rand__(self, other): 74 return self._combine(other, self.AND, True) 75 76 def __ror__(self, other): 77 return self._combine(other, self.OR, True) 78 79 def as_sql(self, opts, field, prep_func=None, qn=None): 80 raise NotImplementedError 81 82 class ExpressionNode(Expression, tree.Node): 83 default = None 84 85 def __init__(self, children=None, connector=None, negated=False): 86 if children is not None and len(children) > 1 and connector is None: 87 raise TypeError('You have to specify a connector.') 88 super(ExpressionNode, self).__init__(children, connector, negated) 89 90 def _combine(self, *args, **kwargs): 91 return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs) 92 93 def as_sql(self, opts, field, prep_func=None, qn=None, node=None): 94 if not qn: 95 qn = connection.ops.quote_name 96 if node is None: 97 node = self 98 99 result = [] 100 result_params = [] 101 for child in node.children: 102 if hasattr(child, 'as_sql'): 103 sql, params = child.as_sql(opts, field, prep_func, qn) 104 format = '%s' 105 else: 106 sql, params = self.as_sql(opts, field, prep_func, qn, child) 107 if len(child.children) > 1: 108 format = '(%s)' 109 else: 110 format = '%s' 111 if sql: 112 result.append(format % sql) 113 result_params.extend(params) 114 conn = ' %s ' % node.connector 115 return conn.join(result), result_params 116 117 class Literal(Expression): 118 """ 119 An expression representing the given value. 120 """ 121 def __init__(self, value): 122 self.value = value 123 124 def as_sql(self, opts, field, prep_func=None, qn=None): 125 if self.value is None: 126 return 'NULL', () 127 128 if isinstance(self.value, datetime): 129 sql = connection.ops.datetime_cast_sql() 130 else: 131 sql = '%s' 132 params = prep_func and prep_func(self.value) or (self.value,) 133 134 if isinstance(params, QueryWrapper): 135 return params.data 136 return sql, params 137 138 class F(Expression): 139 """ 140 An expression representing the value of the given field. 141 """ 142 def __init__(self, name): 143 self.name = name 144 145 def as_sql(self, opts, field, prep_func=None, qn=None): 146 if not qn: 147 qn = connection.ops.quote_name 148 149 try: 150 src_field = opts.get_field(self.name) 151 except FieldDoesNotExist: 152 names = opts.get_all_field_names() 153 raise FieldError('Cannot resolve keyword %r into field. ' 154 'Choices are: %s' % (self.name, ', '.join(names))) 155 156 field_sql = connection.ops.field_cast_sql(src_field.db_type()) 157 lhs = '%s.%s' % (qn(opts.db_table), qn(src_field.column)) 158 return field_sql % lhs, () -
django/db/models/sql/query.py
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index 3044882..567879b 100644
a b class Query(object): 250 250 # get_from_clause() for details. 251 251 from_, f_params = self.get_from_clause() 252 252 253 where, w_params = self.where.as_sql( qn=self.quote_name_unless_alias)253 where, w_params = self.where.as_sql(self.get_meta(), qn=self.quote_name_unless_alias) 254 254 params = list(self.extra_select_params) 255 255 256 256 result = ['SELECT'] -
django/db/models/sql/subqueries.py
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 28436ab..b64639c 100644
a b from django.db.models.sql.constants import * 7 7 from django.db.models.sql.datastructures import Date 8 8 from django.db.models.sql.query import Query 9 9 from django.db.models.sql.where import AND 10 from django.db.models.sql.expressions import Expression, Literal 10 11 11 12 __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 12 13 'CountQuery'] … … class DeleteQuery(Query): 24 25 assert len(self.tables) == 1, \ 25 26 "Can only delete from one table at a time." 26 27 result = ['DELETE FROM %s' % self.quote_name_unless_alias(self.tables[0])] 27 where, params = self.where.as_sql( )28 where, params = self.where.as_sql(self.get_meta()) 28 29 result.append('WHERE %s' % where) 29 30 return ' '.join(result), tuple(params) 30 31 … … class UpdateQuery(Query): 126 127 result = ['UPDATE %s' % qn(table)] 127 128 result.append('SET') 128 129 values, update_params = [], [] 129 for name, val, placeholder in self.values: 130 if val is not None: 131 values.append('%s = %s' % (qn(name), placeholder)) 132 update_params.append(val) 133 else: 134 values.append('%s = NULL' % qn(name)) 130 for name, sql, params in self.values: 131 values.append('%s = %s' % (qn(name), sql)) 132 update_params.extend(params) 135 133 result.append(', '.join(values)) 136 where, params = self.where.as_sql( )134 where, params = self.where.as_sql(self.get_meta()) 137 135 if where: 138 136 result.append('WHERE %s' % where) 139 137 return ' '.join(result), tuple(update_params + params) … … class UpdateQuery(Query): 207 205 self.where.add((None, f.column, f, 'in', 208 206 pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), 209 207 AND) 210 self.values = [(related_field.column, None, '%s')]208 self.values = [(related_field.column, 'NULL', ())] 211 209 self.execute_sql(None) 212 210 213 211 def add_update_values(self, values): … … class UpdateQuery(Query): 232 230 """ 233 231 from django.db.models.base import Model 234 232 for field, model, val in values_seq: 235 # FIXME: Some sort of db_prep_* is probably more appropriate here. 236 if field.rel and isinstance(val, Model): 237 val = val.pk 238 239 # Getting the placeholder for the field. 240 if hasattr(field, 'get_placeholder'): 241 placeholder = field.get_placeholder(val) 233 if isinstance(val, Expression): 234 expr = val 242 235 else: 243 placeholder = '%s' 236 expr = Literal(val) 237 238 sql, params = expr.as_sql( 239 self.get_meta(), 240 field, 241 lambda x: field.get_db_prep_lookup('exact', field.get_db_prep_save(x)), 242 self.connection.ops.quote_name) 244 243 245 244 if model: 246 self.add_related_update(model, field.column, val, placeholder)245 self.add_related_update(model, field.column, sql, params) 247 246 else: 248 self.values.append((field.column, val, placeholder))247 self.values.append((field.column, sql, params)) 249 248 250 def add_related_update(self, model, column, value, placeholder):249 def add_related_update(self, model, column, sql, params): 251 250 """ 252 251 Adds (name, value) to an update query for an ancestor model. 253 252 254 253 Updates are coalesced so that we only run one update query per ancestor. 255 254 """ 256 255 try: 257 self.related_updates[model].append((column, value, placeholder))256 self.related_updates[model].append((column, sql, params)) 258 257 except KeyError: 259 self.related_updates[model] = [(column, value, placeholder)]258 self.related_updates[model] = [(column, sql, params)] 260 259 261 260 def get_related_updates(self): 262 261 """ … … class InsertQuery(Query): 312 311 parameters. This provides a way to insert NULL and DEFAULT keywords 313 312 into the query, for example. 314 313 """ 315 placeholders, values = [],[]314 values = [] 316 315 for field, val in insert_values: 317 if hasattr(field, 'get_placeholder'):318 # Some fields (e.g. geo fields) need special munging before319 # they can be inserted.320 placeholders.append(field.get_placeholder(val))321 else:322 placeholders.append('%s')323 324 316 self.columns.append(field.column) 325 317 values.append(val) 326 318 if raw_values: 327 319 self.values.extend(values) 328 320 else: 329 321 self.params += tuple(values) 330 self.values.extend( placeholders)322 self.values.extend(['%s'] * len(values)) 331 323 332 324 class DateQuery(Query): 333 325 """ -
django/db/models/sql/where.py
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 14e5448..e98596b 100644
a b Code to manage the creation and SQL rendering of 'where' constraints. 4 4 import datetime 5 5 6 6 from django.utils import tree 7 from django.utils.functional import curry 7 8 from django.db import connection 8 9 from django.db.models.fields import Field 9 10 from django.db.models.query_utils import QueryWrapper 11 from django.db.models.sql.expressions import Expression, Literal 10 12 from datastructures import EmptyResultSet, FullResultSet 11 13 12 14 # Connection types … … class WhereNode(tree.Node): 26 28 """ 27 29 default = AND 28 30 29 def as_sql(self, node=None, qn=None):31 def as_sql(self, opts, node=None, qn=None): 30 32 """ 31 33 Returns the SQL version of the where clause and the value to be 32 34 substituted in. Returns None, None if this node is empty. … … class WhereNode(tree.Node): 47 49 for child in node.children: 48 50 try: 49 51 if hasattr(child, 'as_sql'): 50 sql, params = child.as_sql( qn=qn)52 sql, params = child.as_sql(opts, qn=qn) 51 53 format = '(%s)' 52 54 elif isinstance(child, tree.Node): 53 sql, params = self.as_sql( child, qn)55 sql, params = self.as_sql(opts, child, qn) 54 56 if child.negated: 55 57 format = 'NOT (%s)' 56 58 elif len(child.children) == 1: … … class WhereNode(tree.Node): 58 60 else: 59 61 format = '(%s)' 60 62 else: 61 sql, params = self.make_atom( child, qn)63 sql, params = self.make_atom(opts, child, qn) 62 64 format = '%s' 63 65 except EmptyResultSet: 64 66 if node.connector == AND and not node.negated: … … class WhereNode(tree.Node): 86 88 conn = ' %s ' % node.connector 87 89 return conn.join(result), result_params 88 90 89 def make_atom(self, child, qn):91 def make_atom(self, opts, child, qn): 90 92 """ 91 93 Turn a tuple (table_alias, field_name, field_class, lookup_type, value) 92 94 into valid SQL. … … class WhereNode(tree.Node): 99 101 lhs = '%s.%s' % (qn(table_alias), qn(name)) 100 102 else: 101 103 lhs = qn(name) 104 if not field: 105 field = Field() 102 106 db_type = field and field.db_type() or None 103 107 field_sql = connection.ops.field_cast_sql(db_type) % lhs 108 prep_func = curry(field.get_db_prep_lookup, lookup_type) 104 109 105 if isinstance(value, datetime.datetime): 106 cast_sql = connection.ops.datetime_cast_sql() 107 else: 108 cast_sql = '%s' 110 if lookup_type in connection.operators: 111 if isinstance(value, Expression): 112 sql, params = value.as_sql(opts, field, prep_func, qn) 113 else: 114 sql, params = Literal(value).as_sql(opts, field, prep_func, qn) 109 115 110 if field: 111 params = field.get_db_prep_lookup(lookup_type, value) 112 else: 113 params = Field().get_db_prep_lookup(lookup_type, value) 116 format = '%s %s' % ( 117 connection.ops.lookup_cast(lookup_type), 118 connection.operators[lookup_type]) 119 return format % (field_sql, sql), params 120 121 if isinstance(value, Expression): 122 raise TypeError('Invalid lookup_type for use with %s object: %r' % ( 123 value.__class__.__name__, lookup_type)) 124 125 params = prep_func(value) 114 126 if isinstance(params, QueryWrapper): 115 127 extra, params = params.data 116 128 else: 117 129 extra = '' 118 130 119 if lookup_type in connection.operators:120 format = "%s %%s %s" % (connection.ops.lookup_cast(lookup_type),121 extra)122 return (format % (field_sql,123 connection.operators[lookup_type] % cast_sql), params)124 125 131 if lookup_type == 'in': 126 132 if not value: 127 133 raise EmptyResultSet 128 134 if extra: 129 135 return ('%s IN %s' % (field_sql, extra), params) 130 return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), 131 params) 136 return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), params) 132 137 elif lookup_type in ('range', 'year'): 133 138 return ('%s BETWEEN %%s and %%s' % field_sql, params) 134 139 elif lookup_type in ('month', 'day'): … … class EverythingNode(object): 164 169 """ 165 170 A node that matches everything. 166 171 """ 167 def as_sql(self, qn=None):172 def as_sql(self, opts, qn=None): 168 173 raise FullResultSet 169 174 170 175 def relabel_aliases(self, change_map, node=None): -
new file tests/modeltests/expressions/models.py
diff --git a/tests/modeltests/expressions/models.py b/tests/modeltests/expressions/models.py new file mode 100644 index 0000000..bb92ee8
- + 1 """ 2 Tests for the update() queryset method that allows in-place, multi-object 3 updates. 4 """ 5 6 from django.db import models 7 8 # 9 # Model for testing arithmetic expressions. 10 # 11 12 class Number(models.Model): 13 integer = models.IntegerField() 14 float = models.FloatField(null=True) 15 16 def __unicode__(self): 17 return u'%i, %.3f' % (self.integer, self.float) 18 19 # 20 # A more ordinary use case. 21 # 22 23 class Employee(models.Model): 24 firstname = models.CharField(max_length=50) 25 lastname = models.CharField(max_length=50) 26 27 def __unicode__(self): 28 return u'%s %s' % (self.firstname, self.lastname) 29 30 class Company(models.Model): 31 name = models.CharField(max_length=100) 32 num_employees = models.PositiveIntegerField() 33 num_chairs = models.PositiveIntegerField() 34 ceo = models.ForeignKey( 35 Employee, 36 related_name='company_ceo_set') 37 point_of_contact = models.ForeignKey( 38 Employee, 39 related_name='company_point_of_contact_set', 40 null=True) 41 42 def __unicode__(self): 43 return self.name 44 45 46 __test__ = {'API_TESTS': """ 47 >>> from django.db.models import F 48 49 >>> Number(integer=-1).save() 50 >>> Number(integer=42).save() 51 >>> Number(integer=1337).save() 52 53 We can fill a value in all objects with an other value of the same object. 54 55 >>> Number.objects.update(float=F('integer')) 56 >>> Number.objects.all() 57 [<Number: -1, -1.000>, <Number: 42, 42.000>, <Number: 1337, 1337.000>] 58 59 We can increment a value of all objects in a query set. 60 61 >>> Number.objects.filter(integer__gt=0).update(integer=F('integer') + 1) 62 >>> Number.objects.all() 63 [<Number: -1, -1.000>, <Number: 43, 42.000>, <Number: 1338, 1337.000>] 64 65 We can filter for objects, where a value is not equals the value of an other field. 66 67 >>> Number.objects.exclude(float=F('integer')) 68 [<Number: 43, 42.000>, <Number: 1338, 1337.000>] 69 70 Complex expressions of different connection types are possible. 71 72 >>> n = Number.objects.create(integer=10, float=123.45) 73 >>> Number.objects.filter(pk=n.pk).update(float=F('integer') + F('float') * 2) 74 >>> Number.objects.get(pk=n.pk) 75 <Number: 10, 256.900> 76 77 All supported operators, work as expected in native and reverse order. 78 79 >>> from operator import add, sub, mul, div, mod, and_, or_ 80 >>> for op in (add, sub, mul, div, mod, and_, or_): 81 ... n = Number.objects.create(integer=42, float=15.) 82 ... Number.objects.filter(pk=n.pk).update( 83 ... integer=op(F('integer'), 15), float=op(42., F('float'))) 84 ... Number.objects.get(pk=n.pk) 85 <Number: 57, 57.000> 86 <Number: 27, 27.000> 87 <Number: 630, 630.000> 88 <Number: 3, 2.800> 89 <Number: 12, 12.000> 90 <Number: 10, 10.000> 91 <Number: 47, 47.000> 92 93 94 >>> Company(name='Example Inc.', num_employees=2300, num_chairs=5, 95 ... ceo=Employee.objects.create(firstname='Joe', lastname='Smith')).save() 96 >>> Company(name='Foobar Ltd.', num_employees=3, num_chairs=3, 97 ... ceo=Employee.objects.create(firstname='Frank', lastname='Meyer')).save() 98 >>> Company(name='Test GmbH', num_employees=32, num_chairs=1, 99 ... ceo=Employee.objects.create(firstname='Max', lastname='Mustermann')).save() 100 101 We can filter for companies where the number of employees is greater than the 102 number of chairs. 103 104 >>> Company.objects.filter(num_employees__gt=F('num_chairs')) 105 [<Company: Example Inc.>, <Company: Test GmbH>] 106 107 The relation of a foreign key can become copied over to an other foreign key. 108 109 >>> Company.objects.update(point_of_contact=F('ceo')) 110 >>> [c.point_of_contact for c in Company.objects.all()] 111 [<Employee: Joe Smith>, <Employee: Frank Meyer>, <Employee: Max Mustermann>] 112 113 """} -
tests/modeltests/update/models.py
diff --git a/tests/modeltests/update/models.py b/tests/modeltests/update/models.py index 8a35b61..1df1811 100644
a b updates. 4 4 """ 5 5 6 6 from django.db import models 7 from django.conf import settings 7 8 8 class DataPoint(models.Model):9 class Product(models.Model): 9 10 name = models.CharField(max_length=20) 10 value= models.CharField(max_length=20)11 another_value = models.CharField(max_length=20, blank=True)11 description = models.CharField(max_length=20) 12 expires = models.DateTimeField(null=True) 12 13 13 14 def __unicode__(self): 14 15 return unicode(self.name) 15 16 16 class RelatedP oint(models.Model):17 class RelatedProduct(models.Model): 17 18 name = models.CharField(max_length=20) 18 data = models.ForeignKey( DataPoint)19 data = models.ForeignKey(Product) 19 20 20 21 def __unicode__(self): 21 22 return unicode(self.name) 22 23 23 24 24 25 __test__ = {'API_TESTS': """ 25 >>> DataPoint(name="d0", value="apple").save() 26 >>> DataPoint(name="d2", value="banana").save() 27 >>> d3 = DataPoint(name="d3", value="banana") 28 >>> d3.save() 29 >>> RelatedPoint(name="r1", data=d3).save() 26 >>> from datetime import datetime 27 28 >>> Product(name="p0", description="apple").save() 29 >>> Product(name="p2", description="banana").save() 30 >>> p3 = Product(name="p3", description="banana") 31 >>> p3.save() 32 >>> RelatedProduct(name="r1", data=p3).save() 30 33 31 34 Objects are updated by first filtering the candidates into a queryset and then 32 35 calling the update() method. It executes immediately and returns nothing. 33 36 34 >>> DataPoint.objects.filter(value="apple").update(name="d1")35 >>> DataPoint.objects.filter(value="apple")36 [< DataPoint: d1>]37 >>> Product.objects.filter(description="apple").update(name="p1") 38 >>> Product.objects.filter(description="apple") 39 [<Product: p1>] 37 40 38 41 We can update multiple objects at once. 39 42 40 >>> DataPoint.objects.filter(value="banana").update(value="pineapple")41 >>> DataPoint.objects.get(name="d2").value43 >>> Product.objects.filter(description="banana").update(description="pineapple") 44 >>> Product.objects.get(name="p2").description 42 45 u'pineapple' 43 46 44 47 Foreign key fields can also be updated, although you can only update the object 45 48 referred to, not anything inside the related object. 46 49 47 >>> d = DataPoint.objects.get(name="d1") 48 >>> RelatedPoint.objects.filter(name="r1").update(data=d) 49 >>> RelatedPoint.objects.filter(data__name="d1") 50 [<RelatedPoint: r1>] 50 >>> p = Product.objects.get(name="p1") 51 >>> RelatedProduct.objects.filter(name="r1").update(data=p) 52 >>> RelatedProduct.objects.filter(data__name="p1") 53 [<RelatedProduct: r1>] 54 55 Multiple fields can be updated at once. If DATABASE_ENGINE is mysql microseconds 56 must be truncated. 57 58 >>> Product.objects.filter(description="pineapple").update( 59 ... description="fruit", 60 ... expires=datetime(2010, 1, 1, 12, 0, 0, 123456)) 61 >>> p = Product.objects.get(name="p2") 62 >>> p.description, p.expires 63 """} 51 64 52 Multiple fields can be updated at once 65 if settings.DATABASE_ENGINE == 'mysql': 66 __test__['API_TESTS'] += "(u'fruit', datetime.datetime(2010, 1, 1, 12, 0))" 67 else: 68 __test__['API_TESTS'] += "(u'fruit', datetime.datetime(2010, 1, 1, 12, 0, 0, 123456))" 53 69 54 >>> DataPoint.objects.filter(value="pineapple").update(value="fruit", another_value="peaches") 55 >>> d = DataPoint.objects.get(name="d2") 56 >>> d.value, d.another_value 57 (u'fruit', u'peaches') 70 __test__['API_TESTS'] += """ 58 71 59 72 In the rare case you want to update every instance of a model, update() is also 60 a manager method .73 a manager method and update with None works as well. 61 74 62 >>> DataPoint.objects.update(value='thing')63 >>> DataPoint.objects.values('value').distinct()64 [{' value': u'thing'}]75 >>> Product.objects.update(expires=None) 76 >>> Product.objects.values('expires').distinct() 77 [{'expires': None}] 65 78 66 79 We do not support update on already sliced query sets. 67 80 68 >>> DataPoint.objects.all()[:2].update(another_value='another thing')81 >>> Product.objects.all()[:2].update(another_value='another thing') 69 82 Traceback (most recent call last): 70 83 ... 71 84 AssertionError: Cannot update a query once a slice has been taken. 72 85 73 86 """ 74 }