diff --git a/django/db/models/query.py b/django/db/models/query.py
index 0210a79..e144956 100644
a
|
b
|
class EmptyQuerySet(QuerySet):
|
1276 | 1276 | value_annotation = False |
1277 | 1277 | |
1278 | 1278 | def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None, |
1279 | | only_load=None, local_only=False): |
| 1279 | only_load=None, local_only=False, last_klass=None): |
1280 | 1280 | """ |
1281 | 1281 | Helper function that recursively returns an information for a klass, to be |
1282 | 1282 | used in get_cached_row. It exists just to compute this information only |
… |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
1298 | 1298 | the full field list for `klass` can be assumed. |
1299 | 1299 | * local_only - Only populate local fields. This is used when |
1300 | 1300 | following reverse select-related relations |
| 1301 | * last_klass - the last class seen when following reverse |
| 1302 | select-related relations |
1301 | 1303 | """ |
1302 | 1304 | if max_depth and requested is None and cur_depth > max_depth: |
1303 | 1305 | # We've recursed deeply enough; stop now. |
… |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
1343 | 1345 | # But kwargs version of Model.__init__ is slower, so we should avoid using |
1344 | 1346 | # it when it is not really neccesary. |
1345 | 1347 | if local_only and len(klass._meta.local_fields) != len(klass._meta.fields): |
1346 | | field_count = len(klass._meta.local_fields) |
1347 | | field_names = [f.attname for f in klass._meta.local_fields] |
| 1348 | parents = [p for p in klass._meta.get_parent_list() |
| 1349 | if p is not last_klass] |
| 1350 | field_names = [f.attname for f in klass._meta.fields |
| 1351 | if f in klass._meta.local_fields |
| 1352 | or f.model in parents] |
| 1353 | field_count = len(field_names) |
| 1354 | if field_count == len(klass._meta.fields): |
| 1355 | field_names = () |
1348 | 1356 | else: |
1349 | 1357 | field_count = len(klass._meta.fields) |
1350 | 1358 | field_names = () |
… |
… |
def get_klass_info(klass, max_depth=0, cur_depth=0, requested=None,
|
1369 | 1377 | only_load.get(o.model), reverse=True): |
1370 | 1378 | next = requested[o.field.related_query_name()] |
1371 | 1379 | klass_info = get_klass_info(o.model, max_depth=max_depth, cur_depth=cur_depth+1, |
1372 | | requested=next, only_load=only_load, local_only=True) |
| 1380 | requested=next, only_load=only_load, local_only=True, |
| 1381 | last_klass=klass) |
1373 | 1382 | reverse_related_fields.append((o.field, klass_info)) |
1374 | 1383 | |
1375 | 1384 | return klass, field_names, field_count, related_fields, reverse_related_fields |
… |
… |
def get_cached_row(row, index_start, using, klass_info, offset=0):
|
1455 | 1464 | # Now populate all the non-local field values |
1456 | 1465 | # on the related object |
1457 | 1466 | for rel_field, rel_model in rel_obj._meta.get_fields_with_model(): |
1458 | | if rel_model is not None: |
| 1467 | if rel_model is not None and isinstance(obj, rel_model): |
1459 | 1468 | setattr(rel_obj, rel_field.attname, getattr(obj, rel_field.attname)) |
1460 | 1469 | # populate the field cache for any related object |
1461 | 1470 | # that has already been retrieved |
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py
index a68f6e0..b8cf74d 100644
a
|
b
|
class SQLCompiler(object):
|
249 | 249 | return result |
250 | 250 | |
251 | 251 | def get_default_columns(self, with_aliases=False, col_aliases=None, |
252 | | start_alias=None, opts=None, as_pairs=False, local_only=False): |
| 252 | start_alias=None, opts=None, as_pairs=False, local_only=False, |
| 253 | last_opts=None): |
253 | 254 | """ |
254 | 255 | Computes the default columns for selecting every field in the base |
255 | 256 | model. Will sometimes be called to pull in related models (e.g. via |
… |
… |
class SQLCompiler(object):
|
273 | 274 | |
274 | 275 | if start_alias: |
275 | 276 | seen = {None: start_alias} |
| 277 | parents = [p for p in opts.get_parent_list() if p._meta is not last_opts] |
276 | 278 | for field, model in opts.get_fields_with_model(): |
277 | | if local_only and model is not None: |
| 279 | if local_only and model is not None and model not in parents: |
278 | 280 | continue |
279 | 281 | if start_alias: |
280 | 282 | try: |
… |
… |
class SQLCompiler(object):
|
282 | 284 | except KeyError: |
283 | 285 | link_field = opts.get_ancestor_link(model) |
284 | 286 | alias = self.query.join((start_alias, model._meta.db_table, |
285 | | link_field.column, model._meta.pk.column)) |
| 287 | link_field.column, model._meta.pk.column), |
| 288 | promote=(model in parents)) |
286 | 289 | seen[model] = alias |
287 | 290 | else: |
288 | 291 | # If we're starting from the base model of the queryset, the |
… |
… |
class SQLCompiler(object):
|
733 | 736 | ) |
734 | 737 | used.add(alias) |
735 | 738 | columns, aliases = self.get_default_columns(start_alias=alias, |
736 | | opts=model._meta, as_pairs=True, local_only=True) |
| 739 | opts=model._meta, as_pairs=True, local_only=True, |
| 740 | last_opts=opts) |
737 | 741 | self.query.related_select_cols.extend(columns) |
738 | 742 | self.query.related_select_fields.extend(model._meta.fields) |
739 | 743 | |
diff --git a/tests/django b/tests/django
new file mode 120000
index 0000000..8016dee
-
|
+
|
|
| 1 | ../django |
| 2 | No newline at end of file |
diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py
index 3284def..6216e0c 100644
a
|
b
|
class StatDetails(models.Model):
|
51 | 51 | class AdvancedUserStat(UserStat): |
52 | 52 | karma = models.IntegerField() |
53 | 53 | |
| 54 | |
54 | 55 | class Image(models.Model): |
55 | 56 | name = models.CharField(max_length=100) |
56 | 57 | |
… |
… |
class Image(models.Model):
|
58 | 59 | class Product(models.Model): |
59 | 60 | name = models.CharField(max_length=100) |
60 | 61 | image = models.OneToOneField(Image, null=True) |
| 62 | |
| 63 | |
| 64 | class Parent1(models.Model): |
| 65 | name1 = models.CharField(max_length=50) |
| 66 | def __unicode__(self): |
| 67 | return self.name1 |
| 68 | |
| 69 | |
| 70 | class Parent2(models.Model): |
| 71 | name2 = models.CharField(max_length=50) |
| 72 | def __unicode__(self): |
| 73 | return self.name2 |
| 74 | |
| 75 | |
| 76 | class Child1(Parent1, Parent2): |
| 77 | other = models.CharField(max_length=50) |
| 78 | def __unicode__(self): |
| 79 | return self.name1 |
| 80 | |
| 81 | |
| 82 | class Child2(Parent1): |
| 83 | parent2 = models.OneToOneField(Parent2) |
| 84 | other = models.CharField(max_length=50) |
| 85 | def __unicode__(self): |
| 86 | return self.name1 |
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py
index 1373f04..3c8623f 100644
a
|
b
|
from __future__ import absolute_import
|
3 | 3 | from django.test import TestCase |
4 | 4 | |
5 | 5 | from .models import (User, UserProfile, UserStat, UserStatResult, StatDetails, |
6 | | AdvancedUserStat, Image, Product) |
| 6 | AdvancedUserStat, Image, Product, Parent1, Parent2, Child1, Child2) |
7 | 7 | |
8 | 8 | |
9 | 9 | class ReverseSelectRelatedTestCase(TestCase): |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
21 | 21 | advstat = AdvancedUserStat.objects.create(user=user2, posts=200, karma=5, |
22 | 22 | results=results2) |
23 | 23 | StatDetails.objects.create(base_stats=advstat, comments=250) |
| 24 | p1 = Parent1(name1="Only Parent1") |
| 25 | p1.save() |
| 26 | c1 = Child1(name1="Child1 Parent1", name2="Child1 Parent2") |
| 27 | c1.save() |
| 28 | p2 = Parent2(name2="Child2 Parent2") |
| 29 | p2.save() |
| 30 | c2 = Child2(name1="Child2 Parent1", parent2=p2) |
| 31 | c2.save() |
24 | 32 | |
25 | 33 | def test_basic(self): |
26 | 34 | with self.assertNumQueries(1): |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
79 | 87 | p1 = Product.objects.create(name="Django Plushie", image=im) |
80 | 88 | p2 = Product.objects.create(name="Talking Django Plushie") |
81 | 89 | |
| 90 | self.assertEqual(len(Product.objects.select_related("image")), 2) |
| 91 | |
82 | 92 | with self.assertNumQueries(1): |
83 | 93 | result = sorted(Product.objects.select_related("image"), key=lambda x: x.name) |
84 | 94 | self.assertEqual([p.name for p in result], ["Django Plushie", "Talking Django Plushie"]) |
… |
… |
class ReverseSelectRelatedTestCase(TestCase):
|
108 | 118 | image = Image.objects.select_related('product').get() |
109 | 119 | with self.assertRaises(Product.DoesNotExist): |
110 | 120 | image.product |
| 121 | |
| 122 | def test_parent_only(self): |
| 123 | Parent1.objects.select_related('child1').get(name1="Only Parent1") |
| 124 | |
| 125 | def test_multiple_subclass(self): |
| 126 | with self.assertNumQueries(1): |
| 127 | p = Parent1.objects.select_related('child1').get(name1="Child1 Parent1") |
| 128 | self.assertEqual(p.child1.name2, u'Child1 Parent2') |
| 129 | |
| 130 | def test_onetoone_with_subclass(self): |
| 131 | with self.assertNumQueries(1): |
| 132 | p = Parent2.objects.select_related('child2').get(name2="Child2 Parent2") |
| 133 | self.assertEqual(p.child2.name1, u'Child2 Parent1') |