Ticket #13781: select_related_subclass_patch_1.3.X.diff

File select_related_subclass_patch_1.3.X.diff, 9.0 KB (added by David Bennett, 13 years ago)

Tests and patch (1.3.X)

  • django/db/models/query.py

    From 9282d1d27823e780ccad7ddb006182e27e66262e Mon Sep 17 00:00:00 2001
    From: David Bennett <david@dbinit.com>
    Date: Mon, 30 Jan 2012 12:42:57 -0600
    Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
    
    ---
     django/db/models/query.py                          |   16 +++++++++---
     django/db/models/sql/compiler.py                   |   12 ++++++---
     .../select_related_onetoone/models.py              |   26 ++++++++++++++++++++
     .../select_related_onetoone/tests.py               |   23 ++++++++++++++++-
     4 files changed, 68 insertions(+), 9 deletions(-)
    
    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 324554e..ede2581 100644
    a b class EmptyQuerySet(QuerySet):  
    11261126
    11271127
    11281128def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
    1129                    requested=None, offset=0, only_load=None, local_only=False):
     1129                   requested=None, offset=0, only_load=None, local_only=False,
     1130                   last_klass=None):
    11301131    """
    11311132    Helper function that recursively returns an object with the specified
    11321133    related attributes already populated.
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,  
    11561157       the full field list for `klass` can be assumed.
    11571158     * local_only - Only populate local fields. This is used when building
    11581159       following reverse select-related relations
     1160     * last_klass - the last class seen when following reverse
     1161       select-related relations
    11591162    """
    11601163    if max_depth and requested is None and cur_depth > max_depth:
    11611164        # We've recursed deeply enough; stop now.
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,  
    12021205    else:
    12031206        # Load all fields on klass
    12041207        if local_only:
    1205             field_names = [f.attname for f in klass._meta.local_fields]
     1208            parents = [p for p in klass._meta.get_parent_list()
     1209                       if p is not last_klass]
     1210            field_names = [f.attname for f in klass._meta.fields
     1211                           if f in klass._meta.local_fields
     1212                           or f.model in parents]
    12061213        else:
    12071214            field_names = [f.attname for f in klass._meta.fields]
    12081215        field_count = len(field_names)
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,  
    12611268            next = requested[f.related_query_name()]
    12621269            # Recursively retrieve the data for the related object
    12631270            cached_row = get_cached_row(model, row, index_end, using,
    1264                 max_depth, cur_depth+1, next, only_load=only_load, local_only=True)
     1271                max_depth, cur_depth+1, next, only_load=only_load, local_only=True,
     1272                last_klass=klass)
    12651273            # If the recursive descent found an object, populate the
    12661274            # descriptor caches relevant to the object
    12671275            if cached_row:
    def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,  
    12771285                    # Now populate all the non-local field values
    12781286                    # on the related object
    12791287                    for rel_field,rel_model in rel_obj._meta.get_fields_with_model():
    1280                         if rel_model is not None:
     1288                        if rel_model is not None and isinstance(obj, rel_model):
    12811289                            setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname))
    12821290                            # populate the field cache for any related object
    12831291                            # that has already been retrieved
  • django/db/models/sql/compiler.py

    diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
    index d425c8b..7c0ff75 100644
    a b class SQLCompiler(object):  
    216216        return result
    217217
    218218    def get_default_columns(self, with_aliases=False, col_aliases=None,
    219             start_alias=None, opts=None, as_pairs=False, local_only=False):
     219            start_alias=None, opts=None, as_pairs=False, local_only=False,
     220            last_opts=None):
    220221        """
    221222        Computes the default columns for selecting every field in the base
    222223        model. Will sometimes be called to pull in related models (e.g. via
    class SQLCompiler(object):  
    240241
    241242        if start_alias:
    242243            seen = {None: start_alias}
     244        parents = [p for p in opts.get_parent_list() if p._meta is not last_opts]
    243245        for field, model in opts.get_fields_with_model():
    244             if local_only and model is not None:
     246            if local_only and model is not None and model not in parents:
    245247                continue
    246248            if start_alias:
    247249                try:
    class SQLCompiler(object):  
    252254                    else:
    253255                        link_field = opts.get_ancestor_link(model)
    254256                        alias = self.query.join((start_alias, model._meta.db_table,
    255                                 link_field.column, model._meta.pk.column))
     257                                link_field.column, model._meta.pk.column),
     258                                promote=(model in parents))
    256259                    seen[model] = alias
    257260            else:
    258261                # If we're starting from the base model of the queryset, the
    class SQLCompiler(object):  
    650653                )
    651654                used.add(alias)
    652655                columns, aliases = self.get_default_columns(start_alias=alias,
    653                     opts=model._meta, as_pairs=True, local_only=True)
     656                    opts=model._meta, as_pairs=True, local_only=True,
     657                    last_opts=opts)
    654658                self.query.related_select_cols.extend(columns)
    655659                self.query.related_select_fields.extend(model._meta.fields)
    656660
  • tests/regressiontests/select_related_onetoone/models.py

    diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
    index 3d6da9b..4bfad1d 100644
    a b class StatDetails(models.Model):  
    4545class AdvancedUserStat(UserStat):
    4646    karma = models.IntegerField()
    4747
     48
    4849class Image(models.Model):
    4950    name = models.CharField(max_length=100)
    5051
    class Image(models.Model):  
    5253class Product(models.Model):
    5354    name = models.CharField(max_length=100)
    5455    image = models.OneToOneField(Image, null=True)
     56
     57
     58class Parent1(models.Model):
     59    name1 = models.CharField(max_length=50)
     60    def __unicode__(self):
     61        return self.name1
     62
     63
     64class Parent2(models.Model):
     65    name2 = models.CharField(max_length=50)
     66    def __unicode__(self):
     67        return self.name2
     68
     69
     70class Child1(Parent1, Parent2):
     71    other = models.CharField(max_length=50)
     72    def __unicode__(self):
     73        return self.name1
     74
     75
     76class Child2(Parent1):
     77    parent2 = models.OneToOneField(Parent2)
     78    other = models.CharField(max_length=50)
     79    def __unicode__(self):
     80        return self.name1
  • tests/regressiontests/select_related_onetoone/tests.py

    diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
    index ab35fec..407cdc9 100644
    a b from django.conf import settings  
    33from django.test import TestCase
    44
    55from models import (User, UserProfile, UserStat, UserStatResult, StatDetails,
    6     AdvancedUserStat, Image, Product)
     6    AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2)
    77
    88class ReverseSelectRelatedTestCase(TestCase):
    99    def setUp(self):
    class ReverseSelectRelatedTestCase(TestCase):  
    2020        advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5,
    2121                                                  results=results2)
    2222        StatDetails.objects.create(base_stats=advstat, comments=250)
     23        p1 = Parent1(name1="Only Parent1")
     24        p1.save()
     25        c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2")
     26        c1.save()
     27        p2 = Parent2(name2="Child2 Parent2")
     28        p2.save()
     29        c2 = Child2(name1="Child2 Parent1", parent2=p2)
     30        c2.save()
    2331
    2432    def test_basic(self):
    2533        def test():
    class ReverseSelectRelatedTestCase(TestCase):  
    8896        p2 = Product.objects.create(name="Talking Django Plushie")
    8997
    9098        self.assertEqual(len(Product.objects.select_related("image")), 2)
     99
     100    def test_parent_only(self):
     101        Parent1.objects.select_related('child1').get(name1="Only Parent1")
     102
     103    def test_multiple_subclass(self):
     104        with self.assertNumQueries(1):
     105            p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1")
     106            self.assertEqual(p.child1.name2, u"Child1 Parent2")
     107
     108    def test_onetoone_with_subclass(self):
     109        with self.assertNumQueries(1):
     110            p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2")
     111            self.assertEqual(p.child2.name1, u"Child2 Parent1")
Back to Top