Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.2.patch
File 0001-Added-expression-support-for-QuerySet.update.2.patch, 8.2 KB (added by , 17 years ago) |
---|
-
new file django/db/models/sql/expressions.py
From 2555dc3f525243548e8476bd1fe474171a350531 Mon Sep 17 00:00:00 2001 From: Sebastian Noack <sebastian.noack@gmail.com> Date: Thu, 8 May 2008 14:30:19 +0200 Subject: [PATCH] Added expression support for QuerySet.update. --- django/db/models/sql/expressions.py | 127 +++++++++++++++++++++++++++++++++++ django/db/models/sql/subqueries.py | 38 +++++------ 2 files changed, 145 insertions(+), 20 deletions(-) create mode 100644 django/db/models/sql/expressions.py diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py new file mode 100644 index 0000000..1532479
- + 1 from copy import deepcopy 2 3 from django.db import connection 4 from django.db.models.fields import FieldDoesNotExist 5 from django.core.exceptions import FieldError 6 from django.utils import tree 7 8 class Expression(object): 9 """ 10 Base class for all sql expressions, expected by QuerySet.update. 11 """ 12 # Arithmetic connection types 13 ADD = '+' 14 SUB = '-' 15 MUL = '*' 16 DIV = '/' 17 MOD = '%' 18 19 # Logical connection types 20 AND = 'AND' 21 OR = 'OR' 22 23 def _combine(self, other, conn, node=None): 24 obj = node or ExpressionNode([self], conn) 25 if isinstance(other, Expression): 26 obj.add(other, conn) 27 else: 28 obj.add(L(other), conn) 29 return obj 30 31 def __add__(self, other): 32 return self._combine(other, self.ADD) 33 34 def __sub__(self, other): 35 return self._combine(other, self.SUB) 36 37 def __mul__(self, other): 38 return self._combine(other, self.MUL) 39 40 def __div__(self, other): 41 return self._combine(other, self.DIV) 42 43 def __mod__(self, other): 44 return self._combine(other, self.MOD) 45 46 def __and__(self, other): 47 return self._combine(other, self.AND) 48 49 def __or__(self, other): 50 return self._combine(other, self.OR) 51 52 def __invert__(self, node=None): 53 obj = node or ExpressionNode([self]) 54 obj.negate() 55 return obj 56 57 def as_sql(self, field, opts, qn=None): 58 raise NotImplementedError 59 60 class ExpressionNode(Expression, tree.Node): 61 default = None 62 63 def __init__(self, children=None, connector=None, negated=False): 64 if children is not None and len(children) > 1 and connector is None: 65 raise TypeError('You have to specify a connector.') 66 super(ExpressionNode, self).__init__(children, connector, negated) 67 68 def _combine(self, *args, **kwargs): 69 return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs) 70 71 def __invert__(self): 72 return super(ExpressionNode, self).__invert__(node=deepcopy(self)) 73 74 def as_sql(self, field, opts, qn=None, node=None): 75 if node is None: 76 node = self 77 result = [] 78 result_params = [] 79 for child in node.children: 80 if hasattr(child, 'as_sql'): 81 sql, params = child.as_sql(field, opts, qn) 82 format = '%s' 83 else: 84 sql, params = self.as_sql(field, opts, qn, child) 85 if child.negated: 86 format = 'NOT %s' 87 else: 88 format = '%s' 89 if len(child.children) > 1: 90 format %= '(%s)' 91 if sql: 92 result.append(format % sql) 93 result_params.extend(params) 94 conn = ' %s ' % node.connector 95 return conn.join(result), result_params 96 97 class L(Expression): 98 """ 99 An expression representing the given value. 100 """ 101 def __init__(self, value): 102 self.value = value 103 104 def as_sql(self, field, opts, qn=None): 105 if self.value is None: 106 return 'NULL', () 107 if hasattr(field, 'get_placeholder'): 108 return field.get_placeholder(self.value), (self.value,) 109 return '%s', (self.value,) 110 111 class F(Expression): 112 """ 113 An expression representing the value of the given field. 114 """ 115 def __init__(self, name): 116 self.name = name 117 118 def as_sql(self, field, opts, qn=None): 119 if not qn: 120 qn = connection.ops.quote_name 121 try: 122 column = opts.get_field(self.name).attname 123 except FieldDoesNotExist: 124 names = opts.get_all_field_names() 125 raise FieldError('Cannot resolve keyword %r into field. ' 126 'Choices are: %s' % (self.name, ', '.join(names))) 127 return '%s.%s' % (qn(opts.db_table), qn(column)), () -
django/db/models/sql/subqueries.py
diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py index 7385cd0..d0d5393 100644
a b from django.db.models.sql.constants import * 8 8 from django.db.models.sql.datastructures import RawValue, Date 9 9 from django.db.models.sql.query import Query 10 10 from django.db.models.sql.where import AND 11 from django.db.models.sql.expressions import Expression, L 11 12 12 13 __all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery', 13 14 'CountQuery'] … … class UpdateQuery(Query): 126 127 result = ['UPDATE %s' % qn(table)] 127 128 result.append('SET') 128 129 values, update_params = [], [] 129 for name, val, placeholder in self.values: 130 if val is not None: 131 values.append('%s = %s' % (qn(name), placeholder)) 132 update_params.append(val) 133 else: 134 values.append('%s = NULL' % qn(name)) 130 for name, sql, params in self.values: 131 values.append('%s = %s' % (qn(name), sql)) 132 update_params.extend(params) 135 133 result.append(', '.join(values)) 136 134 where, params = self.where.as_sql() 137 135 if where: … … class UpdateQuery(Query): 207 205 self.where.add((None, f.column, f, 'in', 208 206 pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]), 209 207 AND) 210 self.values = [(related_field.column, None, '%s')]208 self.values = [(related_field.column, 'NULL', ())] 211 209 self.execute_sql(None) 212 210 213 211 def add_update_values(self, values): … … class UpdateQuery(Query): 232 230 """ 233 231 from django.db.models.base import Model 234 232 for field, model, val in values_seq: 235 # FIXME: Some sort of db_prep_* is probably more appropriate here. 236 if field.rel and isinstance(val, Model): 237 val = val.pk 233 if isinstance(val, Expression): 234 expr = val 235 elif field.rel and isinstance(val, Model): # FIXME: Some sort of 236 expr = L(val.pk) # db_prep_* is probably 237 else: # more appropriate here. 238 expr = L(val) 238 239 239 # Getting the placeholder for the field. 240 if hasattr(field, 'get_placeholder'): 241 placeholder = field.get_placeholder(val) 242 else: 243 placeholder = '%s' 240 sql, params = expr.as_sql( 241 field, self.get_meta(), self.connection.ops.quote_name) 244 242 245 243 if model: 246 self.add_related_update(model, field.column, val, placeholder)244 self.add_related_update(model, field.column, sql, params) 247 245 else: 248 self.values.append((field.column, val, placeholder))246 self.values.append((field.column, sql, params)) 249 247 250 def add_related_update(self, model, column, value, placeholder):248 def add_related_update(self, model, column, sql, params): 251 249 """ 252 250 Adds (name, value) to an update query for an ancestor model. 253 251 254 252 Updates are coalesced so that we only run one update query per ancestor. 255 253 """ 256 254 try: 257 self.related_updates[model].append((column, value, placeholder))255 self.related_updates[model].append((column, sql, params)) 258 256 except KeyError: 259 self.related_updates[model] = [(column, value, placeholder)]257 self.related_updates[model] = [(column, sql, params)] 260 258 261 259 def get_related_updates(self): 262 260 """