Michael Merickel
2018-10-17 e14661417e7ceb50d5cf83bbd6abd6b133e473ba
commit | author | age
c7974f 1 import re
MM 2
3 from pyramid.exceptions import ConfigurationError
4
5 from pyramid.compat import is_nonstr_iter
6
a2c7c7 7 from pyramid.csrf import check_csrf_token
c7974f 8 from pyramid.traversal import (
MM 9     find_interface,
10     traversal_path,
0c29cf 11     resource_path_tuple,
52fde9 12 )
c7974f 13
0c29cf 14 from pyramid.urldispatch import _compile_route
MM 15 from pyramid.util import as_sorted_tuple, object_description
16
c7974f 17 _marker = object()
0c29cf 18
c7974f 19
MM 20 class XHRPredicate(object):
21     def __init__(self, val, config):
22         self.val = bool(val)
23
24     def text(self):
25         return 'xhr = %s' % self.val
26
27     phash = text
28
29     def __call__(self, context, request):
30         return bool(request.is_xhr) is self.val
0c29cf 31
c7974f 32
MM 33 class RequestMethodPredicate(object):
34     def __init__(self, val, config):
35         request_method = as_sorted_tuple(val)
36         if 'GET' in request_method and 'HEAD' not in request_method:
37             # GET implies HEAD too
38             request_method = as_sorted_tuple(request_method + ('HEAD',))
39         self.val = request_method
40
41     def text(self):
42         return 'request_method = %s' % (','.join(self.val))
43
44     phash = text
45
46     def __call__(self, context, request):
47         return request.method in self.val
48
0c29cf 49
c7974f 50 class PathInfoPredicate(object):
MM 51     def __init__(self, val, config):
52         self.orig = val
53         try:
54             val = re.compile(val)
55         except re.error as why:
56             raise ConfigurationError(why.args[0])
57         self.val = val
58
59     def text(self):
60         return 'path_info = %s' % (self.orig,)
61
62     phash = text
63
64     def __call__(self, context, request):
65         return self.val.match(request.upath_info) is not None
0c29cf 66
MM 67
c7974f 68 class RequestParamPredicate(object):
MM 69     def __init__(self, val, config):
70         val = as_sorted_tuple(val)
71         reqs = []
72         for p in val:
73             k = p
74             v = None
75             if p.startswith('='):
76                 if '=' in p[1:]:
77                     k, v = p[1:].split('=', 1)
78                     k = '=' + k
79                     k, v = k.strip(), v.strip()
80             elif '=' in p:
81                 k, v = p.split('=', 1)
82                 k, v = k.strip(), v.strip()
83             reqs.append((k, v))
84         self.val = val
85         self.reqs = reqs
86
87     def text(self):
88         return 'request_param %s' % ','.join(
0c29cf 89             ['%s=%s' % (x, y) if y else x for x, y in self.reqs]
c7974f 90         )
MM 91
92     phash = text
93
94     def __call__(self, context, request):
95         for k, v in self.reqs:
96             actual = request.params.get(k)
97             if actual is None:
98                 return False
99             if v is not None and actual != v:
100                 return False
101         return True
0c29cf 102
c7974f 103
MM 104 class HeaderPredicate(object):
105     def __init__(self, val, config):
106         name = val
107         v = None
108         if ':' in name:
109             name, val_str = name.split(':', 1)
110             try:
111                 v = re.compile(val_str)
112             except re.error as why:
113                 raise ConfigurationError(why.args[0])
114         if v is None:
115             self._text = 'header %s' % (name,)
116         else:
117             self._text = 'header %s=%s' % (name, val_str)
118         self.name = name
119         self.val = v
120
121     def text(self):
122         return self._text
123
124     phash = text
125
126     def __call__(self, context, request):
127         if self.val is None:
128             return self.name in request.headers
129         val = request.headers.get(self.name)
130         if val is None:
131             return False
132         return self.val.match(val) is not None
133
0c29cf 134
c7974f 135 class AcceptPredicate(object):
4a9f4f 136     _is_using_deprecated_ranges = False
MM 137
138     def __init__(self, values, config):
139         if not is_nonstr_iter(values):
140             values = (values,)
141         # deprecated media ranges were only supported in versions of the
142         # predicate that didn't support lists, so check it here
143         if len(values) == 1 and '*' in values[0]:
144             self._is_using_deprecated_ranges = True
145         self.values = values
c7974f 146
MM 147     def text(self):
30f79d 148         return 'accept = %s' % (', '.join(self.values),)
c7974f 149
MM 150     phash = text
151
152     def __call__(self, context, request):
4a9f4f 153         if self._is_using_deprecated_ranges:
MM 154             return self.values[0] in request.accept
2497d0 155         return bool(request.accept.acceptable_offers(self.values))
c7974f 156
0c29cf 157
c7974f 158 class ContainmentPredicate(object):
MM 159     def __init__(self, val, config):
160         self.val = config.maybe_dotted(val)
161
162     def text(self):
163         return 'containment = %s' % (self.val,)
164
165     phash = text
166
167     def __call__(self, context, request):
168         ctx = getattr(request, 'context', context)
169         return find_interface(ctx, self.val) is not None
0c29cf 170
MM 171
c7974f 172 class RequestTypePredicate(object):
MM 173     def __init__(self, val, config):
174         self.val = val
175
176     def text(self):
177         return 'request_type = %s' % (self.val,)
178
179     phash = text
180
181     def __call__(self, context, request):
182         return self.val.providedBy(request)
0c29cf 183
MM 184
c7974f 185 class MatchParamPredicate(object):
MM 186     def __init__(self, val, config):
187         val = as_sorted_tuple(val)
188         self.val = val
0c29cf 189         reqs = [p.split('=', 1) for p in val]
MM 190         self.reqs = [(x.strip(), y.strip()) for x, y in reqs]
c7974f 191
MM 192     def text(self):
193         return 'match_param %s' % ','.join(
0c29cf 194             ['%s=%s' % (x, y) for x, y in self.reqs]
MM 195         )
c7974f 196
MM 197     phash = text
198
199     def __call__(self, context, request):
200         if not request.matchdict:
201             # might be None
202             return False
203         for k, v in self.reqs:
204             if request.matchdict.get(k) != v:
205                 return False
206         return True
0c29cf 207
MM 208
c7974f 209 class CustomPredicate(object):
MM 210     def __init__(self, func, config):
211         self.func = func
212
213     def text(self):
214         return getattr(
215             self.func,
216             '__text__',
0c29cf 217             'custom predicate: %s' % object_description(self.func),
MM 218         )
c7974f 219
MM 220     def phash(self):
221         # using hash() here rather than id() is intentional: we
222         # want to allow custom predicates that are part of
223         # frameworks to be able to define custom __hash__
224         # functions for custom predicates, so that the hash output
225         # of predicate instances which are "logically the same"
226         # may compare equal.
227         return 'custom:%r' % hash(self.func)
228
229     def __call__(self, context, request):
230         return self.func(context, request)
0c29cf 231
MM 232
c7974f 233 class TraversePredicate(object):
MM 234     # Can only be used as a *route* "predicate"; it adds 'traverse' to the
235     # matchdict if it's specified in the routing args.  This causes the
236     # ResourceTreeTraverser to use the resolved traverse pattern as the
237     # traversal path.
238     def __init__(self, val, config):
239         _, self.tgenerate = _compile_route(val)
240         self.val = val
0c29cf 241
c7974f 242     def text(self):
MM 243         return 'traverse matchdict pseudo-predicate'
244
245     def phash(self):
246         # This isn't actually a predicate, it's just a infodict modifier that
247         # injects ``traverse`` into the matchdict.  As a result, we don't
248         # need to update the hash.
249         return ''
250
251     def __call__(self, context, request):
252         if 'traverse' in context:
253             return True
254         m = context['match']
255         tvalue = self.tgenerate(m)  # tvalue will be urlquoted string
256         m['traverse'] = traversal_path(tvalue)
257         # This isn't actually a predicate, it's just a infodict modifier that
258         # injects ``traverse`` into the matchdict.  As a result, we just
259         # return True.
260         return True
261
0c29cf 262
c7974f 263 class CheckCSRFTokenPredicate(object):
MM 264
0c29cf 265     check_csrf_token = staticmethod(check_csrf_token)  # testing
MM 266
c7974f 267     def __init__(self, val, config):
MM 268         self.val = val
269
270     def text(self):
271         return 'check_csrf = %s' % (self.val,)
272
273     phash = text
274
275     def __call__(self, context, request):
276         val = self.val
277         if val:
278             if val is True:
279                 val = 'csrf_token'
280             return self.check_csrf_token(request, val, raises=False)
281         return True
0c29cf 282
c7974f 283
MM 284 class PhysicalPathPredicate(object):
285     def __init__(self, val, config):
286         if is_nonstr_iter(val):
287             self.val = tuple(val)
288         else:
289             val = tuple(filter(None, val.split('/')))
290             self.val = ('',) + val
291
292     def text(self):
293         return 'physical_path = %s' % (self.val,)
294
295     phash = text
296
297     def __call__(self, context, request):
298         if getattr(context, '__name__', _marker) is not _marker:
299             return resource_path_tuple(context) == self.val
300         return False
0c29cf 301
c7974f 302
MM 303 class EffectivePrincipalsPredicate(object):
304     def __init__(self, val, config):
305         if is_nonstr_iter(val):
306             self.val = set(val)
307         else:
308             self.val = set((val,))
309
310     def text(self):
311         return 'effective_principals = %s' % sorted(list(self.val))
312
313     phash = text
314
315     def __call__(self, context, request):
316         req_principals = request.effective_principals
317         if is_nonstr_iter(req_principals):
318             rpset = set(req_principals)
319             if self.val.issubset(rpset):
320                 return True
321         return False
322
0c29cf 323
52fde9 324 class Notted(object):
MM 325     def __init__(self, predicate):
326         self.predicate = predicate
327
328     def _notted_text(self, val):
329         # if the underlying predicate doesnt return a value, it's not really
330         # a predicate, it's just something pretending to be a predicate,
331         # so dont update the hash
332         if val:
333             val = '!' + val
334         return val
335
336     def text(self):
337         return self._notted_text(self.predicate.text())
338
339     def phash(self):
340         return self._notted_text(self.predicate.phash())
341
342     def __call__(self, context, request):
343         result = self.predicate(context, request)
344         phash = self.phash()
345         if phash:
346             result = not result
347         return result