aboutsummaryrefslogtreecommitdiff
path: root/app/app_sessions.py
blob: 9f4640466da10cc15ea28feb0d2bcac4fa1e43d8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
# UGE / L2 / Intro to relational databases / Python project prototype
# Author: Pacien TRAN-GIRARD
# Licence: EUPL-1.2

from os import environ
from functools import partial

from fastapi import Request, HTTPException, status
from starlette.middleware.sessions import SessionMiddleware


# Use a signed-cookie session manager.
# The default SameSite policy offers some protection against CSRF attacks.
cookie_key = environ['COOKIE_SECRET_KEY']
SessionManager = partial(SessionMiddleware, secret_key=cookie_key)


class FlashMessageQueue:
    """
    Session proxy for managing session flash messages to be displayed to the
    user from one page to another. This suits confirmation and error messages.
    Messages are stored in the session cookie, which is limited in size to
    about 4kb.
    """

    def __init__(self, request: Request):
        if 'messages' not in request.session:
            request.session['messages'] = []

        self._messages = request.session['messages']

    def add(self, class_: str, message: str):
        self._messages.append((class_, message))

    def __iter__(self):
        return self

    def __next__(self):
        if not self._messages:
            raise StopIteration

        return self._messages.pop(0)


class UserSession:
    """
    Session proxy for managing user login sessions.
    """

    def __init__(self, request: Request):
        self._session = request.session

    def is_logged_in(self) -> bool:
        return 'user_id' in self._session

    def get_user_id(self) -> int:
        return self._session['user_id']

    def login(self, user_id: int):
        self._session['user_id'] = user_id

    def logout(self):
        self._session.pop('user_id', None)

    @classmethod
    def authenticated(cls, request: Request) -> 'UserSession':
        """
        Returns the authenticated user session or raises an HTTP Exception,
        dropping the request if the user is not logged in.
        """
        session = cls(request)
        if not session.is_logged_in():
            raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED)

        return session