From bbd8e472f24db8a3134687d4deae9cf39faa437c Mon Sep 17 00:00:00 2001
From: David Bennett <david@dbinit.com>
Date: Mon, 30 Jan 2012 13:02:13 -0600
Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
---
django/db/models/query.py | 16 +++++++++---
django/db/models/sql/compiler.py | 12 ++++++---
.../select_related_onetoone/models.py | 26 ++++++++++++++++++++
.../select_related_onetoone/tests.py | 23 ++++++++++++++++-
4 files changed, 68 insertions(+), 9 deletions(-)
diff --git a/django/db/models/query.py b/django/db/models/query.py
index a2d7ffb..cb50647 100644
a
|
b
|
class EmptyQuerySet(QuerySet):
|
1142 | 1142 | |
1143 | 1143 | |
1144 | 1144 | def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0, |
1145 | | requested=None, offset=0, only_load=None, local_only=False): |
| 1145 | requested=None, offset=0, only_load=None, local_only=False, |
| 1146 | last_klass=None): |
1146 | 1147 | """ |
1147 | 1148 | Helper function that recursively returns an object with the specified |
1148 | 1149 | related attributes already populated. |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1172 | 1173 | the full field list for `klass` can be assumed. |
1173 | 1174 | * local_only - Only populate local fields. This is used when building |
1174 | 1175 | following reverse select-related relations |
| 1176 | * last_klass - the last class seen when following reverse |
| 1177 | select-related relations |
1175 | 1178 | """ |
1176 | 1179 | if max_depth and requested is None and cur_depth > max_depth: |
1177 | 1180 | # We've recursed deeply enough; stop now. |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1218 | 1221 | else: |
1219 | 1222 | # Load all fields on klass |
1220 | 1223 | if local_only: |
1221 | | field_names = [f.attname for f in klass._meta.local_fields] |
| 1224 | parents = [p for p in klass._meta.get_parent_list() |
| 1225 | if p is not last_klass] |
| 1226 | field_names = [f.attname for f in klass._meta.fields |
| 1227 | if f in klass._meta.local_fields |
| 1228 | or f.model in parents] |
1222 | 1229 | else: |
1223 | 1230 | field_names = [f.attname for f in klass._meta.fields] |
1224 | 1231 | field_count = len(field_names) |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1277 | 1284 | next = requested[f.related_query_name()] |
1278 | 1285 | # Recursively retrieve the data for the related object |
1279 | 1286 | cached_row = get_cached_row(model, row, index_end, using, |
1280 | | max_depth, cur_depth+1, next, only_load=only_load, local_only=True) |
| 1287 | max_depth, cur_depth+1, next, only_load=only_load, local_only=True, |
| 1288 | last_klass=klass) |
1281 | 1289 | # If the recursive descent found an object, populate the |
1282 | 1290 | # descriptor caches relevant to the object |
1283 | 1291 | if cached_row: |
… |
… |
def get_cached_row(klass, row, index_start, using, max_depth=0, cur_depth=0,
|
1293 | 1301 | # Now populate all the non-local field values |
1294 | 1302 | # on the related object |
1295 | 1303 | for rel_field,rel_model in rel_obj._meta.get_fields_with_model(): |
1296 | | if rel_model is not None: |
| 1304 | if rel_model is not None and isinstance(obj, rel_model): |
1297 | 1305 | setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) |
1298 | 1306 | # populate the field cache for any related object |
1299 | 1307 | # that has already been retrieved |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index fb9674c..bcfd014 100644
a
|
b
|
class SQLCompiler(object):
|
213 | 213 | return result |
214 | 214 | |
215 | 215 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
216 | | start_alias=None, opts=None, as_pairs=False, local_only=False): |
| 216 | start_alias=None, opts=None, as_pairs=False, local_only=False, |
| 217 | last_opts=None): |
217 | 218 | """ |
218 | 219 | Computes the default columns for selecting every field in the base |
219 | 220 | model. Will sometimes be called to pull in related models (e.g. via |
… |
… |
class SQLCompiler(object):
|
237 | 238 | |
238 | 239 | if start_alias: |
239 | 240 | seen = {None: start_alias} |
| 241 | parents = [p for p in opts.get_parent_list() if p._meta is not last_opts] |
240 | 242 | for field, model in opts.get_fields_with_model(): |
241 | | if local_only and model is not None: |
| 243 | if local_only and model is not None and model not in parents: |
242 | 244 | continue |
243 | 245 | if start_alias: |
244 | 246 | try: |
… |
… |
class SQLCompiler(object):
|
249 | 251 | else: |
250 | 252 | link_field = opts.get_ancestor_link(model) |
251 | 253 | alias = self.query.join((start_alias, model._meta.db_table, |
252 | | link_field.column, model._meta.pk.column)) |
| 254 | link_field.column, model._meta.pk.column), |
| 255 | promote=(model in parents)) |
253 | 256 | seen[model] = alias |
254 | 257 | else: |
255 | 258 | # If we're starting from the base model of the queryset, the |
… |
… |
class SQLCompiler(object):
|
647 | 650 | ) |
648 | 651 | used.add(alias) |
649 | 652 | columns, aliases = self.get_default_columns(start_alias=alias, |
650 | | opts=model._meta, as_pairs=True, local_only=True) |
| 653 | opts=model._meta, as_pairs=True, local_only=True, |
| 654 | last_opts=opts) |
651 | 655 | self.query.related_select_cols.extend(columns) |
652 | 656 | self.query.related_select_fields.extend(model._meta.fields) |
653 | 657 | |
diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
index 3d6da9b..4bfad1d 100644
a
|
b
|
class StatDetails(models.Model):
|
45 | 45 | class AdvancedUserStat(UserStat): |
46 | 46 | karma = models.IntegerField() |
47 | 47 | |
| 48 | |
48 | 49 | class Image(models.Model): |
49 | 50 | name = models.CharField(max_length=100) |
50 | 51 | |
… |
… |
class Image(models.Model):
|
52 | 53 | class Product(models.Model): |
53 | 54 | name = models.CharField(max_length=100) |
54 | 55 | image = models.OneToOneField(Image, null=True) |
| 56 | |
| 57 | |
| 58 | class Parent1(models.Model): |
| 59 | name1 = models.CharField(max_length=50) |
| 60 | def __unicode__(self): |
| 61 | return self.name1 |
| 62 | |
| 63 | |
| 64 | class Parent2(models.Model): |
| 65 | name2 = models.CharField(max_length=50) |
| 66 | def __unicode__(self): |
| 67 | return self.name2 |
| 68 | |
| 69 | |
| 70 | class Child1(Parent1, Parent2): |
| 71 | other = models.CharField(max_length=50) |
| 72 | def __unicode__(self): |
| 73 | return self.name1 |
| 74 | |
| 75 | |
| 76 | class Child2(Parent1): |
| 77 | parent2 = models.OneToOneField(Parent2) |
| 78 | other = models.CharField(max_length=50) |
| 79 | def __unicode__(self): |
| 80 | return self.name1 |
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index 4ccb584..b2f8549 100644
a
|
b
|
from django.conf import settings
|
3 | 3 | from django.test import TestCase |
4 | 4 | |
5 | 5 | from models import (User, UserProfile, UserStat, UserStatResult, StatDetails, |
6 | | AdvancedUserStat, Image, Product) |
| 6 | AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2) |
7 | 7 | |
8 | 8 | class ReverseSelectRelatedTestCase(TestCase): |
9 | 9 | def setUp(self): |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
25 | 25 | advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, |
26 | 26 | results=results2) |
27 | 27 | StatDetails.objects.create(base_stats=advstat, comments=250) |
| 28 | p1 = Parent1(name1="Only Parent1") |
| 29 | p1.save() |
| 30 | c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2") |
| 31 | c1.save() |
| 32 | p2 = Parent2(name2="Child2 Parent2") |
| 33 | p2.save() |
| 34 | c2 = Child2(name1="Child2 Parent1", parent2=p2) |
| 35 | c2.save() |
28 | 36 | |
29 | 37 | db.reset_queries() |
30 | 38 | |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
92 | 100 | p2 = Product.objects.create(name="Talking Django Plushie") |
93 | 101 | |
94 | 102 | self.assertEqual(len(Product.objects.select_related("image")), 2) |
| 103 | |
| 104 | def test_parent_only(self): |
| 105 | Parent1.objects.select_related('child1').get(name1="Only Parent1") |
| 106 | |
| 107 | def test_multiple_subclass(self): |
| 108 | p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") |
| 109 | self.assertEqual(p.child1.name2, u'Child1 Parent2') |
| 110 | self.assertQueries(1) |
| 111 | |
| 112 | def test_onetoone_with_subclass(self): |
| 113 | p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") |
| 114 | self.assertEqual(p.child2.name1, u'Child2 Parent1') |
| 115 | self.assertQueries(1) |