Steve Piercy
2018-09-22 2a45fe74f9598b4e726ab17ce17948d4e709894b
commit | author | age
c7974f 1 import re
MM 2
3 from pyramid.exceptions import ConfigurationError
4
5 from pyramid.compat import is_nonstr_iter
6
a2c7c7 7 from pyramid.csrf import check_csrf_token
c7974f 8 from pyramid.traversal import (
MM 9     find_interface,
10     traversal_path,
11     resource_path_tuple
12     )
13
14 from pyramid.urldispatch import _compile_route
15 from pyramid.util import object_description
16 from pyramid.util import as_sorted_tuple
17
18 _marker = object()
19
20 class XHRPredicate(object):
21     def __init__(self, val, config):
22         self.val = bool(val)
23
24     def text(self):
25         return 'xhr = %s' % self.val
26
27     phash = text
28
29     def __call__(self, context, request):
30         return bool(request.is_xhr) is self.val
31
32 class RequestMethodPredicate(object):
33     def __init__(self, val, config):
34         request_method = as_sorted_tuple(val)
35         if 'GET' in request_method and 'HEAD' not in request_method:
36             # GET implies HEAD too
37             request_method = as_sorted_tuple(request_method + ('HEAD',))
38         self.val = request_method
39
40     def text(self):
41         return 'request_method = %s' % (','.join(self.val))
42
43     phash = text
44
45     def __call__(self, context, request):
46         return request.method in self.val
47
48 class PathInfoPredicate(object):
49     def __init__(self, val, config):
50         self.orig = val
51         try:
52             val = re.compile(val)
53         except re.error as why:
54             raise ConfigurationError(why.args[0])
55         self.val = val
56
57     def text(self):
58         return 'path_info = %s' % (self.orig,)
59
60     phash = text
61
62     def __call__(self, context, request):
63         return self.val.match(request.upath_info) is not None
64     
65 class RequestParamPredicate(object):
66     def __init__(self, val, config):
67         val = as_sorted_tuple(val)
68         reqs = []
69         for p in val:
70             k = p
71             v = None
72             if p.startswith('='):
73                 if '=' in p[1:]:
74                     k, v = p[1:].split('=', 1)
75                     k = '=' + k
76                     k, v = k.strip(), v.strip()
77             elif '=' in p:
78                 k, v = p.split('=', 1)
79                 k, v = k.strip(), v.strip()
80             reqs.append((k, v))
81         self.val = val
82         self.reqs = reqs
83
84     def text(self):
85         return 'request_param %s' % ','.join(
86             ['%s=%s' % (x,y) if y else x for x, y in self.reqs]
87         )
88
89     phash = text
90
91     def __call__(self, context, request):
92         for k, v in self.reqs:
93             actual = request.params.get(k)
94             if actual is None:
95                 return False
96             if v is not None and actual != v:
97                 return False
98         return True
99
100 class HeaderPredicate(object):
101     def __init__(self, val, config):
102         name = val
103         v = None
104         if ':' in name:
105             name, val_str = name.split(':', 1)
106             try:
107                 v = re.compile(val_str)
108             except re.error as why:
109                 raise ConfigurationError(why.args[0])
110         if v is None:
111             self._text = 'header %s' % (name,)
112         else:
113             self._text = 'header %s=%s' % (name, val_str)
114         self.name = name
115         self.val = v
116
117     def text(self):
118         return self._text
119
120     phash = text
121
122     def __call__(self, context, request):
123         if self.val is None:
124             return self.name in request.headers
125         val = request.headers.get(self.name)
126         if val is None:
127             return False
128         return self.val.match(val) is not None
129
130 class AcceptPredicate(object):
131     def __init__(self, val, config):
132         self.val = val
133
134     def text(self):
135         return 'accept = %s' % (self.val,)
136
137     phash = text
138
139     def __call__(self, context, request):
140         return self.val in request.accept
141
142 class ContainmentPredicate(object):
143     def __init__(self, val, config):
144         self.val = config.maybe_dotted(val)
145
146     def text(self):
147         return 'containment = %s' % (self.val,)
148
149     phash = text
150
151     def __call__(self, context, request):
152         ctx = getattr(request, 'context', context)
153         return find_interface(ctx, self.val) is not None
154     
155 class RequestTypePredicate(object):
156     def __init__(self, val, config):
157         self.val = val
158
159     def text(self):
160         return 'request_type = %s' % (self.val,)
161
162     phash = text
163
164     def __call__(self, context, request):
165         return self.val.providedBy(request)
166     
167 class MatchParamPredicate(object):
168     def __init__(self, val, config):
169         val = as_sorted_tuple(val)
170         self.val = val
171         reqs = [ p.split('=', 1) for p in val ]
172         self.reqs = [ (x.strip(), y.strip()) for x, y in reqs ]
173
174     def text(self):
175         return 'match_param %s' % ','.join(
176             ['%s=%s' % (x,y) for x, y in self.reqs]
177             )
178
179     phash = text
180
181     def __call__(self, context, request):
182         if not request.matchdict:
183             # might be None
184             return False
185         for k, v in self.reqs:
186             if request.matchdict.get(k) != v:
187                 return False
188         return True
189     
190 class CustomPredicate(object):
191     def __init__(self, func, config):
192         self.func = func
193
194     def text(self):
195         return getattr(
196             self.func,
197             '__text__',
198             'custom predicate: %s' % object_description(self.func)
199             )
200
201     def phash(self):
202         # using hash() here rather than id() is intentional: we
203         # want to allow custom predicates that are part of
204         # frameworks to be able to define custom __hash__
205         # functions for custom predicates, so that the hash output
206         # of predicate instances which are "logically the same"
207         # may compare equal.
208         return 'custom:%r' % hash(self.func)
209
210     def __call__(self, context, request):
211         return self.func(context, request)
212     
213     
214 class TraversePredicate(object):
215     # Can only be used as a *route* "predicate"; it adds 'traverse' to the
216     # matchdict if it's specified in the routing args.  This causes the
217     # ResourceTreeTraverser to use the resolved traverse pattern as the
218     # traversal path.
219     def __init__(self, val, config):
220         _, self.tgenerate = _compile_route(val)
221         self.val = val
222         
223     def text(self):
224         return 'traverse matchdict pseudo-predicate'
225
226     def phash(self):
227         # This isn't actually a predicate, it's just a infodict modifier that
228         # injects ``traverse`` into the matchdict.  As a result, we don't
229         # need to update the hash.
230         return ''
231
232     def __call__(self, context, request):
233         if 'traverse' in context:
234             return True
235         m = context['match']
236         tvalue = self.tgenerate(m)  # tvalue will be urlquoted string
237         m['traverse'] = traversal_path(tvalue)
238         # This isn't actually a predicate, it's just a infodict modifier that
239         # injects ``traverse`` into the matchdict.  As a result, we just
240         # return True.
241         return True
242
243 class CheckCSRFTokenPredicate(object):
244
245     check_csrf_token = staticmethod(check_csrf_token) # testing
246     
247     def __init__(self, val, config):
248         self.val = val
249
250     def text(self):
251         return 'check_csrf = %s' % (self.val,)
252
253     phash = text
254
255     def __call__(self, context, request):
256         val = self.val
257         if val:
258             if val is True:
259                 val = 'csrf_token'
260             return self.check_csrf_token(request, val, raises=False)
261         return True
262
263 class PhysicalPathPredicate(object):
264     def __init__(self, val, config):
265         if is_nonstr_iter(val):
266             self.val = tuple(val)
267         else:
268             val = tuple(filter(None, val.split('/')))
269             self.val = ('',) + val
270
271     def text(self):
272         return 'physical_path = %s' % (self.val,)
273
274     phash = text
275
276     def __call__(self, context, request):
277         if getattr(context, '__name__', _marker) is not _marker:
278             return resource_path_tuple(context) == self.val
279         return False
280
281 class EffectivePrincipalsPredicate(object):
282     def __init__(self, val, config):
283         if is_nonstr_iter(val):
284             self.val = set(val)
285         else:
286             self.val = set((val,))
287
288     def text(self):
289         return 'effective_principals = %s' % sorted(list(self.val))
290
291     phash = text
292
293     def __call__(self, context, request):
294         req_principals = request.effective_principals
295         if is_nonstr_iter(req_principals):
296             rpset = set(req_principals)
297             if self.val.issubset(rpset):
298                 return True
299         return False
300