diff --git a/.changeset/fuzzy-pandas-scope.md b/.changeset/fuzzy-pandas-scope.md new file mode 100644 index 00000000..7e1e359c --- /dev/null +++ b/.changeset/fuzzy-pandas-scope.md @@ -0,0 +1,5 @@ +--- +'pypi/posthog': minor +--- + +Add context helper methods to custom PostHog client instances. diff --git a/posthog/client.py b/posthog/client.py index 8c4520c0..8f334778 100644 --- a/posthog/client.py +++ b/posthog/client.py @@ -24,7 +24,13 @@ get_context_device_id, get_context_distinct_id, get_context_session_id, + get_tags as _context_get_tags, + identify_context as _context_identify_context, + _scoped as _context_scoped, new_context, + set_context_device_id as _context_set_context_device_id, + set_context_session as _context_set_context_session, + tag as _context_tag, ) from posthog.exception_capture import ExceptionCapture from posthog.exception_utils import ( @@ -486,9 +492,9 @@ def new_context(self, fresh=False, capture_exceptions: Optional[bool] = None): Examples: ```python - with posthog.new_context(): - identify_context('') - posthog.capture('event_name') + with client.new_context(): + client.identify_context('') + client.capture('event_name') ``` Category: @@ -498,6 +504,83 @@ def new_context(self, fresh=False, capture_exceptions: Optional[bool] = None): fresh=fresh, capture_exceptions=capture_exceptions, client=self ) + def scoped(self, fresh=False, capture_exceptions: Optional[bool] = None): + """ + Decorator that creates a new context for the wrapped function using this client. + + Args: + fresh: Whether to create a fresh context that doesn't inherit from parent. + capture_exceptions: Whether to automatically capture exceptions in this context. If omitted, defaults to this client's exception autocapture setting. + + Category: + Contexts + """ + + return _context_scoped( + fresh=fresh, capture_exceptions=capture_exceptions, client=self + ) + + def tag(self, name: str, value: Any) -> None: + """ + Add a tag to the current context. + + Args: + name: The tag key. + value: The tag value. + + Category: + Contexts + """ + _context_tag(name, value) + + def get_tags(self) -> Dict[str, Any]: + """ + Get all tags from the current context. + + Returns: + Dict of all tags in the current context. + + Category: + Contexts + """ + return _context_get_tags() + + def identify_context(self, distinct_id: str) -> None: + """ + Identify the current context with a distinct ID. + + Args: + distinct_id: The distinct ID to associate with the current context and its children. + + Category: + Contexts + """ + _context_identify_context(distinct_id) + + def set_context_session(self, session_id: str) -> None: + """ + Set the session ID for the current context. + + Args: + session_id: The session ID to associate with the current context and its children. + + Category: + Contexts + """ + _context_set_context_session(session_id) + + def set_context_device_id(self, device_id: str) -> None: + """ + Set the device ID for the current context. + + Args: + device_id: The device ID to associate with the current context and its children. + + Category: + Contexts + """ + _context_set_context_device_id(device_id) + @property def feature_flags(self): """ diff --git a/posthog/contexts.py b/posthog/contexts.py index 99b21562..5a31f80a 100644 --- a/posthog/contexts.py +++ b/posthog/contexts.py @@ -393,6 +393,38 @@ def get_code_variables_ignore_patterns_context() -> Optional[list]: F = TypeVar("F", bound=Callable[..., Any]) +def _scoped( + fresh: bool = False, + capture_exceptions: Optional[bool] = None, + client: Optional["Client"] = None, +): + def decorator(func: F) -> F: + from functools import wraps + from inspect import iscoroutinefunction + + if iscoroutinefunction(func): + + @wraps(func) + async def async_wrapper(*args, **kwargs): + with new_context( + fresh=fresh, capture_exceptions=capture_exceptions, client=client + ): + return await func(*args, **kwargs) + + return cast(F, async_wrapper) + + @wraps(func) + def wrapper(*args, **kwargs): + with new_context( + fresh=fresh, capture_exceptions=capture_exceptions, client=client + ): + return func(*args, **kwargs) + + return cast(F, wrapper) + + return decorator + + def scoped(fresh: bool = False, capture_exceptions: Optional[bool] = None): """ Decorator that creates a new context for the function. Simply wraps @@ -424,25 +456,4 @@ async def middleware(request, call_next): Category: Contexts """ - - def decorator(func: F) -> F: - from functools import wraps - from inspect import iscoroutinefunction - - if iscoroutinefunction(func): - - @wraps(func) - async def async_wrapper(*args, **kwargs): - with new_context(fresh=fresh, capture_exceptions=capture_exceptions): - return await func(*args, **kwargs) - - return cast(F, async_wrapper) - - @wraps(func) - def wrapper(*args, **kwargs): - with new_context(fresh=fresh, capture_exceptions=capture_exceptions): - return func(*args, **kwargs) - - return cast(F, wrapper) - - return decorator + return _scoped(fresh=fresh, capture_exceptions=capture_exceptions) diff --git a/posthog/test/test_client.py b/posthog/test/test_client.py index cb8d4b5d..d871733a 100644 --- a/posthog/test/test_client.py +++ b/posthog/test/test_client.py @@ -1,3 +1,4 @@ +import asyncio import time import unittest from datetime import datetime @@ -2361,6 +2362,35 @@ def test_device_id_from_context_is_used_in_flags_request(self, patch_flags): flag_keys_to_evaluate=["random_key"], ) + @mock.patch("posthog.client.flags") + def test_client_set_context_device_id_is_used_in_flags_request(self, patch_flags): + patch_flags.return_value = { + "featureFlags": { + "beta-feature": "random-variant", + } + } + client = Client( + FAKE_TEST_API_KEY, + on_error=self.set_fail, + ) + + with client.new_context(): + client.set_context_device_id("client-context-device-id") + client.get_feature_flag("random_key", "some_id") + + patch_flags.assert_called_with( + "random_key", + "https://us.i.posthog.com", + timeout=3, + distinct_id="some_id", + groups={}, + person_properties={"distinct_id": "some_id"}, + group_properties={}, + geoip_disable=True, + device_id="client-context-device-id", + flag_keys_to_evaluate=["random_key"], + ) + @parameterized.expand( [ # name, sys_platform, version_info, expected_runtime, expected_version, expected_os, expected_os_version, expected_os_distro, platform_method, platform_return @@ -2534,6 +2564,77 @@ def test_set_context_session_with_capture(self): msg["properties"]["$session_id"], "context-session-123" ) + @parameterized.expand([("new_context",), ("scoped",)]) + def test_client_context_helpers_apply_to_capture(self, context_helper): + with mock.patch("posthog.client.batch_post") as mock_post: + client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail, sync_mode=True) + + def capture_in_context(): + client.tag("client_tag", "tag-value") + client.identify_context("context-user") + client.set_context_session("context-session-123") + + self.assertEqual(client.get_tags(), {"client_tag": "tag-value"}) + + return client.capture( + "test_event", + properties={"custom_prop": "value"}, + ) + + if context_helper == "new_context": + with client.new_context(fresh=True): + msg_uuid = capture_in_context() + else: + + @client.scoped(fresh=True) + def scoped_capture(): + return capture_in_context() + + msg_uuid = scoped_capture() + + self.assertIsNotNone(msg_uuid) + mock_post.assert_called_once() + batch_data = mock_post.call_args[1]["batch"] + msg = batch_data[0] + + self.assertEqual(msg["distinct_id"], "context-user") + self.assertEqual(msg["properties"]["client_tag"], "tag-value") + self.assertEqual(msg["properties"]["custom_prop"], "value") + self.assertEqual(msg["properties"]["$session_id"], "context-session-123") + self.assertCountEqual(msg["properties"]["$context_tags"], ["client_tag"]) + self.assertEqual(client.get_tags(), {}) + + def test_client_scoped_context_helpers_apply_to_capture_async(self): + with mock.patch("posthog.client.batch_post") as mock_post: + client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail, sync_mode=True) + + @client.scoped(fresh=True) + async def scoped_capture(): + client.tag("async_scoped_tag", "async-scoped-value") + client.identify_context("async-scoped-user") + client.set_context_session("async-scoped-session-123") + await asyncio.sleep(0) + return client.capture("async_scoped_event") + + msg_uuid = asyncio.run(scoped_capture()) + + self.assertIsNotNone(msg_uuid) + mock_post.assert_called_once() + batch_data = mock_post.call_args[1]["batch"] + msg = batch_data[0] + + self.assertEqual(msg["distinct_id"], "async-scoped-user") + self.assertEqual( + msg["properties"]["async_scoped_tag"], "async-scoped-value" + ) + self.assertEqual( + msg["properties"]["$session_id"], "async-scoped-session-123" + ) + self.assertCountEqual( + msg["properties"]["$context_tags"], ["async_scoped_tag"] + ) + self.assertEqual(client.get_tags(), {}) + def test_set_context_session_with_page_explicit_properties(self): with mock.patch("posthog.client.batch_post") as mock_post: client = Client(FAKE_TEST_API_KEY, on_error=self.set_fail, sync_mode=True) diff --git a/references/public_api_snapshot.txt b/references/public_api_snapshot.txt index f6a1e30e..f8f1a031 100644 --- a/references/public_api_snapshot.txt +++ b/references/public_api_snapshot.txt @@ -1059,13 +1059,19 @@ method posthog.client.Client.get_feature_payloads(distinct_id, groups=None, pers method posthog.client.Client.get_feature_variants(distinct_id, groups=None, person_properties=None, group_properties=None, disable_geoip=None, flag_keys_to_evaluate: Optional[list[str]] = None, device_id: Optional[str] = None) -> dict[str, Union[bool, str]] method posthog.client.Client.get_flags_decision(distinct_id: Optional[ID_TYPES] = None, groups: Optional[dict] = None, person_properties=None, group_properties=None, disable_geoip=None, flag_keys_to_evaluate: Optional[list[str]] = None, device_id: Optional[str] = None) -> FlagsResponse method posthog.client.Client.get_remote_config_payload(key: str) +method posthog.client.Client.get_tags() -> Dict[str, Any] method posthog.client.Client.group_identify(group_type: str, group_key: str, properties: Optional[Dict[str, Any]] = None, timestamp: Optional[Union[datetime, str]] = None, uuid: Optional[Union[str, UUID]] = None, disable_geoip: Optional[bool] = None, distinct_id: Optional[ID_TYPES] = None) -> Optional[str] +method posthog.client.Client.identify_context(distinct_id: str) -> None method posthog.client.Client.join() -> None method posthog.client.Client.load_feature_flags() method posthog.client.Client.new_context(fresh=False, capture_exceptions: Optional[bool] = None) +method posthog.client.Client.scoped(fresh=False, capture_exceptions: Optional[bool] = None) method posthog.client.Client.set(**kwargs: Unpack[OptionalSetArgs]) -> Optional[str] +method posthog.client.Client.set_context_device_id(device_id: str) -> None +method posthog.client.Client.set_context_session(session_id: str) -> None method posthog.client.Client.set_once(**kwargs: Unpack[OptionalSetArgs]) -> Optional[str] method posthog.client.Client.shutdown() -> None +method posthog.client.Client.tag(name: str, value: Any) -> None method posthog.consumer.Consumer.next() method posthog.consumer.Consumer.pause() method posthog.consumer.Consumer.request(batch)