summaryrefslogtreecommitdiffstats
path: root/accounts/utils/sessions.py
blob: 007c9284e1b77fe8f394a97dc241e5d00fd48df5 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# -*- coding: utf-8 -*-


from Crypto import Random
from Crypto.Cipher import AES
from flask import Flask, Request
from flask.sessions import TaggedJSONSerializer, SecureCookieSessionInterface
from itsdangerous import BadPayload
from base64 import b64encode, b64decode
from accounts.app import accounts_app

from typing import cast


def _pad(value: str, block_size: int) -> bytes:
    padding = block_size - len(value) % block_size
    return (value + (padding * chr(padding))).encode("UTF-8")


def _unpad(value: str) -> str:
    pad_length = ord(value[len(value)-1:])
    return value[:-pad_length]


class EncryptedSerializer(TaggedJSONSerializer):

    def __init__(self):
        super(EncryptedSerializer, self).__init__()
        self.block_size = AES.block_size

    def _cipher(self, iv: bytes):
        key = accounts_app.config['SESSION_ENCRYPTION_KEY']
        assert len(key) == 32
        return AES.new(key, AES.MODE_CBC, iv)

    def dumps(self, value: str) -> str:
        """
        Encrypt the serialized values with `config.SESSION_ENCRYPTION_KEY`.
        The key must be 32 bytes long.
        """
        serialized_value = super(EncryptedSerializer, self).dumps(value)

        raw = _pad(serialized_value, self.block_size)
        iv: bytes = Random.new().read(self.block_size)
        return b64encode(iv + self._cipher(iv).encrypt(raw)).decode("UTF-8")

    def loads(self, value: str):
        """
        Decrypt the given serialized session data with
        `config.SESSION_ENCRYPTION_KEY`.
        """
        decoded = b64decode(value.encode("UTF-8"))
        iv = decoded[:self.block_size]
        raw = self._cipher(iv).decrypt(decoded[AES.block_size:])
        return super(EncryptedSerializer, self) \
            .loads(_unpad(raw))


class EncryptedSessionInterface(SecureCookieSessionInterface):
    serializer = EncryptedSerializer()

    def open_session(self, app: Flask, request: Request):
        session = None
        try:
            parent = super(EncryptedSessionInterface, self)
            session = cast(EncryptedSessionInterface, parent) \
                .open_session(app, request)
        except BadPayload:
            session = self.session_class()

        if session is not None:
            session.permanent = \
              app.config.get('PERMANENT_SESSION_LIFETIME') is not None

        return session