-
Notifications
You must be signed in to change notification settings - Fork 25
Expand file tree
/
Copy pathservices.py
More file actions
136 lines (101 loc) · 4.08 KB
/
services.py
File metadata and controls
136 lines (101 loc) · 4.08 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
import uuid
from datetime import datetime, timedelta, timezone
from typing import Annotated, Any
import jwt
from fastapi import Depends, HTTPException, Request
from fastapi.security import OAuth2PasswordBearer
from jwt.exceptions import InvalidTokenError
from passlib.context import CryptContext
from pydantic import ValidationError
from sqlmodel import Session, select
from src.auth.schemas import TokenPayload
from src.core.config import settings
from src.core.db import get_db
from src.users.models import User
from src.users.schemas import UserPublic
ALGORITHM = "HS256"
reusable_oauth2 = OAuth2PasswordBearer(tokenUrl=f"{settings.API_V1_STR}/tokens")
pwd_context = CryptContext(schemes=["bcrypt"], deprecated="auto")
SessionDep = Annotated[Session, Depends(get_db)]
TokenDep = Annotated[str, Depends(reusable_oauth2)]
def get_user_from_session(request: Request, session: SessionDep) -> User:
user_id = request.session.get("user_id")
if not user_id:
raise HTTPException(status_code=401, detail="Not authenticated (no session)")
from src.users.services import get_user_by_id
try:
user_uuid = uuid.UUID(user_id)
user = get_user_by_id(session=session, user_id=user_uuid)
if not user or not user.is_active:
raise HTTPException(status_code=401, detail="Invalid session user")
return UserPublic.model_validate(user)
except (ValueError, TypeError):
raise HTTPException(status_code=401, detail="Invalid user ID in session")
def get_user_from_token(
session: SessionDep,
token: Annotated[str, Depends(reusable_oauth2)],
) -> User:
try:
payload = jwt.decode(token, settings.SECRET_KEY, algorithms=[ALGORITHM])
token_data = TokenPayload(**payload)
user = session.get(User, token_data.sub)
if not user or not user.is_active:
raise HTTPException(status_code=401, detail="Invalid user")
return user
except (InvalidTokenError, ValidationError):
raise HTTPException(status_code=403, detail="Invalid token")
def get_current_user(
request: Request,
session: SessionDep,
token: Annotated[
str | None,
Depends(
OAuth2PasswordBearer(
tokenUrl=f"{settings.API_V1_STR}/tokens", auto_error=False
)
),
] = None,
) -> User:
# Check for token-based authentication first
if token:
return get_user_from_token(session, token)
# Check for session-based authentication
if request.session.get("user_id"):
return get_user_from_session(request, session)
# No valid authentication method found
raise HTTPException(status_code=401, detail="Not authenticated")
CurrentUser = Annotated[User, Depends(get_current_user)]
def verify_password(plain_password: str, hashed_password: str) -> bool:
return pwd_context.verify(plain_password, hashed_password)
def authenticate(*, session: Session, email: str, password: str) -> User | None:
from src.users.services import get_user_by_email
db_user = get_user_by_email(session=session, email=email)
if not db_user:
return None
# Auth0 users may not have a password
if not db_user.hashed_password:
# Return None for users without a password when using password authentication
return None
if not verify_password(password, db_user.hashed_password):
return None
return db_user
def create_access_token(subject: str | Any, expires_delta: timedelta) -> str:
expire = datetime.now(timezone.utc) + expires_delta
to_encode = {"exp": expire, "sub": str(subject)}
encoded_jwt = jwt.encode(to_encode, settings.SECRET_KEY, algorithm=ALGORITHM)
return encoded_jwt
def get_password_hash(password: str) -> str:
return pwd_context.hash(password)
def get_or_create_user_by_email(
session: Session,
email: str,
defaults: dict | None = None,
) -> User:
user = session.exec(select(User).where(User.email == email)).first()
if user:
return user
user = User(email=email, **defaults)
session.add(user)
session.commit()
session.refresh(user)
return user