# -*- coding: utf-8 -*- from Crypto import Random from Crypto.Cipher import AES from flask import Flask from flask.sessions import TaggedJSONSerializer, SecureCookieSessionInterface from itsdangerous import BadPayload from accounts.app import accounts_app from typing import cast def _pad(value, block_size): padding = block_size - len(value) % block_size return (value + (padding * chr(padding))).encode("UTF-8") def _unpad(value): pad_length = ord(value[len(value)-1:]) return value[:-pad_length] class EncryptedSerializer(TaggedJSONSerializer): def __init__(self): super(EncryptedSerializer, self).__init__() self.block_size = AES.block_size def _cipher(self, iv): key = accounts_app.config['SESSION_ENCRYPTION_KEY'] assert len(key) == 32 return AES.new(key, AES.MODE_CBC, iv) def dumps(self, value): """ Encrypt the serialized values with `config.SESSION_ENCRYPTION_KEY`. The key must be 32 bytes long. """ serialized_value = super(EncryptedSerializer, self).dumps(value) raw = _pad(serialized_value, self.block_size) iv = Random.new().read(self.block_size) return iv + self._cipher(iv).encrypt(raw) def loads(self, value): """ Decrypt the given serialized session data with `config.SESSION_ENCRYPTION_KEY`. """ iv = value[:self.block_size] raw = self._cipher(iv).decrypt(value[AES.block_size:]) return super(EncryptedSerializer, self).loads(_unpad(raw)) class EncryptedSessionInterface(SecureCookieSessionInterface): serializer = EncryptedSerializer() def open_session(self, app: Flask, request): session = None try: parent = super(EncryptedSessionInterface, self) session = cast(EncryptedSessionInterface, parent) \ .open_session(app, request) except BadPayload: session = self.session_class() if session is not None: session.permanent = \ app.config.get('PERMANENT_SESSION_LIFETIME') is not None return session