Tres Seaver
2016-05-31 455778d138ea623d224c9206e5001fd2a1fd7e1c
repoze/who/tests/test_middleware.py
@@ -1,5 +1,6 @@
import unittest
class TestMiddleware(unittest.TestCase):
    def _getTargetClass(self):
@@ -223,6 +224,29 @@
        self.assertEqual(start_response.status, '200 OK')
        self.assertEqual(start_response.headers, headers)
    def test_call_200_no_challengers_app_calls_forget(self):
        # See https://github.com/repoze/repoze.who/issues/21
        environ = self._makeEnviron()
        remember_headers = [('remember', '1')]
        forget_headers = [('forget', '1')]
        app = DummyLogoutApp('200 OK')
        credentials = {'login':'chris', 'password':'password'}
        identifier = DummyIdentifier(
            credentials,
            remember_headers=remember_headers,
            forget_headers=forget_headers)
        identifiers = [ ('identifier', identifier) ]
        authenticator = DummyAuthenticator()
        authenticators = [ ('authenticator', authenticator) ]
        mw = self._makeOne(
            app=app, identifiers=identifiers, authenticators=authenticators)
        start_response = DummyStartResponse()
        result = mw(environ, start_response)
        self.assertEqual(mw.app.environ, environ)
        self.assertEqual(result, ['body'])
        self.assertEqual(start_response.status, '200 OK')
        self.assertEqual(start_response.headers, forget_headers)
    def test_call_401_no_identifiers(self):
        from webob.exc import HTTPUnauthorized
        environ = self._makeEnviron()
@@ -233,9 +257,9 @@
        challengers = [ ('challenge', challenge) ]
        mw = self._makeOne(app=app, challengers=challengers)
        start_response = DummyStartResponse()
        result = mw(environ, start_response)
        result = b''.join(mw(environ, start_response)).decode('ascii')
        self.assertEqual(environ['challenged'], challenge_app)
        self.failUnless(result[0].startswith('401 Unauthorized'))
        self.assertTrue(result.startswith('401 Unauthorized'))
    def test_call_401_challenger_and_identifier_no_authenticator(self):
        from webob.exc import HTTPUnauthorized
@@ -252,9 +276,9 @@
                           identifiers=identifiers)
        start_response = DummyStartResponse()
        result = mw(environ, start_response)
        result = b''.join(mw(environ, start_response)).decode('ascii')
        self.assertEqual(environ['challenged'], challenge_app)
        self.failUnless(result[0].startswith('401 Unauthorized'))
        self.assertTrue(result.startswith('401 Unauthorized'))
        self.assertEqual(identifier.forgotten, False)
        self.assertEqual(environ.get('REMOTE_USER'), None)
@@ -275,9 +299,9 @@
                           identifiers=identifiers,
                           authenticators=authenticators)
        start_response = DummyStartResponse()
        result = mw(environ, start_response)
        result = b''.join(mw(environ, start_response)).decode('ascii')
        self.assertEqual(environ['challenged'], challenge_app)
        self.failUnless(result[0].startswith('401 Unauthorized'))
        self.assertTrue(result.startswith('401 Unauthorized'))
        # @@ unfuck
##         self.assertEqual(identifier.forgotten, identifier.credentials)
        self.assertEqual(environ['REMOTE_USER'], 'chris')
@@ -385,8 +409,8 @@
                           authenticators=authenticators,
                           mdproviders=mdproviders)
        start_response = DummyStartResponse()
        result = ''.join(mw(environ, start_response))
        self.failUnless(result.startswith('302 Found'))
        result = b''.join(mw(environ, start_response)).decode('ascii')
        self.assertTrue(result.startswith('302 Found'))
        self.assertEqual(start_response.status, '302 Found')
        headers = start_response.headers
        #self.assertEqual(len(headers), 3, headers)
@@ -397,7 +421,7 @@
        self.assertEqual(headers[3],
                         ('a', '1'))
        self.assertEqual(start_response.exc_info, None)
        self.failIf('repoze.who.application' in environ)
        self.assertFalse('repoze.who.application' in environ)
    def test_call_app_doesnt_call_start_response(self):
        from webob.exc import HTTPUnauthorized
@@ -443,9 +467,9 @@
                           authenticators=authenticators,
                           mdproviders=mdproviders)
        start_response = DummyStartResponse()
        result = mw(environ, start_response)
        self.failUnless(result[0].startswith('401 Unauthorized'))
        self.failUnless(app._iterable._closed)
        result = b''.join(mw(environ, start_response)).decode('ascii')
        self.assertTrue(result.startswith('401 Unauthorized'))
        self.assertTrue(app._iterable._closed)
    def test_call_w_challenge_but_no_challenger_still_closes_iterable(self):
        environ = self._makeEnviron()
@@ -465,7 +489,7 @@
                           mdproviders=mdproviders)
        start_response = DummyStartResponse()
        self.assertRaises(RuntimeError, mw, environ, start_response)
        self.failUnless(app._iterable._closed)
        self.assertTrue(app._iterable._closed)
    # XXX need more call tests:
    #  - auth_id sorting
