Skip to content

Commit 7a775a4

Browse files
committed
Do not persist the context in validators
Fixes #5760
1 parent 4d57d46 commit 7a775a4

File tree

5 files changed

+95
-82
lines changed

5 files changed

+95
-82
lines changed

docs/api-guide/fields.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ If set, this gives the default value that will be used for the field if no input
4747

4848
The `default` is not applied during partial update operations. In the partial update case only fields that are provided in the incoming data will have a validated value returned.
4949

50-
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument. This works the same way as for [validators](validators.md#using-set_context).
50+
May be set to a function or other callable, in which case the value will be evaluated each time it is used. When called, it will receive no arguments. If the callable has a `set_context` method, that will be called each time before getting the value with the field instance as only argument.
5151

5252
When serializing the instance, default will be used if the the object attribute or dictionary key is not present in the instance.
5353

docs/api-guide/validators.md

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -290,13 +290,18 @@ To write a class-based validator, use the `__call__` method. Class-based validat
290290
message = 'This field must be a multiple of %d.' % self.base
291291
raise serializers.ValidationError(message)
292292

293-
#### Using `set_context()`
293+
#### Accessing the context
294294

295-
In some advanced cases you might want a validator to be passed the serializer field it is being used with as additional context. You can do so by declaring a `set_context` method on a class-based validator.
295+
In some advanced cases you might want a validator to be passed the serializer
296+
field it is being used with as additional context. You can do so by using
297+
`rest_framework.validators.ContextBasedValidator` as a base class for the
298+
validator. The `__call__` method will then be called with the `serializer_field`
299+
or `serializer` as an additional argument.
296300

297-
def set_context(self, serializer_field):
301+
def __call__(self, value, serializer_field):
298302
# Determine if this is an update or a create operation.
299-
# In `__call__` we can then use that information to modify the validation behavior.
300-
self.is_update = serializer_field.parent.instance is not None
303+
is_update = serializer_field.parent.instance is not None
304+
305+
pass # implementation of the validator that uses `is_update`
301306

302307
[cite]: https://docs.djangoproject.com/en/stable/ref/validators/

rest_framework/fields.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import inspect
99
import re
1010
import uuid
11+
import warnings
1112
from collections import OrderedDict
1213

1314
from django.conf import settings
@@ -541,13 +542,25 @@ def run_validators(self, value):
541542
Test the given value against all the validators on the field,
542543
and either raise a `ValidationError` or simply return.
543544
"""
545+
from rest_framework.validators import ContextBasedValidator
546+
544547
errors = []
545548
for validator in self.validators:
546549
if hasattr(validator, 'set_context'):
550+
warnings.warn(
551+
"Method `set_context` on validators is deprecated and will "
552+
"no longer be called starting with 3.10. Instead derive the "
553+
"validator from `rest_framwork.validators.ContextBasedValidator` "
554+
"and accept the context as additional argument.",
555+
DeprecationWarning, stacklevel=2
556+
)
547557
validator.set_context(self)
548558

549559
try:
550-
validator(value)
560+
if isinstance(validator, ContextBasedValidator):
561+
validator(value, self)
562+
else:
563+
validator(value)
551564
except ValidationError as exc:
552565
# If the validation error contains a mapping of fields to
553566
# errors then simply raise it immediately rather than

rest_framework/validators.py

Lines changed: 66 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,16 @@ def qs_filter(queryset, **kwargs):
3333
return queryset.none()
3434

3535

36-
class UniqueValidator(object):
36+
class ContextBasedValidator(object):
37+
"""Base class for validators that need a context during evaluation.
38+
39+
In extension to regular validators their `__call__` method must not only
40+
accept a value, but also an instance of a serializer.
41+
"""
42+
pass
43+
44+
45+
class UniqueValidator(ContextBasedValidator):
3746
"""
3847
Validator that corresponds to `unique=True` on a model field.
3948
@@ -47,37 +56,32 @@ def __init__(self, queryset, message=None, lookup='exact'):
4756
self.message = message or self.message
4857
self.lookup = lookup
4958

50-
def set_context(self, serializer_field):
51-
"""
52-
This hook is called by the serializer instance,
53-
prior to the validation call being made.
54-
"""
55-
# Determine the underlying model field name. This may not be the
56-
# same as the serializer field name if `source=<>` is set.
57-
self.field_name = serializer_field.source_attrs[-1]
58-
# Determine the existing instance, if this is an update operation.
59-
self.instance = getattr(serializer_field.parent, 'instance', None)
60-
61-
def filter_queryset(self, value, queryset):
59+
def filter_queryset(self, value, queryset, field_name):
6260
"""
6361
Filter the queryset to all instances matching the given attribute.
6462
"""
65-
filter_kwargs = {'%s__%s' % (self.field_name, self.lookup): value}
63+
filter_kwargs = {'%s__%s' % (field_name, self.lookup): value}
6664
return qs_filter(queryset, **filter_kwargs)
6765

68-
def exclude_current_instance(self, queryset):
66+
def exclude_current_instance(self, queryset, instance):
6967
"""
7068
If an instance is being updated, then do not include
7169
that instance itself as a uniqueness conflict.
7270
"""
73-
if self.instance is not None:
74-
return queryset.exclude(pk=self.instance.pk)
71+
if instance is not None:
72+
return queryset.exclude(pk=instance.pk)
7573
return queryset
7674

77-
def __call__(self, value):
75+
def __call__(self, value, serializer_field):
76+
# Determine the underlying model field name. This may not be the
77+
# same as the serializer field name if `source=<>` is set.
78+
field_name = serializer_field.source_attrs[-1]
79+
# Determine the existing instance, if this is an update operation.
80+
instance = getattr(serializer_field.parent, 'instance', None)
81+
7882
queryset = self.queryset
79-
queryset = self.filter_queryset(value, queryset)
80-
queryset = self.exclude_current_instance(queryset)
83+
queryset = self.filter_queryset(value, queryset, field_name)
84+
queryset = self.exclude_current_instance(queryset, instance)
8185
if qs_exists(queryset):
8286
raise ValidationError(self.message, code='unique')
8387

@@ -88,7 +92,7 @@ def __repr__(self):
8892
))
8993

