Chris McDonough
2013-07-31 5fc0d36724a6197c8c0106e846d8e78e1219b1fe
commit | author | age
a00621 1 import re
CM 2
3 from pyramid.exceptions import ConfigurationError
4
c25a8f 5 from pyramid.compat import is_nonstr_iter
CM 6
a00621 7 from pyramid.traversal import (
CM 8     find_interface,
9     traversal_path,
c25a8f 10     resource_path_tuple
a00621 11     )
CM 12
13 from pyramid.urldispatch import _compile_route
9c8ec5 14 from pyramid.util import object_description
643a83 15 from pyramid.session import check_csrf_token
c7337b 16 from pyramid.security import effective_principals
643a83 17
a00621 18 from .util import as_sorted_tuple
CM 19
267dbd 20 _marker = object()
CM 21
a00621 22 class XHRPredicate(object):
9c8ec5 23     def __init__(self, val, config):
a00621 24         self.val = bool(val)
CM 25
4d2602 26     def text(self):
4f0b02 27         return 'xhr = %s' % self.val
a00621 28
4d2602 29     phash = text
a00621 30
CM 31     def __call__(self, context, request):
4f0b02 32         return bool(request.is_xhr) is self.val
a00621 33
CM 34 class RequestMethodPredicate(object):
9c8ec5 35     def __init__(self, val, config):
d98612 36         request_method = as_sorted_tuple(val)
CM 37         if 'GET' in request_method and 'HEAD' not in request_method:
38             # GET implies HEAD too
39             request_method = as_sorted_tuple(request_method + ('HEAD',))
40         self.val = request_method
a00621 41
4d2602 42     def text(self):
CM 43         return 'request_method = %s' % (','.join(self.val))
a00621 44
4d2602 45     phash = text
a00621 46
CM 47     def __call__(self, context, request):
48         return request.method in self.val
49
50 class PathInfoPredicate(object):
5fc0d3 51     negatable = True
9c8ec5 52     def __init__(self, val, config):
a00621 53         self.orig = val
CM 54         try:
55             val = re.compile(val)
56         except re.error as why:
57             raise ConfigurationError(why.args[0])
58         self.val = val
59
4d2602 60     def text(self):
a00621 61         return 'path_info = %s' % (self.orig,)
CM 62
4d2602 63     phash = text
a00621 64
CM 65     def __call__(self, context, request):
66         return self.val.match(request.upath_info) is not None
67     
68 class RequestParamPredicate(object):
5fc0d3 69     negatable = True
9c8ec5 70     def __init__(self, val, config):
4b8cf2 71         val = as_sorted_tuple(val)
3fb934 72         reqs = []
MM 73         for p in val:
74             k = p
75             v = None
76             if '=' in p:
77                 k, v = p.split('=', 1)
78                 k, v = k.strip(), v.strip()
79             reqs.append((k, v))
4b8cf2 80         self.val = val
3fb934 81         self.reqs = reqs
a00621 82
4d2602 83     def text(self):
3fb934 84         return 'request_param %s' % ','.join(
5507b8 85             ['%s=%s' % (x,y) if y else x for x, y in self.reqs]
3fb934 86         )
a00621 87
4d2602 88     phash = text
a00621 89
CM 90     def __call__(self, context, request):
3fb934 91         for k, v in self.reqs:
MM 92             actual = request.params.get(k)
93             if actual is None:
94                 return False
95             if v is not None and actual != v:
96                 return False
97         return True
a00621 98
CM 99 class HeaderPredicate(object):
5fc0d3 100     negatable = True
9c8ec5 101     def __init__(self, val, config):
a00621 102         name = val
CM 103         v = None
104         if ':' in name:
b1d2a3 105             name, val_str = name.split(':', 1)
a00621 106             try:
b1d2a3 107                 v = re.compile(val_str)
a00621 108             except re.error as why:
CM 109                 raise ConfigurationError(why.args[0])
110         if v is None:
4d2602 111             self._text = 'header %s' % (name,)
a00621 112         else:
b1d2a3 113             self._text = 'header %s=%s' % (name, val_str)
a00621 114         self.name = name
CM 115         self.val = v
116
4d2602 117     def text(self):
CM 118         return self._text
a00621 119
4d2602 120     phash = text
a00621 121
CM 122     def __call__(self, context, request):
123         if self.val is None:
124             return self.name in request.headers
125         val = request.headers.get(self.name)
126         if val is None:
127             return False
128         return self.val.match(val) is not None
129
130 class AcceptPredicate(object):
9c8ec5 131     def __init__(self, val, config):
a00621 132         self.val = val
CM 133
4d2602 134     def text(self):
a00621 135         return 'accept = %s' % (self.val,)
CM 136
4d2602 137     phash = text
a00621 138
CM 139     def __call__(self, context, request):
140         return self.val in request.accept
141
142 class ContainmentPredicate(object):
5fc0d3 143     negatable = True
9c8ec5 144     def __init__(self, val, config):
CM 145         self.val = config.maybe_dotted(val)
a00621 146
4d2602 147     def text(self):
a00621 148         return 'containment = %s' % (self.val,)
CM 149
4d2602 150     phash = text
a00621 151
CM 152     def __call__(self, context, request):
153         ctx = getattr(request, 'context', context)
154         return find_interface(ctx, self.val) is not None
155     
156 class RequestTypePredicate(object):
5fc0d3 157     negatable = True
9c8ec5 158     def __init__(self, val, config):
a00621 159         self.val = val
CM 160
4d2602 161     def text(self):
a00621 162         return 'request_type = %s' % (self.val,)
CM 163
4d2602 164     phash = text
a00621 165
CM 166     def __call__(self, context, request):
167         return self.val.providedBy(request)
168     
169 class MatchParamPredicate(object):
5fc0d3 170     negatable = True
9c8ec5 171     def __init__(self, val, config):
4b8cf2 172         val = as_sorted_tuple(val)
a00621 173         self.val = val
4d2602 174         reqs = [ p.split('=', 1) for p in val ]
CM 175         self.reqs = [ (x.strip(), y.strip()) for x, y in reqs ]
a00621 176
4d2602 177     def text(self):
CM 178         return 'match_param %s' % ','.join(
179             ['%s=%s' % (x,y) for x, y in self.reqs]
180             )
a00621 181
4d2602 182     phash = text
a00621 183
CM 184     def __call__(self, context, request):
267dbd 185         if not request.matchdict:
CM 186             # might be None
187             return False
a00621 188         for k, v in self.reqs:
CM 189             if request.matchdict.get(k) != v:
190                 return False
191         return True
192     
193 class CustomPredicate(object):
9c8ec5 194     def __init__(self, func, config):
a00621 195         self.func = func
CM 196
4d2602 197     def text(self):
9c8ec5 198         return getattr(
CM 199             self.func,
200             '__text__',
201             'custom predicate: %s' % object_description(self.func)
202             )
a00621 203
4d2602 204     def phash(self):
9c8ec5 205         # using hash() here rather than id() is intentional: we
CM 206         # want to allow custom predicates that are part of
207         # frameworks to be able to define custom __hash__
208         # functions for custom predicates, so that the hash output
209         # of predicate instances which are "logically the same"
210         # may compare equal.
a00621 211         return 'custom:%r' % hash(self.func)
CM 212
213     def __call__(self, context, request):
214         return self.func(context, request)
215     
216     
217 class TraversePredicate(object):
9c8ec5 218     # Can only be used as a *route* "predicate"; it adds 'traverse' to the
CM 219     # matchdict if it's specified in the routing args.  This causes the
220     # ResourceTreeTraverser to use the resolved traverse pattern as the
221     # traversal path.
222     def __init__(self, val, config):
a00621 223         _, self.tgenerate = _compile_route(val)
CM 224         self.val = val
225         
4d2602 226     def text(self):
a00621 227         return 'traverse matchdict pseudo-predicate'
CM 228
4d2602 229     def phash(self):
9c8ec5 230         # This isn't actually a predicate, it's just a infodict modifier that
CM 231         # injects ``traverse`` into the matchdict.  As a result, we don't
232         # need to update the hash.
a00621 233         return ''
CM 234
235     def __call__(self, context, request):
236         if 'traverse' in context:
237             return True
238         m = context['match']
9c8ec5 239         tvalue = self.tgenerate(m)  # tvalue will be urlquoted string
a00621 240         m['traverse'] = traversal_path(tvalue)
9c8ec5 241         # This isn't actually a predicate, it's just a infodict modifier that
CM 242         # injects ``traverse`` into the matchdict.  As a result, we just
243         # return True.
a00621 244         return True
643a83 245
CM 246 class CheckCSRFTokenPredicate(object):
247
248     check_csrf_token = staticmethod(check_csrf_token) # testing
249     
250     def __init__(self, val, config):
251         self.val = val
252
253     def text(self):
254         return 'check_csrf = %s' % (self.val,)
255
256     phash = text
257
258     def __call__(self, context, request):
259         val = self.val
260         if val:
261             if val is True:
262                 val = 'csrf_token'
263             return self.check_csrf_token(request, val, raises=False)
264         return True
265
c25a8f 266 class PhysicalPathPredicate(object):
5fc0d3 267     negatable = True
c25a8f 268     def __init__(self, val, config):
CM 269         if is_nonstr_iter(val):
270             self.val = tuple(val)
271         else:
272             val = tuple(filter(None, val.split('/')))
273             self.val = ('',) + val
274
275     def text(self):
276         return 'physical_path = %s' % (self.val,)
277
278     phash = text
279
280     def __call__(self, context, request):
267dbd 281         if getattr(context, '__name__', _marker) is not _marker:
CM 282             return resource_path_tuple(context) == self.val
283         return False
c25a8f 284
c7337b 285 class EffectivePrincipalsPredicate(object):
5fc0d3 286     negatable = True
c7337b 287     def __init__(self, val, config):
CM 288         if is_nonstr_iter(val):
289             self.val = set(val)
290         else:
291             self.val = set((val,))
292
293     def text(self):
294         return 'effective_principals = %s' % sorted(list(self.val))
295
296     phash = text
297
298     def __call__(self, context, request):
299         req_principals = effective_principals(request)
300         if is_nonstr_iter(req_principals):
301             rpset = set(req_principals)
302             if self.val.issubset(rpset):
303                 return True
304         return False
32333e 305