203 lines
		
	
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
		
		
			
		
	
	
			203 lines
		
	
	
	
		
			7.7 KiB
		
	
	
	
		
			Python
		
	
	
	
	
	
|   | from __future__ import annotations | ||
|  | 
 | ||
|  | import re | ||
|  | import typing as t | ||
|  | from dataclasses import dataclass | ||
|  | from dataclasses import field | ||
|  | 
 | ||
|  | from .converters import ValidationError | ||
|  | from .exceptions import NoMatch | ||
|  | from .exceptions import RequestAliasRedirect | ||
|  | from .exceptions import RequestPath | ||
|  | from .rules import Rule | ||
|  | from .rules import RulePart | ||
|  | 
 | ||
|  | 
 | ||
|  | class SlashRequired(Exception): | ||
|  |     pass | ||
|  | 
 | ||
|  | 
 | ||
|  | @dataclass | ||
|  | class State: | ||
|  |     """A representation of a rule state.
 | ||
|  | 
 | ||
|  |     This includes the *rules* that correspond to the state and the | ||
|  |     possible *static* and *dynamic* transitions to the next state. | ||
|  |     """
 | ||
|  | 
 | ||
|  |     dynamic: list[tuple[RulePart, State]] = field(default_factory=list) | ||
|  |     rules: list[Rule] = field(default_factory=list) | ||
|  |     static: dict[str, State] = field(default_factory=dict) | ||
|  | 
 | ||
|  | 
 | ||
|  | class StateMachineMatcher: | ||
|  |     def __init__(self, merge_slashes: bool) -> None: | ||
|  |         self._root = State() | ||
|  |         self.merge_slashes = merge_slashes | ||
|  | 
 | ||
|  |     def add(self, rule: Rule) -> None: | ||
|  |         state = self._root | ||
|  |         for part in rule._parts: | ||
|  |             if part.static: | ||
|  |                 state.static.setdefault(part.content, State()) | ||
|  |                 state = state.static[part.content] | ||
|  |             else: | ||
|  |                 for test_part, new_state in state.dynamic: | ||
|  |                     if test_part == part: | ||
|  |                         state = new_state | ||
|  |                         break | ||
|  |                 else: | ||
|  |                     new_state = State() | ||
|  |                     state.dynamic.append((part, new_state)) | ||
|  |                     state = new_state | ||
|  |         state.rules.append(rule) | ||
|  | 
 | ||
|  |     def update(self) -> None: | ||
|  |         # For every state the dynamic transitions should be sorted by | ||
|  |         # the weight of the transition | ||
|  |         state = self._root | ||
|  | 
 | ||
|  |         def _update_state(state: State) -> None: | ||
|  |             state.dynamic.sort(key=lambda entry: entry[0].weight) | ||
|  |             for new_state in state.static.values(): | ||
|  |                 _update_state(new_state) | ||
|  |             for _, new_state in state.dynamic: | ||
|  |                 _update_state(new_state) | ||
|  | 
 | ||
|  |         _update_state(state) | ||
|  | 
 | ||
|  |     def match( | ||
|  |         self, domain: str, path: str, method: str, websocket: bool | ||
|  |     ) -> tuple[Rule, t.MutableMapping[str, t.Any]]: | ||
|  |         # To match to a rule we need to start at the root state and | ||
|  |         # try to follow the transitions until we find a match, or find | ||
|  |         # there is no transition to follow. | ||
|  | 
 | ||
|  |         have_match_for = set() | ||
|  |         websocket_mismatch = False | ||
|  | 
 | ||
|  |         def _match( | ||
|  |             state: State, parts: list[str], values: list[str] | ||
|  |         ) -> tuple[Rule, list[str]] | None: | ||
|  |             # This function is meant to be called recursively, and will attempt | ||
|  |             # to match the head part to the state's transitions. | ||
|  |             nonlocal have_match_for, websocket_mismatch | ||
|  | 
 | ||
|  |             # The base case is when all parts have been matched via | ||
|  |             # transitions. Hence if there is a rule with methods & | ||
|  |             # websocket that work return it and the dynamic values | ||
|  |             # extracted. | ||
|  |             if parts == []: | ||
|  |                 for rule in state.rules: | ||
|  |                     if rule.methods is not None and method not in rule.methods: | ||
|  |                         have_match_for.update(rule.methods) | ||
|  |                     elif rule.websocket != websocket: | ||
|  |                         websocket_mismatch = True | ||
|  |                     else: | ||
|  |                         return rule, values | ||
|  | 
 | ||
