Ticket #3275: query.py.diff

File query.py.diff, 5.9 KB (added by David Cramer <dcramer@…>, 18 years ago)

diffs for django/db/models/query.py

  • query.py

     
    8080        self._filters = Q()
    8181        self._order_by = None        # Ordering, e.g. ('date', '-name'). If None, use model's ordering.
    8282        self._select_related = False # Whether to fill cache for related objects.
     83        self._recurse_depth = 0      # Used to track how deep we are following for select_related()
     84        self._recurse_fields = []    # Fields to recurse through for select_related()
    8385        self._distinct = False       # Whether the query should use SELECT DISTINCT.
    8486        self._select = {}            # Dictionary of attname -> SQL.
    8587        self._where = []             # List of extra WHERE clauses to use.
     
    178180                raise StopIteration
    179181            for row in rows:
    180182                if fill_cache:
    181                     obj, index_end = get_cached_row(self.model, row, 0)
     183                    obj, index_end = get_cached_row(self.model, row, 0, self._recurse_fields, self._recurse_depth)
    182184                else:
    183185                    obj = self.model(*row[:index_end])
    184186                for i, k in enumerate(extra_select):
     
    194196        counter._select_related = False
    195197        select, sql, params = counter._get_sql_clause()
    196198        cursor = connection.cursor()
     199        id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
     200                backend.quote_name(self.model._meta.pk.column))
    197201        if self._distinct:
    198             id_col = "%s.%s" % (backend.quote_name(self.model._meta.db_table),
    199                     backend.quote_name(self.model._meta.pk.column))
    200202            cursor.execute("SELECT COUNT(DISTINCT(%s))" % id_col + sql, params)
    201203        else:
    202             cursor.execute("SELECT COUNT(*)" + sql, params)
     204            cursor.execute("SELECT COUNT(%s)" % id_col + sql, params)
    203205        return cursor.fetchone()[0]
    204206
    205207    def get(self, *args, **kwargs):
     
    359361        else:
    360362            return self._filter_or_exclude(None, **filter_obj)
    361363
    362     def select_related(self, true_or_false=True):
     364    # fields should be a list of field names in the root table, if specified, it modifies depth to 1
     365    # depth is the maximum number of children to recurse through, defaults to infinite
     366    def select_related(self, true_or_false=True, depth=0, fields=[]):
    363367        "Returns a new QuerySet instance with '_select_related' modified."
    364         return self._clone(_select_related=true_or_false)
     368        if fields != []:
     369            depth = 1
     370        return self._clone(_select_related=true_or_false, _recurse_depth=depth, _recurse_fields=fields)
    365371
    366372    def order_by(self, *field_names):
    367373        "Returns a new QuerySet instance with the ordering changed."
     
    395401        c._filters = self._filters
    396402        c._order_by = self._order_by
    397403        c._select_related = self._select_related
     404        c._recurse_fields = self._recurse_fields
     405        c._recurse_depth = self._recurse_depth
    398406        c._distinct = self._distinct
    399407        c._select = self._select.copy()
    400408        c._where = self._where[:]
     
    448456
    449457        # Add additional tables and WHERE clauses based on select_related.
    450458        if self._select_related:
    451             fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table])
     459            fill_table_cache(opts, select, tables, where, opts.db_table, [opts.db_table], self._recurse_depth, self._recurse_fields)
    452460
    453461        # Add any additional SELECTs.
    454462        if self._select:
     
    660668        return backend.get_fulltext_search_sql(table_prefix + field_name)
    661669    raise TypeError, "Got invalid lookup_type: %s" % repr(lookup_type)
    662670
    663 def get_cached_row(klass, row, index_start):
     671def get_cached_row(klass, row, index_start, fields=[], max_depth=0, cur_depth=0):
    664672    "Helper function that recursively returns an object with cache filled"
     673    if max_depth and cur_depth > max_depth:
     674        return None
    665675    index_end = index_start + len(klass._meta.fields)
    666676    obj = klass(*row[index_start:index_end])
    667677    for f in klass._meta.fields:
    668         if f.rel and not f.null:
    669             rel_obj, index_end = get_cached_row(f.rel.to, row, index_end)
    670             setattr(obj, f.get_cache_name(), rel_obj)
     678        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields)):
     679            cached_row = get_cached_row(f.rel.to, row, index_end, fields, max_depth, cur_depth+1)
     680            if cached_row:
     681                    rel_obj, index_end = cached_row
     682                    setattr(obj, f.get_cache_name(), rel_obj)
    671683    return obj, index_end
    672684
    673 def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen):
     685def fill_table_cache(opts, select, tables, where, old_prefix, cache_tables_seen, max_depth=0, fields=[], cur_depth=0):
    674686    """
    675687    Helper function that recursively populates the select, tables and where (in
    676688    place) for select_related queries.
    677689    """
    678690    qn = backend.quote_name
     691    if max_depth and cur_depth > max_depth:
     692        return
    679693    for f in opts.fields:
    680         if f.rel and not f.null:
     694        if f.rel and not f.null and (fields == [] or (cur_depth == 0 and f.name in fields)):
    681695            db_table = f.rel.to._meta.db_table
    682696            if db_table not in cache_tables_seen:
    683697                tables.append(qn(db_table))
     
    689703            where.append('%s.%s = %s.%s' % \
    690704                (qn(old_prefix), qn(f.column), qn(db_table), qn(f.rel.get_related_field().column)))
    691705            select.extend(['%s.%s' % (qn(db_table), qn(f2.column)) for f2 in f.rel.to._meta.fields])
    692             fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen)
     706            fill_table_cache(f.rel.to._meta, select, tables, where, db_table, cache_tables_seen, max_depth, fields, cur_depth+1)
    693707
    694708def parse_lookup(kwarg_items, opts):
    695709    # Helper function that handles converting API kwargs
Back to Top