| | |
| | | from zope.interface import implementer |
| | | |
| | | |
| | | from pyramid.compat import ( |
| | | bytes_, |
| | | urlparse, |
| | | text_, |
| | | ) |
| | | from pyramid.exceptions import ( |
| | | BadCSRFOrigin, |
| | | BadCSRFToken, |
| | | ) |
| | | from pyramid.compat import bytes_, urlparse, text_ |
| | | from pyramid.exceptions import BadCSRFOrigin, BadCSRFToken |
| | | from pyramid.interfaces import ICSRFStoragePolicy |
| | | from pyramid.settings import aslist |
| | | from pyramid.util import ( |
| | | SimpleSerializer, |
| | | is_same_domain, |
| | | strings_differ |
| | | ) |
| | | from pyramid.util import SimpleSerializer, is_same_domain, strings_differ |
| | | |
| | | |
| | | @implementer(ICSRFStoragePolicy) |
| | |
| | | .. versionadded:: 1.9 |
| | | |
| | | """ |
| | | |
| | | def new_csrf_token(self, request): |
| | | """ Sets a new CSRF token into the session and returns it. """ |
| | | return request.session.new_csrf_token() |
| | |
| | | """ Returns ``True`` if the ``supplied_token`` is valid.""" |
| | | expected_token = self.get_csrf_token(request) |
| | | return not strings_differ( |
| | | bytes_(expected_token), bytes_(supplied_token)) |
| | | bytes_(expected_token), bytes_(supplied_token) |
| | | ) |
| | | |
| | | |
| | | @implementer(ICSRFStoragePolicy) |
| | |
| | | .. versionadded:: 1.9 |
| | | |
| | | """ |
| | | |
| | | _token_factory = staticmethod(lambda: text_(uuid.uuid4().hex)) |
| | | |
| | | def __init__(self, key='_csrft_'): |
| | |
| | | """ Returns ``True`` if the ``supplied_token`` is valid.""" |
| | | expected_token = self.get_csrf_token(request) |
| | | return not strings_differ( |
| | | bytes_(expected_token), bytes_(supplied_token)) |
| | | bytes_(expected_token), bytes_(supplied_token) |
| | | ) |
| | | |
| | | |
| | | @implementer(ICSRFStoragePolicy) |
| | |
| | | Added the ``samesite`` option and made the default ``'Lax'``. |
| | | |
| | | """ |
| | | |
| | | _token_factory = staticmethod(lambda: text_(uuid.uuid4().hex)) |
| | | |
| | | def __init__(self, cookie_name='csrf_token', secure=False, httponly=False, |
| | | domain=None, max_age=None, path='/', samesite='Lax'): |
| | | def __init__( |
| | | self, |
| | | cookie_name='csrf_token', |
| | | secure=False, |
| | | httponly=False, |
| | | domain=None, |
| | | max_age=None, |
| | | path='/', |
| | | samesite='Lax', |
| | | ): |
| | | serializer = SimpleSerializer() |
| | | self.cookie_profile = CookieProfile( |
| | | cookie_name=cookie_name, |
| | |
| | | """ Sets a new CSRF token into the request and returns it. """ |
| | | token = self._token_factory() |
| | | request.cookies[self.cookie_name] = token |
| | | |
| | | def set_cookie(request, response): |
| | | self.cookie_profile.set_cookies( |
| | | response, |
| | | token, |
| | | ) |
| | | self.cookie_profile.set_cookies(response, token) |
| | | |
| | | request.add_response_callback(set_cookie) |
| | | return token |
| | | |
| | |
| | | """ Returns ``True`` if the ``supplied_token`` is valid.""" |
| | | expected_token = self.get_csrf_token(request) |
| | | return not strings_differ( |
| | | bytes_(expected_token), bytes_(supplied_token)) |
| | | bytes_(expected_token), bytes_(supplied_token) |
| | | ) |
| | | |
| | | |
| | | def get_csrf_token(request): |
| | |
| | | return csrf.new_csrf_token(request) |
| | | |
| | | |
| | | def check_csrf_token(request, |
| | | token='csrf_token', |
| | | header='X-CSRF-Token', |
| | | raises=True): |
| | | def check_csrf_token( |
| | | request, token='csrf_token', header='X-CSRF-Token', raises=True |
| | | ): |
| | | """ Check the CSRF token returned by the |
| | | :class:`pyramid.interfaces.ICSRFStoragePolicy` implementation against the |
| | | value in ``request.POST.get(token)`` (if a POST request) or |
| | |
| | | Check the ``Origin`` of the request to see if it is a cross site request or |
| | | not. |
| | | |
| | | If the value supplied by the ``Origin`` or ``Referer`` header isn't one of the |
| | | trusted origins and ``raises`` is ``True``, this function will raise a |
| | | If the value supplied by the ``Origin`` or ``Referer`` header isn't one of |
| | | the trusted origins and ``raises`` is ``True``, this function will raise a |
| | | :exc:`pyramid.exceptions.BadCSRFOrigin` exception, but if ``raises`` is |
| | | ``False``, this function will return ``False`` instead. If the CSRF origin |
| | | checks are successful this function will return ``True`` unconditionally. |
| | |
| | | Moved from :mod:`pyramid.session` to :mod:`pyramid.csrf` |
| | | |
| | | """ |
| | | |
| | | def _fail(reason): |
| | | if raises: |
| | | raise BadCSRFOrigin(reason) |
| | |
| | | if trusted_origins is None: |
| | | trusted_origins = aslist( |
| | | request.registry.settings.get( |
| | | "pyramid.csrf_trusted_origins", []) |
| | | "pyramid.csrf_trusted_origins", [] |
| | | ) |
| | | ) |
| | | |
| | | if request.host_port not in set(["80", "443"]): |
| | |
| | | |
| | | # Actually check to see if the request's origin matches any of our |
| | | # trusted origins. |
| | | if not any(is_same_domain(originp.netloc, host) |
| | | for host in trusted_origins): |
| | | if not any( |
| | | is_same_domain(originp.netloc, host) for host in trusted_origins |
| | | ): |
| | | reason = ( |
| | | "Referer checking failed - {0} does not match any trusted " |
| | | "origins." |