import re from pyramid.exceptions import ConfigurationError from pyramid.compat import is_nonstr_iter from pyramid.csrf import check_csrf_token from pyramid.traversal import ( find_interface, traversal_path, resource_path_tuple, ) from pyramid.urldispatch import _compile_route from pyramid.util import as_sorted_tuple, object_description _marker = object() class XHRPredicate(object): def __init__(self, val, config): self.val = bool(val) def text(self): return 'xhr = %s' % self.val phash = text def __call__(self, context, request): return bool(request.is_xhr) is self.val class RequestMethodPredicate(object): def __init__(self, val, config): request_method = as_sorted_tuple(val) if 'GET' in request_method and 'HEAD' not in request_method: # GET implies HEAD too request_method = as_sorted_tuple(request_method + ('HEAD',)) self.val = request_method def text(self): return 'request_method = %s' % (','.join(self.val)) phash = text def __call__(self, context, request): return request.method in self.val class PathInfoPredicate(object): def __init__(self, val, config): self.orig = val try: val = re.compile(val) except re.error as why: raise ConfigurationError(why.args[0]) self.val = val def text(self): return 'path_info = %s' % (self.orig,) phash = text def __call__(self, context, request): return self.val.match(request.upath_info) is not None class RequestParamPredicate(object): def __init__(self, val, config): val = as_sorted_tuple(val) reqs = [] for p in val: k = p v = None if p.startswith('='): if '=' in p[1:]: k, v = p[1:].split('=', 1) k = '=' + k k, v = k.strip(), v.strip() elif '=' in p: k, v = p.split('=', 1) k, v = k.strip(), v.strip() reqs.append((k, v)) self.val = val self.reqs = reqs def text(self): return 'request_param %s' % ','.join( ['%s=%s' % (x, y) if y else x for x, y in self.reqs] ) phash = text def __call__(self, context, request): for k, v in self.reqs: actual = request.params.get(k) if actual is None: return False if v is not None and actual != v: return False return True class HeaderPredicate(object): def __init__(self, val, config): name = val v = None if ':' in name: name, val_str = name.split(':', 1) try: v = re.compile(val_str) except re.error as why: raise ConfigurationError(why.args[0]) if v is None: self._text = 'header %s' % (name,) else: self._text = 'header %s=%s' % (name, val_str) self.name = name self.val = v def text(self): return self._text phash = text def __call__(self, context, request): if self.val is None: return self.name in request.headers val = request.headers.get(self.name) if val is None: return False return self.val.match(val) is not None class AcceptPredicate(object): _is_using_deprecated_ranges = False def __init__(self, values, config): if not is_nonstr_iter(values): values = (values,) # deprecated media ranges were only supported in versions of the # predicate that didn't support lists, so check it here if len(values) == 1 and '*' in values[0]: self._is_using_deprecated_ranges = True self.values = values def text(self): return 'accept = %s' % (', '.join(self.values),) phash = text def __call__(self, context, request): if self._is_using_deprecated_ranges: return self.values[0] in request.accept return bool(request.accept.acceptable_offers(self.values)) class ContainmentPredicate(object): def __init__(self, val, config): self.val = config.maybe_dotted(val) def text(self): return 'containment = %s' % (self.val,) phash = text def __call__(self, context, request): ctx = getattr(request, 'context', context) return find_interface(ctx, self.val) is not None class RequestTypePredicate(object): def __init__(self, val, config): self.val = val def text(self): return 'request_type = %s' % (self.val,) phash = text def __call__(self, context, request): return self.val.providedBy(request) class MatchParamPredicate(object): def __init__(self, val, config): val = as_sorted_tuple(val) self.val = val reqs = [p.split('=', 1) for p in val] self.reqs = [(x.strip(), y.strip()) for x, y in reqs] def text(self): return 'match_param %s' % ','.join( ['%s=%s' % (x, y) for x, y in self.reqs] ) phash = text def __call__(self, context, request): if not request.matchdict: # might be None return False for k, v in self.reqs: if request.matchdict.get(k) != v: return False return True class CustomPredicate(object): def __init__(self, func, config): self.func = func def text(self): return getattr( self.func, '__text__', 'custom predicate: %s' % object_description(self.func), ) def phash(self): # using hash() here rather than id() is intentional: we # want to allow custom predicates that are part of # frameworks to be able to define custom __hash__ # functions for custom predicates, so that the hash output # of predicate instances which are "logically the same" # may compare equal. return 'custom:%r' % hash(self.func) def __call__(self, context, request): return self.func(context, request) class TraversePredicate(object): # Can only be used as a *route* "predicate"; it adds 'traverse' to the # matchdict if it's specified in the routing args. This causes the # ResourceTreeTraverser to use the resolved traverse pattern as the # traversal path. def __init__(self, val, config): _, self.tgenerate = _compile_route(val) self.val = val def text(self): return 'traverse matchdict pseudo-predicate' def phash(self): # This isn't actually a predicate, it's just a infodict modifier that # injects ``traverse`` into the matchdict. As a result, we don't # need to update the hash. return '' def __call__(self, context, request): if 'traverse' in context: return True m = context['match'] tvalue = self.tgenerate(m) # tvalue will be urlquoted string m['traverse'] = traversal_path(tvalue) # This isn't actually a predicate, it's just a infodict modifier that # injects ``traverse`` into the matchdict. As a result, we just # return True. return True class CheckCSRFTokenPredicate(object): check_csrf_token = staticmethod(check_csrf_token) # testing def __init__(self, val, config): self.val = val def text(self): return 'check_csrf = %s' % (self.val,) phash = text def __call__(self, context, request): val = self.val if val: if val is True: val = 'csrf_token' return self.check_csrf_token(request, val, raises=False) return True class PhysicalPathPredicate(object): def __init__(self, val, config): if is_nonstr_iter(val): self.val = tuple(val) else: val = tuple(filter(None, val.split('/'))) self.val = ('',) + val def text(self): return 'physical_path = %s' % (self.val,) phash = text def __call__(self, context, request): if getattr(context, '__name__', _marker) is not _marker: return resource_path_tuple(context) == self.val return False class EffectivePrincipalsPredicate(object): def __init__(self, val, config): if is_nonstr_iter(val): self.val = set(val) else: self.val = set((val,)) def text(self): return 'effective_principals = %s' % sorted(list(self.val)) phash = text def __call__(self, context, request): req_principals = request.effective_principals if is_nonstr_iter(req_principals): rpset = set(req_principals) if self.val.issubset(rpset): return True return False class Notted(object): def __init__(self, predicate): self.predicate = predicate def _notted_text(self, val): # if the underlying predicate doesnt return a value, it's not really # a predicate, it's just something pretending to be a predicate, # so dont update the hash if val: val = '!' + val return val def text(self): return self._notted_text(self.predicate.text()) def phash(self): return self._notted_text(self.predicate.phash()) def __call__(self, context, request): result = self.predicate(context, request) phash = self.phash() if phash: result = not result return result