Ticket #7210: 0001-Added-expression-support-for-QuerySet.update.3.patch

File 0001-Added-expression-support-for-QuerySet.update.3.patch, 15.1 KB (added by Sebastian Noack, 17 years ago)
  • django/db/models/__init__.py

    From a439f12ed3c4bcd3c2bc78b67bf36f47043c51ee Mon Sep 17 00:00:00 2001
    From: Sebastian Noack <sebastian.noack@gmail.com>
    Date: Thu, 8 May 2008 14:30:19 +0200
    Subject: [PATCH] Added expression support for QuerySet.update.
    
    ---
     django/db/models/__init__.py        |    1 +
     django/db/models/sql/expressions.py |  186 +++++++++++++++++++++++++++++++++++
     django/db/models/sql/query.py       |    2 +-
     django/db/models/sql/subqueries.py  |   36 +++----
     django/db/models/sql/where.py       |   38 ++++---
     5 files changed, 225 insertions(+), 38 deletions(-)
     create mode 100644 django/db/models/sql/expressions.py
    
    diff --git a/django/db/models/__init__.py b/django/db/models/__init__.py
    index 86763d9..f5ab8eb 100644
    a b from django.core import validators  
    44from django.db import connection
    55from django.db.models.loading import get_apps, get_app, get_models, get_model, register_models
    66from django.db.models.query import Q
     7from django.db.models.sql.expressions import F
    78from django.db.models.manager import Manager
    89from django.db.models.base import Model, AdminOptions
    910from django.db.models.fields import *
  • new file django/db/models/sql/expressions.py

    diff --git a/django/db/models/sql/expressions.py b/django/db/models/sql/expressions.py
    new file mode 100644
    index 0000000..2be6472
    - +  
     1from copy import deepcopy
     2from datetime import datetime
     3
     4from django.utils import tree
     5from django.core.exceptions import FieldError
     6from django.db import connection
     7from django.db.models.fields import Field, FieldDoesNotExist
     8from django.db.models.query_utils import QueryWrapper
     9
     10class Expression(object):
     11    """
     12    Base class for all sql expressions, expected by QuerySet.update.
     13    """
     14    # Arithmetic connection types
     15    ADD = '+'
     16    SUB = '-'
     17    MUL = '*'
     18    DIV = '/'
     19    MOD = '%'
     20
     21    # Bitwise connection types
     22    AND = '&'
     23    OR = '|'
     24    LSHIFT = '<<'
     25    RSHIFT = '>>'
     26
     27    def _combine(self, other, conn, reversed, node=None):
     28        if reversed:
     29            obj = ExpressionNode([Literal(other)], conn)
     30            obj.add(node or self, conn)
     31        else:
     32            obj = node or ExpressionNode([self], conn)
     33            if isinstance(other, Expression):
     34                obj.add(other, conn)
     35            else:
     36                obj.add(Literal(other), conn)
     37        return obj
     38
     39    def __add__(self, other):
     40        return self._combine(other, self.ADD, False)
     41
     42    def __sub__(self, other):
     43        return self._combine(other, self.SUB, False)
     44
     45    def __mul__(self, other):
     46        return self._combine(other, self.MUL, False)
     47
     48    def __div__(self, other):
     49        return self._combine(other, self.DIV, False)
     50
     51    def __mod__(self, other):
     52        return self._combine(other, self.MOD, False)
     53
     54    def __and__(self, other):
     55        return self._combine(other, self.AND, False)
     56
     57    def __or__(self, other):
     58        return self._combine(other, self.OR, False)
     59
     60    def __lshift__(self, other):
     61        return self._combine(other, self.LSHIFT, False)
     62
     63    def __rshift__(self, other):
     64        return self._combine(other, self.RSHIFT, False)
     65
     66    def __radd__(self, other):
     67        return self._combine(other, self.ADD, True)
     68
     69    def __rsub__(self, other):
     70        return self._combine(other, self.SUB, True)
     71
     72    def __rmul__(self, other):
     73        return self._combine(other, self.MUL, True)
     74
     75    def __rdiv__(self, other):
     76        return self._combine(other, self.DIV, True)
     77
     78    def __rmod__(self, other):
     79        return self._combine(other, self.MOD, True)
     80
     81    def __rand__(self, other):
     82        return self._combine(other, self.AND, True)
     83
     84    def __ror__(self, other):
     85        return self._combine(other, self.OR, True)
     86
     87    def __lshift__(self, other):
     88        return self._combine(other, self.LSHIFT, True)
     89
     90    def __rshift__(self, other):
     91        return self._combine(other, self.RSHIFT, True)
     92
     93    def __invert__(self, node=None):
     94        obj = node or ExpressionNode([self])
     95        obj.negate()
     96        return obj
     97
     98    def as_sql(self, opts, field=None, lookup_type='exact', qn=None):
     99        raise NotImplementedError
     100
     101class ExpressionNode(Expression, tree.Node):
     102    default = None
     103
     104    def __init__(self, children=None, connector=None, negated=False):
     105        if children is not None and len(children) > 1 and connector is None:
     106            raise TypeError('You have to specify a connector.')
     107        super(ExpressionNode, self).__init__(children, connector, negated)
     108
     109    def _combine(self, *args, **kwargs):
     110        return super(ExpressionNode, self)._combine(node=deepcopy(self), *args, **kwargs)
     111
     112    def __invert__(self):
     113        return super(ExpressionNode, self).__invert__(node=deepcopy(self))
     114
     115    def as_sql(self, opts, field=None, lookup_type='exact', qn=None, node=None):
     116        if not qn:
     117            qn = connection.ops.quote_name
     118        if node is None:
     119            node = self
     120
     121        result = []
     122        result_params = []
     123        for child in node.children:
     124            if hasattr(child, 'as_sql'):
     125                sql, params = child.as_sql(opts, field, qn=qn)
     126                format = '%s'
     127            else:
     128                sql, params = self.as_sql(opts, field, qn=qn, node=child)
     129                if child.negated:
     130                    format = '~%s'
     131                else:
     132                    format = '%s'
     133                if len(child.children) > 1:
     134                    format %= '(%s)'
     135            if sql:
     136                result.append(format % sql)
     137                result_params.extend(params)
     138        conn = ' %s ' % node.connector
     139        return conn.join(result), result_params
     140
     141class Literal(Expression):
     142    """
     143    An expression representing the given value.
     144    """
     145    def __init__(self, value):
     146        self.value = value
     147
     148    def as_sql(self, opts, field=None, lookup_type='exact', qn=None):
     149        if self.value is None:
     150            return 'NULL', ()
     151
     152        if field:
     153            params = field.get_db_prep_lookup(lookup_type, self.value)
     154        else:
     155            params = Field().get_db_prep_lookup(lookup_type, self.value)
     156
     157        if isinstance(self.value, datetime):
     158            sql = connection.ops.datetime_cast_sql()
     159        else:
     160            sql = '%s'
     161
     162        if isinstance(params, QueryWrapper):
     163            return params.data
     164        return sql, params
     165
     166class F(Expression):
     167    """
     168    An expression representing the value of the given field.
     169    """
     170    def __init__(self, name):
     171        self.name = name
     172
     173    def as_sql(self, opts, field=None, lookup_type='exact', qn=None):
     174        if not qn:
     175            qn = connection.ops.quote_name
     176
     177        try:
     178            src_field = opts.get_field(self.name)
     179        except FieldDoesNotExist:
     180            names = opts.get_all_field_names()
     181            raise FieldError('Cannot resolve keyword %r into field. '
     182                    'Choices are: %s' % (self.name, ', '.join(names)))
     183
     184        field_sql = connection.ops.field_cast_sql(src_field.db_type())
     185        lhs = '%s.%s' % (qn(opts.db_table), qn(src_field.attname))
     186        return field_sql % lhs, ()
  • django/db/models/sql/query.py

    diff --git a/django/db/models/sql/query.py b/django/db/models/sql/query.py
    index a6957ba..bd1d030 100644
    a b class Query(object):  
    250250        # get_from_clause() for details.
    251251        from_, f_params = self.get_from_clause()
    252252
    253         where, w_params = self.where.as_sql(qn=self.quote_name_unless_alias)
     253        where, w_params = self.where.as_sql(self.get_meta(), qn=self.quote_name_unless_alias)
    254254        params = list(self.extra_select_params)
    255255
    256256        result = ['SELECT']
  • django/db/models/sql/subqueries.py

    diff --git a/django/db/models/sql/subqueries.py b/django/db/models/sql/subqueries.py
    index 7385cd0..251bab2 100644
    a b from django.db.models.sql.constants import *  
    88from django.db.models.sql.datastructures import RawValue, Date
    99from django.db.models.sql.query import Query
    1010from django.db.models.sql.where import AND
     11from django.db.models.sql.expressions import Expression, Literal
    1112
    1213__all__ = ['DeleteQuery', 'UpdateQuery', 'InsertQuery', 'DateQuery',
    1314        'CountQuery']
    class UpdateQuery(Query):  
    126127        result = ['UPDATE %s' % qn(table)]
    127128        result.append('SET')
    128129        values, update_params = [], []
    129         for name, val, placeholder in self.values:
    130             if val is not None:
    131                 values.append('%s = %s' % (qn(name), placeholder))
    132                 update_params.append(val)
    133             else:
    134                 values.append('%s = NULL' % qn(name))
     130        for name, sql, params in self.values:
     131            values.append('%s = %s' % (qn(name), sql))
     132            update_params.extend(params)
    135133        result.append(', '.join(values))
    136134        where, params = self.where.as_sql()
    137135        if where:
    class UpdateQuery(Query):  
    207205            self.where.add((None, f.column, f, 'in',
    208206                    pk_list[offset : offset + GET_ITERATOR_CHUNK_SIZE]),
    209207                    AND)
    210             self.values = [(related_field.column, None, '%s')]
     208            self.values = [(related_field.column, 'NULL', ())]
    211209            self.execute_sql(None)
    212210
    213211    def add_update_values(self, values):
    class UpdateQuery(Query):  
    232230        """
    233231        from django.db.models.base import Model
    234232        for field, model, val in values_seq:
    235             # FIXME: Some sort of db_prep_* is probably more appropriate here.
    236             if field.rel and isinstance(val, Model):
    237                 val = val.pk
    238 
    239             # Getting the placeholder for the field.
    240             if hasattr(field, 'get_placeholder'):
    241                 placeholder = field.get_placeholder(val)
     233            if isinstance(val, Expression):
     234                expr = val
    242235            else:
    243                 placeholder = '%s'
     236                expr = Literal(val)
     237
     238            sql, params = expr.as_sql(
     239                field, self.get_meta(), qn=self.connection.ops.quote_name)
    244240
    245241            if model:
    246                 self.add_related_update(model, field.column, val, placeholder)
     242                self.add_related_update(model, field.column, sql, params)
    247243            else:
    248                 self.values.append((field.column, val, placeholder))
     244                self.values.append((field.column, sql, params))
    249245
    250     def add_related_update(self, model, column, value, placeholder):
     246    def add_related_update(self, model, column, sql, params):
    251247        """
    252248        Adds (name, value) to an update query for an ancestor model.
    253249
    254250        Updates are coalesced so that we only run one update query per ancestor.
    255251        """
    256252        try:
    257             self.related_updates[model].append((column, value, placeholder))
     253            self.related_updates[model].append((column, sql, params))
    258254        except KeyError:
    259             self.related_updates[model] = [(column, value, placeholder)]
     255            self.related_updates[model] = [(column, sql, params)]
    260256
    261257    def get_related_updates(self):
    262258        """
  • django/db/models/sql/where.py

    diff --git a/django/db/models/sql/where.py b/django/db/models/sql/where.py
    index 3e8bfed..d333899 100644
    a b from django.utils import tree  
    77from django.db import connection
    88from django.db.models.fields import Field
    99from django.db.models.query_utils import QueryWrapper
     10from django.db.models.sql.expressions import Expression, Literal
    1011from datastructures import EmptyResultSet, FullResultSet
    1112
    1213# Connection types
    class WhereNode(tree.Node):  
    2627    """
    2728    default = AND
    2829
    29     def as_sql(self, node=None, qn=None):
     30    def as_sql(self, opts, node=None, qn=None):
    3031        """
    3132        Returns the SQL version of the where clause and the value to be
    3233        substituted in. Returns None, None if this node is empty.
    class WhereNode(tree.Node):  
    4748        for child in node.children:
    4849            try:
    4950                if hasattr(child, 'as_sql'):
    50                     sql, params = child.as_sql(qn=qn)
     51                    sql, params = child.as_sql(opts, qn=qn)
    5152                    format = '(%s)'
    5253                elif isinstance(child, tree.Node):
    53                     sql, params = self.as_sql(child, qn)
     54                    sql, params = self.as_sql(opts, child, qn)
    5455                    if child.negated:
    5556                        format = 'NOT (%s)'
    5657                    elif len(child.children) == 1:
    class WhereNode(tree.Node):  
    5859                    else:
    5960                        format = '(%s)'
    6061                else:
    61                     sql, params = self.make_atom(child, qn)
     62                    sql, params = self.make_atom(opts, child, qn)
    6263                    format = '%s'
    6364            except EmptyResultSet:
    6465                if node.connector == AND and not node.negated:
    class WhereNode(tree.Node):  
    8687        conn = ' %s ' % node.connector
    8788        return conn.join(result), result_params
    8889
    89     def make_atom(self, child, qn):
     90    def make_atom(self, opts, child, qn):
    9091        """
    9192        Turn a tuple (table_alias, field_name, field_class, lookup_type, value)
    9293        into valid SQL.
    class WhereNode(tree.Node):  
    102103        db_type = field and field.db_type() or None
    103104        field_sql = connection.ops.field_cast_sql(db_type) % lhs
    104105
    105         if isinstance(value, datetime.datetime):
    106             cast_sql = connection.ops.datetime_cast_sql()
    107         else:
    108             cast_sql = '%s'
     106        if lookup_type in connection.operators:
     107            if isinstance(value, Expression):
     108                sql, params = value.as_sql(opts, field, lookup_type, qn)
     109            else:
     110                sql, params = Literal(value).as_sql(opts, field, lookup_type, qn)
     111
     112            format= '%s %s' % (
     113                connection.ops.lookup_cast(lookup_type),
     114                connection.operators[lookup_type])
     115            return format % (field_sql, sql), params
     116
     117        if isinstance(value, Expression):
     118            TypeError('Invalid lookup_type for use with %s object: %r' % (
     119                value.__class__.__name__, lookup_type))
    109120
    110121        if field:
    111122            params = field.get_db_prep_lookup(lookup_type, value)
    class WhereNode(tree.Node):  
    116127        else:
    117128            extra = ''
    118129
    119         if lookup_type in connection.operators:
    120             format = "%s %%s %s" % (connection.ops.lookup_cast(lookup_type),
    121                     extra)
    122             return (format % (field_sql,
    123                     connection.operators[lookup_type] % cast_sql), params)
    124 
    125130        if lookup_type == 'in':
    126131            if not value:
    127132                raise EmptyResultSet
    128133            if extra:
    129134                return ('%s IN %s' % (field_sql, extra), params)
    130             return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))),
    131                     params)
     135            return ('%s IN (%s)' % (field_sql, ', '.join(['%s'] * len(value))), params)
    132136        elif lookup_type in ('range', 'year'):
    133137            return ('%s BETWEEN %%s and %%s' % field_sql, params)
    134138        elif lookup_type in ('month', 'day'):
Back to Top