Ticket #1133: query_args.r1799.v2.patch

File query_args.r1799.v2.patch, 8.8 KB (added by freakboy@…, 19 years ago)

Updated patch after merge to r1799

  • django/db/models/manager.py

     
    5050           self.creation_counter < klass._default_manager.creation_counter:
    5151                klass._default_manager = self
    5252
    53     def _get_sql_clause(self, **kwargs):
     53    def _get_sql_clause(self, *args, **kwargs):
    5454        def quote_only_if_word(word):
    5555            if ' ' in word:
    5656                return word
     
    6262        # Construct the fundamental parts of the query: SELECT X FROM Y WHERE Z.
    6363        select = ["%s.%s" % (backend.quote_name(opts.db_table), backend.quote_name(f.column)) for f in opts.fields]
    6464        tables = (kwargs.get('tables') and [quote_only_if_word(t) for t in kwargs['tables']] or [])
     65        joins = {}
    6566        where = kwargs.get('where') and kwargs['where'][:] or []
    6667        params = kwargs.get('params') and kwargs['params'][:] or []
    6768
     69        # Convert all the args into SQL.
     70        table_count = 0
     71        for arg in args:
     72            # check that the provided argument is a Query (i.e., it has a get_sql method)
     73            if not hasattr(arg, 'get_sql'):
     74                raise TypeError, "got unknown query argument '%s'" % str(arg)
     75
     76            tables2, joins2, where2, params2 = arg.get_sql(opts)
     77            tables.extend(tables2)
     78            joins.update(joins2)
     79            where.extend(where2)
     80            params.extend(params2)
     81
    6882        # Convert the kwargs into SQL.
    69         tables2, joins, where2, params2 = parse_lookup(kwargs.items(), opts)
     83        tables2, joins2, where2, params2 = parse_lookup(kwargs.items(), opts)
    7084        tables.extend(tables2)
     85        joins.update(joins2)
    7186        where.extend(where2)
    7287        params.extend(params2)
    7388
     
    132147
    133148        return select, " ".join(sql), params
    134149
    135     def get_iterator(self, **kwargs):
     150    def get_iterator(self, *args, **kwargs):
    136151        # kwargs['select'] is a dictionary, and dictionaries' key order is
    137152        # undefined, so we convert it to a list of tuples internally.
    138153        kwargs['select'] = kwargs.get('select', {}).items()
    139154
    140155        cursor = connection.cursor()
    141         select, sql, params = self._get_sql_clause(**kwargs)
     156        select, sql, params = self._get_sql_clause(*args, **kwargs)
    142157        cursor.execute("SELECT " + (kwargs.get('distinct') and "DISTINCT " or "") + ",".join(select) + sql, params)
    143158        fill_cache = kwargs.get('select_related')
    144159        index_end = len(self.klass._meta.fields)
     
    155170                    setattr(obj, k[0], row[index_end+i])
    156171                yield obj
    157172
    158     def get_list(self, **kwargs):
    159         return list(self.get_iterator(**kwargs))
     173    def get_list(self, *args, **kwargs):
     174        return list(self.get_iterator(*args, **kwargs))
    160175
    161     def get_count(self, **kwargs):
     176    def get_count(self, *args, **kwargs):
    162177        kwargs['order_by'] = []
    163178        kwargs['offset'] = None
    164179        kwargs['limit'] = None
    165180        kwargs['select_related'] = False
    166         _, sql, params = self._get_sql_clause(**kwargs)
     181        _, sql, params = self._get_sql_clause(*args, **kwargs)
    167182        cursor = connection.cursor()
    168183        cursor.execute("SELECT COUNT(*)" + sql, params)
    169184        return cursor.fetchone()[0]
    170185
    171     def get_object(self, **kwargs):
    172         obj_list = self.get_list(**kwargs)
     186    def get_object(self, *args, **kwargs):
     187        obj_list = self.get_list(*args, **kwargs)
    173188        if len(obj_list) < 1:
    174189            raise self.klass.DoesNotExist, "%s does not exist for %s" % (self.klass._meta.object_name, kwargs)
    175190        assert len(obj_list) == 1, "get_object() returned more than one %s -- it returned %s! Lookup parameters were %s" % (self.klass._meta.object_name, len(obj_list), kwargs)
    176191        return obj_list[0]
    177192
    178193    def get_in_bulk(self, *args, **kwargs):
    179         id_list = args and args[0] or kwargs['id_list']
    180         assert id_list != [], "get_in_bulk() cannot be passed an empty list."
     194        # Separate any list arguments: these will be added together to provide the id list
     195        id_args = filter(lambda arg: isinstance(arg, list), args)
     196        # Separate any non-list arguments: these are assumed to be query arguments
     197        sql_args = filter(lambda arg: not isinstance(arg, list), args)
     198
     199        id_list = id_args and id_args[0] or kwargs.get('id_list', [])
     200        assert id_list != [], "get_in_bulk() cannot be passed an empty ID list."
    181201        kwargs['where'] = ["%s.%s IN (%s)" % (backend.quote_name(self.klass._meta.db_table), backend.quote_name(self.klass._meta.pk.column), ",".join(['%s'] * len(id_list)))]
    182202        kwargs['params'] = id_list
    183         obj_list = self.get_list(**kwargs)
     203        obj_list = self.get_list(*sql_args, **kwargs)
    184204        return dict([(getattr(o, self.klass._meta.pk.attname), o) for o in obj_list])
    185205
    186     def get_values_iterator(self, **kwargs):
     206    def get_values_iterator(self, *args, **kwargs):
    187207        # select_related and select aren't supported in get_values().
    188208        kwargs['select_related'] = False
    189209        kwargs['select'] = {}
     
    195215            fields = [f.column for f in self.klass._meta.fields]
    196216
    197217        cursor = connection.cursor()
    198         _, sql, params = self._get_sql_clause(**kwargs)
     218        _, sql, params = self._get_sql_clause(*args, **kwargs)
    199219        select = ['%s.%s' % (backend.quote_name(self.klass._meta.db_table), backend.quote_name(f)) for f in fields]
    200220        cursor.execute("SELECT " + (kwargs.get('distinct') and "DISTINCT " or "") + ",".join(select) + sql, params)
    201221        while 1:
     
    205225            for row in rows:
    206226                yield dict(zip(fields, row))
    207227
    208     def get_values(self, **kwargs):
    209         return list(self.get_values_iterator(**kwargs))
     228    def get_values(self, *args, **kwargs):
     229        return list(self.get_values_iterator(*args, **kwargs))
    210230
    211     def __get_latest(self, **kwargs):
     231    def __get_latest(self, *args, **kwargs):
    212232        kwargs['order_by'] = ('-' + self.klass._meta.get_latest_by,)
    213233        kwargs['limit'] = 1
    214         return self.get_object(**kwargs)
     234        return self.get_object(*args, **kwargs)
    215235
    216236    def __get_date_list(self, field, *args, **kwargs):
     237        # Separate any string arguments: the first will be used as the kind
     238        kind_args = filter(lambda arg: isinstance(arg, str), args)
     239        # Separate any non-list arguments: these are assumed to be query arguments
     240        sql_args = filter(lambda arg: not isinstance(arg, str), args)
     241
    217242        from django.db.backends.util import typecast_timestamp
    218         kind = args and args[0] or kwargs['kind']
     243        kind = kind_args and kind_args[0] or kwargs.get('kind', "")
    219244        assert kind in ("month", "year", "day"), "'kind' must be one of 'year', 'month' or 'day'."
    220245        order = 'ASC'
    221246        if kwargs.has_key('order'):
     
    226251        if field.null:
    227252            kwargs.setdefault('where', []).append('%s.%s IS NOT NULL' % \
    228253                (backend.quote_name(self.klass._meta.db_table), backend.quote_name(field.column)))
    229         select, sql, params = self._get_sql_clause(**kwargs)
     254        select, sql, params = self._get_sql_clause(*sql_args, **kwargs)
    230255        sql = 'SELECT %s %s GROUP BY 1 ORDER BY 1 %s' % \
    231256            (backend.get_date_trunc_sql(kind, '%s.%s' % (backend.quote_name(self.klass._meta.db_table),
    232257            backend.quote_name(field.column))), sql, order)
  • django/db/models/query.py

     
    197197        if kwarg_value is None:
    198198            continue
    199199        if kwarg == 'complex':
     200            if not hasattr(kwarg_value, 'get_sql'):
     201                raise TypeError, "got unknown query argument '%s'" % str(arg)   
    200202            tables2, joins2, where2, params2 = kwarg_value.get_sql(opts)
    201203            tables.extend(tables2)
    202204            joins.update(joins2)
  • tests/modeltests/or_lookups/models.py

     
    5454>>> Article.objects.get_list(complex=(Q(pk=1) | Q(pk=2) | Q(pk=3)))
    5555[Hello, Goodbye, Hello and goodbye]
    5656
     57>>> Article.objects.get_list(Q(headline__startswith='Hello'))
     58[Hello, Hello and goodbye]
     59
     60>>> Article.objects.get_list(Q(headline__startswith='Hello'), Q(headline__contains='bye'))
     61[Hello and goodbye]
     62
     63>>> Article.objects.get_list(Q(headline__startswith='Hello') & Q(headline__contains='bye'))
     64[Hello and goodbye]
     65
     66>>> Article.objects.get_list(Q(headline__contains='bye'), headline__startswith='Hello')
     67[Hello and goodbye]
     68
     69>>> Article.objects.get_list(Q(headline__contains='Hello') | Q(headline__contains='bye'))
     70[Hello, Goodbye, Hello and goodbye]
     71
    5772"""
Back to Top