Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.2.patch

File 0001-Added-expression-support-for-QuerySet.update.2.patch, 8.2 KB (added by Sebastian Noack, 17 years ago)
  • new file django/db/models/sql/expressions.py

    From 2555dc3f525243548e8476bd1fe474171a350531 Mon Sep 17 00:00:00 2001
    From: Sebastian Noack <sebastian.noack@gmail.com>
    Date: Thu, 8 May 2008 14:30:19 +0200
    Subject: [PATCH] Added expression support for QuerySet.update.
    
    ---
     django/db/models/sql/expressions.py |  127 +++++++++++++++++++++++++++++++++++
     django/db/models/sql/subqueries.py  |   38 +++++------
     2 files changed, 145 insertions(+), 20 deletions(-)
     create mode 100644 django/db/models/sql/expressions.py
    
    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    new file mode 100644
    index 0000000..1532479
    - +  
     1from copy import deepcopy
     2
     3from django.db import connection
     4from django.db.models.fields import FieldDoesNotExist
     5from django.core.exceptions import FieldError
     6from django.utils import tree
     7
     8class Expression(object):
     9    """
     10    Base class for all sql expressions, expected by QuerySet.update.
     11    """
     12    # Arithmetic connection types
     13    ADD = '+'
     14    SUB = '-'
     15    MUL = '*'
     16    DIV = '/'
     17    MOD = '%'
     18
     19    # Logical connection types
     20    AND = 'AND'
     21    OR = 'OR'
     22
     23    def _combine(self, other, conn, node=None):
     24        obj = node or ExpressionNode([self], conn)
     25        if isinstance(other, Expression):
     26            obj.add(other, conn)
     27        else:
     28            obj.add(L(other), conn)
     29        return obj
     30
     31    def __add__(self, other):
     32        return self._combine(other, self.ADD)
     33
     34    def __sub__(self, other):
     35        return self._combine(other, self.SUB)
     36
     37    def __mul__(self, other):
     38        return self._combine(other, self.MUL)
     39
     40    def __div__(self, other):
     41        return self._combine(other, self.DIV)
     42
     43    def __mod__(self, other):
     44        return self._combine(other, self.MOD)
     45
     46    def __and__(self, other):
     47        return self._combine(other, self.AND)
     48
     49    def __or__(self, other):
     50        return self._combine(other, self.OR)
     51
     52    def __invert__(self, node=None):
     53        obj = node or ExpressionNode([self])
     54        obj.negate()
     55        return obj
     56
     57    def as_sql(self, field, opts, qn=None):
     58        raise NotImplementedError
     59
     60class ExpressionNode(Expression, tree.Node):
     61    default = None
     62
     63    def __init__(self, children=None, connector=None, negated=False):
     64        if children is not None and len(children) > 1 and connector is None:
     65            raise TypeError('You have to specify a connector.')
     66        super(ExpressionNode, self).__init__(children, connector, negated)
     67
     68    def _combine(self, *args, **kwargs):
     69        return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs)
     70
     71    def __invert__(self):
     72        return super(ExpressionNode, self).__invert__(node=deepcopy(self))
     73
     74    def as_sql(self, field, opts, qn=None, node=None):
     75        if node is None:
     76            node = self
     77        result = []
     78        result_params = []
     79        for child in node.children:
     80            if hasattr(child, 'as_sql'):
     81                sql, params = child.as_sql(field, opts, qn)
     82                format = '%s'
     83            else:
     84                sql, params = self.as_sql(field, opts, qn, child)
     85                if child.negated:
     86                    format = 'NOT %s'
     87                else:
     88                    format = '%s'
     89                if len(child.children) > 1:
     90                    format %= '(%s)'
     91            if sql:
     92                result.append(format % sql)
     93                result_params.extend(params)
     94        conn = ' %s ' % node.connector
     95        return conn.join(result), result_params
     96
     97class L(Expression):
     98    """
     99    An expression representing the given value.
     100    """
     101    def __init__(self, value):
     102        self.value = value
     103
     104    def as_sql(self, field, opts, qn=None):
     105        if self.value is None:
     106            return 'NULL', ()
     107        if hasattr(field, 'get_placeholder'):
     108            return field.get_placeholder(self.value), (self.value,)
     109        return '%s', (self.value,)
     110
     111class F(Expression):
     112    """
     113    An expression representing the value of the given field.
     114    """
     115    def __init__(self, name):
     116        self.name = name
     117
     118    def as_sql(self, field, opts, qn=None):
     119        if not qn:
     120            qn = connection.ops.quote_name
     121        try:
     122            column = opts.get_field(self.name).attname
     123        except FieldDoesNotExist:
     124            names = opts.get_all_field_names()
     125            raise FieldError('Cannot resolve keyword %r into field. '
     126                    'Choices are: %s' % (self.name, ', '.join(names)))
     127        return '%s.%s' % (qn(opts.db_table), qn(column)), ()
  • django/db/models/sql/subqueries.py

    diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
    index 7385cd0..d0d5393 100644
    a b from django.db.models.sql.constants import *  
    88from django.db.models.sql.datastructures import RawValue, Date
    99from django.db.models.sql.query import Query
    1010from django.db.models.sql.where import AND
     11from django.db.models.sql.expressions import Expression, L
    1112
    1213__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
    1314        'CountQuery']
    class UpdateQuery(Query):  
    126127        result = ['UPDATE %s' % qn(table)]
    127128        result.append('SET')
    128129        values, update_params = [], []
    129         for name, val, placeholder in self.values:
    130             if val is not None:
    131                 values.append('%s = %s' % (qn(name), placeholder))
    132                 update_params.append(val)
    133             else:
    134                 values.append('%s = NULL' % qn(name))
     130        for name, sql, params in self.values:
     131            values.append('%s = %s' % (qn(name), sql))
     132            update_params.extend(params)
    135133        result.append(', '.join(values))
    136134        where, params = self.where.as_sql()
    137135        if where:
    class UpdateQuery(Query):  
    207205            self.where.add((None, f.column, f, 'in',
    208206                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
    209207                    AND)
    210             self.values = [(related_field.column, None, '%s')]
     208            self.values = [(related_field.column, 'NULL', ())]
    211209            self.execute_sql(None)
    212210
    213211    def add_update_values(self, values):
    class UpdateQuery(Query):  
    232230        """
    233231        from django.db.models.base import Model
    234232        for field, model, val in values_seq:
    235             # FIXME: Some sort of db_prep_* is probably more appropriate here.
    236             if field.rel and isinstance(val, Model):
    237                 val = val.pk
     233            if isinstance(val, Expression):
     234                expr = val
     235            elif field.rel and isinstance(val, Model):  # FIXME: Some sort of
     236                expr = L(val.pk)                        # db_prep_* is probably
     237            else:                                       # more appropriate here.
     238                expr = L(val)
    238239
    239             # Getting the placeholder for the field.
    240             if hasattr(field, 'get_placeholder'):
    241                 placeholder = field.get_placeholder(val)
    242             else:
    243                 placeholder = '%s'
     240            sql, params = expr.as_sql(
     241                field, self.get_meta(), self.connection.ops.quote_name)
    244242
    245243            if model:
    246                 self.add_related_update(model, field.column, val, placeholder)
     244                self.add_related_update(model, field.column, sql, params)
    247245            else:
    248                 self.values.append((field.column, val, placeholder))
     246                self.values.append((field.column, sql, params))
    249247
    250     def add_related_update(self, model, column, value, placeholder):
     248    def add_related_update(self, model, column, sql, params):
    251249        """
    252250        Adds (name, value) to an update query for an ancestor model.
    253251
    254252        Updates are coalesced so that we only run one update query per ancestor.
    255253        """
    256254        try:
    257             self.related_updates[model].append((column, value, placeholder))
     255            self.related_updates[model].append((column, sql, params))
    258256        except KeyError:
    259             self.related_updates[model] = [(column, value, placeholder)]
     257            self.related_updates[model] = [(column, sql, params)]
    260258
    261259    def get_related_updates(self):
    262260        """
Back to Top