diff options
Diffstat (limited to 'accounts/utils')
-rw-r--r-- | accounts/utils/__init__.py | 13 | ||||
-rw-r--r-- | accounts/utils/confirmation.py | 17 | ||||
-rw-r--r-- | accounts/utils/console.py | 47 | ||||
-rw-r--r-- | accounts/utils/login.py | 16 | ||||
-rw-r--r-- | accounts/utils/sessions.py | 21 |
5 files changed, 68 insertions, 46 deletions
diff --git a/accounts/utils/__init__.py b/accounts/utils/__init__.py index da92528..1f79953 100644 --- a/accounts/utils/__init__.py +++ b/accounts/utils/__init__.py @@ -15,8 +15,9 @@ def templated(template: Optional[str] = None): template_name = template if template_name is None: if request.endpoint: - template_name = request.endpoint \ - .replace('.', '/') + '.html' + template_name = ( + request.endpoint.replace(".", "/") + ".html" + ) else: template_name = "error.html" ctx = f(*args, **kwargs) @@ -25,7 +26,9 @@ def templated(template: Optional[str] = None): elif not isinstance(ctx, dict): return ctx return render_template(template_name, **ctx) + return templated__ + return templated_ @@ -35,15 +38,15 @@ class NotRegexp(Regexp): """ def __call__(self, form, field): - if self.regex.match(field.data or ''): + if self.regex.match(field.data or ""): if self.message is None: - self.message: str = field.gettext('Invalid input.') + self.message: str = field.gettext("Invalid input.") raise ValidationError(self.message) def get_backend(path: str, app: Flask): module = path.rsplit(".", 1).pop() - class_name = '%sBackend' % module.title() + class_name = "%sBackend" % module.title() backend_class = getattr(importlib.import_module(path), class_name) return backend_class(app) diff --git a/accounts/utils/confirmation.py b/accounts/utils/confirmation.py index 60967de..62f14ad 100644 --- a/accounts/utils/confirmation.py +++ b/accounts/utils/confirmation.py @@ -10,13 +10,16 @@ class Confirmation(URLSafeTimedSerializer): def __init__(self, realm: str, key=None, **kwargs): if key is None: - key = accounts_app.config['SECRET_KEY'] + key = accounts_app.config["SECRET_KEY"] super(Confirmation, self).__init__(key, salt=realm, **kwargs) - def loads_http(self, s: Union[str, bytes], - max_age: Optional[int] = None, - return_timestamp: bool = False, - salt: Optional[bytes] = None) -> Any: + def loads_http( + self, + s: Union[str, bytes], + max_age: Optional[int] = None, + return_timestamp: bool = False, + salt: Optional[bytes] = None, + ) -> Any: """ Like `Confirmation.loads`, but raise HTTP exceptions with appropriate messages instead of `BadSignature` or `SignatureExpired`. @@ -25,6 +28,6 @@ class Confirmation(URLSafeTimedSerializer): try: return self.loads(s, max_age, return_timestamp, salt) except BadSignature: - raise Forbidden('Ungültiger Bestätigungslink.') + raise Forbidden("Ungültiger Bestätigungslink.") except SignatureExpired: - raise Forbidden('Bestätigungslink ist zu alt.') + raise Forbidden("Bestätigungslink ist zu alt.") diff --git a/accounts/utils/console.py b/accounts/utils/console.py index 823ec33..1b539bd 100644 --- a/accounts/utils/console.py +++ b/accounts/utils/console.py @@ -1,13 +1,14 @@ # -*- coding: utf-8 -*- -class TablePrinter(): + +class TablePrinter: separator: str - def __init__(self, headers=None, separator='|'): + def __init__(self, headers=None, separator="|"): self.headers = headers self.separator = separator - self.format_string = '' + self.format_string = "" self.widths = list() if headers is not None: @@ -26,17 +27,19 @@ class TablePrinter(): while len(self.widths) < columns: self.widths.append(0) - self.widths = [_column_width(column, width) - for column, width - in zip(list(zip(*rows)), self.widths)] + self.widths = [ + _column_width(column, width) + for column, width in zip(list(zip(*rows)), self.widths) + ] self._update_format_string() def _update_format_string(self) -> None: - sep = ' %s ' % self.separator - self.format_string = '%s %s %s' % ( + sep = " %s " % self.separator + self.format_string = "%s %s %s" % ( + self.separator, + sep.join(["%%-%ds" % width for width in self.widths]), self.separator, - sep.join(['%%-%ds' % width for width in self.widths]), - self.separator) + ) def output(self, rows): if len(rows) > 0: @@ -50,10 +53,18 @@ class TablePrinter(): self._print_row(row) def _print_headline(self) -> None: - print(('%s%s%s' % ( - self.separator, - self.separator.join(['-' * (width + 2) for width in self.widths]), - self.separator))) + print( + ( + "%s%s%s" + % ( + self.separator, + self.separator.join( + ["-" * (width + 2) for width in self.widths] + ), + self.separator, + ) + ) + ) def _print_row(self, row) -> None: print((self.format_string % tuple(row))) @@ -63,7 +74,7 @@ class ConsoleForm(object): _ready = False def __init__(self, formcls, **kwargs): - self.form = formcls(meta={'csrf': False}) + self.form = formcls(meta={"csrf": False}) self._fill(kwargs) self._ready = True @@ -76,11 +87,11 @@ class ConsoleForm(object): def print_errors(self): for field, errors in list(self.form.errors.items()): if len(errors) > 1: - print(('%s:' % field)) + print(("%s:" % field)) for error in errors: - print((' %s' % error)) + print((" %s" % error)) else: - print(('%s: %s' % (field, errors[0]))) + print(("%s: %s" % (field, errors[0]))) def __getattr__(self, name): return getattr(self.form, name) diff --git a/accounts/utils/login.py b/accounts/utils/login.py index 07953e3..938268f 100644 --- a/accounts/utils/login.py +++ b/accounts/utils/login.py @@ -24,16 +24,18 @@ class _compact_json: def create_login_manager() -> flask_login.login_manager.LoginManager: login_manager = LoginManager() - login_manager.login_message = 'Bitte einloggen' - login_manager.login_view = 'login.login' + login_manager.login_message = "Bitte einloggen" + login_manager.login_view = "login.login" @login_manager.user_loader def load_user(user_id: str) -> LoginManager: try: username, password = parse_userid(user_id) return accounts_app.user_backend.auth(username, password) - except (accounts_app.user_backend.NoSuchUserError, - accounts_app.user_backend.InvalidPasswordError): + except ( + accounts_app.user_backend.NoSuchUserError, + accounts_app.user_backend.InvalidPasswordError, + ): return None return login_manager @@ -52,7 +54,9 @@ def logout_required(f): @wraps(f) def logout_required_(*args, **kwargs): if current_user.is_authenticated: - raise Forbidden('Diese Seite ist nur für nicht eingeloggte Benutzer gedacht!') + raise Forbidden( + "Diese Seite ist nur für nicht eingeloggte Benutzer gedacht!" + ) return f(*args, **kwargs) - return logout_required_ + return logout_required_ diff --git a/accounts/utils/sessions.py b/accounts/utils/sessions.py index a452fe1..47580bd 100644 --- a/accounts/utils/sessions.py +++ b/accounts/utils/sessions.py @@ -18,7 +18,7 @@ def _pad(value: str, block_size: int) -> bytes: def _unpad(value: str) -> str: - pad_length = ord(value[len(value)-1:]) + pad_length = ord(value[len(value) - 1 :]) return value[:-pad_length] @@ -29,7 +29,7 @@ class EncryptedSerializer(TaggedJSONSerializer): self.block_size = AES.block_size def _cipher(self, iv: bytes): - key = accounts_app.config['SESSION_ENCRYPTION_KEY'] + key = accounts_app.config["SESSION_ENCRYPTION_KEY"] assert len(key) == 32 return AES.new(key, AES.MODE_CBC, iv) @@ -50,10 +50,9 @@ class EncryptedSerializer(TaggedJSONSerializer): `config.SESSION_ENCRYPTION_KEY`. """ decoded = b64decode(value.encode("UTF-8")) - iv = decoded[:self.block_size] - raw = self._cipher(iv).decrypt(decoded[AES.block_size:]) - return super(EncryptedSerializer, self) \ - .loads(_unpad(raw)) + iv = decoded[: self.block_size] + raw = self._cipher(iv).decrypt(decoded[AES.block_size :]) + return super(EncryptedSerializer, self).loads(_unpad(raw)) class EncryptedSessionInterface(SecureCookieSessionInterface): @@ -63,13 +62,15 @@ class EncryptedSessionInterface(SecureCookieSessionInterface): session = None try: parent = super(EncryptedSessionInterface, self) - session = cast(EncryptedSessionInterface, parent) \ - .open_session(app, request) + session = cast(EncryptedSessionInterface, parent).open_session( + app, request + ) except BadPayload: session = self.session_class() if session is not None: - session.permanent = \ - app.config.get('PERMANENT_SESSION_LIFETIME') is not None + session.permanent = ( + app.config.get("PERMANENT_SESSION_LIFETIME") is not None + ) return session |