From 41551f9ff74c692b3db7818364a9b0966e5a08be Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Jonah=20Br=C3=BCchert?= Date: Fri, 29 Mar 2024 01:57:13 +0100 Subject: Enforce types in function calls --- accounts/__init__.py | 2 +- accounts/backend/user/__init__.py | 4 ++-- accounts/backend/user/dummy.py | 4 ++-- accounts/backend/user/ldap.py | 4 ++-- accounts/models.py | 4 ++-- accounts/utils/confirmation.py | 7 ++++++- accounts/utils/login.py | 12 +++++++----- accounts/utils/sessions.py | 2 +- mypy.ini | 3 +++ 9 files changed, 26 insertions(+), 16 deletions(-) diff --git a/accounts/__init__.py b/accounts/__init__.py index c9310fe..883d45c 100644 --- a/accounts/__init__.py +++ b/accounts/__init__.py @@ -14,7 +14,7 @@ from typing import Optional def absolute_paths(app: Flask, config: str) -> None: - def handle_option(dirname, name): + def handle_option(dirname: str, name: str): if app.config.get(name): app.config[name] = os.path.join(dirname, app.config[name]) diff --git a/accounts/backend/user/__init__.py b/accounts/backend/user/__init__.py index 1504e41..70f973a 100644 --- a/accounts/backend/user/__init__.py +++ b/accounts/backend/user/__init__.py @@ -135,13 +135,13 @@ class Backend(object): """ raise NotImplementedError() - def _store(self, account): + def _store(self, account: Account) -> None: """ Persists an account in the backend. """ raise NotImplementedError() - def _get_next_uidNumber(self): + def _get_next_uidNumber(self) -> int: """ Get the next free uid number. diff --git a/accounts/backend/user/dummy.py b/accounts/backend/user/dummy.py index 3d0dcca..fd6620a 100644 --- a/accounts/backend/user/dummy.py +++ b/accounts/backend/user/dummy.py @@ -57,13 +57,13 @@ class DummyBackend(Backend): self._next_uidNumber = 4 - def _get_accounts(self): + def _get_accounts(self) -> list[Account]: accounts = [] for uid, attrs in self._storage.items(): accounts.append( Account( uid, - attrs["mail"], + str(attrs["mail"]), uidNumber=attrs["uidNumber"] ) ) diff --git a/accounts/backend/user/ldap.py b/accounts/backend/user/ldap.py index 217dcba..99080a4 100644 --- a/accounts/backend/user/ldap.py +++ b/accounts/backend/user/ldap.py @@ -167,7 +167,7 @@ class LdapBackend(Backend): return ','.join(dn) def _connect(self, user: Optional[str] = None, - password: Optional[str] = None): + password: Optional[str] = None) -> Connection: server = ldap3.Server(self.host) conn = ldap3.Connection(server, user, password, raise_exceptions=True) @@ -178,7 +178,7 @@ class LdapBackend(Backend): return conn - def _connect_as_admin(self): + def _connect_as_admin(self) -> Connection: admin_dn = self._format_dn([('cn', self.admin_user)]) return self._connect(admin_dn, self.admin_pass) diff --git a/accounts/models.py b/accounts/models.py index fe6c500..c897815 100644 --- a/accounts/models.py +++ b/accounts/models.py @@ -68,7 +68,7 @@ class Account(UserMixin): else: self.new_password_services[service] = (old_password, new_password) - def _set_attribute(self, key, value): + def _set_attribute(self, key: str, value: Any) -> None: self.attributes[key] = value def change_email(self, new_mail: str): @@ -85,7 +85,7 @@ class Account(UserMixin): raise AttributeError("'%s' object has no attribute '%s'" % (self.__class__.__name__, name)) - def __setattr__(self, name, value): + def __setattr__(self, name: str, value: Any): if self._ready and name not in self.__dict__: self._set_attribute(name, value) else: diff --git a/accounts/utils/confirmation.py b/accounts/utils/confirmation.py index b75716d..60967de 100644 --- a/accounts/utils/confirmation.py +++ b/accounts/utils/confirmation.py @@ -3,6 +3,8 @@ from itsdangerous import BadSignature, SignatureExpired, URLSafeTimedSerializer from werkzeug.exceptions import Forbidden from accounts.app import accounts_app +from typing import Union, Optional, Any + class Confirmation(URLSafeTimedSerializer): @@ -11,7 +13,10 @@ class Confirmation(URLSafeTimedSerializer): key = accounts_app.config['SECRET_KEY'] super(Confirmation, self).__init__(key, salt=realm, **kwargs) - def loads_http(self, s, max_age=None, return_timestamp=False, salt=None): + def loads_http(self, s: Union[str, bytes], + max_age: Optional[int] = None, + return_timestamp: bool = False, + salt: Optional[bytes] = None) -> Any: """ Like `Confirmation.loads`, but raise HTTP exceptions with appropriate messages instead of `BadSignature` or `SignatureExpired`. diff --git a/accounts/utils/login.py b/accounts/utils/login.py index 0cd1dc4..64e6ce8 100644 --- a/accounts/utils/login.py +++ b/accounts/utils/login.py @@ -7,14 +7,16 @@ import json import flask_login.login_manager from accounts.app import accounts_app +from typing import Union, Any + class _compact_json: @staticmethod - def loads(payload): + def loads(payload: Union[bytes, str, bytearray]) -> Any: return json.loads(payload) @staticmethod - def dumps(obj, **kwargs): + def dumps(obj: Union[list, dict, tuple], **kwargs): kwargs.setdefault("ensure_ascii", False) kwargs.setdefault("separators", (",", ":")) return json.dumps(obj, **kwargs) @@ -26,7 +28,7 @@ def create_login_manager() -> flask_login.login_manager.LoginManager: login_manager.login_view = 'login.login' @login_manager.user_loader - def load_user(user_id: str): + def load_user(user_id: str) -> LoginManager: try: username, password = parse_userid(user_id) return accounts_app.user_backend.auth(username, password) @@ -37,12 +39,12 @@ def create_login_manager() -> flask_login.login_manager.LoginManager: return login_manager -def create_userid(username: str, password: str): +def create_userid(username: str, password: str) -> bytes: userid = (username, password) return base64_encode(_compact_json.dumps(userid)) -def parse_userid(value: str): +def parse_userid(value: str) -> Any: return _compact_json.loads(base64_decode(value)) diff --git a/accounts/utils/sessions.py b/accounts/utils/sessions.py index 007c928..a452fe1 100644 --- a/accounts/utils/sessions.py +++ b/accounts/utils/sessions.py @@ -24,7 +24,7 @@ def _unpad(value: str) -> str: class EncryptedSerializer(TaggedJSONSerializer): - def __init__(self): + def __init__(self) -> None: super(EncryptedSerializer, self).__init__() self.block_size = AES.block_size diff --git a/mypy.ini b/mypy.ini index 95a93c1..b75cab3 100644 --- a/mypy.ini +++ b/mypy.ini @@ -1,3 +1,6 @@ +[mypy] +disallow_untyped_calls = True + [mypy-flask_login.*] ignore_missing_imports = True -- cgit v1.2.3-1-g7c22