# -*- coding: utf-8 -*- from Crypto import Random from Crypto.Cipher import AES from flask import Flask, Request from flask.sessions import TaggedJSONSerializer, SecureCookieSessionInterface from itsdangerous import BadPayload from base64 import b64encode, b64decode from accounts.app import accounts_app from typing import cast def _pad(value: str, block_size: int) -> bytes: padding = block_size - len(value) % block_size return (value + (padding * chr(padding))).encode("UTF-8") def _unpad(value: str) -> str: pad_length = ord(value[len(value)-1:]) return value[:-pad_length] class EncryptedSerializer(TaggedJSONSerializer): def __init__(self): super(EncryptedSerializer, self).__init__() self.block_size = AES.block_size def _cipher(self, iv: bytes): key = accounts_app.config['SESSION_ENCRYPTION_KEY'] assert len(key) == 32 return AES.new(key, AES.MODE_CBC, iv) def dumps(self, value: str) -> str: """ Encrypt the serialized values with `config.SESSION_ENCRYPTION_KEY`. The key must be 32 bytes long. """ serialized_value = super(EncryptedSerializer, self).dumps(value) raw = _pad(serialized_value, self.block_size) iv: bytes = Random.new().read(self.block_size) return b64encode(iv + self._cipher(iv).encrypt(raw)).decode("UTF-8") def loads(self, value: str): """ Decrypt the given serialized session data with `config.SESSION_ENCRYPTION_KEY`. """ decoded = b64decode(value.encode("UTF-8")) iv = decoded[:self.block_size] raw = self._cipher(iv).decrypt(decoded[AES.block_size:]) return super(EncryptedSerializer, self) \ .loads(_unpad(raw)) class EncryptedSessionInterface(SecureCookieSessionInterface): serializer = EncryptedSerializer() def open_session(self, app: Flask, request: Request): session = None try: parent = super(EncryptedSessionInterface, self) session = cast(EncryptedSessionInterface, parent) \ .open_session(app, request) except BadPayload: session = self.session_class() if session is not None: session.permanent = \ app.config.get('PERMANENT_SESSION_LIFETIME') is not None return session