|  |                 # Test if there is a match with this path with a | ||
|  |                 # trailing slash, if so raise an exception to report | ||
|  |                 # that matching is possible with an additional slash | ||
|  |                 if "" in state.static: | ||
|  |                     for rule in state.static[""].rules: | ||
|  |                         if websocket == rule.websocket and ( | ||
|  |                             rule.methods is None or method in rule.methods | ||
|  |                         ): | ||
|  |                             if rule.strict_slashes: | ||
|  |                                 raise SlashRequired() | ||
|  |                             else: | ||
|  |                                 return rule, values | ||
|  |                 return None | ||
|  | 
 | ||
|  |             part = parts[0] | ||
|  |             # To match this part try the static transitions first | ||
|  |             if part in state.static: | ||
|  |                 rv = _match(state.static[part], parts[1:], values) | ||
|  |                 if rv is not None: | ||
|  |                     return rv | ||
|  |             # No match via the static transitions, so try the dynamic | ||
|  |             # ones. | ||
|  |             for test_part, new_state in state.dynamic: | ||
|  |                 target = part | ||
|  |                 remaining = parts[1:] | ||
|  |                 # A final part indicates a transition that always | ||
|  |                 # consumes the remaining parts i.e. transitions to a | ||
|  |                 # final state. | ||
|  |                 if test_part.final: | ||
|  |                     target = "/".join(parts) | ||
|  |                     remaining = [] | ||
|  |                 match = re.compile(test_part.content).match(target) | ||
|  |                 if match is not None: | ||
|  |                     if test_part.suffixed: | ||
|  |                         # If a part_isolating=False part has a slash suffix, remove the | ||
|  |                         # suffix from the match and check for the slash redirect next. | ||
|  |                         suffix = match.groups()[-1] | ||
|  |                         if suffix == "/": | ||
|  |                             remaining = [""] | ||
|  | 
 | ||
|  |                     converter_groups = sorted( | ||
|  |                         match.groupdict().items(), key=lambda entry: entry[0] | ||
|  |                     ) | ||
|  |                     groups = [ | ||
|  |                         value | ||
|  |                         for key, value in converter_groups | ||
|  |                         if key[:11] == "__werkzeug_" | ||
|  |                     ] | ||
|  |                     rv = _match(new_state, remaining, values + groups) | ||
|  |                     if rv is not None: | ||
|  |                         return rv | ||
|  | 
 | ||
|  |             # If there is no match and the only part left is a | ||
|  |             # trailing slash ("") consider rules that aren't | ||
|  |             # strict-slashes as these should match if there is a final | ||
|  |             # slash part. | ||
|  |             if parts == [""]: | ||
|  |                 for rule in state.rules: | ||
|  |                     if rule.strict_slashes: | ||
|  |                         continue | ||
|  |                     if rule.methods is not None and method not in rule.methods: | ||
|  |                         have_match_for.update(rule.methods) | ||
|  |                     elif rule.websocket != websocket: | ||
|  |                         websocket_mismatch = True | ||
|  |                     else: | ||
|  |                         return rule, values | ||
|  | 
 | ||
|  |             return None | ||
|  | 
 | ||
|  |         try: | ||
|  |             rv = _match(self._root, [domain, *path.split("/")], []) | ||
|  |         except SlashRequired: | ||
|  |             raise RequestPath(f"{path}/") from None | ||
|  | 
 | ||
|  |         if self.merge_slashes and rv is None: | ||
|  |             # Try to match again, but with slashes merged | ||
|  |             path = re.sub("/{2,}?", "/", path) | ||
|  |             try: | ||
|  |                 rv = _match(self._root, [domain, *path.split("/")], []) | ||
|  |             except SlashRequired: | ||
|  |                 raise RequestPath(f"{path}/") from None | ||
|  |             if rv is None or rv[0].merge_slashes is False: | ||
|  |                 raise NoMatch(have_match_for, websocket_mismatch) | ||
|  |             else: | ||
|  |                 raise RequestPath(f"{path}") | ||
|  |         elif rv is not None: | ||
|  |             rule, values = rv | ||
|  | 
 | ||
|  |             result = {} | ||
|  |             for name, value in zip(rule._converters.keys(), values): | ||
|  |                 try: | ||
|  |                     value = rule._converters[name].to_python(value) | ||
|  |                 except ValidationError: | ||
|  |                     raise NoMatch(have_match_for, websocket_mismatch) from None | ||
|  |                 result[str(name)] = value | ||
|  |             if rule.defaults: | ||
|  |                 result.update(rule.defaults) | ||
|  | 
 | ||
|  |             if rule.alias and rule.map.redirect_defaults: | ||
|  |                 raise RequestAliasRedirect(result, rule.endpoint) | ||
|  | 
 | ||
|  |             return rule, result | ||
|  | 
 | ||
|  |         raise NoMatch(have_match_for, websocket_mismatch) |