Michael Merickel
2017-04-30 3f14d63c009ae7f101b7aeb4525bab2dfe66fa11
restore the ``ICSRFStoragePolicy.check_csrf_token`` api
3 files modified
166 ■■■■■ changed files
pyramid/csrf.py 35 ●●●● patch | view | raw | blame | history
pyramid/interfaces.py 10 ●●●●● patch | view | raw | blame | history
pyramid/tests/test_csrf.py 121 ●●●●● patch | view | raw | blame | history
pyramid/csrf.py
@@ -47,6 +47,12 @@
        generating a new one if needed."""
        return request.session.get_csrf_token()
    def check_csrf_token(self, request, supplied_token):
        """ Returns ``True`` if the ``supplied_token`` is valid."""
        expected_token = self.get_csrf_token(request)
        return not strings_differ(
            bytes_(expected_token), bytes_(supplied_token))
@implementer(ICSRFStoragePolicy)
class SessionCSRFStoragePolicy(object):
@@ -81,6 +87,12 @@
        if not token:
            token = self.new_csrf_token(request)
        return token
    def check_csrf_token(self, request, supplied_token):
        """ Returns ``True`` if the ``supplied_token`` is valid."""
        expected_token = self.get_csrf_token(request)
        return not strings_differ(
            bytes_(expected_token), bytes_(supplied_token))
@implementer(ICSRFStoragePolicy)
@@ -133,6 +145,12 @@
            token = self.new_csrf_token(request)
        return token
    def check_csrf_token(self, request, supplied_token):
        """ Returns ``True`` if the ``supplied_token`` is valid."""
        expected_token = self.get_csrf_token(request)
        return not strings_differ(
            bytes_(expected_token), bytes_(supplied_token))
def get_csrf_token(request):
    """ Get the currently active CSRF token for the request passed, generating
@@ -140,6 +158,7 @@
    calls the equivalent method in the chosen CSRF protection implementation.
    .. versionadded :: 1.9
    """
    registry = request.registry
    csrf = registry.getUtility(ICSRFStoragePolicy)
@@ -152,6 +171,7 @@
    chosen CSRF protection implementation.
    .. versionadded :: 1.9
    """
    registry = request.registry
    csrf = registry.getUtility(ICSRFStoragePolicy)
@@ -171,9 +191,8 @@
    function, the string ``X-CSRF-Token`` will be used to look up the token in
    ``request.headers``.
    If the value supplied by post or by header doesn't match the value supplied
    by ``policy.get_csrf_token()`` (where ``policy`` is an implementation of
    :class:`pyramid.interfaces.ICSRFStoragePolicy`), and ``raises`` is
    If the value supplied by post or by header cannot be verified by the
    :class:`pyramid.interfaces.ICSRFStoragePolicy`, and ``raises`` is
    ``True``, this function will raise an
    :exc:`pyramid.exceptions.BadCSRFToken` exception. If the values differ
    and ``raises`` is ``False``, this function will return ``False``.  If the
@@ -191,7 +210,10 @@
       a header.
    .. versionchanged:: 1.9
       Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf`
       Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` and updated
       to use the configured :class:`pyramid.interfaces.ICSRFStoragePolicy` to
       verify the CSRF token.
    """
    supplied_token = ""
    # We first check the headers for a csrf token, as that is significantly
@@ -207,8 +229,8 @@
    if supplied_token == "" and token is not None:
        supplied_token = request.POST.get(token, "")
    expected_token = get_csrf_token(request)
    if strings_differ(bytes_(expected_token), bytes_(supplied_token)):
    policy = request.registry.getUtility(ICSRFStoragePolicy)
    if not policy.check_csrf_token(request, text_(supplied_token)):
        if raises:
            raise BadCSRFToken('check_csrf_token(): Invalid token')
        return False
@@ -239,6 +261,7 @@
    .. versionchanged:: 1.9
       Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf`
    """
    def _fail(reason):
        if raises:
