import unittest
|
|
from pyramid import testing
|
from pyramid.config import Configurator
|
|
|
class TestLegacySessionCSRFStoragePolicy(unittest.TestCase):
|
class MockSession(object):
|
def __init__(self, current_token='02821185e4c94269bdc38e6eeae0a2f8'):
|
self.current_token = current_token
|
|
def new_csrf_token(self):
|
self.current_token = 'e5e9e30a08b34ff9842ff7d2b958c14b'
|
return self.current_token
|
|
def get_csrf_token(self):
|
return self.current_token
|
|
def _makeOne(self):
|
from pyramid.csrf import LegacySessionCSRFStoragePolicy
|
|
return LegacySessionCSRFStoragePolicy()
|
|
def test_register_session_csrf_policy(self):
|
from pyramid.csrf import LegacySessionCSRFStoragePolicy
|
from pyramid.interfaces import ICSRFStoragePolicy
|
|
config = Configurator()
|
config.set_csrf_storage_policy(self._makeOne())
|
config.commit()
|
|
policy = config.registry.queryUtility(ICSRFStoragePolicy)
|
|
self.assertTrue(isinstance(policy, LegacySessionCSRFStoragePolicy))
|
|
def test_session_csrf_implementation_delegates_to_session(self):
|
policy = self._makeOne()
|
request = DummyRequest(session=self.MockSession())
|
|
self.assertEqual(
|
policy.get_csrf_token(request), '02821185e4c94269bdc38e6eeae0a2f8'
|
)
|
self.assertEqual(
|
policy.new_csrf_token(request), 'e5e9e30a08b34ff9842ff7d2b958c14b'
|
)
|
|
def test_check_csrf_token(self):
|
request = DummyRequest(session=self.MockSession('foo'))
|
|
policy = self._makeOne()
|
self.assertTrue(policy.check_csrf_token(request, 'foo'))
|
self.assertFalse(policy.check_csrf_token(request, 'bar'))
|
|
|
class TestSessionCSRFStoragePolicy(unittest.TestCase):
|
def _makeOne(self, **kw):
|
from pyramid.csrf import SessionCSRFStoragePolicy
|
|
return SessionCSRFStoragePolicy(**kw)
|
|
def test_register_session_csrf_policy(self):
|
from pyramid.csrf import SessionCSRFStoragePolicy
|
from pyramid.interfaces import ICSRFStoragePolicy
|
|
config = Configurator()
|
config.set_csrf_storage_policy(self._makeOne())
|
config.commit()
|
|
policy = config.registry.queryUtility(ICSRFStoragePolicy)
|
|
self.assertTrue(isinstance(policy, SessionCSRFStoragePolicy))
|
|
def test_it_creates_a_new_token(self):
|
request = DummyRequest(session={})
|
|
policy = self._makeOne()
|
policy._token_factory = lambda: 'foo'
|
self.assertEqual(policy.get_csrf_token(request), 'foo')
|
|
def test_get_csrf_token_returns_the_new_token(self):
|
request = DummyRequest(session={'_csrft_': 'foo'})
|
|
policy = self._makeOne()
|
self.assertEqual(policy.get_csrf_token(request), 'foo')
|
|
token = policy.new_csrf_token(request)
|
self.assertNotEqual(token, 'foo')
|
self.assertEqual(token, policy.get_csrf_token(request))
|
|
def test_check_csrf_token(self):
|
request = DummyRequest(session={})
|
|
policy = self._makeOne()
|
self.assertFalse(policy.check_csrf_token(request, 'foo'))
|
|
request.session = {'_csrft_': 'foo'}
|
self.assertTrue(policy.check_csrf_token(request, 'foo'))
|
self.assertFalse(policy.check_csrf_token(request, 'bar'))
|
|
|
class TestCookieCSRFStoragePolicy(unittest.TestCase):
|
def _makeOne(self, **kw):
|
from pyramid.csrf import CookieCSRFStoragePolicy
|
|
return CookieCSRFStoragePolicy(**kw)
|
|
def test_register_cookie_csrf_policy(self):
|
from pyramid.csrf import CookieCSRFStoragePolicy
|
from pyramid.interfaces import ICSRFStoragePolicy
|
|
config = Configurator()
|
config.set_csrf_storage_policy(self._makeOne())
|
config.commit()
|
|
policy = config.registry.queryUtility(ICSRFStoragePolicy)
|
|
self.assertTrue(isinstance(policy, CookieCSRFStoragePolicy))
|
|
def test_get_cookie_csrf_with_no_existing_cookie_sets_cookies(self):
|
response = MockResponse()
|
request = DummyRequest()
|
|
policy = self._makeOne()
|
token = policy.get_csrf_token(request)
|
request.response_callback(request, response)
|
self.assertEqual(
|
response.headerlist,
|
[
|
(
|
'Set-Cookie',
|
'csrf_token={}; Path=/; SameSite=Lax'.format(token),
|
)
|
],
|
)
|
|
def test_get_cookie_csrf_nondefault_samesite(self):
|
response = MockResponse()
|
request = DummyRequest()
|
|
policy = self._makeOne(samesite=None)
|
token = policy.get_csrf_token(request)
|
request.response_callback(request, response)
|
self.assertEqual(
|
response.headerlist,
|
[('Set-Cookie', 'csrf_token={}; Path=/'.format(token))],
|
)
|
|
def test_existing_cookie_csrf_does_not_set_cookie(self):
|
request = DummyRequest()
|
request.cookies = {'csrf_token': 'e6f325fee5974f3da4315a8ccf4513d2'}
|
|
policy = self._makeOne()
|
token = policy.get_csrf_token(request)
|
|
self.assertEqual(token, 'e6f325fee5974f3da4315a8ccf4513d2')
|
self.assertIsNone(request.response_callback)
|
|
def test_new_cookie_csrf_with_existing_cookie_sets_cookies(self):
|
request = DummyRequest()
|
request.cookies = {'csrf_token': 'e6f325fee5974f3da4315a8ccf4513d2'}
|
|
policy = self._makeOne()
|
token = policy.new_csrf_token(request)
|
|
response = MockResponse()
|
request.response_callback(request, response)
|
self.assertEqual(
|
response.headerlist,
|
[
|
(
|
'Set-Cookie',
|
'csrf_token={}; Path=/; SameSite=Lax'.format(token),
|
)
|
],
|
)
|
|
def test_get_csrf_token_returns_the_new_token(self):
|
request = DummyRequest()
|
request.cookies = {'csrf_token': 'foo'}
|
|
policy = self._makeOne()
|
self.assertEqual(policy.get_csrf_token(request), 'foo')
|
|
token = policy.new_csrf_token(request)
|
self.assertNotEqual(token, 'foo')
|
self.assertEqual(token, policy.get_csrf_token(request))
|
|
def test_check_csrf_token(self):
|
request = DummyRequest()
|
|
policy = self._makeOne()
|
self.assertFalse(policy.check_csrf_token(request, 'foo'))
|
|
request.cookies = {'csrf_token': 'foo'}
|
self.assertTrue(policy.check_csrf_token(request, 'foo'))
|
self.assertFalse(policy.check_csrf_token(request, 'bar'))
|
|
|
class Test_get_csrf_token(unittest.TestCase):
|
def setUp(self):
|
self.config = testing.setUp()
|
|
def _callFUT(self, *args, **kwargs):
|
from pyramid.csrf import get_csrf_token
|
|
return get_csrf_token(*args, **kwargs)
|
|
def test_no_override_csrf_utility_registered(self):
|
request = testing.DummyRequest()
|
self._callFUT(request)
|
|
def test_success(self):
|
self.config.set_csrf_storage_policy(DummyCSRF())
|
request = testing.DummyRequest()
|
|
csrf_token = self._callFUT(request)
|
|
self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8')
|
|
|
class Test_new_csrf_token(unittest.TestCase):
|
def setUp(self):
|
self.config = testing.setUp()
|
|
def _callFUT(self, *args, **kwargs):
|
from pyramid.csrf import new_csrf_token
|
|
return new_csrf_token(*args, **kwargs)
|
|
def test_no_override_csrf_utility_registered(self):
|
request = testing.DummyRequest()
|
self._callFUT(request)
|
|
def test_success(self):
|
self.config.set_csrf_storage_policy(DummyCSRF())
|
request = testing.DummyRequest()
|
|
csrf_token = self._callFUT(request)
|
|
self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b')
|
|
|
class Test_check_csrf_token(unittest.TestCase):
|
def setUp(self):
|
self.config = testing.setUp()
|
|
# set up CSRF
|
self.config.set_default_csrf_options(require_csrf=False)
|
|
def _callFUT(self, *args, **kwargs):
|
from pyramid.csrf import check_csrf_token
|
|
return check_csrf_token(*args, **kwargs)
|
|
def test_success_token(self):
|
request = testing.DummyRequest()
|
request.method = "POST"
|
request.POST = {'csrf_token': request.session.get_csrf_token()}
|
self.assertEqual(self._callFUT(request, token='csrf_token'), True)
|
|
def test_success_header(self):
|
request = testing.DummyRequest()
|
request.headers['X-CSRF-Token'] = request.session.get_csrf_token()
|
self.assertEqual(self._callFUT(request, header='X-CSRF-Token'), True)
|
|
def test_success_default_token(self):
|
request = testing.DummyRequest()
|
request.method = "POST"
|
request.POST = {'csrf_token': request.session.get_csrf_token()}
|
self.assertEqual(self._callFUT(request), True)
|
|
def test_success_default_header(self):
|
request = testing.DummyRequest()
|
request.headers['X-CSRF-Token'] = request.session.get_csrf_token()
|
self.assertEqual(self._callFUT(request), True)
|
|
def test_failure_raises(self):
|
from pyramid.exceptions import BadCSRFToken
|
|
request = testing.DummyRequest()
|
self.assertRaises(BadCSRFToken, self._callFUT, request, 'csrf_token')
|
|
def test_failure_no_raises(self):
|
request = testing.DummyRequest()
|
result = self._callFUT(request, 'csrf_token', raises=False)
|
self.assertEqual(result, False)
|
|
|
class Test_check_csrf_token_without_defaults_configured(unittest.TestCase):
|
def setUp(self):
|
self.config = testing.setUp()
|
|
def _callFUT(self, *args, **kwargs):
|
from pyramid.csrf import check_csrf_token
|
|
return check_csrf_token(*args, **kwargs)
|
|
def test_success_token(self):
|
request = testing.DummyRequest()
|
request.method = "POST"
|
request.POST = {'csrf_token': request.session.get_csrf_token()}
|
self.assertEqual(self._callFUT(request, token='csrf_token'), True)
|
|
def test_failure_raises(self):
|
from pyramid.exceptions import BadCSRFToken
|
|
request = testing.DummyRequest()
|
self.assertRaises(BadCSRFToken, self._callFUT, request, 'csrf_token')
|
|
def test_failure_no_raises(self):
|
request = testing.DummyRequest()
|
result = self._callFUT(request, 'csrf_token', raises=False)
|
self.assertEqual(result, False)
|
|
|
class Test_check_csrf_origin(unittest.TestCase):
|
def _callFUT(self, *args, **kwargs):
|
from pyramid.csrf import check_csrf_origin
|
|
return check_csrf_origin(*args, **kwargs)
|
|
def test_success_with_http(self):
|
request = testing.DummyRequest()
|
request.scheme = "http"
|
self.assertTrue(self._callFUT(request))
|
|
def test_success_with_https_and_referrer(self):
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.host = "example.com"
|
request.host_port = "443"
|
request.referrer = "https://example.com/login/"
|
request.registry.settings = {}
|
self.assertTrue(self._callFUT(request))
|
|
def test_success_with_https_and_origin(self):
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.host = "example.com"
|
request.host_port = "443"
|
request.headers = {"Origin": "https://example.com/"}
|
request.referrer = "https://not-example.com/"
|
request.registry.settings = {}
|
self.assertTrue(self._callFUT(request))
|
|
def test_success_with_additional_trusted_host(self):
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.host = "example.com"
|
request.host_port = "443"
|
request.referrer = "https://not-example.com/login/"
|
request.registry.settings = {
|
"pyramid.csrf_trusted_origins": ["not-example.com"]
|
}
|
self.assertTrue(self._callFUT(request))
|
|
def test_success_with_nonstandard_port(self):
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.host = "example.com:8080"
|
request.host_port = "8080"
|
request.referrer = "https://example.com:8080/login/"
|
request.registry.settings = {}
|
self.assertTrue(self._callFUT(request))
|
|
def test_fails_with_wrong_host(self):
|
from pyramid.exceptions import BadCSRFOrigin
|
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.host = "example.com"
|
request.host_port = "443"
|
request.referrer = "https://not-example.com/login/"
|
request.registry.settings = {}
|
self.assertRaises(BadCSRFOrigin, self._callFUT, request)
|
self.assertFalse(self._callFUT(request, raises=False))
|
|
def test_fails_with_no_origin(self):
|
from pyramid.exceptions import BadCSRFOrigin
|
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.referrer = None
|
self.assertRaises(BadCSRFOrigin, self._callFUT, request)
|
self.assertFalse(self._callFUT(request, raises=False))
|
|
def test_fails_when_http_to_https(self):
|
from pyramid.exceptions import BadCSRFOrigin
|
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.host = "example.com"
|
request.host_port = "443"
|
request.referrer = "http://example.com/evil/"
|
request.registry.settings = {}
|
self.assertRaises(BadCSRFOrigin, self._callFUT, request)
|
self.assertFalse(self._callFUT(request, raises=False))
|
|
def test_fails_with_nonstandard_port(self):
|
from pyramid.exceptions import BadCSRFOrigin
|
|
request = testing.DummyRequest()
|
request.scheme = "https"
|
request.host = "example.com:8080"
|
request.host_port = "8080"
|
request.referrer = "https://example.com/login/"
|
request.registry.settings = {}
|
self.assertRaises(BadCSRFOrigin, self._callFUT, request)
|
self.assertFalse(self._callFUT(request, raises=False))
|
|
|
class DummyRequest(object):
|
registry = None
|
session = None
|
response_callback = None
|
|
def __init__(self, registry=None, session=None):
|
self.registry = registry
|
self.session = session
|
self.cookies = {}
|
|
def add_response_callback(self, callback):
|
self.response_callback = callback
|
|
|
class MockResponse(object):
|
def __init__(self):
|
self.headerlist = []
|
|
|
class DummyCSRF(object):
|
def new_csrf_token(self, request):
|
return 'e5e9e30a08b34ff9842ff7d2b958c14b'
|
|
def get_csrf_token(self, request):
|
return '02821185e4c94269bdc38e6eeae0a2f8'
|