Ticket #18177: 18177.patch

File 18177.patch, 14.9 KB (added by Aymeric Augustin, 12 years ago)
  • django/db/models/fields/related.py

    diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py
    index a16f955..c4f95a1 100644
    a b class SingleRelatedObjectDescriptor(object):  
    237237        return self.related.model._base_manager.using(db)
    238238
    239239    def get_prefetch_query_set(self, instances):
    240         vals = set(instance._get_pk_val() for instance in instances)
    241         params = {'%s__pk__in' % self.related.field.name: vals}
    242         return (self.get_query_set(instance=instances[0]).filter(**params),
    243                 attrgetter(self.related.field.attname),
    244                 lambda obj: obj._get_pk_val(),
    245                 True,
    246                 self.cache_name)
     240        rel_obj_attr = attrgetter(self.related.field.attname)
     241        instance_attr = lambda obj: obj._get_pk_val()
     242        instances_dict = dict((instance_attr(inst), inst) for inst in instances)
     243        params = {'%s__pk__in' % self.related.field.name: instances_dict.keys()}
     244        qs = self.get_query_set(instance=instances[0]).filter(**params)
     245        # Since we're going to assign directly in the cache,
     246        # we must manage the reverse relation cache manually.
     247        rel_obj_cache_name = self.related.field.get_cache_name()
     248        for rel_obj in qs:
     249            instance = instances_dict[rel_obj_attr(rel_obj)]
     250            setattr(rel_obj, rel_obj_cache_name, instance)
     251        return qs, rel_obj_attr, instance_attr, True, self.cache_name
    247252
    248253    def __get__(self, instance, instance_type=None):
    249254        if instance is None:
    class ReverseSingleRelatedObjectDescriptor(object):  
    324329            return QuerySet(self.field.rel.to).using(db)
    325330
    326331    def get_prefetch_query_set(self, instances):
    327         vals = set(getattr(instance, self.field.attname) for instance in instances)
     332        rel_obj_attr = attrgetter(self.field.rel.field_name)
     333        instance_attr = attrgetter(self.field.attname)
     334        instances_dict = dict((instance_attr(inst), inst) for inst in instances)
    328335        other_field = self.field.rel.get_related_field()
    329336        if other_field.rel:
    330             params = {'%s__pk__in' % self.field.rel.field_name: vals}
     337            params = {'%s__pk__in' % self.field.rel.field_name: instances_dict.keys()}
    331338        else:
    332             params = {'%s__in' % self.field.rel.field_name: vals}
    333         return (self.get_query_set(instance=instances[0]).filter(**params),
    334                 attrgetter(self.field.rel.field_name),
    335                 attrgetter(self.field.attname),
    336                 True,
    337                 self.cache_name)
     339            params = {'%s__in' % self.field.rel.field_name: instances_dict.keys()}
     340        qs = self.get_query_set(instance=instances[0]).filter(**params)
     341        # Since we're going to assign directly in the cache,
     342        # we must manage the reverse relation cache manually.
     343        if not self.field.rel.multiple:
     344            rel_obj_cache_name = self.field.related.get_cache_name()
     345            for rel_obj in qs:
     346                instance = instances_dict[rel_obj_attr(rel_obj)]
     347                setattr(rel_obj, rel_obj_cache_name, instance)
     348        return qs, rel_obj_attr, instance_attr, True, self.cache_name
    338349
    339350    def __get__(self, instance, instance_type=None):
    340351        if instance is None:
    class ForeignRelatedObjectsDescriptor(object):  
    467478                    return self.instance._prefetched_objects_cache[rel_field.related_query_name()]
    468479                except (AttributeError, KeyError):
    469480                    db = self._db or router.db_for_read(self.model, instance=self.instance)
    470                     return super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
     481                    qs = super(RelatedManager, self).get_query_set().using(db).filter(**self.core_filters)
     482                    qs._known_related_object = (rel_field.name, self.instance)
     483                    return qs
    471484
    472485            def get_prefetch_query_set(self, instances):
     486                rel_obj_attr = attrgetter(rel_field.get_attname())
     487                instance_attr = attrgetter(attname)
     488                instances_dict = dict((instance_attr(inst), inst) for inst in instances)
    473489                db = self._db or router.db_for_read(self.model, instance=instances[0])
    474                 query = {'%s__%s__in' % (rel_field.name, attname):
    475                              set(getattr(obj, attname) for obj in instances)}
     490                query = {'%s__%s__in' % (rel_field.name, attname): instances_dict.keys()}
    476491                qs = super(RelatedManager, self).get_query_set().using(db).filter(**query)
    477                 return (qs,
    478                         attrgetter(rel_field.get_attname()),
    479                         attrgetter(attname),
    480                         False,
    481                         rel_field.related_query_name())
     492                # Since we just bypassed this class' get_query_set(), we must manage
     493                # the reverse relation manually.
     494                for rel_obj in qs:
     495                    instance = instances_dict[rel_obj_attr(rel_obj)]
     496                    setattr(rel_obj, rel_field.name, instance)
     497                cache_name = rel_field.related_query_name()
     498                return qs, rel_obj_attr, instance_attr, False, cache_name
    482499
    483500            def add(self, *objs):
    484501                for obj in objs:
  • django/db/models/query.py

    diff --git a/django/db/models/query.py b/django/db/models/query.py
    index 65a3697..755820c 100644
    a b class QuerySet(object):  
    4141        self._for_write = False
    4242        self._prefetch_related_lookups = []
    4343        self._prefetch_done = False
     44        self._known_related_object = None       # (attname, rel_obj)
    4445
    4546    ########################
    4647    # PYTHON MAGIC METHODS #
    class QuerySet(object):  
    282283                    init_list.append(field.attname)
    283284            model_cls = deferred_class_factory(self.model, skip)
    284285
    285         # Cache db and model outside the loop
     286        # Cache db, model and known_related_object outside the loop
    286287        db = self.db
    287288        model = self.model
     289        kro_attname, kro_instance = self._known_related_object or (None, None)
    288290        compiler = self.query.get_compiler(using=db)
    289291        if fill_cache:
    290292            klass_info = get_klass_info(model, max_depth=max_depth,
    class QuerySet(object):  
    294296                obj, _ = get_cached_row(row, index_start, db, klass_info,
    295297                                        offset=len(aggregate_select))
    296298            else:
     299                # Omit aggregates in object creation.
     300                row_data = row[index_start:aggregate_start]
    297301                if skip:
    298                     row_data = row[index_start:aggregate_start]
    299302                    obj = model_cls(**dict(zip(init_list, row_data)))
    300303                else:
    301                     # Omit aggregates in object creation.
    302                     obj = model(*row[index_start:aggregate_start])
     304                    obj = model(*row_data)
    303305
    304306                # Store the source database of the object
    305307                obj._state.db = db
    class QuerySet(object):  
    313315            # Add the aggregates to the model
    314316            if aggregate_select:
    315317                for i, aggregate in enumerate(aggregate_select):
    316                     setattr(obj, aggregate, row[i+aggregate_start])
     318                    setattr(obj, aggregate, row[i + aggregate_start])
     319
     320            # Add the known related object to the model, if there is one
     321            if kro_instance:
     322                setattr(obj, kro_attname, kro_instance)
    317323
    318324            yield obj
    319325
    class QuerySet(object):  
    864870        c = klass(model=self.model, query=query, using=self._db)
    865871        c._for_write = self._for_write
    866872        c._prefetch_related_lookups = self._prefetch_related_lookups[:]
     873        c._known_related_object = self._known_related_object
    867874        c.__dict__.update(kwargs)
    868875        if setup and hasattr(c, '_setup_query'):
    869876            c._setup_query()
    def prefetch_one_level(instances, prefetcher, attname):  
    17811788    rel_obj_cache = {}
    17821789    for rel_obj in all_related_objects:
    17831790        rel_attr_val = rel_obj_attr(rel_obj)
    1784         if rel_attr_val not in rel_obj_cache:
    1785             rel_obj_cache[rel_attr_val] = []
    1786         rel_obj_cache[rel_attr_val].append(rel_obj)
     1791        rel_obj_cache.setdefault(rel_attr_val, []).append(rel_obj)
    17871792
    17881793    for obj in instances:
    17891794        instance_attr_val = instance_attr(obj)
  • new file tests/modeltests/known_related_objects/fixtures/tournament.json

    diff --git a/tests/modeltests/known_related_objects/__init__.py b/tests/modeltests/known_related_objects/__init__.py
    new file mode 100644
    index 0000000..e69de29
    diff --git a/tests/modeltests/known_related_objects/fixtures/tournament.json b/tests/modeltests/known_related_objects/fixtures/tournament.json
    new file mode 100644
    index 0000000..2f2b1c5
    - +  
     1[
     2    {
     3        "pk": 1,
     4        "model": "known_related_objects.tournament",
     5        "fields": {
     6            "name": "Tourney 1"
     7            }
     8        },
     9    {
     10        "pk": 2,
     11        "model": "known_related_objects.tournament",
     12        "fields": {
     13            "name": "Tourney 2"
     14            }
     15        },
     16    {
     17        "pk": 1,
     18        "model": "known_related_objects.pool",
     19        "fields": {
     20            "tournament": 1,
     21            "name": "T1 Pool 1"
     22            }
     23        },
     24    {
     25        "pk": 2,
     26        "model": "known_related_objects.pool",
     27        "fields": {
     28            "tournament": 1,
     29            "name": "T1 Pool 2"
     30            }
     31        },
     32    {
     33        "pk": 3,
     34        "model": "known_related_objects.pool",
     35        "fields": {
     36            "tournament": 2,
     37            "name": "T2 Pool 1"
     38            }
     39        },
     40    {
     41        "pk": 4,
     42        "model": "known_related_objects.pool",
     43        "fields": {
     44            "tournament": 2,
     45            "name": "T2 Pool 2"
     46            }
     47        },
     48    {
     49        "pk": 1,
     50        "model": "known_related_objects.poolstyle",
     51        "fields": {
     52            "name": "T1 Pool 2 Style",
     53            "pool": 2
     54            }
     55        },
     56    {
     57        "pk": 2,
     58        "model": "known_related_objects.poolstyle",
     59        "fields": {
     60            "name": "T2 Pool 1 Style",
     61            "pool": 3
     62            }
     63        }
     64]
     65
  • new file tests/modeltests/known_related_objects/models.py

    diff --git a/tests/modeltests/known_related_objects/models.py b/tests/modeltests/known_related_objects/models.py
    new file mode 100644
    index 0000000..4c516dd
    - +  
     1"""
     2Existing related object instance caching.
     3
     4Test that queries are not redone when going back through known relations.
     5"""
     6
     7from django.db import models
     8
     9class Tournament(models.Model):
     10    name = models.CharField(max_length=30)
     11
     12class Pool(models.Model):
     13    name = models.CharField(max_length=30)
     14    tournament = models.ForeignKey(Tournament)
     15
     16class PoolStyle(models.Model):
     17    name = models.CharField(max_length=30)
     18    pool = models.OneToOneField(Pool)
     19
  • new file tests/modeltests/known_related_objects/tests.py

    diff --git a/tests/modeltests/known_related_objects/tests.py b/tests/modeltests/known_related_objects/tests.py
    new file mode 100644
    index 0000000..24feab2
    - +  
     1from __future__ import absolute_import
     2
     3from django.test import TestCase
     4
     5from .models import Tournament, Pool, PoolStyle
     6
     7class ExistingRelatedInstancesTests(TestCase):
     8    fixtures = ['tournament.json']
     9
     10    def test_foreign_key(self):
     11        with self.assertNumQueries(2):
     12            tournament = Tournament.objects.get(pk=1)
     13            pool = tournament.pool_set.all()[0]
     14            self.assertIs(tournament, pool.tournament)
     15
     16    def test_foreign_key_prefetch_related(self):
     17        with self.assertNumQueries(2):
     18            tournament = (Tournament.objects.prefetch_related('pool_set').get(pk=1))
     19            pool = tournament.pool_set.all()[0]
     20            self.assertIs(tournament, pool.tournament)
     21
     22    def test_foreign_key_multiple_prefetch(self):
     23        with self.assertNumQueries(2):
     24            tournaments = list(Tournament.objects.prefetch_related('pool_set'))
     25            pool1 = tournaments[0].pool_set.all()[0]
     26            self.assertIs(tournaments[0], pool1.tournament)
     27            pool2 = tournaments[1].pool_set.all()[0]
     28            self.assertIs(tournaments[1], pool2.tournament)
     29
     30    def test_one_to_one(self):
     31        with self.assertNumQueries(2):
     32            style = PoolStyle.objects.get(pk=1)
     33            pool = style.pool
     34            self.assertIs(style, pool.poolstyle)
     35
     36    def test_one_to_one_select_related(self):
     37        with self.assertNumQueries(1):
     38            style = PoolStyle.objects.select_related('pool').get(pk=1)
     39            pool = style.pool
     40            self.assertIs(style, pool.poolstyle)
     41
     42    def test_one_to_one_multi_select_related(self):
     43        with self.assertNumQueries(1):
     44            poolstyles = list(PoolStyle.objects.select_related('pool'))
     45            self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
     46            self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
     47
     48    def test_one_to_one_prefetch_related(self):
     49        with self.assertNumQueries(2):
     50            style = PoolStyle.objects.prefetch_related('pool').get(pk=1)
     51            pool = style.pool
     52            self.assertIs(style, pool.poolstyle)
     53
     54    def test_one_to_one_multi_prefetch_related(self):
     55        with self.assertNumQueries(2):
     56            poolstyles = list(PoolStyle.objects.prefetch_related('pool'))
     57            self.assertIs(poolstyles[0], poolstyles[0].pool.poolstyle)
     58            self.assertIs(poolstyles[1], poolstyles[1].pool.poolstyle)
     59
     60    def test_reverse_one_to_one(self):
     61        with self.assertNumQueries(2):
     62            pool = Pool.objects.get(pk=2)
     63            style = pool.poolstyle
     64            self.assertIs(pool, style.pool)
     65
     66    def test_reverse_one_to_one_select_related(self):
     67        with self.assertNumQueries(1):
     68            pool = Pool.objects.select_related('poolstyle').get(pk=2)
     69            style = pool.poolstyle
     70            self.assertIs(pool, style.pool)
     71
     72    def test_reverse_one_to_one_prefetch_related(self):
     73        with self.assertNumQueries(2):
     74            pool = Pool.objects.prefetch_related('poolstyle').get(pk=2)
     75            style = pool.poolstyle
     76            self.assertIs(pool, style.pool)
     77
     78    def test_reverse_one_to_one_multi_select_related(self):
     79        with self.assertNumQueries(1):
     80            pools = list(Pool.objects.select_related('poolstyle'))
     81            self.assertIs(pools[1], pools[1].poolstyle.pool)
     82            self.assertIs(pools[2], pools[2].poolstyle.pool)
     83
     84    def test_reverse_one_to_one_multi_prefetch_related(self):
     85        with self.assertNumQueries(2):
     86            pools = list(Pool.objects.prefetch_related('poolstyle'))
     87            self.assertIs(pools[1], pools[1].poolstyle.pool)
     88            self.assertIs(pools[2], pools[2].poolstyle.pool)
Back to Top