pyramid/interfaces.py
@@ -1010,6 +1010,16 @@
        """
    def check_csrf_token(request, token):
        """ Determine if the supplied ``token`` is valid. Most implementations
        should simply compare the ``token`` to the current value of
        ``get_csrf_token`` but it is possible to verify the token using
        any mechanism necessary using this method.
        Returns ``True`` if the ``token`` is valid, otherwise ``False``.
        """
class IIntrospector(Interface):
    def get(category_name, discriminator, default=None):
pyramid/tests/test_csrf.py
@@ -1,61 +1,20 @@
import unittest
from zope.interface.interfaces import ComponentLookupError
from pyramid import testing
from pyramid.config import Configurator
from pyramid.events import BeforeRender
class Test_get_csrf_token(unittest.TestCase):
    def setUp(self):
        self.config = testing.setUp()
    def _callFUT(self, *args, **kwargs):
        from pyramid.csrf import get_csrf_token
        return get_csrf_token(*args, **kwargs)
    def test_no_override_csrf_utility_registered(self):
        request = testing.DummyRequest()
        self._callFUT(request)
    def test_success(self):
        self.config.set_csrf_storage_policy(DummyCSRF())
        request = testing.DummyRequest()
        csrf_token = self._callFUT(request)
        self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8')
class Test_new_csrf_token(unittest.TestCase):
    def setUp(self):
        self.config = testing.setUp()
    def _callFUT(self, *args, **kwargs):
        from pyramid.csrf import new_csrf_token
        return new_csrf_token(*args, **kwargs)
    def test_no_override_csrf_utility_registered(self):
        request = testing.DummyRequest()
        self._callFUT(request)
    def test_success(self):
        self.config.set_csrf_storage_policy(DummyCSRF())
        request = testing.DummyRequest()
        csrf_token = self._callFUT(request)
        self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b')
class TestLegacySessionCSRFStoragePolicy(unittest.TestCase):
    class MockSession(object):
        def __init__(self, current_token='02821185e4c94269bdc38e6eeae0a2f8'):
            self.current_token = current_token
        def new_csrf_token(self):
            return 'e5e9e30a08b34ff9842ff7d2b958c14b'
            self.current_token = 'e5e9e30a08b34ff9842ff7d2b958c14b'
            return self.current_token
        def get_csrf_token(self):
            return '02821185e4c94269bdc38e6eeae0a2f8'
            return self.current_token
    def _makeOne(self):
        from pyramid.csrf import LegacySessionCSRFStoragePolicy
@@ -85,6 +44,13 @@
            policy.new_csrf_token(request),
            'e5e9e30a08b34ff9842ff7d2b958c14b'
        )
    def test_check_csrf_token(self):
        request = DummyRequest(session=self.MockSession('foo'))
        policy = self._makeOne()
        self.assertTrue(policy.check_csrf_token(request, 'foo'))
        self.assertFalse(policy.check_csrf_token(request, 'bar'))
class TestSessionCSRFStoragePolicy(unittest.TestCase):
@@ -120,6 +86,16 @@
        token = policy.new_csrf_token(request)
        self.assertNotEqual(token, 'foo')
        self.assertEqual(token, policy.get_csrf_token(request))
    def test_check_csrf_token(self):
        request = DummyRequest(session={})
        policy = self._makeOne()
        self.assertFalse(policy.check_csrf_token(request, 'foo'))
        request.session = {'_csrft_': 'foo'}
        self.assertTrue(policy.check_csrf_token(request, 'foo'))
        self.assertFalse(policy.check_csrf_token(request, 'bar'))
class TestCookieCSRFStoragePolicy(unittest.TestCase):
@@ -189,6 +165,57 @@
        self.assertNotEqual(token, 'foo')
        self.assertEqual(token, policy.get_csrf_token(request))
    def test_check_csrf_token(self):
        request = DummyRequest()
        policy = self._makeOne()
        self.assertFalse(policy.check_csrf_token(request, 'foo'))
        request.cookies = {'csrf_token': 'foo'}
        self.assertTrue(policy.check_csrf_token(request, 'foo'))
        self.assertFalse(policy.check_csrf_token(request, 'bar'))
class Test_get_csrf_token(unittest.TestCase):
    def setUp(self):
        self.config = testing.setUp()
    def _callFUT(self, *args, **kwargs):
        from pyramid.csrf import get_csrf_token
        return get_csrf_token(*args, **kwargs)
    def test_no_override_csrf_utility_registered(self):
        request = testing.DummyRequest()
        self._callFUT(request)
    def test_success(self):
        self.config.set_csrf_storage_policy(DummyCSRF())
        request = testing.DummyRequest()
        csrf_token = self._callFUT(request)
        self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8')
class Test_new_csrf_token(unittest.TestCase):
    def setUp(self):
        self.config = testing.setUp()
    def _callFUT(self, *args, **kwargs):
        from pyramid.csrf import new_csrf_token
        return new_csrf_token(*args, **kwargs)
    def test_no_override_csrf_utility_registered(self):
        request = testing.DummyRequest()
        self._callFUT(request)
    def test_success(self):
        self.config.set_csrf_storage_policy(DummyCSRF())
        request = testing.DummyRequest()
        csrf_token = self._callFUT(request)
        self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b')
class Test_check_csrf_token(unittest.TestCase):
    def setUp(self):