Ticket #7270: reverse_select_related.diff
File reverse_select_related.diff, 11.5 KB (added by , 15 years ago) |
---|
-
django/db/models/fields/related.py
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 8fec836..8e27244 100644
a b class SingleRelatedObjectDescriptor(object): 188 188 # SingleRelatedObjectDescriptor instance. 189 189 def __init__(self, related): 190 190 self.related = related 191 self.cache_name = '_%s_cache' % related.get_accessor_name()191 self.cache_name = related.get_accessor_cache() 192 192 193 193 def __get__(self, instance, instance_type=None): 194 194 if instance is None: … … class ReverseSingleRelatedObjectDescriptor(object): 307 307 # cache. This cache also might not exist if the related object 308 308 # hasn't been accessed yet. 309 309 if related: 310 cache_name = '_%s_cache' % self.field.related.get_accessor_name()310 cache_name = self.field.related.get_accessor_cache() 311 311 try: 312 312 delattr(related, cache_name) 313 313 except AttributeError: -
django/db/models/query.py
diff --git a/django/db/models/query.py b/django/db/models/query.py index 4e3326a..afbcc24 100644
a b def get_cached_row(klass, row, index_start, max_depth=0, cur_depth=0, 1147 1147 rel_obj, index_end = cached_row 1148 1148 if obj is not None: 1149 1149 setattr(obj, f.get_cache_name(), rel_obj) 1150 if f.unique: 1151 setattr(rel_obj, f.related.get_accessor_cache(), obj) 1152 1153 if restricted: 1154 related_fields = [(o.field, o.model) for o in klass._meta.get_all_related_objects() 1155 if o.field.unique and o.field.related_query_name() in requested] 1156 for f, model in related_fields: 1157 next = requested.get(f.related_query_name(), {}) 1158 cached_row = get_cached_row(model, row, index_end, max_depth, 1159 cur_depth+1, next) 1160 if cached_row: 1161 rel_obj, index_end = cached_row 1162 if obj is not None: 1163 setattr(obj, f.related.get_accessor_cache(), rel_obj) 1164 if rel_obj is not None: 1165 setattr(rel_obj, f.get_cache_name(), obj) 1166 1167 1150 1168 return obj, index_end 1151 1169 1152 1170 def delete_objects(seen_objs, using): -
django/db/models/related.py
diff --git a/django/db/models/related.py b/django/db/models/related.py index afdf3f7..54258ca 100644
a b class RelatedObject(object): 45 45 return self.field.rel.related_name or (self.opts.object_name.lower() + '_set') 46 46 else: 47 47 return self.field.rel.related_name or (self.opts.object_name.lower()) 48 49 def get_accessor_cache(self): 50 return "_%s_cache" % self.get_accessor_name() -
django/db/models/sql/compiler.py
diff --git a/django/db/models/sql/compiler.py b/django/db/models/sql/compiler.py index 6a95d32..99d4c7e 100644
a b class SQLCompiler(object): 520 520 521 521 # Setup for the case when only particular related fields should be 522 522 # included in the related selection. 523 if requested is None and restricted is not False:523 if requested is None: 524 524 if isinstance(self.query.select_related, dict): 525 525 requested = self.query.select_related 526 526 restricted = True … … class SQLCompiler(object): 600 600 self.fill_related_selections(f.rel.to._meta, alias, cur_depth + 1, 601 601 used, next, restricted, new_nullable, dupe_set, avoid) 602 602 603 if restricted and requested is not None: 604 related_fields = [(o.field, o.model) for o in opts.get_all_related_objects() 605 if o.field.unique and o.field.related_query_name() in requested 606 ] 607 for f, model in related_fields: 608 table = model._meta.db_table 609 int_opts = opts 610 alias = root_alias 611 alias_chain = [] 612 chain = opts.get_base_chain(f.rel.to) 613 avoid = avoid_set.copy() 614 if chain is not None: 615 for int_model in chain: 616 if not int_opts.parents[int_model]: 617 int_opts = int_model._meta 618 continue 619 lhs_col = int_opts.parents[int_model].column 620 dedupe = lhs_col in opts.duplicate_targets 621 if dedupe: 622 avoid.update(self.query.dupe_avoidance.get(id(opts), lhs_col), ()) 623 dupe_set.add((opts, lhs_col)) 624 int_opts = int_model._meta 625 alias = self.query.join( 626 (alias, int_opts.db_table, lhs_col, int_opts.pk.column), 627 exclusions=used, promote=True, reuse=used 628 ) 629 alias_chain.append(alias) 630 for dupe_opts, dupe_col in dupe_set: 631 self.query.update_dupe_avoidance(dupe_opts, dupe_col, alias) 632 dedupe = f.column in opts.duplicate_targets 633 if dupe_set or dedupe: 634 avoid.update(self.query.dupe_avoidance.get((id(opts), f.column), ())) 635 if dedupe: 636 dupe_set.add((opts, f.column)) 637 alias = self.query.join( 638 (alias, table, f.rel.get_related_field().column, f.column), 639 exclusions=used.union(avoid), 640 promote=True 641 ) 642 used.add(alias) 643 columns, aliases = self.get_default_columns(start_alias=alias, 644 opts=model._meta, as_pairs=True) 645 self.query.related_select_cols.extend(columns) 646 self.query.related_select_fields.extend(model._meta.fields) 647 648 next = requested.get(f.related_query_name(), {}) 649 new_nullable = f.null or None 650 651 self.fill_related_selections(model._meta, table, cur_depth+1, 652 used, next, restricted, new_nullable) 653 603 654 def deferred_to_columns(self): 604 655 """ 605 656 Converts the self.deferred_loading data structure to mapping of table -
new file tests/regressiontests/select_related_onetoone/models.py
diff --git a/tests/regressiontests/select_related_onetoone/__init__.py b/tests/regressiontests/select_related_onetoone/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/regressiontests/select_related_onetoone/models.py b/tests/regressiontests/select_related_onetoone/models.py new file mode 100644 index 0000000..05efadf
- + 1 from django.db import models 2 3 4 class User(models.Model): 5 username = models.CharField(max_length=100) 6 email = models.EmailField() 7 8 def __unicode__(self): 9 return self.username 10 11 12 class UserProfile(models.Model): 13 user = models.OneToOneField(User) 14 city = models.CharField(max_length=100) 15 state = models.CharField(max_length=2) 16 17 def __unicode__(self): 18 return "%s, %s" % (self.city, self.state) 19 20 21 class UserStatResult(models.Model): 22 results = models.CharField(max_length=50) 23 24 def __unicode__(self): 25 return 'UserStatResults, results = %s' % (self.results,) 26 27 28 class UserStat(models.Model): 29 user = models.OneToOneField(User, primary_key=True) 30 posts = models.IntegerField() 31 results = models.ForeignKey(UserStatResult) 32 33 def __unicode__(self): 34 return 'UserStat, posts = %s' % (self.posts,) 35 36 class StatDetails(models.Model): 37 base_stats = models.OneToOneField(UserStat) 38 comments = models.IntegerField() 39 40 def __unicode__(self): 41 return 'StatDetails, comments = %s' % (self.comments,) 42 43 class AdvancedUserStat(UserStat): 44 pass -
new file tests/regressiontests/select_related_onetoone/tests.py
diff --git a/tests/regressiontests/select_related_onetoone/tests.py b/tests/regressiontests/select_related_onetoone/tests.py new file mode 100644 index 0000000..08e798b
- + 1 from django import db 2 from django.conf import settings 3 from django.test import TestCase 4 5 from models import User, UserProfile, UserStat, UserStatResult, StatDetails, AdvancedUserStat 6 7 class ReverseSelectRelatedTestCase(TestCase): 8 def setUp(self): 9 self.old_debug = settings.DEBUG 10 settings.DEBUG = True 11 12 user = User.objects.create(username="test") 13 userprofile = UserProfile.objects.create(user=user, state="KS", 14 city="Lawrence") 15 results = UserStatResult.objects.create(results='first results') 16 userstat = UserStat.objects.create(user=user, posts=150, 17 results=results) 18 details = StatDetails.objects.create(base_stats=userstat, comments=259) 19 20 user2 = User.objects.create(username="bob") 21 results2 = UserStatResult.objects.create(results='moar results') 22 advstat = AdvancedUserStat.objects.create(user=user2, posts=200, 23 results=results2) 24 StatDetails.objects.create(base_stats=advstat, comments=250) 25 26 db.reset_queries() 27 28 def assertQueries(self, queries): 29 self.assertEqual(len(db.connection.queries), queries) 30 31 def tearDown(self): 32 settings.DEBUG = self.old_debug 33 34 def test_basic(self): 35 u = User.objects.select_related("userprofile").get(username="test") 36 self.assertEqual(u.userprofile.state, "KS") 37 self.assertQueries(1) 38 39 def test_follow_next_level(self): 40 u = User.objects.select_related("userstat__results").get(username="test") 41 self.assertEqual(u.userstat.posts, 150) 42 self.assertEqual(u.userstat.results.results, 'first results') 43 self.assertQueries(1) 44 45 def test_follow_two(self): 46 u = User.objects.select_related("userprofile", "userstat").get(username="test") 47 self.assertEqual(u.userprofile.state, "KS") 48 self.assertEqual(u.userstat.posts, 150) 49 self.assertQueries(1) 50 51 def test_follow_two_next_level(self): 52 u = User.objects.select_related("userstat__results", "userstat__statdetails").get(username="test") 53 self.assertEqual(u.userstat.results.results, 'first results') 54 self.assertEqual(u.userstat.statdetails.comments, 259) 55 self.assertQueries(1) 56 57 def test_forward_and_back(self): 58 stat = UserStat.objects.select_related("user__userprofile").get(user__username="test") 59 self.assertEqual(stat.user.userprofile.state, 'KS') 60 self.assertEqual(stat.user.userstat.posts, 150) 61 self.assertQueries(1) 62 63 def test_back_and_forward(self): 64 u = User.objects.select_related("userstat").get(username="test") 65 self.assertEqual(u.userstat.user.username, 'test') 66 self.assertQueries(1) 67 68 def test_not_followed_by_default(self): 69 u = User.objects.select_related().get(username="test") 70 self.assertEqual(u.userstat.posts, 150) 71 self.assertQueries(2) 72 73 def test_follow_from_child_class(self): 74 stat = AdvancedUserStat.objects.select_related("statdetails").get(posts=200) 75 self.assertEqual(stat.statdetails.comments, 250) 76 self.assertQueries(1) 77 78 def test_follow_inheritance(self): 79 stat = UserStat.objects.select_related('advanceduserstat').get(posts=200) 80 self.assertEqual(stat.advanceduserstat.posts, 200) 81 self.assertQueries(1)