From 8226534e173df938c533ebab6db8cd08a60901b9 Mon Sep 17 00:00:00 2001
From: Michael Merickel <michael@merickel.org>
Date: Sun, 18 Jun 2017 06:53:01 +0200
Subject: [PATCH] add a router.request_context context manager

---
 pyramid/threadlocal.py       |   27 ++++++
 pyramid/interfaces.py        |   50 ++++++++++--
 pyramid/router.py            |   81 ++++++++++----------
 pyramid/tests/test_router.py |   69 ++++++++++------
 4 files changed, 148 insertions(+), 79 deletions(-)

diff --git a/pyramid/interfaces.py b/pyramid/interfaces.py
index c6fbe3a..e9cc007 100644
--- a/pyramid/interfaces.py
+++ b/pyramid/interfaces.py
@@ -679,18 +679,41 @@
         """
 
 class IRouter(Interface):
-    """ WSGI application which routes requests to 'view' code based on
-    a view registry."""
+    """
+    WSGI application which routes requests to 'view' code based on
+    a view registry.
+
+    """
     registry = Attribute(
         """Component architecture registry local to this application.""")
 
-    def make_request(environ):
+    def request_context(environ):
         """
-        Create a new request object.
+        Create a new request context from a WSGI environ.
 
-        This method initializes a new :class:`pyramid.interfaces.IRequest`
-        object using the application's
-        :class:`pyramid.interfaces.IRequestFactory`.
+        The request context is used to push/pop the threadlocals required
+        when processing the request. It also contains an initialized
+        :class:`pyramid.interfaces.IRequest` instance using the registered
+        :class:`pyramid.interfaces.IRequestFactory`. The context may be
+        used as a context manager to control the threadlocal lifecycle:
+
+        .. code-block:: python
+
+            with router.request_context(environ) as request:
+                ...
+
+        Alternatively, the context may be used without the ``with`` statement
+        by manually invoking its ``begin()`` and ``end()`` methods.
+
+        .. code-block:: python
+
+            ctx = router.request_context(environ)
+            request = ctx.begin()
+            try:
+                ...
+            finally:
+                ctx.end()
+
         """
 
     def invoke_request(request):
@@ -698,6 +721,10 @@
         Invoke the :app:`Pyramid` request pipeline.
 
         See :ref:`router_chapter` for information on the request pipeline.
+
+        The output should be a :class:`pyramid.interfaces.IResponse` object
+        or a raised exception.
+
         """
 
 class IExecutionPolicy(Interface):
@@ -716,13 +743,16 @@
         object or an exception that will be handled by WSGI middleware.
 
         The default execution policy simply creates a request and sends it
-        through the pipeline:
+        through the pipeline, attempting to render any exception that escapes:
 
         .. code-block:: python
 
             def simple_execution_policy(environ, router):
-                request = router.make_request(environ)
-                return router.invoke_request(request)
+                with router.request_context(environ) as request:
+                    try:
+                        return router.invoke_request(request)
+                    except Exception:
+                        return request.invoke_exception_view(reraise=True)
         """
 
 class ISettings(IDict):
diff --git a/pyramid/router.py b/pyramid/router.py
index a02ff17..49b7b60 100644
--- a/pyramid/router.py
+++ b/pyramid/router.py
@@ -1,4 +1,3 @@
-import sys
 from zope.interface import (
     implementer,
     providedBy,
@@ -25,12 +24,11 @@
     BeforeTraversal,
     )
 
-from pyramid.compat import reraise
 from pyramid.httpexceptions import HTTPNotFound
 from pyramid.request import Request
 from pyramid.view import _call_view
 from pyramid.request import apply_request_extensions
-from pyramid.threadlocal import manager
+from pyramid.threadlocal import RequestContext
 
 from pyramid.traversal import (
     DefaultRootFactory,
@@ -42,8 +40,6 @@
 
     debug_notfound = False
     debug_routematch = False
-
-    threadlocal_manager = manager
 
     def __init__(self, registry):
         q = registry.queryUtility
@@ -195,16 +191,35 @@
         extensions = self.request_extensions
         if extensions is not None:
             apply_request_extensions(request, extensions=extensions)
-        return self.invoke_request(request, _use_tweens=use_tweens)
+        with RequestContext(request):
+            return self.invoke_request(request, _use_tweens=use_tweens)
 
-    def make_request(self, environ):
+    def request_context(self, environ):
         """
