summaryrefslogtreecommitdiffstats
path: root/accounts/utils
diff options
context:
space:
mode:
Diffstat (limited to 'accounts/utils')
-rw-r--r--accounts/utils/__init__.py13
-rw-r--r--accounts/utils/confirmation.py17
-rw-r--r--accounts/utils/console.py47
-rw-r--r--accounts/utils/login.py16
-rw-r--r--accounts/utils/sessions.py21
5 files changed, 68 insertions, 46 deletions
diff --git a/accounts/utils/__init__.py b/accounts/utils/__init__.py
index da92528..1f79953 100644
--- a/accounts/utils/__init__.py
+++ b/accounts/utils/__init__.py
@@ -15,8 +15,9 @@ def templated(template: Optional[str] = None):
template_name = template
if template_name is None:
if request.endpoint:
- template_name = request.endpoint \
- .replace('.', '/') + '.html'
+ template_name = (
+ request.endpoint.replace(".", "/") + ".html"
+ )
else:
template_name = "error.html"
ctx = f(*args, **kwargs)
@@ -25,7 +26,9 @@ def templated(template: Optional[str] = None):
elif not isinstance(ctx, dict):
return ctx
return render_template(template_name, **ctx)
+
return templated__
+
return templated_
@@ -35,15 +38,15 @@ class NotRegexp(Regexp):
"""
def __call__(self, form, field):
- if self.regex.match(field.data or ''):
+ if self.regex.match(field.data or ""):
if self.message is None:
- self.message: str = field.gettext('Invalid input.')
+ self.message: str = field.gettext("Invalid input.")
raise ValidationError(self.message)
def get_backend(path: str, app: Flask):
module = path.rsplit(".", 1).pop()
- class_name = '%sBackend' % module.title()
+ class_name = "%sBackend" % module.title()
backend_class = getattr(importlib.import_module(path), class_name)
return backend_class(app)
diff --git a/accounts/utils/confirmation.py b/accounts/utils/confirmation.py
index 60967de..62f14ad 100644
--- a/accounts/utils/confirmation.py
+++ b/accounts/utils/confirmation.py
@@ -10,13 +10,16 @@ class Confirmation(URLSafeTimedSerializer):
def __init__(self, realm: str, key=None, **kwargs):
if key is None:
- key = accounts_app.config['SECRET_KEY']
+ key = accounts_app.config["SECRET_KEY"]
super(Confirmation, self).__init__(key, salt=realm, **kwargs)
- def loads_http(self, s: Union[str, bytes],
- max_age: Optional[int] = None,
- return_timestamp: bool = False,
- salt: Optional[bytes] = None) -> Any:
+ def loads_http(
+ self,
+ s: Union[str, bytes],
+ max_age: Optional[int] = None,
+ return_timestamp: bool = False,
+ salt: Optional[bytes] = None,
+ ) -> Any:
"""
Like `Confirmation.loads`, but raise HTTP exceptions with appropriate
messages instead of `BadSignature` or `SignatureExpired`.
@@ -25,6 +28,6 @@ class Confirmation(URLSafeTimedSerializer):
try:
return self.loads(s, max_age, return_timestamp, salt)
except BadSignature:
- raise Forbidden('Ungültiger Bestätigungslink.')
+ raise Forbidden("Ungültiger Bestätigungslink.")
except SignatureExpired:
- raise Forbidden('Bestätigungslink ist zu alt.')
+ raise Forbidden("Bestätigungslink ist zu alt.")
diff --git a/accounts/utils/console.py b/accounts/utils/console.py
index 823ec33..1b539bd 100644
--- a/accounts/utils/console.py
+++ b/accounts/utils/console.py
@@ -1,13 +1,14 @@
# -*- coding: utf-8 -*-
-class TablePrinter():
+
+class TablePrinter:
separator: str
- def __init__(self, headers=None, separator='|'):
+ def __init__(self, headers=None, separator="|"):
self.headers = headers
self.separator = separator
- self.format_string = ''
+ self.format_string = ""
self.widths = list()
if headers is not None:
@@ -26,17 +27,19 @@ class TablePrinter():
while len(self.widths) < columns:
self.widths.append(0)
- self.widths = [_column_width(column, width)
- for column, width
- in zip(list(zip(*rows)), self.widths)]
+ self.widths = [
+ _column_width(column, width)
+ for column, width in zip(list(zip(*rows)), self.widths)
+ ]
self._update_format_string()
def _update_format_string(self) -> None:
- sep = ' %s ' % self.separator
- self.format_string = '%s %s %s' % (
+ sep = " %s " % self.separator
+ self.format_string = "%s %s %s" % (
+ self.separator,
+ sep.join(["%%-%ds" % width for width in self.widths]),
self.separator,
- sep.join(['%%-%ds' % width for width in self.widths]),
- self.separator)
+ )
def output(self, rows):
if len(rows) > 0:
@@ -50,10 +53,18 @@ class TablePrinter():
self._print_row(row)
def _print_headline(self) -> None:
- print(('%s%s%s' % (
- self.separator,
- self.separator.join(['-' * (width + 2) for width in self.widths]),
- self.separator)))
+ print(
+ (
+ "%s%s%s"
+ % (
+ self.separator,
+ self.separator.join(
+ ["-" * (width + 2) for width in self.widths]
+ ),
+ self.separator,
+ )
+ )
+ )
def _print_row(self, row) -> None:
print((self.format_string % tuple(row)))
@@ -63,7 +74,7 @@ class ConsoleForm(object):
_ready = False
def __init__(self, formcls, **kwargs):
- self.form = formcls(meta={'csrf': False})
+ self.form = formcls(meta={"csrf": False})
self._fill(kwargs)
self._ready = True
@@ -76,11 +87,11 @@ class ConsoleForm(object):
def print_errors(self):
for field, errors in list(self.form.errors.items()):
if len(errors) > 1:
- print(('%s:' % field))
+ print(("%s:" % field))
for error in errors:
- print((' %s' % error))
+ print((" %s" % error))
else:
- print(('%s: %s' % (field, errors[0])))
+ print(("%s: %s" % (field, errors[0])))
def __getattr__(self, name):
return getattr(self.form, name)
diff --git a/accounts/utils/login.py b/accounts/utils/login.py
index 07953e3..938268f 100644
--- a/accounts/utils/login.py
+++ b/accounts/utils/login.py
@@ -24,16 +24,18 @@ class _compact_json:
def create_login_manager() -> flask_login.login_manager.LoginManager:
login_manager = LoginManager()
- login_manager.login_message = 'Bitte einloggen'
- login_manager.login_view = 'login.login'
+ login_manager.login_message = "Bitte einloggen"
+ login_manager.login_view = "login.login"
@login_manager.user_loader
def load_user(user_id: str) -> LoginManager:
try:
username, password = parse_userid(user_id)
return accounts_app.user_backend.auth(username, password)
- except (accounts_app.user_backend.NoSuchUserError,
- accounts_app.user_backend.InvalidPasswordError):
+ except (
+ accounts_app.user_backend.NoSuchUserError,
+ accounts_app.user_backend.InvalidPasswordError,
+ ):
return None
return login_manager
@@ -52,7 +54,9 @@ def logout_required(f):
@wraps(f)
def logout_required_(*args, **kwargs):
if current_user.is_authenticated:
- raise Forbidden('Diese Seite ist nur für nicht eingeloggte Benutzer gedacht!')
+ raise Forbidden(
+ "Diese Seite ist nur für nicht eingeloggte Benutzer gedacht!"
+ )
return f(*args, **kwargs)
- return logout_required_
+ return logout_required_
diff --git a/accounts/utils/sessions.py b/accounts/utils/sessions.py
index a452fe1..47580bd 100644
--- a/accounts/utils/sessions.py
+++ b/accounts/utils/sessions.py
@@ -18,7 +18,7 @@ def _pad(value: str, block_size: int) -> bytes:
def _unpad(value: str) -> str:
- pad_length = ord(value[len(value)-1:])
+ pad_length = ord(value[len(value) - 1 :])
return value[:-pad_length]
@@ -29,7 +29,7 @@ class EncryptedSerializer(TaggedJSONSerializer):
self.block_size = AES.block_size
def _cipher(self, iv: bytes):
- key = accounts_app.config['SESSION_ENCRYPTION_KEY']
+ key = accounts_app.config["SESSION_ENCRYPTION_KEY"]
assert len(key) == 32
return AES.new(key, AES.MODE_CBC, iv)
@@ -50,10 +50,9 @@ class EncryptedSerializer(TaggedJSONSerializer):
`config.SESSION_ENCRYPTION_KEY`.
"""
decoded = b64decode(value.encode("UTF-8"))
- iv = decoded[:self.block_size]
- raw = self._cipher(iv).decrypt(decoded[AES.block_size:])
- return super(EncryptedSerializer, self) \
- .loads(_unpad(raw))
+ iv = decoded[: self.block_size]
+ raw = self._cipher(iv).decrypt(decoded[AES.block_size :])
+ return super(EncryptedSerializer, self).loads(_unpad(raw))
class EncryptedSessionInterface(SecureCookieSessionInterface):
@@ -63,13 +62,15 @@ class EncryptedSessionInterface(SecureCookieSessionInterface):
session = None
try:
parent = super(EncryptedSessionInterface, self)
- session = cast(EncryptedSessionInterface, parent) \
- .open_session(app, request)
+ session = cast(EncryptedSessionInterface, parent).open_session(
+ app, request
+ )
except BadPayload:
session = self.session_class()
if session is not None:
- session.permanent = \
- app.config.get('PERMANENT_SESSION_LIFETIME') is not None
+ session.permanent = (
+ app.config.get("PERMANENT_SESSION_LIFETIME") is not None
+ )
return session