diff --git a/django/contrib/sites/managers.py b/django/contrib/sites/managers.py
index 3df485a..cf38d16 100644
a
|
b
|
|
1 | 1 | from django.conf import settings |
2 | 2 | from django.db import models |
3 | 3 | from django.db.models.fields import FieldDoesNotExist |
| 4 | from django.db.models.sql import constants |
4 | 5 | |
5 | 6 | class CurrentSiteManager(models.Manager): |
6 | 7 | "Use this to limit objects to those associated with the current site." |
… |
… |
class CurrentSiteManager(models.Manager):
|
10 | 11 | self.__is_validated = False |
11 | 12 | |
12 | 13 | def _validate_field_name(self): |
13 | | field_names = self.model._meta.get_all_field_names() |
14 | | |
15 | | # If a custom name is provided, make sure the field exists on the model |
16 | | if self.__field_name is not None and self.__field_name not in field_names: |
17 | | raise ValueError("%s couldn't find a field named %s in %s." % \ |
18 | | (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) |
19 | | |
20 | | # Otherwise, see if there is a field called either 'site' or 'sites' |
21 | | else: |
| 14 | """ |
| 15 | Given the field identifier, goes down the chain to check that each |
| 16 | specified field |
| 17 | |
| 18 | a) exists, |
| 19 | b) is of type ForeignKey or ManyToManyField |
| 20 | |
| 21 | If no field name is specified when instantiating |
| 22 | CurrentSiteManager, it tries to find either 'site' or 'sites' as |
| 23 | the site link. |
| 24 | """ |
| 25 | if self.__field_name is None: |
| 26 | # Guess at field name |
| 27 | field_names = self.model._meta.get_all_field_names() |
22 | 28 | for potential_name in ['site', 'sites']: |
23 | 29 | if potential_name in field_names: |
24 | 30 | self.__field_name = potential_name |
25 | | self.__is_validated = True |
26 | 31 | break |
27 | | |
28 | | # Now do a type check on the field (FK or M2M only) |
| 32 | else: |
| 33 | raise ValueError( |
| 34 | "%s couldn't find a field named either 'site' or 'sites' in %s." % |
| 35 | (self.__class__.__name__, self.model._meta.object_name) |
| 36 | ) |
| 37 | |
| 38 | fieldname_chain = self.__field_name.split(constants.LOOKUP_SEP) |
| 39 | model = self.model |
| 40 | |
| 41 | for fieldname in fieldname_chain: |
| 42 | # Throws an exception if anything goes bad |
| 43 | self._validate_single_field_name(model, fieldname) |
| 44 | model = self._get_related_model(model, fieldname) |
| 45 | |
| 46 | # If we get this far without an exception, everything is good |
| 47 | self.__is_validated = True |
| 48 | |
| 49 | def _validate_single_field_name(self, model, field_name): |
| 50 | """ |
| 51 | Checks if the given field name can be used to make a link between a |
| 52 | model and a site with the CurrentSiteManager class |
| 53 | """ |
29 | 54 | try: |
30 | | field = self.model._meta.get_field(self.__field_name) |
| 55 | field = model._meta.get_field(field_name) |
31 | 56 | if not isinstance(field, (models.ForeignKey, models.ManyToManyField)): |
32 | | raise TypeError("%s must be a ForeignKey or ManyToManyField." %self.__field_name) |
| 57 | raise TypeError( |
| 58 | "Field %s of model %s must be a ForeignKey or ManyToManyField." % ( |
| 59 | field_name, |
| 60 | model._meta.object_name |
| 61 | ) |
| 62 | ) |
33 | 63 | except FieldDoesNotExist: |
34 | 64 | raise ValueError("%s couldn't find a field named %s in %s." % \ |
35 | | (self.__class__.__name__, self.__field_name, self.model._meta.object_name)) |
36 | | self.__is_validated = True |
| 65 | (self.__class__.__name__, field_name, model._meta.object_name)) |
37 | 66 | |
| 67 | def _get_related_model(self, model, fieldname): |
| 68 | """ |
| 69 | Given a model and the name of a ForeignKey or ManyToManyField field |
| 70 | name as a string, returns the associated model. |
| 71 | """ |
| 72 | return model._meta.get_field_by_name(fieldname)[0].rel.to |
| 73 | |
38 | 74 | def get_query_set(self): |
39 | 75 | if not self.__is_validated: |
40 | 76 | self._validate_field_name() |
diff --git a/docs/ref/contrib/sites.txt b/docs/ref/contrib/sites.txt
index 8fc434b..6d3ae02 100644
a
|
b
|
demonstrates this::
|
349 | 349 | If you attempt to use :class:`~django.contrib.sites.managers.CurrentSiteManager` |
350 | 350 | and pass a field name that doesn't exist, Django will raise a ``ValueError``. |
351 | 351 | |
| 352 | .. versionchanged:: 1.5 |
| 353 | |
| 354 | :class:`~django.contrib.sites.managers.CurrentSiteManager` can span |
| 355 | multiple models by using the same syntax as queries, as per the |
| 356 | :ref:`models and database queries documentation<field-lookups-intro>`. For |
| 357 | example, using the ``Photo`` model defined above:: |
| 358 | |
| 359 | from django.db import models |
| 360 | from django.contrib.sites.managers import CurrentSiteManager |
| 361 | |
| 362 | class PhotoLocation(models.Model): |
| 363 | name = models.CharField(max_length=100) |
| 364 | photo = models.ForeignKey(Photo) |
| 365 | objects = models.Manager() |
| 366 | on_site = CurrentSiteManager('photo__publish_on') |
| 367 | |
| 368 | ``PhotoLocation.on_site.all()`` will return all ``PhotoLocation`` objects |
| 369 | in the database associated with ``Photo`` objects which themselves are |
| 370 | associated with the current site. |
| 371 | |
352 | 372 | Finally, note that you'll probably want to keep a normal |
353 | 373 | (non-site-specific) ``Manager`` on your model, even if you use |
354 | 374 | :class:`~django.contrib.sites.managers.CurrentSiteManager`. As |
diff --git a/tests/regressiontests/sites_framework/models.py b/tests/regressiontests/sites_framework/models.py
index 9ecc3e6..90d26f1 100644
a
|
b
|
class InvalidArticle(AbstractArticle):
|
34 | 34 | |
35 | 35 | class ConfusedArticle(AbstractArticle): |
36 | 36 | site = models.IntegerField() |
| 37 | |
| 38 | class ArticleComment(models.Model): |
| 39 | parent_article = models.ForeignKey(ExclusiveArticle) |
| 40 | text = models.CharField(max_length=50) |
| 41 | |
| 42 | objects = models.Manager() |
| 43 | on_site = CurrentSiteManager("parent_article__site") |
diff --git a/tests/regressiontests/sites_framework/tests.py b/tests/regressiontests/sites_framework/tests.py
index 8e664fd..2ac0635 100644
a
|
b
|
from django.contrib.sites.models import Site
|
5 | 5 | from django.test import TestCase |
6 | 6 | |
7 | 7 | from .models import (SyndicatedArticle, ExclusiveArticle, CustomArticle, |
8 | | InvalidArticle, ConfusedArticle) |
| 8 | InvalidArticle, ConfusedArticle, ArticleComment) |
9 | 9 | |
10 | 10 | |
11 | 11 | class SitesFrameworkTestCase(TestCase): |
… |
… |
class SitesFrameworkTestCase(TestCase):
|
18 | 18 | self.assertEqual(ExclusiveArticle.on_site.all().get(), article) |
19 | 19 | |
20 | 20 | def test_sites_m2m(self): |
21 | | article = SyndicatedArticle.objects.create(title="Fresh News!") |
22 | | article.sites.add(Site.objects.get(id=settings.SITE_ID)) |
23 | | article.sites.add(Site.objects.get(id=settings.SITE_ID+1)) |
24 | | article2 = SyndicatedArticle.objects.create(title="More News!") |
25 | | article2.sites.add(Site.objects.get(id=settings.SITE_ID+1)) |
26 | | self.assertEqual(SyndicatedArticle.on_site.all().get(), article) |
| 21 | first_article = SyndicatedArticle.objects.create(title="Fresh News!") |
| 22 | first_article.sites.add(Site.objects.get(id=settings.SITE_ID)) |
| 23 | first_article.sites.add(Site.objects.get(id=settings.SITE_ID+1)) |
| 24 | second_article = SyndicatedArticle.objects.create(title="More News!") |
| 25 | second_article.sites.add(Site.objects.get(id=settings.SITE_ID+1)) |
| 26 | self.assertEqual(SyndicatedArticle.on_site.all().get(), first_article) |
27 | 27 | |
28 | 28 | def test_custom_named_field(self): |
29 | 29 | article = CustomArticle.objects.create(title="Tantalizing News!", places_this_article_should_appear_id=settings.SITE_ID) |
… |
… |
class SitesFrameworkTestCase(TestCase):
|
36 | 36 | def test_invalid_field_type(self): |
37 | 37 | article = ConfusedArticle.objects.create(title="More Bad News!", site=settings.SITE_ID) |
38 | 38 | self.assertRaises(TypeError, ConfusedArticle.on_site.all) |
| 39 | |
| 40 | def test_indirect_link(self): |
| 41 | first_article = ExclusiveArticle.objects.create( |
| 42 | title="Breaking News!", |
| 43 | site_id=settings.SITE_ID |
| 44 | ) |
| 45 | second_article = ExclusiveArticle.objects.create( |
| 46 | title="More News!", |
| 47 | site_id=settings.SITE_ID + 1 |
| 48 | ) |
| 49 | |
| 50 | comment = ArticleComment.objects.create( |
| 51 | parent_article=first_article, |
| 52 | text="First post!" |
| 53 | ) |
| 54 | comment2 = ArticleComment.objects.create( |
| 55 | parent_article=second_article, |
| 56 | text="Second post." |
| 57 | ) |
| 58 | |
| 59 | self.assertEqual( |
| 60 | ArticleComment.on_site.all().get(), |
| 61 | comment |
| 62 | ) |