-        Configure a request object for use by the router.
+        Create a new request context from a WSGI environ.
 
-        The request is created using the configured
-        :class:`pyramid.interfaces.IRequestFactory` and will have any
-        configured request methods / properties added that were set by
-        :meth:`pyramid.config.Configurator.add_request_method`.
+        The request context is used to push/pop the threadlocals required
+        when processing the request. It also contains an initialized
+        :class:`pyramid.interfaces.IRequest` instance using the registered
+        :class:`pyramid.interfaces.IRequestFactory`. The context may be
+        used as a context manager to control the threadlocal lifecycle:
+
+        .. code-block:: python
+
+            with router.request_context(environ) as request:
+                ...
+
+        Alternatively, the context may be used without the ``with`` statement
+        by manually invoking its ``begin()`` and ``end()`` methods.
+
+        .. code-block:: python
+
+            ctx = router.request_context(environ)
+            request = ctx.begin()
+            try:
+                ...
+            finally:
+                ctx.end()
 
         """
         request = self.request_factory(environ)
@@ -213,7 +228,7 @@
         extensions = self.request_extensions
         if extensions is not None:
             apply_request_extensions(request, extensions=extensions)
-        return request
+        return RequestContext(request)
 
     def invoke_request(self, request, _use_tweens=True):
         """
@@ -222,11 +237,8 @@
 
         """
         registry = self.registry
-        has_listeners = self.registry.has_listeners
-        notify = self.registry.notify
-        threadlocals = {'registry': registry, 'request': request}
-        manager = self.threadlocal_manager
-        manager.push(threadlocals)
+        has_listeners = registry.has_listeners
+        notify = registry.notify
 
         if _use_tweens:
             handle_request = self.handle_request
@@ -234,23 +246,18 @@
             handle_request = self.orig_handle_request
 
         try:
+            response = handle_request(request)
 
-            try:
-                response = handle_request(request)
+            if request.response_callbacks:
+                request._process_response_callbacks(response)
 
-                if request.response_callbacks:
-                    request._process_response_callbacks(response)
+            has_listeners and notify(NewResponse(request, response))
 
-                has_listeners and notify(NewResponse(request, response))
-
-                return response
-
-            finally:
-                if request.finished_callbacks:
-                    request._process_finished_callbacks()
+            return response
 
         finally:
-            manager.pop()
+            if request.finished_callbacks:
+                request._process_finished_callbacks()
 
     def __call__(self, environ, start_response):
         """
@@ -264,14 +271,8 @@
         return response(environ, start_response)
 
 def default_execution_policy(environ, router):
-    request = router.make_request(environ)
-    try:
-        return router.invoke_request(request)
-    except Exception:
-        exc_info = sys.exc_info()
+    with router.request_context(environ) as request:
         try:
-            return request.invoke_exception_view(exc_info)
-        except HTTPNotFound:
-            reraise(*exc_info)
-        finally:
-            del exc_info  # avoid local ref cycle
+            return router.invoke_request(request)
+        except Exception:
+            return request.invoke_exception_view(reraise=True)
diff --git a/pyramid/tests/test_router.py b/pyramid/tests/test_router.py
index bd02382..6097018 100644
--- a/pyramid/tests/test_router.py
+++ b/pyramid/tests/test_router.py
@@ -641,22 +641,6 @@
         result = router(environ, start_response)
         self.assertEqual(result, exception_response.app_iter)
 
-    def test_call_pushes_and_pops_threadlocal_manager(self):
-        from pyramid.interfaces import IViewClassifier
-        context = DummyContext()
-        self._registerTraverserFactory(context)
-        response = DummyResponse()
-        response.app_iter = ['Hello world']
-        view = DummyView(response)
-        environ = self._makeEnviron()
-        self._registerView(view, '', IViewClassifier, None, None)
-        router = self._makeOne()
-        start_response = DummyStartResponse()
-        router.threadlocal_manager = DummyThreadLocalManager()
-        router(environ, start_response)
-        self.assertEqual(len(router.threadlocal_manager.pushed), 1)
-        self.assertEqual(len(router.threadlocal_manager.popped), 1)
-
     def test_call_route_matches_and_has_factory(self):
         from pyramid.interfaces import IViewClassifier
         logger = self._registerLogger()
