Ticket #9200: ticket9200-1.diff

File ticket9200-1.diff, 138.4 KB (added by Jannis Leidel, 14 years ago)

Updated patch (including changes for secure cookie from #12417)

  • django/conf/global_settings.py

    diff --git a/django/conf/global_settings.py b/django/conf/global_settings.py
    index 88aa5a3..c98cab7 100644
    a b LOGIN_REDIRECT_URL = '/accounts/profile/'  
    476476# The number of days a password reset link is valid for
    477477PASSWORD_RESET_TIMEOUT_DAYS = 3
    478478
     479###########
     480# SIGNING #
     481###########
     482
     483SIGNING_BACKEND = 'django.core.signing.TimestampSigner'
     484
    479485########
    480486# CSRF #
    481487########
  • django/contrib/formtools/tests/__init__.py

    diff --git a/django/contrib/formtools/tests/__init__.py b/django/contrib/formtools/tests/__init__.py
    index be0372a..911bc65 100644
    a b from django.test import TestCase  
    88from django.test.utils import get_warnings_state, restore_warnings_state
    99from django.utils import unittest
    1010
     11from django.contrib.formtools.wizard.tests import *
     12
    1113
    1214success_string = "Done was called!"
    1315
  • deleted file django/contrib/formtools/wizard.py

    diff --git a/django/contrib/formtools/wizard.py b/django/contrib/formtools/wizard.py
    deleted file mode 100644
    index c19578c..0000000
    + -  
    1 """
    2 FormWizard class -- implements a multi-page form, validating between each
    3 step and storing the form's state as HTML hidden fields so that no state is
    4 stored on the server side.
    5 """
    6 
    7 try:
    8     import cPickle as pickle
    9 except ImportError:
    10     import pickle
    11 
    12 from django import forms
    13 from django.conf import settings
    14 from django.contrib.formtools.utils import form_hmac
    15 from django.http import Http404
    16 from django.shortcuts import render_to_response
    17 from django.template.context import RequestContext
    18 from django.utils.crypto import constant_time_compare
    19 from django.utils.translation import ugettext_lazy as _
    20 from django.utils.decorators import method_decorator
    21 from django.views.decorators.csrf import csrf_protect
    22 
    23 
    24 class FormWizard(object):
    25     # The HTML (and POST data) field name for the "step" variable.
    26     step_field_name="wizard_step"
    27 
    28     # METHODS SUBCLASSES SHOULDN'T OVERRIDE ###################################
    29 
    30     def __init__(self, form_list, initial=None):
    31         """
    32         Start a new wizard with a list of forms.
    33 
    34         form_list should be a list of Form classes (not instances).
    35         """
    36         self.form_list = form_list[:]
    37         self.initial = initial or {}
    38 
    39         # Dictionary of extra template context variables.
    40         self.extra_context = {}
    41 
    42         # A zero-based counter keeping track of which step we're in.
    43         self.step = 0
    44 
    45     def __repr__(self):
    46         return "step: %d\nform_list: %s\ninitial_data: %s" % (self.step, self.form_list, self.initial)
    47 
    48     def get_form(self, step, data=None):
    49         "Helper method that returns the Form instance for the given step."
    50         # Sanity check.
    51         if step >= self.num_steps():
    52             raise Http404('Step %s does not exist' % step)
    53         return self.form_list[step](data, prefix=self.prefix_for_step(step), initial=self.initial.get(step, None))
    54 
    55     def num_steps(self):
    56         "Helper method that returns the number of steps."
    57         # You might think we should just set "self.num_steps = len(form_list)"
    58         # in __init__(), but this calculation needs to be dynamic, because some
    59         # hook methods might alter self.form_list.
    60         return len(self.form_list)
    61 
    62     def _check_security_hash(self, token, request, form):
    63         expected = self.security_hash(request, form)
    64         return constant_time_compare(token, expected)
    65 
    66     @method_decorator(csrf_protect)
    67     def __call__(self, request, *args, **kwargs):
    68         """
    69         Main method that does all the hard work, conforming to the Django view
    70         interface.
    71         """
    72         if 'extra_context' in kwargs:
    73             self.extra_context.update(kwargs['extra_context'])
    74         current_step = self.determine_step(request, *args, **kwargs)
    75         self.parse_params(request, *args, **kwargs)
    76 
    77         # Validate and process all the previous forms before instantiating the
    78         # current step's form in case self.process_step makes changes to
    79         # self.form_list.
    80 
    81         # If any of them fails validation, that must mean the validator relied
    82         # on some other input, such as an external Web site.
    83 
    84         # It is also possible that alidation might fail under certain attack
    85         # situations: an attacker might be able to bypass previous stages, and
    86         # generate correct security hashes for all the skipped stages by virtue
    87         # of:
    88         #  1) having filled out an identical form which doesn't have the
    89         #     validation (and does something different at the end),
    90         #  2) or having filled out a previous version of the same form which
    91         #     had some validation missing,
    92         #  3) or previously having filled out the form when they had more
    93         #     privileges than they do now.
    94         #
    95         # Since the hashes only take into account values, and not other other
    96         # validation the form might do, we must re-do validation now for
    97         # security reasons.
    98         previous_form_list = []
    99         for i in range(current_step):
    100             f = self.get_form(i, request.POST)
    101             if not self._check_security_hash(request.POST.get("hash_%d" % i, ''),
    102                                              request, f):
    103                 return self.render_hash_failure(request, i)
    104 
    105             if not f.is_valid():
    106                 return self.render_revalidation_failure(request, i, f)
    107             else:
    108                 self.process_step(request, f, i)
    109                 previous_form_list.append(f)
    110 
    111         # Process the current step. If it's valid, go to the next step or call
    112         # done(), depending on whether any steps remain.
    113         if request.method == 'POST':
    114             form = self.get_form(current_step, request.POST)
    115         else:
    116             form = self.get_form(current_step)
    117 
    118         if form.is_valid():
    119             self.process_step(request, form, current_step)
    120             next_step = current_step + 1
    121 
    122             if next_step == self.num_steps():
    123                 return self.done(request, previous_form_list + [form])
    124             else:
    125                 form = self.get_form(next_step)
    126                 self.step = current_step = next_step
    127 
    128         return self.render(form, request, current_step)
    129 
    130     def render(self, form, request, step, context=None):
    131         "Renders the given Form object, returning an HttpResponse."
    132         old_data = request.POST
    133         prev_fields = []
    134         if old_data:
    135             hidden = forms.HiddenInput()
    136             # Collect all data from previous steps and render it as HTML hidden fields.
    137             for i in range(step):
    138                 old_form = self.get_form(i, old_data)
    139                 hash_name = 'hash_%s' % i
    140                 prev_fields.extend([bf.as_hidden() for bf in old_form])
    141                 prev_fields.append(hidden.render(hash_name, old_data.get(hash_name, self.security_hash(request, old_form))))
    142         return self.render_template(request, form, ''.join(prev_fields), step, context)
    143 
    144     # METHODS SUBCLASSES MIGHT OVERRIDE IF APPROPRIATE ########################
    145 
    146     def prefix_for_step(self, step):
    147         "Given the step, returns a Form prefix to use."
    148         return str(step)
    149 
    150     def render_hash_failure(self, request, step):
    151         """
    152         Hook for rendering a template if a hash check failed.
    153 
    154         step is the step that failed. Any previous step is guaranteed to be
    155         valid.
    156 
    157         This default implementation simply renders the form for the given step,
    158         but subclasses may want to display an error message, etc.
    159         """
    160         return self.render(self.get_form(step), request, step, context={'wizard_error': _('We apologize, but your form has expired. Please continue filling out the form from this page.')})
    161 
    162     def render_revalidation_failure(self, request, step, form):
    163         """
    164         Hook for rendering a template if final revalidation failed.
    165 
    166         It is highly unlikely that this point would ever be reached, but See
    167         the comment in __call__() for an explanation.
    168         """
    169         return self.render(form, request, step)
    170 
    171     def security_hash(self, request, form):
    172         """
    173         Calculates the security hash for the given HttpRequest and Form instances.
    174 
    175         Subclasses may want to take into account request-specific information,
    176         such as the IP address.
    177         """
    178         return form_hmac(form)
    179 
    180     def determine_step(self, request, *args, **kwargs):
    181         """
    182         Given the request object and whatever *args and **kwargs were passed to
    183         __call__(), returns the current step (which is zero-based).
    184 
    185         Note that the result should not be trusted. It may even be a completely
    186         invalid number. It's not the job of this method to validate it.
    187         """
    188         if not request.POST:
    189             return 0
    190         try:
    191             step = int(request.POST.get(self.step_field_name, 0))
    192         except ValueError:
    193             return 0
    194         return step
    195 
    196     def parse_params(self, request, *args, **kwargs):
    197         """
    198         Hook for setting some state, given the request object and whatever
    199         *args and **kwargs were passed to __call__(), sets some state.
    200 
    201         This is called at the beginning of __call__().
    202         """
    203         pass
    204 
    205     def get_template(self, step):
    206         """
    207         Hook for specifying the name of the template to use for a given step.
    208 
    209         Note that this can return a tuple of template names if you'd like to
    210         use the template system's select_template() hook.
    211         """
    212         return 'forms/wizard.html'
    213 
    214     def render_template(self, request, form, previous_fields, step, context=None):
    215         """
    216         Renders the template for the given step, returning an HttpResponse object.
    217 
    218         Override this method if you want to add a custom context, return a
    219         different MIME type, etc. If you only need to override the template
    220         name, use get_template() instead.
    221 
    222         The template will be rendered with the following context:
    223             step_field -- The name of the hidden field containing the step.
    224             step0      -- The current step (zero-based).
    225             step       -- The current step (one-based).
    226             step_count -- The total number of steps.
    227             form       -- The Form instance for the current step (either empty
    228                           or with errors).
    229             previous_fields -- A string representing every previous data field,
    230                           plus hashes for completed forms, all in the form of
    231                           hidden fields. Note that you'll need to run this
    232                           through the "safe" template filter, to prevent
    233                           auto-escaping, because it's raw HTML.
    234         """
    235         context = context or {}
    236         context.update(self.extra_context)
    237         return render_to_response(self.get_template(step), dict(context,
    238             step_field=self.step_field_name,
    239             step0=step,
    240             step=step + 1,
    241             step_count=self.num_steps(),
    242             form=form,
    243             previous_fields=previous_fields
    244         ), context_instance=RequestContext(request))
    245 
    246     def process_step(self, request, form, step):
    247         """
    248         Hook for modifying the FormWizard's internal state, given a fully
    249         validated Form object. The Form is guaranteed to have clean, valid
    250         data.
    251 
    252         This method should *not* modify any of that data. Rather, it might want
    253         to set self.extra_context or dynamically alter self.form_list, based on
    254         previously submitted forms.
    255 
    256         Note that this method is called every time a page is rendered for *all*
    257         submitted steps.
    258         """
    259         pass
    260 
    261     # METHODS SUBCLASSES MUST OVERRIDE ########################################
    262 
    263     def done(self, request, form_list):
    264         """
    265         Hook for doing something with the validated data. This is responsible
    266         for the final processing.
    267 
    268         form_list is a list of Form instances, each containing clean, valid
    269         data.
    270         """
    271         raise NotImplementedError("Your %s class has not defined a done() method, which is required." % self.__class__.__name__)
  • new file django/contrib/formtools/wizard/__init__.py

    diff --git a/django/contrib/formtools/wizard/__init__.py b/django/contrib/formtools/wizard/__init__.py
    new file mode 100644
    index 0000000..c19578c
    - +  
     1"""
     2FormWizard class -- implements a multi-page form, validating between each
     3step and storing the form's state as HTML hidden fields so that no state is
     4stored on the server side.
     5"""
     6
     7try:
     8    import cPickle as pickle
     9except ImportError:
     10    import pickle
     11
     12from django import forms
     13from django.conf import settings
     14from django.contrib.formtools.utils import form_hmac
     15from django.http import Http404
     16from django.shortcuts import render_to_response
     17from django.template.context import RequestContext
     18from django.utils.crypto import constant_time_compare
     19from django.utils.translation import ugettext_lazy as _
     20from django.utils.decorators import method_decorator
     21from django.views.decorators.csrf import csrf_protect
     22
     23
     24class FormWizard(object):
     25    # The HTML (and POST data) field name for the "step" variable.
     26    step_field_name="wizard_step"
     27
     28    # METHODS SUBCLASSES SHOULDN'T OVERRIDE ###################################
     29
     30    def __init__(self, form_list, initial=None):
     31        """
     32        Start a new wizard with a list of forms.
     33
     34        form_list should be a list of Form classes (not instances).
     35        """
     36        self.form_list = form_list[:]
     37        self.initial = initial or {}
     38
     39        # Dictionary of extra template context variables.
     40        self.extra_context = {}
     41
     42        # A zero-based counter keeping track of which step we're in.
     43        self.step = 0
     44
     45    def __repr__(self):
     46        return "step: %d\nform_list: %s\ninitial_data: %s" % (self.step, self.form_list, self.initial)
     47
     48    def get_form(self, step, data=None):
     49        "Helper method that returns the Form instance for the given step."
     50        # Sanity check.
     51        if step >= self.num_steps():
     52            raise Http404('Step %s does not exist' % step)
     53        return self.form_list[step](data, prefix=self.prefix_for_step(step), initial=self.initial.get(step, None))
     54
     55    def num_steps(self):
     56        "Helper method that returns the number of steps."
     57        # You might think we should just set "self.num_steps = len(form_list)"
     58        # in __init__(), but this calculation needs to be dynamic, because some
     59        # hook methods might alter self.form_list.
     60        return len(self.form_list)
     61
     62    def _check_security_hash(self, token, request, form):
     63        expected = self.security_hash(request, form)
     64        return constant_time_compare(token, expected)
     65
     66    @method_decorator(csrf_protect)
     67    def __call__(self, request, *args, **kwargs):
     68        """
     69        Main method that does all the hard work, conforming to the Django view
     70        interface.
     71        """
     72        if 'extra_context' in kwargs:
     73            self.extra_context.update(kwargs['extra_context'])
     74        current_step = self.determine_step(request, *args, **kwargs)
     75        self.parse_params(request, *args, **kwargs)
     76
     77        # Validate and process all the previous forms before instantiating the
     78        # current step's form in case self.process_step makes changes to
     79        # self.form_list.
     80
     81        # If any of them fails validation, that must mean the validator relied
     82        # on some other input, such as an external Web site.
     83
     84        # It is also possible that alidation might fail under certain attack
     85        # situations: an attacker might be able to bypass previous stages, and
     86        # generate correct security hashes for all the skipped stages by virtue
     87        # of:
     88        #  1) having filled out an identical form which doesn't have the
     89        #     validation (and does something different at the end),
     90        #  2) or having filled out a previous version of the same form which
     91        #     had some validation missing,
     92        #  3) or previously having filled out the form when they had more
     93        #     privileges than they do now.
     94        #
     95        # Since the hashes only take into account values, and not other other
     96        # validation the form might do, we must re-do validation now for
     97        # security reasons.
     98        previous_form_list = []
     99        for i in range(current_step):
     100            f = self.get_form(i, request.POST)
     101            if not self._check_security_hash(request.POST.get("hash_%d" % i, ''),
     102                                             request, f):
     103                return self.render_hash_failure(request, i)
     104
     105            if not f.is_valid():
     106                return self.render_revalidation_failure(request, i, f)
     107            else:
     108                self.process_step(request, f, i)
     109                previous_form_list.append(f)
     110
     111        # Process the current step. If it's valid, go to the next step or call
     112        # done(), depending on whether any steps remain.
     113        if request.method == 'POST':
     114            form = self.get_form(current_step, request.POST)
     115        else:
     116            form = self.get_form(current_step)
     117
     118        if form.is_valid():
     119            self.process_step(request, form, current_step)
     120            next_step = current_step + 1
     121
     122            if next_step == self.num_steps():
     123                return self.done(request, previous_form_list + [form])
     124            else:
     125                form = self.get_form(next_step)
     126                self.step = current_step = next_step
     127
     128        return self.render(form, request, current_step)
     129
     130    def render(self, form, request, step, context=None):
     131        "Renders the given Form object, returning an HttpResponse."
     132        old_data = request.POST
     133        prev_fields = []
     134        if old_data:
     135            hidden = forms.HiddenInput()
     136            # Collect all data from previous steps and render it as HTML hidden fields.
     137            for i in range(step):
     138                old_form = self.get_form(i, old_data)
     139                hash_name = 'hash_%s' % i
     140                prev_fields.extend([bf.as_hidden() for bf in old_form])
     141                prev_fields.append(hidden.render(hash_name, old_data.get(hash_name, self.security_hash(request, old_form))))
     142        return self.render_template(request, form, ''.join(prev_fields), step, context)
     143
     144    # METHODS SUBCLASSES MIGHT OVERRIDE IF APPROPRIATE ########################
     145
     146    def prefix_for_step(self, step):
     147        "Given the step, returns a Form prefix to use."
     148        return str(step)
     149
     150    def render_hash_failure(self, request, step):
     151        """
     152        Hook for rendering a template if a hash check failed.
     153
     154        step is the step that failed. Any previous step is guaranteed to be
     155        valid.
     156
     157        This default implementation simply renders the form for the given step,
     158        but subclasses may want to display an error message, etc.
     159        """
     160        return self.render(self.get_form(step), request, step, context={'wizard_error': _('We apologize, but your form has expired. Please continue filling out the form from this page.')})
     161
     162    def render_revalidation_failure(self, request, step, form):
     163        """
     164        Hook for rendering a template if final revalidation failed.
     165
     166        It is highly unlikely that this point would ever be reached, but See
     167        the comment in __call__() for an explanation.
     168        """
     169        return self.render(form, request, step)
     170
     171    def security_hash(self, request, form):
     172        """
     173        Calculates the security hash for the given HttpRequest and Form instances.
     174
     175        Subclasses may want to take into account request-specific information,
     176        such as the IP address.
     177        """
     178        return form_hmac(form)
     179
     180    def determine_step(self, request, *args, **kwargs):
     181        """
     182        Given the request object and whatever *args and **kwargs were passed to
     183        __call__(), returns the current step (which is zero-based).
     184
     185        Note that the result should not be trusted. It may even be a completely
     186        invalid number. It's not the job of this method to validate it.
     187        """
     188        if not request.POST:
     189            return 0
     190        try:
     191            step = int(request.POST.get(self.step_field_name, 0))
     192        except ValueError:
     193            return 0
     194        return step
     195
     196    def parse_params(self, request, *args, **kwargs):
     197        """
     198        Hook for setting some state, given the request object and whatever
     199        *args and **kwargs were passed to __call__(), sets some state.
     200
     201        This is called at the beginning of __call__().
     202        """
     203        pass
     204
     205    def get_template(self, step):
     206        """
     207        Hook for specifying the name of the template to use for a given step.
     208
     209        Note that this can return a tuple of template names if you'd like to
     210        use the template system's select_template() hook.
     211        """
     212        return 'forms/wizard.html'
     213
     214    def render_template(self, request, form, previous_fields, step, context=None):
     215        """
     216        Renders the template for the given step, returning an HttpResponse object.
     217
     218        Override this method if you want to add a custom context, return a
     219        different MIME type, etc. If you only need to override the template
     220        name, use get_template() instead.
     221
     222        The template will be rendered with the following context:
     223            step_field -- The name of the hidden field containing the step.
     224            step0      -- The current step (zero-based).
     225            step       -- The current step (one-based).
     226            step_count -- The total number of steps.
     227            form       -- The Form instance for the current step (either empty
     228                          or with errors).
     229            previous_fields -- A string representing every previous data field,
     230                          plus hashes for completed forms, all in the form of
     231                          hidden fields. Note that you'll need to run this
     232                          through the "safe" template filter, to prevent
     233                          auto-escaping, because it's raw HTML.
     234        """
     235        context = context or {}
     236        context.update(self.extra_context)
     237        return render_to_response(self.get_template(step), dict(context,
     238            step_field=self.step_field_name,
     239            step0=step,
     240            step=step + 1,
     241            step_count=self.num_steps(),
     242            form=form,
     243            previous_fields=previous_fields
     244        ), context_instance=RequestContext(request))
     245
     246    def process_step(self, request, form, step):
     247        """
     248        Hook for modifying the FormWizard's internal state, given a fully
     249        validated Form object. The Form is guaranteed to have clean, valid
     250        data.
     251
     252        This method should *not* modify any of that data. Rather, it might want
     253        to set self.extra_context or dynamically alter self.form_list, based on
     254        previously submitted forms.
     255
     256        Note that this method is called every time a page is rendered for *all*
     257        submitted steps.
     258        """
     259        pass
     260
     261    # METHODS SUBCLASSES MUST OVERRIDE ########################################
     262
     263    def done(self, request, form_list):
     264        """
     265        Hook for doing something with the validated data. This is responsible
     266        for the final processing.
     267
     268        form_list is a list of Form instances, each containing clean, valid
     269        data.
     270        """
     271        raise NotImplementedError("Your %s class has not defined a done() method, which is required." % self.__class__.__name__)
  • new file django/contrib/formtools/wizard/storage/__init__.py

    diff --git a/django/contrib/formtools/wizard/storage/__init__.py b/django/contrib/formtools/wizard/storage/__init__.py
    new file mode 100644
    index 0000000..7f03028
    - +  
     1from django.core.exceptions import ImproperlyConfigured
     2from django.utils.importlib import import_module
     3
     4from django.contrib.formtools.wizard.storage.base import BaseStorage
     5
     6class MissingStorageModule(ImproperlyConfigured):
     7    pass
     8
     9class MissingStorageClass(ImproperlyConfigured):
     10    pass
     11
     12class NoFileStorageConfigured(ImproperlyConfigured):
     13    pass
     14
     15def get_storage(path, *args, **kwargs):
     16    i = path.rfind('.')
     17    module, attr = path[:i], path[i+1:]
     18    try:
     19        mod = import_module(module)
     20    except ImportError, e:
     21        raise MissingStorageModule(
     22            'Error loading storage %s: "%s"' % (module, e))
     23    try:
     24        storage_class = getattr(mod, attr)
     25    except AttributeError:
     26        raise MissingStorageClass(
     27            'Module "%s" does not define a storage named "%s"' % (module, attr))
     28    return storage_class(*args, **kwargs)
     29
  • new file django/contrib/formtools/wizard/storage/base.py

    diff --git a/django/contrib/formtools/wizard/storage/base.py b/django/contrib/formtools/wizard/storage/base.py
    new file mode 100644
    index 0000000..0e9c677
    - +  
     1class BaseStorage(object):
     2    def __init__(self, prefix):
     3        self.prefix = 'wizard_%s' % prefix
     4
     5    def get_current_step(self):
     6        raise NotImplementedError
     7
     8    def set_current_step(self, step):
     9        raise NotImplementedError
     10
     11    def get_step_data(self, step):
     12        raise NotImplementedError
     13
     14    def get_current_step_data(self):
     15        raise NotImplementedError
     16
     17    def set_step_data(self, step, cleaned_data):
     18        raise NotImplementedError
     19
     20    def get_step_files(self, step):
     21        raise NotImplementedError
     22
     23    def set_step_files(self, step, files):
     24        raise NotImplementedError
     25
     26    def get_extra_context_data(self):
     27        raise NotImplementedError
     28
     29    def set_extra_context_data(self, extra_context):
     30        raise NotImplementedError
     31
     32    def reset(self):
     33        raise NotImplementedError
     34
     35    def update_response(self, response):
     36        raise NotImplementedError
     37
  • new file django/contrib/formtools/wizard/storage/cookie.py

    diff --git a/django/contrib/formtools/wizard/storage/cookie.py b/django/contrib/formtools/wizard/storage/cookie.py
    new file mode 100644
    index 0000000..f11cd15
    - +  
     1from django.core.exceptions import SuspiciousOperation
     2from django.core.signing import BadSignature
     3from django.core.files.uploadedfile import UploadedFile
     4from django.utils import simplejson as json
     5
     6from django.contrib.formtools.wizard.storage import (BaseStorage,
     7                                                     NoFileStorageConfigured)
     8
     9class CookieStorage(BaseStorage):
     10    step_cookie_key = 'step'
     11    step_data_cookie_key = 'step_data'
     12    step_files_cookie_key = 'step_files'
     13    extra_context_cookie_key = 'extra_context'
     14
     15    def __init__(self, prefix, request, file_storage, *args, **kwargs):
     16        super(CookieStorage, self).__init__(prefix)
     17        self.file_storage = file_storage
     18        self.request = request
     19        self.cookie_data = self.load_cookie_data()
     20        if self.cookie_data is None:
     21            self.init_storage()
     22
     23    def init_storage(self):
     24        self.cookie_data = {
     25            self.step_cookie_key: None,
     26            self.step_data_cookie_key: {},
     27            self.step_files_cookie_key: {},
     28            self.extra_context_cookie_key: {},
     29        }
     30        return True
     31
     32    def get_current_step(self):
     33        return self.cookie_data[self.step_cookie_key]
     34
     35    def set_current_step(self, step):
     36        self.cookie_data[self.step_cookie_key] = step
     37        return True
     38
     39    def get_step_data(self, step):
     40        return self.cookie_data[self.step_data_cookie_key].get(step, None)
     41
     42    def get_current_step_data(self):
     43        return self.get_step_data(self.get_current_step())
     44
     45    def set_step_data(self, step, cleaned_data):
     46        self.cookie_data[self.step_data_cookie_key][step] = cleaned_data
     47        return True
     48
     49    def set_step_files(self, step, files):
     50        if files and not self.file_storage:
     51            raise NoFileStorageConfigured
     52
     53        if step not in self.cookie_data[self.step_files_cookie_key]:
     54            self.cookie_data[self.step_files_cookie_key][step] = {}
     55
     56        for field, field_file in (files or {}).items():
     57            tmp_filename = self.file_storage.save(field_file.name, field_file)
     58            file_dict = {
     59                'tmp_name': tmp_filename,
     60                'name': field_file.name,
     61                'content_type': field_file.content_type,
     62                'size': field_file.size,
     63                'charset': field_file.charset
     64            }
     65            self.cookie_data[self.step_files_cookie_key][step][field] = file_dict
     66
     67        return True
     68
     69    def get_current_step_files(self):
     70        return self.get_step_files(self.get_current_step())
     71
     72    def get_step_files(self, step):
     73        session_files = self.cookie_data[self.step_files_cookie_key].get(step, {})
     74
     75        if session_files and not self.file_storage:
     76            raise NoFileStorageConfigured
     77
     78        files = {}
     79        for field, field_dict in session_files.items():
     80            files[field] = UploadedFile(
     81                file=self.file_storage.open(field_dict['tmp_name']),
     82                name=field_dict['name'],
     83                content_type=field_dict['content_type'],
     84                size=field_dict['size'],
     85                charset=field_dict['charset'],
     86            )
     87        return files or None
     88
     89    def get_extra_context_data(self):
     90        return self.cookie_data[self.extra_context_cookie_key] or {}
     91
     92    def set_extra_context_data(self, extra_context):
     93        self.cookie_data[self.extra_context_cookie_key] = extra_context
     94        return True
     95
     96    def reset(self):
     97        return self.init_storage()
     98
     99    def update_response(self, response):
     100        if len(self.cookie_data) > 0:
     101            response.set_signed_cookie(self.prefix,
     102                self.create_cookie_data(self.cookie_data))
     103        else:
     104            response.delete_cookie(self.prefix)
     105        return response
     106
     107    def load_cookie_data(self):
     108        try:
     109            data = self.request.get_signed_cookie(self.prefix)
     110        except KeyError:
     111            data = None
     112        except BadSignature:
     113            raise SuspiciousOperation('FormWizard cookie manipulated')
     114
     115        if data is None:
     116            return None
     117
     118        return json.loads(data, cls=json.JSONDecoder)
     119
     120    def create_cookie_data(self, data):
     121        encoder = json.JSONEncoder(separators=(',', ':'))
     122        return encoder.encode(data)
     123
  • new file django/contrib/formtools/wizard/storage/session.py

    diff --git a/django/contrib/formtools/wizard/storage/session.py b/django/contrib/formtools/wizard/storage/session.py
    new file mode 100644
    index 0000000..35468e7
    - +  
     1from django.core.files.uploadedfile import UploadedFile
     2
     3from django.contrib.formtools.wizard.storage import (BaseStorage,
     4                                                     NoFileStorageConfigured)
     5
     6class SessionStorage(BaseStorage):
     7    step_session_key = 'step'
     8    step_data_session_key = 'step_data'
     9    step_files_session_key = 'step_files'
     10    extra_context_session_key = 'extra_context'
     11
     12    def __init__(self, prefix, request, file_storage=None, *args, **kwargs):
     13        super(SessionStorage, self).__init__(prefix)
     14        self.request = request
     15        self.file_storage = file_storage
     16        if self.prefix not in self.request.session:
     17            self.init_storage()
     18
     19    def init_storage(self):
     20        self.request.session[self.prefix] = {
     21            self.step_session_key: None,
     22            self.step_data_session_key: {},
     23            self.step_files_session_key: {},
     24            self.extra_context_session_key: {},
     25        }
     26        self.request.session.modified = True
     27        return True
     28
     29    def get_current_step(self):
     30        return self.request.session[self.prefix][self.step_session_key]
     31
     32    def set_current_step(self, step):
     33        self.request.session[self.prefix][self.step_session_key] = step
     34        self.request.session.modified = True
     35        return True
     36
     37    def get_step_data(self, step):
     38        return self.request.session[self.prefix][self.step_data_session_key].get(step, None)
     39
     40    def get_current_step_data(self):
     41        return self.get_step_data(self.get_current_step())
     42
     43    def set_step_data(self, step, cleaned_data):
     44        self.request.session[self.prefix][self.step_data_session_key][step] = cleaned_data
     45        self.request.session.modified = True
     46        return True
     47
     48    def set_step_files(self, step, files):
     49        if files and not self.file_storage:
     50            raise NoFileStorageConfigured
     51
     52        if step not in self.request.session[self.prefix][self.step_files_session_key]:
     53            self.request.session[self.prefix][self.step_files_session_key][step] = {}
     54
     55        for field, field_file in (files or {}).items():
     56            tmp_filename = self.file_storage.save(field_file.name, field_file)
     57            file_dict = {
     58                'tmp_name': tmp_filename,
     59                'name': field_file.name,
     60                'content_type': field_file.content_type,
     61                'size': field_file.size,
     62                'charset': field_file.charset
     63            }
     64            self.request.session[self.prefix][self.step_files_session_key][step][field] = file_dict
     65
     66        self.request.session.modified = True
     67        return True
     68
     69    def get_current_step_files(self):
     70        return self.get_step_files(self.get_current_step())
     71
     72    def get_step_files(self, step):
     73        session_files = self.request.session[self.prefix][self.step_files_session_key].get(step, {})
     74
     75        if session_files and not self.file_storage:
     76            raise NoFileStorageConfigured
     77
     78        files = {}
     79        for field, field_dict in session_files.items():
     80            files[field] = UploadedFile(
     81                file=self.file_storage.open(field_dict['tmp_name']),
     82                name=field_dict['name'],
     83                content_type=field_dict['content_type'],
     84                size=field_dict['size'],
     85                charset=field_dict['charset'],
     86            )
     87        return files or None
     88
     89    def get_extra_context_data(self):
     90        return self.request.session[self.prefix][self.extra_context_session_key] or {}
     91
     92    def set_extra_context_data(self, extra_context):
     93        self.request.session[self.prefix][self.extra_context_session_key] = extra_context
     94        self.request.session.modified = True
     95        return True
     96
     97    def reset(self):
     98        if self.file_storage:
     99            for step_fields in self.request.session[self.prefix][self.step_files_session_key].values():
     100                for file_dict in step_fields.values():
     101                    self.file_storage.delete(file_dict['tmp_name'])
     102        return self.init_storage()
     103
     104    def update_response(self, response):
     105        return response
     106
  • new file django/contrib/formtools/wizard/templates/formtools/wizard/wizard.html

    diff --git a/django/contrib/formtools/wizard/templates/formtools/wizard/wizard.html b/django/contrib/formtools/wizard/templates/formtools/wizard/wizard.html
    new file mode 100644
    index 0000000..6981312
    - +  
     1{% load i18n %}
     2{% csrf_token %}
     3{% if form.forms %}
     4    {{ form.management_form }}
     5    {% for fs in form.forms %}
     6        {{ fs.as_p }}
     7    {% endfor %}
     8{% else %}
     9    {{ form.as_p }}
     10{% endif %}
     11
     12{% if form_prev_step %}
     13<button name="form_prev_step" value="{{ form_first_step }}">{% trans "first step" %}</button>
     14<button name="form_prev_step" value="{{ form_prev_step }}">{% trans "prev step" %}</button>
     15{% endif %}
     16<input type="submit" name="submit" value="{% trans "submit" %}" />
  • new file django/contrib/formtools/wizard/tests/__init__.py

    diff --git a/django/contrib/formtools/wizard/tests/__init__.py b/django/contrib/formtools/wizard/tests/__init__.py
    new file mode 100644
    index 0000000..22fd8bc
    - +  
     1from django.contrib.formtools.wizard.tests.formtests import *
     2from django.contrib.formtools.wizard.tests.basestoragetests import *
     3from django.contrib.formtools.wizard.tests.sessionstoragetests import *
     4from django.contrib.formtools.wizard.tests.cookiestoragetests import *
     5from django.contrib.formtools.wizard.tests.loadstoragetests import *
     6from django.contrib.formtools.wizard.tests.wizardtests import *
     7from django.contrib.formtools.wizard.tests.namedwizardtests import *
  • new file django/contrib/formtools/wizard/tests/basestoragetests.py

    diff --git a/django/contrib/formtools/wizard/tests/basestoragetests.py b/django/contrib/formtools/wizard/tests/basestoragetests.py
    new file mode 100644
    index 0000000..4e46dba
    - +  
     1from django.test import TestCase
     2from django.contrib.formtools.wizard.storage.base import BaseStorage
     3
     4class TestBaseStorage(TestCase):
     5    def setUp(self):
     6        self.storage = BaseStorage('wizard1')
     7
     8    def test_get_current_step(self):
     9        self.assertRaises(NotImplementedError,
     10                          self.storage.get_current_step)
     11
     12    def test_set_current_step(self):
     13        self.assertRaises(NotImplementedError,
     14                          self.storage.set_current_step, None)
     15
     16    def test_get_step_data(self):
     17        self.assertRaises(NotImplementedError,
     18                          self.storage.get_step_data, None)
     19
     20    def test_set_step_data(self):
     21        self.assertRaises(NotImplementedError,
     22                          self.storage.set_step_data, None, None)
     23
     24    def test_get_extra_context_data(self):
     25        self.assertRaises(NotImplementedError,
     26                          self.storage.get_extra_context_data)
     27
     28    def test_set_extra_context_data(self):
     29        self.assertRaises(NotImplementedError,
     30                          self.storage.set_extra_context_data, None)
     31
     32    def test_reset(self):
     33        self.assertRaises(NotImplementedError,
     34                          self.storage.reset)
     35
     36    def test_update_response(self):
     37        self.assertRaises(NotImplementedError,
     38                          self.storage.update_response, None)
     39
  • new file django/contrib/formtools/wizard/tests/cookiestoragetests.py

    diff --git a/django/contrib/formtools/wizard/tests/cookiestoragetests.py b/django/contrib/formtools/wizard/tests/cookiestoragetests.py
    new file mode 100644
    index 0000000..945df5c
    - +  
     1from django.test import TestCase
     2from django.core import signing
     3from django.core.exceptions import SuspiciousOperation
     4from django.http import HttpResponse
     5
     6from django.contrib.formtools.wizard.storage.cookie import CookieStorage
     7from django.contrib.formtools.wizard.tests.storagetests import *
     8
     9class TestCookieStorage(TestStorage, TestCase):
     10    def get_storage(self):
     11        return CookieStorage
     12
     13    def test_manipulated_cookie(self):
     14        request = get_request()
     15        storage = self.get_storage()('wizard1', request, None)
     16
     17        cookie_signer = signing.get_cookie_signer()
     18
     19        storage.request.COOKIES[storage.prefix] = cookie_signer.sign(
     20            storage.create_cookie_data({'key1': 'value1'}),
     21            salt=storage.prefix)
     22
     23        self.assertEqual(storage.load_cookie_data(), {'key1': 'value1'})
     24
     25        storage.request.COOKIES[storage.prefix] = 'i_am_manipulated'
     26        self.assertRaises(SuspiciousOperation, storage.load_cookie_data)
     27
     28        #raise SuspiciousOperation('FormWizard cookie manipulated')
     29
     30    def test_delete_cookie(self):
     31        request = get_request()
     32        storage = self.get_storage()('wizard1', request, None)
     33
     34        storage.cookie_data = {'key1': 'value1'}
     35
     36        response = HttpResponse()
     37        storage.update_response(response)
     38
     39        cookie_signer = signing.get_cookie_signer()
     40        signed_cookie_data = cookie_signer.sign(
     41            storage.create_cookie_data(storage.cookie_data),
     42            salt=storage.prefix)
     43
     44        self.assertEqual(response.cookies[storage.prefix].value,
     45            signed_cookie_data)
     46
     47        storage.cookie_data = {}
     48        storage.update_response(response)
     49        self.assertEqual(response.cookies[storage.prefix].value, '')
  • new file django/contrib/formtools/wizard/tests/formtests.py

    diff --git a/django/contrib/formtools/wizard/tests/formtests.py b/django/contrib/formtools/wizard/tests/formtests.py
    new file mode 100644
    index 0000000..b600eb3
    - +  
     1from django import forms, http
     2from django.conf import settings
     3from django.test import TestCase
     4from django.template.response import TemplateResponse
     5from django.utils.importlib import import_module
     6
     7from django.contrib.auth.models import User
     8
     9from django.contrib.formtools.wizard.views import (WizardView,
     10                                                   SessionWizardView,
     11                                                   CookieWizardView)
     12
     13
     14class DummyRequest(http.HttpRequest):
     15    def __init__(self, POST=None):
     16        super(DummyRequest, self).__init__()
     17        self.method = POST and "POST" or "GET"
     18        if POST is not None:
     19            self.POST.update(POST)
     20        self.session = {}
     21        self._dont_enforce_csrf_checks = True
     22
     23def get_request(*args, **kwargs):
     24    request = DummyRequest(*args, **kwargs)
     25    engine = import_module(settings.SESSION_ENGINE)
     26    request.session = engine.SessionStore(None)
     27    return request
     28
     29class Step1(forms.Form):
     30    name = forms.CharField()
     31
     32class Step2(forms.Form):
     33    name = forms.CharField()
     34
     35class Step3(forms.Form):
     36    data = forms.CharField()
     37
     38class UserForm(forms.ModelForm):
     39    class Meta:
     40        model = User
     41
     42UserFormSet = forms.models.modelformset_factory(User, form=UserForm, extra=2)
     43
     44class TestWizard(WizardView):
     45    storage_name = 'django.contrib.formtools.wizard.storage.session.SessionStorage'
     46
     47    def dispatch(self, request, *args, **kwargs):
     48        response = super(TestWizard, self).dispatch(request, *args, **kwargs)
     49        return response, self
     50
     51class FormTests(TestCase):
     52    def test_form_init(self):
     53        testform = TestWizard.get_initkwargs([Step1, Step2])
     54        self.assertEquals(testform['form_list'], {u'0': Step1, u'1': Step2})
     55
     56        testform = TestWizard.get_initkwargs([('start', Step1), ('step2', Step2)])
     57        self.assertEquals(
     58            testform['form_list'], {u'start': Step1, u'step2': Step2})
     59
     60        testform = TestWizard.get_initkwargs([Step1, Step2, ('finish', Step3)])
     61        self.assertEquals(
     62            testform['form_list'], {u'0': Step1, u'1': Step2, u'finish': Step3})
     63
     64    def test_first_step(self):
     65        request = get_request()
     66
     67        testform = TestWizard.as_view([Step1, Step2])
     68        response, instance = testform(request)
     69        self.assertEquals(instance.determine_step(), u'0')
     70
     71        testform = TestWizard.as_view([('start', Step1), ('step2', Step2)])
     72        response, instance = testform(request)
     73
     74        self.assertEquals(instance.determine_step(), 'start')
     75
     76    def test_persistence(self):
     77        request = get_request({'name': 'data1'})
     78
     79        testform = TestWizard.as_view([('start', Step1), ('step2', Step2)])
     80        response, instance = testform(request)
     81        self.assertEquals(instance.determine_step(), 'start')
     82        instance.storage.set_current_step('step2')
     83
     84        testform2 = TestWizard.as_view([('start', Step1), ('step2', Step2)])
     85        response, instance = testform2(request)
     86        self.assertEquals(instance.determine_step(), 'step2')
     87
     88    def test_form_condition(self):
     89        request = get_request()
     90
     91        testform = TestWizard.as_view(
     92            [('start', Step1), ('step2', Step2), ('step3', Step3)],
     93            condition_list={'step2': True})
     94        response, instance = testform(request)
     95        self.assertEquals(instance.get_next_step(), 'step2')
     96
     97        testform = TestWizard.as_view(
     98            [('start', Step1), ('step2', Step2), ('step3', Step3)],
     99            condition_list={'step2': False})
     100        response, instance = testform(request)
     101        self.assertEquals(instance.get_next_step(), 'step3')
     102
     103    def test_add_extra_context(self):
     104        request = get_request()
     105
     106        testform = TestWizard.as_view([('start', Step1), ('step2', Step2)])
     107        response, instance = testform(
     108            request, extra_context={'key1': 'value1'})
     109        self.assertEqual(instance.get_extra_context(), {'key1': 'value1'})
     110
     111        request.method = 'POST'
     112        response, instance = testform(
     113            request, extra_context={'key1': 'value1'})
     114        self.assertEqual(instance.get_extra_context(), {'key1': 'value1'})
     115
     116    def test_form_prefix(self):
     117        request = get_request()
     118
     119        testform = TestWizard.as_view([('start', Step1), ('step2', Step2)])
     120        response, instance = testform(request)
     121
     122        self.assertEqual(instance.get_form_prefix(), 'start')
     123        self.assertEqual(instance.get_form_prefix('another'), 'another')
     124
     125    def test_form_initial(self):
     126        request = get_request()
     127
     128        testform = TestWizard.as_view([('start', Step1), ('step2', Step2)],
     129            initial_list={'start': {'name': 'value1'}})
     130        response, instance = testform(request)
     131
     132        self.assertEqual(instance.get_form_initial('start'), {'name': 'value1'})
     133        self.assertEqual(instance.get_form_initial('step2'), {})
     134
     135    def test_form_instance(self):
     136        request = get_request()
     137        the_instance = User()
     138        testform = TestWizard.as_view([('start', UserForm), ('step2', Step2)],
     139            instance_list={'start': the_instance})
     140        response, instance = testform(request)
     141
     142        self.assertEqual(
     143            instance.get_form_instance('start'),
     144            the_instance)
     145        self.assertEqual(
     146            instance.get_form_instance('non_exist_instance'),
     147            None)
     148
     149    def test_formset_instance(self):
     150        request = get_request()
     151        the_instance1, created = User.objects.get_or_create(
     152            username='testuser1')
     153        the_instance2, created = User.objects.get_or_create(
     154            username='testuser2')
     155        testform = TestWizard.as_view([('start', UserFormSet), ('step2', Step2)],
     156            instance_list={'start': User.objects.filter(username='testuser1')})
     157        response, instance = testform(request)
     158
     159        self.assertEqual(list(instance.get_form_instance('start')), [the_instance1])
     160        self.assertEqual(instance.get_form_instance('non_exist_instance'), None)
     161
     162        self.assertEqual(instance.get_form().initial_form_count(), 1)
     163
     164    def test_done(self):
     165        request = get_request()
     166
     167        testform = TestWizard.as_view([('start', Step1), ('step2', Step2)])
     168        response, instance = testform(request)
     169
     170        self.assertRaises(NotImplementedError, instance.done, None)
     171
     172    def test_revalidation(self):
     173        request = get_request()
     174
     175        testform = TestWizard.as_view([('start', Step1), ('step2', Step2)])
     176        response, instance = testform(request)
     177        instance.render_done(None)
     178        self.assertEqual(instance.storage.get_current_step(), 'start')
     179
     180    def test_form_refresh(self):
     181        testform = TestWizard.as_view([('start', Step1), ('step2', UserFormSet)])
     182        request = get_request({'start-name': 'foo'})
     183        request.method = 'POST'
     184
     185        response, instance = testform(request)
     186        self.assertEqual(instance.storage.get_current_step(), 'step2')
     187        # refresh form
     188        response, instance = testform(request)
     189        self.assertEqual(instance.storage.get_current_step(), 'step2')
     190
     191
     192class SessionFormTests(TestCase):
     193    def test_init(self):
     194        request = get_request()
     195        testform = SessionWizardView.as_view([('start', Step1)])
     196        self.assertTrue(isinstance(testform(request), TemplateResponse))
     197
     198
     199class CookieFormTests(TestCase):
     200    def test_init(self):
     201        request = get_request()
     202        testform = CookieWizardView.as_view([('start', Step1)])
     203        self.assertTrue(isinstance(testform(request), TemplateResponse))
     204
  • new file django/contrib/formtools/wizard/tests/loadstoragetests.py

    diff --git a/django/contrib/formtools/wizard/tests/loadstoragetests.py b/django/contrib/formtools/wizard/tests/loadstoragetests.py
    new file mode 100644
    index 0000000..267dee0
    - +  
     1from django.test import TestCase
     2
     3from django.contrib.formtools.wizard.storage import (get_storage,
     4                                                     MissingStorageModule,
     5                                                     MissingStorageClass)
     6from django.contrib.formtools.wizard.storage.base import BaseStorage
     7
     8
     9class TestLoadStorage(TestCase):
     10    def test_load_storage(self):
     11        self.assertEqual(
     12            type(get_storage('django.contrib.formtools.wizard.storage.base.BaseStorage', 'wizard1')),
     13            BaseStorage)
     14
     15    def test_missing_module(self):
     16        self.assertRaises(MissingStorageModule, get_storage,
     17            'django.contrib.formtools.wizard.storage.idontexist.IDontExistStorage', 'wizard1')
     18
     19    def test_missing_class(self):
     20        self.assertRaises(MissingStorageClass, get_storage,
     21            'django.contrib.formtools.wizard.storage.base.IDontExistStorage', 'wizard1')
     22
  • new file django/contrib/formtools/wizard/tests/namedwizardtests/__init__.py

    diff --git a/django/contrib/formtools/wizard/tests/namedwizardtests/__init__.py b/django/contrib/formtools/wizard/tests/namedwizardtests/__init__.py
    new file mode 100644
    index 0000000..4387356
    - +  
     1from django.contrib.formtools.wizard.tests.namedwizardtests.tests import *
     2 No newline at end of file
  • new file django/contrib/formtools/wizard/tests/namedwizardtests/forms.py

    diff --git a/django/contrib/formtools/wizard/tests/namedwizardtests/forms.py b/django/contrib/formtools/wizard/tests/namedwizardtests/forms.py
    new file mode 100644
    index 0000000..ae98126
    - +  
     1from django import forms
     2from django.forms.formsets import formset_factory
     3from django.http import HttpResponse
     4from django.template import Template, Context
     5
     6from django.contrib.auth.models import User
     7
     8from django.contrib.formtools.wizard.views import NamedUrlWizardView
     9
     10class Page1(forms.Form):
     11    name = forms.CharField(max_length=100)
     12    user = forms.ModelChoiceField(queryset=User.objects.all())
     13    thirsty = forms.NullBooleanField()
     14
     15class Page2(forms.Form):
     16    address1 = forms.CharField(max_length=100)
     17    address2 = forms.CharField(max_length=100)
     18
     19class Page3(forms.Form):
     20    random_crap = forms.CharField(max_length=100)
     21
     22Page4 = formset_factory(Page3, extra=2)
     23
     24class ContactWizard(NamedUrlWizardView):
     25    def done(self, form_list, **kwargs):
     26        c = Context({
     27            'form_list': [x.cleaned_data for x in form_list],
     28            'all_cleaned_data': self.get_all_cleaned_data()
     29        })
     30
     31        for form in self.form_list.keys():
     32            c[form] = self.get_cleaned_data_for_step(form)
     33
     34        c['this_will_fail'] = self.get_cleaned_data_for_step('this_will_fail')
     35        return HttpResponse(Template('').render(c))
     36
     37class SessionContactWizard(ContactWizard):
     38    storage_name = 'django.contrib.formtools.wizard.storage.session.SessionStorage'
     39
     40class CookieContactWizard(ContactWizard):
     41    storage_name = 'django.contrib.formtools.wizard.storage.cookie.CookieStorage'
     42
  • new file django/contrib/formtools/wizard/tests/namedwizardtests/tests.py

    diff --git a/django/contrib/formtools/wizard/tests/namedwizardtests/models.py b/django/contrib/formtools/wizard/tests/namedwizardtests/models.py
    new file mode 100644
    index 0000000..e69de29
    diff --git a/django/contrib/formtools/wizard/tests/namedwizardtests/tests.py b/django/contrib/formtools/wizard/tests/namedwizardtests/tests.py
    new file mode 100644
    index 0000000..de83764
    - +  
     1import os
     2
     3from django.core.urlresolvers import reverse
     4from django.http import QueryDict
     5from django.test import TestCase
     6from django.conf import settings
     7
     8from django.contrib.auth.models import User
     9
     10from django.contrib.formtools import wizard
     11
     12from django.contrib.formtools.wizard.views import (NamedUrlSessionWizardView,
     13                                                   NamedUrlCookieWizardView)
     14from django.contrib.formtools.wizard.tests.formtests import (get_request,
     15                                                             Step1,
     16                                                             Step2)
     17
     18class NamedWizardTests(object):
     19    urls = 'django.contrib.formtools.wizard.tests.namedwizardtests.urls'
     20
     21    wizard_step_data = (
     22        {
     23            'form1-name': 'Pony',
     24            'form1-thirsty': '2',
     25        },
     26        {
     27            'form2-address1': '123 Main St',
     28            'form2-address2': 'Djangoland',
     29        },
     30        {
     31            'form3-random_crap': 'blah blah',
     32        },
     33        {
     34            'form4-INITIAL_FORMS': '0',
     35            'form4-TOTAL_FORMS': '2',
     36            'form4-MAX_NUM_FORMS': '0',
     37            'form4-0-random_crap': 'blah blah',
     38            'form4-1-random_crap': 'blah blah',
     39        }
     40    )
     41
     42    def setUp(self):
     43        self.testuser, created = User.objects.get_or_create(username='testuser1')
     44        self.wizard_step_data[0]['form1-user'] = self.testuser.pk
     45
     46        wizard_template_dirs = [os.path.join(os.path.dirname(wizard.__file__), 'templates')]
     47        settings.TEMPLATE_DIRS = list(settings.TEMPLATE_DIRS) + wizard_template_dirs
     48
     49    def tearDown(self):
     50        del settings.TEMPLATE_DIRS[-1]
     51
     52    def test_initial_call(self):
     53        response = self.client.get(reverse('%s_start' % self.wizard_urlname))
     54        self.assertEqual(response.status_code, 302)
     55        response = self.client.get(response['Location'])
     56        self.assertEqual(response.status_code, 200)
     57        self.assertEqual(response.context['form_step'], 'form1')
     58        self.assertEqual(response.context['form_step0'], 0)
     59        self.assertEqual(response.context['form_step1'], 1)
     60        self.assertEqual(response.context['form_last_step'], 'form4')
     61        self.assertEqual(response.context['form_prev_step'], None)
     62        self.assertEqual(response.context['form_next_step'], 'form2')
     63        self.assertEqual(response.context['form_step_count'], 4)
     64
     65    def test_initial_call_with_params(self):
     66        get_params = {'getvar1': 'getval1', 'getvar2': 'getval2'}
     67        response = self.client.get(reverse('%s_start' % self.wizard_urlname),
     68                                   get_params)
     69        self.assertEqual(response.status_code, 302)
     70
     71        # Test for proper redirect GET parameters
     72        location = response['Location']
     73        self.assertNotEqual(location.find('?'), -1)
     74        querydict = QueryDict(location[location.find('?') + 1:])
     75        self.assertEqual(dict(querydict.items()), get_params)
     76
     77    def test_form_post_error(self):
     78        response = self.client.post(
     79            reverse(self.wizard_urlname, kwargs={'step':'form1'}))
     80
     81        self.assertEqual(response.status_code, 200)
     82        self.assertEqual(response.context['form_step'], 'form1')
     83        self.assertEqual(response.context['form'].errors,
     84                         {'name': [u'This field is required.'],
     85                          'user': [u'This field is required.']})
     86
     87    def test_form_post_success(self):
     88        response = self.client.post(
     89            reverse(self.wizard_urlname, kwargs={'step':'form1'}),
     90            self.wizard_step_data[0])
     91        response = self.client.get(response['Location'])
     92
     93        self.assertEqual(response.status_code, 200)
     94        self.assertEqual(response.context['form_step'], 'form2')
     95        self.assertEqual(response.context['form_step0'], 1)
     96        self.assertEqual(response.context['form_prev_step'], 'form1')
     97        self.assertEqual(response.context['form_next_step'], 'form3')
     98
     99    def test_form_stepback(self):
     100        response = self.client.get(
     101            reverse(self.wizard_urlname, kwargs={'step':'form1'}))
     102
     103        self.assertEqual(response.status_code, 200)
     104        self.assertEqual(response.context['form_step'], 'form1')
     105
     106        response = self.client.post(
     107            reverse(self.wizard_urlname, kwargs={'step':'form1'}),
     108            self.wizard_step_data[0])
     109        response = self.client.get(response['Location'])
     110
     111        self.assertEqual(response.status_code, 200)
     112        self.assertEqual(response.context['form_step'], 'form2')
     113
     114        response = self.client.post(
     115            reverse(self.wizard_urlname,
     116                    kwargs={'step': response.context['form_step']}),
     117            {'form_prev_step': response.context['form_prev_step']})
     118        response = self.client.get(response['Location'])
     119
     120        self.assertEqual(response.status_code, 200)
     121        self.assertEqual(response.context['form_step'], 'form1')
     122
     123    def test_form_jump(self):
     124        response = self.client.get(
     125            reverse(self.wizard_urlname, kwargs={'step':'form1'}))
     126
     127        self.assertEqual(response.status_code, 200)
     128        self.assertEqual(response.context['form_step'], 'form1')
     129
     130        response = self.client.get(
     131            reverse(self.wizard_urlname, kwargs={'step':'form3'}))
     132        self.assertEqual(response.status_code, 200)
     133        self.assertEqual(response.context['form_step'], 'form3')
     134
     135    def test_form_finish(self):
     136        response = self.client.get(
     137            reverse(self.wizard_urlname, kwargs={'step': 'form1'}))
     138
     139        self.assertEqual(response.status_code, 200)
     140        self.assertEqual(response.context['form_step'], 'form1')
     141
     142        response = self.client.post(
     143            reverse(self.wizard_urlname,
     144                    kwargs={'step': response.context['form_step']}),
     145            self.wizard_step_data[0])
     146        response = self.client.get(response['Location'])
     147
     148        self.assertEqual(response.status_code, 200)
     149        self.assertEqual(response.context['form_step'], 'form2')
     150
     151        response = self.client.post(
     152            reverse(self.wizard_urlname,
     153                    kwargs={'step': response.context['form_step']}),
     154            self.wizard_step_data[1])
     155        response = self.client.get(response['Location'])
     156
     157        self.assertEqual(response.status_code, 200)
     158        self.assertEqual(response.context['form_step'], 'form3')
     159
     160        response = self.client.post(
     161            reverse(self.wizard_urlname,
     162                    kwargs={'step': response.context['form_step']}),
     163            self.wizard_step_data[2])
     164        response = self.client.get(response['Location'])
     165
     166        self.assertEqual(response.status_code, 200)
     167        self.assertEqual(response.context['form_step'], 'form4')
     168
     169        response = self.client.post(
     170            reverse(self.wizard_urlname,
     171                    kwargs={'step': response.context['form_step']}),
     172            self.wizard_step_data[3])
     173        response = self.client.get(response['Location'])
     174        self.assertEqual(response.status_code, 200)
     175
     176        self.assertEqual(response.context['form_list'], [
     177            {'name': u'Pony', 'thirsty': True, 'user': self.testuser},
     178            {'address1': u'123 Main St', 'address2': u'Djangoland'},
     179            {'random_crap': u'blah blah'},
     180            [{'random_crap': u'blah blah'}, {'random_crap': u'blah blah'}]])
     181
     182    def test_cleaned_data(self):
     183        response = self.client.get(
     184            reverse(self.wizard_urlname, kwargs={'step': 'form1'}))
     185        self.assertEqual(response.status_code, 200)
     186
     187        response = self.client.post(
     188            reverse(self.wizard_urlname,
     189                    kwargs={'step': response.context['form_step']}),
     190            self.wizard_step_data[0])
     191        response = self.client.get(response['Location'])
     192        self.assertEqual(response.status_code, 200)
     193
     194        response = self.client.post(
     195            reverse(self.wizard_urlname,
     196                    kwargs={'step': response.context['form_step']}),
     197            self.wizard_step_data[1])
     198        response = self.client.get(response['Location'])
     199        self.assertEqual(response.status_code, 200)
     200
     201        response = self.client.post(
     202            reverse(self.wizard_urlname,
     203                    kwargs={'step': response.context['form_step']}),
     204            self.wizard_step_data[2])
     205        response = self.client.get(response['Location'])
     206        self.assertEqual(response.status_code, 200)
     207
     208        response = self.client.post(
     209            reverse(self.wizard_urlname,
     210                    kwargs={'step': response.context['form_step']}),
     211            self.wizard_step_data[3])
     212        response = self.client.get(response['Location'])
     213        self.assertEqual(response.status_code, 200)
     214
     215        self.assertEqual(
     216            response.context['all_cleaned_data'],
     217            {'name': u'Pony', 'thirsty': True, 'user': self.testuser,
     218             'address1': u'123 Main St', 'address2': u'Djangoland',
     219             'random_crap': u'blah blah', 'formset-form4': [
     220                 {'random_crap': u'blah blah'},
     221                 {'random_crap': u'blah blah'}
     222             ]})
     223
     224    def test_manipulated_data(self):
     225        response = self.client.get(
     226            reverse(self.wizard_urlname, kwargs={'step': 'form1'}))
     227        self.assertEqual(response.status_code, 200)
     228
     229        response = self.client.post(
     230            reverse(self.wizard_urlname,
     231                    kwargs={'step': response.context['form_step']}),
     232            self.wizard_step_data[0])
     233        response = self.client.get(response['Location'])
     234        self.assertEqual(response.status_code, 200)
     235
     236        response = self.client.post(
     237            reverse(self.wizard_urlname,
     238                    kwargs={'step': response.context['form_step']}),
     239            self.wizard_step_data[1])
     240        response = self.client.get(response['Location'])
     241        self.assertEqual(response.status_code, 200)
     242
     243        response = self.client.post(
     244            reverse(self.wizard_urlname,
     245                    kwargs={'step': response.context['form_step']}),
     246            self.wizard_step_data[2])
     247        response = self.client.get(response['Location'])
     248        self.assertEqual(response.status_code, 200)
     249
     250        self.client.cookies.pop('sessionid', None)
     251        self.client.cookies.pop('wizard_cookie_contact_wizard', None)
     252
     253        response = self.client.post(
     254            reverse(self.wizard_urlname,
     255                    kwargs={'step': response.context['form_step']}),
     256            self.wizard_step_data[3])
     257        self.assertEqual(response.status_code, 200)
     258        self.assertEqual(response.context.get('form_step', None), 'form1')
     259
     260    def test_form_reset(self):
     261        response = self.client.post(
     262            reverse(self.wizard_urlname, kwargs={'step':'form1'}),
     263            self.wizard_step_data[0])
     264        response = self.client.get(response['Location'])
     265        self.assertEqual(response.status_code, 200)
     266        self.assertEqual(response.context['form_step'], 'form2')
     267
     268        response = self.client.get(
     269            '%s?reset=1' % reverse('%s_start' % self.wizard_urlname))
     270        self.assertEqual(response.status_code, 302)
     271
     272        response = self.client.get(response['Location'])
     273        self.assertEqual(response.status_code, 200)
     274        self.assertEqual(response.context['form_step'], 'form1')
     275
     276class NamedSessionWizardTests(NamedWizardTests, TestCase):
     277    wizard_urlname = 'nwiz_session'
     278
     279class NamedCookieWizardTests(NamedWizardTests, TestCase):
     280    wizard_urlname = 'nwiz_cookie'
     281
     282class NamedFormTests(object):
     283    urls = 'django.contrib.formtools.wizard.tests.namedwizardtests.urls'
     284
     285    def test_add_extra_context(self):
     286        request = get_request()
     287
     288        testform = self.formwizard_class.as_view(
     289            [('start', Step1), ('step2', Step2)],
     290            url_name=self.wizard_urlname)
     291
     292        response, instance = testform(request,
     293                                      step='form1',
     294                                      extra_context={'key1': 'value1'})
     295        self.assertEqual(instance.get_extra_context(), {'key1': 'value1'})
     296
     297        instance.reset_wizard()
     298
     299        response, instance = testform(request,
     300                                      extra_context={'key2': 'value2'})
     301        self.assertEqual(instance.get_extra_context(), {'key2': 'value2'})
     302
     303    def test_revalidation(self):
     304        request = get_request()
     305
     306        testform = self.formwizard_class.as_view(
     307            [('start', Step1), ('step2', Step2)],
     308            url_name=self.wizard_urlname)
     309        response, instance = testform(request, step='done')
     310
     311        instance.render_done(None)
     312        self.assertEqual(instance.storage.get_current_step(), 'start')
     313
     314class TestNamedUrlSessionFormWizard(NamedUrlSessionWizardView):
     315
     316    def dispatch(self, request, *args, **kwargs):
     317        response = super(TestNamedUrlSessionFormWizard, self).dispatch(request, *args, **kwargs)
     318        return response, self
     319
     320class TestNamedUrlCookieFormWizard(NamedUrlCookieWizardView):
     321
     322    def dispatch(self, request, *args, **kwargs):
     323        response = super(TestNamedUrlCookieFormWizard, self).dispatch(request, *args, **kwargs)
     324        return response, self
     325
     326
     327class NamedSessionFormTests(NamedFormTests, TestCase):
     328    formwizard_class = TestNamedUrlSessionFormWizard
     329    wizard_urlname = 'nwiz_session'
     330
     331class NamedCookieFormTests(NamedFormTests, TestCase):
     332    formwizard_class = TestNamedUrlCookieFormWizard
     333    wizard_urlname = 'nwiz_cookie'
     334
  • new file django/contrib/formtools/wizard/tests/namedwizardtests/urls.py

    diff --git a/django/contrib/formtools/wizard/tests/namedwizardtests/urls.py b/django/contrib/formtools/wizard/tests/namedwizardtests/urls.py
    new file mode 100644
    index 0000000..a97ca98
    - +  
     1from django.conf.urls.defaults import *
     2from django.contrib.formtools.wizard.tests.namedwizardtests.forms import (
     3    SessionContactWizard, CookieContactWizard, Page1, Page2, Page3, Page4)
     4
     5def get_named_session_wizard():
     6    return SessionContactWizard.as_view(
     7        [('form1', Page1), ('form2', Page2), ('form3', Page3), ('form4', Page4)],
     8        url_name='nwiz_session',
     9        done_step_name='nwiz_session_done'
     10    )
     11
     12def get_named_cookie_wizard():
     13    return CookieContactWizard.as_view(
     14        [('form1', Page1), ('form2', Page2), ('form3', Page3), ('form4', Page4)],
     15        url_name='nwiz_cookie',
     16        done_step_name='nwiz_cookie_done'
     17    )
     18
     19urlpatterns = patterns('',
     20    url(r'^nwiz_session/(?P<step>.+)/$', get_named_session_wizard(), name='nwiz_session'),
     21    url(r'^nwiz_session/$', get_named_session_wizard(), name='nwiz_session_start'),
     22    url(r'^nwiz_cookie/(?P<step>.+)/$', get_named_cookie_wizard(), name='nwiz_cookie'),
     23    url(r'^nwiz_cookie/$', get_named_cookie_wizard(), name='nwiz_cookie_start'),
     24)
  • new file django/contrib/formtools/wizard/tests/sessionstoragetests.py

    diff --git a/django/contrib/formtools/wizard/tests/sessionstoragetests.py b/django/contrib/formtools/wizard/tests/sessionstoragetests.py
    new file mode 100644
    index 0000000..b89e9c2
    - +  
     1from django.test import TestCase
     2
     3from django.contrib.formtools.wizard.tests.storagetests import *
     4from django.contrib.formtools.wizard.storage.session import SessionStorage
     5
     6class TestSessionStorage(TestStorage, TestCase):
     7    def get_storage(self):
     8        return SessionStorage
     9
  • new file django/contrib/formtools/wizard/tests/storagetests.py

    diff --git a/django/contrib/formtools/wizard/tests/storagetests.py b/django/contrib/formtools/wizard/tests/storagetests.py
    new file mode 100644
    index 0000000..897d062
    - +  
     1from datetime import datetime
     2
     3from django.http import HttpRequest
     4from django.conf import settings
     5from django.utils.importlib import import_module
     6
     7from django.contrib.auth.models import User
     8
     9def get_request():
     10    request = HttpRequest()
     11    engine = import_module(settings.SESSION_ENGINE)
     12    request.session = engine.SessionStore(None)
     13    return request
     14
     15class TestStorage(object):
     16    def setUp(self):
     17        self.testuser, created = User.objects.get_or_create(username='testuser1')
     18
     19    def test_current_step(self):
     20        request = get_request()
     21        storage = self.get_storage()('wizard1', request, None)
     22        my_step = 2
     23
     24        self.assertEqual(storage.get_current_step(), None)
     25
     26        storage.set_current_step(my_step)
     27        self.assertEqual(storage.get_current_step(), my_step)
     28
     29        storage.reset()
     30        self.assertEqual(storage.get_current_step(), None)
     31
     32        storage.set_current_step(my_step)
     33        storage2 = self.get_storage()('wizard2', request, None)
     34        self.assertEqual(storage2.get_current_step(), None)
     35
     36    def test_step_data(self):
     37        request = get_request()
     38        storage = self.get_storage()('wizard1', request, None)
     39        step1 = 'start'
     40        step_data1 = {'field1': 'data1',
     41                      'field2': 'data2',
     42                      'field3': datetime.now(),
     43                      'field4': self.testuser}
     44
     45        self.assertEqual(storage.get_step_data(step1), None)
     46
     47        storage.set_step_data(step1, step_data1)
     48        self.assertEqual(storage.get_step_data(step1), step_data1)
     49
     50        storage.reset()
     51        self.assertEqual(storage.get_step_data(step1), None)
     52
     53        storage.set_step_data(step1, step_data1)
     54        storage2 = self.get_storage()('wizard2', request, None)
     55        self.assertEqual(storage2.get_step_data(step1), None)
     56
     57    def test_extra_context(self):
     58        request = get_request()
     59        storage = self.get_storage()('wizard1', request, None)
     60        extra_context = {'key1': 'data1',
     61                         'key2': 'data2',
     62                         'key3': datetime.now(),
     63                         'key4': self.testuser}
     64
     65        self.assertEqual(storage.get_extra_context_data(), {})
     66
     67        storage.set_extra_context_data(extra_context)
     68        self.assertEqual(storage.get_extra_context_data(), extra_context)
     69
     70        storage.reset()
     71        self.assertEqual(storage.get_extra_context_data(), {})
     72
     73        storage.set_extra_context_data(extra_context)
     74        storage2 = self.get_storage()('wizard2', request, None)
     75        self.assertEqual(storage2.get_extra_context_data(), {})
     76
  • new file django/contrib/formtools/wizard/tests/wizardtests/__init__.py

    diff --git a/django/contrib/formtools/wizard/tests/wizardtests/__init__.py b/django/contrib/formtools/wizard/tests/wizardtests/__init__.py
    new file mode 100644
    index 0000000..9173cd8
    - +  
     1from django.contrib.formtools.wizard.tests.wizardtests.tests import *
     2 No newline at end of file
  • new file django/contrib/formtools/wizard/tests/wizardtests/forms.py

    diff --git a/django/contrib/formtools/wizard/tests/wizardtests/forms.py b/django/contrib/formtools/wizard/tests/wizardtests/forms.py
    new file mode 100644
    index 0000000..971ff4d
    - +  
     1import tempfile
     2
     3from django import forms
     4from django.core.files.storage import FileSystemStorage
     5from django.forms.formsets import formset_factory
     6from django.http import HttpResponse
     7from django.template import Template, Context
     8
     9from django.contrib.auth.models import User
     10
     11from django.contrib.formtools.wizard.views import WizardView
     12
     13temp_storage_location = tempfile.mkdtemp()
     14temp_storage = FileSystemStorage(location=temp_storage_location)
     15
     16class Page1(forms.Form):
     17    name = forms.CharField(max_length=100)
     18    user = forms.ModelChoiceField(queryset=User.objects.all())
     19    thirsty = forms.NullBooleanField()
     20
     21class Page2(forms.Form):
     22    address1 = forms.CharField(max_length=100)
     23    address2 = forms.CharField(max_length=100)
     24    file1 = forms.FileField()
     25
     26class Page3(forms.Form):
     27    random_crap = forms.CharField(max_length=100)
     28
     29Page4 = formset_factory(Page3, extra=2)
     30
     31class ContactWizard(WizardView):
     32    file_storage = temp_storage
     33
     34    def done(self, form_list, **kwargs):
     35        c = Context({
     36            'form_list': [x.cleaned_data for x in form_list],
     37            'all_cleaned_data': self.get_all_cleaned_data()
     38        })
     39
     40        for form in self.form_list.keys():
     41            c[form] = self.get_cleaned_data_for_step(form)
     42
     43        c['this_will_fail'] = self.get_cleaned_data_for_step('this_will_fail')
     44        return HttpResponse(Template('').render(c))
     45
     46    def get_context_data(self, form, **kwargs):
     47        context = super(ContactWizard, self).get_context_data(form, **kwargs)
     48        if self.storage.get_current_step() == 'form2':
     49            context.update({'another_var': True})
     50        return context
     51
     52class SessionContactWizard(ContactWizard):
     53    storage_name = 'django.contrib.formtools.wizard.storage.session.SessionStorage'
     54
     55class CookieContactWizard(ContactWizard):
     56    storage_name = 'django.contrib.formtools.wizard.storage.cookie.CookieStorage'
     57
  • new file django/contrib/formtools/wizard/tests/wizardtests/tests.py

    diff --git a/django/contrib/formtools/wizard/tests/wizardtests/models.py b/django/contrib/formtools/wizard/tests/wizardtests/models.py
    new file mode 100644
    index 0000000..e69de29
    diff --git a/django/contrib/formtools/wizard/tests/wizardtests/tests.py b/django/contrib/formtools/wizard/tests/wizardtests/tests.py
    new file mode 100644
    index 0000000..2dc8fa0
    - +  
     1import os
     2
     3from django.test import TestCase
     4from django.conf import settings
     5from django.contrib.auth.models import User
     6
     7from django.contrib.formtools import wizard
     8
     9class WizardTests(object):
     10    urls = 'django.contrib.formtools.wizard.tests.wizardtests.urls'
     11
     12    wizard_step_data = (
     13        {
     14            'form1-name': 'Pony',
     15            'form1-thirsty': '2',
     16        },
     17        {
     18            'form2-address1': '123 Main St',
     19            'form2-address2': 'Djangoland',
     20        },
     21        {
     22            'form3-random_crap': 'blah blah',
     23        },
     24        {
     25            'form4-INITIAL_FORMS': '0',
     26            'form4-TOTAL_FORMS': '2',
     27            'form4-MAX_NUM_FORMS': '0',
     28            'form4-0-random_crap': 'blah blah',
     29            'form4-1-random_crap': 'blah blah',
     30        }
     31    )
     32
     33    def setUp(self):
     34        self.testuser, created = User.objects.get_or_create(username='testuser1')
     35        self.wizard_step_data[0]['form1-user'] = self.testuser.pk
     36
     37        wizard_template_dirs = [os.path.join(os.path.dirname(wizard.__file__), 'templates')]
     38        settings.TEMPLATE_DIRS = list(settings.TEMPLATE_DIRS) + wizard_template_dirs
     39
     40    def tearDown(self):
     41        del settings.TEMPLATE_DIRS[-1]
     42
     43    def test_initial_call(self):
     44        response = self.client.get(self.wizard_url)
     45
     46        self.assertEqual(response.status_code, 200)
     47        self.assertEqual(response.context['form_step'], 'form1')
     48        self.assertEqual(response.context['form_step0'], 0)
     49        self.assertEqual(response.context['form_step1'], 1)
     50        self.assertEqual(response.context['form_last_step'], 'form4')
     51        self.assertEqual(response.context['form_prev_step'], None)
     52        self.assertEqual(response.context['form_next_step'], 'form2')
     53        self.assertEqual(response.context['form_step_count'], 4)
     54
     55    def test_form_post_error(self):
     56        response = self.client.post(self.wizard_url)
     57
     58        self.assertEqual(response.status_code, 200)
     59        self.assertEqual(response.context['form_step'], 'form1')
     60        self.assertEqual(response.context['form'].errors,
     61                         {'name': [u'This field is required.'],
     62                          'user': [u'This field is required.']})
     63
     64    def test_form_post_success(self):
     65        response = self.client.post(self.wizard_url, self.wizard_step_data[0])
     66
     67        self.assertEqual(response.status_code, 200)
     68        self.assertEqual(response.context['form_step'], 'form2')
     69        self.assertEqual(response.context['form_step0'], 1)
     70        self.assertEqual(response.context['form_prev_step'], 'form1')
     71        self.assertEqual(response.context['form_next_step'], 'form3')
     72
     73    def test_form_stepback(self):
     74        response = self.client.get(self.wizard_url)
     75
     76        self.assertEqual(response.status_code, 200)
     77        self.assertEqual(response.context['form_step'], 'form1')
     78
     79        response = self.client.post(self.wizard_url, self.wizard_step_data[0])
     80
     81        self.assertEqual(response.status_code, 200)
     82        self.assertEqual(response.context['form_step'], 'form2')
     83
     84        response = self.client.post(
     85            self.wizard_url,
     86            {'form_prev_step': response.context['form_prev_step']})
     87
     88        self.assertEqual(response.status_code, 200)
     89        self.assertEqual(response.context['form_step'], 'form1')
     90
     91    def test_template_context(self):
     92        response = self.client.get(self.wizard_url)
     93
     94        self.assertEqual(response.status_code, 200)
     95        self.assertEqual(response.context['form_step'], 'form1')
     96        self.assertEqual(response.context.get('another_var', None), None)
     97
     98        response = self.client.post(self.wizard_url, self.wizard_step_data[0])
     99
     100        self.assertEqual(response.status_code, 200)
     101        self.assertEqual(response.context['form_step'], 'form2')
     102        self.assertEqual(response.context.get('another_var', None), True)
     103
     104    def test_form_finish(self):
     105        response = self.client.get(self.wizard_url)
     106
     107        self.assertEqual(response.status_code, 200)
     108        self.assertEqual(response.context['form_step'], 'form1')
     109
     110        response = self.client.post(self.wizard_url, self.wizard_step_data[0])
     111
     112        self.assertEqual(response.status_code, 200)
     113        self.assertEqual(response.context['form_step'], 'form2')
     114
     115        post_data = self.wizard_step_data[1]
     116        post_data['form2-file1'] = open(__file__)
     117        response = self.client.post(self.wizard_url, post_data)
     118
     119        self.assertEqual(response.status_code, 200)
     120        self.assertEqual(response.context['form_step'], 'form3')
     121
     122        response = self.client.post(self.wizard_url, self.wizard_step_data[2])
     123
     124        self.assertEqual(response.status_code, 200)
     125        self.assertEqual(response.context['form_step'], 'form4')
     126
     127        response = self.client.post(self.wizard_url, self.wizard_step_data[3])
     128        self.assertEqual(response.status_code, 200)
     129
     130        all_data = response.context['form_list']
     131        self.assertEqual(all_data[1]['file1'].read(), open(__file__).read())
     132        del all_data[1]['file1']
     133        self.assertEqual(all_data, [
     134            {'name': u'Pony', 'thirsty': True, 'user': self.testuser},
     135            {'address1': u'123 Main St', 'address2': u'Djangoland'},
     136            {'random_crap': u'blah blah'},
     137            [{'random_crap': u'blah blah'},
     138             {'random_crap': u'blah blah'}]])
     139
     140    def test_cleaned_data(self):
     141        response = self.client.get(self.wizard_url)
     142        self.assertEqual(response.status_code, 200)
     143        response = self.client.post(self.wizard_url, self.wizard_step_data[0])
     144        self.assertEqual(response.status_code, 200)
     145        post_data = self.wizard_step_data[1]
     146        post_data['form2-file1'] = open(__file__)
     147        response = self.client.post(self.wizard_url, post_data)
     148        self.assertEqual(response.status_code, 200)
     149        response = self.client.post(self.wizard_url, self.wizard_step_data[2])
     150        self.assertEqual(response.status_code, 200)
     151        response = self.client.post(self.wizard_url, self.wizard_step_data[3])
     152        self.assertEqual(response.status_code, 200)
     153
     154        all_data = response.context['all_cleaned_data']
     155        self.assertEqual(all_data['file1'].read(), open(__file__).read())
     156        del all_data['file1']
     157        self.assertEqual(all_data, {
     158            'name': u'Pony', 'thirsty': True, 'user': self.testuser,
     159            'address1': u'123 Main St', 'address2': u'Djangoland',
     160            'random_crap': u'blah blah', 'formset-form4': [
     161                {'random_crap': u'blah blah'},
     162                {'random_crap': u'blah blah'}]})
     163
     164    def test_manipulated_data(self):
     165        response = self.client.get(self.wizard_url)
     166        self.assertEqual(response.status_code, 200)
     167        response = self.client.post(self.wizard_url, self.wizard_step_data[0])
     168        self.assertEqual(response.status_code, 200)
     169        post_data = self.wizard_step_data[1]
     170        post_data['form2-file1'] = open(__file__)
     171        response = self.client.post(self.wizard_url, post_data)
     172        self.assertEqual(response.status_code, 200)
     173        response = self.client.post(self.wizard_url, self.wizard_step_data[2])
     174        self.assertEqual(response.status_code, 200)
     175        self.client.cookies.pop('sessionid', None)
     176        self.client.cookies.pop('wizard_cookie_contact_wizard', None)
     177        response = self.client.post(self.wizard_url, self.wizard_step_data[3])
     178        self.assertEqual(response.status_code, 200)
     179        self.assertEqual(response.context.get('form_step', None), 'form1')
     180
     181class SessionWizardTests(WizardTests, TestCase):
     182    wizard_url = '/wiz_session/'
     183
     184class CookieWizardTests(WizardTests, TestCase):
     185    wizard_url = '/wiz_cookie/'
     186
  • new file django/contrib/formtools/wizard/tests/wizardtests/urls.py

    diff --git a/django/contrib/formtools/wizard/tests/wizardtests/urls.py b/django/contrib/formtools/wizard/tests/wizardtests/urls.py
    new file mode 100644
    index 0000000..e305397
    - +  
     1from django.conf.urls.defaults import *
     2from django.contrib.formtools.wizard.tests.wizardtests.forms import (
     3    SessionContactWizard, CookieContactWizard, Page1, Page2, Page3, Page4)
     4
     5urlpatterns = patterns('',
     6    url(r'^wiz_session/$', SessionContactWizard.as_view(
     7        [('form1', Page1),
     8         ('form2', Page2),
     9         ('form3', Page3),
     10         ('form4', Page4)])),
     11    url(r'^wiz_cookie/$', CookieContactWizard.as_view(
     12        [('form1', Page1),
     13         ('form2', Page2),
     14         ('form3', Page3),
     15         ('form4', Page4)])),
     16)
  • new file django/contrib/formtools/wizard/views.py

    diff --git a/django/contrib/formtools/wizard/views.py b/django/contrib/formtools/wizard/views.py
    new file mode 100644
    index 0000000..f00a428
    - +  
     1import re
     2
     3from django import forms
     4from django.core.urlresolvers import reverse
     5from django.forms import formsets
     6from django.http import HttpResponseRedirect
     7from django.views.generic import TemplateView
     8from django.utils.datastructures import SortedDict
     9from django.utils.decorators import classonlymethod
     10
     11from django.contrib.formtools.wizard.storage import get_storage, NoFileStorageConfigured
     12
     13def normalize_name(name):
     14    new = re.sub('(((?<=[a-z])[A-Z])|([A-Z](?![A-Z]|$)))', '_\\1', name)
     15    return new.lower().strip('_')
     16
     17class WizardView(TemplateView):
     18    """
     19    The WizardView is used to create multi-page forms and handles all the
     20    storage and validation stuff. The wizard is based on Django's generic
     21    class based views.
     22    """
     23    storage_name = None
     24    form_list = None
     25    initial_list = None
     26    instance_list = None
     27    condition_list = None
     28    template_name = 'formtools/wizard/wizard.html'
     29
     30    @classonlymethod
     31    def as_view(cls, *args, **kwargs):
     32        """
     33        This method is used within urls.py to create unique formwizard
     34        instances for every request. We need to override this method because
     35        we add some kwargs which are needed to make the formwizard usable.
     36        """
     37        initkwargs = cls.get_initkwargs(*args, **kwargs)
     38        return super(WizardView, cls).as_view(**initkwargs)
     39
     40    @classmethod
     41    def get_initkwargs(cls, form_list,
     42            initial_list=None, instance_list=None, condition_list=None):
     43        """
     44        Creates a dict with all needed parameters for the form wizard instances.
     45
     46        * `form_list` - is a list of forms. The list entries can be single form
     47          classes or tuples of (`step_name`, `form_class`). If you pass a list
     48          of forms, the formwizard will convert the class list to
     49          (`zero_based_counter`, `form_class`). This is needed to access the
     50          form for a specific step.
     51        * `initial_list` - contains a dictionary of initial data dictionaries.
     52          The key should be equal to the `step_name` in the `form_list` (or
     53          the str of the zero based counter - if no step_names added in the
     54          `form_list`)
     55        * `instance_list` - contains a dictionary of instance objects. This list
     56          is only used when `ModelForm`s are used. The key should be equal to
     57          the `step_name` in the `form_list`. Same rules as for `initial_list`
     58          apply.
     59        * `condition_list` - contains a dictionary of boolean values or
     60          callables. If the value of for a specific `step_name` is callable it
     61          will be called with the formwizard instance as the only argument.
     62          If the return value is true, the step's form will be used.
     63        """
     64        kwargs = {
     65            'initial_list': initial_list or {},
     66            'instance_list': instance_list or {},
     67            'condition_list': condition_list or {},
     68        }
     69        init_form_list = SortedDict()
     70
     71        assert len(form_list) > 0, 'at least one form is needed'
     72
     73        # walk through the passed form list
     74        for i, form in enumerate(form_list):
     75            if isinstance(form, (list, tuple)):
     76                # if the element is a tuple, add the tuple to the new created
     77                # sorted dictionary.
     78                init_form_list[unicode(form[0])] = form[1]
     79            else:
     80                # if not, add the form with a zero based counter as unicode
     81                init_form_list[unicode(i)] = form
     82
     83        # walk through the ne created list of forms
     84        for form in init_form_list.values():
     85            if issubclass(form, formsets.BaseFormSet):
     86                # if the element is based on BaseFormSet (FormSet/ModelFormSet)
     87                # we need to override the form variable.
     88                form = form.form
     89            # check if any form contains a FileField, if yes, we need a
     90            # file_storage added to the formwizard (by subclassing).
     91            for field in form.base_fields.values():
     92                if (isinstance(field, forms.FileField) and
     93                        not hasattr(cls, 'file_storage')):
     94                    raise NoFileStorageConfigured
     95
     96        # build the kwargs for the formwizard instances
     97        kwargs['form_list'] = init_form_list
     98        return kwargs
     99
     100    def __repr__(self):
     101        return '<%s: form_list: %s, initial_list: %s>' % (
     102            self.__class__.__name__, self.form_list, self.initial_list)
     103
     104    def dispatch(self, request, *args, **kwargs):
     105        """
     106        This method gets called by the routing engine. The first argument is
     107        `request` which contains a `HttpRequest` instance.
     108        The request is stored in `self.request` for later use. The storage
     109        instance is stored in `self.storage`.
     110
     111        After processing the request using the `dispatch` method, the
     112        response gets updated by the storage engine (for example add cookies).
     113        """
     114        # add the storage engine to the current formwizard instance
     115        self.storage = get_storage(
     116            self.storage_name, normalize_name(self.__class__.__name__),
     117            request, getattr(self, 'file_storage', None))
     118        response = super(WizardView, self).dispatch(request, *args, **kwargs)
     119
     120        # update the response (e.g. adding cookies)
     121        self.storage.update_response(response)
     122        return response
     123
     124    def get_form_list(self):
     125        """
     126        This method returns a form_list based on the initial form list but
     127        checks if there is a condition method/value in the condition_list.
     128        If an entry exists in the condition list, it will call/read the value
     129        and respect the result. (True means add the form, False means ignore
     130        the form)
     131
     132        The form_list is always generated on the fly because condition methods
     133        could use data from other (maybe previous forms).
     134        """
     135        form_list = SortedDict()
     136        for form_key, form_class in self.form_list.items():
     137            # try to fetch the value from condition list, by default, the form
     138            # gets passed to the new list.
     139            condition = self.condition_list.get(form_key, True)
     140            if callable(condition):
     141                # call the value if needed, passes the current instance.
     142                condition = condition(self)
     143            if condition:
     144                form_list[form_key] = form_class
     145        return form_list
     146
     147    def get(self, request, *args, **kwargs):
     148        """
     149        This method handles GET requests.
     150
     151        If a GET request reaches this point, the wizard assumes that the user
     152        just starts at the first step or wants to restart the process.
     153        The data of the wizard will be resetted before rendering the first step.
     154        """
     155        self.reset_wizard()
     156
     157        # if there is an extra_context item in the kwars, pass the data to the
     158        # storage engine.
     159        self.update_extra_context(kwargs.get('extra_context', {}))
     160
     161        # reset the current step to the first step.
     162        self.storage.set_current_step(self.get_first_step())
     163        return self.render(self.get_form())
     164
     165    def post(self, *args, **kwargs):
     166        """
     167        This method handles POST requests.
     168
     169        The wizard will render either the current step (if form validation
     170        wasn't successful), the next step (if the current step was stored
     171        successful) or the done view (if no more steps are available)
     172        """
     173        # if there is an extra_context item in the kwargs,
     174        # pass the data to the storage engine.
     175        self.update_extra_context(kwargs.get('extra_context', {}))
     176
     177        # Look for a form_prev_step element in the posted data which contains
     178        # a valid step name. If one was found, render the requested form.
     179        # (This makes stepping back a lot easier).
     180        form_prev_step = self.request.POST.get('form_prev_step', None)
     181        if form_prev_step and form_prev_step in self.get_form_list():
     182            self.storage.set_current_step(form_prev_step)
     183            current_step = self.determine_step()
     184            form = self.get_form(data=self.storage.get_step_data(current_step),
     185                files=self.storage.get_step_files(current_step))
     186        else:
     187            # TODO: refactor the form-was-refreshed code
     188            # Check if form was refreshed
     189            current_step = self.determine_step()
     190            prev_step = self.get_prev_step(step=current_step)
     191            for value in self.request.POST:
     192                if (prev_step and not value.startswith(current_step) and
     193                        value.startswith(prev_step)):
     194                    # form refreshed, change current step
     195                    self.storage.set_current_step(prev_step)
     196                    break
     197
     198            # get the form for the current step
     199            form = self.get_form(data=self.request.POST,
     200                                 files=self.request.FILES)
     201
     202            # and try to validate
     203            if form.is_valid():
     204                # if the form is valid, store the cleaned data and files.
     205                current_step = self.determine_step()
     206                self.storage.set_step_data(current_step, self.process_step(form))
     207                self.storage.set_step_files(current_step, self.process_step_files(form))
     208
     209                # check if the current step is the last step
     210                if current_step == self.get_last_step():
     211                    # no more steps, render done view
     212                    return self.render_done(form, **kwargs)
     213                else:
     214                    # proceed to the next step
     215                    return self.render_next_step(form)
     216        return self.render(form)
     217
     218    def render_next_step(self, form, **kwargs):
     219        """
     220        THis method gets called when the next step/form should be rendered.
     221        `form` contains the last/current form.
     222        """
     223        next_step = self.get_next_step()
     224        # get the form instance based on the data from the storage backend
     225        # (if available).
     226        new_form = self.get_form(next_step,
     227                                 data=self.storage.get_step_data(next_step),
     228                                 files=self.storage.get_step_files(next_step))
     229
     230        # change the stored current step
     231        self.storage.set_current_step(next_step)
     232        return self.render(new_form, **kwargs)
     233
     234    def render_done(self, form, **kwargs):
     235        """
     236        This method gets called when all forms passed. The method should also
     237        re-validate all steps to prevent manipulation. If any form don't
     238        validate, `render_revalidation_failure` should get called.
     239        If everything is fine call `done`.
     240        """
     241        final_form_list = []
     242        # walk through the form list and try to validate the data again.
     243        for form_key in self.get_form_list():
     244            form_obj = self.get_form(
     245                step=form_key,
     246                data=self.storage.get_step_data(form_key),
     247                files=self.storage.get_step_files(form_key)
     248            )
     249            if not form_obj.is_valid():
     250                return self.render_revalidation_failure(form_key,
     251                                                        form_obj,
     252                                                        **kwargs)
     253            final_form_list.append(form_obj)
     254
     255        # render the done view and reset the wizard before returning the
     256        # response. This is needed to prevent from rendering done with the
     257        # same data twice.
     258        done_response = self.done(final_form_list, **kwargs)
     259        self.reset_wizard()
     260        return done_response
     261
     262    def get_form_prefix(self, step=None, form=None):
     263        """
     264        Returns the prefix which will be used when calling the actual form for
     265        the given step. `step` contains the step-name, `form` the form which
     266        will be called with the returned prefix.
     267
     268        If no step is given, the form_prefix will determine the current step
     269        automatically.
     270        """
     271        if step is None:
     272            step = self.determine_step()
     273        return str(step)
     274
     275    def get_form_initial(self, step):
     276        """
     277        Returns a dictionary which will be passed to the form for `step`
     278        as `initial`. If no initial data was provied while initializing the
     279        form wizard, a empty dictionary will be returned.
     280        """
     281        return self.initial_list.get(step, {})
     282
     283    def get_form_instance(self, step):
     284        """
     285        Returns a object which will be passed to the form for `step`
     286        as `instance`. If no instance object was provied while initializing
     287        the form wizard, None be returned.
     288        """
     289        return self.instance_list.get(step, None)
     290
     291    def get_form(self, step=None, data=None, files=None):
     292        """
     293        Constructs the form for a given `step`. If no `step` is defined, the
     294        current step will be determined automatically.
     295
     296        The form will be initialized using the `data` argument to prefill the
     297        new form. If needed, instance or queryset (for `ModelForm` or
     298        `ModelFormSet`) will be added too.
     299        """
     300        if step is None:
     301            step = self.determine_step()
     302
     303        # prepare the kwargs for the form instance.
     304        kwargs = {
     305            'data': data,
     306            'files': files,
     307            'prefix': self.get_form_prefix(step, self.form_list[step]),
     308            'initial': self.get_form_initial(step),
     309        }
     310        if issubclass(self.form_list[step], forms.ModelForm):
     311            # If the form is based on ModelForm, add instance if available.
     312            kwargs.update({'instance': self.get_form_instance(step)})
     313        elif issubclass(self.form_list[step], forms.models.BaseModelFormSet):
     314            # If the form is based on ModelFormSet, add queryset if available.
     315            kwargs.update({'queryset': self.get_form_instance(step)})
     316        return self.form_list[step](**kwargs)
     317
     318    def process_step(self, form):
     319        """
     320        This method is used to postprocess the form data. By default, it
     321        returns the raw `form.data` dictionary.
     322        """
     323        return self.get_form_step_data(form)
     324
     325    def process_step_files(self, form):
     326        """
     327        This method is used to postprocess the form files. By default, it
     328        returns the raw `form.files` dictionary.
     329        """
     330        return self.get_form_step_files(form)
     331
     332    def render_revalidation_failure(self, step, form, **kwargs):
     333        """
     334        Gets called when a form doesn't validate when rendering the done
     335        view. By default, it changed the current step to failing forms step
     336        and renders the form.
     337        """
     338        self.storage.set_current_step(step)
     339        return self.render(form, **kwargs)
     340
     341    def get_form_step_data(self, form):
     342        """
     343        Is used to return the raw form data. You may use this method to
     344        manipulate the data.
     345        """
     346        return form.data
     347
     348    def get_form_step_files(self, form):
     349        """
     350        Is used to return the raw form files. You may use this method to
     351        manipulate the data.
     352        """
     353        return form.files
     354
     355    def get_all_cleaned_data(self):
     356        """
     357        Returns a merged dictionary of all step cleaned_data dictionaries.
     358        If a step contains a `FormSet`, the key will be prefixed with formset
     359        and contain a list of the formset' cleaned_data dictionaries.
     360        """
     361        cleaned_data = {}
     362        for form_key in self.get_form_list():
     363            form_obj = self.get_form(
     364                step=form_key,
     365                data=self.storage.get_step_data(form_key),
     366                files=self.storage.get_step_files(form_key)
     367            )
     368            if form_obj.is_valid():
     369                if isinstance(form_obj.cleaned_data, (tuple, list)):
     370                    cleaned_data.update({
     371                        'formset-%s' % form_key: form_obj.cleaned_data
     372                    })
     373                else:
     374                    cleaned_data.update(form_obj.cleaned_data)
     375        return cleaned_data
     376
     377    def get_cleaned_data_for_step(self, step):
     378        """
     379        Returns the cleaned data for a given `step`. Before returning the
     380        cleaned data, the stored values are being revalidated through the
     381        form. If the data doesn't validate, None will be returned.
     382        """
     383        if step in self.form_list:
     384            form_obj = self.get_form(step=step,
     385                data=self.storage.get_step_data(step),
     386                files=self.storage.get_step_files(step))
     387            if form_obj.is_valid():
     388                return form_obj.cleaned_data
     389        return None
     390
     391    def determine_step(self):
     392        """
     393        Returns the current step. If no current step is stored in the storage
     394        backend, the first step will be returned.
     395        """
     396        return self.storage.get_current_step() or self.get_first_step()
     397
     398    def get_first_step(self):
     399        """
     400        Returns the name of the first step.
     401        """
     402        return self.get_form_list().keys()[0]
     403
     404    def get_last_step(self):
     405        """
     406        Returns the name of the last step.
     407        """
     408        return self.get_form_list().keys()[-1]
     409
     410    def get_next_step(self, step=None):
     411        """
     412        Returns the next step after the given `step`. If no more steps are
     413        available, None will be returned. If the `step` argument is None, the
     414        current step will be determined automatically.
     415        """
     416        if step is None:
     417            step = self.determine_step()
     418        form_list = self.get_form_list()
     419        key = form_list.keyOrder.index(step) + 1
     420        if len(form_list.keyOrder) > key:
     421            return form_list.keyOrder[key]
     422        return None
     423
     424    def get_prev_step(self, step=None):
     425        """
     426        Returns the previous step before the given `step`. If there are no
     427        steps available, None will be returned. If the `step` argument is
     428        None, the current step will be determined automatically.
     429        """
     430        if step is None:
     431            step = self.determine_step()
     432        form_list = self.get_form_list()
     433        key = form_list.keyOrder.index(step) - 1
     434        if key >= 0:
     435            return form_list.keyOrder[key]
     436        return None
     437
     438    def get_step_index(self, step=None):
     439        """
     440        Returns the index for the given `step` name. If no step is given,
     441        the current step will be used to get the index.
     442        """
     443        if step is None:
     444            step = self.determine_step()
     445        return self.get_form_list().keyOrder.index(step)
     446
     447    def get_num_steps(self):
     448        """
     449        Returns the total number of steps/forms in this the wizard.
     450        """
     451        return len(self.get_form_list())
     452
     453    def reset_wizard(self):
     454        """
     455        Resets the user-state of the wizard.
     456        """
     457        self.storage.reset()
     458
     459    def get_context_data(self, form, *args, **kwargs):
     460        """
     461        Returns the template context for a step. You can overwrite this method
     462        to add more data for all or some steps.
     463        Example:
     464
     465        .. code-block:: python
     466
     467            class MyWizard(FormWizard):
     468                def get_context_data(self, form, **kwargs):
     469                    context = super(MyWizard, self).get_context_data(form, **kwargs)
     470                    if self.storage.get_current_step() == 'my_step_name':
     471                        context.update({'another_var': True})
     472                    return context
     473        """
     474        context = super(WizardView, self).get_context_data(*args, **kwargs)
     475        context.update({
     476            'extra_context': self.get_extra_context(),
     477            'form_step': self.determine_step(),
     478            'form_first_step': self.get_first_step(),
     479            'form_last_step': self.get_last_step(),
     480            'form_prev_step': self.get_prev_step(),
     481            'form_next_step': self.get_next_step(),
     482            'form_step0': int(self.get_step_index()),
     483            'form_step1': int(self.get_step_index()) + 1,
     484            'form_step_count': self.get_num_steps(),
     485            'form': form,
     486        })
     487        # if there is an extra_context item in the kwars, pass the data to the
     488        # storage engine.
     489        self.update_extra_context(kwargs.get('extra_context', {}))
     490        return context
     491
     492    def get_extra_context(self):
     493        """
     494        Returns the extra data currently stored in the storage backend.
     495        """
     496        return self.storage.get_extra_context_data()
     497
     498    def update_extra_context(self, new_context):
     499        """
     500        Updates the currently stored extra context data. Already stored extra
     501        context will be kept!
     502        """
     503        context = self.get_extra_context()
     504        context.update(new_context)
     505        return self.storage.set_extra_context_data(context)
     506
     507    def render(self, form, **kwargs):
     508        """
     509        Renders the acutal `form`. This method can be used to pre-process data
     510        or conditionally skip steps.
     511        """
     512        return self.render_template(form, **kwargs)
     513
     514    def render_template(self, form=None, **kwargs):
     515        """
     516        Returns a `HttpResponse` containing the rendered form step. Available
     517        template context variables are:
     518
     519         * `extra_context` - current extra context data
     520         * `form_step` - name of the current step
     521         * `form_first_step` - name of the first step
     522         * `form_last_step` - name of the last step
     523         * `form_prev_step`- name of the previous step
     524         * `form_next_step` - name of the next step
     525         * `form_step0` - index of the current step
     526         * `form_step1` - index of the current step as a 1-index
     527         * `form_step_count` - total number of steps
     528         * `form` - form instance of the current step
     529        """
     530
     531        form = form or self.get_form()
     532        context = self.get_context_data(form, **kwargs)
     533        return self.render_to_response(context)
     534
     535    def done(self, form_list, **kwargs):
     536        """
     537        This method muss be overrided by a subclass to process to form data
     538        after processing all steps.
     539        """
     540        raise NotImplementedError("Your %s class has not defined a done() "
     541            "method, which is required." % self.__class__.__name__)
     542
     543
     544class SessionWizardView(WizardView):
     545    """
     546    A WizardView with pre-configured SessionStorage backend.
     547    """
     548    storage_name = 'django.contrib.formtools.wizard.storage.session.SessionStorage'
     549
     550
     551class CookieWizardView(WizardView):
     552    """
     553    A WizardView with pre-configured CookieStorage backend.
     554    """
     555    storage_name = 'django.contrib.formtools.wizard.storage.cookie.CookieStorage'
     556
     557
     558class NamedUrlWizardView(WizardView):
     559    """
     560    A WizardView with url-named steps support.
     561    """
     562    url_name = None
     563    done_step_name = None
     564
     565    @classmethod
     566    def get_initkwargs(cls, *args, **kwargs):
     567        """
     568        We require a url_name to reverse urls later. Additionally users can
     569        pass a done_step_name to change the url-name of the "done" view.
     570        """
     571        extra_kwargs = {
     572            'done_step_name': 'done'
     573        }
     574        assert 'url_name' in kwargs, 'url name is needed to resolve correct wizard urls'
     575        extra_kwargs['url_name'] = kwargs['url_name']
     576        del kwargs['url_name']
     577
     578        if 'done_step_name' in kwargs:
     579            extra_kwargs['done_step_name'] = kwargs['done_step_name']
     580            del kwargs['done_step_name']
     581
     582        initkwargs = super(NamedUrlWizardView, cls).get_initkwargs(*args, **kwargs)
     583        initkwargs.update(extra_kwargs)
     584
     585        assert initkwargs['done_step_name'] not in initkwargs['form_list'], \
     586            'step name "%s" is reserved for "done" view' % initkwargs['done_step_name']
     587
     588        return initkwargs
     589
     590    def get(self, *args, **kwargs):
     591        """
     592        This renders the form or, if needed, does the http redirects.
     593        """
     594        self.update_extra_context(kwargs.get('extra_context', {}))
     595        step_url = kwargs.get('step', None)
     596        if step_url is None:
     597            if 'reset' in self.request.GET:
     598                self.reset_wizard()
     599                self.storage.set_current_step(self.get_first_step())
     600
     601            if self.request.GET:
     602                query_string = "?%s" % self.request.GET.urlencode()
     603            else:
     604                query_string = ""
     605            next_step_url = reverse(self.url_name, kwargs={
     606                'step': self.determine_step()
     607            }) + query_string
     608            return HttpResponseRedirect(next_step_url)
     609        else:
     610            # is the current step the "done" name/view?
     611            if step_url == self.done_step_name:
     612                last_step = self.get_last_step()
     613                return self.render_done(self.get_form(step=last_step,
     614                    data=self.storage.get_step_data(last_step),
     615                    files=self.storage.get_step_files(last_step)
     616                ), **kwargs)
     617
     618            # is the url step name not equal to the step in the storage?
     619            # if yes, change the step in the storage (if name exists)
     620            if step_url == self.determine_step():
     621                # url step name and storage step name are equal, render!
     622                return self.render(self.get_form(
     623                    data=self.storage.get_current_step_data(),
     624                    files=self.storage.get_current_step_files()
     625                ), **kwargs)
     626            if step_url in self.get_form_list():
     627                self.storage.set_current_step(step_url)
     628                return self.render(self.get_form(
     629                    data=self.storage.get_current_step_data(),
     630                    files=self.storage.get_current_step_files()
     631                ), **kwargs)
     632            else:
     633                # invalid step name, reset to first and redirect.
     634                self.storage.set_current_step(self.get_first_step())
     635                first_step_url = reverse(self.url_name, kwargs={
     636                    'step': self.storage.get_current_step()
     637                })
     638                return HttpResponseRedirect(first_step_url)
     639
     640    def post(self, *args, **kwargs):
     641        """
     642        Do a redirect if user presses the prev. step button. The rest of this
     643        is super'd from FormWizard.
     644        """
     645        prev_step = self.request.POST.get('form_prev_step', None)
     646        if prev_step and prev_step in self.get_form_list():
     647            self.storage.set_current_step(prev_step)
     648            current_step_url = reverse(self.url_name, kwargs={
     649                'step': self.storage.get_current_step(),
     650            })
     651            return HttpResponseRedirect(current_step_url)
     652        return super(NamedUrlWizardView, self).post(*args, **kwargs)
     653
     654    def render_next_step(self, form, **kwargs):
     655        """
     656        When using the NamedUrlFormWizard, we have to redirect to update the
     657        browser's url to match the shown step.
     658        """
     659        next_step = self.get_next_step()
     660        next_step_url = reverse(self.url_name, kwargs={
     661            'step': next_step,
     662        })
     663        self.storage.set_current_step(next_step)
     664        return HttpResponseRedirect(next_step_url)
     665
     666    def render_revalidation_failure(self, failed_step, form, **kwargs):
     667        """
     668        When a step fails, we have to redirect the user to the first failing
     669        step.
     670        """
     671        self.storage.set_current_step(failed_step)
     672        return HttpResponseRedirect(reverse(self.url_name, kwargs={
     673            'step': self.storage.get_current_step()
     674        }))
     675
     676    def render_done(self, form, **kwargs):
     677        """
     678        When rendering the done view, we have to redirect first (if the url
     679        name doesn't fit).
     680        """
     681        step_url = kwargs.get('step', None)
     682        if step_url != self.done_step_name:
     683            return HttpResponseRedirect(reverse(self.url_name, kwargs={
     684                'step': self.done_step_name
     685            }))
     686        return super(NamedUrlWizardView, self).render_done(form, **kwargs)
     687
     688class NamedUrlSessionWizardView(NamedUrlWizardView):
     689    """
     690    A NamedUrlWizardView with pre-configured SessionStorage backend.
     691    """
     692    storage_name = 'django.contrib.formtools.wizard.storage.session.SessionStorage'
     693
     694
     695class NamedUrlCookieWizardView(NamedUrlWizardView):
     696    """
     697    A NamedUrlFormWizard with pre-configured CookieStorageBackend.
     698    """
     699    storage_name = 'django.contrib.formtools.wizard.storage.cookie.CookieStorage'
     700
  • new file django/core/signing.py

    diff --git a/django/core/signing.py b/django/core/signing.py
    new file mode 100644
    index 0000000..70fcc44
    - +  
     1"""
     2Functions for creating and restoring url-safe signed JSON objects.
     3
     4The format used looks like this:
     5
     6>>> signed.dumps("hello")
     7'ImhlbGxvIg.RjVSUCt6S64WBilMYxG89-l0OA8'
     8
     9There are two components here, separatad by a '.'. The first component is a
     10URLsafe base64 encoded JSON of the object passed to dumps(). The second
     11component is a base64 encoded hmac/SHA1 hash of "$first_component.$secret"
     12
     13signed.loads(s) checks the signature and returns the deserialised object.
     14If the signature fails, a BadSignature exception is raised.
     15
     16>>> signed.loads("ImhlbGxvIg.RjVSUCt6S64WBilMYxG89-l0OA8")
     17u'hello'
     18>>> signed.loads("ImhlbGxvIg.RjVSUCt6S64WBilMYxG89-l0OA8-modified")
     19...
     20BadSignature: Signature failed: RjVSUCt6S64WBilMYxG89-l0OA8-modified
     21
     22You can optionally compress the JSON prior to base64 encoding it to save
     23space, using the compress=True argument. This checks if compression actually
     24helps and only applies compression if the result is a shorter string:
     25
     26>>> signed.dumps(range(1, 20), compress=True)
     27'.eJwFwcERACAIwLCF-rCiILN47r-GyZVJsNgkxaFxoDgxcOHGxMKD_T7vhAml.oFq6lAAEbkHXBHfGnVX7Qx6NlZ8'
     28
     29The fact that the string is compressed is signalled by the prefixed '.' at the
     30start of the base64 JSON.
     31
     32There are 65 url-safe characters: the 64 used by url-safe base64 and the '.'.
     33These functions make use of all of them.
     34"""
     35import hmac
     36import base64
     37import time
     38
     39from django.conf import settings
     40from django.utils.hashcompat import sha_constructor
     41from django.utils import baseconv, simplejson
     42from django.utils.crypto import constant_time_compare
     43from django.utils.encoding import force_unicode, smart_str
     44from django.utils.importlib import import_module
     45
     46class BadSignature(Exception):
     47    """
     48    Signature does not match
     49    """
     50    pass
     51
     52
     53class SignatureExpired(BadSignature):
     54    """
     55    Signature timestamp is older than required max_age
     56    """
     57    pass
     58
     59
     60def b64_encode(s):
     61    return base64.urlsafe_b64encode(s).strip('=')
     62
     63
     64def b64_decode(s):
     65    pad = '=' * (-len(s) % 4)
     66    return base64.urlsafe_b64decode(s + pad)
     67
     68
     69def base64_hmac(value, key):
     70    return b64_encode((hmac.new(key, value, sha_constructor).digest()))
     71
     72
     73def get_cookie_signer():
     74    modpath = settings.SIGNING_BACKEND
     75    module, attr = modpath.rsplit('.', 1)
     76    try:
     77        mod = import_module(module)
     78    except ImportError, e:
     79        raise ImproperlyConfigured(
     80            'Error importing cookie signer %s: "%s"' % (modpath, e))
     81    try:
     82        Signer = getattr(mod, attr)
     83    except AttributeError, e:
     84        raise ImproperlyConfigured(
     85            'Error importing cookie signer %s: "%s"' % (modpath, e))
     86    return Signer('django.http.cookies' + settings.SECRET_KEY)
     87
     88
     89def dumps(obj, key=None, salt='', compress=False):
     90    """
     91    Returns URL-safe, sha1 signed base64 compressed JSON string. If key is
     92    None, settings.SECRET_KEY is used instead.
     93
     94    If compress is True (not the default) checks if compressing using zlib can
     95    save some space. Prepends a '.' to signify compression. This is included
     96    in the signature, to protect against zip bombs.
     97
     98    salt can be used to further salt the hash, in case you're worried
     99    that the NSA might try to brute-force your SHA-1 protected secret.
     100    """
     101    json = simplejson.dumps(obj, separators=(',', ':'))
     102
     103    # Flag for if it's been compressed or not
     104    is_compressed = False
     105
     106    if compress:
     107        # Avoid zlib dependency unless compress is being used
     108        import zlib
     109        compressed = zlib.compress(json)
     110        if len(compressed) < (len(json) - 1):
     111            json = compressed
     112            is_compressed = True
     113    base64d = b64_encode(json)
     114    if is_compressed:
     115        base64d = '.' + base64d
     116    return TimestampSigner(key).sign(base64d, salt=salt)
     117
     118
     119def loads(s, key=None, salt='', max_age=None):
     120    """
     121    Reverse of dumps(), raises BadSignature if signature fails
     122    """
     123    base64d = smart_str(
     124        TimestampSigner(key).unsign(s, salt=salt, max_age=max_age))
     125    decompress = False
     126    if base64d[0] == '.':
     127        # It's compressed; uncompress it first
     128        base64d = base64d[1:]
     129        decompress = True
     130    json = b64_decode(base64d)
     131    if decompress:
     132        import zlib
     133        jsond = zlib.decompress(json)
     134    return simplejson.loads(json)
     135
     136
     137class Signer(object):
     138    def __init__(self, key=None, sep=':'):
     139        self.sep = sep
     140        self.key = key or settings.SECRET_KEY
     141
     142    def signature(self, value, salt=''):
     143        # Derive a new key from the SECRET_KEY, using the optional salt
     144        key = sha_constructor(salt + 'signer' + self.key).hexdigest()
     145        return base64_hmac(value, key)
     146
     147    def sign(self, value, salt=''):
     148        value = smart_str(value)
     149        return '%s%s%s' % (value, self.sep, self.signature(value, salt=salt))
     150
     151    def unsign(self, signed_value, salt=''):
     152        signed_value = smart_str(signed_value)
     153        if not self.sep in signed_value:
     154            raise BadSignature('No "%s" found in value' % self.sep)
     155        value, sig = signed_value.rsplit(self.sep, 1)
     156        expected = self.signature(value, salt=salt)
     157        if constant_time_compare(sig, expected):
     158            return force_unicode(value)
     159        # Important: do NOT include the expected sig in the exception
     160        # message, since it might leak up to an attacker!
     161        # TODO: Can we enforce this in the Django debug templates?
     162        raise BadSignature('Signature "%s" does not match' % sig)
     163
     164
     165class TimestampSigner(Signer):
     166    def timestamp(self):
     167        return baseconv.base62.from_int(int(time.time()))
     168
     169    def sign(self, value, salt=''):
     170        value = smart_str('%s%s%s' % (value, self.sep, self.timestamp()))
     171        return '%s%s%s' % (value, self.sep, self.signature(value, salt=salt))
     172
     173    def unsign(self, value, salt='', max_age=None):
     174        value, timestamp = super(TimestampSigner, self).unsign(
     175            value, salt=salt).rsplit(self.sep, 1)
     176        timestamp = baseconv.base62.to_int(timestamp)
     177        if max_age is not None:
     178            # Check timestamp is not older than max_age
     179            age = time.time() - timestamp
     180            if age > max_age:
     181                raise SignatureExpired(
     182                    'Signature age %s > %s seconds' % (age, max_age))
     183        return value
  • django/http/__init__.py

    diff --git a/django/http/__init__.py b/django/http/__init__.py
    index 0d28ec0..0a0d665 100644
    a b from django.utils.encoding import smart_str, iri_to_uri, force_unicode  
    122122from django.utils.http import cookie_date
    123123from django.http.multipartparser import MultiPartParser
    124124from django.conf import settings
     125from django.core import signing
    125126from django.core.files import uploadhandler
    126127from utils import *
    127128
    absolute_http_url_re = re.compile(r"^https?://", re.I)  
    132133class Http404(Exception):
    133134    pass
    134135
     136RAISE_ERROR = object()
     137
    135138class HttpRequest(object):
    136139    """A basic HTTP request."""
    137140
    class HttpRequest(object):  
    170173        # Rather than crash if this doesn't happen, we encode defensively.
    171174        return '%s%s' % (self.path, self.META.get('QUERY_STRING', '') and ('?' + iri_to_uri(self.META.get('QUERY_STRING', ''))) or '')
    172175
     176    def get_signed_cookie(self, key, default=RAISE_ERROR, salt='',
     177                          max_age=None):
     178        """
     179        Attempts to return a signed cookie. If the signature fails or the
     180        cookie has expired, raises an exception... unless you provide the
     181        default argument in which case that value will be returned instead.
     182        """
     183        try:
     184            cookie_value = self.COOKIES[key].encode('utf-8')
     185        except KeyError:
     186            if default is not RAISE_ERROR:
     187                return default
     188            else:
     189                raise
     190        try:
     191            value = signing.get_cookie_signer().unsign(
     192                cookie_value, salt=key + salt, max_age=max_age)
     193        except signing.BadSignature:
     194            if default is not RAISE_ERROR:
     195                return default
     196            else:
     197                raise
     198        return value
     199
    173200    def build_absolute_uri(self, location=None):
    174201        """
    175202        Builds an absolute URI from the location and the variables available in
    class HttpResponse(object):  
    584611        if httponly:
    585612            self.cookies[key]['httponly'] = True
    586613
     614    def set_signed_cookie(self, key, value, salt='', **kwargs):
     615        value = signing.get_cookie_signer().sign(value, salt=key + salt)
     616        return self.set_cookie(key, value, **kwargs)
     617
    587618    def delete_cookie(self, key, path='/', domain=None):
    588619        self.set_cookie(key, max_age=0, path=path, domain=domain,
    589620                        expires='Thu, 01-Jan-1970 00:00:00 GMT')
    def str_to_unicode(s, encoding):  
    686717        return unicode(s, encoding, 'replace')
    687718    else:
    688719        return s
    689 
  • new file django/utils/baseconv.py

    diff --git a/django/utils/baseconv.py b/django/utils/baseconv.py
    new file mode 100644
    index 0000000..db152f7
    - +  
     1"""
     2Convert numbers from base 10 integers to base X strings and back again.
     3
     4Sample usage:
     5
     6>>> base20 = BaseConverter('0123456789abcdefghij')
     7>>> base20.from_int(1234)
     8'31e'
     9>>> base20.to_int('31e')
     101234
     11"""
     12
     13
     14class BaseConverter(object):
     15    decimal_digits = "0123456789"
     16
     17    def __init__(self, digits):
     18        self.digits = digits
     19
     20    def from_int(self, i):
     21        return self.convert(i, self.decimal_digits, self.digits)
     22
     23    def to_int(self, s):
     24        return int(self.convert(s, self.digits, self.decimal_digits))
     25
     26    def convert(number, fromdigits, todigits):
     27        # Based on http://code.activestate.com/recipes/111286/
     28        if str(number)[0] == '-':
     29            number = str(number)[1:]
     30            neg = 1
     31        else:
     32            neg = 0
     33
     34        # make an integer out of the number
     35        x = 0
     36        for digit in str(number):
     37            x = x * len(fromdigits) + fromdigits.index(digit)
     38
     39        # create the result in base 'len(todigits)'
     40        if x == 0:
     41            res = todigits[0]
     42        else:
     43            res = ""
     44            while x > 0:
     45                digit = x % len(todigits)
     46                res = todigits[digit] + res
     47                x = int(x / len(todigits))
     48            if neg:
     49                res = '-' + res
     50        return res
     51    convert = staticmethod(convert)
     52
     53base2 = BaseConverter('01')
     54base16 = BaseConverter('0123456789ABCDEF')
     55base36 = BaseConverter('0123456789abcdefghijklmnopqrstuvwxyz')
     56base62 = BaseConverter(
     57    '0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'
     58)
  • docs/index.txt

    diff --git a/docs/index.txt b/docs/index.txt
    index 9135d32..8b4ae53 100644
    a b Other batteries included  
    171171    * :doc:`Comments <ref/contrib/comments/index>` | :doc:`Moderation <ref/contrib/comments/moderation>` | :doc:`Custom comments <ref/contrib/comments/custom>`
    172172    * :doc:`Content types <ref/contrib/contenttypes>`
    173173    * :doc:`Cross Site Request Forgery protection <ref/contrib/csrf>`
     174    * :doc:`Cryptographic signing <topics/signing>`
    174175    * :doc:`Databrowse <ref/contrib/databrowse>`
    175176    * :doc:`E-mail (sending) <topics/email>`
    176177    * :doc:`Flatpages <ref/contrib/flatpages>`
  • docs/ref/request-response.txt

    diff --git a/docs/ref/request-response.txt b/docs/ref/request-response.txt
    index 6281120..e17c0a7 100644
    a b Methods  
    240240
    241241   Example: ``"http://example.com/music/bands/the_beatles/?print=true"``
    242242
     243.. method:: HttpRequest.get_signed_cookie(key, default=RAISE_ERROR, salt='', max_age=None)
     244
     245   .. versionadded:: 1.4
     246
     247   Returns a cookie value for a signed cookie, or raises a
     248   :class:`~django.core.signing.BadSignature` exception if the signature is
     249   no longer valid. If you provide the ``default`` argument the exception
     250   will be suppressed and that default value will be returned instead.
     251
     252   The optional ``salt`` argument can be used to provide extra protection
     253   against brute force attacks on your secret key. If supplied, the
     254   ``max_age`` argument will be checked against the signed timestamp
     255   attached to the cookie value to ensure the cookie is not older than
     256   ``max_age`` seconds.
     257
     258   For example::
     259
     260          >>> request.get_signed_cookie('name')
     261          'Tony'
     262          >>> request.get_signed_cookie('name', salt='name-salt')
     263          'Tony' # assuming cookie was set using the same salt
     264          >>> request.get_signed_cookie('non-existing-cookie')
     265          ...
     266          KeyError: 'non-existing-cookie'
     267          >>> request.get_signed_cookie('non-existing-cookie', False)
     268          False
     269          >>> request.get_signed_cookie('cookie-that-was-tampered-with')
     270          ...
     271          BadSignature: ...
     272          >>> request.get_signed_cookie('name', max_age=60)
     273          ...
     274          SignatureExpired: Signature age 1677.3839159 > 60 seconds
     275          >>> request.get_signed_cookie('name', False, max_age=60)
     276          False
     277
     278   See :ref:`cryptographic signing <topics-signing>` for more information.
     279
    243280.. method:: HttpRequest.is_secure()
    244281
    245282   Returns ``True`` if the request is secure; that is, if it was made with
    Methods  
    618655    .. _`cookie Morsel`: http://docs.python.org/library/cookie.html#Cookie.Morsel
    619656    .. _HTTPOnly: http://www.owasp.org/index.php/HTTPOnly
    620657
     658.. method:: HttpResponse.set_signed_cookie(key, value='', salt='', max_age=None, expires=None, path='/', domain=None, secure=None, httponly=False)
     659
     660    .. versionadded:: 1.4
     661
     662    Like :meth:`~HttpResponse.set_cookie()`, but
     663    :ref:`cryptographically signs <topics-signing>` the cookie before setting
     664    it. Use in conjunction with :meth:`HttpRequest.get_signed_cookie`.
     665    You can use the optional ``salt`` argument for added key strength, but
     666    you will need to remember to pass it to the corresponding
     667    :meth:`HttpRequest.get_signed_cookie` call.
     668
    621669.. method:: HttpResponse.delete_cookie(key, path='/', domain=None)
    622670
    623671    Deletes the cookie with the given key. Fails silently if the key doesn't
  • docs/ref/settings.txt

    diff --git a/docs/ref/settings.txt b/docs/ref/settings.txt
    index f5f1226..38977e8 100644
    a b See :tfilter:`allowed date format strings <date>`.  
    16471647
    16481648See also ``DATE_FORMAT`` and ``SHORT_DATETIME_FORMAT``.
    16491649
     1650.. setting:: SIGNING_BACKEND
     1651
     1652SIGNING_BACKEND
     1653---------------
     1654
     1655.. versionadded:: 1.4
     1656
     1657Default: 'django.core.signing.TimestampSigner'
     1658
     1659The backend used for signing cookies and other data.
     1660
     1661See also the :ref:`topics-signing` documentation.
     1662
    16501663.. setting:: SITE_ID
    16511664
    16521665SITE_ID
  • docs/topics/index.txt

    diff --git a/docs/topics/index.txt b/docs/topics/index.txt
    index 49a03be..84f9e9f 100644
    a b Introductions to all the key parts of Django you'll need to know:  
    1818   auth
    1919   cache
    2020   conditional-view-processing
     21   signing
    2122   email
    2223   i18n/index
    2324   logging
  • new file docs/topics/signing.txt

    diff --git a/docs/topics/signing.txt b/docs/topics/signing.txt
    new file mode 100644
    index 0000000..c94462c
    - +  
     1.. _topics-signing:
     2
     3=====================
     4Cryptographic signing
     5=====================
     6
     7.. module:: django.core.signing
     8   :synopsis: Django's signing framework.
     9
     10.. versionadded:: 1.4
     11
     12The golden rule of Web application security is to never trust data from
     13untrusted sources. Sometimes it can be useful to pass data through an
     14untrusted medium. Cryptographically signed values can be passed through an
     15untrusted channel safe in the knowledge that any tampering will be detected.
     16
     17Django provides both a low-level API for signing values and a high-level API
     18for setting and reading signed cookies, one of the most common uses of
     19signing in Web applications.
     20
     21You may also find signing useful for the following:
     22
     23    * Generating "recover my account" URLs for sending to users who have
     24      lost their password.
     25
     26    * Ensuring data stored in hidden form fields has not been tampered with.
     27
     28    * Generating one-time secret URLs for allowing temporary access to a
     29      protected resource, for example a downloadable file that a user has
     30      paid for.
     31
     32Protecting the SECRET_KEY
     33=========================
     34
     35When you create a new Django project using :djadmin:`startproject`, the
     36``settings.py`` file it generates automatically gets a random
     37:setting:`SECRET_KEY` value. This value is the key to securing signed
     38data -- it is vital you keep this secure, or attackers could use it to
     39generate their own signed values.
     40
     41Using the low-level API
     42=======================
     43
     44.. class:: Signer
     45
     46Django's signing methods live in the ``django.core.signing`` module.
     47To sign a value, first instantiate a ``Signer`` instance::
     48
     49    >>> from django.core.signing import Signer
     50    >>> signer = Signer()
     51    >>> value = signer.sign('My string')
     52    >>> value
     53    'My string:GdMGD6HNQ_qdgxYP8yBZAdAIV1w'
     54
     55The signature is appended to the end of the string, following the colon.
     56You can retrieve the original value using the ``unsign`` method::
     57
     58    >>> original = signer.unsign(value)
     59    >>> original
     60    u'My string'
     61
     62If the signature or value have been altered in any way, a
     63``django.core.signing.BadSigature`` exception will be raised::
     64
     65    >>> value += 'm'
     66    >>> try:
     67    ...    original = signer.unsign(value)
     68    ... except signing.BadSignature:
     69    ...    print "Tampering detected!"
     70
     71By default, the ``Signer`` class uses the :setting:`SECRET_KEY` setting to
     72generate signatures. You can use a different secret by passing it to the
     73``Signer`` constructor::
     74
     75    >>> signer = Signer('my-other-secret')
     76    >>> value = signer.sign('My string')
     77    >>> value
     78    'My string:EkfQJafvGyiofrdGnuthdxImIJw'
     79
     80Using the salt argument
     81-----------------------
     82
     83If you do not wish to use the same key for every signing operation in your
     84application, you can use the optional ``salt`` argument to the ``sign`` and
     85``unsign`` methods to further strengthen your :setting:`SECRET_KEY` against
     86brute force attacks. Using a salt will cause a new key to be derived from
     87both the salt and your :setting:`SECRET_KEY`::
     88
     89    >>> signer = Signer()
     90    >>> signer.sign('My string')
     91    'My string:GdMGD6HNQ_qdgxYP8yBZAdAIV1w'
     92    >>> signer.sign('My string', salt='extra')
     93    'My string:Ee7vGi-ING6n02gkcJ-QLHg6vFw'
     94    >>> signer.unsign('My string:Ee7vGi-ING6n02gkcJ-QLHg6vFw', salt='extra')
     95    u'My string'
     96
     97Unlike your :setting:`SECRET_KEY`, your salt argument does not need to stay
     98secret.
     99
     100Verifying timestamped values
     101----------------------------
     102
     103.. class:: TimestampSigner
     104
     105``TimestampSigner`` is a subclass of :class:`~Signer` that appends a signed
     106timestamp to the value. This allows you to confirm that a signed value was
     107created within a specified period of time::
     108
     109    >>> from django.core.signing import TimestampSigner
     110    >>> signer = TimestampSigner()
     111    >>> value = signer.sign('hello')
     112    >>> value
     113    'hello:1NMg5H:oPVuCqlJWmChm1rA2lyTUtelC-c'
     114    >>> signer.unsign(value)
     115    u'hello'
     116    >>> signer.unsign(value, max_age=10)
     117    ...
     118    SignatureExpired: Signature age 15.5289158821 > 10 seconds
     119    >>> signer.unsign(value, max_age=20)
     120    u'hello'
     121
     122Protecting complex data structures
     123----------------------------------
     124
     125If you wish to protect a list, tuple or dictionary you can do so using the
     126signing module's dumps and loads functions. These imitate Python's pickle
     127module, but uses JSON serialization under the hood. JSON ensures that even
     128if your :setting:`SECRET_KEY` is stolen an attacker will not be able to
     129execute arbitrary commands by exploiting the pickle format.::
     130
     131    >>> from django.core import signing
     132    >>> value = signing.dumps({"foo": "bar"})
     133    >>> value
     134    'eyJmb28iOiJiYXIifQ:1NMg1b:zGcDE4-TCkaeGzLeW9UQwZesciI'
     135    >>> signing.loads(value)
     136    {'foo': 'bar'}
  • new file tests/regressiontests/signed_cookies_tests/models.py

    diff --git a/tests/regressiontests/signed_cookies_tests/__init__.py b/tests/regressiontests/signed_cookies_tests/__init__.py
    new file mode 100644
    index 0000000..e69de29
    diff --git a/tests/regressiontests/signed_cookies_tests/models.py b/tests/regressiontests/signed_cookies_tests/models.py
    new file mode 100644
    index 0000000..71abcc5
    - +  
     1# models.py file for tests to run.
  • new file tests/regressiontests/signed_cookies_tests/tests.py

    diff --git a/tests/regressiontests/signed_cookies_tests/tests.py b/tests/regressiontests/signed_cookies_tests/tests.py
    new file mode 100644
    index 0000000..c28892a
    - +  
     1import time
     2
     3from django.core import signing
     4from django.http import HttpRequest, HttpResponse
     5from django.test import TestCase
     6
     7class SignedCookieTest(TestCase):
     8
     9    def test_can_set_and_read_signed_cookies(self):
     10        response = HttpResponse()
     11        response.set_signed_cookie('c', 'hello')
     12        self.assertIn('c', response.cookies)
     13        self.assertTrue(response.cookies['c'].value.startswith('hello:'))
     14        request = HttpRequest()
     15        request.COOKIES['c'] = response.cookies['c'].value
     16        value = request.get_signed_cookie('c')
     17        self.assertEqual(value, u'hello')
     18
     19    def test_can_use_salt(self):
     20        response = HttpResponse()
     21        response.set_signed_cookie('a', 'hello', salt='one')
     22        request = HttpRequest()
     23        request.COOKIES['a'] = response.cookies['a'].value
     24        value = request.get_signed_cookie('a', salt='one')
     25        self.assertEqual(value, u'hello')
     26        self.assertRaises(signing.BadSignature,
     27            request.get_signed_cookie, 'a', salt='two')
     28
     29    def test_detects_tampering(self):
     30        response = HttpResponse()
     31        response.set_signed_cookie('c', 'hello')
     32        request = HttpRequest()
     33        request.COOKIES['c'] = response.cookies['c'].value[:-2] + '$$'
     34        self.assertRaises(signing.BadSignature,
     35            request.get_signed_cookie, 'c')
     36
     37    def test_default_argument_supresses_exceptions(self):
     38        response = HttpResponse()
     39        response.set_signed_cookie('c', 'hello')
     40        request = HttpRequest()
     41        request.COOKIES['c'] = response.cookies['c'].value[:-2] + '$$'
     42        self.assertEqual(request.get_signed_cookie('c', default=None), None)
     43
     44    def test_max_age_argument(self):
     45        value = u'hello'
     46        _time = time.time
     47        time.time = lambda: 123456789
     48        try:
     49            response = HttpResponse()
     50            response.set_signed_cookie('c', value)
     51            request = HttpRequest()
     52            request.COOKIES['c'] = response.cookies['c'].value
     53            self.assertEqual(request.get_signed_cookie('c'), value)
     54
     55            time.time = lambda: 123456800
     56            self.assertEqual(request.get_signed_cookie('c', max_age=12), value)
     57            self.assertEqual(request.get_signed_cookie('c', max_age=11), value)
     58            self.assertRaises(signing.SignatureExpired,
     59                request.get_signed_cookie, 'c', max_age = 10)
     60        finally:
     61            time.time = _time
  • new file tests/regressiontests/signing/models.py

    diff --git a/tests/regressiontests/signing/__init__.py b/tests/regressiontests/signing/__init__.py
    new file mode 100644
    index 0000000..e69de29
    diff --git a/tests/regressiontests/signing/models.py b/tests/regressiontests/signing/models.py
    new file mode 100644
    index 0000000..71abcc5
    - +  
     1# models.py file for tests to run.
  • new file tests/regressiontests/signing/tests.py

    diff --git a/tests/regressiontests/signing/tests.py b/tests/regressiontests/signing/tests.py
    new file mode 100644
    index 0000000..0c28f53
    - +  
     1import time
     2
     3from django.core import signing
     4from django.test import TestCase
     5from django.utils.encoding import force_unicode
     6from django.utils.hashcompat import sha_constructor
     7
     8class TestSigner(TestCase):
     9
     10    def test_signature(self):
     11        "signature() method should generate a signature"
     12        signer = signing.Signer('predictable-secret')
     13        signer2 = signing.Signer('predictable-secret2')
     14        for s in (
     15            'hello',
     16            '3098247:529:087:',
     17            u'\u2019'.encode('utf8'),
     18        ):
     19            self.assertEqual(
     20                signer.signature(s),
     21                signing.base64_hmac(s, sha_constructor(
     22                    'signer' + 'predictable-secret'
     23                ).hexdigest())
     24            )
     25            self.assertNotEqual(signer.signature(s), signer2.signature(s))
     26
     27    def test_signature_with_salt(self):
     28        "signature(value, salt=...) should work"
     29        signer = signing.Signer('predictable-secret')
     30        self.assertEqual(
     31            signer.signature('hello', salt='extra-salt'),
     32            signing.base64_hmac('hello', sha_constructor(
     33                'extra-salt' + 'signer' + 'predictable-secret'
     34            ).hexdigest())
     35        )
     36        self.assertNotEqual(
     37            signer.signature('hello', salt='one'),
     38            signer.signature('hello', salt='two'))
     39
     40    def test_sign_unsign(self):
     41        "sign/unsign should be reversible"
     42        signer = signing.Signer('predictable-secret')
     43        examples = (
     44            'q;wjmbk;wkmb',
     45            '3098247529087',
     46            '3098247:529:087:',
     47            'jkw osanteuh ,rcuh nthu aou oauh ,ud du',
     48            u'\u2019',
     49        )
     50        for example in examples:
     51            self.assertNotEqual(
     52                force_unicode(example), force_unicode(signer.sign(example)))
     53            self.assertEqual(example, signer.unsign(signer.sign(example)))
     54
     55    def unsign_detects_tampering(self):
     56        "unsign should raise an exception if the value has been tampered with"
     57        signer = signing.Signer('predictable-secret')
     58        value = 'Another string'
     59        signed_value = signer.sign(value)
     60        transforms = (
     61            lambda s: s.upper(),
     62            lambda s: s + 'a',
     63            lambda s: 'a' + s[1:],
     64            lambda s: s.replace(':', ''),
     65        )
     66        self.assertEqual(value, signer.unsign(signed_value))
     67        for transform in transforms:
     68            self.assertRaises(
     69                signing.BadSignature, signer.unsign, transform(signed_value))
     70
     71    def test_dumps_loads(self):
     72        "dumps and loads be reversible for any JSON serializable object"
     73        objects = (
     74            ['a', 'list'],
     75            'a string',
     76            u'a unicode string \u2019',
     77            {'a': 'dictionary'},
     78        )
     79        for o in objects:
     80            self.assertNotEqual(o, signing.dumps(o))
     81            self.assertEqual(o, signing.loads(signing.dumps(o)))
     82
     83    def test_decode_detects_tampering(self):
     84        "loads should raise exception for tampered objects"
     85        transforms = (
     86            lambda s: s.upper(),
     87            lambda s: s + 'a',
     88            lambda s: 'a' + s[1:],
     89            lambda s: s.replace(':', ''),
     90        )
     91        value = {
     92            'foo': 'bar',
     93            'baz': 1,
     94        }
     95        encoded = signing.dumps(value)
     96        self.assertEqual(value, signing.loads(encoded))
     97        for transform in transforms:
     98            self.assertRaises(
     99                signing.BadSignature, signing.loads, transform(encoded))
     100
     101class TestTimestampSigner(TestCase):
     102
     103    def test_timestamp_signer(self):
     104        value = u'hello'
     105        _time = time.time
     106        time.time = lambda: 123456789
     107        try:
     108            signer = signing.TimestampSigner('predictable-key')
     109            ts = signer.sign(value)
     110            self.assertNotEqual(ts,
     111                signing.Signer('predictable-key').sign(value))
     112
     113            self.assertEqual(signer.unsign(ts), value)
     114            time.time = lambda: 123456800
     115            self.assertEqual(signer.unsign(ts, max_age=12), value)
     116            self.assertEqual(signer.unsign(ts, max_age=11), value)
     117            self.assertRaises(
     118                signing.SignatureExpired, signer.unsign, ts, max_age=10)
     119        finally:
     120            time.time = _time
  • new file tests/regressiontests/utils/baseconv.py

    diff --git a/tests/regressiontests/utils/baseconv.py b/tests/regressiontests/utils/baseconv.py
    new file mode 100644
    index 0000000..90fe77f
    - +  
     1from unittest import TestCase
     2from django.utils.baseconv import base2, base16, base36, base62
     3
     4class TestBaseConv(TestCase):
     5
     6    def test_baseconv(self):
     7        nums = [-10 ** 10, 10 ** 10] + range(-100, 100)
     8        for convertor in [base2, base16, base36, base62]:
     9            for i in nums:
     10                self.assertEqual(
     11                    i, convertor.to_int(convertor.from_int(i))
     12                )
     13
  • tests/regressiontests/utils/tests.py

    diff --git a/tests/regressiontests/utils/tests.py b/tests/regressiontests/utils/tests.py
    index 5c4c060..2b61627 100644
    a b from timesince import *  
    1717from datastructures import *
    1818from tzinfo import *
    1919from datetime_safe import *
     20from baseconv import *
Back to Top