Ticket #18823: m2m_through_field.patch
File m2m_through_field.patch, 5.5 KB (added by , 12 years ago) |
---|
-
django/db/models/fields/related.py
diff --git a/django/db/models/fields/related.py b/django/db/models/fields/related.py index 08cc0a7..c6b9f90 100644
a b def create_many_related_manager(superclass, rel): 558 558 self.reverse = reverse 559 559 self.through = through 560 560 self.prefetch_cache_name = prefetch_cache_name 561 self._pk_val = self. instance.pk561 self._pk_val = self._get_fk_val(self.instance, source_field_name) 562 562 if self._pk_val is None: 563 563 raise ValueError("%r instance needs to have a primary key value before a many-to-many relationship can be used." % instance.__class__.__name__) 564 564 565 def _get_fk_val(self, obj, field_name): 566 # Get's the correct value for this relationship 567 # takes to_field into account 568 fk = self.through._meta.get_field(field_name) 569 value = obj.pk 570 if fk.rel.field_name and fk.rel.field_name != fk.rel.to._meta.pk.attname: 571 attname = fk.rel.get_related_field().get_attname() 572 value = fk.get_prep_lookup('exact', getattr(obj, attname)) 573 return value 574 565 575 def get_query_set(self): 566 576 try: 567 577 return self.instance._prefetched_objects_cache[self.prefetch_cache_name] … … def create_many_related_manager(superclass, rel): 662 672 if not router.allow_relation(obj, self.instance): 663 673 raise ValueError('Cannot add "%r": instance is on database "%s", value is on database "%s"' % 664 674 (obj, self.instance._state.db, obj._state.db)) 665 new_ids.add( obj.pk)675 new_ids.add(self._get_fk_val(obj, target_field_name)) 666 676 elif isinstance(obj, Model): 667 677 raise TypeError("'%s' instance expected, got %r" % (self.model._meta.object_name, obj)) 668 678 else: … … def create_many_related_manager(superclass, rel): 689 699 }) 690 700 for obj_id in new_ids 691 701 ]) 702 692 703 if self.reverse or source_field_name == self.source_field_name: 693 704 # Don't send the signal when we are inserting the 694 705 # duplicate data row for symmetrical reverse entries. … … def create_many_related_manager(superclass, rel): 707 718 old_ids = set() 708 719 for obj in objs: 709 720 if isinstance(obj, self.model): 710 old_ids.add( obj.pk)721 old_ids.add(self._get_fk_val(obj, target_field_name)) 711 722 else: 712 723 old_ids.add(obj) 713 724 # Work out what DB we're operating on -
tests/regressiontests/m2m_through_regress/models.py
diff --git a/tests/regressiontests/m2m_through_regress/models.py b/tests/regressiontests/m2m_through_regress/models.py index 47c24ed..73a4645 100644
a b class CarDriver(models.Model): 80 80 car = models.ForeignKey('Car', to_field='make') 81 81 driver = models.ForeignKey('Driver', to_field='name') 82 82 83 class Meta: 84 auto_created = Car 85 83 86 def __str__(self): 84 87 return "pk=%s car=%s driver=%s" % (str(self.pk), self.car, self.driver) -
tests/regressiontests/m2m_through_regress/tests.py
diff --git a/tests/regressiontests/m2m_through_regress/tests.py b/tests/regressiontests/m2m_through_regress/tests.py index 458c194..9808bd4 100644
a b class ToFieldThroughTests(TestCase): 136 136 ["<Car: Toyota>"] 137 137 ) 138 138 139 def test_to_field_clear_reverse(self): 140 self.driver.car_set.clear() 141 self.assertQuerysetEqual( 142 self.driver.car_set.all(),[]) 143 144 def test_to_field_clear(self): 145 self.car.drivers.clear() 146 self.assertQuerysetEqual( 147 self.car.drivers.all(),[]) 148 149 class AutoToFieldThroughTests(TestCase): 150 def setUp(self): 151 self.car = Car.objects.create(make="Toyota") 152 self.driver = Driver.objects.create(name="Ryan Briscoe") 153 154 def test_add(self): 155 self.assertQuerysetEqual( 156 self.car.drivers.all(),[]) 157 self.car.drivers.add(self.driver) 158 self.assertQuerysetEqual( 159 self.car.drivers.all(), 160 ["<Driver: Ryan Briscoe>"] 161 ) 162 163 def test_add_reverse(self): 164 self.assertQuerysetEqual( 165 self.driver.car_set.all(),[]) 166 self.driver.car_set.add(self.car) 167 self.assertQuerysetEqual( 168 self.driver.car_set.all(), 169 ["<Car: Toyota>"] 170 ) 171 172 def test_remove(self): 173 CarDriver.objects.create(car=self.car, driver=self.driver) 174 self.assertQuerysetEqual( 175 self.car.drivers.all(), 176 ["<Driver: Ryan Briscoe>"] 177 ) 178 self.car.drivers.remove(self.driver) 179 self.assertQuerysetEqual( 180 self.car.drivers.all(),[]) 181 182 183 def test_remove_reverse(self): 184 CarDriver.objects.create(car=self.car, driver=self.driver) 185 self.assertQuerysetEqual( 186 self.driver.car_set.all(), 187 ["<Car: Toyota>"] 188 ) 189 self.driver.car_set.remove(self.car) 190 self.assertQuerysetEqual( 191 self.driver.car_set.all(),[]) 192 193 139 194 class ThroughLoadDataTestCase(TestCase): 140 195 fixtures = ["m2m_through"] 141 196