Ticket #29636: position_field.py

File position_field.py, 10.6 KB (added by Petr Boros, 6 years ago)

PositionField source

Line 
1import datetime
2import warnings
3
4from django.db import models
5from django.db.models.signals import post_delete, post_save, pre_delete
6
7try:
8 from django.utils.timezone import now
9except ImportError:
10 now = datetime.datetime.now
11
12# define basestring for python 3
13try:
14 basestring
15except NameError:
16 basestring = (str, bytes)
17
18
19class PositionField(models.IntegerField):
20 def __init__(self, verbose_name=None, name=None, default=-1, collection=None, parent_link=None, unique_for_field=None, unique_for_fields=None, *args, **kwargs):
21 if 'unique' in kwargs:
22 raise TypeError("%s can't have a unique constraint." % self.__class__.__name__)
23 super(PositionField, self).__init__(verbose_name, name, default=default, *args, **kwargs)
24
25 # Backwards-compatibility mess begins here.
26 if collection is not None and unique_for_field is not None:
27 raise TypeError("'collection' and 'unique_for_field' are incompatible arguments.")
28
29 if collection is not None and unique_for_fields is not None:
30 raise TypeError("'collection' and 'unique_for_fields' are incompatible arguments.")
31
32 if unique_for_field is not None:
33 warnings.warn("The 'unique_for_field' argument is deprecated. Please use 'collection' instead.", DeprecationWarning)
34 if unique_for_fields is not None:
35 raise TypeError("'unique_for_field' and 'unique_for_fields' are incompatible arguments.")
36 collection = unique_for_field
37
38 if unique_for_fields is not None:
39 warnings.warn("The 'unique_for_fields' argument is deprecated. Please use 'collection' instead.", DeprecationWarning)
40 collection = unique_for_fields
41 # Backwards-compatibility mess ends here.
42
43 if isinstance(collection, basestring):
44 collection = (collection,)
45 self.collection = collection
46 self.parent_link = parent_link
47 self._collection_changed = None
48
49 def get_cache_name(self):
50 return '_%s_cache' % self.name
51
52 def contribute_to_class(self, cls, name):
53 super(PositionField, self).contribute_to_class(cls, name)
54 for constraint in cls._meta.unique_together:
55 if self.name in constraint:
56 raise TypeError("%s can't be part of a unique constraint." % self.__class__.__name__)
57 self.auto_now_fields = []
58 for field in cls._meta.fields:
59 if getattr(field, 'auto_now', False):
60 self.auto_now_fields.append(field)
61 setattr(cls, self.name, self)
62 pre_delete.connect(self.prepare_delete, sender=cls)
63 post_delete.connect(self.update_on_delete, sender=cls)
64 post_save.connect(self.update_on_save, sender=cls)
65
66 def pre_save(self, model_instance, add):
67 # NOTE: check if the node has been moved to another collection; if it has, delete it from the old collection.
68 previous_instance = None
69 collection_changed = False
70 if not add and self.collection is not None:
71 try:
72 previous_instance = type(model_instance)._default_manager.get(pk=model_instance.pk)
73 for field_name in self.collection:
74 field = model_instance._meta.get_field(field_name)
75 current_field_value = getattr(model_instance, field.attname)
76 previous_field_value = getattr(previous_instance, field.attname)
77 if previous_field_value != current_field_value:
78 collection_changed = True
79 break
80 except models.ObjectDoesNotExist:
81 add = True
82 if not collection_changed:
83 previous_instance = None
84
85 self._collection_changed = collection_changed
86 if collection_changed:
87 self.remove_from_collection(previous_instance)
88
89 cache_name = self.get_cache_name()
90 current, updated = getattr(model_instance, cache_name)
91
92 if collection_changed:
93 current = None
94
95 if add:
96 if updated is None:
97 updated = current
98 current = None
99
100 # existing instance, position not modified; no cleanup required
101 if current is not None and updated is None:
102 return current
103
104 # if updated is still unknown set the object to the last position,
105 # either it is a new object or collection has been changed
106 if updated is None:
107 updated = -1
108
109 collection_count = self.get_collection(model_instance).count()
110 if current is None:
111 max_position = collection_count
112 else:
113 max_position = collection_count - 1
114 min_position = 0
115
116 # new instance; appended; no cleanup required on post_save
117 if add and (updated == -1 or updated >= max_position):
118 setattr(model_instance, cache_name, (max_position, None))
119 return max_position
120
121 if max_position >= updated >= min_position:
122 # positive position; valid index
123 position = updated
124 elif updated > max_position:
125 # positive position; invalid index
126 position = max_position
127 elif abs(updated) <= (max_position + 1):
128 # negative position; valid index
129
130 # Add 1 to max_position to make this behave like a negative list index.
131 # -1 means the last position, not the last position minus 1
132
133 position = max_position + 1 + updated
134 else:
135 # negative position; invalid index
136 position = min_position
137
138 # instance inserted; cleanup required on post_save
139 setattr(model_instance, cache_name, (current, position))
140 return position
141
142 def __get__(self, instance, owner):
143 if instance is None:
144 raise AttributeError("%s must be accessed via instance." % self.name)
145 current, updated = getattr(instance, self.get_cache_name())
146 return current if updated is None else updated
147
148 def __set__(self, instance, value):
149 if instance is None:
150 raise AttributeError("%s must be accessed via instance." % self.name)
151 if value is None:
152 value = self.default
153 cache_name = self.get_cache_name()
154 try:
155 current, updated = getattr(instance, cache_name)
156 except AttributeError:
157 current, updated = value, None
158 else:
159 updated = value
160
161 instance.__dict__[self.name] = value # Django 1.10 fix for deferred fields
162 setattr(instance, cache_name, (current, updated))
163
164 def get_collection(self, instance):
165 filters = {}
166 if self.collection is not None:
167 for field_name in self.collection:
168 field = instance._meta.get_field(field_name)
169 field_value = getattr(instance, field.attname)
170 if field.null and field_value is None:
171 filters['%s__isnull' % field.name] = True
172 else:
173 filters[field.name] = field_value
174 model = type(instance)
175 parent_link = self.parent_link
176 if parent_link is not None:
177 model = model._meta.get_field(parent_link).rel.to
178 return model._default_manager.filter(**filters)
179
180 def get_next_sibling(self, instance):
181 """
182 Returns the next sibling of this instance.
183 """
184 try:
185 return self.get_collection(instance).filter(**{'%s__gt' % self.name: getattr(instance, self.get_cache_name())[0]})[0]
186 except:
187 return None
188
189 def remove_from_collection(self, instance):
190 """
191 Removes a positioned item from the collection.
192 """
193 queryset = self.get_collection(instance)
194 current = getattr(instance, self.get_cache_name())[0]
195 updates = {self.name: models.F(self.name) - 1}
196 if self.auto_now_fields:
197 right_now = now()
198 for field in self.auto_now_fields:
199 updates[field.name] = right_now
200 queryset.filter(**{'%s__gt' % self.name: current}).update(**updates)
201
202 def prepare_delete(self, sender, instance, **kwargs):
203 next_sibling = self.get_next_sibling(instance)
204 if next_sibling:
205 setattr(instance, '_next_sibling_pk', next_sibling.pk)
206 else:
207 setattr(instance, '_next_sibling_pk', None)
208
209 def update_on_delete(self, sender, instance, **kwargs):
210 next_sibling_pk = getattr(instance, '_next_sibling_pk', None)
211 if next_sibling_pk:
212 try:
213 next_sibling = type(instance)._default_manager.get(pk=next_sibling_pk)
214 except:
215 next_sibling = None
216 if next_sibling:
217 queryset = self.get_collection(next_sibling)
218 current = getattr(instance, self.get_cache_name())[0]
219 updates = {self.name: models.F(self.name) - 1}
220 if self.auto_now_fields:
221 right_now = now()
222 for field in self.auto_now_fields:
223 updates[field.name] = right_now
224 queryset.filter(**{'%s__gt' % self.name: current}).update(**updates)
225 setattr(instance, '_next_sibling_pk', None)
226
227 def update_on_save(self, sender, instance, created, **kwargs):
228 collection_changed = self._collection_changed
229 self._collection_changed = None
230
231 current, updated = getattr(instance, self.get_cache_name())
232
233 if current is None:
234 current = 0
235
236 if updated is None and not collection_changed:
237 return None
238
239 queryset = self.get_collection(instance).exclude(pk=instance.pk)
240
241 updates = {}
242 if self.auto_now_fields:
243 right_now = now()
244 for field in self.auto_now_fields:
245 updates[field.name] = right_now
246
247 if updated is None and created:
248 updated = -1
249
250 if created or collection_changed:
251 # increment positions gte updated or node moved from another collection
252 queryset = queryset.filter(**{'%s__gte' % self.name: updated})
253 updates[self.name] = models.F(self.name) + 1
254 elif updated > current:
255 # decrement positions gt current and lte updated
256 queryset = queryset.filter(**{'%s__gt' % self.name: current, '%s__lte' % self.name: updated})
257 updates[self.name] = models.F(self.name) - 1
258 else:
259 # increment positions lt current and gte updated
260 queryset = queryset.filter(**{'%s__lt' % self.name: current, '%s__gte' % self.name: updated})
261 updates[self.name] = models.F(self.name) + 1
262
263 queryset.update(**updates)
264 setattr(instance, self.get_cache_name(), (updated, None))
Back to Top