Skip to content

Commit 0a65d2b

Browse files
committed
Add tracking for Celery canvas primitives (map/starmap/chunks)
Previously, tasks spawned via Celery canvas primitives like map, starmap, and chunks were not tracked because these execute as built-in celery.map or celery.starmap tasks, which were filtered out by the before_task_publish signal handler. This change adds a _maybe_create_task function called during task_prerun that: - Creates TaskBadger tasks for canvas primitives using the inner task's name (not "celery.map") - Includes metadata about canvas_type and item_count - Optionally includes celery_task_items when record_task_args is enabled - Carefully orders checks to avoid thread-local Badger state pollution - Respects existing tracking intent via taskbadger_track header For example, when using task.chunks(items, 2).group(), each chunk now creates a TaskBadger task named after the user's task with data showing it was executed via celery.starmap and how many items it processed.
1 parent 925af21 commit 0a65d2b

File tree

2 files changed

+158
-10
lines changed

2 files changed

+158
-10
lines changed

taskbadger/celery.py

Lines changed: 80 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,88 @@ def task_publish_handler(sender=None, headers=None, body=None, **kwargs):
212212
ctask.update_state(task_id=headers["id"], state="PENDING", meta=meta)
213213

214214

215+
def _maybe_create_task(signal_sender):
216+
"""Create a TaskBadger task if one doesn't exist yet.
217+
218+
This handles cases where before_task_publish didn't fire or was skipped:
219+
- Eager mode (before_task_publish doesn't fire)
220+
- Canvas primitives like map/starmap/chunks (fire for celery.* tasks)
221+
"""
222+
# Check if task was already created FIRST (before accessing Badger)
223+
# This avoids initializing thread-local Badger state for tasks like celery.ping
224+
task_id = _get_taskbadger_task_id(signal_sender.request)
225+
if task_id:
226+
return
227+
228+
task_name = signal_sender.name
229+
230+
# Skip built-in celery tasks that we don't track (like celery.ping)
231+
# Only handle celery.map and celery.starmap specially
232+
if task_name.startswith("celery.") and task_name not in ("celery.map", "celery.starmap"):
233+
return
234+
235+
# For non-canvas tasks, only create if there was an explicit intent to track
236+
# (indicated by taskbadger_track header). This prevents creating tasks when
237+
# Badger wasn't configured at publish time but has stale config in worker.
238+
headers = signal_sender.request.headers or {}
239+
is_canvas_task = task_name in ("celery.map", "celery.starmap")
240+
if not is_canvas_task and not headers.get("taskbadger_track"):
241+
return
242+
243+
# NOW it's safe to check Badger configuration
244+
if not Badger.is_configured():
245+
return
246+
247+
celery_system = Badger.current.settings.get_system_by_id("celery")
248+
data = None
249+
inner_task = None
250+
251+
# Handle celery.map and celery.starmap - extract the inner task name
252+
if task_name in ("celery.map", "celery.starmap"):
253+
inner_task_info = signal_sender.request.kwargs.get("task")
254+
if inner_task_info:
255+
# inner_task_info can be a dict (serialized signature) or a Signature object
256+
if isinstance(inner_task_info, dict):
257+
task_name = inner_task_info.get("task", task_name)
258+
elif hasattr(inner_task_info, "name"):
259+
task_name = inner_task_info.name
260+
# Get the actual task class to check if it uses Task base
261+
inner_task = celery.current_app.tasks.get(task_name)
262+
items = signal_sender.request.kwargs.get("it", [])
263+
# Convert to list if needed for counting and potential recording
264+
items_list = list(items) if not isinstance(items, (list, tuple)) else items
265+
item_count = len(items_list)
266+
data = {"canvas_type": signal_sender.name, "item_count": item_count}
267+
268+
# Include task items if record_task_args is enabled
269+
if celery_system and celery_system.record_task_args:
270+
try:
271+
_, _, value = serialization.dumps({"items": items_list}, serializer="json")
272+
items_data = json.loads(value)
273+
data["celery_task_items"] = items_data["items"]
274+
except Exception:
275+
log.warning("Error serializing canvas items for task '%s'", task_name)
276+
277+
# Check if we should track this task
278+
auto_track = celery_system and celery_system.track_task(task_name)
279+
# Check if the task (or inner task for map/starmap) uses our Task base class
280+
task_to_check = inner_task if inner_task else signal_sender
281+
manual_track = isinstance(task_to_check, Task)
282+
if not manual_track and not auto_track:
283+
return
284+
285+
enter_session()
286+
287+
task = create_task_safe(task_name, status=StatusEnum.PENDING, data=data)
288+
if task:
289+
# Store the task ID in the request so _update_task can find it
290+
signal_sender.request.update({TB_TASK_ID: task.id})
291+
safe_get_task.cache.set((task.id,), task)
292+
293+
215294
@task_prerun.connect
216295
def task_prerun_handler(sender=None, **kwargs):
296+
_maybe_create_task(sender)
217297
_update_task(sender, StatusEnum.PROCESSING)
218298