@@ -1311,6 +1295,48 @@
         result = router(environ, start_response)
         self.assertEqual(result, ["Hello, world"])
 
+    def test_request_context_with_statement(self):
+        from pyramid.threadlocal import get_current_request
+        from pyramid.interfaces import IExecutionPolicy
+        from pyramid.request import Request
+        from pyramid.response import Response
+        registry = self.config.registry
+        result = []
+        def dummy_policy(environ, router):
+            with router.request_context(environ):
+                result.append(get_current_request())
+            result.append(get_current_request())
+            return Response(status=200, body=b'foo')
+        registry.registerUtility(dummy_policy, IExecutionPolicy)
+        router = self._makeOne()
+        resp = Request.blank('/test_path').get_response(router)
+        self.assertEqual(resp.status_code, 200)
+        self.assertEqual(resp.body, b'foo')
+        self.assertEqual(result[0].path_info, '/test_path')
+        self.assertEqual(result[1], None)
+
+    def test_request_context_manually(self):
+        from pyramid.threadlocal import get_current_request
+        from pyramid.interfaces import IExecutionPolicy
+        from pyramid.request import Request
+        from pyramid.response import Response
+        registry = self.config.registry
+        result = []
+        def dummy_policy(environ, router):
+            ctx = router.request_context(environ)
+            ctx.begin()
+            result.append(get_current_request())
+            ctx.end()
+            result.append(get_current_request())
+            return Response(status=200, body=b'foo')
+        registry.registerUtility(dummy_policy, IExecutionPolicy)
+        router = self._makeOne()
+        resp = Request.blank('/test_path').get_response(router)
+        self.assertEqual(resp.status_code, 200)
+        self.assertEqual(resp.body, b'foo')
+        self.assertEqual(result[0].path_info, '/test_path')
+        self.assertEqual(result[1], None)
+
 class DummyPredicate(object):
     def __call__(self, info, request):
         return True
@@ -1361,17 +1387,6 @@
         self.environ = environ
         start_response(self.status, self.headerlist)
         return self.app_iter
-    
-class DummyThreadLocalManager:
-    def __init__(self):
-        self.pushed = []
-        self.popped = []
-
-    def push(self, val):
-        self.pushed.append(val)
-
-    def pop(self):
-        self.popped.append(True)
     
 class DummyAuthenticationPolicy:
     pass
diff --git a/pyramid/threadlocal.py b/pyramid/threadlocal.py
index 9429fe9..e8f8257 100644
--- a/pyramid/threadlocal.py
+++ b/pyramid/threadlocal.py
@@ -36,7 +36,8 @@
 manager = ThreadLocalManager(default=defaults)
 
 def get_current_request():
-    """Return the currently active request or ``None`` if no request
+    """
+    Return the currently active request or ``None`` if no request
     is currently active.
 
     This function should be used *extremely sparingly*, usually only
@@ -44,11 +45,13 @@
     ``get_current_request`` outside a testing context because its
     usage makes it possible to write code that can be neither easily
     tested nor scripted.
+
     """
     return manager.get()['request']
 
 def get_current_registry(context=None): # context required by getSiteManager API
-    """Return the currently active :term:`application registry` or the
+    """
+    Return the currently active :term:`application registry` or the
     global application registry if no request is currently active.
 
     This function should be used *extremely sparingly*, usually only
@@ -56,5 +59,25 @@
     ``get_current_registry`` outside a testing context because its
     usage makes it possible to write code that can be neither easily
     tested nor scripted.
+
     """
     return manager.get()['registry']
+
+class RequestContext(object):
+    def __init__(self, request):
+        self.request = request
+
+    def begin(self):
+        request = self.request
+        registry = request.registry
+        manager.push({'registry': registry, 'request': request})
+        return request
+
+    def end(self):
+        manager.pop()
+
+    def __enter__(self):
+        return self.begin()
+
+    def __exit__(self, *args):
+        self.end()

--
Gitblit v1.9.3