9094

91-
class UniqueTogetherValidator(object):
95+
class UniqueTogetherValidator(ContextBasedValidator):
9296
"""
9397
Validator that corresponds to `unique_together = (...)` on a model class.
9498
@@ -103,20 +107,12 @@ def __init__(self, queryset, fields, message=None):
103107
self.serializer_field = None
104108
self.message = message or self.message
105109

106-
def set_context(self, serializer):
107-
"""
108-
This hook is called by the serializer instance,
109-
prior to the validation call being made.
110-
"""
111-
# Determine the existing instance, if this is an update operation.
112-
self.instance = getattr(serializer, 'instance', None)
113-
114-
def enforce_required_fields(self, attrs):
110+
def enforce_required_fields(self, attrs, instance):
115111
"""
116112
The `UniqueTogetherValidator` always forces an implied 'required'
117113
state on the fields it applies to.
118114
"""
119-
if self.instance is not None:
115+
if instance is not None:
120116
return
121117

122118
missing_items = {
@@ -127,16 +123,16 @@ def enforce_required_fields(self, attrs):
127123
if missing_items:
128124
raise ValidationError(missing_items, code='required')
129125

130-
def filter_queryset(self, attrs, queryset):
126+
def filter_queryset(self, attrs, queryset, instance):
131127
"""
132128
Filter the queryset to all instances matching the given attributes.
133129
"""
134130
# If this is an update, then any unprovided field should
135131
# have it's value set based on the existing instance attribute.
136-
if self.instance is not None:
132+
if instance is not None:
137133
for field_name in self.fields:
138134
if field_name not in attrs:
139-
attrs[field_name] = getattr(self.instance, field_name)
135+
attrs[field_name] = getattr(instance, field_name)
140136

141137
# Determine the filter keyword arguments and filter the queryset.
142138
filter_kwargs = {
@@ -145,20 +141,23 @@ def filter_queryset(self, attrs, queryset):
145141
}
146142
return qs_filter(queryset, **filter_kwargs)
147143

148-
def exclude_current_instance(self, attrs, queryset):
144+
def exclude_current_instance(self, attrs, queryset, instance):
149145
"""
150146
If an instance is being updated, then do not include
151147
that instance itself as a uniqueness conflict.
152148
"""
153-
if self.instance is not None:
154-
return queryset.exclude(pk=self.instance.pk)
149+
if instance is not None:
150+
return queryset.exclude(pk=instance.pk)
155151
return queryset
156152

157-
def __call__(self, attrs):
158-
self.enforce_required_fields(attrs)
153+
def __call__(self, attrs, serializer):
154+
# Determine the existing instance, if this is an update operation.
155+
instance = getattr(serializer, 'instance', None)
156+
157+
self.enforce_required_fields(attrs, instance)
159158
queryset = self.queryset
160-
queryset = self.filter_queryset(attrs, queryset)
161-
queryset = self.exclude_current_instance(attrs, queryset)
159+
queryset = self.filter_queryset(attrs, queryset, instance)
160+
queryset = self.exclude_current_instance(attrs, queryset, instance)
162161

163162
# Ignore validation if any field is None
164163
checked_values = [
@@ -177,7 +176,7 @@ def __repr__(self):
177176
))
178177

179178

180-
class BaseUniqueForValidator(object):
179+
class BaseUniqueForValidator(ContextBasedValidator):
181180
message = None
182181
missing_message = _('This field is required.')
183182

@@ -187,18 +186,6 @@ def __init__(self, queryset, field, date_field, message=None):
187186
self.date_field = date_field
188187
self.message = message or self.message
189188

190-
def set_context(self, serializer):
191-
"""
192-
This hook is called by the serializer instance,
193-
prior to the validation call being made.
194-
"""
195-
# Determine the underlying model field names. These may not be the
196-
# same as the serializer field names if `source=<>` is set.
197-
self.field_name = serializer.fields[self.field].source_attrs[-1]
198-
self.date_field_name = serializer.fields[self.date_field].source_attrs[-1]
199-
# Determine the existing instance, if this is an update operation.
200-
self.instance = getattr(serializer, 'instance', None)
201-
202189
def enforce_required_fields(self, attrs):
203190
"""
204191
The `UniqueFor<Range>Validator` classes always force an implied
@@ -212,23 +199,30 @@ def enforce_required_fields(self, attrs):
212199
if missing_items:
213200
raise ValidationError(missing_items, code='required')
214201

