From 9282d1d27823e780ccad7ddb006182e27e66262e Mon Sep 17 00:00:00 2001
From: David Bennett <david@dbinit.com>
Date: Mon, 30 Jan 2012 12:42:57 -0600
Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
---
django/db/models/query.py | 16 +++++++++---
django/db/models/sql/compiler.py | 12 ++++++---
.../select_related_onetoone/models.py | 26 ++++++++++++++++++++
.../select_related_onetoone/tests.py | 23 ++++++++++++++++-
4 files changed, 68 insertions(+), 9 deletions(-)
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 324554e..ede2581 100644
a
|
b
|
class EmptyQuerySet(QuerySet):
|
1126 | 1126 | |
1127 | 1127 | |
1128 | 1128 | def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, |
1129 | | requested=None, offset=0, only_load=None, local_only=False): |
| 1129 | requested=None, offset=0, only_load=None, local_only=False, |
| 1130 | last_klass=None): |
1130 | 1131 | """ |
1131 | 1132 | Helper function that recursively returns an object with the specified |
1132 | 1133 | related attributes already populated. |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1156 | 1157 | the full field list for `klass` can be assumed. |
1157 | 1158 | * local_only - Only populate local fields. This is used when building |
1158 | 1159 | following reverse select-related relations |
| 1160 | * last_klass - the last class seen when following reverse |
| 1161 | select-related relations |
1159 | 1162 | """ |
1160 | 1163 | if max_depth and requested is None and cur_depth > max_depth: |
1161 | 1164 | # We've recursed deeply enough; stop now. |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1202 | 1205 | else: |
1203 | 1206 | # Load all fields on klass |
1204 | 1207 | if local_only: |
1205 | | field_names = [f.attname for f in klass._meta.local_fields] |
| 1208 | parents = [p for p in klass._meta.get_parent_list() |
| 1209 | if p is not last_klass] |
| 1210 | field_names = [f.attname for f in klass._meta.fields |
| 1211 | if f in klass._meta.local_fields |
| 1212 | or f.model in parents] |
1206 | 1213 | else: |
1207 | 1214 | field_names = [f.attname for f in klass._meta.fields] |
1208 | 1215 | field_count = len(field_names) |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1261 | 1268 | next = requested[f.related_query_name()] |
1262 | 1269 | # Recursively retrieve the data for the related object |
1263 | 1270 | cached_row = get_cached_row(model, row, index_end, using, |
1264 | | max_depth, cur_depth+1, next, only_load=only_load, local_only=True) |
| 1271 | max_depth, cur_depth+1, next, only_load=only_load, local_only=True, |
| 1272 | last_klass=klass) |
1265 | 1273 | # If the recursive descent found an object, populate the |
1266 | 1274 | # descriptor caches relevant to the object |
1267 | 1275 | if cached_row: |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1277 | 1285 | # Now populate all the non-local field values |
1278 | 1286 | # on the related object |
1279 | 1287 | for rel_field,rel_model in rel_obj._meta.get_fields_with_model(): |
1280 | | if rel_model is not None: |
| 1288 | if rel_model is not None and isinstance(obj, rel_model): |
1281 | 1289 | setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) |
1282 | 1290 | # populate the field cache for any related object |
1283 | 1291 | # that has already been retrieved |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index d425c8b..7c0ff75 100644
a
|
b
|
class SQLCompiler(object):
|
216 | 216 | return result |
217 | 217 | |
218 | 218 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
219 | | start_alias=None, opts=None, as_pairs=False, local_only=False): |
| 219 | start_alias=None, opts=None, as_pairs=False, local_only=False, |
| 220 | last_opts=None): |
220 | 221 | """ |
221 | 222 | Computes the default columns for selecting every field in the base |
222 | 223 | model. Will sometimes be called to pull in related models (e.g. via |
… |
… |
class SQLCompiler(object):
|
240 | 241 | |
241 | 242 | if start_alias: |
242 | 243 | seen = {None: start_alias} |
| 244 | parents = [p for p in opts.get_parent_list() if p._meta is not last_opts] |
243 | 245 | for field, model in opts.get_fields_with_model(): |
244 | | if local_only and model is not None: |
| 246 | if local_only and model is not None and model not in parents: |
245 | 247 | continue |
246 | 248 | if start_alias: |
247 | 249 | try: |
… |
… |
class SQLCompiler(object):
|
252 | 254 | else: |
253 | 255 | link_field = opts.get_ancestor_link(model) |
254 | 256 | alias = self.query.join((start_alias, model._meta.db_table, |
255 | | link_field.column, model._meta.pk.column)) |
| 257 | link_field.column, model._meta.pk.column), |
| 258 | promote=(model in parents)) |
256 | 259 | seen[model] = alias |
257 | 260 | else: |
258 | 261 | # If we're starting from the base model of the queryset, the |
… |
… |
class SQLCompiler(object):
|
650 | 653 | ) |
651 | 654 | used.add(alias) |
652 | 655 | columns, aliases = self.get_default_columns(start_alias=alias, |
653 | | opts=model._meta, as_pairs=True, local_only=True) |
| 656 | opts=model._meta, as_pairs=True, local_only=True, |
| 657 | last_opts=opts) |
654 | 658 | self.query.related_select_cols.extend(columns) |
655 | 659 | self.query.related_select_fields.extend(model._meta.fields) |
656 | 660 | |
diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
index 3d6da9b..4bfad1d 100644
a
|
b
|
class StatDetails(models.Model):
|
45 | 45 | class AdvancedUserStat(UserStat): |
46 | 46 | karma = models.IntegerField() |
47 | 47 | |
| 48 | |
48 | 49 | class Image(models.Model): |
49 | 50 | name = models.CharField(max_length=100) |
50 | 51 | |
… |
… |
class Image(models.Model):
|
52 | 53 | class Product(models.Model): |
53 | 54 | name = models.CharField(max_length=100) |
54 | 55 | image = models.OneToOneField(Image, null=True) |
| 56 | |
| 57 | |
| 58 | class Parent1(models.Model): |
| 59 | name1 = models.CharField(max_length=50) |
| 60 | def __unicode__(self): |
| 61 | return self.name1 |
| 62 | |
| 63 | |
| 64 | class Parent2(models.Model): |
| 65 | name2 = models.CharField(max_length=50) |
| 66 | def __unicode__(self): |
| 67 | return self.name2 |
| 68 | |
| 69 | |
| 70 | class Child1(Parent1, Parent2): |
| 71 | other = models.CharField(max_length=50) |
| 72 | def __unicode__(self): |
| 73 | return self.name1 |
| 74 | |
| 75 | |
| 76 | class Child2(Parent1): |
| 77 | parent2 = models.OneToOneField(Parent2) |
| 78 | other = models.CharField(max_length=50) |
| 79 | def __unicode__(self): |
| 80 | return self.name1 |
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index ab35fec..407cdc9 100644
a
|
b
|
from django.conf import settings
|
3 | 3 | from django.test import TestCase |
4 | 4 | |
5 | 5 | from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, |
6 | | AdvancedUserStat, Image, Product) |
| 6 | AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2) |
7 | 7 | |
8 | 8 | class ReverseSelectRelatedTestCase(TestCase): |
9 | 9 | def setUp(self): |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
20 | 20 | advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, |
21 | 21 | results=results2) |
22 | 22 | StatDetails.objects.create(base_stats=advstat, comments=250) |
| 23 | p1 = Parent1(name1="Only Parent1") |
| 24 | p1.save() |
| 25 | c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2") |
| 26 | c1.save() |
| 27 | p2 = Parent2(name2="Child2 Parent2") |
| 28 | p2.save() |
| 29 | c2 = Child2(name1="Child2 Parent1", parent2=p2) |
| 30 | c2.save() |
23 | 31 | |
24 | 32 | def test_basic(self): |
25 | 33 | def test(): |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
88 | 96 | p2 = Product.objects.create(name="Talking Django Plushie") |
89 | 97 | |
90 | 98 | self.assertEqual(len(Product.objects.select_related("image")), 2) |
| 99 | |
| 100 | def test_parent_only(self): |
| 101 | Parent1.objects.select_related('child1').get(name1="Only Parent1") |
| 102 | |
| 103 | def test_multiple_subclass(self): |
| 104 | with self.assertNumQueries(1): |
| 105 | p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") |
| 106 | self.assertEqual(p.child1.name2, u"Child1 Parent2") |
| 107 | |
| 108 | def test_onetoone_with_subclass(self): |
| 109 | with self.assertNumQueries(1): |
| 110 | p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") |
| 111 | self.assertEqual(p.child2.name1, u"Child2 Parent1") |