Michael Merickel
2016-04-18 2b1a90fd82f804eef5b3e57091f5a7df81aa4ac9
disable csrf checking on all exception views unless explicitly turned on
2 files modified
83 ■■■■ changed files
pyramid/tests/test_viewderivers.py 59 ●●●●● patch | view | raw | blame | history
pyramid/viewderivers.py 24 ●●●●● patch | view | raw | blame | history
pyramid/tests/test_viewderivers.py
@@ -1297,6 +1297,64 @@
        result = view(None, request)
        self.assertTrue(result is response)
    def test_csrf_view_skipped_by_default_on_exception_view(self):
        from pyramid.request import Request
        def view(request):
            raise ValueError
        def excview(request):
            return 'hello'
        self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
        self.config.set_session_factory(
            lambda request: DummySession({'csrf_token': 'foo'}))
        self.config.add_view(view, name='foo', require_csrf=False)
        self.config.add_view(excview, context=ValueError, renderer='string')
        app = self.config.make_wsgi_app()
        request = Request.blank('/foo', base_url='http://example.com')
        request.method = 'POST'
        response = request.get_response(app)
        self.assertTrue(b'hello' in response.body)
    def test_csrf_view_failed_on_explicit_exception_view(self):
        from pyramid.exceptions import BadCSRFToken
        from pyramid.request import Request
        def view(request):
            raise ValueError
        def excview(request): pass
        self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
        self.config.set_session_factory(
            lambda request: DummySession({'csrf_token': 'foo'}))
        self.config.add_view(view, name='foo', require_csrf=False)
        self.config.add_view(excview, context=ValueError, renderer='string',
                             require_csrf=True)
        app = self.config.make_wsgi_app()
        request = Request.blank('/foo', base_url='http://example.com')
        request.method = 'POST'
        try:
            request.get_response(app)
        except BadCSRFToken:
            pass
        else: # pragma: no cover
            raise AssertionError
    def test_csrf_view_passed_on_explicit_exception_view(self):
        from pyramid.request import Request
        def view(request):
            raise ValueError
        def excview(request):
            return 'hello'
        self.config.add_settings({'pyramid.require_default_csrf': 'yes'})
        self.config.set_session_factory(
            lambda request: DummySession({'csrf_token': 'foo'}))
        self.config.add_view(view, name='foo', require_csrf=False)
        self.config.add_view(excview, context=ValueError, renderer='string',
                             require_csrf=True)
        app = self.config.make_wsgi_app()
        request = Request.blank('/foo', base_url='http://example.com')
        request.method = 'POST'
        request.headers['X-CSRF-Token'] = 'foo'
        response = request.get_response(app)
        self.assertTrue(b'hello' in response.body)
class TestDerivationOrder(unittest.TestCase):
    def setUp(self):
@@ -1554,7 +1612,6 @@
        from pyramid.interfaces import IRequest
        from pyramid.interfaces import IView
        from pyramid.interfaces import IViewClassifier
        from pyramid.interfaces import IExceptionViewClassifier
        classifier = IViewClassifier
        if ctx_iface is None:
            ctx_iface = Interface
pyramid/viewderivers.py
@@ -483,21 +483,29 @@
    default_val = _parse_csrf_setting(
        info.settings.get('pyramid.require_default_csrf'),
        'Config setting "pyramid.require_default_csrf"')
    val = _parse_csrf_setting(
    explicit_val = _parse_csrf_setting(
        info.options.get('require_csrf'),
        'View option "require_csrf"')
    if (val is True and default_val) or val is None:
        val = default_val
    if val is True:
        val = 'csrf_token'
    resolved_val = explicit_val
    if (explicit_val is True and default_val) or explicit_val is None:
        resolved_val = default_val
    if resolved_val is True:
        resolved_val = 'csrf_token'
    wrapped_view = view
    if val:
    if resolved_val:
        def csrf_view(context, request):
            # Assume that anything not defined as 'safe' by RFC2616 needs
            # protection
            if request.method not in SAFE_REQUEST_METHODS:
            if (
                request.method not in SAFE_REQUEST_METHODS and
                (
                    # skip exception views unless value is explicitly defined
                    getattr(request, 'exception', None) is None or
                    explicit_val is not None
                )
            ):
                check_csrf_origin(request, raises=True)
                check_csrf_token(request, val, raises=True)
                check_csrf_token(request, resolved_val, raises=True)
            return view(context, request)
        wrapped_view = csrf_view
    return wrapped_view