215-
def filter_queryset(self, attrs, queryset):
202+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
216203
raise NotImplementedError('`filter_queryset` must be implemented.')
217204

218-
def exclude_current_instance(self, attrs, queryset):
205+
def exclude_current_instance(self, attrs, queryset, instance):
219206
"""
220207
If an instance is being updated, then do not include
221208
that instance itself as a uniqueness conflict.
222209
"""
223-
if self.instance is not None:
224-
return queryset.exclude(pk=self.instance.pk)
210+
if instance is not None:
211+
return queryset.exclude(pk=instance.pk)
225212
return queryset
226213

227-
def __call__(self, attrs):
214+
def __call__(self, attrs, serializer):
215+
# Determine the underlying model field names. These may not be the
216+
# same as the serializer field names if `source=<>` is set.
217+
field_name = serializer.fields[self.field].source_attrs[-1]
218+
date_field_name = serializer.fields[self.date_field].source_attrs[-1]
219+
# Determine the existing instance, if this is an update operation.
220+
instance = getattr(serializer, 'instance', None)
221+
228222
self.enforce_required_fields(attrs)
229223
queryset = self.queryset
230-
queryset = self.filter_queryset(attrs, queryset)
231-
queryset = self.exclude_current_instance(attrs, queryset)
224+
queryset = self.filter_queryset(attrs, queryset, field_name, date_field_name)
225+
queryset = self.exclude_current_instance(attrs, queryset, instance)
232226
if qs_exists(queryset):
233227
message = self.message.format(date_field=self.date_field)
234228
raise ValidationError({
@@ -247,39 +241,39 @@ def __repr__(self):
247241
class UniqueForDateValidator(BaseUniqueForValidator):
248242
message = _('This field must be unique for the "{date_field}" date.')
249243

250-
def filter_queryset(self, attrs, queryset):
244+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
251245
value = attrs[self.field]
252246
date = attrs[self.date_field]
253247

254248
filter_kwargs = {}
255-
filter_kwargs[self.field_name] = value
256-
filter_kwargs['%s__day' % self.date_field_name] = date.day
257-
filter_kwargs['%s__month' % self.date_field_name] = date.month
258-
filter_kwargs['%s__year' % self.date_field_name] = date.year
249+
filter_kwargs[field_name] = value
250+
filter_kwargs['%s__day' % date_field_name] = date.day
251+
filter_kwargs['%s__month' % date_field_name] = date.month
252+
filter_kwargs['%s__year' % date_field_name] = date.year
259253
return qs_filter(queryset, **filter_kwargs)
260254

261255

262256
class UniqueForMonthValidator(BaseUniqueForValidator):
263257
message = _('This field must be unique for the "{date_field}" month.')
264258

265-
def filter_queryset(self, attrs, queryset):
259+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
266260
value = attrs[self.field]
267261
date = attrs[self.date_field]
268262

269263
filter_kwargs = {}
270-
filter_kwargs[self.field_name] = value
271-
filter_kwargs['%s__month' % self.date_field_name] = date.month
264+
filter_kwargs[field_name] = value
265+
filter_kwargs['%s__month' % date_field_name] = date.month
272266
return qs_filter(queryset, **filter_kwargs)
273267

274268

275269
class UniqueForYearValidator(BaseUniqueForValidator):
276270
message = _('This field must be unique for the "{date_field}" year.')
277271

278-
def filter_queryset(self, attrs, queryset):
272+
def filter_queryset(self, attrs, queryset, field_name, date_field_name):
279273
value = attrs[self.field]
280274
date = attrs[self.date_field]
281275

282276
filter_kwargs = {}
283-
filter_kwargs[self.field_name] = value
284-
filter_kwargs['%s__year' % self.date_field_name] = date.year
277+
filter_kwargs[field_name] = value
278+
filter_kwargs['%s__year' % date_field_name] = date.year
285279
return qs_filter(queryset, **filter_kwargs)

tests/test_validators.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -361,8 +361,7 @@ def filter(self, **kwargs):
361361
queryset = MockQueryset()
362362
validator = UniqueTogetherValidator(queryset, fields=('race_name',
363363
'position'))
364-
validator.instance = self.instance
365-
validator.filter_queryset(attrs=data, queryset=queryset)
364+
validator.filter_queryset(attrs=data, queryset=queryset, instance=self.instance)
366365
assert queryset.called_with == {'race_name': 'bar', 'position': 1}
367366

368367

@@ -586,4 +585,6 @@ def test_validator_raises_error_when_abstract_method_called(self):
586585
validator = BaseUniqueForValidator(queryset=object(), field='foo',
587586
date_field='bar')
588587
with pytest.raises(NotImplementedError):
589-
validator.filter_queryset(attrs=None, queryset=None)
588+
validator.filter_queryset(
589+
attrs=None, queryset=None, field_name='', date_field_name=''
590+
)

0 commit comments

Comments
 (0)