Michael Merickel
2018-10-15 433efe06191a7007ca8c5bf8fafee5c7c1439ebb
commit | author | age
c7974f 1 import re
MM 2
3 from pyramid.exceptions import ConfigurationError
4
5 from pyramid.compat import is_nonstr_iter
6
a2c7c7 7 from pyramid.csrf import check_csrf_token
c7974f 8 from pyramid.traversal import (
MM 9     find_interface,
10     traversal_path,
11     resource_path_tuple
12     )
13
14 from pyramid.urldispatch import _compile_route
52fde9 15 from pyramid.util import (
MM 16     as_sorted_tuple,
17     object_description,
18 )
c7974f 19
MM 20 _marker = object()
21
22 class XHRPredicate(object):
23     def __init__(self, val, config):
24         self.val = bool(val)
25
26     def text(self):
27         return 'xhr = %s' % self.val
28
29     phash = text
30
31     def __call__(self, context, request):
32         return bool(request.is_xhr) is self.val
33
34 class RequestMethodPredicate(object):
35     def __init__(self, val, config):
36         request_method = as_sorted_tuple(val)
37         if 'GET' in request_method and 'HEAD' not in request_method:
38             # GET implies HEAD too
39             request_method = as_sorted_tuple(request_method + ('HEAD',))
40         self.val = request_method
41
42     def text(self):
43         return 'request_method = %s' % (','.join(self.val))
44
45     phash = text
46
47     def __call__(self, context, request):
48         return request.method in self.val
49
50 class PathInfoPredicate(object):
51     def __init__(self, val, config):
52         self.orig = val
53         try:
54             val = re.compile(val)
55         except re.error as why:
56             raise ConfigurationError(why.args[0])
57         self.val = val
58
59     def text(self):
60         return 'path_info = %s' % (self.orig,)
61
62     phash = text
63
64     def __call__(self, context, request):
65         return self.val.match(request.upath_info) is not None
66     
67 class RequestParamPredicate(object):
68     def __init__(self, val, config):
69         val = as_sorted_tuple(val)
70         reqs = []
71         for p in val:
72             k = p
73             v = None
74             if p.startswith('='):
75                 if '=' in p[1:]:
76                     k, v = p[1:].split('=', 1)
77                     k = '=' + k
78                     k, v = k.strip(), v.strip()
79             elif '=' in p:
80                 k, v = p.split('=', 1)
81                 k, v = k.strip(), v.strip()
82             reqs.append((k, v))
83         self.val = val
84         self.reqs = reqs
85
86     def text(self):
87         return 'request_param %s' % ','.join(
88             ['%s=%s' % (x,y) if y else x for x, y in self.reqs]
89         )
90
91     phash = text
92
93     def __call__(self, context, request):
94         for k, v in self.reqs:
95             actual = request.params.get(k)
96             if actual is None:
97                 return False
98             if v is not None and actual != v:
99                 return False
100         return True
101
102 class HeaderPredicate(object):
103     def __init__(self, val, config):
104         name = val
105         v = None
106         if ':' in name:
107             name, val_str = name.split(':', 1)
108             try:
109                 v = re.compile(val_str)
110             except re.error as why:
111                 raise ConfigurationError(why.args[0])
112         if v is None:
113             self._text = 'header %s' % (name,)
114         else:
115             self._text = 'header %s=%s' % (name, val_str)
116         self.name = name
117         self.val = v
118
119     def text(self):
120         return self._text
121
122     phash = text
123
124     def __call__(self, context, request):
125         if self.val is None:
126             return self.name in request.headers
127         val = request.headers.get(self.name)
128         if val is None:
129             return False
130         return self.val.match(val) is not None
131
132 class AcceptPredicate(object):
4a9f4f 133     _is_using_deprecated_ranges = False
MM 134
135     def __init__(self, values, config):
136         if not is_nonstr_iter(values):
137             values = (values,)
138         # deprecated media ranges were only supported in versions of the
139         # predicate that didn't support lists, so check it here
140         if len(values) == 1 and '*' in values[0]:
141             self._is_using_deprecated_ranges = True
142         self.values = values
c7974f 143
MM 144     def text(self):
30f79d 145         return 'accept = %s' % (', '.join(self.values),)
c7974f 146
MM 147     phash = text
148
149     def __call__(self, context, request):
4a9f4f 150         if self._is_using_deprecated_ranges:
MM 151             return self.values[0] in request.accept
2497d0 152         return bool(request.accept.acceptable_offers(self.values))
c7974f 153
MM 154 class ContainmentPredicate(object):
155     def __init__(self, val, config):
156         self.val = config.maybe_dotted(val)
157
158     def text(self):
159         return 'containment = %s' % (self.val,)
160
161     phash = text
162
163     def __call__(self, context, request):
164         ctx = getattr(request, 'context', context)
165         return find_interface(ctx, self.val) is not None
166     
167 class RequestTypePredicate(object):
168     def __init__(self, val, config):
169         self.val = val
170
171     def text(self):
172         return 'request_type = %s' % (self.val,)
173
174     phash = text
175
176     def __call__(self, context, request):
177         return self.val.providedBy(request)
178     
179 class MatchParamPredicate(object):
180     def __init__(self, val, config):
181         val = as_sorted_tuple(val)
182         self.val = val
183         reqs = [ p.split('=', 1) for p in val ]
184         self.reqs = [ (x.strip(), y.strip()) for x, y in reqs ]
185
186     def text(self):
187         return 'match_param %s' % ','.join(
188             ['%s=%s' % (x,y) for x, y in self.reqs]
189             )
190
191     phash = text
192
193     def __call__(self, context, request):
194         if not request.matchdict:
195             # might be None
196             return False
197         for k, v in self.reqs:
198             if request.matchdict.get(k) != v:
199                 return False
200         return True
201     
202 class CustomPredicate(object):
203     def __init__(self, func, config):
204         self.func = func
205
206     def text(self):
207         return getattr(
208             self.func,
209             '__text__',
210             'custom predicate: %s' % object_description(self.func)
211             )
212
213     def phash(self):
214         # using hash() here rather than id() is intentional: we
215         # want to allow custom predicates that are part of
216         # frameworks to be able to define custom __hash__
217         # functions for custom predicates, so that the hash output
218         # of predicate instances which are "logically the same"
219         # may compare equal.
220         return 'custom:%r' % hash(self.func)
221
222     def __call__(self, context, request):
223         return self.func(context, request)
224     
225     
226 class TraversePredicate(object):
227     # Can only be used as a *route* "predicate"; it adds 'traverse' to the
228     # matchdict if it's specified in the routing args.  This causes the
229     # ResourceTreeTraverser to use the resolved traverse pattern as the
230     # traversal path.
231     def __init__(self, val, config):
232         _, self.tgenerate = _compile_route(val)
233         self.val = val
234         
235     def text(self):
236         return 'traverse matchdict pseudo-predicate'
237
238     def phash(self):
239         # This isn't actually a predicate, it's just a infodict modifier that
240         # injects ``traverse`` into the matchdict.  As a result, we don't
241         # need to update the hash.
242         return ''
243
244     def __call__(self, context, request):
245         if 'traverse' in context:
246             return True
247         m = context['match']
248         tvalue = self.tgenerate(m)  # tvalue will be urlquoted string
249         m['traverse'] = traversal_path(tvalue)
250         # This isn't actually a predicate, it's just a infodict modifier that
251         # injects ``traverse`` into the matchdict.  As a result, we just
252         # return True.
253         return True
254
255 class CheckCSRFTokenPredicate(object):
256
257     check_csrf_token = staticmethod(check_csrf_token) # testing
258     
259     def __init__(self, val, config):
260         self.val = val
261
262     def text(self):
263         return 'check_csrf = %s' % (self.val,)
264
265     phash = text
266
267     def __call__(self, context, request):
268         val = self.val
269         if val:
270             if val is True:
271                 val = 'csrf_token'
272             return self.check_csrf_token(request, val, raises=False)
273         return True
274
275 class PhysicalPathPredicate(object):
276     def __init__(self, val, config):
277         if is_nonstr_iter(val):
278             self.val = tuple(val)
279         else:
280             val = tuple(filter(None, val.split('/')))
281             self.val = ('',) + val
282
283     def text(self):
284         return 'physical_path = %s' % (self.val,)
285
286     phash = text
287
288     def __call__(self, context, request):
289         if getattr(context, '__name__', _marker) is not _marker:
290             return resource_path_tuple(context) == self.val
291         return False
292
293 class EffectivePrincipalsPredicate(object):
294     def __init__(self, val, config):
295         if is_nonstr_iter(val):
296             self.val = set(val)
297         else:
298             self.val = set((val,))
299
300     def text(self):
301         return 'effective_principals = %s' % sorted(list(self.val))
302
303     phash = text
304
305     def __call__(self, context, request):
306         req_principals = request.effective_principals
307         if is_nonstr_iter(req_principals):
308             rpset = set(req_principals)
309             if self.val.issubset(rpset):
310                 return True
311         return False
312
52fde9 313 class Notted(object):
MM 314     def __init__(self, predicate):
315         self.predicate = predicate
316
317     def _notted_text(self, val):
318         # if the underlying predicate doesnt return a value, it's not really
319         # a predicate, it's just something pretending to be a predicate,
320         # so dont update the hash
321         if val:
322             val = '!' + val
323         return val
324
325     def text(self):
326         return self._notted_text(self.predicate.text())
327
328     def phash(self):
329         return self._notted_text(self.predicate.phash())
330
331     def __call__(self, context, request):
332         result = self.predicate(context, request)
333         phash = self.phash()
334         if phash:
335             result = not result
336         return result