From dcad1d5ca832ea05ababa3d38de9a82fc361f2ba Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonah=20Br=C3=BCchert?= Date: Fri, 29 Mar 2024 02:13:57 +0100 Subject: Enforce types in variables --- accounts/app.py | 2 ++ accounts/models.py | 2 +- accounts/utils/__init__.py | 10 +++++++--- accounts/utils/console.py | 17 +++++++++-------- accounts/utils/login.py | 4 ++-- accounts/views/admin/__init__.py | 2 +- accounts/views/default/__init__.py | 2 +- mypy.ini | 1 + 8 files changed, 24 insertions(+), 16 deletions(-) diff --git a/accounts/app.py b/accounts/app.py index eaab824..d39ce03 100644 --- a/accounts/app.py +++ b/accounts/app.py @@ -1,4 +1,5 @@ from flask import Flask, current_app +from flask_login import LoginManager from typing import TYPE_CHECKING, cast if TYPE_CHECKING: @@ -11,6 +12,7 @@ class AccountsFlask(Flask): username_blacklist: list[str] user_backend: "user.Backend" mail_backend: "mail.Backend" + login_manager: LoginManager accounts_app: AccountsFlask = cast(AccountsFlask, current_app) diff --git a/accounts/models.py b/accounts/models.py index c897815..7bbb235 100644 --- a/accounts/models.py +++ b/accounts/models.py @@ -71,7 +71,7 @@ class Account(UserMixin): def _set_attribute(self, key: str, value: Any) -> None: self.attributes[key] = value - def change_email(self, new_mail: str): + def change_email(self, new_mail: str) -> None: """ Changes the mail address of an account. You have to use the AccountService class to make changes permanent. diff --git a/accounts/utils/__init__.py b/accounts/utils/__init__.py index 6adf317..da92528 100644 --- a/accounts/utils/__init__.py +++ b/accounts/utils/__init__.py @@ -14,8 +14,11 @@ def templated(template: Optional[str] = None): def templated__(*args, **kwargs): template_name = template if template_name is None: - template_name = request.endpoint \ - .replace('.', '/') + '.html' + if request.endpoint: + template_name = request.endpoint \ + .replace('.', '/') + '.html' + else: + template_name = "error.html" ctx = f(*args, **kwargs) if ctx is None: ctx = {} @@ -30,10 +33,11 @@ class NotRegexp(Regexp): """ Like wtforms.validators.Regexp, but rejects data that DOES match the regex. """ + def __call__(self, form, field): if self.regex.match(field.data or ''): if self.message is None: - self.message = field.gettext('Invalid input.') + self.message: str = field.gettext('Invalid input.') raise ValidationError(self.message) diff --git a/accounts/utils/console.py b/accounts/utils/console.py index c63480a..823ec33 100644 --- a/accounts/utils/console.py +++ b/accounts/utils/console.py @@ -1,6 +1,7 @@ # -*- coding: utf-8 -*- -class TablePrinter(object): +class TablePrinter(): + separator: str def __init__(self, headers=None, separator='|'): self.headers = headers @@ -12,11 +13,11 @@ class TablePrinter(object): if headers is not None: self._calulate_widths([headers]) - def _calulate_widths(self, rows): - def _get_column_count(rows): + def _calulate_widths(self, rows) -> None: + def _get_column_count(rows: list): return min([len(row) for row in rows]) - def _column_width(column, width): + def _column_width(column: tuple, width: int) -> int: widths = [len(str(elem)) for elem in column] widths.append(width) return max(widths) @@ -30,7 +31,7 @@ class TablePrinter(object): in zip(list(zip(*rows)), self.widths)] self._update_format_string() - def _update_format_string(self): + def _update_format_string(self) -> None: sep = ' %s ' % self.separator self.format_string = '%s %s %s' % ( self.separator, @@ -48,13 +49,13 @@ class TablePrinter(object): for row in rows: self._print_row(row) - def _print_headline(self): + def _print_headline(self) -> None: print(('%s%s%s' % ( self.separator, self.separator.join(['-' * (width + 2) for width in self.widths]), self.separator))) - def _print_row(self, row): + def _print_row(self, row) -> None: print((self.format_string % tuple(row))) @@ -66,7 +67,7 @@ class ConsoleForm(object): self._fill(kwargs) self._ready = True - def _fill(self, data): + def _fill(self, data) -> None: for key, value in list(data.items()): field = getattr(self.form, key, None) if field is not None: diff --git a/accounts/utils/login.py b/accounts/utils/login.py index 64e6ce8..07953e3 100644 --- a/accounts/utils/login.py +++ b/accounts/utils/login.py @@ -7,7 +7,7 @@ import json import flask_login.login_manager from accounts.app import accounts_app -from typing import Union, Any +from typing import Union, Any, Optional class _compact_json: @@ -39,7 +39,7 @@ def create_login_manager() -> flask_login.login_manager.LoginManager: return login_manager -def create_userid(username: str, password: str) -> bytes: +def create_userid(username: str, password: Optional[str]) -> bytes: userid = (username, password) return base64_encode(_compact_json.dumps(userid)) diff --git a/accounts/views/admin/__init__.py b/accounts/views/admin/__init__.py index 7378e38..938033b 100644 --- a/accounts/views/admin/__init__.py +++ b/accounts/views/admin/__init__.py @@ -66,7 +66,7 @@ def disable_account(): if 'uid' in request.args: form = AdminDisableAccountForm(username=request.args['uid']) - if form.validate_on_submit(): + if form.validate_on_submit() and form.user: random_pw = str(uuid4()) form.user.change_password(random_pw) for service in accounts_app.all_services: diff --git a/accounts/views/default/__init__.py b/accounts/views/default/__init__.py index 0b7065d..bba20fd 100644 --- a/accounts/views/default/__init__.py +++ b/accounts/views/default/__init__.py @@ -84,7 +84,7 @@ def register_complete(token: str): @logout_required def lost_password(): form = LostPasswordForm() - if form.validate_on_submit(): + if form.validate_on_submit() and form.user: #TODO: make the link only usable once (e.g include a hash of the old pw) # atm the only thing we do is make the link valid for only little time accounts_app.mail_backend.send( diff --git a/mypy.ini b/mypy.ini index b75cab3..9a65410 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,5 +1,6 @@ [mypy] disallow_untyped_calls = True +check_untyped_defs = True [mypy-flask_login.*] ignore_missing_imports = True -- cgit v1.2.3-1-g7c22