Ticket #13781: select_related_subclass_patch.diff

File select_related_subclass_patch.diff, 9.3 KB (added by David Bennett, 13 years ago)

Tests and patch (trunk)

  • django/db/models/query.py

    From 2b446904bfeb8bb79762426eb5ee0b97f57e6844 Mon Sep 17 00:00:00 2001
    From: David Bennett <david@dbinit.com>
    Date: Mon, 30 Jan 2012 12:54:44 -0600
    Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
    
    ---
     django/db/models/query.py                          |   19 ++++++++++----
     django/db/models/sql/compiler.py                   |   12 ++++++---
     .../select_related_onetoone/models.py              |   26 ++++++++++++++++++++
     .../select_related_onetoone/tests.py               |   23 ++++++++++++++++-
     4 files changed, 70 insertions(+), 10 deletions(-)
    
    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 41c24c7..c76d6d0 100644
    a b class EmptyQuerySet(QuerySet):  
    12381238    value_annotation = False
    12391239
    12401240def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
    1241                    only_load=None, local_only=False):
     1241                   only_load=None, local_only=False, last_klass=None):
    12421242    """
    12431243    Helper function that recursively returns an information for a klass, to be
    12441244    used in get_cached_row.  It exists just to compute this information only
    def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,  
    12601260       the full field list for `klass` can be assumed.
    12611261     * local_only - Only populate local fields. This is used when
    12621262       following reverse select-related relations
     1263     * last_klass - the last class seen when following reverse
     1264       select-related relations
    12631265    """
    12641266    if max_depth and requested is None and cur_depth > max_depth:
    12651267        # We've recursed deeply enough; stop now.
    def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,  
    13051307        # But kwargs version of Model.__init__ is slower, so we should avoid using
    13061308        # it when it is not really neccesary.
    13071309        if local_only and len(klass._meta.local_fields) != len(klass._meta.fields):
    1308             field_count = len(klass._meta.local_fields)
    1309             field_names = [f.attname for f in klass._meta.local_fields]
     1310            parents = [p for p in klass._meta.get_parent_list()
     1311                       if p is not last_klass]
     1312            field_names = [f.attname for f in klass._meta.fields
     1313                           if f in klass._meta.local_fields
     1314                           or f.model in parents]
     1315            field_count = len(field_names)
     1316            if field_count == len(klass._meta.fields):
     1317                field_names = ()
    13101318        else:
    13111319            field_count = len(klass._meta.fields)
    13121320            field_names = ()
    def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,  
    13301338            if o.field.unique and select_related_descend(o.field, restricted, requested, reverse=True):
    13311339                next = requested[o.field.related_query_name()]
    13321340                klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1,
    1333                                             requested=next, only_load=only_load, local_only=True)
     1341                                            requested=next, only_load=only_load, local_only=True,
     1342                                            last_klass=klass)
    13341343                reverse_related_fields.append((o.field, klass_info))
    13351344
    13361345    return klass, field_names, field_count, related_fields, reverse_related_fields
    def get_cached_row(row, index_start, using, klass_info, offset=0):  
    14161425                # Now populate all the non-local field values
    14171426                # on the related object
    14181427                for rel_field, rel_model in rel_obj._meta.get_fields_with_model():
    1419                     if rel_model is not None:
     1428                    if rel_model is not None and isinstance(obj, rel_model):
    14201429                        setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
    14211430                        # populate the field cache for any related object
    14221431                        # that has already been retrieved
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index 72948f9..c773a3b 100644
    a b class SQLCompiler(object):  
    246246        return result
    247247
    248248    def get_default_columns(self, with_aliases=False, col_aliases=None,
    249             start_alias=None, opts=None, as_pairs=False, local_only=False):
     249            start_alias=None, opts=None, as_pairs=False, local_only=False,
     250            last_opts=None):
    250251        """
    251252        Computes the default columns for selecting every field in the base
    252253        model. Will sometimes be called to pull in related models (e.g. via
    class SQLCompiler(object):  
    270271
    271272        if start_alias:
    272273            seen = {None: start_alias}
     274        parents = [p for p in opts.get_parent_list() if p._meta is not last_opts]
    273275        for field, model in opts.get_fields_with_model():
    274             if local_only and model is not None:
     276            if local_only and model is not None and model not in parents:
    275277                continue
    276278            if start_alias:
    277279                try:
    class SQLCompiler(object):  
    282284                    else:
    283285                        link_field = opts.get_ancestor_link(model)
    284286                        alias = self.query.join((start_alias, model._meta.db_table,
    285                                 link_field.column, model._meta.pk.column))
     287                                link_field.column, model._meta.pk.column),
     288                                promote=(model in parents))
    286289                    seen[model] = alias
    287290            else:
    288291                # If we're starting from the base model of the queryset, the
    class SQLCompiler(object):  
    728731                )
    729732                used.add(alias)
    730733                columns, aliases = self.get_default_columns(start_alias=alias,
    731                     opts=model._meta, as_pairs=True, local_only=True)
     734                    opts=model._meta, as_pairs=True, local_only=True,
     735                    last_opts=opts)
    732736                self.query.related_select_cols.extend(columns)
    733737                self.query.related_select_fields.extend(model._meta.fields)
    734738
  • tests/regressiontests/select_related_onetoone/models.py

    diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
    index 3d6da9b..4bfad1d 100644
    a b class StatDetails(models.Model):  
    4545class AdvancedUserStat(UserStat):
    4646    karma = models.IntegerField()
    4747
     48
    4849class Image(models.Model):
    4950    name = models.CharField(max_length=100)
    5051
    class Image(models.Model):  
    5253class Product(models.Model):
    5354    name = models.CharField(max_length=100)
    5455    image = models.OneToOneField(Image, null=True)
     56
     57
     58class Parent1(models.Model):
     59    name1 = models.CharField(max_length=50)
     60    def __unicode__(self):
     61        return self.name1
     62
     63
     64class Parent2(models.Model):
     65    name2 = models.CharField(max_length=50)
     66    def __unicode__(self):
     67        return self.name2
     68
     69
     70class Child1(Parent1, Parent2):
     71    other = models.CharField(max_length=50)
     72    def __unicode__(self):
     73        return self.name1
     74
     75
     76class Child2(Parent1):
     77    parent2 = models.OneToOneField(Parent2)
     78    other = models.CharField(max_length=50)
     79    def __unicode__(self):
     80        return self.name1
  • tests/regressiontests/select_related_onetoone/tests.py

    diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
    index 643a0ff..a57142c 100644
    a b from __future__ import with_statement, absolute_import  
    33from django.test import TestCase
    44
    55from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
    6     AdvancedUserStat, Image, Product)
     6    AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2)
    77
    88
    99class ReverseSelectRelatedTestCase(TestCase):
    class ReverseSelectRelatedTestCase(TestCase):  
    2121        advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
    2222                                                  results=results2)
    2323        StatDetails.objects.create(base_stats=advstat, comments=250)
     24        p1 = Parent1(name1="Only Parent1")
     25        p1.save()
     26        c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2")
     27        c1.save()
     28        p2 = Parent2(name2="Child2 Parent2")
     29        p2.save()
     30        c2 = Child2(name1="Child2 Parent1", parent2=p2)
     31        c2.save()
    2432
    2533    def test_basic(self):
    2634        with self.assertNumQueries(1):
    class ReverseSelectRelatedTestCase(TestCase):  
    8088        p2 = Product.objects.create(name="Talking Django Plushie")
    8189
    8290        self.assertEqual(len(Product.objects.select_related("image")), 2)
     91
     92    def test_parent_only(self):
     93        Parent1.objects.select_related('child1').get(name1="Only Parent1")
     94
     95    def test_multiple_subclass(self):
     96        with self.assertNumQueries(1):
     97            p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
     98            self.assertEqual(p.child1.name2, u'Child1 Parent2')
     99
     100    def test_onetoone_with_subclass(self):
     101        with self.assertNumQueries(1):
     102            p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
     103            self.assertEqual(p.child2.name1, u'Child2 Parent1')
Back to Top