Michael Merickel
2018-10-15 bda1306749c62ef4f11cfe567ed7d56c8ad94240
src/pyramid/csrf.py
@@ -4,22 +4,11 @@
from zope.interface import implementer
from pyramid.compat import (
    bytes_,
    urlparse,
    text_,
)
from pyramid.exceptions import (
    BadCSRFOrigin,
    BadCSRFToken,
)
from pyramid.compat import bytes_, urlparse, text_
from pyramid.exceptions import BadCSRFOrigin, BadCSRFToken
from pyramid.interfaces import ICSRFStoragePolicy
from pyramid.settings import aslist
from pyramid.util import (
    SimpleSerializer,
    is_same_domain,
    strings_differ
)
from pyramid.util import SimpleSerializer, is_same_domain, strings_differ
@implementer(ICSRFStoragePolicy)
@@ -37,6 +26,7 @@
    .. versionadded:: 1.9
    """
    def new_csrf_token(self, request):
        """ Sets a new CSRF token into the session and returns it. """
        return request.session.new_csrf_token()
@@ -50,7 +40,8 @@
        """ Returns ``True`` if the ``supplied_token`` is valid."""
        expected_token = self.get_csrf_token(request)
        return not strings_differ(
            bytes_(expected_token), bytes_(supplied_token))
            bytes_(expected_token), bytes_(supplied_token)
        )
@implementer(ICSRFStoragePolicy)
@@ -68,6 +59,7 @@
    .. versionadded:: 1.9
    """
    _token_factory = staticmethod(lambda: text_(uuid.uuid4().hex))
    def __init__(self, key='_csrft_'):
@@ -91,7 +83,8 @@
        """ Returns ``True`` if the ``supplied_token`` is valid."""
        expected_token = self.get_csrf_token(request)
        return not strings_differ(
            bytes_(expected_token), bytes_(supplied_token))
            bytes_(expected_token), bytes_(supplied_token)
        )
@implementer(ICSRFStoragePolicy)
@@ -111,10 +104,19 @@
       Added the ``samesite`` option and made the default ``'Lax'``.
    """
    _token_factory = staticmethod(lambda: text_(uuid.uuid4().hex))
    def __init__(self, cookie_name='csrf_token', secure=False, httponly=False,
                 domain=None, max_age=None, path='/', samesite='Lax'):
    def __init__(
        self,
        cookie_name='csrf_token',
        secure=False,
        httponly=False,
        domain=None,
        max_age=None,
        path='/',
        samesite='Lax',
    ):
        serializer = SimpleSerializer()
        self.cookie_profile = CookieProfile(
            cookie_name=cookie_name,
@@ -132,11 +134,10 @@
        """ Sets a new CSRF token into the request and returns it. """
        token = self._token_factory()
        request.cookies[self.cookie_name] = token
        def set_cookie(request, response):
            self.cookie_profile.set_cookies(
                response,
                token,
            )
            self.cookie_profile.set_cookies(response, token)
        request.add_response_callback(set_cookie)
        return token
@@ -153,7 +154,8 @@
        """ Returns ``True`` if the ``supplied_token`` is valid."""
        expected_token = self.get_csrf_token(request)
        return not strings_differ(
            bytes_(expected_token), bytes_(supplied_token))
            bytes_(expected_token), bytes_(supplied_token)
        )
def get_csrf_token(request):
@@ -182,10 +184,9 @@
    return csrf.new_csrf_token(request)
def check_csrf_token(request,
                     token='csrf_token',
                     header='X-CSRF-Token',
                     raises=True):
def check_csrf_token(
    request, token='csrf_token', header='X-CSRF-Token', raises=True
):
    """ Check the CSRF token returned by the
    :class:`pyramid.interfaces.ICSRFStoragePolicy` implementation against the
    value in ``request.POST.get(token)`` (if a POST request) or
@@ -246,8 +247,8 @@
    Check the ``Origin`` of the request to see if it is a cross site request or
    not.
    If the value supplied by the ``Origin`` or ``Referer`` header isn't one of the
    trusted origins and ``raises`` is ``True``, this function will raise a
    If the value supplied by the ``Origin`` or ``Referer`` header isn't one of
    the trusted origins and ``raises`` is ``True``, this function will raise a
    :exc:`pyramid.exceptions.BadCSRFOrigin` exception, but if ``raises`` is
    ``False``, this function will return ``False`` instead. If the CSRF origin
    checks are successful this function will return ``True`` unconditionally.
@@ -267,6 +268,7 @@
       Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf`
    """
    def _fail(reason):
        if raises:
            raise BadCSRFOrigin(reason)
@@ -315,7 +317,8 @@
        if trusted_origins is None:
            trusted_origins = aslist(
                request.registry.settings.get(
                    "pyramid.csrf_trusted_origins", [])
                    "pyramid.csrf_trusted_origins", []
                )
            )
        if request.host_port not in set(["80", "443"]):
@@ -325,8 +328,9 @@
        # Actually check to see if the request's origin matches any of our
        # trusted origins.
        if not any(is_same_domain(originp.netloc, host)
                   for host in trusted_origins):
        if not any(
            is_same_domain(originp.netloc, host) for host in trusted_origins
        ):
            reason = (
                "Referer checking failed - {0} does not match any trusted "
                "origins."