From 2b446904bfeb8bb79762426eb5ee0b97f57e6844 Mon Sep 17 00:00:00 2001
From: David Bennett <david@dbinit.com>
Date: Mon, 30 Jan 2012 12:54:44 -0600
Subject: [PATCH] Fixed #13781 -- select_related and multiple inheritance
---
django/db/models/query.py | 19 ++++++++++----
django/db/models/sql/compiler.py | 12 ++++++---
.../select_related_onetoone/models.py | 26 ++++++++++++++++++++
.../select_related_onetoone/tests.py | 23 ++++++++++++++++-
4 files changed, 70 insertions(+), 10 deletions(-)
diff --git a/django/db/models/query.py b/django/db/models/query.py
index 41c24c7..c76d6d0 100644
a
|
b
|
class EmptyQuerySet(QuerySet):
|
1238 | 1238 | value_annotation = False |
1239 | 1239 | |
1240 | 1240 | def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, |
1241 | | only_load=None, local_only=False): |
| 1241 | only_load=None, local_only=False, last_klass=None): |
1242 | 1242 | """ |
1243 | 1243 | Helper function that recursively returns an information for a klass, to be |
1244 | 1244 | used in get_cached_row. It exists just to compute this information only |
… |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
1260 | 1260 | the full field list for `klass` can be assumed. |
1261 | 1261 | * local_only - Only populate local fields. This is used when |
1262 | 1262 | following reverse select-related relations |
| 1263 | * last_klass - the last class seen when following reverse |
| 1264 | select-related relations |
1263 | 1265 | """ |
1264 | 1266 | if max_depth and requested is None and cur_depth > max_depth: |
1265 | 1267 | # We've recursed deeply enough; stop now. |
… |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
1305 | 1307 | # But kwargs version of Model.__init__ is slower, so we should avoid using |
1306 | 1308 | # it when it is not really neccesary. |
1307 | 1309 | if local_only and len(klass._meta.local_fields) != len(klass._meta.fields): |
1308 | | field_count = len(klass._meta.local_fields) |
1309 | | field_names = [f.attname for f in klass._meta.local_fields] |
| 1310 | parents = [p for p in klass._meta.get_parent_list() |
| 1311 | if p is not last_klass] |
| 1312 | field_names = [f.attname for f in klass._meta.fields |
| 1313 | if f in klass._meta.local_fields |
| 1314 | or f.model in parents] |
| 1315 | field_count = len(field_names) |
| 1316 | if field_count == len(klass._meta.fields): |
| 1317 | field_names = () |
1310 | 1318 | else: |
1311 | 1319 | field_count = len(klass._meta.fields) |
1312 | 1320 | field_names = () |
… |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
1330 | 1338 | if o.field.unique and select_related_descend(o.field, restricted, requested, reverse=True): |
1331 | 1339 | next = requested[o.field.related_query_name()] |
1332 | 1340 | klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, |
1333 | | requested=next, only_load=only_load, local_only=True) |
| 1341 | requested=next, only_load=only_load, local_only=True, |
| 1342 | last_klass=klass) |
1334 | 1343 | reverse_related_fields.append((o.field, klass_info)) |
1335 | 1344 | |
1336 | 1345 | return klass, field_names, field_count, related_fields, reverse_related_fields |
… |
… |
def get_cached_row(row, index_start, using, klass_info, offset=0):
|
1416 | 1425 | # Now populate all the non-local field values |
1417 | 1426 | # on the related object |
1418 | 1427 | for rel_field, rel_model in rel_obj._meta.get_fields_with_model(): |
1419 | | if rel_model is not None: |
| 1428 | if rel_model is not None and isinstance(obj, rel_model): |
1420 | 1429 | setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) |
1421 | 1430 | # populate the field cache for any related object |
1422 | 1431 | # that has already been retrieved |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index 72948f9..c773a3b 100644
a
|
b
|
class SQLCompiler(object):
|
246 | 246 | return result |
247 | 247 | |
248 | 248 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
249 | | start_alias=None, opts=None, as_pairs=False, local_only=False): |
| 249 | start_alias=None, opts=None, as_pairs=False, local_only=False, |
| 250 | last_opts=None): |
250 | 251 | """ |
251 | 252 | Computes the default columns for selecting every field in the base |
252 | 253 | model. Will sometimes be called to pull in related models (e.g. via |
… |
… |
class SQLCompiler(object):
|
270 | 271 | |
271 | 272 | if start_alias: |
272 | 273 | seen = {None: start_alias} |
| 274 | parents = [p for p in opts.get_parent_list() if p._meta is not last_opts] |
273 | 275 | for field, model in opts.get_fields_with_model(): |
274 | | if local_only and model is not None: |
| 276 | if local_only and model is not None and model not in parents: |
275 | 277 | continue |
276 | 278 | if start_alias: |
277 | 279 | try: |
… |
… |
class SQLCompiler(object):
|
282 | 284 | else: |
283 | 285 | link_field = opts.get_ancestor_link(model) |
284 | 286 | alias = self.query.join((start_alias, model._meta.db_table, |
285 | | link_field.column, model._meta.pk.column)) |
| 287 | link_field.column, model._meta.pk.column), |
| 288 | promote=(model in parents)) |
286 | 289 | seen[model] = alias |
287 | 290 | else: |
288 | 291 | # If we're starting from the base model of the queryset, the |
… |
… |
class SQLCompiler(object):
|
728 | 731 | ) |
729 | 732 | used.add(alias) |
730 | 733 | columns, aliases = self.get_default_columns(start_alias=alias, |
731 | | opts=model._meta, as_pairs=True, local_only=True) |
| 734 | opts=model._meta, as_pairs=True, local_only=True, |
| 735 | last_opts=opts) |
732 | 736 | self.query.related_select_cols.extend(columns) |
733 | 737 | self.query.related_select_fields.extend(model._meta.fields) |
734 | 738 | |
diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
index 3d6da9b..4bfad1d 100644
a
|
b
|
class StatDetails(models.Model):
|
45 | 45 | class AdvancedUserStat(UserStat): |
46 | 46 | karma = models.IntegerField() |
47 | 47 | |
| 48 | |
48 | 49 | class Image(models.Model): |
49 | 50 | name = models.CharField(max_length=100) |
50 | 51 | |
… |
… |
class Image(models.Model):
|
52 | 53 | class Product(models.Model): |
53 | 54 | name = models.CharField(max_length=100) |
54 | 55 | image = models.OneToOneField(Image, null=True) |
| 56 | |
| 57 | |
| 58 | class Parent1(models.Model): |
| 59 | name1 = models.CharField(max_length=50) |
| 60 | def __unicode__(self): |
| 61 | return self.name1 |
| 62 | |
| 63 | |
| 64 | class Parent2(models.Model): |
| 65 | name2 = models.CharField(max_length=50) |
| 66 | def __unicode__(self): |
| 67 | return self.name2 |
| 68 | |
| 69 | |
| 70 | class Child1(Parent1, Parent2): |
| 71 | other = models.CharField(max_length=50) |
| 72 | def __unicode__(self): |
| 73 | return self.name1 |
| 74 | |
| 75 | |
| 76 | class Child2(Parent1): |
| 77 | parent2 = models.OneToOneField(Parent2) |
| 78 | other = models.CharField(max_length=50) |
| 79 | def __unicode__(self): |
| 80 | return self.name1 |
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index 643a0ff..a57142c 100644
a
|
b
|
from __future__ import with_statement, absolute_import
|
3 | 3 | from django.test import TestCase |
4 | 4 | |
5 | 5 | from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, |
6 | | AdvancedUserStat, Image, Product) |
| 6 | AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2) |
7 | 7 | |
8 | 8 | |
9 | 9 | class ReverseSelectRelatedTestCase(TestCase): |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
21 | 21 | advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, |
22 | 22 | results=results2) |
23 | 23 | StatDetails.objects.create(base_stats=advstat, comments=250) |
| 24 | p1 = Parent1(name1="Only Parent1") |
| 25 | p1.save() |
| 26 | c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2") |
| 27 | c1.save() |
| 28 | p2 = Parent2(name2="Child2 Parent2") |
| 29 | p2.save() |
| 30 | c2 = Child2(name1="Child2 Parent1", parent2=p2) |
| 31 | c2.save() |
24 | 32 | |
25 | 33 | def test_basic(self): |
26 | 34 | with self.assertNumQueries(1): |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
80 | 88 | p2 = Product.objects.create(name="Talking Django Plushie") |
81 | 89 | |
82 | 90 | self.assertEqual(len(Product.objects.select_related("image")), 2) |
| 91 | |
| 92 | def test_parent_only(self): |
| 93 | Parent1.objects.select_related('child1').get(name1="Only Parent1") |
| 94 | |
| 95 | def test_multiple_subclass(self): |
| 96 | with self.assertNumQueries(1): |
| 97 | p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") |
| 98 | self.assertEqual(p.child1.name2, u'Child1 Parent2') |
| 99 | |
| 100 | def test_onetoone_with_subclass(self): |
| 101 | with self.assertNumQueries(1): |
| 102 | p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") |
| 103 | self.assertEqual(p.child2.name1, u'Child2 Parent1') |