@@ -484,7 +508,7 @@
        wrapper = self._makeOne(None)
        self.assertEqual(wrapper.start_response, None)
        self.assertEqual(wrapper.headers, [])
        self.failUnless(wrapper.buffer)
        self.assertTrue(wrapper.buffer)
    def test_finish_response(self):
        from repoze.who._compat import StringIO
@@ -531,29 +555,38 @@
        self.assertEqual(L, ['yo!'])
        self.assertEqual(list(newgen), ['a', 'b'])
    def test_w_empty_generator(self):
        def gen():
            if False:
                yield 'a'  # pragma: no cover
        newgen = self._callFUT(gen())
        self.assertEqual(list(newgen), [])
    def test_w_iterator_having_close(self):
        def gen():
            yield 'a'
            yield 'b'
        iterable = DummyIterableWithClose(gen())
        newgen = self._callFUT(iterable)
        self.failIf(iterable._closed)
        self.assertFalse(iterable._closed)
        self.assertEqual(list(newgen), ['a', 'b'])
        self.failUnless(iterable._closed)
        self.assertTrue(iterable._closed)
class TestMakeTestMiddleware(unittest.TestCase):
    def setUp(self):
        import os
        self._old_WHO_LOG = os.environ.get('WHO_LOG')
        try:
            del os.environ['WHO_LOG']
        except KeyError:
            pass
    def tearDown(self):
        import os
        if self._old_WHO_LOG is not None:
            os.environ['WHO_LOG'] = self._old_WHO_LOG
        else:
            if 'WHO_LOG' in os.environ:
                del os.environ['WHO_LOG']
        try:
            del os.environ['WHO_LOG']
        except KeyError:
            pass
    def _getFactory(self):
        from repoze.who.middleware import make_test_middleware
@@ -581,13 +614,13 @@
        middleware = factory(app, global_conf)
        self.assertEqual(middleware.logger.getEffectiveLevel(), logging.DEBUG)
class DummyApp:
class DummyApp(object):
    environ = None
    def __call__(self, environ, start_response):
        self.environ = environ
        return []
class DummyWorkingApp:
class DummyWorkingApp(object):
    def __init__(self, status, headers):
        self.status = status
        self.headers = headers
@@ -597,7 +630,18 @@
        start_response(self.status, self.headers)
        return ['body']
class DummyGeneratorApp:
class DummyLogoutApp(object):
    def __init__(self, status):
        self.status = status
    def __call__(self, environ, start_response):
        self.environ = environ
        api = environ['repoze.who.api']
        headers = api.logout()
        start_response(self.status, headers)
        return ['body']
class DummyGeneratorApp(object):
    def __init__(self, status, headers):
        self.status = status
        self.headers = headers
@@ -609,7 +653,7 @@
            yield 'body'
        return gen()
class DummyIterableWithClose:
class DummyIterableWithClose(object):
    _closed = False
    def __init__(self, iterable):
        self._iterable = iterable
@@ -618,7 +662,7 @@
    def close(self):
        self._closed = True
class DummyIterableWithCloseApp:
class DummyIterableWithCloseApp(object):
    def __init__(self, status, headers):
        self.status = status
        self.headers = headers
@@ -629,7 +673,7 @@
        start_response(self.status, self.headers)
        return self._iterable
class DummyIdentityResetApp:
class DummyIdentityResetApp(object):
    def __init__(self, status, headers, new_identity):
        self.status = status
        self.headers = headers
@@ -642,7 +686,7 @@
        start_response(self.status, self.headers)
        return ['body']
class DummyChallenger:
class DummyChallenger(object):
    def __init__(self, app=None):
        self.app = app
@@ -650,7 +694,7 @@
        environ['challenged'] = self.app
        return self.app
class DummyIdentifier:
class DummyIdentifier(object):
    forgotten = False
    remembered = False
@@ -674,51 +718,29 @@
        self.remembered = identity
        return self.remember_headers
class DummyAuthenticator:
    def __init__(self, userid=None):
        self.userid = userid
class DummyAuthenticator(object):
    def authenticate(self, environ, credentials):
        if self.userid is None:
            return credentials['login']
        return self.userid
        return credentials['login']
class DummyFailAuthenticator:
    def authenticate(self, environ, credentials):
        return None
class DummyRequestClassifier:
class DummyRequestClassifier(object):
    def __call__(self, environ):
        return 'browser'
class DummyChallengeDecider:
class DummyChallengeDecider(object):
    def __call__(self, environ, status, headers):
        if status.startswith('401 '):
            return True
class DummyNoResultsIdentifier:
    def identify(self, environ):
        return None
    def remember(self, *arg, **kw):
        pass
    def forget(self, *arg, **kw):
        pass
class DummyStartResponse:
class DummyStartResponse(object):
    def __call__(self, status, headers, exc_info=None):
        self.status = status
        self.headers = headers
        self.exc_info = exc_info
        return []
class DummyMDProvider:
class DummyMDProvider(object):
    def __init__(self, metadata=None):
        self._metadata = metadata
    def add_metadata(self, environ, identity):
        return identity.update(self._metadata)
class DummyMultiPlugin:
    pass