pyramid/csrf.py | ●●●●● patch | view | raw | blame | history | |
pyramid/interfaces.py | ●●●●● patch | view | raw | blame | history | |
pyramid/tests/test_csrf.py | ●●●●● patch | view | raw | blame | history |
pyramid/csrf.py
@@ -47,6 +47,12 @@ generating a new one if needed.""" return request.session.get_csrf_token() def check_csrf_token(self, request, supplied_token): """ Returns ``True`` if the ``supplied_token`` is valid.""" expected_token = self.get_csrf_token(request) return not strings_differ( bytes_(expected_token), bytes_(supplied_token)) @implementer(ICSRFStoragePolicy) class SessionCSRFStoragePolicy(object): @@ -81,6 +87,12 @@ if not token: token = self.new_csrf_token(request) return token def check_csrf_token(self, request, supplied_token): """ Returns ``True`` if the ``supplied_token`` is valid.""" expected_token = self.get_csrf_token(request) return not strings_differ( bytes_(expected_token), bytes_(supplied_token)) @implementer(ICSRFStoragePolicy) @@ -133,6 +145,12 @@ token = self.new_csrf_token(request) return token def check_csrf_token(self, request, supplied_token): """ Returns ``True`` if the ``supplied_token`` is valid.""" expected_token = self.get_csrf_token(request) return not strings_differ( bytes_(expected_token), bytes_(supplied_token)) def get_csrf_token(request): """ Get the currently active CSRF token for the request passed, generating @@ -140,6 +158,7 @@ calls the equivalent method in the chosen CSRF protection implementation. .. versionadded :: 1.9 """ registry = request.registry csrf = registry.getUtility(ICSRFStoragePolicy) @@ -152,6 +171,7 @@ chosen CSRF protection implementation. .. versionadded :: 1.9 """ registry = request.registry csrf = registry.getUtility(ICSRFStoragePolicy) @@ -171,9 +191,8 @@ function, the string ``X-CSRF-Token`` will be used to look up the token in ``request.headers``. If the value supplied by post or by header doesn't match the value supplied by ``policy.get_csrf_token()`` (where ``policy`` is an implementation of :class:`pyramid.interfaces.ICSRFStoragePolicy`), and ``raises`` is If the value supplied by post or by header cannot be verified by the :class:`pyramid.interfaces.ICSRFStoragePolicy`, and ``raises`` is ``True``, this function will raise an :exc:`pyramid.exceptions.BadCSRFToken` exception. If the values differ and ``raises`` is ``False``, this function will return ``False``. If the @@ -191,7 +210,10 @@ a header. .. versionchanged:: 1.9 Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` and updated to use the configured :class:`pyramid.interfaces.ICSRFStoragePolicy` to verify the CSRF token. """ supplied_token = "" # We first check the headers for a csrf token, as that is significantly @@ -207,8 +229,8 @@ if supplied_token == "" and token is not None: supplied_token = request.POST.get(token, "") expected_token = get_csrf_token(request) if strings_differ(bytes_(expected_token), bytes_(supplied_token)): policy = request.registry.getUtility(ICSRFStoragePolicy) if not policy.check_csrf_token(request, text_(supplied_token)): if raises: raise BadCSRFToken('check_csrf_token(): Invalid token') return False @@ -239,6 +261,7 @@ .. versionchanged:: 1.9 Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` """ def _fail(reason): if raises: pyramid/interfaces.py
@@ -1010,6 +1010,16 @@ """ def check_csrf_token(request, token): """ Determine if the supplied ``token`` is valid. Most implementations should simply compare the ``token`` to the current value of ``get_csrf_token`` but it is possible to verify the token using any mechanism necessary using this method. Returns ``True`` if the ``token`` is valid, otherwise ``False``. """ class IIntrospector(Interface): def get(category_name, discriminator, default=None): pyramid/tests/test_csrf.py
@@ -1,61 +1,20 @@ import unittest from zope.interface.interfaces import ComponentLookupError from pyramid import testing from pyramid.config import Configurator from pyramid.events import BeforeRender class Test_get_csrf_token(unittest.TestCase): def setUp(self): self.config = testing.setUp() def _callFUT(self, *args, **kwargs): from pyramid.csrf import get_csrf_token return get_csrf_token(*args, **kwargs) def test_no_override_csrf_utility_registered(self): request = testing.DummyRequest() self._callFUT(request) def test_success(self): self.config.set_csrf_storage_policy(DummyCSRF()) request = testing.DummyRequest() csrf_token = self._callFUT(request) self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8') class Test_new_csrf_token(unittest.TestCase): def setUp(self): self.config = testing.setUp() def _callFUT(self, *args, **kwargs): from pyramid.csrf import new_csrf_token return new_csrf_token(*args, **kwargs) def test_no_override_csrf_utility_registered(self): request = testing.DummyRequest() self._callFUT(request) def test_success(self): self.config.set_csrf_storage_policy(DummyCSRF()) request = testing.DummyRequest() csrf_token = self._callFUT(request) self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b') class TestLegacySessionCSRFStoragePolicy(unittest.TestCase): class MockSession(object): def __init__(self, current_token='02821185e4c94269bdc38e6eeae0a2f8'): self.current_token = current_token def new_csrf_token(self): return 'e5e9e30a08b34ff9842ff7d2b958c14b' self.current_token = 'e5e9e30a08b34ff9842ff7d2b958c14b' return self.current_token def get_csrf_token(self): return '02821185e4c94269bdc38e6eeae0a2f8' return self.current_token def _makeOne(self): from pyramid.csrf import LegacySessionCSRFStoragePolicy @@ -85,6 +44,13 @@ policy.new_csrf_token(request), 'e5e9e30a08b34ff9842ff7d2b958c14b' ) def test_check_csrf_token(self): request = DummyRequest(session=self.MockSession('foo')) policy = self._makeOne() self.assertTrue(policy.check_csrf_token(request, 'foo')) self.assertFalse(policy.check_csrf_token(request, 'bar')) class TestSessionCSRFStoragePolicy(unittest.TestCase): @@ -120,6 +86,16 @@ token = policy.new_csrf_token(request) self.assertNotEqual(token, 'foo') self.assertEqual(token, policy.get_csrf_token(request)) def test_check_csrf_token(self): request = DummyRequest(session={}) policy = self._makeOne() self.assertFalse(policy.check_csrf_token(request, 'foo')) request.session = {'_csrft_': 'foo'} self.assertTrue(policy.check_csrf_token(request, 'foo')) self.assertFalse(policy.check_csrf_token(request, 'bar')) class TestCookieCSRFStoragePolicy(unittest.TestCase): @@ -189,6 +165,57 @@ self.assertNotEqual(token, 'foo') self.assertEqual(token, policy.get_csrf_token(request)) def test_check_csrf_token(self): request = DummyRequest() policy = self._makeOne() self.assertFalse(policy.check_csrf_token(request, 'foo')) request.cookies = {'csrf_token': 'foo'} self.assertTrue(policy.check_csrf_token(request, 'foo')) self.assertFalse(policy.check_csrf_token(request, 'bar')) class Test_get_csrf_token(unittest.TestCase): def setUp(self): self.config = testing.setUp() def _callFUT(self, *args, **kwargs): from pyramid.csrf import get_csrf_token return get_csrf_token(*args, **kwargs) def test_no_override_csrf_utility_registered(self): request = testing.DummyRequest() self._callFUT(request) def test_success(self): self.config.set_csrf_storage_policy(DummyCSRF()) request = testing.DummyRequest() csrf_token = self._callFUT(request) self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8') class Test_new_csrf_token(unittest.TestCase): def setUp(self): self.config = testing.setUp() def _callFUT(self, *args, **kwargs): from pyramid.csrf import new_csrf_token return new_csrf_token(*args, **kwargs) def test_no_override_csrf_utility_registered(self): request = testing.DummyRequest() self._callFUT(request) def test_success(self): self.config.set_csrf_storage_policy(DummyCSRF()) request = testing.DummyRequest() csrf_token = self._callFUT(request) self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b') class Test_check_csrf_token(unittest.TestCase): def setUp(self):