219299

tests/test_celery.py

Lines changed: 78 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -393,32 +393,100 @@ def task_signature(self, a):
393393

394394
@pytest.mark.usefixtures("_bind_settings")
395395
def test_task_map(celery_session_worker):
396-
"""Tasks executed in a map or starmap are not executed as tasks"""
396+
"""Tasks executed via map canvas primitive should be tracked.
397+
398+
Note: The individual function calls within map are not separate Celery tasks,
399+
so we track the map operation itself with the inner task's name.
400+
"""
397401

398402
@celery.shared_task(bind=True, base=Task)
399-
def task_map(self, a):
400-
assert self.taskbadger_task is None
401-
assert Badger.current.session().client is None
403+
def task_map_fn(self, a):
402404
return a * 2
403405

404406
celery_session_worker.reload()
405407

406-
task_map = task_map.map(list(range(5)))
408+
map_canvas = task_map_fn.map(list(range(5)))
407409

408410
with (
409411
mock.patch("taskbadger.celery.create_task_safe") as create,
410412
mock.patch("taskbadger.celery.update_task_safe") as update,
411-
mock.patch("taskbadger.celery.get_task") as get_task,
413+
mock.patch("taskbadger.celery.get_task"),
412414
):
413-
result = task_map.delay()
415+
tb_task = task_for_test()
416+
create.return_value = tb_task
417+
result = map_canvas.delay()
414418
assert result.get(timeout=10, propagate=True) == [0, 2, 4, 6, 8]
415419

416-
assert create.call_count == 0
417-
assert get_task.call_count == 0
418-
assert update.call_count == 0
420+
# Map operation should create one TaskBadger task
421+
assert create.call_count == 1
422+
# Verify the task name uses the inner task's name, not "celery.map"
423+
call_args = create.call_args
424+
assert "task_map_fn" in call_args[0][0]
425+
assert update.call_count == 2 # PROCESSING and SUCCESS
419426
assert Badger.current.session().client is None
420427

421428

429+
@pytest.mark.usefixtures("_bind_settings")
430+
def test_task_starmap(celery_session_worker):
431+
"""Tasks executed via starmap canvas primitive should be tracked."""
432+
433+
@celery.shared_task(bind=True, base=Task)
434+
def task_starmap_fn(self, a, b):
435+
return a + b
436+
437+
celery_session_worker.reload()
438+
439+
starmap_canvas = task_starmap_fn.starmap([(1, 2), (3, 4), (5, 6)])
440+
441+
with (
442+
mock.patch("taskbadger.celery.create_task_safe") as create,
443+
mock.patch("taskbadger.celery.update_task_safe") as update,
444+
mock.patch("taskbadger.celery.get_task"),
445+
):
446+
tb_task = task_for_test()
447+
create.return_value = tb_task
448+
result = starmap_canvas.delay()
449+
assert result.get(timeout=10, propagate=True) == [3, 7, 11]
450+
451+
# Starmap operation should create one TaskBadger task
452+
assert create.call_count == 1
453+
# Verify the task name uses the inner task's name
454+
call_args = create.call_args
455+
assert "task_starmap_fn" in call_args[0][0]
456+
assert update.call_count == 2 # PROCESSING and SUCCESS
457+
458+
459+
@pytest.mark.usefixtures("_bind_settings")
460+
def test_task_chunks(celery_session_worker):
461+
"""Tasks executed via chunks canvas primitive should be tracked."""
462+
463+
@celery.shared_task(bind=True, base=Task)
464+
def task_chunks_fn(self, a):
465+
return a * 2
466+
467+
celery_session_worker.reload()
468+
469+
# chunks creates multiple starmap tasks
470+
chunks_canvas = task_chunks_fn.chunks([(x,) for x in range(6)], 2).group()
471+
472+
with (
473+
mock.patch("taskbadger.celery.create_task_safe") as create,
474+
mock.patch("taskbadger.celery.update_task_safe") as update,
475+
mock.patch("taskbadger.celery.get_task"),
476+
):
477+
tb_task = task_for_test()
478+
create.return_value = tb_task
479+
result = chunks_canvas.delay()
480+
assert result.get(timeout=10, propagate=True) == [[0, 2], [4, 6], [8, 10]]
481+
482+
# Each chunk should create a TaskBadger task (3 chunks of 2)
483+
assert create.call_count == 3
484+
# Verify the task names use the inner task's name
485+
for call in create.call_args_list:
486+
assert "task_chunks_fn" in call[0][0]
487+
assert update.call_count == 6 # 3 tasks * 2 updates each
488+
489+
422490
@pytest.mark.usefixtures("_bind_settings")
423491
def test_celery_task_already_in_terminal_state(celery_session_worker):
424492
@celery.shared_task(bind=True, base=Task)

0 commit comments

Comments
 (0)