From c46223d92d4419d6416e590ad28d49a97a9608c8 Mon Sep 17 00:00:00 2001 From: Saul Cooperman Date: Sun, 22 Feb 2026 17:28:12 +0000 Subject: [PATCH 1/2] Support subinterpreters when reporting Python stacks Iterate over all interpreters in the linked list when determining the PID offset, fixing an issue where pystack only inspected the first interpreter, which is not guaranteed to be the main interpreter. This ensures the correct TID is located even when multiple subinterpreters are present. Add support for subinterpreters in pure Python stack reporting. Threads running in subinterpreters are now detected and grouped by interpreter ID. Native stack reporting for subinterpreters is not yet supported. Signed-off-by: Saul Cooperman --- Dockerfile | 3 + news/279.bugfix.rst | 5 + news/279.feature.rst | 3 + setup.py | 1 + src/pystack/__init__.py | 2 - src/pystack/__main__.py | 10 +- src/pystack/_pystack.pyx | 18 +- src/pystack/_pystack/CMakeLists.txt | 3 +- src/pystack/_pystack/cpython/interpreter.h | 4 +- src/pystack/_pystack/interpreter.cpp | 36 +++ src/pystack/_pystack/interpreter.h | 24 ++ src/pystack/_pystack/interpreter.pxd | 13 + src/pystack/_pystack/process.cpp | 1 + src/pystack/_pystack/pythread.cpp | 133 ++++---- src/pystack/_pystack/version.cpp | 47 ++- src/pystack/_pystack/version.h | 1 + src/pystack/traceback_formatter.py | 33 +- src/pystack/types.py | 1 + tests/integration/test_subinterpreters.py | 349 +++++++++++++++++++++ tests/unit/test_main.py | 204 +++++++----- tests/unit/test_traceback_formatter.py | 274 +++++++++++++++- tests/utils.py | 7 + 22 files changed, 1011 insertions(+), 161 deletions(-) create mode 100644 news/279.bugfix.rst create mode 100644 news/279.feature.rst create mode 100644 src/pystack/_pystack/interpreter.cpp create mode 100644 src/pystack/_pystack/interpreter.h create mode 100644 src/pystack/_pystack/interpreter.pxd create mode 100644 tests/integration/test_subinterpreters.py diff --git a/Dockerfile b/Dockerfile index aa7a050f..96b153c1 100644 --- a/Dockerfile +++ b/Dockerfile @@ -66,6 +66,9 @@ RUN apt-get update \ python3.13-dev \ python3.13-dbg \ python3.13-venv \ + python3.14-dev \ + python3.14-dbg \ + python3.14-venv \ make \ cmake \ gdb \ diff --git a/news/279.bugfix.rst b/news/279.bugfix.rst new file mode 100644 index 00000000..99dfcba3 --- /dev/null +++ b/news/279.bugfix.rst @@ -0,0 +1,5 @@ +Fix an issue where the PID offset could not be determined when multiple +subinterpreters were present. Previously, pystack only checked the first +interpreter in the linked list, which was not guaranteed to be the main +interpreter. The fix now iterates over all interpreters and correctly locates +the TID. diff --git a/news/279.feature.rst b/news/279.feature.rst new file mode 100644 index 00000000..f4b6f28c --- /dev/null +++ b/news/279.feature.rst @@ -0,0 +1,3 @@ +Add support for subinterpreters when reporting pure Python stacks. Threads +running in subinterpreters are now identified and grouped by interpreter ID. +Native stack reporting for subinterpreters is not yet supported. diff --git a/setup.py b/setup.py index d76c9375..58bc40d2 100644 --- a/setup.py +++ b/setup.py @@ -77,6 +77,7 @@ "src/pystack/_pystack.pyx", "src/pystack/_pystack/corefile.cpp", "src/pystack/_pystack/elf_common.cpp", + "src/pystack/_pystack/interpreter.cpp", "src/pystack/_pystack/logging.cpp", "src/pystack/_pystack/mem.cpp", "src/pystack/_pystack/process.cpp", diff --git a/src/pystack/__init__.py b/src/pystack/__init__.py index e973464d..a98c6c19 100644 --- a/src/pystack/__init__.py +++ b/src/pystack/__init__.py @@ -1,7 +1,5 @@ from ._version import __version__ -from .traceback_formatter import print_thread __all__ = [ "__version__", - "print_thread", ] diff --git a/src/pystack/__main__.py b/src/pystack/__main__.py index 55102c91..5744d7f7 100644 --- a/src/pystack/__main__.py +++ b/src/pystack/__main__.py @@ -19,13 +19,13 @@ from pystack.process import is_gzip from . import errors -from . import print_thread from .colors import colored from .engine import CoreFileAnalyzer from .engine import NativeReportingMode from .engine import StackMethod from .engine import get_process_threads from .engine import get_process_threads_for_core +from .traceback_formatter import TracebackPrinter PERMISSION_ERROR_MSG = "Operation not permitted" NO_SUCH_PROCESS_ERROR_MSG = "No such process" @@ -285,6 +285,9 @@ def process_remote(parser: argparse.ArgumentParser, args: argparse.Namespace) -> if not args.block and args.native_mode != NativeReportingMode.OFF: parser.error("Native traces are only available in blocking mode") + printer = TracebackPrinter( + native_mode=args.native_mode, include_subinterpreters=True + ) for thread in get_process_threads( args.pid, stop_process=args.block, @@ -292,7 +295,7 @@ def process_remote(parser: argparse.ArgumentParser, args: argparse.Namespace) -> locals=args.locals, method=StackMethod.ALL if args.exhaustive else StackMethod.AUTO, ): - print_thread(thread, args.native_mode) + printer.print_thread(thread) def format_psinfo_information(psinfo: Dict[str, Any]) -> str: @@ -414,6 +417,7 @@ def process_core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> N elf_id if elf_id else "", ) + printer = TracebackPrinter(args.native_mode, include_subinterpreters=True) for thread in get_process_threads_for_core( corefile, executable, @@ -422,7 +426,7 @@ def process_core(parser: argparse.ArgumentParser, args: argparse.Namespace) -> N locals=args.locals, method=StackMethod.ALL if args.exhaustive else StackMethod.AUTO, ): - print_thread(thread, args.native_mode) + printer.print_thread(thread) if __name__ == "__main__": # pragma: no cover diff --git a/src/pystack/_pystack.pyx b/src/pystack/_pystack.pyx index de16701d..d9a0488b 100644 --- a/src/pystack/_pystack.pyx +++ b/src/pystack/_pystack.pyx @@ -22,6 +22,7 @@ from _pystack.elf_common cimport CoreFileAnalyzer as NativeCoreFileAnalyzer from _pystack.elf_common cimport ProcessAnalyzer as NativeProcessAnalyzer from _pystack.elf_common cimport SectionInfo from _pystack.elf_common cimport getSectionInfo +from _pystack.interpreter cimport InterpreterUtils from _pystack.logging cimport initializePythonLoggerInterface from _pystack.mem cimport AbstractRemoteMemoryManager from _pystack.mem cimport MemoryMapInformation as CppMemoryMapInformation @@ -462,6 +463,7 @@ cdef object _construct_threads_from_interpreter_state( bint add_native_traces, bint resolve_locals, ): + interpreter_id = InterpreterUtils.getInterpreterId(manager, head) LOGGER.info("Fetching Python threads") threads = [] @@ -486,6 +488,7 @@ cdef object _construct_threads_from_interpreter_state( current_thread.isGilHolder(), current_thread.isGCCollecting(), python_version, + interpreter_id, name=get_thread_name(pid, current_thread.Tid()), ) ) @@ -622,7 +625,7 @@ def _get_process_threads( ) all_tids = list(manager.get().Tids()) - if head: + while head: add_native_traces = native_mode != NativeReportingMode.OFF for thread in _construct_threads_from_interpreter_state( manager, @@ -635,6 +638,7 @@ def _get_process_threads( if thread.tid in all_tids: all_tids.remove(thread.tid) yield thread + head = InterpreterUtils.getNextInterpreter(manager, head) if native_mode == NativeReportingMode.ALL: yield from _construct_os_threads(manager, pid, all_tids) @@ -769,14 +773,20 @@ def _get_process_threads_for_core( all_tids = list(manager.get().Tids()) - if head: - native = native_mode in {NativeReportingMode.PYTHON, NativeReportingMode.ALL} + while head: + add_native_traces = native_mode != NativeReportingMode.OFF for thread in _construct_threads_from_interpreter_state( - manager, head, pymanager.pid, pymanager.python_version, native, locals + manager, + head, + pymanager.pid, + pymanager.python_version, + add_native_traces, + locals, ): if thread.tid in all_tids: all_tids.remove(thread.tid) yield thread + head = InterpreterUtils.getNextInterpreter(manager, head) if native_mode == NativeReportingMode.ALL: yield from _construct_os_threads(manager, pymanager.pid, all_tids) diff --git a/src/pystack/_pystack/CMakeLists.txt b/src/pystack/_pystack/CMakeLists.txt index 5a0fd8a8..74183d11 100644 --- a/src/pystack/_pystack/CMakeLists.txt +++ b/src/pystack/_pystack/CMakeLists.txt @@ -21,6 +21,7 @@ add_library(_pystack STATIC pythread.cpp version.cpp elf_common.cpp - pytypes.cpp) + pytypes.cpp + interpreter.cpp) set_property(TARGET _pystack PROPERTY POSITION_INDEPENDENT_CODE ON) include_directories("." "cpython" ${PYTHON_INCLUDE_DIRS}) diff --git a/src/pystack/_pystack/cpython/interpreter.h b/src/pystack/_pystack/cpython/interpreter.h index b35b26ba..d75558fe 100644 --- a/src/pystack/_pystack/cpython/interpreter.h +++ b/src/pystack/_pystack/cpython/interpreter.h @@ -375,10 +375,10 @@ struct _gil_runtime_state int locked; unsigned long switch_number; pthread_cond_t cond; - pthread_cond_t mutex; + pthread_mutex_t mutex; #ifdef FORCE_SWITCHING pthread_cond_t switch_cond; - pthread_cond_t switch_mutex; + pthread_mutex_t switch_mutex; #endif }; diff --git a/src/pystack/_pystack/interpreter.cpp b/src/pystack/_pystack/interpreter.cpp new file mode 100644 index 00000000..4f52e043 --- /dev/null +++ b/src/pystack/_pystack/interpreter.cpp @@ -0,0 +1,36 @@ +#include + +#include "interpreter.h" +#include "logging.h" +#include "process.h" +#include "structure.h" +#include "version.h" + +namespace pystack { + +remote_addr_t +InterpreterUtils::getNextInterpreter( + const std::shared_ptr& manager, + remote_addr_t interpreter_addr) +{ + Structure is(manager, interpreter_addr); + return is.getField(&py_is_v::o_next); +} + +int64_t +InterpreterUtils::getInterpreterId( + const std::shared_ptr& manager, + remote_addr_t interpreter_addr) +{ + if (!manager->versionIsAtLeast(3, 7)) { + // No support for subinterpreters so the only interpreter is ID 0. + return 0; + } + + Structure is(manager, interpreter_addr); + int64_t id_value = is.getField(&py_is_v::o_id); + + return id_value; +} + +} // namespace pystack diff --git a/src/pystack/_pystack/interpreter.h b/src/pystack/_pystack/interpreter.h new file mode 100644 index 00000000..0138ff84 --- /dev/null +++ b/src/pystack/_pystack/interpreter.h @@ -0,0 +1,24 @@ +#pragma once + +#include +#include + +#include "mem.h" +#include "process.h" + +namespace pystack { + +class InterpreterUtils +{ + public: + // Static Methods + static remote_addr_t getNextInterpreter( + const std::shared_ptr& manager, + remote_addr_t interpreter_addr); + + static int64_t getInterpreterId( + const std::shared_ptr& manager, + remote_addr_t interpreter_addr); +}; + +} // namespace pystack diff --git a/src/pystack/_pystack/interpreter.pxd b/src/pystack/_pystack/interpreter.pxd new file mode 100644 index 00000000..0248f468 --- /dev/null +++ b/src/pystack/_pystack/interpreter.pxd @@ -0,0 +1,13 @@ +from _pystack.mem cimport remote_addr_t +from _pystack.process cimport AbstractProcessManager +from libc.stdint cimport int64_t +from libcpp.memory cimport shared_ptr + + +cdef extern from "interpreter.h" namespace "pystack": + cdef cppclass InterpreterUtils: + @staticmethod + remote_addr_t getNextInterpreter(shared_ptr[AbstractProcessManager] manager, remote_addr_t interpreter_addr) except + + + @staticmethod + int64_t getInterpreterId(shared_ptr[AbstractProcessManager] manager, remote_addr_t interpreter_addr) except + diff --git a/src/pystack/_pystack/process.cpp b/src/pystack/_pystack/process.cpp index 56e32bae..761509a0 100644 --- a/src/pystack/_pystack/process.cpp +++ b/src/pystack/_pystack/process.cpp @@ -964,6 +964,7 @@ AbstractProcessManager::copyDebugOffsets(Structure& py_runtime, py set_offset(py_is.o_sysdict, &py_runtime_v::o_dbg_off_interpreter_state_sysdict); set_offset(py_is.o_builtins, &py_runtime_v::o_dbg_off_interpreter_state_builtins); set_offset(py_is.o_gil_runtime_state, &py_runtime_v::o_dbg_off_interpreter_state_ceval_gil); + set_offset(py_is.o_id, &py_runtime_v::o_dbg_off_interpreter_state_id); set_size(py_thread, &py_runtime_v::o_dbg_off_thread_state_struct_size); set_offset(py_thread.o_prev, &py_runtime_v::o_dbg_off_thread_state_prev); diff --git a/src/pystack/_pystack/pythread.cpp b/src/pystack/_pystack/pythread.cpp index d50e4126..b6b90a24 100644 --- a/src/pystack/_pystack/pythread.cpp +++ b/src/pystack/_pystack/pythread.cpp @@ -2,6 +2,8 @@ #include #include +#include "cpython/pthread.h" +#include "interpreter.h" #include "logging.h" #include "mem.h" #include "native_frame.h" @@ -11,8 +13,6 @@ #include "structure.h" #include "version.h" -#include "cpython/pthread.h" - namespace pystack { Thread::Thread(pid_t pid, pid_t tid) @@ -47,79 +47,88 @@ findPthreadTidOffset( remote_addr_t interp_state_addr) { LOG(DEBUG) << "Attempting to locate tid offset in pthread structure"; - Structure is(manager, interp_state_addr); - auto current_thread_addr = is.getField(&py_is_v::o_tstate_head); + // If interp_state_addr does not point to the main interpreter (id 0) we won't find the + // PID == TID in the interpreter threads. Hence, we traverse the linked list of interpreters. The + // main interpreter is not necessarily the head of the linked lists of interpreters. + + while (interp_state_addr != 0) { + Structure is(manager, interp_state_addr); + + auto current_thread_addr = is.getField(&py_is_v::o_tstate_head); - auto thread_head = current_thread_addr; + auto thread_head = current_thread_addr; - // Iterate over all Python threads until we find a thread that has a tid equal to - // the process pid. This works because in the main thread the tid is equal to the pid, - // so when this happens it has to happen on the main thread. Note that the main thread - // is not necessarily at the head of the Python thread linked list + // Iterate over all Python threads until we find a thread that has a tid equal to + // the process pid. This works because in the main thread the tid is equal to the pid, + // so when this happens it has to happen on the main thread. Note that the main thread + // is not necessarily at the head of the Python thread linked list #if defined(__GLIBC__) - // If we detect GLIBC, we can try the two main known structs for 'struct - // pthread' that we know about to avoid having to do guess-work by doing a - // linear scan over the struct. - while (current_thread_addr != (remote_addr_t) nullptr) { - Structure current_thread(manager, current_thread_addr); - auto pthread_id_addr = current_thread.getField(&py_thread_v::o_thread_id); - - pid_t the_tid; - std::vector glibc_pthread_offset_candidates = { - offsetof(_pthread_structure_with_simple_header, tid), - offsetof(_pthread_structure_with_tcbhead, tid)}; - for (off_t candidate : glibc_pthread_offset_candidates) { - manager->copyObjectFromProcess((remote_addr_t)(pthread_id_addr + candidate), &the_tid); - if (the_tid == manager->Pid()) { - LOG(DEBUG) << "Tid offset located using GLIBC offsets at offset " << std::showbase - << std::hex << candidate << " in pthread structure"; - return candidate; + // If we detect GLIBC, we can try the two main known structs for 'struct + // pthread' that we know about to avoid having to do guess-work by doing a + // linear scan over the struct. + while (current_thread_addr != (remote_addr_t) nullptr) { + Structure current_thread(manager, current_thread_addr); + auto pthread_id_addr = current_thread.getField(&py_thread_v::o_thread_id); + + pid_t the_tid; + std::vector glibc_pthread_offset_candidates = { + offsetof(_pthread_structure_with_simple_header, tid), + offsetof(_pthread_structure_with_tcbhead, tid)}; + for (off_t candidate : glibc_pthread_offset_candidates) { + manager->copyObjectFromProcess((remote_addr_t)(pthread_id_addr + candidate), &the_tid); + if (the_tid == manager->Pid()) { + LOG(DEBUG) << "Tid offset located using GLIBC offsets at offset " << std::showbase + << std::hex << candidate << " in pthread structure"; + return candidate; + } } + remote_addr_t next_thread_addr = current_thread.getField(&py_thread_v::o_next); + if (next_thread_addr == current_thread_addr) { + break; + } + current_thread_addr = next_thread_addr; } - remote_addr_t next_thread_addr = current_thread.getField(&py_thread_v::o_next); - if (next_thread_addr == current_thread_addr) { - break; - } - current_thread_addr = next_thread_addr; - } #endif - current_thread_addr = thread_head; - - while (current_thread_addr != (remote_addr_t) nullptr) { - Structure current_thread(manager, current_thread_addr); - auto pthread_id_addr = current_thread.getField(&py_thread_v::o_thread_id); - - // Attempt to locate a field in the pthread struct that's equal to the pid. - uintptr_t buffer[100]; - size_t buffer_size = sizeof(buffer); - while (buffer_size > 0) { - try { - LOG(DEBUG) << "Trying to copy a buffer of " << buffer_size << " bytes to get pthread ID"; - manager->copyMemoryFromProcess(pthread_id_addr, buffer_size, &buffer); - break; - } catch (const RemoteMemCopyError& ex) { - LOG(DEBUG) << "Failed to copy buffer to get pthread ID"; - buffer_size /= 2; + current_thread_addr = thread_head; + + while (current_thread_addr != (remote_addr_t) nullptr) { + Structure current_thread(manager, current_thread_addr); + auto pthread_id_addr = current_thread.getField(&py_thread_v::o_thread_id); + + // Attempt to locate a field in the pthread struct that's equal to the pid. + uintptr_t buffer[100]; + size_t buffer_size = sizeof(buffer); + while (buffer_size > 0) { + try { + LOG(DEBUG) << "Trying to copy a buffer of " << buffer_size + << " bytes to get pthread ID"; + manager->copyMemoryFromProcess(pthread_id_addr, buffer_size, &buffer); + break; + } catch (const RemoteMemCopyError& ex) { + LOG(DEBUG) << "Failed to copy buffer to get pthread ID"; + buffer_size /= 2; + } } - } - LOG(DEBUG) << "Copied a buffer of " << buffer_size << " bytes to get pthread ID"; - for (size_t i = 0; i < buffer_size / sizeof(uintptr_t); i++) { - if (static_cast(buffer[i]) == manager->Pid()) { - off_t offset = sizeof(uintptr_t) * i; - LOG(DEBUG) << "Tid offset located by scanning at offset " << std::showbase << std::hex - << offset << " in pthread structure"; - return offset; + LOG(DEBUG) << "Copied a buffer of " << buffer_size << " bytes to get pthread ID"; + for (size_t i = 0; i < buffer_size / sizeof(uintptr_t); i++) { + if (static_cast(buffer[i]) == manager->Pid()) { + off_t offset = sizeof(uintptr_t) * i; + LOG(DEBUG) << "Tid offset located by scanning at offset " << std::showbase + << std::hex << offset << " in pthread structure"; + return offset; + } } - } - remote_addr_t next_thread_addr = current_thread.getField(&py_thread_v::o_next); - if (next_thread_addr == current_thread_addr) { - break; + remote_addr_t next_thread_addr = current_thread.getField(&py_thread_v::o_next); + if (next_thread_addr == current_thread_addr) { + break; + } + current_thread_addr = next_thread_addr; } - current_thread_addr = next_thread_addr; + interp_state_addr = InterpreterUtils::getNextInterpreter(manager, interp_state_addr); } LOG(ERROR) << "Could not find tid offset in pthread structure"; return 0; diff --git a/src/pystack/_pystack/version.cpp b/src/pystack/_pystack/version.cpp index f58ff878..a31ad1fc 100644 --- a/src/pystack/_pystack/version.cpp +++ b/src/pystack/_pystack/version.cpp @@ -179,6 +179,23 @@ py_is() }; } +template +constexpr py_is_v +py_isv7() +{ + return { + sizeof(T), + {offsetof(T, next)}, + {offsetof(T, tstate_head)}, + {offsetof(T, gc)}, + {offsetof(T, modules)}, + {offsetof(T, sysdict)}, + {offsetof(T, builtins)}, + {0}, + {offsetof(T, id)}, + }; +} + template constexpr py_is_v py_isv311() @@ -191,6 +208,8 @@ py_isv311() {offsetof(T, modules)}, {offsetof(T, sysdict)}, {offsetof(T, builtins)}, + {0}, + {offsetof(T, id)}, }; } @@ -207,6 +226,24 @@ py_isv312() {offsetof(T, sysdict)}, {offsetof(T, builtins)}, {offsetof(T, ceval.gil)}, + {offsetof(T, id)}, + }; +} + +template +constexpr py_is_v +py_isv314() +{ + return { + sizeof(T), + {offsetof(T, next)}, + {offsetof(T, threads.head)}, + {offsetof(T, gc)}, + {offsetof(T, imports.modules)}, + {offsetof(T, sysdict)}, + {offsetof(T, builtins)}, + {offsetof(T, _gil)}, + {offsetof(T, id)}, }; } @@ -578,7 +615,7 @@ python_v python_v3_7 = { py_code(), py_frame(), py_thread(), - py_is(), + py_isv7(), py_runtime(), py_gc(), }; @@ -600,7 +637,7 @@ python_v python_v3_8 = { py_code(), py_frame(), py_thread(), - py_is(), + py_isv7(), py_runtime(), py_gc(), }; @@ -622,7 +659,7 @@ python_v python_v3_9 = { py_code(), py_frame(), py_thread(), - py_is(), + py_isv7(), py_runtime(), py_gc(), }; @@ -644,7 +681,7 @@ python_v python_v3_10 = { py_code(), py_frame(), py_thread(), - py_is(), + py_isv7(), py_runtime(), py_gc(), }; @@ -737,7 +774,7 @@ python_v python_v3_14 = { py_codev311(), py_framev314(), py_threadv313(), - py_isv312(), + py_isv314(), py_runtimev313(), py_gc(), py_cframe(), diff --git a/src/pystack/_pystack/version.h b/src/pystack/_pystack/version.h index c56851ac..d9b2b2de 100644 --- a/src/pystack/_pystack/version.h +++ b/src/pystack/_pystack/version.h @@ -241,6 +241,7 @@ struct py_is_v FieldOffset o_sysdict; FieldOffset o_builtins; FieldOffset o_gil_runtime_state; + FieldOffset o_id; }; struct py_gc_v diff --git a/src/pystack/traceback_formatter.py b/src/pystack/traceback_formatter.py index 35637b97..9d9e303b 100644 --- a/src/pystack/traceback_formatter.py +++ b/src/pystack/traceback_formatter.py @@ -12,9 +12,36 @@ from .types import frame_type -def print_thread(thread: PyThread, native_mode: NativeReportingMode) -> None: - for line in format_thread(thread, native_mode): - print(line, file=sys.stdout, flush=True) +class TracebackPrinter: + def __init__( + self, native_mode: NativeReportingMode, include_subinterpreters: bool = False + ): + self.native_mode = native_mode + self.include_subinterpreters = include_subinterpreters + self._current_interpreter_id = -1 + + def print_thread(self, thread: PyThread) -> None: + # Print interpreter header if we've switched interpreters + if self.include_subinterpreters: + if thread.interpreter_id != self._current_interpreter_id: + self._print_interpreter_header(thread.interpreter_id) + self._current_interpreter_id = ( + thread.interpreter_id if thread.interpreter_id is not None else -1 + ) + + # Print the thread with indentation + for line in format_thread(thread, self.native_mode): + if self.include_subinterpreters: + print(" " * 2, end="") + print(line, file=sys.stdout, flush=True) + + def _print_interpreter_header(self, interpreter_id: Optional[int]) -> None: + header = ( + f"Interpreter-{interpreter_id if interpreter_id is not None else 'Unknown'}" + ) + if interpreter_id == 0: + header += " (main)" + print(header, file=sys.stdout, flush=True) def format_frame(frame: PyFrame) -> Iterable[str]: diff --git a/src/pystack/types.py b/src/pystack/types.py index fbd1eb13..5eb77ced 100644 --- a/src/pystack/types.py +++ b/src/pystack/types.py @@ -115,6 +115,7 @@ class PyThread: holds_the_gil: int is_gc_collecting: int python_version: Optional[Tuple[int, int]] + interpreter_id: Optional[int] = None name: Optional[str] = None @property diff --git a/tests/integration/test_subinterpreters.py b/tests/integration/test_subinterpreters.py new file mode 100644 index 00000000..91f08013 --- /dev/null +++ b/tests/integration/test_subinterpreters.py @@ -0,0 +1,349 @@ +import io +from collections import Counter +from contextlib import redirect_stdout +from pathlib import Path +from typing import Set + +import pytest + +from pystack.engine import NativeReportingMode +from pystack.engine import StackMethod +from pystack.engine import get_process_threads +from pystack.engine import get_process_threads_for_core +from pystack.traceback_formatter import TracebackPrinter +from pystack.types import NativeFrame +from pystack.types import frame_type +from tests.utils import ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +from tests.utils import generate_core_file +from tests.utils import spawn_child_process + +NUM_INTERPRETERS = 3 +NUM_INTERPRETERS_WITH_THREADS = 2 +NUM_THREADS_PER_SUBINTERPRETER = 2 + +PROGRAM = f"""\ +import sys +import threading +import time + +from concurrent import interpreters + +NUM_INTERPRETERS = {NUM_INTERPRETERS} + + +def start_interpreter_async(interp, code): + t = threading.Thread(target=interp.exec, args=(code,)) + t.daemon = True + t.start() + return t + + +CODE = '''\\ +import time +while True: + time.sleep(1) +''' + +threads = [] +for _ in range(NUM_INTERPRETERS): + interp = interpreters.create() + t = start_interpreter_async(interp, CODE) + threads.append(t) + +# Give sub-interpreters time to start executing +time.sleep(1) + +fifo = sys.argv[1] +with open(fifo, "w") as f: + f.write("ready") + +while True: + time.sleep(1) +""" + + +PROGRAM_WITH_THREADS = f"""\ +import sys +import threading +import time + +from concurrent import interpreters + +NUM_INTERPRETERS = {NUM_INTERPRETERS_WITH_THREADS} + + +def start_interpreter_async(interp, code): + t = threading.Thread(target=interp.exec, args=(code,)) + t.daemon = True + t.start() + return t + + +CODE = '''\\ +import threading +import time + +NUM_THREADS = {NUM_THREADS_PER_SUBINTERPRETER} + +def worker(): + while True: + time.sleep(1) + +threads = [] +for _ in range(NUM_THREADS): + t = threading.Thread(target=worker) + # daemon threads are disabled in isolated subinterpreters + t.start() + threads.append(t) + +while True: + time.sleep(1) +''' + +threads = [] +for _ in range(NUM_INTERPRETERS): + interp = interpreters.create() + t = start_interpreter_async(interp, CODE) + threads.append(t) + +# Give sub-interpreters and their internal workers time to start. +time.sleep(2) + +fifo = sys.argv[1] +with open(fifo, "w") as f: + f.write("ready") + +while True: + time.sleep(1) +""" + + +def _collect_threads( + python_executable: Path, + tmpdir: Path, + native_mode: NativeReportingMode = NativeReportingMode.OFF, +): + test_file = Path(str(tmpdir)) / "subinterpreters_program.py" + test_file.write_text(PROGRAM) + + with spawn_child_process( + str(python_executable), str(test_file), tmpdir + ) as child_process: + return list( + get_process_threads( + child_process.pid, + stop_process=True, + native_mode=native_mode, + ) + ) + + +def _assert_interpreter_headers( + threads, + native_mode: NativeReportingMode, + interpreter_ids, +) -> str: + printer = TracebackPrinter( + native_mode=native_mode, + include_subinterpreters=True, + ) + output = io.StringIO() + with redirect_stdout(output): + for thread in threads: + printer.print_thread(thread) + + result = output.getvalue() + assert "Interpreter-0 (main)" in result + for interpreter_id in interpreter_ids: + if interpreter_id == 0: + continue + assert f"Interpreter-{interpreter_id}" in result + return result + + +def _count_threads_by_interpreter(threads): + return dict( + Counter( + thread.interpreter_id + for thread in threads + if thread.interpreter_id is not None + ) + ) + + +def _interpreter_ids(threads) -> Set[int]: + return { + thread.interpreter_id for thread in threads if thread.interpreter_id is not None + } + + +def _assert_subinterpreter_coverage(threads) -> Set[int]: + interpreter_ids = _interpreter_ids(threads) + assert 0 in interpreter_ids + assert len(interpreter_ids) == NUM_INTERPRETERS + 1 + return interpreter_ids + + +def _assert_native_eval_symbols(threads) -> None: + eval_frames = [ + frame + for thread in threads + for frame in thread.native_frames + if frame_type(frame, thread.python_version) == NativeFrame.FrameType.EVAL + ] + assert eval_frames + assert all("?" not in frame.symbol for frame in eval_frames) + if any(frame.linenumber == 0 for frame in eval_frames): # pragma: no cover + assert all(frame.linenumber == 0 for frame in eval_frames) + assert all(frame.path == "???" for frame in eval_frames) + else: # pragma: no cover + assert all(frame.linenumber != 0 for frame in eval_frames) + assert any(frame.path and "?" not in frame.path for frame in eval_frames) + + +@ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +def test_subinterpreters(python, tmpdir): + _, python_executable = python + + threads = _collect_threads( + python_executable=python_executable, + tmpdir=tmpdir, + native_mode=NativeReportingMode.OFF, + ) + + interpreter_ids = _assert_subinterpreter_coverage(threads) + assert all(not thread.native_frames for thread in threads) + _assert_interpreter_headers( + threads=threads, + native_mode=NativeReportingMode.OFF, + interpreter_ids=interpreter_ids, + ) + + +@ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +@pytest.mark.parametrize( + "native_mode", + [ + NativeReportingMode.PYTHON, + NativeReportingMode.LAST, + NativeReportingMode.ALL, + ], + ids=["python", "last", "all"], +) +def test_subinterpreters_with_native(python, tmpdir, native_mode): + _, python_executable = python + + threads = _collect_threads( + python_executable=python_executable, + tmpdir=tmpdir, + native_mode=native_mode, + ) + + interpreter_ids = _assert_subinterpreter_coverage(threads) + assert any(thread.native_frames for thread in threads) + _assert_native_eval_symbols(threads) + + output = _assert_interpreter_headers( + threads=threads, + native_mode=native_mode, + interpreter_ids=interpreter_ids, + ) + assert "(C)" in output or "Unable to merge native stack" in output + + +@ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +def test_subinterpreters_many_threads_with_native(python, tmpdir): + _, python_executable = python + + test_file = Path(str(tmpdir)) / "subinterpreters_with_threads_program.py" + test_file.write_text(PROGRAM_WITH_THREADS) + + with spawn_child_process(python_executable, test_file, tmpdir) as child_process: + threads = list( + get_process_threads( + child_process.pid, + stop_process=True, + native_mode=NativeReportingMode.PYTHON, + method=StackMethod.DEBUG_OFFSETS, + ) + ) + + interpreter_ids = _interpreter_ids(threads) + assert 0 in interpreter_ids + assert len(interpreter_ids) == NUM_INTERPRETERS_WITH_THREADS + 1 + + counts_by_interpreter = _count_threads_by_interpreter(threads) + assert all( + counts_by_interpreter.get(interpreter_id, 0) >= 1 + for interpreter_id in interpreter_ids + ) + # At least one sub-interpreter should expose multiple Python threads. + assert any( + count > 1 + for interpreter_id, count in counts_by_interpreter.items() + if interpreter_id != 0 + ) + + assert any(thread.native_frames for thread in threads) + _assert_native_eval_symbols(threads) + + +@ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +def test_subinterpreters_for_core(python, tmpdir): + _, python_executable = python + + test_file = Path(str(tmpdir)) / "subinterpreters_program.py" + test_file.write_text(PROGRAM) + + with generate_core_file(python_executable, test_file, tmpdir) as core_file: + threads = list( + get_process_threads_for_core( + core_file, + python_executable, + native_mode=NativeReportingMode.OFF, + ) + ) + + interpreter_ids = _assert_subinterpreter_coverage(threads) + assert all(not thread.native_frames for thread in threads) + _assert_interpreter_headers( + threads=threads, + native_mode=NativeReportingMode.OFF, + interpreter_ids=interpreter_ids, + ) + + +@ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +@pytest.mark.parametrize( + "native_mode", + [ + NativeReportingMode.PYTHON, + NativeReportingMode.LAST, + NativeReportingMode.ALL, + ], + ids=["python", "last", "all"], +) +def test_subinterpreters_for_core_with_native(python, tmpdir, native_mode): + _, python_executable = python + + test_file = Path(str(tmpdir)) / "subinterpreters_program.py" + test_file.write_text(PROGRAM) + + with generate_core_file(python_executable, test_file, tmpdir) as core_file: + threads = list( + get_process_threads_for_core( + core_file, + python_executable, + native_mode=native_mode, + ) + ) + + interpreter_ids = _assert_subinterpreter_coverage(threads) + assert any(thread.native_frames for thread in threads) + _assert_native_eval_symbols(threads) + output = _assert_interpreter_headers( + threads=threads, + native_mode=native_mode, + interpreter_ids=interpreter_ids, + ) + assert "(C)" in output or "Unable to merge native stack" in output diff --git a/tests/unit/test_main.py b/tests/unit/test_main.py index 2401d554..6f591dd9 100644 --- a/tests/unit/test_main.py +++ b/tests/unit/test_main.py @@ -191,8 +191,8 @@ def test_process_remote_default(): with patch( "pystack.__main__.get_process_threads" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ): get_process_threads_mock.return_value = threads @@ -207,8 +207,11 @@ def test_process_remote_default(): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + native_mode=NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -224,8 +227,8 @@ def test_process_remote_no_block(): with patch( "pystack.__main__.get_process_threads" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ): get_process_threads_mock.return_value = threads @@ -240,8 +243,11 @@ def test_process_remote_no_block(): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + native_mode=NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -265,8 +271,8 @@ def test_process_remote_native(argument, mode): with patch( "pystack.__main__.get_process_threads" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ): get_process_threads_mock.return_value = threads @@ -281,7 +287,12 @@ def test_process_remote_native(argument, mode): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [call(thread, mode) for thread in threads] + TracebackPrinterMock.assert_called_once_with( + native_mode=mode, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads + ] def test_process_remote_locals(): @@ -296,8 +307,8 @@ def test_process_remote_locals(): with patch( "pystack.__main__.get_process_threads" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ): get_process_threads_mock.return_value = threads @@ -312,8 +323,11 @@ def test_process_remote_locals(): locals=True, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + native_mode=NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -329,8 +343,8 @@ def test_process_remote_native_no_block(capsys): with patch( "pystack.__main__.get_process_threads" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ): get_process_threads_mock.return_value = threads @@ -340,7 +354,8 @@ def test_process_remote_native_no_block(capsys): main() get_process_threads_mock.assert_not_called() - print_thread_mock.assert_not_called() + TracebackPrinterMock.assert_not_called() + TracebackPrinterMock.return_value.print_thread.assert_not_called() def test_process_remote_exhaustive(): @@ -355,8 +370,8 @@ def test_process_remote_exhaustive(): with patch( "pystack.__main__.get_process_threads" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ): get_process_threads_mock.return_value = threads @@ -371,8 +386,11 @@ def test_process_remote_exhaustive(): locals=False, method=StackMethod.ALL, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + native_mode=NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -389,8 +407,8 @@ def test_process_remote_error(exception, exval, capsys): with patch( "pystack.__main__.get_process_threads" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -403,7 +421,10 @@ def test_process_remote_error(exception, exval, capsys): # THEN get_process_threads_mock.assert_called_once() - print_thread_mock.assert_not_called() + TracebackPrinterMock.assert_called_once_with( + native_mode=NativeReportingMode.OFF, include_subinterpreters=True + ) + TracebackPrinterMock.return_value.print_thread.assert_not_called() capture = capsys.readouterr() assert "Oh no!" in capture.err @@ -420,8 +441,8 @@ def test_process_core_default_without_executable(): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -448,8 +469,11 @@ def test_process_core_default_without_executable(): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -471,8 +495,8 @@ def test_process_core_default_gzip_without_executable(): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -504,8 +528,11 @@ def test_process_core_default_gzip_without_executable(): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] gzip_open_mock.assert_called_with(Path("corefile.gz"), "rb") @@ -575,8 +602,8 @@ def test_process_core_default_with_executable(): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -600,8 +627,11 @@ def test_process_core_default_with_executable(): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -625,8 +655,8 @@ def test_process_core_native(argument, mode): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -650,7 +680,10 @@ def test_process_core_native(argument, mode): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [call(thread, mode) for thread in threads] + TracebackPrinterMock.assert_called_once_with(mode, include_subinterpreters=True) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads + ] def test_process_core_locals(): @@ -665,8 +698,8 @@ def test_process_core_locals(): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -690,8 +723,11 @@ def test_process_core_locals(): locals=True, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -714,8 +750,8 @@ def test_process_core_with_search_path(): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -739,8 +775,11 @@ def test_process_core_with_search_path(): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -756,8 +795,8 @@ def test_process_core_with_search_root(): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -789,8 +828,11 @@ def test_process_core_with_search_root(): locals=False, method=StackMethod.AUTO, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -802,7 +844,7 @@ def test_process_core_with_not_readable_search_root(): # WHEN with patch("pystack.__main__.get_process_threads_for_core"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("pathlib.Path.exists", return_value=True), patch( "pystack.__main__.CoreFileAnalyzer" ), patch( @@ -826,7 +868,7 @@ def test_process_core_with_invalid_search_root(): # WHEN with patch("pystack.__main__.get_process_threads_for_core"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("pathlib.Path.exists", return_value=True), patch( "pystack.__main__.CoreFileAnalyzer" ), patch( @@ -851,8 +893,8 @@ def path_exists(what): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch.object( Path, "exists", path_exists @@ -865,7 +907,8 @@ def path_exists(what): # THEN get_process_threads_mock.assert_not_called() - print_thread_mock.assert_not_called() + TracebackPrinterMock.assert_not_called() + TracebackPrinterMock.return_value.print_thread.assert_not_called() def test_process_core_executable_does_not_exit(): @@ -883,8 +926,8 @@ def does_exit(what): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "pystack.__main__.is_gzip", return_value=False ), patch( "sys.argv", argv @@ -898,7 +941,8 @@ def does_exit(what): # THEN get_process_threads_mock.assert_not_called() - print_thread_mock.assert_not_called() + TracebackPrinterMock.assert_not_called() + TracebackPrinterMock.return_value.print_thread.assert_not_called() @pytest.mark.parametrize( @@ -914,8 +958,8 @@ def test_process_core_error(exception, exval, capsys): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -935,7 +979,10 @@ def test_process_core_error(exception, exval, capsys): # THEN get_process_threads_mock.assert_called_once() - print_thread_mock.assert_not_called() + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + TracebackPrinterMock.return_value.print_thread.assert_not_called() capture = capsys.readouterr() assert "Oh no!" in capture.err @@ -951,8 +998,8 @@ def test_process_core_exhaustive(): with patch( "pystack.__main__.get_process_threads_for_core" ) as get_process_threads_mock, patch( - "pystack.__main__.print_thread" - ) as print_thread_mock, patch( + "pystack.__main__.TracebackPrinter" + ) as TracebackPrinterMock, patch( "sys.argv", argv ), patch( "pathlib.Path.exists", return_value=True @@ -976,8 +1023,11 @@ def test_process_core_exhaustive(): locals=False, method=StackMethod.ALL, ) - assert print_thread_mock.mock_calls == [ - call(thread, NativeReportingMode.OFF) for thread in threads + TracebackPrinterMock.assert_called_once_with( + NativeReportingMode.OFF, include_subinterpreters=True + ) + assert TracebackPrinterMock.return_value.print_thread.mock_calls == [ + call(thread) for thread in threads ] @@ -990,7 +1040,7 @@ def test_default_colored_output(): # WHEN with patch("pystack.__main__.get_process_threads"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("os.environ", environ): main() @@ -1008,7 +1058,7 @@ def test_nocolor_output(): # WHEN with patch("pystack.__main__.get_process_threads"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("os.environ", environ): main() @@ -1026,7 +1076,7 @@ def test_nocolor_output_at_the_front_for_process(): # WHEN with patch("pystack.__main__.get_process_threads"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("os.environ", environ): main() @@ -1043,7 +1093,7 @@ def test_nocolor_output_at_the_front_for_core(): # WHEN with patch("pystack.__main__.get_process_threads_for_core"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("os.environ", environ), patch( "pathlib.Path.exists", return_value=True ), patch( @@ -1069,7 +1119,7 @@ def test_global_options_can_be_placed_at_any_point(option): # WHEN with patch("pystack.__main__.get_process_threads_for_core"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("os.environ", environ), patch( "pathlib.Path.exists", return_value=True ), patch( @@ -1092,7 +1142,7 @@ def test_verbose_as_global_options_sets_correctly_the_logger(): # WHEN with patch("pystack.__main__.get_process_threads"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("sys.argv", argv), patch("os.environ", environ), patch( "pathlib.Path.exists", return_value=True ), patch( @@ -1241,7 +1291,7 @@ def test_process_core_does_not_crash_if_core_analyzer_fails(method): # WHEN / THEN with patch("pystack.__main__.get_process_threads_for_core"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("pystack.__main__.is_elf", return_value=True), patch( "pystack.__main__.is_gzip", return_value=False ), patch( @@ -1268,7 +1318,7 @@ def test_core_file_missing_modules_are_logged(caplog, native): # WHEN with patch("pystack.__main__.get_process_threads_for_core"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("pystack.__main__.is_elf", return_value=True), patch( "pystack.__main__.is_gzip", return_value=False ), patch( @@ -1301,7 +1351,7 @@ def test_core_file_missing_build_ids_are_logged(caplog, native): # WHEN with patch("pystack.__main__.get_process_threads_for_core"), patch( - "pystack.__main__.print_thread" + "pystack.__main__.TracebackPrinter" ), patch("pystack.__main__.is_elf", return_value=True), patch( "pystack.__main__.is_gzip", return_value=False ), patch( @@ -1342,7 +1392,7 @@ def test_executable_is_not_elf_uses_the_first_map(): with patch( "pystack.__main__.get_process_threads_for_core" - ) as get_process_threads_mock, patch("pystack.__main__.print_thread"), patch( + ) as get_process_threads_mock, patch("pystack.__main__.TracebackPrinter"), patch( "pystack.__main__.is_elf", lambda x: x == real_executable ), patch( "pystack.__main__.is_gzip", return_value=False diff --git a/tests/unit/test_traceback_formatter.py b/tests/unit/test_traceback_formatter.py index 636cfc6b..86c9c3ad 100644 --- a/tests/unit/test_traceback_formatter.py +++ b/tests/unit/test_traceback_formatter.py @@ -4,8 +4,8 @@ import pytest from pystack.engine import NativeReportingMode +from pystack.traceback_formatter import TracebackPrinter from pystack.traceback_formatter import format_thread -from pystack.traceback_formatter import print_thread from pystack.types import SYMBOL_IGNORELIST from pystack.types import LocationInfo from pystack.types import NativeFrame @@ -1205,6 +1205,7 @@ def test_traceback_formatter_native_last(): def test_print_thread(capsys): + printer = TracebackPrinter(NativeReportingMode.OFF) # GIVEN thread = PyThread( tid=1, @@ -1220,7 +1221,9 @@ def test_print_thread(capsys): "pystack.traceback_formatter.format_thread", return_value=("1", "2", "3"), ): - print_thread(thread, NativeReportingMode.OFF) + printer.print_thread( + thread, + ) # THEN @@ -1629,3 +1632,270 @@ def test_native_traceback_with_shim_frames(): colored_mock.assert_any_call("x =", color="blue") colored_mock.assert_any_call('"This is the line 2" ', color="blue") colored_mock.assert_any_call("(1+1)", color="blue") + + +@pytest.mark.parametrize( + "native_mode", + [ + NativeReportingMode.OFF, + NativeReportingMode.ALL, + NativeReportingMode.PYTHON, + NativeReportingMode.LAST, + ], +) +def test_traceback_printer_created_with_native_level(native_mode): + # GIVEN / WHEN + printer = TracebackPrinter(native_mode) + + # THEN + assert printer.native_mode is native_mode + assert printer.include_subinterpreters is False + assert printer._current_interpreter_id == -1 + + +def test_traceback_printer_created_with_subinterpreters(): + # GIVEN / WHEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=True) + + # THEN + assert printer.native_mode is NativeReportingMode.OFF + assert printer.include_subinterpreters is True + + +def test_print_thread_passes_native_mode_to_format_thread(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.ALL) + thread = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1", "line2"), + ) as format_mock: + printer.print_thread(thread) + + # THEN + format_mock.assert_called_once_with(thread, NativeReportingMode.ALL) + captured = capsys.readouterr() + assert captured.out == "line1\nline2\n" + + +def test_print_thread_with_subinterpreters(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=True) + thread = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=0, + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1", "line2"), + ): + printer.print_thread(thread) + + # THEN + captured = capsys.readouterr() + assert "Interpreter-0 (main)" in captured.out + # Lines should be indented with 2 spaces + assert " line1\n" in captured.out + assert " line2\n" in captured.out + + +def test_print_thread_with_subinterpreters_nonzero_interp(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=True) + thread = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=2, + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1",), + ): + printer.print_thread(thread) + + # THEN + captured = capsys.readouterr() + assert "Interpreter-2\n" in captured.out + assert " line1\n" in captured.out + + +def test_print_thread_with_subinterpreters_none_interp(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=True) + thread = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=None, + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1",), + ): + printer.print_thread(thread) + + # THEN + captured = capsys.readouterr() + assert "Interpreter-Unknown\n" in captured.out + + +def test_print_thread_with_subinterpreters_same_interp_no_repeat_header(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=True) + thread1 = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=1, + ) + thread2 = PyThread( + tid=2, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=1, + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1",), + ): + printer.print_thread(thread1) + printer.print_thread(thread2) + + # THEN + captured = capsys.readouterr() + # Header should appear only once + assert captured.out.count("Interpreter-1") == 1 + + +def test_print_thread_with_subinterpreters_main_interp_no_repeat_header(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=True) + thread1 = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=0, + ) + thread2 = PyThread( + tid=2, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=0, + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1",), + ): + printer.print_thread(thread1) + printer.print_thread(thread2) + + # THEN + captured = capsys.readouterr() + # Header should appear only once + assert captured.out.count("Interpreter-0 (main)") == 1 + + +def test_print_thread_with_subinterpreters_different_interps_prints_headers(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=True) + thread1 = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=1, + ) + thread2 = PyThread( + tid=2, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=2, + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1",), + ): + printer.print_thread(thread1) + printer.print_thread(thread2) + + # THEN + captured = capsys.readouterr() + assert "Interpreter-1\n" in captured.out + assert "Interpreter-2\n" in captured.out + + +def test_print_thread_without_subinterpreters_no_indentation(capsys): + # GIVEN + printer = TracebackPrinter(NativeReportingMode.OFF, include_subinterpreters=False) + thread = PyThread( + tid=1, + frame=None, + native_frames=[], + holds_the_gil=False, + is_gc_collecting=False, + python_version=(3, 8), + interpreter_id=1, + ) + + # WHEN + with patch( + "pystack.traceback_formatter.format_thread", + return_value=("line1", "line2"), + ): + printer.print_thread(thread) + + # THEN + captured = capsys.readouterr() + # No interpreter header and no indentation + assert "Interpreter" not in captured.out + assert captured.out == "line1\nline2\n" diff --git a/tests/utils.py b/tests/utils.py index dacb18fa..dcf1f639 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -307,6 +307,13 @@ def all_pystack_combinations(corefile=False, native=False): ) +ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS = pytest.mark.parametrize( + "python", + [python[:2] for python in AVAILABLE_PYTHONS if python.version >= (3, 14)], + ids=[python[1].name for python in AVAILABLE_PYTHONS if python.version >= (3, 14)], +) + + ALL_PYTHONS_THAT_DO_NOT_SUPPORT_ELF_DATA = pytest.mark.parametrize( "python", [python[:2] for python in AVAILABLE_PYTHONS if python.version < (3, 10)], From 77a8d895b3351d602acc03eb254605dec928e519 Mon Sep 17 00:00:00 2001 From: Pablo Galindo Salgado Date: Sun, 8 Mar 2026 17:19:52 +0000 Subject: [PATCH 2/2] Fix --native subinterpreter merge for shared TIDs using shim stack anchors When multiple subinterpreters execute on the same OS thread, each PyThread previously received the full native stack for that TID. That made native/Python merging fail because every thread in the group saw the same set of eval frames, so n_eval did not match each thread's entry-frame count. This change makes native merging deterministic for same-TID subinterpreter groups. The game is played like this: - Capture a per-thread stack anchor in the native layer: - add Thread::StackAnchor() and d_stack_anchor. - compute the anchor from the Python frame chain by walking backwards to the nearest stack/shim-owned frame (FRAME_OWNED_BY_INTERPRETER / FRAME_OWNED_BY_CSTACK on 3.14+, FRAME_OWNED_BY_CSTACK on 3.12/3.13). - Thread construction now forwards this anchor into PyThread as stack_anchor. - Switch process/core thread assembly from immediate yielding to collect-then-normalize. - Group Python threads by tid when native mode is enabled. - For groups with more than one thread: - pick a canonical native stack, - sort group members by stack_anchor (stable tie-breaker), - partition eval-frame ownership according to each thread's Python entry-frame count, - slice native frames accordingly per thread. - If counts are inconsistent, keep existing behavior for that group and skip slicing. --- src/pystack/_pystack.pyx | 99 +++++++- src/pystack/_pystack/pythread.cpp | 47 ++++ src/pystack/_pystack/pythread.h | 5 + src/pystack/_pystack/pythread.pxd | 2 + src/pystack/types.py | 1 + tests/integration/test_subinterpreters.py | 285 +++++++++++++++++++++- tests/utils.py | 4 +- 7 files changed, 437 insertions(+), 6 deletions(-) diff --git a/src/pystack/_pystack.pyx b/src/pystack/_pystack.pyx index d9a0488b..4efb04ce 100644 --- a/src/pystack/_pystack.pyx +++ b/src/pystack/_pystack.pyx @@ -66,6 +66,7 @@ from .types import NativeFrame from .types import PyCodeObject from .types import PyFrame from .types import PyThread +from .types import frame_type LOGGER = logging.getLogger(__file__) @@ -490,6 +491,7 @@ cdef object _construct_threads_from_interpreter_state( python_version, interpreter_id, name=get_thread_name(pid, current_thread.Tid()), + stack_anchor=current_thread.StackAnchor(), ) ) current_thread = ( @@ -498,6 +500,91 @@ cdef object _construct_threads_from_interpreter_state( return threads + +def _entry_frame_count(thread: PyThread) -> int: + return sum(1 for frame in thread.all_frames if frame.is_entry) + + +def _eval_frame_positions(thread: PyThread): + if not thread.python_version: + return [] + return [ + index + for index, native_frame in enumerate(thread.native_frames) + if frame_type(native_frame, thread.python_version) == NativeFrame.FrameType.EVAL + ] + + +def _slice_native_stacks_for_same_tid_threads(threads) -> None: + if len(threads) < 2: + return + + canonical = next((thread for thread in threads if thread.native_frames), None) + if canonical is None: + return + + canonical_frames = list(canonical.native_frames) + eval_positions = [ + index + for index, native_frame in enumerate(canonical_frames) + if frame_type(native_frame, canonical.python_version) == NativeFrame.FrameType.EVAL + ] + if not eval_positions: + return + + entry_counts = [_entry_frame_count(thread) for thread in threads] + if sum(entry_counts) != len(eval_positions): + LOGGER.debug( + "Skipping same-tid native slicing for tid %s due to mismatched counts: " + "entry=%s eval=%s", + threads[0].tid, + sum(entry_counts), + len(eval_positions), + ) + return + + ordered_threads = sorted( + enumerate(threads), + key=lambda item: ( + item[1].stack_anchor is None, + -(item[1].stack_anchor or 0), + item[0], + ), + ) + + cursor = 0 + for _, thread in ordered_threads: + required_eval_frames = _entry_frame_count(thread) + if required_eval_frames == 0: + thread.native_frames = [] + continue + + group_start = cursor + group_end = cursor + required_eval_frames + prev_eval = eval_positions[group_start - 1] if group_start > 0 else -1 + next_eval = ( + eval_positions[group_end] + if group_end < len(eval_positions) + else len(canonical_frames) + ) + thread.native_frames = canonical_frames[prev_eval + 1 : next_eval] + cursor = group_end + + +def _normalize_python_threads(threads, native_mode: NativeReportingMode): + if native_mode == NativeReportingMode.OFF: + return threads + + threads_by_tid = {} + for thread in threads: + threads_by_tid.setdefault(thread.tid, []).append(thread) + + for group in threads_by_tid.values(): + if len(group) <= 1: + continue + _slice_native_stacks_for_same_tid_threads(group) + return threads + cdef object _construct_os_thread( shared_ptr[AbstractProcessManager] manager, int pid, int tid ): @@ -625,6 +712,7 @@ def _get_process_threads( ) all_tids = list(manager.get().Tids()) + threads = [] while head: add_native_traces = native_mode != NativeReportingMode.OFF for thread in _construct_threads_from_interpreter_state( @@ -637,9 +725,12 @@ def _get_process_threads( ): if thread.tid in all_tids: all_tids.remove(thread.tid) - yield thread + threads.append(thread) head = InterpreterUtils.getNextInterpreter(manager, head) + for thread in _normalize_python_threads(threads, native_mode): + yield thread + if native_mode == NativeReportingMode.ALL: yield from _construct_os_threads(manager, pid, all_tids) @@ -772,6 +863,7 @@ def _get_process_threads_for_core( ) all_tids = list(manager.get().Tids()) + threads = [] while head: add_native_traces = native_mode != NativeReportingMode.OFF @@ -785,8 +877,11 @@ def _get_process_threads_for_core( ): if thread.tid in all_tids: all_tids.remove(thread.tid) - yield thread + threads.append(thread) head = InterpreterUtils.getNextInterpreter(manager, head) + for thread in _normalize_python_threads(threads, native_mode): + yield thread + if native_mode == NativeReportingMode.ALL: yield from _construct_os_threads(manager, pymanager.pid, all_tids) diff --git a/src/pystack/_pystack/pythread.cpp b/src/pystack/_pystack/pythread.cpp index b6b90a24..ae496e92 100644 --- a/src/pystack/_pystack/pythread.cpp +++ b/src/pystack/_pystack/pythread.cpp @@ -2,6 +2,7 @@ #include #include +#include "cpython/frame.h" #include "cpython/pthread.h" #include "interpreter.h" #include "logging.h" @@ -18,6 +19,7 @@ namespace pystack { Thread::Thread(pid_t pid, pid_t tid) : d_pid(pid) , d_tid(tid) +, d_stack_anchor(0) { } @@ -27,6 +29,12 @@ Thread::Tid() const return d_tid; } +remote_addr_t +Thread::StackAnchor() const +{ + return d_stack_anchor; +} + const std::vector& Thread::NativeFrames() const { @@ -148,6 +156,7 @@ PyThread::PyThread(const std::shared_ptr& manager, << frame_addr; d_first_frame = std::make_unique(manager, frame_addr, 0); } + d_stack_anchor = getStackAnchor(manager, frame_addr); d_addr = addr; remote_addr_t candidate_next_addr = ts.getField(&py_thread_v::o_next); @@ -237,6 +246,44 @@ PyThread::getFrameAddr( } } +remote_addr_t +PyThread::getStackAnchor( + const std::shared_ptr& manager, + remote_addr_t frame_addr) +{ + if (!frame_addr) { + return 0; + } + if (!manager->versionIsAtLeast(3, 12)) { + return frame_addr; + } + + remote_addr_t current_addr = frame_addr; + for (int i = 0; i < 4096 && current_addr; ++i) { + Structure current_frame(manager, current_addr); + auto owner = current_frame.getField(&py_frame_v::o_owner); + + if (manager->versionIsAtLeast(3, 14)) { + if (owner == Python3_14::FRAME_OWNED_BY_INTERPRETER + || owner == Python3_14::FRAME_OWNED_BY_CSTACK) + { + return current_addr; + } + } else { + if (owner == Python3_12::FRAME_OWNED_BY_CSTACK) { + return current_addr; + } + } + + remote_addr_t next_addr = current_frame.getField(&py_frame_v::o_back); + if (next_addr == current_addr) { + break; + } + current_addr = next_addr; + } + return frame_addr; +} + std::shared_ptr PyThread::FirstFrame() const { diff --git a/src/pystack/_pystack/pythread.h b/src/pystack/_pystack/pythread.h index ab02c672..067ff986 100644 --- a/src/pystack/_pystack/pythread.h +++ b/src/pystack/_pystack/pythread.h @@ -16,6 +16,7 @@ class Thread public: Thread(pid_t pid, pid_t tid); pid_t Tid() const; + remote_addr_t StackAnchor() const; const std::vector& NativeFrames() const; // Methods @@ -25,6 +26,7 @@ class Thread // Data members pid_t d_pid; pid_t d_tid; + remote_addr_t d_stack_anchor; std::vector d_native_frames; }; @@ -50,6 +52,9 @@ class PyThread : public Thread static remote_addr_t getFrameAddr( const std::shared_ptr& manager, Structure& ts); + static remote_addr_t getStackAnchor( + const std::shared_ptr& manager, + remote_addr_t frame_addr); private: // Data members diff --git a/src/pystack/_pystack/pythread.pxd b/src/pystack/_pystack/pythread.pxd index 3930a6de..b70f6825 100644 --- a/src/pystack/_pystack/pythread.pxd +++ b/src/pystack/_pystack/pythread.pxd @@ -11,6 +11,7 @@ cdef extern from "pythread.h" namespace "pystack": cdef cppclass NativeThread "pystack::Thread": NativeThread(int, int) except+ int Tid() + remote_addr_t StackAnchor() vector[NativeFrame]& NativeFrames() void populateNativeStackTrace(shared_ptr[AbstractProcessManager] manager) except+ @@ -28,6 +29,7 @@ cdef extern from "pythread.h" namespace "pystack::PyThread": cdef extern from "pythread.h" namespace "pystack": cdef cppclass Thread "pystack::PyThread": int Tid() + remote_addr_t StackAnchor() shared_ptr[FrameObject] FirstFrame() shared_ptr[Thread] NextThread() vector[NativeFrame]& NativeFrames() diff --git a/src/pystack/types.py b/src/pystack/types.py index 5eb77ced..d04c5bef 100644 --- a/src/pystack/types.py +++ b/src/pystack/types.py @@ -117,6 +117,7 @@ class PyThread: python_version: Optional[Tuple[int, int]] interpreter_id: Optional[int] = None name: Optional[str] = None + stack_anchor: Optional[int] = None @property def frames(self) -> Iterable[PyFrame]: diff --git a/tests/integration/test_subinterpreters.py b/tests/integration/test_subinterpreters.py index 91f08013..8fbdedbe 100644 --- a/tests/integration/test_subinterpreters.py +++ b/tests/integration/test_subinterpreters.py @@ -1,4 +1,6 @@ import io +import subprocess +import time from collections import Counter from contextlib import redirect_stdout from pathlib import Path @@ -21,12 +23,32 @@ NUM_INTERPRETERS_WITH_THREADS = 2 NUM_THREADS_PER_SUBINTERPRETER = 2 +# Compatibility shim so test programs work on both 3.13 (_interpreters) +# and 3.14+ (concurrent.interpreters). +_INTERPRETERS_SHIM = """\ +import sys as _sys +try: + from concurrent import interpreters +except ImportError: + import _interpreters as _raw + class _W: + def __init__(self, id): + self.id = id + def exec(self, code): + _raw.exec(self.id, code) + class interpreters: + @staticmethod + def create(): + return _W(_raw.create()) + Interpreter = _W +""" + PROGRAM = f"""\ import sys import threading import time -from concurrent import interpreters +{_INTERPRETERS_SHIM} NUM_INTERPRETERS = {NUM_INTERPRETERS} @@ -67,7 +89,7 @@ def start_interpreter_async(interp, code): import threading import time -from concurrent import interpreters +{_INTERPRETERS_SHIM} NUM_INTERPRETERS = {NUM_INTERPRETERS_WITH_THREADS} @@ -117,6 +139,104 @@ def worker(): time.sleep(1) """ +PROGRAM_NESTED_SAME_THREAD = ( + """\ +import sys +import threading +import time + +""" + + _INTERPRETERS_SHIM + + """ +_SHIM = '''""" + + _INTERPRETERS_SHIM + + """''' + +fifo = sys.argv[1] + +interp_outer = interpreters.create() +interp_inner = interpreters.create() + +inner_code = f'''\\ +import time +with open({fifo!r}, "w") as f: + f.write("ready") +while True: + time.sleep(1) +''' +outer_code = _SHIM + f''' +interpreters.Interpreter({{inner_id}}).exec({{inner_code!r}}) +'''.format(inner_id=interp_inner.id, inner_code=inner_code) + +t = threading.Thread(target=interp_outer.exec, args=(outer_code,)) +t.daemon = True +t.start() + +while True: + time.sleep(1) +""" +) + +PROGRAM_TWO_THREADS_THREE_SUBINTERPRETERS_EACH = ( + """\ +import sys +import threading +import time +from pathlib import Path + +""" + + _INTERPRETERS_SHIM + + """ +_SHIM = '''""" + + _INTERPRETERS_SHIM + + """''' + +signal_file = Path(sys.argv[1]) + + +def make_level3_code(token): + return f'''\\ +import time +from pathlib import Path +Path({str(signal_file)!r}).open("a").write("{token}\\\\n") +while True: + time.sleep(1) +''' + + +def make_level2_code(interp3_id, level3_code): + return _SHIM + f''' +interpreters.Interpreter({interp3_id}).exec({level3_code!r}) +''' + + +def make_level1_code(interp2_id, level2_code): + return _SHIM + f''' +interpreters.Interpreter({interp2_id}).exec({level2_code!r}) +''' + + +def launch_chain(token): + interp1 = interpreters.create() + interp2 = interpreters.create() + interp3 = interpreters.create() + + level3_code = make_level3_code(token) + level2_code = make_level2_code(interp3.id, level3_code) + level1_code = make_level1_code(interp2.id, level2_code) + interp1.exec(level1_code) + + +t1 = threading.Thread(target=launch_chain, args=("chain1",), daemon=True) +t2 = threading.Thread(target=launch_chain, args=("chain2",), daemon=True) +t1.start() +t2.start() + +while True: + time.sleep(1) +""" +) + def _collect_threads( python_executable: Path, @@ -201,6 +321,67 @@ def _assert_native_eval_symbols(threads) -> None: assert any(frame.path and "?" not in frame.path for frame in eval_frames) +def _assert_mergeable_same_tid_groups(threads) -> bool: + groups = {} + for thread in threads: + groups.setdefault(thread.tid, []).append(thread) + + found_shared_tid = False + for group in groups.values(): + interpreter_ids = { + thread.interpreter_id + for thread in group + if thread.interpreter_id is not None + } + if len(group) < 2 or len(interpreter_ids) < 2: + continue + found_shared_tid = True + for thread in group: + eval_frames = [ + frame + for frame in thread.native_frames + if frame_type(frame, thread.python_version) + == NativeFrame.FrameType.EVAL + ] + entry_count = sum(frame.is_entry for frame in thread.all_frames) + assert len(eval_frames) == entry_count + return found_shared_tid + + +def _shared_tid_groups_with_min_interpreters(threads, min_interpreters): + groups = {} + for thread in threads: + groups.setdefault(thread.tid, []).append(thread) + + matching = [] + for tid, group in groups.items(): + interpreter_ids = { + thread.interpreter_id + for thread in group + if thread.interpreter_id is not None + } + if len(interpreter_ids) >= min_interpreters: + matching.append((tid, group)) + return matching + + +def _assert_strict_native_eval_symbols_for_group(group) -> None: + for thread in group: + eval_frames = [ + frame + for frame in thread.native_frames + if frame_type(frame, thread.python_version) == NativeFrame.FrameType.EVAL + ] + assert eval_frames + assert all("?" not in frame.symbol for frame in eval_frames) + if any(frame.linenumber == 0 for frame in eval_frames): + assert all(frame.linenumber == 0 for frame in eval_frames) + assert all(frame.path == "???" for frame in eval_frames) + else: + assert all(frame.linenumber != 0 for frame in eval_frames) + assert any(frame.path and "?" not in frame.path for frame in eval_frames) + + @ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS def test_subinterpreters(python, tmpdir): _, python_executable = python @@ -288,6 +469,106 @@ def test_subinterpreters_many_threads_with_native(python, tmpdir): _assert_native_eval_symbols(threads) +@ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +def test_subinterpreters_nested_same_thread_with_native(python, tmpdir): + _, python_executable = python + + test_file = Path(str(tmpdir)) / "subinterpreters_nested_same_thread.py" + test_file.write_text(PROGRAM_NESTED_SAME_THREAD) + + with spawn_child_process(python_executable, test_file, tmpdir) as child_process: + threads = list( + get_process_threads( + child_process.pid, + stop_process=True, + native_mode=NativeReportingMode.PYTHON, + method=StackMethod.DEBUG_OFFSETS, + ) + ) + + assert any(thread.native_frames for thread in threads) + _assert_native_eval_symbols(threads) + + has_shared_tid = _assert_mergeable_same_tid_groups(threads) + assert has_shared_tid + + output = _assert_interpreter_headers( + threads=threads, + native_mode=NativeReportingMode.PYTHON, + interpreter_ids=_interpreter_ids(threads), + ) + assert ( + "Unable to merge native stack due to insufficient native information" + not in output + ) + + +@ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS +def test_subinterpreters_two_threads_three_per_thread_with_native(python, tmpdir): + _, python_executable = python + + test_file = Path(str(tmpdir)) / "subinterpreters_two_threads_three_each.py" + signal_file = Path(str(tmpdir)) / "subinterpreters_ready.txt" + signal_file.write_text("") + test_file.write_text(PROGRAM_TWO_THREADS_THREE_SUBINTERPRETERS_EACH) + + with subprocess.Popen( + [str(python_executable), str(test_file), str(signal_file)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + ) as child_process: + deadline = time.time() + 10 + while time.time() < deadline: + lines = [line for line in signal_file.read_text().splitlines() if line] + if len(lines) >= 2: + break + time.sleep(0.1) + else: + child_process.terminate() + child_process.kill() + raise AssertionError("Timed out waiting for nested subinterpreter chains") + + threads = list( + get_process_threads( + child_process.pid, + stop_process=True, + native_mode=NativeReportingMode.PYTHON, + method=StackMethod.DEBUG_OFFSETS, + ) + ) + + child_process.terminate() + child_process.kill() + child_process.wait(timeout=5) + + groups = _shared_tid_groups_with_min_interpreters(threads, min_interpreters=3) + assert len(groups) >= 2 + + for _, group in groups: + _assert_strict_native_eval_symbols_for_group(group) + for thread in group: + eval_frames = [ + frame + for frame in thread.native_frames + if frame_type(frame, thread.python_version) + == NativeFrame.FrameType.EVAL + ] + entry_count = sum(frame.is_entry for frame in thread.all_frames) + assert len(eval_frames) == entry_count + assert len(eval_frames) > 0 + + output = _assert_interpreter_headers( + threads=threads, + native_mode=NativeReportingMode.PYTHON, + interpreter_ids=_interpreter_ids(threads), + ) + assert ( + "Unable to merge native stack due to insufficient native information" + not in output + ) + + @ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS def test_subinterpreters_for_core(python, tmpdir): _, python_executable = python diff --git a/tests/utils.py b/tests/utils.py index dcf1f639..ea9a2ec3 100644 --- a/tests/utils.py +++ b/tests/utils.py @@ -309,8 +309,8 @@ def all_pystack_combinations(corefile=False, native=False): ALL_PYTHONS_THAT_SUPPORT_SUBINTERPRETERS = pytest.mark.parametrize( "python", - [python[:2] for python in AVAILABLE_PYTHONS if python.version >= (3, 14)], - ids=[python[1].name for python in AVAILABLE_PYTHONS if python.version >= (3, 14)], + [python[:2] for python in AVAILABLE_PYTHONS if python.version >= (3, 13)], + ids=[python[1].name for python in AVAILABLE_PYTHONS if python.version >= (3, 13)], )