Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.3.patch
File 0001-Added-expression-support-for-QuerySet.update.3.patch, 15.1 KB (added by , 17 years ago) |
---|
-
django/db/models/__init__.py
From a439f12ed3c4bcd3c2bc78b67bf36f47043c51ee Mon Sep 17 00:00:00 2001 From: Sebastian Noack <sebastian.noack@gmail.com> Date: Thu, 8 May 2008 14:30:19 +0200 Subject: [PATCH] Added expression support for QuerySet.update. --- django/db/models/__init__.py | 1 + django/db/models/sql/expressions.py | 186 +++++++++++++++++++++++++++++++++++ django/db/models/sql/query.py | 2 +- django/db/models/sql/subqueries.py | 36 +++---- django/db/models/sql/where.py | 38 ++++--- 5 files changed, 225 insertions(+), 38 deletions(-) create mode 100644 django/db/models/sql/expressions.py diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py index 86763d9..f5ab8eb 100644
a b from django.core import validators 4 4 from django.db import connection 5 5 from django.db.models.loading import get_apps, get_app, get_models, get_model, register_models 6 6 from django.db.models.query import Q 7 from django.db.models.sql.expressions import F 7 8 from django.db.models.manager import Manager 8 9 from django.db.models.base import Model, AdminOptions 9 10 from django.db.models.fields import * -
new file django/db/models/sql/expressions.py
diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py new file mode 100644 index 0000000..2be6472
- + 1 from copy import deepcopy 2 from datetime import datetime 3 4 from django.utils import tree 5 from django.core.exceptions import FieldError 6 from django.db import connection 7 from django.db.models.fields import Field, FieldDoesNotExist 8 from django.db.models.query_utils import QueryWrapper 9 10 class Expression(object): 11 """ 12 Base class for all sql expressions, expected by QuerySet.update. 13 """ 14 # Arithmetic connection types 15 ADD = '+' 16 SUB = '-' 17 MUL = '*' 18 DIV = '/' 19 MOD = '%' 20 21 # Bitwise connection types 22 AND = '&' 23 OR = '|' 24 LSHIFT = '<<' 25 RSHIFT = '>>' 26 27 def _combine(self, other, conn, reversed, node=None): 28 if reversed: 29 obj = ExpressionNode([Literal(other)], conn) 30 obj.add(node or self, conn) 31 else: 32 obj = node or ExpressionNode([self], conn) 33 if isinstance(other, Expression): 34 obj.add(other, conn) 35 else: 36 obj.add(Literal(other), conn) 37 return obj 38 39 def __add__(self, other): 40 return self._combine(other, self.ADD, False) 41 42 def __sub__(self, other): 43 return self._combine(other, self.SUB, False) 44 45 def __mul__(self, other): 46 return self._combine(other, self.MUL, False) 47 48 def __div__(self, other): 49 return self._combine(other, self.DIV, False) 50 51 def __mod__(self, other): 52 return self._combine(other, self.MOD, False) 53 54 def __and__(self, other): 55 return self._combine(other, self.AND, False) 56 57 def __or__(self, other): 58 return self._combine(other, self.OR, False) 59 60 def __lshift__(self, other): 61 return self._combine(other, self.LSHIFT, False) 62 63 def __rshift__(self, other): 64 return self._combine(other, self.RSHIFT, False) 65 66 def __radd__(self, other): 67 return self._combine(other, self.ADD, True) 68 69 def __rsub__(self, other): 70 return self._combine(other, self.SUB, True) 71 72 def __rmul__(self, other): 73 return self._combine(other, self.MUL, True) 74 75 def __rdiv__(self, other): 76 return self._combine(other, self.DIV, True) 77 78 def __rmod__(self, other): 79 return self._combine(other, self.MOD, True) 80 81 def __rand__(self, other): 82 return self._combine(other, self.AND, True) 83 84 def __ror__(self, other): 85 return self._combine(other, self.OR, True) 86 87 def __lshift__(self, other): 88 return self._combine(other, self.LSHIFT, True) 89 90 def __rshift__(self, other): 91 return self._combine(other, self.RSHIFT, True) 92 93 def __invert__(self, node=None): 94 obj = node or ExpressionNode([self]) 95 obj.negate() 96 return obj 97 98 def as_sql(self, opts, field=None, lookup_type='exact', qn=None): 99 raise NotImplementedError 100 101 class ExpressionNode(Expression, tree.Node): 102 default = None 103 104 def __init__(self, children=None, connector=None, negated=False): 105 if children is not None and len(children) > 1 and connector is None: 106 raise TypeError('You have to specify a connector.') 107 super(ExpressionNode, self).__init__(children, connector, negated) 108 109 def _combine(self, *args, **kwargs): 110 return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs) 111 112 def __invert__(self): 113 return super(ExpressionNode, self).__invert__(node=deepcopy(self)) 114 115 def as_sql(self, opts, field=None, lookup_type='exact', qn=None, node=None): 116 if not qn: 117 qn = connection.ops.quote_name 118 if node is None: 119 node = self 120 121 result = [] 122 result_params = [] 123 for child in node.children: 124 if hasattr(child, 'as_sql'): 125 sql, params = child.as_sql(opts, field, qn=qn) 126 format = '%s' 127 else: 128 sql, params = self.as_sql(opts, field, qn=qn, node=child) 129 if child.negated: 130 format = '~%s' 131 else: 132 format = '%s' 133 if len(child.children) > 1: 134 format %= '(%s)' 135 if sql: 136 result.append(format % sql) 137 result_params.extend(params) 138 conn = ' %s ' % node.connector 139 return conn.join(result), result_params 140 141 class Literal(Expression): 142 """ 143 An expression representing the given value. 144 """ 145 def __init__(self, value): 146 self.value = value 147 148 def as_sql(self, opts, field=None, lookup_type='exact', qn=None): 149 if self.value is None: 150 return 'NULL', () 151 152 if field: 153 params = field.get_db_prep_lookup(lookup_type, self.value) 154 else: 155 params = Field().get_db_prep_lookup(lookup_type, self.value) 156 157 if isinstance(self.value, datetime): 158 sql = connection.ops.datetime_cast_sql() 159 else: 160 sql = '%s' 161 162 if isinstance(params, QueryWrapper): 163 return params.data 164 return sql, params 165 166 class F(Expression): 167 """ 168 An expression representing the value of the given field. 169 """ 170 def __init__(self, name): 171 self.name = name 172 173 def as_sql(self, opts, field=None, lookup_type='exact', qn=None): 174 if not qn: 175 qn = connection.ops.quote_name 176 177 try: 178 src_field = opts.get_field(self.name) 179 except FieldDoesNotExist: 180 names = opts.get_all_field_names() 181 raise FieldError('Cannot resolve keyword %r into field. ' 182 'Choices are: %s' % (self.name, ', '.join(names))) 183 184 field_sql = connection.ops.field_cast_sql(src_field.db_type()) 185 lhs = '%s.%s' % (qn(opts.db_table), qn(src_field.attname)) 186 return field_sql % lhs, () -
django/db/models/sql/query.py
diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py index a6957ba..bd1d030 100644
a b class Query(object): 250 250 # get_from_clause() for details. 251 251 from_, f_params = self.get_from_clause() 252 252 253 where, w_params = self.where.as_sql( qn=self.quote_name_unless_alias)253 where, w_params = self.where.as_sql(self.get_meta(), qn=self.quote_name_unless_alias) 254 254 params = list(self.extra_select_params) 255 255 256 256 result = ['SELECT'] -
django/db/models/sql/subqueries.py
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 7385cd0..251bab2 100644
a b from django.db.models.sql.constants import * 8 8 from django.db.models.sql.datastructures import RawValue, Date 9 9 from django.db.models.sql.query import Query 10 10 from django.db.models.sql.where import AND 11 from django.db.models.sql.expressions import Expression, Literal 11 12 12 13 __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 13 14 'CountQuery'] … … class UpdateQuery(Query): 126 127 result = ['UPDATE %s' % qn(table)] 127 128 result.append('SET') 128 129 values, update_params = [], [] 129 for name, val, placeholder in self.values: 130 if val is not None: 131 values.append('%s = %s' % (qn(name), placeholder)) 132 update_params.append(val) 133 else: 134 values.append('%s = NULL' % qn(name)) 130 for name, sql, params in self.values: 131 values.append('%s = %s' % (qn(name), sql)) 132 update_params.extend(params) 135 133 result.append(', '.join(values)) 136 134 where, params = self.where.as_sql() 137 135 if where: … … class UpdateQuery(Query): 207 205 self.where.add((None, f.column, f, 'in', 208 206 pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), 209 207 AND) 210 self.values = [(related_field.column, None, '%s')]208 self.values = [(related_field.column, 'NULL', ())] 211 209 self.execute_sql(None) 212 210 213 211 def add_update_values(self, values): … … class UpdateQuery(Query): 232 230 """ 233 231 from django.db.models.base import Model 234 232 for field, model, val in values_seq: 235 # FIXME: Some sort of db_prep_* is probably more appropriate here. 236 if field.rel and isinstance(val, Model): 237 val = val.pk 238 239 # Getting the placeholder for the field. 240 if hasattr(field, 'get_placeholder'): 241 placeholder = field.get_placeholder(val) 233 if isinstance(val, Expression): 234 expr = val 242 235 else: 243 placeholder = '%s' 236 expr = Literal(val) 237 238 sql, params = expr.as_sql( 239 field, self.get_meta(), qn=self.connection.ops.quote_name) 244 240 245 241 if model: 246 self.add_related_update(model, field.column, val, placeholder)242 self.add_related_update(model, field.column, sql, params) 247 243 else: 248 self.values.append((field.column, val, placeholder))244 self.values.append((field.column, sql, params)) 249 245 250 def add_related_update(self, model, column, value, placeholder):246 def add_related_update(self, model, column, sql, params): 251 247 """ 252 248 Adds (name, value) to an update query for an ancestor model. 253 249 254 250 Updates are coalesced so that we only run one update query per ancestor. 255 251 """ 256 252 try: 257 self.related_updates[model].append((column, value, placeholder))253 self.related_updates[model].append((column, sql, params)) 258 254 except KeyError: 259 self.related_updates[model] = [(column, value, placeholder)]255 self.related_updates[model] = [(column, sql, params)] 260 256 261 257 def get_related_updates(self): 262 258 """ -
django/db/models/sql/where.py
diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py index 3e8bfed..d333899 100644
a b from django.utils import tree 7 7 from django.db import connection 8 8 from django.db.models.fields import Field 9 9 from django.db.models.query_utils import QueryWrapper 10 from django.db.models.sql.expressions import Expression, Literal 10 11 from datastructures import EmptyResultSet, FullResultSet 11 12 12 13 # Connection types … … class WhereNode(tree.Node): 26 27 """ 27 28 default = AND 28 29 29 def as_sql(self, node=None, qn=None):30 def as_sql(self, opts, node=None, qn=None): 30 31 """ 31 32 Returns the SQL version of the where clause and the value to be 32 33 substituted in. Returns None, None if this node is empty. … … class WhereNode(tree.Node): 47 48 for child in node.children: 48 49 try: 49 50 if hasattr(child, 'as_sql'): 50 sql, params = child.as_sql( qn=qn)51 sql, params = child.as_sql(opts, qn=qn) 51 52 format = '(%s)' 52 53 elif isinstance(child, tree.Node): 53 sql, params = self.as_sql( child, qn)54 sql, params = self.as_sql(opts, child, qn) 54 55 if child.negated: 55 56 format = 'NOT (%s)' 56 57 elif len(child.children) == 1: … … class WhereNode(tree.Node): 58 59 else: 59 60 format = '(%s)' 60 61 else: 61 sql, params = self.make_atom( child, qn)62 sql, params = self.make_atom(opts, child, qn) 62 63 format = '%s' 63 64 except EmptyResultSet: 64 65 if node.connector == AND and not node.negated: … … class WhereNode(tree.Node): 86 87 conn = ' %s ' % node.connector 87 88 return conn.join(result), result_params 88 89 89 def make_atom(self, child, qn):90 def make_atom(self, opts, child, qn): 90 91 """ 91 92 Turn a tuple (table_alias, field_name, field_class, lookup_type, value) 92 93 into valid SQL. … … class WhereNode(tree.Node): 102 103 db_type = field and field.db_type() or None 103 104 field_sql = connection.ops.field_cast_sql(db_type) % lhs 104 105 105 if isinstance(value, datetime.datetime): 106 cast_sql = connection.ops.datetime_cast_sql() 107 else: 108 cast_sql = '%s' 106 if lookup_type in connection.operators: 107 if isinstance(value, Expression): 108 sql, params = value.as_sql(opts, field, lookup_type, qn) 109 else: 110 sql, params = Literal(value).as_sql(opts, field, lookup_type, qn) 111 112 format= '%s %s' % ( 113 connection.ops.lookup_cast(lookup_type), 114 connection.operators[lookup_type]) 115 return format % (field_sql, sql), params 116 117 if isinstance(value, Expression): 118 TypeError('Invalid lookup_type for use with %s object: %r' % ( 119 value.__class__.__name__, lookup_type)) 109 120 110 121 if field: 111 122 params = field.get_db_prep_lookup(lookup_type, value) … … class WhereNode(tree.Node): 116 127 else: 117 128 extra = '' 118 129 119 if lookup_type in connection.operators:120 format = "%s %%s %s" % (connection.ops.lookup_cast(lookup_type),121 extra)122 return (format % (field_sql,123 connection.operators[lookup_type] % cast_sql), params)124 125 130 if lookup_type == 'in': 126 131 if not value: 127 132 raise EmptyResultSet 128 133 if extra: 129 134 return ('%s IN %s' % (field_sql, extra), params) 130 return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), 131 params) 135 return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), params) 132 136 elif lookup_type in ('range', 'year'): 133 137 return ('%s BETWEEN %%s and %%s' % field_sql, params) 134 138 elif lookup_type in ('month', 'day'):