diff --git a/src/apify/_actor.py b/src/apify/_actor.py index 6b556cf3..ba41c115 100644 --- a/src/apify/_actor.py +++ b/src/apify/_actor.py @@ -32,7 +32,7 @@ from apify._crypto import decrypt_input_secrets, load_private_key from apify._models import ActorRun from apify._proxy_configuration import ProxyConfiguration -from apify._utils import docs_group, docs_name, get_system_info, is_running_in_ipython +from apify._utils import docs_group, docs_name, ensure_context, get_system_info, is_running_in_ipython from apify.events import ApifyEventManager, EventManager, LocalEventManager from apify.log import _configure_logging, logger from apify.storage_clients import ApifyStorageClient, SmartApifyStorageClient @@ -53,6 +53,8 @@ MainReturnType = TypeVar('MainReturnType') +_ensure_context = ensure_context('_active') + @docs_name('Actor') @docs_group('Actor') @@ -139,8 +141,8 @@ def __init__( # Keep track of all used state stores to persist their values on exit self._use_state_stores: set[str | None] = set() - self._is_initialized = False - """Whether any Actor instance is currently initialized.""" + self._active = False + """Whether the Actor instance is currently active (initialized and within context).""" self._is_rebooting = False """Whether the Actor is currently rebooting.""" @@ -161,7 +163,7 @@ async def __aenter__(self) -> Self: This method must be called exactly once per Actor instance. Re-initializing an Actor or having multiple active Actor instances is not standard usage and may lead to warnings or unexpected behavior. """ - if self._is_initialized: + if self._active: raise RuntimeError('The Actor was already initialized!') # Initialize configuration first - it's required for the next steps. @@ -198,7 +200,7 @@ async def __aenter__(self) -> Self: self.log.debug('Charging manager initialized') # Mark initialization as complete and update global state. - self._is_initialized = True + self._active = True if not Actor.is_at_home(): # Make sure that the input related KVS is initialized to ensure that the input aware client is used @@ -225,7 +227,8 @@ async def __aexit__( if self._is_exiting: return - self._raise_if_not_initialized() + if not self._active: + raise RuntimeError('The _ActorType is not active. Use it within the async context.') if exc_value and not is_running_in_ipython(): # In IPython, we don't run `sys.exit()` during Actor exits, @@ -257,7 +260,7 @@ async def finalize() -> None: except TimeoutError: self.log.exception('Actor cleanup timed out') finally: - self._is_initialized = False + self._active = False if self._exit_process: sys.exit(self.exit_code) @@ -513,6 +516,7 @@ def new_client( timeout_secs=int(timeout.total_seconds()) if timeout else None, ) + @_ensure_context async def open_dataset( self, *, @@ -540,7 +544,6 @@ async def open_dataset( Returns: An instance of the `Dataset` class for the given ID or name. """ - self._raise_if_not_initialized() return await Dataset.open( id=id, name=name, @@ -548,6 +551,7 @@ async def open_dataset( storage_client=self._storage_client.get_suitable_storage_client(force_cloud=force_cloud), ) + @_ensure_context async def open_key_value_store( self, *, @@ -574,7 +578,6 @@ async def open_key_value_store( Returns: An instance of the `KeyValueStore` class for the given ID or name. """ - self._raise_if_not_initialized() return await KeyValueStore.open( id=id, name=name, @@ -582,6 +585,7 @@ async def open_key_value_store( storage_client=self._storage_client.get_suitable_storage_client(force_cloud=force_cloud), ) + @_ensure_context async def open_request_queue( self, *, @@ -610,7 +614,6 @@ async def open_request_queue( Returns: An instance of the `RequestQueue` class for the given ID or name. """ - self._raise_if_not_initialized() return await RequestQueue.open( id=id, name=name, @@ -622,6 +625,7 @@ async def open_request_queue( async def push_data(self, data: dict | list[dict]) -> None: ... @overload async def push_data(self, data: dict | list[dict], charged_event_name: str) -> ChargeResult: ... + @_ensure_context async def push_data(self, data: dict | list[dict], charged_event_name: str | None = None) -> ChargeResult | None: """Store an object or a list of objects to the default dataset of the current Actor run. @@ -630,8 +634,6 @@ async def push_data(self, data: dict | list[dict], charged_event_name: str | Non charged_event_name: If provided and if the Actor uses the pay-per-event pricing model, the method will attempt to charge for the event for each pushed item. """ - self._raise_if_not_initialized() - if not data: return None @@ -665,10 +667,9 @@ async def push_data(self, data: dict | list[dict], charged_event_name: str | Non count=pushed_items_count, ) + @_ensure_context async def get_input(self) -> Any: """Get the Actor input value from the default key-value store associated with the current Actor run.""" - self._raise_if_not_initialized() - input_value = await self.get_value(self.configuration.input_key) input_secrets_private_key = self.configuration.input_secrets_private_key_file input_secrets_key_passphrase = self.configuration.input_secrets_private_key_passphrase @@ -681,6 +682,7 @@ async def get_input(self) -> Any: return input_value + @_ensure_context async def get_value(self, key: str, default_value: Any = None) -> Any: """Get a value from the default key-value store associated with the current Actor run. @@ -688,11 +690,10 @@ async def get_value(self, key: str, default_value: Any = None) -> Any: key: The key of the record which to retrieve. default_value: Default value returned in case the record does not exist. """ - self._raise_if_not_initialized() - key_value_store = await self.open_key_value_store() return await key_value_store.get_value(key, default_value) + @_ensure_context async def set_value( self, key: str, @@ -707,16 +708,15 @@ async def set_value( value: The value of the record which to set, or None, if the record should be deleted. content_type: The content type which should be set to the value. """ - self._raise_if_not_initialized() - key_value_store = await self.open_key_value_store() return await key_value_store.set_value(key, value, content_type=content_type) + @_ensure_context def get_charging_manager(self) -> ChargingManager: """Retrieve the charging manager to access granular pricing information.""" - self._raise_if_not_initialized() return self._charging_manager_implementation + @_ensure_context async def charge(self, event_name: str, count: int = 1) -> ChargeResult: """Charge for a specified number of events - sub-operations of the Actor. @@ -726,7 +726,6 @@ async def charge(self, event_name: str, count: int = 1) -> ChargeResult: event_name: Name of the event to be charged for. count: Number of events to charge for. """ - self._raise_if_not_initialized() # Acquire lock to prevent race conditions with concurrent charge/push_data calls. async with self._charge_lock: return await self.get_charging_manager().charge(event_name, count) @@ -754,6 +753,7 @@ def on( @overload def on(self, event_name: Event, listener: EventListener[None]) -> EventListener[Any]: ... + @_ensure_context def on(self, event_name: Event, listener: EventListener[Any]) -> EventListener[Any]: """Add an event listener to the Actor's event manager. @@ -778,8 +778,6 @@ def on(self, event_name: Event, listener: EventListener[Any]) -> EventListener[A event_name: The Actor event to listen for. listener: The function to be called when the event is emitted (can be async). """ - self._raise_if_not_initialized() - self.event_manager.on(event=event_name, listener=listener) return listener @@ -796,6 +794,7 @@ def off(self, event_name: Literal[Event.EXIT], listener: EventListener[EventExit @overload def off(self, event_name: Event, listener: EventListener[None]) -> None: ... + @_ensure_context def off(self, event_name: Event, listener: Callable | None = None) -> None: """Remove a listener, or all listeners, from an Actor event. @@ -804,14 +803,13 @@ def off(self, event_name: Event, listener: Callable | None = None) -> None: listener: The listener which is supposed to be removed. If not passed, all listeners of this event are removed. """ - self._raise_if_not_initialized() - self.event_manager.off(event=event_name, listener=listener) def is_at_home(self) -> bool: """Return `True` when the Actor is running on the Apify platform, and `False` otherwise (e.g. local run).""" return self.configuration.is_at_home + @_ensure_context def get_env(self) -> dict: """Return a dictionary with information parsed from all the `APIFY_XXX` environment variables. @@ -819,8 +817,6 @@ def get_env(self) -> dict: [Actor documentation](https://docs.apify.com/actors/development/environment-variables). If some variables are not defined or are invalid, the corresponding value in the resulting dictionary will be None. """ - self._raise_if_not_initialized() - config = dict[str, Any]() for field_name, field in Configuration.model_fields.items(): if field.deprecated: @@ -841,6 +837,7 @@ def get_env(self) -> dict: env_vars = {env_var.value.lower(): env_var.name.lower() for env_var in [*ActorEnvVars, *ApifyEnvVars]} return {option_name: config[env_var] for env_var, option_name in env_vars.items() if env_var in config} + @_ensure_context async def start( self, actor_id: str, @@ -879,8 +876,6 @@ async def start( Returns: Info about the started Actor run """ - self._raise_if_not_initialized() - client = self.new_client(token=token) if token else self.apify_client if webhooks: @@ -919,6 +914,7 @@ async def start( return ActorRun.model_validate(api_result) + @_ensure_context async def abort( self, run_id: str, @@ -942,8 +938,6 @@ async def abort( Returns: Info about the aborted Actor run. """ - self._raise_if_not_initialized() - client = self.new_client(token=token) if token else self.apify_client if status_message: @@ -953,6 +947,7 @@ async def abort( return ActorRun.model_validate(api_result) + @_ensure_context async def call( self, actor_id: str, @@ -995,8 +990,6 @@ async def call( Returns: Info about the started Actor run. """ - self._raise_if_not_initialized() - client = self.new_client(token=token) if token else self.apify_client if webhooks: @@ -1037,6 +1030,7 @@ async def call( return ActorRun.model_validate(api_result) + @_ensure_context async def call_task( self, task_id: str, @@ -1077,8 +1071,6 @@ async def call_task( Returns: Info about the started Actor run. """ - self._raise_if_not_initialized() - client = self.new_client(token=token) if token else self.apify_client if webhooks: @@ -1108,6 +1100,7 @@ async def call_task( return ActorRun.model_validate(api_result) + @_ensure_context async def metamorph( self, target_actor_id: str, @@ -1132,8 +1125,6 @@ async def metamorph( content_type: The content type of the input. custom_after_sleep: How long to sleep for after the metamorph, to wait for the container to be stopped. """ - self._raise_if_not_initialized() - if not self.is_at_home(): self.log.error('Actor.metamorph() is only supported when running on the Apify platform.') return @@ -1155,6 +1146,7 @@ async def metamorph( if custom_after_sleep: await asyncio.sleep(custom_after_sleep.total_seconds()) + @_ensure_context async def reboot( self, *, @@ -1169,8 +1161,6 @@ async def reboot( event_listeners_timeout: How long should the Actor wait for Actor event listeners to finish before exiting. custom_after_sleep: How long to sleep for after the reboot, to wait for the container to be stopped. """ - self._raise_if_not_initialized() - if not self.is_at_home(): self.log.error('Actor.reboot() is only supported when running on the Apify platform.') return @@ -1210,6 +1200,7 @@ async def reboot( if custom_after_sleep: await asyncio.sleep(custom_after_sleep.total_seconds()) + @_ensure_context async def add_webhook( self, webhook: Webhook, @@ -1237,8 +1228,6 @@ async def add_webhook( Returns: The created webhook. """ - self._raise_if_not_initialized() - if not self.is_at_home(): self.log.error('Actor.add_webhook() is only supported when running on the Apify platform.') return @@ -1257,6 +1246,7 @@ async def add_webhook( idempotency_key=idempotency_key, ) + @_ensure_context async def set_status_message( self, status_message: str, @@ -1272,8 +1262,6 @@ async def set_status_message( Returns: The updated Actor run object. """ - self._raise_if_not_initialized() - if not self.is_at_home(): title = 'Terminal status message' if is_terminal else 'Status message' self.log.info(f'[{title}]: {status_message}') @@ -1289,6 +1277,7 @@ async def set_status_message( return ActorRun.model_validate(api_result) + @_ensure_context async def create_proxy_configuration( self, *, @@ -1321,8 +1310,6 @@ async def create_proxy_configuration( ProxyConfiguration object with the passed configuration, or None, if no proxy should be used based on the configuration. """ - self._raise_if_not_initialized() - if actor_proxy_input is not None: if actor_proxy_input.get('useApifyProxy', False): country_code = country_code or actor_proxy_input.get('apifyProxyCountry') @@ -1346,6 +1333,7 @@ async def create_proxy_configuration( return proxy_configuration + @_ensure_context async def use_state( self, default_value: dict[str, JsonSerializable] | None = None, @@ -1365,8 +1353,6 @@ async def use_state( Returns: The state dictionary with automatic persistence. """ - self._raise_if_not_initialized() - self._use_state_stores.add(kvs_name) kvs = await self.open_key_value_store(name=kvs_name) return await kvs.get_auto_saved_value(key or self._ACTOR_STATE_KEY, default_value) @@ -1376,10 +1362,6 @@ async def _save_actor_state(self) -> None: store = await self.open_key_value_store(name=kvs_name) await store.persist_autosaved_values() - def _raise_if_not_initialized(self) -> None: - if not self._is_initialized: - raise RuntimeError('The Actor was not initialized!') - def _get_default_exit_process(self) -> bool: """Return False for IPython and Scrapy environments, True otherwise.""" if is_running_in_ipython(): diff --git a/src/apify/_charging.py b/src/apify/_charging.py index f7a17b2b..c5ed7826 100644 --- a/src/apify/_charging.py +++ b/src/apify/_charging.py @@ -8,8 +8,6 @@ from pydantic import TypeAdapter -from crawlee._utils.context import ensure_context - from apify._models import ( ActorRun, FlatPricePerMonthActorPricingInfo, @@ -18,7 +16,7 @@ PricePerDatasetItemActorPricingInfo, PricingModel, ) -from apify._utils import docs_group +from apify._utils import docs_group, ensure_context from apify.log import logger from apify.storages import Dataset @@ -31,6 +29,8 @@ run_validator = TypeAdapter[ActorRun | None](ActorRun | None) +_ensure_context = ensure_context('active') + @docs_group('Charging') class ChargingManager(Protocol): @@ -201,7 +201,7 @@ async def __aexit__( self.active = False - @ensure_context + @_ensure_context async def charge(self, event_name: str, count: int = 1) -> ChargeResult: def calculate_chargeable() -> dict[str, int | None]: """Calculate the maximum number of events of each type that can be charged within the current budget.""" @@ -291,14 +291,14 @@ def calculate_chargeable() -> dict[str, int | None]: chargeable_within_limit=calculate_chargeable(), ) - @ensure_context + @_ensure_context def calculate_total_charged_amount(self) -> Decimal: return sum( (item.total_charged_amount for item in self._charging_state.values()), start=Decimal(), ) - @ensure_context + @_ensure_context def calculate_max_event_charge_count_within_limit(self, event_name: str) -> int | None: pricing_info = self._pricing_info.get(event_name) @@ -315,7 +315,7 @@ def calculate_max_event_charge_count_within_limit(self, event_name: str) -> int result = (self._max_total_charge_usd - self.calculate_total_charged_amount()) / price return max(0, math.floor(result)) if result.is_finite() else None - @ensure_context + @_ensure_context def get_pricing_info(self) -> ActorPricingInfo: return ActorPricingInfo( pricing_model=self._pricing_model, @@ -328,12 +328,12 @@ def get_pricing_info(self) -> ActorPricingInfo: }, ) - @ensure_context + @_ensure_context def get_charged_event_count(self, event_name: str) -> int: item = self._charging_state.get(event_name) return item.charge_count if item is not None else 0 - @ensure_context + @_ensure_context def get_max_total_charge_usd(self) -> Decimal: return self._max_total_charge_usd diff --git a/src/apify/_utils.py b/src/apify/_utils.py index 4198bf5d..33cb3d54 100644 --- a/src/apify/_utils.py +++ b/src/apify/_utils.py @@ -1,13 +1,48 @@ from __future__ import annotations import builtins +import inspect import sys +from collections.abc import Callable from enum import Enum +from functools import wraps from importlib import metadata -from typing import TYPE_CHECKING, Any, Literal +from typing import Any, Literal, TypeVar, cast -if TYPE_CHECKING: - from collections.abc import Callable +T = TypeVar('T', bound=Callable[..., Any]) + + +def ensure_context(attribute_name: str) -> Callable[[T], T]: + """Create a decorator that ensures the context manager is initialized before executing the method. + + The decorator checks if the calling instance has the specified attribute and verifies that it is set to `True`. + If the instance is inactive, it raises a `RuntimeError`. Works for both synchronous and asynchronous methods. + + Args: + attribute_name: The name of the boolean attribute to check on the instance. + + Returns: + A decorator that wraps methods with context checking. + """ + + def decorator(method: T) -> T: + @wraps(method) + def sync_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if not getattr(self, attribute_name, False): + raise RuntimeError(f'The {self.__class__.__name__} is not active. Use it within the context.') + + return method(self, *args, **kwargs) + + @wraps(method) + async def async_wrapper(self: Any, *args: Any, **kwargs: Any) -> Any: + if not getattr(self, attribute_name, False): + raise RuntimeError(f'The {self.__class__.__name__} is not active. Use it within the async context.') + + return await method(self, *args, **kwargs) + + return cast('T', async_wrapper if inspect.iscoroutinefunction(method) else sync_wrapper) + + return decorator def get_system_info() -> dict: diff --git a/tests/e2e/conftest.py b/tests/e2e/conftest.py index 4ae56fc6..bdcc9883 100644 --- a/tests/e2e/conftest.py +++ b/tests/e2e/conftest.py @@ -69,7 +69,7 @@ def _prepare_test_env() -> None: if hasattr(apify._actor.Actor, '__wrapped__'): delattr(apify._actor.Actor, '__wrapped__') - apify._actor.Actor._is_initialized = False + apify._actor.Actor._active = False # Set the environment variable for the local storage directory to the temporary path. monkeypatch.setenv(ApifyEnvVars.LOCAL_STORAGE_DIR, str(tmp_path)) diff --git a/tests/e2e/test_actor_lifecycle.py b/tests/e2e/test_actor_lifecycle.py index 47fa06e0..983b8ca3 100644 --- a/tests/e2e/test_actor_lifecycle.py +++ b/tests/e2e/test_actor_lifecycle.py @@ -118,24 +118,24 @@ async def test_actor_sequential_contexts(make_actor: MakeActorFunction, run_acto async def main() -> None: async with Actor as actor: actor._exit_process = False - assert actor._is_initialized is True + assert actor._active is True # Actor after Actor. async with Actor as actor: actor._exit_process = False - assert actor._is_initialized is True + assert actor._active is True # Actor() after Actor. async with Actor(exit_process=False) as actor: - assert actor._is_initialized is True + assert actor._active is True # Actor() after Actor(). async with Actor(exit_process=False) as actor: - assert actor._is_initialized is True + assert actor._active is True # Actor after Actor(). async with Actor as actor: - assert actor._is_initialized is True + assert actor._active is True actor = await make_actor(label='actor-sequential-contexts', main_func=main) run_result = await run_actor(actor) diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index f0fd1e0e..30aa077d 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -60,7 +60,7 @@ def _prepare_test_env() -> None: if hasattr(apify._actor.Actor, '__wrapped__'): delattr(apify._actor.Actor, '__wrapped__') - apify._actor.Actor._is_initialized = False + apify._actor.Actor._active = False # Set the environment variable for the local storage directory to the temporary path. monkeypatch.setenv(ApifyEnvVars.LOCAL_STORAGE_DIR, str(tmp_path)) diff --git a/tests/unit/actor/test_actor_lifecycle.py b/tests/unit/actor/test_actor_lifecycle.py index 09338514..03fdd00e 100644 --- a/tests/unit/actor/test_actor_lifecycle.py +++ b/tests/unit/actor/test_actor_lifecycle.py @@ -68,41 +68,41 @@ async def test_actor_init_instance_manual() -> None: """Test that Actor instance can be properly initialized and cleaned up manually.""" actor = Actor() await actor.init() - assert actor._is_initialized is True + assert actor._active is True await actor.exit() - assert actor._is_initialized is False + assert actor._active is False async def test_actor_init_instance_async_with() -> None: """Test that Actor instance can be properly initialized and cleaned up using async context manager.""" actor = Actor() async with actor: - assert actor._is_initialized is True + assert actor._active is True - assert actor._is_initialized is False + assert actor._active is False async def test_actor_init_class_manual() -> None: """Test that Actor class can be properly initialized and cleaned up manually.""" await Actor.init() - assert Actor._is_initialized is True + assert Actor._active is True await Actor.exit() - assert not Actor._is_initialized + assert not Actor._active async def test_actor_init_class_async_with() -> None: """Test that Actor class can be properly initialized and cleaned up using async context manager.""" async with Actor: - assert Actor._is_initialized is True + assert Actor._active is True - assert not Actor._is_initialized + assert not Actor._active async def test_fail_properly_deinitializes_actor(actor: _ActorType) -> None: """Test that fail() method properly deinitializes the Actor.""" - assert actor._is_initialized + assert actor._active await actor.fail() - assert actor._is_initialized is False + assert actor._active is False async def test_actor_handles_exceptions_and_cleans_up_properly() -> None: @@ -111,16 +111,16 @@ async def test_actor_handles_exceptions_and_cleans_up_properly() -> None: with contextlib.suppress(Exception): async with Actor() as actor: - assert actor._is_initialized + assert actor._active raise Exception('Failed') # noqa: TRY002 assert actor is not None - assert actor._is_initialized is False + assert actor._active is False async def test_double_init_raises_runtime_error(actor: _ActorType) -> None: """Test that attempting to initialize an already initialized Actor raises RuntimeError.""" - assert actor._is_initialized + assert actor._active with pytest.raises(RuntimeError): await actor.init() @@ -196,7 +196,7 @@ def on_event(event_type: Event) -> Callable: actor = Actor() async with actor: - assert actor._is_initialized + assert actor._active actor.on(Event.PERSIST_STATE, on_event(Event.PERSIST_STATE)) actor.on(Event.SYSTEM_INFO, on_event(Event.SYSTEM_INFO)) await asyncio.sleep(1) @@ -249,12 +249,12 @@ async def test_actor_sequential_contexts(*, first_with_call: bool, second_with_c mock = AsyncMock() async with Actor(exit_process=False) if first_with_call else Actor as actor: await mock() - assert actor._is_initialized is True + assert actor._active is True # After exiting the context, new Actor instance can be created without conflicts. async with Actor() if second_with_call else Actor as actor: await mock() - assert actor._is_initialized is True + assert actor._active is True # The mock should have been called twice, once in each context. assert mock.call_count == 2 diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index 3f792ad8..8d8297c5 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -62,7 +62,7 @@ def prepare_test_env(monkeypatch: pytest.MonkeyPatch, tmp_path: Path) -> Callabl def _prepare_test_env() -> None: if hasattr(apify._actor.Actor, '__wrapped__'): delattr(apify._actor.Actor, '__wrapped__') - apify._actor.Actor._is_initialized = False + apify._actor.Actor._active = False # Set the environment variable for the local storage directory to the temporary path. monkeypatch.setenv(ApifyEnvVars.LOCAL_STORAGE_DIR, str(tmp_path))