diff --git a/tests/test_session.py b/tests/test_session.py index c700b507..254c9cf0 100644 --- a/tests/test_session.py +++ b/tests/test_session.py @@ -34,29 +34,44 @@ def TEST_CONSTANTS(): encryption_algorithm=serialization.NoEncryption(), ) + current_datetime = datetime.now(timezone.utc) + current_timestamp = str(current_datetime) + + token_claims = { + "sid": "session_123", + "org_id": "organization_123", + "role": "admin", + "permissions": ["read"], + "entitlements": ["feature_1"], + "exp": int(current_datetime.timestamp()) + 3600, + "iat": int(current_datetime.timestamp()), + } + + user_id = "user_123" + return { "COOKIE_PASSWORD": "pfSqwTFXUTGEBBD1RQh2kt/oNJYxBgaoZan4Z8sMrKU=", "SESSION_DATA": "session_data", "CLIENT_ID": "client_123", - "USER_ID": "user_123", + "USER_ID": user_id, "SESSION_ID": "session_123", "ORGANIZATION_ID": "organization_123", - "CURRENT_TIMESTAMP": str(datetime.now(timezone.utc)), + "CURRENT_DATETIME": current_datetime, + "CURRENT_TIMESTAMP": current_timestamp, "PRIVATE_KEY": private_pem, "PUBLIC_KEY": public_key, - "TEST_TOKEN": jwt.encode( - { - "sid": "session_123", - "org_id": "organization_123", - "role": "admin", - "permissions": ["read"], - "entitlements": ["feature_1"], - "exp": int(datetime.now(timezone.utc).timestamp()) + 3600, - "iat": int(datetime.now(timezone.utc).timestamp()), - }, - private_pem, - algorithm="RS256", - ), + "TEST_TOKEN": jwt.encode(token_claims, private_pem, algorithm="RS256"), + "TEST_TOKEN_CLAIMS": token_claims, + "TEST_USER": { + "object": "user", + "id": user_id, + "email": "user@example.com", + "first_name": "Test", + "last_name": "User", + "email_verified": True, + "created_at": current_timestamp, + "updated_at": current_timestamp, + }, } @@ -145,6 +160,30 @@ def test_authenticate_invalid_jwt(TEST_CONSTANTS, mock_user_management): assert response.reason == AuthenticateWithSessionCookieFailureReason.INVALID_JWT +@with_jwks_mock +def test_authenticate_jwt_with_aud_claim(TEST_CONSTANTS, mock_user_management): + access_token = jwt.encode( + {**TEST_CONSTANTS["TEST_TOKEN_CLAIMS"], **{"aud": TEST_CONSTANTS["CLIENT_ID"]}}, + TEST_CONSTANTS["PRIVATE_KEY"], + algorithm="RS256", + ) + + session_data = Session.seal_data( + {"access_token": access_token, "user": TEST_CONSTANTS["TEST_USER"]}, + TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=session_data, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + response = session.authenticate() + + assert isinstance(response, AuthenticateWithSessionCookieSuccessResponse) + + @with_jwks_mock def test_authenticate_success(TEST_CONSTANTS, mock_user_management): session = Session( @@ -229,19 +268,8 @@ def test_refresh_invalid_session_cookie(TEST_CONSTANTS, mock_user_management): @with_jwks_mock def test_refresh_success(TEST_CONSTANTS, mock_user_management): - test_user = { - "object": "user", - "id": TEST_CONSTANTS["USER_ID"], - "email": "user@example.com", - "first_name": "Test", - "last_name": "User", - "email_verified": True, - "created_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], - "updated_at": TEST_CONSTANTS["CURRENT_TIMESTAMP"], - } - session_data = Session.seal_data( - {"refresh_token": "refresh_token_12345", "user": test_user}, + {"refresh_token": "refresh_token_12345", "user": TEST_CONSTANTS["TEST_USER"]}, TEST_CONSTANTS["COOKIE_PASSWORD"], ) @@ -249,7 +277,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): "access_token": TEST_CONSTANTS["TEST_TOKEN"], "refresh_token": "refresh_token_123", "sealed_session": session_data, - "user": test_user, + "user": TEST_CONSTANTS["TEST_USER"], } mock_user_management.authenticate_with_refresh_token.return_value = ( @@ -278,7 +306,7 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): assert isinstance(response, RefreshWithSessionCookieSuccessResponse) assert response.authenticated is True - assert response.user.id == test_user["id"] + assert response.user.id == TEST_CONSTANTS["TEST_USER"]["id"] # Verify the refresh token was used correctly mock_user_management.authenticate_with_refresh_token.assert_called_once_with( @@ -291,6 +319,42 @@ def test_refresh_success(TEST_CONSTANTS, mock_user_management): ) +@with_jwks_mock +def test_refresh_success_with_aud_claim(TEST_CONSTANTS, mock_user_management): + session_data = Session.seal_data( + {"refresh_token": "refresh_token_12345", "user": TEST_CONSTANTS["TEST_USER"]}, + TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + access_token = jwt.encode( + {**TEST_CONSTANTS["TEST_TOKEN_CLAIMS"], **{"aud": TEST_CONSTANTS["CLIENT_ID"]}}, + TEST_CONSTANTS["PRIVATE_KEY"], + algorithm="RS256", + ) + + mock_response = { + "access_token": access_token, + "refresh_token": "refresh_token_123", + "sealed_session": session_data, + "user": TEST_CONSTANTS["TEST_USER"], + } + + mock_user_management.authenticate_with_refresh_token.return_value = ( + RefreshTokenAuthenticationResponse(**mock_response) + ) + + session = Session( + user_management=mock_user_management, + client_id=TEST_CONSTANTS["CLIENT_ID"], + session_data=session_data, + cookie_password=TEST_CONSTANTS["COOKIE_PASSWORD"], + ) + + response = session.refresh() + + assert isinstance(response, RefreshWithSessionCookieSuccessResponse) + + def test_seal_data(TEST_CONSTANTS): test_data = {"test": "data"} sealed = Session.seal_data(test_data, TEST_CONSTANTS["COOKIE_PASSWORD"]) diff --git a/workos/session.py b/workos/session.py index c4ab1c35..abee954c 100644 --- a/workos/session.py +++ b/workos/session.py @@ -79,7 +79,10 @@ def authenticate( signing_key = self.jwks.get_signing_key_from_jwt(session["access_token"]) decoded = jwt.decode( - session["access_token"], signing_key.key, algorithms=self.jwk_algorithms + session["access_token"], + signing_key.key, + algorithms=self.jwk_algorithms, + options={"verify_aud": False}, ) return AuthenticateWithSessionCookieSuccessResponse( @@ -141,6 +144,7 @@ def refresh( auth_response.access_token, signing_key.key, algorithms=self.jwk_algorithms, + options={"verify_aud": False}, ) return RefreshWithSessionCookieSuccessResponse( @@ -176,7 +180,12 @@ def get_logout_url(self, return_to: Optional[str] = None) -> str: def _is_valid_jwt(self, token: str) -> bool: try: signing_key = self.jwks.get_signing_key_from_jwt(token) - jwt.decode(token, signing_key.key, algorithms=self.jwk_algorithms) + jwt.decode( + token, + signing_key.key, + algorithms=self.jwk_algorithms, + options={"verify_aud": False}, + ) return True except jwt.exceptions.InvalidTokenError: return False