forked from benavlabs/FastAPI-boilerplate
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsecurity.py
More file actions
136 lines (104 loc) · 4.81 KB
/
security.py
File metadata and controls
136 lines (104 loc) · 4.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
from datetime import UTC, datetime, timedelta
from enum import Enum
from typing import Any, Literal
import bcrypt
import jwt
from fastapi.security import OAuth2PasswordBearer
from pydantic import SecretStr
from sqlalchemy.ext.asyncio import AsyncSession
from ..crud.crud_users import crud_users
from .config import settings
from .db.crud_token_blacklist import crud_token_blacklist
from .schemas import TokenBlacklistCreate, TokenData
SECRET_KEY: SecretStr = settings.SECRET_KEY
ALGORITHM = settings.ALGORITHM
ACCESS_TOKEN_EXPIRE_MINUTES = settings.ACCESS_TOKEN_EXPIRE_MINUTES
REFRESH_TOKEN_EXPIRE_DAYS = settings.REFRESH_TOKEN_EXPIRE_DAYS
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="/api/v1/login")
class TokenType(str, Enum):
ACCESS = "access"
REFRESH = "refresh"
async def verify_password(plain_password: str, hashed_password: str) -> bool:
correct_password: bool = bcrypt.checkpw(plain_password.encode(), hashed_password.encode())
return correct_password
def get_password_hash(password: str) -> str:
hashed_password: str = bcrypt.hashpw(password.encode(), bcrypt.gensalt()).decode()
return hashed_password
async def authenticate_user(username_or_email: str, password: str, db: AsyncSession) -> dict[str, Any] | Literal[False]:
if "@" in username_or_email:
db_user = await crud_users.get(db=db, email=username_or_email, is_deleted=False)
else:
db_user = await crud_users.get(db=db, username=username_or_email, is_deleted=False)
if not db_user:
return False
if not await verify_password(password, db_user["hashed_password"]):
return False
return db_user
async def create_access_token(data: dict[str, Any], expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(minutes=ACCESS_TOKEN_EXPIRE_MINUTES)
to_encode.update({"exp": int(expire.timestamp()), "token_type": TokenType.ACCESS})
encoded_jwt: str = jwt.encode(to_encode, SECRET_KEY.get_secret_value(), algorithm=ALGORITHM)
return encoded_jwt
async def create_refresh_token(data: dict[str, Any], expires_delta: timedelta | None = None) -> str:
to_encode = data.copy()
if expires_delta:
expire = datetime.now(UTC) + expires_delta
else:
expire = datetime.now(UTC) + timedelta(days=REFRESH_TOKEN_EXPIRE_DAYS)
to_encode.update({"exp": int(expire.timestamp()), "token_type": TokenType.REFRESH})
encoded_jwt: str = jwt.encode(to_encode, SECRET_KEY.get_secret_value(), algorithm=ALGORITHM)
return encoded_jwt
async def verify_token(token: str, expected_token_type: TokenType, db: AsyncSession) -> TokenData | None:
"""Verify a JWT token and return TokenData if valid.
Parameters
----------
token: str
The JWT token to be verified.
expected_token_type: TokenType
The expected type of token (access or refresh)
db: AsyncSession
Database session for performing database operations.
Returns
-------
TokenData | None
TokenData instance if the token is valid, None otherwise.
"""
is_blacklisted = await crud_token_blacklist.exists(db, token=token)
if is_blacklisted:
return None
try:
payload = jwt.decode(token, SECRET_KEY.get_secret_value(), algorithms=[ALGORITHM])
username_or_email: str | None = payload.get("sub")
token_type: str | None = payload.get("token_type")
if username_or_email is None or token_type != expected_token_type:
return None
return TokenData(username_or_email=username_or_email)
except jwt.PyJWTError:
return None
async def blacklist_tokens(access_token: str, refresh_token: str, db: AsyncSession) -> None:
"""Blacklist both access and refresh tokens.
Parameters
----------
access_token: str
The access token to blacklist
refresh_token: str
The refresh token to blacklist
db: AsyncSession
Database session for performing database operations.
"""
for token in [access_token, refresh_token]:
payload = jwt.decode(token, SECRET_KEY.get_secret_value(), algorithms=[ALGORITHM])
exp_timestamp = payload.get("exp")
if exp_timestamp is not None:
expires_at = datetime.fromtimestamp(exp_timestamp)
await crud_token_blacklist.create(db, object=TokenBlacklistCreate(token=token, expires_at=expires_at))
async def blacklist_token(token: str, db: AsyncSession) -> None:
payload = jwt.decode(token, SECRET_KEY.get_secret_value(), algorithms=[ALGORITHM])
exp_timestamp = payload.get("exp")
if exp_timestamp is not None:
expires_at = datetime.fromtimestamp(exp_timestamp)
await crud_token_blacklist.create(db, object=TokenBlacklistCreate(token=token, expires_at=expires_at))