Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.patch
File 0001-Added-expression-support-for-QuerySet.update.patch, 8.5 KB (added by , 17 years ago) |
---|
-
new file django/db/models/sql/expressions.py
From e1b81cccb9881c21626aec01fa5b050c972a1b0c Mon Sep 17 00:00:00 2001 From: Sebastian Noack <sebastian.noack@gmail.com> Date: Thu, 8 May 2008 14:30:19 +0200 Subject: [PATCH] Added expression support for QuerySet.update. --- django/db/models/sql/expressions.py | 133 +++++++++++++++++++++++++++++++++++ django/db/models/sql/subqueries.py | 38 +++++----- 2 files changed, 151 insertions(+), 20 deletions(-) create mode 100644 django/db/models/sql/expressions.py diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py new file mode 100644 index 0000000..d52cc46
- + 1 from copy import deepcopy 2 3 from django.db import connection 4 from django.db.models.fields import FieldDoesNotExist 5 from django.core.exceptions import FieldError 6 from django.utils import tree 7 8 class Expression(object): 9 """ 10 Base class for all sql expressions, expected by QuerySet.update. 11 """ 12 # Arithmetic connection types 13 ADD = '+' 14 SUB = '-' 15 MUL = '*' 16 DIV = '/' 17 MOD = '%' 18 19 # Logical connection types 20 AND = 'AND' 21 OR = 'OR' 22 23 def _combine(self, other, conn, node=None): 24 if not isinstance(other, Expression): 25 raise TypeError(other) 26 obj = node or ExpressionNode([self], conn) 27 obj.add(other, conn) 28 return obj 29 30 def __add__(self, other): 31 return self._combine(other, self.ADD) 32 33 def __sub__(self, other): 34 return self._combine(other, self.SUB) 35 36 def __mul__(self, other): 37 return self._combine(other, self.MUL) 38 39 def __div__(self, other): 40 return self._combine(other, self.DIV) 41 42 def __mod__(self, other): 43 return self._combine(other, self.MOD) 44 45 def __and__(self, other): 46 return self._combine(other, self.AND) 47 48 def __or__(self, other): 49 return self._combine(other, self.OR) 50 51 def __invert__(self, node=None): 52 obj = node or ExpressionNode([self]) 53 obj.negate() 54 return obj 55 56 def as_sql(self, field, opts, qn=None): 57 raise NotImplementedError 58 59 class ExpressionNode(Expression, tree.Node): 60 def __init__(self, children=None, connector=None, negated=False): 61 if children and len(children) > 1 and connector in (None, self.default): 62 raise TypeError('You have to specify a connector.') 63 super(ExpressionNode, self).__init__(children, connector, negated) 64 65 def _combine(self, *args, **kwargs): 66 return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs) 67 68 def __invert__(self): 69 return super(ExpressionNode, self).__invert__(node=deepcopy(self)) 70 71 def as_sql(self, field, opts, qn=None, node=None): 72 if node is None: 73 node = self 74 result = [] 75 result_params = [] 76 for child in node.children: 77 if hasattr(child, 'as_sql'): 78 sql, params = child.as_sql(field, opts, qn) 79 format = '%s' 80 else: 81 sql, params = self.as_sql(field, opts, qn, child) 82 if child.negated: 83 format = 'NOT %s' 84 else: 85 format = '%s' 86 if len(child.children) > 1: 87 format %= '(%s)' 88 if sql: 89 result.append(format % sql) 90 result_params.extend(params) 91 conn = ' %s ' % node.connector 92 return conn.join(result), result_params 93 94 class LiteralExpr(Expression): 95 """ 96 An expression representing the given value. 97 """ 98 def __init__(self, value): 99 self.value = value 100 101 def as_sql(self, field, opts, qn=None): 102 if self.value is None: 103 return 'NULL', () 104 if hasattr(field, 'get_placeholder'): 105 return field.get_placeholder(self.value), (self.value,) 106 return '%s', (self.value,) 107 108 class ColumnExpr(Expression): 109 """ 110 An expression representing the value of the given column. 111 """ 112 def __init__(self, column): 113 self.column = column 114 115 def as_sql(self, field, opts, qn=None): 116 if not qn: 117 qn = connection.ops.quote_name 118 try: 119 column = opts.get_field(self.column).attname 120 except FieldDoesNotExist: 121 names = opts.get_all_field_names() 122 raise FieldError('Cannot resolve keyword %r into field. ' 123 'Choices are: %s' % (self.column, ', '.join(names))) 124 return '%s.%s' % (qn(opts.db_table), qn(column)), () 125 126 class CurrentExpr(Expression): 127 """ 128 An expression representing the value of the current column. 129 """ 130 def as_sql(self, field, opts, qn=None): 131 if not qn: 132 qn = connection.ops.quote_name 133 return '%s.%s' % (qn(opts.db_table), qn(field.attname)), () -
django/db/models/sql/subqueries.py
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 7385cd0..0bcca96 100644
a b from django.db.models.sql.constants import * 8 8 from django.db.models.sql.datastructures import RawValue, Date 9 9 from django.db.models.sql.query import Query 10 10 from django.db.models.sql.where import AND 11 from django.db.models.sql.expressions import Expression, LiteralExpr 11 12 12 13 __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 13 14 'CountQuery'] … … class UpdateQuery(Query): 126 127 result = ['UPDATE %s' % qn(table)] 127 128 result.append('SET') 128 129 values, update_params = [], [] 129 for name, val, placeholder in self.values: 130 if val is not None: 131 values.append('%s = %s' % (qn(name), placeholder)) 132 update_params.append(val) 133 else: 134 values.append('%s = NULL' % qn(name)) 130 for name, sql, params in self.values: 131 values.append('%s = %s' % (qn(name), sql)) 132 update_params.extend(params) 135 133 result.append(', '.join(values)) 136 134 where, params = self.where.as_sql() 137 135 if where: … … class UpdateQuery(Query): 207 205 self.where.add((None, f.column, f, 'in', 208 206 pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), 209 207 AND) 210 self.values = [(related_field.column, None, '%s')]208 self.values = [(related_field.column, 'NULL', ())] 211 209 self.execute_sql(None) 212 210 213 211 def add_update_values(self, values): … … class UpdateQuery(Query): 232 230 """ 233 231 from django.db.models.base import Model 234 232 for field, model, val in values_seq: 235 # FIXME: Some sort of db_prep_* is probably more appropriate here. 236 if field.rel and isinstance(val, Model): 237 val = val.pk 233 if isinstance(val, Expression): 234 expr = val 235 elif field.rel and isinstance(val, Model): # FIXME: Some sort of 236 expr = LiteralExpr(val.pk) # db_prep_* is probably 237 else: # more appropriate here. 238 expr = LiteralExpr(val) 238 239 239 # Getting the placeholder for the field. 240 if hasattr(field, 'get_placeholder'): 241 placeholder = field.get_placeholder(val) 242 else: 243 placeholder = '%s' 240 sql, params = expr.as_sql( 241 field, self.get_meta(), self.connection.ops.quote_name) 244 242 245 243 if model: 246 self.add_related_update(model, field.column, val, placeholder)244 self.add_related_update(model, field.column, sql, params) 247 245 else: 248 self.values.append((field.column, val, placeholder))246 self.values.append((field.column, sql, params)) 249 247 250 def add_related_update(self, model, column, value, placeholder):248 def add_related_update(self, model, column, sql, params): 251 249 """ 252 250 Adds (name, value) to an update query for an ancestor model. 253 251 254 252 Updates are coalesced so that we only run one update query per ancestor. 255 253 """ 256 254 try: 257 self.related_updates[model].append((column, value, placeholder))255 self.related_updates[model].append((column, sql, params)) 258 256 except KeyError: 259 self.related_updates[model] = [(column, value, placeholder)]257 self.related_updates[model] = [(column, sql, params)] 260 258 261 259 def get_related_updates(self): 262 260 """