Michael Merickel
2018-10-15 dd3cc81f75dcb5ff96e0751653071722a15f46c2
commit | author | age
a2c7c7 1 import unittest
MW 2
313c25 3 from pyramid import testing
a2c7c7 4 from pyramid.config import Configurator
313c25 5
JC 6
682a9b 7 class TestLegacySessionCSRFStoragePolicy(unittest.TestCase):
a2c7c7 8     class MockSession(object):
3f14d6 9         def __init__(self, current_token='02821185e4c94269bdc38e6eeae0a2f8'):
MM 10             self.current_token = current_token
11
a2c7c7 12         def new_csrf_token(self):
3f14d6 13             self.current_token = 'e5e9e30a08b34ff9842ff7d2b958c14b'
MM 14             return self.current_token
a2c7c7 15
MW 16         def get_csrf_token(self):
3f14d6 17             return self.current_token
a2c7c7 18
313c25 19     def _makeOne(self):
682a9b 20         from pyramid.csrf import LegacySessionCSRFStoragePolicy
MM 21         return LegacySessionCSRFStoragePolicy()
313c25 22
JC 23     def test_register_session_csrf_policy(self):
682a9b 24         from pyramid.csrf import LegacySessionCSRFStoragePolicy
fe0d22 25         from pyramid.interfaces import ICSRFStoragePolicy
313c25 26
a2c7c7 27         config = Configurator()
7c0f09 28         config.set_csrf_storage_policy(self._makeOne())
a2c7c7 29         config.commit()
MW 30
fe0d22 31         policy = config.registry.queryUtility(ICSRFStoragePolicy)
313c25 32
682a9b 33         self.assertTrue(isinstance(policy, LegacySessionCSRFStoragePolicy))
313c25 34
JC 35     def test_session_csrf_implementation_delegates_to_session(self):
36         policy = self._makeOne()
37         request = DummyRequest(session=self.MockSession())
38
a2c7c7 39         self.assertEqual(
313c25 40             policy.get_csrf_token(request),
a2c7c7 41             '02821185e4c94269bdc38e6eeae0a2f8'
MW 42         )
43         self.assertEqual(
313c25 44             policy.new_csrf_token(request),
a2c7c7 45             'e5e9e30a08b34ff9842ff7d2b958c14b'
MW 46         )
3f14d6 47
MM 48     def test_check_csrf_token(self):
49         request = DummyRequest(session=self.MockSession('foo'))
50
51         policy = self._makeOne()
52         self.assertTrue(policy.check_csrf_token(request, 'foo'))
53         self.assertFalse(policy.check_csrf_token(request, 'bar'))
a2c7c7 54
682a9b 55
MM 56 class TestSessionCSRFStoragePolicy(unittest.TestCase):
57     def _makeOne(self, **kw):
58         from pyramid.csrf import SessionCSRFStoragePolicy
59         return SessionCSRFStoragePolicy(**kw)
60
61     def test_register_session_csrf_policy(self):
62         from pyramid.csrf import SessionCSRFStoragePolicy
63         from pyramid.interfaces import ICSRFStoragePolicy
64
65         config = Configurator()
66         config.set_csrf_storage_policy(self._makeOne())
67         config.commit()
68
69         policy = config.registry.queryUtility(ICSRFStoragePolicy)
70
71         self.assertTrue(isinstance(policy, SessionCSRFStoragePolicy))
72
73     def test_it_creates_a_new_token(self):
74         request = DummyRequest(session={})
75
313c25 76         policy = self._makeOne()
682a9b 77         policy._token_factory = lambda: 'foo'
MM 78         self.assertEqual(policy.get_csrf_token(request), 'foo')
a2c7c7 79
682a9b 80     def test_get_csrf_token_returns_the_new_token(self):
MM 81         request = DummyRequest(session={'_csrft_': 'foo'})
a2c7c7 82
313c25 83         policy = self._makeOne()
682a9b 84         self.assertEqual(policy.get_csrf_token(request), 'foo')
313c25 85
682a9b 86         token = policy.new_csrf_token(request)
MM 87         self.assertNotEqual(token, 'foo')
88         self.assertEqual(token, policy.get_csrf_token(request))
3f14d6 89
MM 90     def test_check_csrf_token(self):
91         request = DummyRequest(session={})
92
93         policy = self._makeOne()
94         self.assertFalse(policy.check_csrf_token(request, 'foo'))
95
96         request.session = {'_csrft_': 'foo'}
97         self.assertTrue(policy.check_csrf_token(request, 'foo'))
98         self.assertFalse(policy.check_csrf_token(request, 'bar'))
313c25 99
JC 100
7c0f09 101 class TestCookieCSRFStoragePolicy(unittest.TestCase):
682a9b 102     def _makeOne(self, **kw):
7c0f09 103         from pyramid.csrf import CookieCSRFStoragePolicy
682a9b 104         return CookieCSRFStoragePolicy(**kw)
313c25 105
JC 106     def test_register_cookie_csrf_policy(self):
7c0f09 107         from pyramid.csrf import CookieCSRFStoragePolicy
fe0d22 108         from pyramid.interfaces import ICSRFStoragePolicy
313c25 109
a2c7c7 110         config = Configurator()
7c0f09 111         config.set_csrf_storage_policy(self._makeOne())
a2c7c7 112         config.commit()
MW 113
fe0d22 114         policy = config.registry.queryUtility(ICSRFStoragePolicy)
a2c7c7 115
7c0f09 116         self.assertTrue(isinstance(policy, CookieCSRFStoragePolicy))
313c25 117
JC 118     def test_get_cookie_csrf_with_no_existing_cookie_sets_cookies(self):
119         response = MockResponse()
682a9b 120         request = DummyRequest()
313c25 121
JC 122         policy = self._makeOne()
123         token = policy.get_csrf_token(request)
682a9b 124         request.response_callback(request, response)
a2c7c7 125         self.assertEqual(
7c0f09 126             response.headerlist,
87771a 127             [('Set-Cookie', 'csrf_token={}; Path=/; SameSite=Lax'.format(
CM 128                 token))]
129         )
130
131     def test_get_cookie_csrf_nondefault_samesite(self):
132         response = MockResponse()
133         request = DummyRequest()
134
135         policy = self._makeOne(samesite=None)
136         token = policy.get_csrf_token(request)
137         request.response_callback(request, response)
138         self.assertEqual(
139             response.headerlist,
7c0f09 140             [('Set-Cookie', 'csrf_token={}; Path=/'.format(token))]
a2c7c7 141         )
MW 142
143     def test_existing_cookie_csrf_does_not_set_cookie(self):
682a9b 144         request = DummyRequest()
a2c7c7 145         request.cookies = {'csrf_token': 'e6f325fee5974f3da4315a8ccf4513d2'}
MW 146
313c25 147         policy = self._makeOne()
JC 148         token = policy.get_csrf_token(request)
149
a2c7c7 150         self.assertEqual(
MW 151             token,
152             'e6f325fee5974f3da4315a8ccf4513d2'
153         )
682a9b 154         self.assertIsNone(request.response_callback)
a2c7c7 155
MW 156     def test_new_cookie_csrf_with_existing_cookie_sets_cookies(self):
682a9b 157         request = DummyRequest()
a2c7c7 158         request.cookies = {'csrf_token': 'e6f325fee5974f3da4315a8ccf4513d2'}
MW 159
313c25 160         policy = self._makeOne()
JC 161         token = policy.new_csrf_token(request)
682a9b 162
MM 163         response = MockResponse()
164         request.response_callback(request, response)
a2c7c7 165         self.assertEqual(
7c0f09 166             response.headerlist,
87771a 167             [('Set-Cookie', 'csrf_token={}; Path=/; SameSite=Lax'.format(token)
CM 168             )]
a2c7c7 169         )
MW 170
682a9b 171     def test_get_csrf_token_returns_the_new_token(self):
MM 172         request = DummyRequest()
173         request.cookies = {'csrf_token': 'foo'}
313c25 174
JC 175         policy = self._makeOne()
682a9b 176         self.assertEqual(policy.get_csrf_token(request), 'foo')
313c25 177
682a9b 178         token = policy.new_csrf_token(request)
MM 179         self.assertNotEqual(token, 'foo')
180         self.assertEqual(token, policy.get_csrf_token(request))
313c25 181
3f14d6 182     def test_check_csrf_token(self):
MM 183         request = DummyRequest()
184
185         policy = self._makeOne()
186         self.assertFalse(policy.check_csrf_token(request, 'foo'))
187
188         request.cookies = {'csrf_token': 'foo'}
189         self.assertTrue(policy.check_csrf_token(request, 'foo'))
190         self.assertFalse(policy.check_csrf_token(request, 'bar'))
191
192 class Test_get_csrf_token(unittest.TestCase):
193     def setUp(self):
194         self.config = testing.setUp()
195
196     def _callFUT(self, *args, **kwargs):
197         from pyramid.csrf import get_csrf_token
198         return get_csrf_token(*args, **kwargs)
199
200     def test_no_override_csrf_utility_registered(self):
201         request = testing.DummyRequest()
202         self._callFUT(request)
203
204     def test_success(self):
205         self.config.set_csrf_storage_policy(DummyCSRF())
206         request = testing.DummyRequest()
207
208         csrf_token = self._callFUT(request)
209
210         self.assertEquals(csrf_token, '02821185e4c94269bdc38e6eeae0a2f8')
211
212
213 class Test_new_csrf_token(unittest.TestCase):
214     def setUp(self):
215         self.config = testing.setUp()
216
217     def _callFUT(self, *args, **kwargs):
218         from pyramid.csrf import new_csrf_token
219         return new_csrf_token(*args, **kwargs)
220
221     def test_no_override_csrf_utility_registered(self):
222         request = testing.DummyRequest()
223         self._callFUT(request)
224
225     def test_success(self):
226         self.config.set_csrf_storage_policy(DummyCSRF())
227         request = testing.DummyRequest()
228
229         csrf_token = self._callFUT(request)
230
231         self.assertEquals(csrf_token, 'e5e9e30a08b34ff9842ff7d2b958c14b')
232
313c25 233
JC 234 class Test_check_csrf_token(unittest.TestCase):
235     def setUp(self):
236         self.config = testing.setUp()
237
4b3603 238         # set up CSRF
313c25 239         self.config.set_default_csrf_options(require_csrf=False)
JC 240
241     def _callFUT(self, *args, **kwargs):
dd3cc8 242         from pyramid.csrf import check_csrf_token
313c25 243         return check_csrf_token(*args, **kwargs)
JC 244
245     def test_success_token(self):
246         request = testing.DummyRequest()
247         request.method = "POST"
248         request.POST = {'csrf_token': request.session.get_csrf_token()}
249         self.assertEqual(self._callFUT(request, token='csrf_token'), True)
250
251     def test_success_header(self):
252         request = testing.DummyRequest()
253         request.headers['X-CSRF-Token'] = request.session.get_csrf_token()
254         self.assertEqual(self._callFUT(request, header='X-CSRF-Token'), True)
255
256     def test_success_default_token(self):
257         request = testing.DummyRequest()
258         request.method = "POST"
259         request.POST = {'csrf_token': request.session.get_csrf_token()}
260         self.assertEqual(self._callFUT(request), True)
261
262     def test_success_default_header(self):
263         request = testing.DummyRequest()
264         request.headers['X-CSRF-Token'] = request.session.get_csrf_token()
265         self.assertEqual(self._callFUT(request), True)
266
267     def test_failure_raises(self):
268         from pyramid.exceptions import BadCSRFToken
269         request = testing.DummyRequest()
270         self.assertRaises(BadCSRFToken, self._callFUT, request,
271                           'csrf_token')
272
273     def test_failure_no_raises(self):
274         request = testing.DummyRequest()
275         result = self._callFUT(request, 'csrf_token', raises=False)
276         self.assertEqual(result, False)
277
278
f6d63a 279 class Test_check_csrf_token_without_defaults_configured(unittest.TestCase):
MW 280     def setUp(self):
281         self.config = testing.setUp()
282
283     def _callFUT(self, *args, **kwargs):
dd3cc8 284         from pyramid.csrf import check_csrf_token
f6d63a 285         return check_csrf_token(*args, **kwargs)
MW 286
287     def test_success_token(self):
288         request = testing.DummyRequest()
289         request.method = "POST"
290         request.POST = {'csrf_token': request.session.get_csrf_token()}
291         self.assertEqual(self._callFUT(request, token='csrf_token'), True)
292
293     def test_failure_raises(self):
294         from pyramid.exceptions import BadCSRFToken
295         request = testing.DummyRequest()
296         self.assertRaises(BadCSRFToken, self._callFUT, request,
297                           'csrf_token')
298
299     def test_failure_no_raises(self):
300         request = testing.DummyRequest()
301         result = self._callFUT(request, 'csrf_token', raises=False)
302         self.assertEqual(result, False)
303
304
313c25 305 class Test_check_csrf_origin(unittest.TestCase):
JC 306     def _callFUT(self, *args, **kwargs):
dd3cc8 307         from pyramid.csrf import check_csrf_origin
313c25 308         return check_csrf_origin(*args, **kwargs)
JC 309
310     def test_success_with_http(self):
311         request = testing.DummyRequest()
312         request.scheme = "http"
313         self.assertTrue(self._callFUT(request))
314
315     def test_success_with_https_and_referrer(self):
316         request = testing.DummyRequest()
317         request.scheme = "https"
318         request.host = "example.com"
319         request.host_port = "443"
320         request.referrer = "https://example.com/login/"
321         request.registry.settings = {}
322         self.assertTrue(self._callFUT(request))
323
324     def test_success_with_https_and_origin(self):
325         request = testing.DummyRequest()
326         request.scheme = "https"
327         request.host = "example.com"
328         request.host_port = "443"
329         request.headers = {"Origin": "https://example.com/"}
330         request.referrer = "https://not-example.com/"
331         request.registry.settings = {}
332         self.assertTrue(self._callFUT(request))
333
334     def test_success_with_additional_trusted_host(self):
335         request = testing.DummyRequest()
336         request.scheme = "https"
337         request.host = "example.com"
338         request.host_port = "443"
339         request.referrer = "https://not-example.com/login/"
340         request.registry.settings = {
341             "pyramid.csrf_trusted_origins": ["not-example.com"],
342         }
343         self.assertTrue(self._callFUT(request))
344
345     def test_success_with_nonstandard_port(self):
346         request = testing.DummyRequest()
347         request.scheme = "https"
348         request.host = "example.com:8080"
349         request.host_port = "8080"
350         request.referrer = "https://example.com:8080/login/"
351         request.registry.settings = {}
352         self.assertTrue(self._callFUT(request))
353
354     def test_fails_with_wrong_host(self):
355         from pyramid.exceptions import BadCSRFOrigin
356         request = testing.DummyRequest()
357         request.scheme = "https"
358         request.host = "example.com"
359         request.host_port = "443"
360         request.referrer = "https://not-example.com/login/"
361         request.registry.settings = {}
362         self.assertRaises(BadCSRFOrigin, self._callFUT, request)
363         self.assertFalse(self._callFUT(request, raises=False))
364
365     def test_fails_with_no_origin(self):
366         from pyramid.exceptions import BadCSRFOrigin
367         request = testing.DummyRequest()
368         request.scheme = "https"
369         request.referrer = None
370         self.assertRaises(BadCSRFOrigin, self._callFUT, request)
371         self.assertFalse(self._callFUT(request, raises=False))
372
373     def test_fails_when_http_to_https(self):
374         from pyramid.exceptions import BadCSRFOrigin
375         request = testing.DummyRequest()
376         request.scheme = "https"
377         request.host = "example.com"
378         request.host_port = "443"
379         request.referrer = "http://example.com/evil/"
380         request.registry.settings = {}
381         self.assertRaises(BadCSRFOrigin, self._callFUT, request)
382         self.assertFalse(self._callFUT(request, raises=False))
383
384     def test_fails_with_nonstandard_port(self):
385         from pyramid.exceptions import BadCSRFOrigin
386         request = testing.DummyRequest()
387         request.scheme = "https"
388         request.host = "example.com:8080"
389         request.host_port = "8080"
390         request.referrer = "https://example.com/login/"
391         request.registry.settings = {}
392         self.assertRaises(BadCSRFOrigin, self._callFUT, request)
393         self.assertFalse(self._callFUT(request, raises=False))
394
a2c7c7 395
MW 396 class DummyRequest(object):
397     registry = None
398     session = None
682a9b 399     response_callback = None
a2c7c7 400
682a9b 401     def __init__(self, registry=None, session=None):
a2c7c7 402         self.registry = registry
MW 403         self.session = session
682a9b 404         self.cookies = {}
a2c7c7 405
MW 406     def add_response_callback(self, callback):
682a9b 407         self.response_callback = callback
a2c7c7 408
MW 409
410 class MockResponse(object):
411     def __init__(self):
7c0f09 412         self.headerlist = []
313c25 413
JC 414
415 class DummyCSRF(object):
416     def new_csrf_token(self, request):
417         return 'e5e9e30a08b34ff9842ff7d2b958c14b'
418
419     def get_csrf_token(self, request):
420         return '02821185e4c94269bdc38e6eeae0a2f8'